From 56cd947037ae0f2af6495e219ba282724242dfd9 Mon Sep 17 00:00:00 2001 From: Simeon Keske Date: Thu, 4 Jun 2020 15:57:59 +0200 Subject: [PATCH] fix username formatter --- .../username_picker.py | 34 +++++++++++-------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/matrix_synapse_saml_mozilla/username_picker.py b/matrix_synapse_saml_mozilla/username_picker.py index 5d2eeba..101d95e 100644 --- a/matrix_synapse_saml_mozilla/username_picker.py +++ b/matrix_synapse_saml_mozilla/username_picker.py @@ -26,7 +26,7 @@ from twisted.web.static import File import synapse.module_api from synapse.module_api import run_in_background from synapse.module_api.errors import SynapseError - +from synapse.types import UserID from synapse.api.errors import Codes, LoginError from matrix_synapse_saml_mozilla._sessions import ( @@ -170,7 +170,10 @@ class SubmitResource(AsyncResource): return else: password = request.args[b"password"][0].decode("utf-8", errors="replace") - registered_user_id = '@{}:localhost'.format(localpart) + if localpart.startswith("@"): + registered_user_id = localpart + else: + registered_user_id = UserID(localpart, self._module_api._hs.hostname).to_string() success = False try: @@ -247,17 +250,17 @@ class CheckResource(AsyncResource): async def async_render_POST(self, request: Request): # make sure that there is a valid mapping session, to stop people dictionary- # scanning for accounts - # session_id = request.getCookie(SESSION_COOKIE_NAME) - # if not session_id: - # _return_json({"error": "missing session_id"}, request) - # return - # - # session_id = session_id.decode("ascii", errors="replace") - # session = get_mapping_session(session_id) - # if not session: - # logger.info("Couldn't find session id %s", session_id) - # _return_json({"error": "unknown session"}, request) - # return + session_id = request.getCookie(SESSION_COOKIE_NAME) + if not session_id: + _return_json({"error": "missing session_id"}, request) + return + + session_id = session_id.decode("ascii", errors="replace") + session = get_mapping_session(session_id) + if not session: + logger.info("Couldn't find session id %s", session_id) + _return_json({"error": "unknown session"}, request) + return if b"username" not in request.args: _return_json({"error": "missing username"}, request) @@ -269,7 +272,10 @@ class CheckResource(AsyncResource): return password = request.args[b"password"][0].decode("utf-8", errors="replace") - uid = '@{}:localhost'.format(localpart) + if localpart.startswith("@"): + uid = localpart + else: + uid = UserID(localpart, self._module_api._hs.hostname).to_string() success = False try: