Implement new get_remote_user_id method

master
Richard van der Hoff 5 years ago
parent 091ca2bcff
commit 4c15ff5fd2

@ -22,6 +22,7 @@ import attr
import saml2.response import saml2.response
import synapse.module_api import synapse.module_api
from synapse.api.errors import CodeMessageException
from synapse.module_api.errors import RedirectException from synapse.module_api.errors import RedirectException
from matrix_synapse_saml_mozilla._sessions import ( from matrix_synapse_saml_mozilla._sessions import (
@ -39,7 +40,7 @@ MAPPING_SESSION_VALIDITY_PERIOD_MS = 15 * 60 * 1000
@attr.s @attr.s
class SamlConfig(object): class SamlConfig(object):
pass use_name_id_for_remote_uid = attr.ib(type=bool)
class SamlMappingProvider(object): class SamlMappingProvider(object):
@ -52,6 +53,20 @@ class SamlMappingProvider(object):
parsed_config: A configuration object. The result of self.parse_config parsed_config: A configuration object. The result of self.parse_config
""" """
self._random = random.SystemRandom() self._random = random.SystemRandom()
self._config = parsed_config
def get_remote_user_id(
self, saml_response: saml2.response.AuthnResponse, client_redirect_url: str
):
"""Extracts the remote user id from the SAML response"""
if self._config.use_name_id_for_remote_uid:
return saml_response.name_id
else:
try:
return saml_response.ava["uid"][0]
except KeyError:
logger.warning("SAML2 response lacks a 'uid' attestation")
raise CodeMessageException(400, "'uid' not in SAML2 response")
def saml_response_to_user_attributes( def saml_response_to_user_attributes(
self, self,
@ -73,7 +88,7 @@ class SamlMappingProvider(object):
* mxid_localpart (str): Required. The localpart of the user's mxid * mxid_localpart (str): Required. The localpart of the user's mxid
* displayname (str): The displayname of the user * displayname (str): The displayname of the user
""" """
remote_user_id = saml_response.ava["uid"][0] remote_user_id = self.get_remote_user_id(saml_response, client_redirect_url)
displayname = saml_response.ava.get("displayName", [None])[0] displayname = saml_response.ava.get("displayName", [None])[0]
expire_old_sessions() expire_old_sessions()
@ -109,7 +124,9 @@ class SamlMappingProvider(object):
Returns: Returns:
SamlConfig: A custom config object SamlConfig: A custom config object
""" """
return SamlConfig() return SamlConfig(
use_name_id_for_remote_uid=config.get("use_name_id_for_remote_uid"),
)
@staticmethod @staticmethod
def get_saml_attributes(config: SamlConfig) -> Tuple[set, set]: def get_saml_attributes(config: SamlConfig) -> Tuple[set, set]:

Loading…
Cancel
Save