From 4c15ff5fd2a2db18cf0bacf4927ca9d69abb007f Mon Sep 17 00:00:00 2001
From: Richard van der Hoff <richard@matrix.org>
Date: Mon, 13 Jan 2020 22:47:29 +0000
Subject: [PATCH] Implement new get_remote_user_id method

---
 .../mapping_provider.py                       | 23 ++++++++++++++++---
 1 file changed, 20 insertions(+), 3 deletions(-)

diff --git a/matrix_synapse_saml_mozilla/mapping_provider.py b/matrix_synapse_saml_mozilla/mapping_provider.py
index 4392d91..87f2d86 100644
--- a/matrix_synapse_saml_mozilla/mapping_provider.py
+++ b/matrix_synapse_saml_mozilla/mapping_provider.py
@@ -22,6 +22,7 @@ import attr
 import saml2.response
 
 import synapse.module_api
+from synapse.api.errors import CodeMessageException
 from synapse.module_api.errors import RedirectException
 
 from matrix_synapse_saml_mozilla._sessions import (
@@ -39,7 +40,7 @@ MAPPING_SESSION_VALIDITY_PERIOD_MS = 15 * 60 * 1000
 
 @attr.s
 class SamlConfig(object):
-    pass
+    use_name_id_for_remote_uid = attr.ib(type=bool)
 
 
 class SamlMappingProvider(object):
@@ -52,6 +53,20 @@ class SamlMappingProvider(object):
             parsed_config: A configuration object. The result of self.parse_config
         """
         self._random = random.SystemRandom()
+        self._config = parsed_config
+
+    def get_remote_user_id(
+        self, saml_response: saml2.response.AuthnResponse, client_redirect_url: str
+    ):
+        """Extracts the remote user id from the SAML response"""
+        if self._config.use_name_id_for_remote_uid:
+            return saml_response.name_id
+        else:
+            try:
+                return saml_response.ava["uid"][0]
+            except KeyError:
+                logger.warning("SAML2 response lacks a 'uid' attestation")
+                raise CodeMessageException(400, "'uid' not in SAML2 response")
 
     def saml_response_to_user_attributes(
         self,
@@ -73,7 +88,7 @@ class SamlMappingProvider(object):
                 * mxid_localpart (str): Required. The localpart of the user's mxid
                 * displayname (str): The displayname of the user
         """
-        remote_user_id = saml_response.ava["uid"][0]
+        remote_user_id = self.get_remote_user_id(saml_response, client_redirect_url)
         displayname = saml_response.ava.get("displayName", [None])[0]
 
         expire_old_sessions()
@@ -109,7 +124,9 @@ class SamlMappingProvider(object):
         Returns:
             SamlConfig: A custom config object
         """
-        return SamlConfig()
+        return SamlConfig(
+            use_name_id_for_remote_uid=config.get("use_name_id_for_remote_uid"),
+        )
 
     @staticmethod
     def get_saml_attributes(config: SamlConfig) -> Tuple[set, set]: