diff --git a/saml_mapping_provider.py b/saml_mapping_provider.py index f3f6f3d..82f0a0e 100644 --- a/saml_mapping_provider.py +++ b/saml_mapping_provider.py @@ -29,18 +29,19 @@ class SamlMappingProvider(object): ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),)) ) self._multiple_to_single_dot_pattern = re.compile(r"\.{2,}") - self._string_end_dot_pattern = re.compile(r"\.$") - self._mxid_source_attribute = None def saml_response_to_user_attributes( self, + config: dict, saml_response: saml2.response.AuthnResponse, failures: int = 0, ) -> dict: """Maps some text from a SAML response to attributes of a new user Args: + config: A configuration dictionary + saml_response: A SAML auth response object failures: How many times a call to this function with this @@ -52,7 +53,7 @@ class SamlMappingProvider(object): * displayname (str): The displayname of the user """ # The calling function will catch the KeyError if this fails - mxid_source = saml_response.ava[self._mxid_source_attribute][0] + mxid_source = saml_response.ava[config["mxid_source_attribute"]][0] # Truncate the username to the first found '@' character to prevent complete # emails being leaked @@ -94,8 +95,3 @@ class SamlMappingProvider(object): # Remove any trailing dots username = self._string_end_dot_pattern.sub("", username) return username - - def parse_config(self, config): - """Parse the dict provided by the homeserver config""" - self._mxid_source_attribute = config.get("mxid_source_attribute", "uid") -