diff --git a/README.md b/README.md index bcbcb6f..0d86591 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ Add the following in your Synapse config: ```yaml saml2_config: user_mapping_provider: - module: "saml_mapping_provider_mozilla.SamlMappingProvider" + module: "matrix_synapse_saml_mozilla.SamlMappingProvider" config: mxid_source_attribute: "uid" ``` diff --git a/saml_mapping_provider_mozilla.py b/matrix_synapse_saml_mozilla.py similarity index 100% rename from saml_mapping_provider_mozilla.py rename to matrix_synapse_saml_mozilla.py diff --git a/setup.py b/setup.py index 3f03471..c355f8c 100755 --- a/setup.py +++ b/setup.py @@ -41,7 +41,7 @@ def exec_file(path_segments, name): setup( name="matrix-synapse-saml-mozilla", - version=exec_file(("saml_mapping_provider_mozilla.py",), "__version__"), + version=exec_file(("matrix_synapse_saml_mozilla.py",), "__version__"), py_modules=["matrix-synapse-saml-mozilla"], description="An Mozilla-flavoured SAML MXID mapper for Synapse", install_requires=[ diff --git a/tests/__init__.py b/tests/__init__.py index bc9c4dd..c04c70f 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,5 +1,16 @@ -from saml_mapping_provider_mozilla import SamlMappingProvider +from typing import Tuple +from matrix_synapse_saml_mozilla import SamlMappingProvider -def create_mapping_provider() -> SamlMappingProvider: - return SamlMappingProvider() +def create_mapping_provider() -> Tuple[SamlMappingProvider, dict]: + # Default configuration + config_dict = { + "mxid_source_attribute": "uid" + } + + # Convert the config dictionary to a SamlMappingProvider.SamlConfig object + config = SamlMappingProvider.parse_config(config_dict) + + # Create a new instance of the provider with the specified config + # Return the config dict as well for other test methods to use + return SamlMappingProvider(config), config_dict diff --git a/tests/test_username.py b/tests/test_username.py index 05998b8..9d713ed 100644 --- a/tests/test_username.py +++ b/tests/test_username.py @@ -15,26 +15,59 @@ import logging import unittest +from typing import Optional from . import create_mapping_provider logging.basicConfig() +def _make_test_saml_response( + provider_config: dict, + source_attribute_value: str, + display_name: Optional[str] = None +): + """Create a fake object based off of saml2.response.AuthnResponse + + Args: + provider_config: The config dictionary used when creating the provider object + source_attribute_value: The desired value that the mapping provider will + pull out of the response object to turn into a Matrix UserID. + display_name: The desired displayname that the mapping provider will pull + out of the response object to turn into a Matrix user displayname. + + Returns: + An object masquerading as a saml2.response.AuthnResponse object + """ + + class FakeResponse(object): + + def __init__(self): + self.ava = { + provider_config["mxid_source_attribute"]: [source_attribute_value], + } + + if display_name: + self.ava["displayName"] = display_name + + return FakeResponse() + + class SamlUsernameTestCase(unittest.TestCase): def test_normal_user(self): - provider = create_mapping_provider() + provider, config = create_mapping_provider() + response = _make_test_saml_response(config, "john*doe2000#@example.com", None) - username = "john*doe2000#@example.com" - localpart = provider.mxid_source_to_mxid_localpart(username) - self.assertEqual(localpart, "john.doe2000") + attribute_dict = provider.saml_response_to_user_attributes(response) + self.assertEqual(attribute_dict["mxid_localpart"], "john.doe2000") + self.assertEqual(attribute_dict["displayname"], "john.doe2000") def test_multiple_adjacent_symbols(self): provider = create_mapping_provider() username = "bob%^$&#!bobby@example.com" - localpart = provider.mxid_source_to_mxid_localpart(username) + localpart = provider.saml_response_to_user_attributes(username) self.assertEqual(localpart, "bob.bobby") def test_username_does_not_end_with_dot(self): @@ -42,29 +75,29 @@ class SamlUsernameTestCase(unittest.TestCase): provider = create_mapping_provider() username = "bob.bobby$@example.com" - localpart = provider.mxid_source_to_mxid_localpart(username) + localpart = provider.saml_response_to_user_attributes(username) self.assertEqual(localpart, "bob.bobby") def test_username_no_email(self): provider = create_mapping_provider() username = "bob.bobby" - localpart = provider.mxid_source_to_mxid_localpart(username) + localpart = provider.saml_response_to_user_attributes(username) self.assertEqual(localpart, "bob.bobby") def test_username_starting_with_underscore(self): provider = create_mapping_provider() username = "_twilight (sparkle)@somewhere.com" - localpart = provider.mxid_source_to_mxid_localpart(username) + localpart = provider.saml_response_to_user_attributes(username) self.assertEqual(localpart, "twilight.sparkle") def test_existing_user(self): provider = create_mapping_provider() username = "wibble%@wobble.com" - localpart = provider.mxid_source_to_mxid_localpart(username) + localpart = provider.saml_response_to_user_attributes(username) # Simulate a failure on the first attempt - localpart = provider.mxid_source_to_mxid_localpart(username, failures=1) + localpart = provider.saml_response_to_user_attributes(username, failures=1) self.assertEqual(localpart, "wibble1")