From 4c15ff5fd2a2db18cf0bacf4927ca9d69abb007f Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Mon, 13 Jan 2020 22:47:29 +0000 Subject: [PATCH] Implement new get_remote_user_id method --- .../mapping_provider.py | 23 ++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/matrix_synapse_saml_mozilla/mapping_provider.py b/matrix_synapse_saml_mozilla/mapping_provider.py index 4392d91..87f2d86 100644 --- a/matrix_synapse_saml_mozilla/mapping_provider.py +++ b/matrix_synapse_saml_mozilla/mapping_provider.py @@ -22,6 +22,7 @@ import attr import saml2.response import synapse.module_api +from synapse.api.errors import CodeMessageException from synapse.module_api.errors import RedirectException from matrix_synapse_saml_mozilla._sessions import ( @@ -39,7 +40,7 @@ MAPPING_SESSION_VALIDITY_PERIOD_MS = 15 * 60 * 1000 @attr.s class SamlConfig(object): - pass + use_name_id_for_remote_uid = attr.ib(type=bool) class SamlMappingProvider(object): @@ -52,6 +53,20 @@ class SamlMappingProvider(object): parsed_config: A configuration object. The result of self.parse_config """ self._random = random.SystemRandom() + self._config = parsed_config + + def get_remote_user_id( + self, saml_response: saml2.response.AuthnResponse, client_redirect_url: str + ): + """Extracts the remote user id from the SAML response""" + if self._config.use_name_id_for_remote_uid: + return saml_response.name_id + else: + try: + return saml_response.ava["uid"][0] + except KeyError: + logger.warning("SAML2 response lacks a 'uid' attestation") + raise CodeMessageException(400, "'uid' not in SAML2 response") def saml_response_to_user_attributes( self, @@ -73,7 +88,7 @@ class SamlMappingProvider(object): * mxid_localpart (str): Required. The localpart of the user's mxid * displayname (str): The displayname of the user """ - remote_user_id = saml_response.ava["uid"][0] + remote_user_id = self.get_remote_user_id(saml_response, client_redirect_url) displayname = saml_response.ava.get("displayName", [None])[0] expire_old_sessions() @@ -109,7 +124,9 @@ class SamlMappingProvider(object): Returns: SamlConfig: A custom config object """ - return SamlConfig() + return SamlConfig( + use_name_id_for_remote_uid=config.get("use_name_id_for_remote_uid"), + ) @staticmethod def get_saml_attributes(config: SamlConfig) -> Tuple[set, set]: