|
|
@ -22,9 +22,10 @@ from saml2.config import SPConfig
|
|
|
|
from saml2.response import AuthnResponse
|
|
|
|
from saml2.response import AuthnResponse
|
|
|
|
from saml2.sigver import CryptoBackend, SecurityContext
|
|
|
|
from saml2.sigver import CryptoBackend, SecurityContext
|
|
|
|
|
|
|
|
|
|
|
|
from synapse.api.errors import RedirectException
|
|
|
|
from synapse.api.errors import CodeMessageException, RedirectException
|
|
|
|
|
|
|
|
|
|
|
|
from matrix_synapse_saml_mozilla._sessions import username_mapping_sessions
|
|
|
|
from matrix_synapse_saml_mozilla._sessions import username_mapping_sessions
|
|
|
|
|
|
|
|
from matrix_synapse_saml_mozilla.mapping_provider import SamlConfig, SamlMappingProvider
|
|
|
|
|
|
|
|
|
|
|
|
from . import create_mapping_provider
|
|
|
|
from . import create_mapping_provider
|
|
|
|
|
|
|
|
|
|
|
@ -33,6 +34,7 @@ class FakeResponse:
|
|
|
|
def __init__(self, source_uid, display_name):
|
|
|
|
def __init__(self, source_uid, display_name):
|
|
|
|
self.ava = {
|
|
|
|
self.ava = {
|
|
|
|
"uid": [source_uid],
|
|
|
|
"uid": [source_uid],
|
|
|
|
|
|
|
|
"emails": [],
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if display_name:
|
|
|
|
if display_name:
|
|
|
@ -49,7 +51,13 @@ def _load_test_response() -> AuthnResponse:
|
|
|
|
).decode("utf-8")
|
|
|
|
).decode("utf-8")
|
|
|
|
|
|
|
|
|
|
|
|
config = SPConfig()
|
|
|
|
config = SPConfig()
|
|
|
|
config.load({})
|
|
|
|
config.load(
|
|
|
|
|
|
|
|
{
|
|
|
|
|
|
|
|
"attribute_map_dir": pkg_resources.resource_filename(
|
|
|
|
|
|
|
|
"matrix_synapse_saml_mozilla", "saml_maps"
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
)
|
|
|
|
assert config.attribute_converters is not None
|
|
|
|
assert config.attribute_converters is not None
|
|
|
|
|
|
|
|
|
|
|
|
response = AuthnResponse(
|
|
|
|
response = AuthnResponse(
|
|
|
@ -57,6 +65,7 @@ def _load_test_response() -> AuthnResponse:
|
|
|
|
attribute_converters=config.attribute_converters,
|
|
|
|
attribute_converters=config.attribute_converters,
|
|
|
|
entity_id="https://host/_matrix/saml2/metadata.xml",
|
|
|
|
entity_id="https://host/_matrix/saml2/metadata.xml",
|
|
|
|
allow_unsolicited=True,
|
|
|
|
allow_unsolicited=True,
|
|
|
|
|
|
|
|
allow_unknown_attributes=True,
|
|
|
|
# tell it not to check the `destination`
|
|
|
|
# tell it not to check the `destination`
|
|
|
|
asynchop=False,
|
|
|
|
asynchop=False,
|
|
|
|
# tell it not to check the issue time
|
|
|
|
# tell it not to check the issue time
|
|
|
@ -70,7 +79,7 @@ def _load_test_response() -> AuthnResponse:
|
|
|
|
class SamlUserAttributeTestCase(unittest.TestCase):
|
|
|
|
class SamlUserAttributeTestCase(unittest.TestCase):
|
|
|
|
def test_get_remote_user_id_from_name_id(self):
|
|
|
|
def test_get_remote_user_id_from_name_id(self):
|
|
|
|
resp = _load_test_response()
|
|
|
|
resp = _load_test_response()
|
|
|
|
provider = create_mapping_provider({"use_name_id_for_remote_uid": True})
|
|
|
|
provider = create_mapping_provider()
|
|
|
|
remote_user_id = provider.get_remote_user_id(resp, "",)
|
|
|
|
remote_user_id = provider.get_remote_user_id(resp, "",)
|
|
|
|
self.assertEqual(remote_user_id, "test@domain.com")
|
|
|
|
self.assertEqual(remote_user_id, "test@domain.com")
|
|
|
|
|
|
|
|
|
|
|
@ -78,7 +87,7 @@ class SamlUserAttributeTestCase(unittest.TestCase):
|
|
|
|
"""Creates a dummy response, feeds it to the provider and checks that it
|
|
|
|
"""Creates a dummy response, feeds it to the provider and checks that it
|
|
|
|
redirects to the username picker.
|
|
|
|
redirects to the username picker.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
provider = create_mapping_provider()
|
|
|
|
provider = create_mapping_provider({"use_name_id_for_remote_uid": False})
|
|
|
|
response = FakeResponse(123435, "Jonny")
|
|
|
|
response = FakeResponse(123435, "Jonny")
|
|
|
|
|
|
|
|
|
|
|
|
# we expect this to redirect to the username picker
|
|
|
|
# we expect this to redirect to the username picker
|
|
|
@ -105,3 +114,15 @@ class SamlUserAttributeTestCase(unittest.TestCase):
|
|
|
|
expected_expiry = (time.time() + 15 * 60) * 1000
|
|
|
|
expected_expiry = (time.time() + 15 * 60) * 1000
|
|
|
|
self.assertGreaterEqual(session.expiry_time_ms, expected_expiry - 1000)
|
|
|
|
self.assertGreaterEqual(session.expiry_time_ms, expected_expiry - 1000)
|
|
|
|
self.assertLessEqual(session.expiry_time_ms, expected_expiry + 1000)
|
|
|
|
self.assertLessEqual(session.expiry_time_ms, expected_expiry + 1000)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_reject_blacklisted_email(self):
|
|
|
|
|
|
|
|
config = SamlConfig(
|
|
|
|
|
|
|
|
use_name_id_for_remote_uid=True, domain_block_list={"otherdomain.com"}
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
provider = SamlMappingProvider(config, None)
|
|
|
|
|
|
|
|
resp = _load_test_response()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with self.assertRaises(CodeMessageException) as e:
|
|
|
|
|
|
|
|
provider.saml_response_to_user_attributes(resp, 0, "http://client/")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.assertEqual(e.exception.code, 403)
|
|
|
|