|
|
|
@ -16,7 +16,7 @@ import logging
|
|
|
|
|
import random
|
|
|
|
|
import string
|
|
|
|
|
import time
|
|
|
|
|
from typing import Tuple
|
|
|
|
|
from typing import Set, Tuple
|
|
|
|
|
|
|
|
|
|
import attr
|
|
|
|
|
import saml2.response
|
|
|
|
@ -40,7 +40,8 @@ MAPPING_SESSION_VALIDITY_PERIOD_MS = 15 * 60 * 1000
|
|
|
|
|
|
|
|
|
|
@attr.s
|
|
|
|
|
class SamlConfig(object):
|
|
|
|
|
use_name_id_for_remote_uid = attr.ib(type=bool)
|
|
|
|
|
use_name_id_for_remote_uid = attr.ib(type=bool, default=True)
|
|
|
|
|
domain_block_list = attr.ib(type=Set[str], default={})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SamlMappingProvider(object):
|
|
|
|
@ -55,6 +56,8 @@ class SamlMappingProvider(object):
|
|
|
|
|
self._random = random.SystemRandom()
|
|
|
|
|
self._config = parsed_config
|
|
|
|
|
|
|
|
|
|
logger.info("Domain block list: %s", self._config.domain_block_list)
|
|
|
|
|
|
|
|
|
|
def get_remote_user_id(
|
|
|
|
|
self, saml_response: saml2.response.AuthnResponse, client_redirect_url: str
|
|
|
|
|
):
|
|
|
|
@ -69,7 +72,7 @@ class SamlMappingProvider(object):
|
|
|
|
|
try:
|
|
|
|
|
return saml_response.ava["uid"][0]
|
|
|
|
|
except KeyError:
|
|
|
|
|
logger.warning("SAML2 response lacks a 'uid' attestation")
|
|
|
|
|
logger.warning("SAML2 response lacks a 'uid' attribute")
|
|
|
|
|
raise CodeMessageException(400, "'uid' not in SAML2 response")
|
|
|
|
|
|
|
|
|
|
def saml_response_to_user_attributes(
|
|
|
|
@ -97,6 +100,29 @@ class SamlMappingProvider(object):
|
|
|
|
|
|
|
|
|
|
expire_old_sessions()
|
|
|
|
|
|
|
|
|
|
# check the user's emails against our block list
|
|
|
|
|
if "emails" not in saml_response.ava:
|
|
|
|
|
logger.warning("SAML2 response lacks an 'emails' attribute")
|
|
|
|
|
raise CodeMessageException(400, "'emails' not in SAML2 response")
|
|
|
|
|
|
|
|
|
|
for email in saml_response.ava["emails"]:
|
|
|
|
|
parts = email.rsplit("@", 1)
|
|
|
|
|
if len(parts) != 2:
|
|
|
|
|
logger.warning(
|
|
|
|
|
"Rejecting registration from remote user %s with unparsable email %s",
|
|
|
|
|
remote_user_id,
|
|
|
|
|
email,
|
|
|
|
|
)
|
|
|
|
|
raise CodeMessageException(403, "Forbidden")
|
|
|
|
|
|
|
|
|
|
if parts[1].lower() in self._config.domain_block_list:
|
|
|
|
|
logger.warning(
|
|
|
|
|
"Rejecting registration from remote user %s with blacklisted email %s",
|
|
|
|
|
remote_user_id,
|
|
|
|
|
email,
|
|
|
|
|
)
|
|
|
|
|
raise CodeMessageException(403, "Forbidden")
|
|
|
|
|
|
|
|
|
|
# make up a cryptorandom session id
|
|
|
|
|
session_id = "".join(
|
|
|
|
|
self._random.choice(string.ascii_letters) for _ in range(16)
|
|
|
|
@ -128,10 +154,24 @@ class SamlMappingProvider(object):
|
|
|
|
|
Returns:
|
|
|
|
|
SamlConfig: A custom config object
|
|
|
|
|
"""
|
|
|
|
|
return SamlConfig(
|
|
|
|
|
use_name_id_for_remote_uid=config.get("use_name_id_for_remote_uid"),
|
|
|
|
|
parsed = SamlConfig()
|
|
|
|
|
if "use_name_id_for_remote_uid" in config:
|
|
|
|
|
parsed.use_name_id_for_remote_uid = config["use_name_id_for_remote_uid"]
|
|
|
|
|
|
|
|
|
|
domain_block_file = config.get("domain_block_file")
|
|
|
|
|
if domain_block_file:
|
|
|
|
|
try:
|
|
|
|
|
with open(domain_block_file, encoding="ascii") as fh:
|
|
|
|
|
parsed.domain_block_list = {
|
|
|
|
|
line.strip().lower() for line in fh.readlines()
|
|
|
|
|
}
|
|
|
|
|
except Exception as e:
|
|
|
|
|
raise Exception(
|
|
|
|
|
"Error reading domain block file %s: %s" % (domain_block_file, e)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return parsed
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_saml_attributes(config: SamlConfig) -> Tuple[set, set]:
|
|
|
|
|
"""Returns the required and optional attributes of a SAML auth response object
|
|
|
|
@ -145,4 +185,10 @@ class SamlMappingProvider(object):
|
|
|
|
|
second set consists of those attributes which can be used if
|
|
|
|
|
available, but are not necessary
|
|
|
|
|
"""
|
|
|
|
|
return {"uid"}, {"displayName"}
|
|
|
|
|
required = set()
|
|
|
|
|
optional = {"uid", "emails", "displayName"}
|
|
|
|
|
|
|
|
|
|
if not config.use_name_id_for_remote_uid:
|
|
|
|
|
required += "uid"
|
|
|
|
|
|
|
|
|
|
return required, optional
|
|
|
|
|