Update to latest synapse provider spec

master
Andrew Morgan 5 years ago
parent 962cbce291
commit 2be3b40e9c

@ -21,7 +21,14 @@ __version__ = "0.0.1"
class SamlMappingProvider(object): class SamlMappingProvider(object):
def __init__(self): def __init__(self, parsed_config):
"""A Mozilla-flavoured, Synapse user mapping provider
Args:
parsed_config: A configuration object. The result of self.parse_config
"""
self._mxid_source_attribute = parsed_config.mxid_source_attribute
mxid_localpart_allowed_characters = set( mxid_localpart_allowed_characters = set(
"_-./=" + string.ascii_lowercase + string.digits "_-./=" + string.ascii_lowercase + string.digits
) )
@ -33,15 +40,12 @@ class SamlMappingProvider(object):
def saml_response_to_user_attributes( def saml_response_to_user_attributes(
self, self,
config: dict,
saml_response: saml2.response.AuthnResponse, saml_response: saml2.response.AuthnResponse,
failures: int = 0, failures: int = 0,
) -> dict: ) -> dict:
"""Maps some text from a SAML response to attributes of a new user """Maps some text from a SAML response to attributes of a new user
Args: Args:
config: A configuration dictionary
saml_response: A SAML auth response object saml_response: A SAML auth response object
failures: How many times a call to this function with this failures: How many times a call to this function with this
@ -53,7 +57,7 @@ class SamlMappingProvider(object):
* displayname (str): The displayname of the user * displayname (str): The displayname of the user
""" """
# The calling function will catch the KeyError if this fails # The calling function will catch the KeyError if this fails
mxid_source = saml_response.ava[config["mxid_source_attribute"]][0] mxid_source = saml_response.ava[self._mxid_source_attribute][0]
# Truncate the username to the first found '@' character to prevent complete # Truncate the username to the first found '@' character to prevent complete
# emails being leaked # emails being leaked
@ -95,3 +99,35 @@ class SamlMappingProvider(object):
# Remove any trailing dots # Remove any trailing dots
username = self._string_end_dot_pattern.sub("", username) username = self._string_end_dot_pattern.sub("", username)
return username return username
@staticmethod
def parse_config(config: dict):
"""Parse the dict provided by the homeserver's config
Args:
config: A dictionary containing configuration options for this provider
Returns:
_SamlConfig: A custom config object
"""
pass
class _SamlConfig(object):
pass
saml_config = _SamlConfig()
saml_config.mxid_source_attribute = config["mxid_source_attribute"]
return saml_config
@staticmethod
def get_required_saml_attributes(config: dict):
"""Returns the required attributes of a SAML
Args:
config: A dictionary containing configuration options for this provider
Returns:
tuple[set,set]: The first set equates to the saml auth response attributes that
are required for the module to function, whereas the second set consists of
those attributes which can be used if available, but are not necessary
"""
saml_config = SamlMappingProvider.parse_config(config)
return {"uid", saml_config.mxid_source_attribute}, {"displayName"}

Loading…
Cancel
Save