Allow users to pick a username on login (#1)

This is essentially a rewrite to collect the username from the user when they first log in, rather than try to determine it algorithmically from SAML attributes.
master
Richard van der Hoff 4 years ago committed by GitHub
parent ccbb42d66b
commit f85ec19465
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,6 +1,10 @@
include *.in include *.in
include *.py
include LICENSE include LICENSE
include tox.ini include tox.ini
include requirements.txt include requirements.txt
prune doc
recursive-include matrix_synapse_saml_mozilla *.py
graft matrix_synapse_saml_mozilla/res
recursive-include tests *.py recursive-include tests *.py

@ -1,7 +1,7 @@
# Synapse Mozilla SAML MXID Mapper # Synapse Mozilla SAML MXID Mapper
Custom SAML auth response -> MXID mapping algorithm used during the Mozilla A Synapse plugin module which allows users to choose their username when they
Matrix trial run. first log in.
## Installation ## Installation
@ -11,8 +11,6 @@ This plugin can be installed via [PyPi](https://pypi.org):
pip install matrix-synapse-saml-mozilla pip install matrix-synapse-saml-mozilla
``` ```
## Usage
### Config ### Config
Add the following in your Synapse config: Add the following in your Synapse config:
@ -21,8 +19,22 @@ Add the following in your Synapse config:
saml2_config: saml2_config:
user_mapping_provider: user_mapping_provider:
module: "matrix_synapse_saml_mozilla.SamlMappingProvider" module: "matrix_synapse_saml_mozilla.SamlMappingProvider"
config: ```
mxid_source_attribute: "uid"
Also, under the HTTP client `listener`, configure an `additional_resource` as per
the below:
```yaml
listeners:
- port: <port>
type: http
resources:
- names: [client]
additional_resources:
"/_matrix/saml2/pick_username":
module: "matrix_synapse_saml_mozilla.pick_username_resource"
``` ```
### Configuration Options ### Configuration Options
@ -30,11 +42,13 @@ Add the following in your Synapse config:
Synapse allows SAML mapping providers to specify custom configuration through the Synapse allows SAML mapping providers to specify custom configuration through the
`saml2_config.user_mapping_provider.config` option. `saml2_config.user_mapping_provider.config` option.
The options supported by this provider are currently: There are no options currently supported by this provider.
## Implementation notes
The login flow looks something like this:
* `mxid_source_attribute` - The SAML attribute (after mapping via the ![login flow](doc/login_flow.svg)
attribute maps) to use to derive the Matrix
ID from. 'uid' by default.
## Development and Testing ## Development and Testing
@ -42,7 +56,7 @@ This repository uses `tox` to run linting and tests.
### Linting ### Linting
Code is linted with the `flake8` tool. Run `tox -e pep8` to check for linting Code is linted with the `flake8` tool. Run `tox -e lint` to check for linting
errors in the codebase. errors in the codebase.
### Tests ### Tests

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 33 KiB

@ -0,0 +1,65 @@
title Mozilla matrix login flow
participant Riot
participant "(Embedded) Browser" as B
participant "Synapse" as HS
participant "SAML2 IdP" as IdP
activate Riot
Riot->HS:""GET /login
activate HS
Riot<--HS:"""type":"m.login.sso"
deactivate HS
create B
Riot->B:
activate B
B->HS:""GET /login/sso/redirect\n--?redirectUrl=<clienturl>--""
activate HS
HS->HS:Generate SAML request
B<--HS:302 to IdP
deactivate HS
B->IdP: ""GET https://auth.mozilla.auth0.com/samlp/...\n--?SAMLRequest=<SAML request>
activate IdP
B<--IdP: 200 login form
deactivate IdP
B->IdP: submit login form with auth credentials
activate IdP
IdP-->B:200: auto-submitting HTML form including SAML Response
deactivate IdP
B->HS:""POST /_matrix/saml2/authn_response\n--SAMLResponse=<response>
activate HS
HS->HS:Check if known user
B<--HS:302 to username picker\n--including ""username_mapping_session"" cookie
deactivate HS
B->HS:""GET /_matrix/saml2/pick_username/
activate HS
B<--HS: 200 with form page
deactivate HS
B->HS:""GET /_matrix/saml2/pick_username/check\n--?username=<username>
activate HS
B<--HS:200 ""{"available": true/false}""\n--or 200 ""{"error": "..."}""
deactivate HS
B->HS:""POST /_matrix/saml2/pick_username/submit\n--username=<username>
activate HS
B<--HS:302 to original clienturl with loginToken
deactivate HS
Riot<-B:
deactivate B
destroysilent B
Riot->HS: ""POST /login\n--{"type": "m.login.token", "token": "<token>"}
activate HS
Riot<--HS:""access token"" etc
deactivate HS

@ -0,0 +1,21 @@
# -*- coding: utf-8 -*-
# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from matrix_synapse_saml_mozilla.mapping_provider import SamlMappingProvider
from matrix_synapse_saml_mozilla.username_picker import pick_username_resource
__version__ = "0.0.1"
__all__ = ["SamlMappingProvider", "pick_username_resource"]

@ -0,0 +1,65 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Optional
import attr
import time
SESSION_COOKIE_NAME = b"username_mapping_session"
logger = logging.getLogger(__name__)
@attr.s
class UsernameMappingSession:
"""Data we track about SAML2 sessions"""
# user ID on the SAML server
remote_user_id = attr.ib(type=str)
# displayname, per the SAML attributes
displayname = attr.ib(type=Optional[str])
# where to redirect the client back to
client_redirect_url = attr.ib(type=str)
# expiry time for the session, in milliseconds
expiry_time_ms = attr.ib(type=int)
# a map from session id to session data
username_mapping_sessions = {} # type: dict[str, UsernameMappingSession]
def expire_old_sessions(gettime=time.time):
"""Delete any sessions which have passed their expiry_time"""
to_expire = []
now = int(gettime() * 1000)
for session_id, session in username_mapping_sessions.items():
if session.expiry_time_ms <= now:
to_expire.append(session_id)
for session_id in to_expire:
logger.info("Expiring mapping session %s", session_id)
del username_mapping_sessions[session_id]
def get_mapping_session(session_id: str) -> Optional[UsernameMappingSession]:
"""Look up the given session id, first expiring any old sessions"""
expire_old_sessions()
return username_mapping_sessions.get(session_id, None)

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2019 The Matrix.org Foundation C.I.C. # Copyright 2020 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -12,100 +12,94 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
import random
import string
import time
from typing import Tuple from typing import Tuple
import re
import attr import attr
import string
import saml2.response import saml2.response
__version__ = "0.0.1" import synapse.module_api
from synapse.module_api.errors import RedirectException
from matrix_synapse_saml_mozilla._sessions import (
UsernameMappingSession,
username_mapping_sessions,
expire_old_sessions,
SESSION_COOKIE_NAME,
)
logger = logging.getLogger(__name__)
MAPPING_SESSION_VALIDITY_PERIOD_MS = 15 * 60 * 1000
@attr.s @attr.s
class SamlConfig(object): class SamlConfig(object):
mxid_source_attribute = attr.ib() pass
class SamlMappingProvider(object): class SamlMappingProvider(object):
def __init__(self, parsed_config: SamlConfig): def __init__(
self, parsed_config: SamlConfig, module_api: synapse.module_api.ModuleApi
):
"""A Mozilla-flavoured, Synapse user mapping provider """A Mozilla-flavoured, Synapse user mapping provider
Args: Args:
parsed_config: A configuration object. The result of self.parse_config parsed_config: A configuration object. The result of self.parse_config
""" """
self._mxid_source_attribute = parsed_config.mxid_source_attribute self._random = random.SystemRandom()
mxid_localpart_allowed_characters = set(
"_-./=" + string.ascii_lowercase + string.digits
)
self._dot_replace_pattern = re.compile(
("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),))
)
self._multiple_to_single_dot_pattern = re.compile(r"\.{2,}")
self._string_end_dot_pattern = re.compile(r"\.$")
def saml_response_to_user_attributes( def saml_response_to_user_attributes(
self, self,
saml_response: saml2.response.AuthnResponse, saml_response: saml2.response.AuthnResponse,
failures: int = 0, failures: int,
client_redirect_url: str,
) -> dict: ) -> dict:
"""Maps some text from a SAML response to attributes of a new user """Maps some text from a SAML response to attributes of a new user
Args: Args:
saml_response: A SAML auth response object saml_response: A SAML auth response object
failures: How many times a call to this function with this failures: How many times a call to this function with this
saml_response has resulted in a failure saml_response has resulted in a failure
client_redirect_url: where the client wants to redirect back to
Returns: Returns:
dict: A dict containing new user attributes. Possible keys: dict: A dict containing new user attributes. Possible keys:
* mxid_localpart (str): Required. The localpart of the user's mxid * mxid_localpart (str): Required. The localpart of the user's mxid
* displayname (str): The displayname of the user * displayname (str): The displayname of the user
""" """
# The calling function will catch the KeyError if this fails remote_user_id = saml_response.ava["uid"][0]
mxid_source = saml_response.ava[self._mxid_source_attribute][0]
# Truncate the username to the first found '@' character to prevent complete
# emails being leaked
pos = mxid_source.find("@")
if pos >= 0:
mxid_source = mxid_source[:pos]
mxid_localpart = self._dotreplace_for_mxid(mxid_source)
# Append suffix integer if last call to this function failed to produce
# a usable mxid
localpart = mxid_localpart + (str(failures) if failures else "")
# Retrieve the display name from the saml response
displayname = saml_response.ava.get("displayName", [None])[0] displayname = saml_response.ava.get("displayName", [None])[0]
return { expire_old_sessions()
"mxid_localpart": localpart,
"displayname": displayname,
}
def _dotreplace_for_mxid(self, username: str) -> str:
"""Replace non-allowed mxid characters with a '.'
Args:
username (str): The username to process
Returns: # make up a cryptorandom session id
str: The processed username session_id = "".join(
""" self._random.choice(string.ascii_letters) for _ in range(16)
username = username.lower() )
username = self._dot_replace_pattern.sub(".", username)
# regular mxids aren't allowed to start with an underscore either now = int(time.time() * 1000)
username = re.sub("^_", "", username) session = UsernameMappingSession(
remote_user_id=remote_user_id,
displayname=displayname,
client_redirect_url=client_redirect_url,
expiry_time_ms=now + MAPPING_SESSION_VALIDITY_PERIOD_MS,
)
# Change all instances of multiple dots together into a single dot username_mapping_sessions[session_id] = session
username = self._multiple_to_single_dot_pattern.sub(".", username) logger.info("Recorded registration session id %s", session_id)
# Remove any trailing dots # Redirect to the username picker
username = self._string_end_dot_pattern.sub("", username) e = RedirectException(b"/_matrix/saml2/pick_username/")
return username e.cookies.append(
b"%s=%s; path=/" % (SESSION_COOKIE_NAME, session_id.encode("ascii"),)
)
raise e
@staticmethod @staticmethod
def parse_config(config: dict) -> SamlConfig: def parse_config(config: dict) -> SamlConfig:
@ -115,8 +109,7 @@ class SamlMappingProvider(object):
Returns: Returns:
SamlConfig: A custom config object SamlConfig: A custom config object
""" """
mxid_source_attribute = config.get("mxid_source_attribute", "uid") return SamlConfig()
return SamlConfig(mxid_source_attribute)
@staticmethod @staticmethod
def get_saml_attributes(config: SamlConfig) -> Tuple[set, set]: def get_saml_attributes(config: SamlConfig) -> Tuple[set, set]:
@ -131,4 +124,4 @@ class SamlMappingProvider(object):
second set consists of those attributes which can be used if second set consists of those attributes which can be used if
available, but are not necessary available, but are not necessary
""" """
return {"uid", config.mxid_source_attribute}, {"displayName"} return {"uid"}, {"displayName"}

@ -0,0 +1,20 @@
<!DOCTYPE html>
<html lang="en">
<head>
<title>Synapse Login</title>
<link rel="stylesheet" href="style.css" type="text/css" />
</head>
<body>
<div class="card">
<form method="post" class="form__input" id="form" action="submit">
<input type="text" name="username" id="field-username" autofocus="">
<label for="field-username">
<span><span aria-hidden="true">Please pick your </span>username</span>
</label>
<input type="button" class="button button--full-width" id="button-submit" value="Submit">
</form>
<div role=alert class="tooltip hidden" id="message"></div>
<script src="script.js"></script>
</div>
</body>
</html>

@ -0,0 +1,117 @@
let inputField = document.getElementById("field-username");
let inputForm = document.getElementById("form");
let submitButton = document.getElementById("button-submit");
let message = document.getElementById("message");
// Remove input field placeholder if the text field is not empty
let switchClass = function(input) {
if (input.value.length > 0) {
input.classList.add('has-contents');
}
else {
input.classList.remove('has-contents');
}
};
// Submit username and receive response
let showMessage = function(messageText) {
// Unhide the message text
message.classList.remove("hidden");
message.innerHTML = messageText;
};
let onResponse = function(response, success) {
// Display message
showMessage(response);
if(success) {
inputForm.submit();
return;
}
// Enable submit button and input field
submitButton.classList.remove('button--disabled');
submitButton.value = "Submit"
};
// We allow upper case characters here, but then lowercase before sending to the server
let allowedUsernameCharacters = RegExp("[^a-zA-Z0-9\\.\\_\\=\\-\\/]");
let usernameIsValid = function(username) {
return !allowedUsernameCharacters.test(username);
}
let allowedCharactersString = "" +
"<code>a-z</code>, " +
"<code>0-9</code>, " +
"<code>.</code>, " +
"<code>_</code>, " +
"<code>-</code>, " +
"<code>/</code>, " +
"<code>=</code>";
let buildQueryString = function(params) {
return Object.keys(params)
.map(k => encodeURIComponent(k) + '=' + encodeURIComponent(params[k]))
.join('&');
}
let submitUsername = function(username) {
if(username.length == 0) {
onResponse("Please enter a username.", false);
return;
}
if(!usernameIsValid(username)) {
onResponse("Invalid username. Only the following characters are allowed: " + allowedCharactersString, false);
return;
}
let check_uri = 'check?' + buildQueryString({"username": username});
fetch(check_uri, {
"credentials": "include",
}).then((response) => {
if(!response.ok) {
// for non-200 responses, raise the body of the response as an exception
return response.text().then((text) => { throw text });
} else {
return response.json()
}
}).then((json) => {
if(json.error) {
throw json.error;
} else if(json.available) {
onResponse("Success. Please wait a moment for your browser to redirect.", true);
} else {
onResponse("This username is not available, please choose another.", false);
}
}).catch((err) => {
onResponse("Error checking username availability: " + err, false);
});
}
let clickSubmit = function() {
if(submitButton.classList.contains('button--disabled')) { return; }
// Disable submit button and input field
submitButton.classList.add('button--disabled');
// Submit username
submitButton.value = "Checking...";
submitUsername(inputField.value);
};
submitButton.onclick = clickSubmit;
// Listen for events on inputField
inputField.addEventListener('keypress', function(event) {
// Listen for Enter on input field
if(event.which === 13) {
event.preventDefault();
clickSubmit();
return true;
}
switchClass(inputField);
});
inputField.addEventListener('change', function() {
switchClass(inputField);
});

@ -0,0 +1,178 @@
body {
background: #ededf0;
color: #737373;
font-family: "Open Sans", sans-serif;
letter-spacing: 0.03em;
margin: 0;
padding: 0;
display: grid;
grid-template-rows: auto 1fr; }
.card {
background-color: #fff;
padding: 2em;
position: relative;
margin: auto;
width: 100%;
box-shadow: 0 0.25em 0.25em 0 rgba(210, 210, 210, 0.5);
border-radius: 0.125em; }
@media (min-width: 25em) {
.card {
max-width: 26em;
padding: 2.5em; } }
@supports (display: grid) {
.card {
grid-row-start: 2; }
@media (min-height: 50em) {
.card {
top: -3em;
/* compensate for negative margin for footer links */ } } }
.card__back {
margin-bottom: 1em; }
.card__heading {
font-size: 1.4em;
font-weight: 400;
text-transform: capitalize;
padding-left: 2.125em;
position: relative;
min-height: 1.5em; }
.card__heading--iconless {
padding-left: 0; }
.card__heading img {
width: 1.5em;
height: 1.5em;
position: absolute;
left: 0;
top: 0; }
.card__heading--success {
color: #12bc00; }
.card__heading--error {
color: #ff0039; }
.card [data-screen]:focus {
outline: none; }
* {
box-sizing: border-box; }
form {
margin: 0; }
form * {
font-family: inherit; }
label {
margin: 2em 0;
display: block; }
input[type="text"],
input[type="email"],
input[type="password"] {
font-size: 100%;
background-color: #ededf0;
border: 1px solid #fff;
border-radius: .2em;
padding: .5em .9em;
display: block;
width: 100%;
margin-bottom: 1em; }
input[type="text"]:hover, input[type="text"]:focus,
input[type="email"]:hover,
input[type="email"]:focus,
input[type="password"]:hover,
input[type="password"]:focus {
border: 1px solid #0060df;
outline: none; }
.focus-styles input[type="text"]:focus, .focus-styles
input[type="email"]:focus, .focus-styles
input[type="password"]:focus {
border-color: transparent; }
.form__input {
position: relative; }
p + .form__input {
margin-top: 2.5em;
/* leave space to fit a paragraph above a field */ }
.form__input label {
margin: 0;
position: absolute;
top: .5em;
left: .9em; }
.form__input input:focus + label,
.form__input input.has-contents + label {
position: absolute;
top: -1.5em;
color: #0060df;
font-weight: bold; }
.form__input input:focus + label > span,
.form__input input.has-contents + label > span {
font-size: 0.75em; }
html,
body {
height: 100%; }
.button {
text-align: center;
text-decoration: none;
padding: .93em 2em;
display: block;
font-size: 87.5%;
letter-spacing: .04em;
line-height: 1.57;
font-family: inherit;
border-radius: 2em;
background-color: #0060df;
color: #fff;
border: 1px solid transparent;
transition: background-color .1s ease-in-out;
-webkit-appearance: none;
-moz-appearance: none;
appearance: none; }
.button:hover {
background-color: #fff;
color: #0060df;
border-color: currentColor;
text-decoration: none; }
.button:active {
background-color: #0060df;
color: #fff;
border-color: #0060df; }
.button--full-width {
width: 100%; }
.button--secondary {
border-color: #b1b1b3;
background-color: transparent;
color: #000;
text-transform: none; }
.button--secondary:hover {
background-color: #000;
color: #fff;
border-color: transparent; }
.button--secondary:hover svg > path {
fill: #fff; }
.button--secondary:active {
background-color: transparent;
border-color: #000;
color: #000; }
.button--disabled {
border-color: #fff;
background-color: transparent;
color: #000;
text-transform: none; }
.button--disabled:hover {
background-color: #fff;
color: #000;
border-color: transparent; }
.hidden {
display: none; }
.tooltip {
background-color: #f9f9fa;
padding: 1em;
margin: 1em 0; }
.tooltip p:last-child {
margin-bottom: 0; }
.tooltip a:last-child {
margin-left: .5em; }
.tooltip:target {
display: block; }

@ -0,0 +1,270 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import html
import json
import logging
import urllib.parse
from typing import Any
import pkg_resources
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET, Request
from twisted.web.static import File
import synapse.module_api
from synapse.module_api import run_in_background
from synapse.module_api.errors import SynapseError
from matrix_synapse_saml_mozilla._sessions import (
get_mapping_session,
username_mapping_sessions,
SESSION_COOKIE_NAME,
)
"""
This file implements the "username picker" resource, which is mapped as an
additional_resource into the synapse resource tree.
The top-level resource is just a File resource which serves up the static files in the
"res" directory, but it has a couple of children:
* "submit", which does the mechanics of registering the new user, and redirects the
browser back to the client URL
* "check" (TODO): checks if a userid is free.
"""
logger = logging.getLogger(__name__)
def pick_username_resource(
parsed_config, module_api: synapse.module_api.ModuleApi
) -> Resource:
"""Factory method to generate the top-level username picker resource"""
base_path = pkg_resources.resource_filename("matrix_synapse_saml_mozilla", "res")
res = File(base_path)
res.putChild(b"submit", SubmitResource(module_api))
res.putChild(b"check", AvailabilityCheckResource(module_api))
return res
def parse_config(config: dict):
return None
pick_username_resource.parse_config = parse_config
HTML_ERROR_TEMPLATE = """<!DOCTYPE html>
<html lang=en>
<head>
<meta charset="utf-8">
<title>Error {code}</title>
</head>
<body>
<p>{msg}</p>
</body>
</html>
"""
def _wrap_for_html_exceptions(f):
async def wrapped(self, request):
try:
return await f(self, request)
except Exception:
logger.exception("Error handling request %s" % (request,))
_return_html_error(500, "Internal server error", request)
return wrapped
def _wrap_for_text_exceptions(f):
async def wrapped(self, request):
try:
return await f(self, request)
except Exception:
logger.exception("Error handling request %s" % (request,))
body = b"Internal server error"
request.setResponseCode(500)
request.setHeader(b"Content-Type", b"text/plain; charset=utf-8")
request.setHeader(b"Content-Length", b"%i" % (len(body),))
request.write(body)
request.finish()
return wrapped
class AsyncResource(Resource):
"""Extends twisted.web.Resource to add support for async_render_X methods"""
def render(self, request: Request):
method = request.method.decode("ascii")
m = getattr(self, "async_render_" + method, None)
if not m and method == "HEAD":
m = getattr(self, "async_render_GET", None)
if not m:
return super().render(request)
async def run():
with request.processing():
return await m(request)
run_in_background(run)
return NOT_DONE_YET
class SubmitResource(AsyncResource):
def __init__(self, module_api: synapse.module_api.ModuleApi):
super().__init__()
self._module_api = module_api
@_wrap_for_html_exceptions
async def async_render_POST(self, request: Request):
session_id = request.getCookie(SESSION_COOKIE_NAME)
if not session_id:
_return_html_error(400, "missing session_id", request)
return
session_id = session_id.decode("ascii", errors="replace")
session = get_mapping_session(session_id)
if not session:
logger.info("Session ID %s not found", session_id)
_return_html_error(403, "Unknown session", request)
return
# we don't clear the session from the dict until the ID is successfully
# registered, so the user can go round and have another go if need be.
#
# this means there's theoretically a race where a single user can register
# two accounts. I'm going to assume that's not a dealbreaker.
if b"username" not in request.args:
_return_html_error(400, "missing username", request)
return
localpart = request.args[b"username"][0].decode("utf-8", errors="replace")
logger.info("Registering username %s", localpart)
try:
registered_user_id = await self._module_api.register_user(
localpart=localpart, displayname=session.displayname
)
except SynapseError as e:
logger.warning("Error during registration: %s", e)
_return_html_error(e.code, e.msg, request)
return
await self._module_api.record_user_external_id(
"saml", session.remote_user_id, registered_user_id
)
del username_mapping_sessions[session_id]
login_token = self._module_api.generate_short_term_login_token(
registered_user_id
)
redirect_url = _add_login_token_to_redirect_url(
session.client_redirect_url, login_token
)
# delete the cookie
request.addCookie(
SESSION_COOKIE_NAME,
b"",
expires=b"Thu, 01 Jan 1970 00:00:00 GMT",
path=b"/",
)
request.redirect(redirect_url)
request.finish()
class AvailabilityCheckResource(AsyncResource):
def __init__(self, module_api: synapse.module_api.ModuleApi):
super().__init__()
self._module_api = module_api
@_wrap_for_text_exceptions
async def async_render_GET(self, request: Request):
# make sure that there is a valid mapping session, to stop people dictionary-
# scanning for accounts
session_id = request.getCookie(SESSION_COOKIE_NAME)
if not session_id:
_return_json({"error": "missing session_id"}, request)
return
session_id = session_id.decode("ascii", errors="replace")
session = get_mapping_session(session_id)
if not session:
logger.info("Couldn't find session id %s", session_id)
_return_json({"error": "unknown session"}, request)
return
if b"username" not in request.args:
_return_json({"error": "missing username"}, request)
return
localpart = request.args[b"username"][0].decode("utf-8", errors="replace")
logger.info("Checking for availability of username %s", localpart)
try:
user_id = self._module_api.get_qualified_user_id(localpart)
registered_id = await self._module_api.check_user_exists(user_id)
available = registered_id is None
except Exception as e:
logger.warning(
"Error checking for availability of %s: %s %s" % (localpart, type(e), e)
)
available = False
response = {"available": available}
_return_json(response, request)
def _add_login_token_to_redirect_url(url, token):
url_parts = list(urllib.parse.urlparse(url))
query = dict(urllib.parse.parse_qsl(url_parts[4]))
query.update({"loginToken": token})
url_parts[4] = urllib.parse.urlencode(query)
return urllib.parse.urlunparse(url_parts)
def _return_html_error(code: int, msg: str, request: Request):
"""Sends an HTML error page"""
body = HTML_ERROR_TEMPLATE.format(code=code, msg=html.escape(msg)).encode("utf-8")
request.setResponseCode(code)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
request.setHeader(b"Content-Length", b"%i" % (len(body),))
request.write(body)
try:
request.finish()
except RuntimeError as e:
logger.info("Connection disconnected before response was written: %r", e)
def _return_json(json_obj: Any, request: Request):
json_bytes = json.dumps(json_obj).encode("utf-8")
request.setHeader(b"Content-Type", b"application/json")
request.setHeader(b"Content-Length", b"%d" % (len(json_bytes),))
request.setHeader(b"Cache-Control", b"no-cache, no-store, must-revalidate")
request.setHeader(b"Access-Control-Allow-Origin", b"*")
request.setHeader(
b"Access-Control-Allow-Methods", b"GET, POST, PUT, DELETE, OPTIONS"
)
request.setHeader(
b"Access-Control-Allow-Headers",
b"Origin, X-Requested-With, Content-Type, Accept, Authorization",
)
request.write(json_bytes)
try:
request.finish()
except RuntimeError as e:
logger.info("Connection disconnected before response was written: %r", e)

@ -2,3 +2,15 @@
max-line-length = 90 max-line-length = 90
# W503 requires that binary operators be at the end, not start, of lines. Erik doesn't like it. # W503 requires that binary operators be at the end, not start, of lines. Erik doesn't like it.
ignore = W503 ignore = W503
[isort]
line_length = 88
not_skip = __init__.py
sections = FUTURE,STDLIB,THIRDPARTY,SYNAPSE,FIRSTPARTY,TESTS,LOCALFOLDER
default_section = THIRDPARTY
known_synapse = synapse
known_first_party = matrix_synapse_saml_mozilla
known_tests = tests
multi_line_output = 3
include_trailing_comma = true
combine_as_imports = true

@ -34,25 +34,22 @@ def exec_file(path_segments, name):
the constant and executing it.""" the constant and executing it."""
result = {} result = {}
code = read_file(path_segments) code = read_file(path_segments)
lines = [line for line in code.split('\n') if line.startswith(name)] lines = [line for line in code.split("\n") if line.startswith(name)]
exec("\n".join(lines), result) exec("\n".join(lines), result)
return result[name] return result[name]
setup( setup(
name="matrix-synapse-saml-mozilla", name="matrix-synapse-saml-mozilla",
version=exec_file(("matrix_synapse_saml_mozilla.py",), "__version__"), version=exec_file(("matrix_synapse_saml_mozilla/__init__.py",), "__version__"),
py_modules=["matrix-synapse-saml-mozilla"], py_modules=["matrix-synapse-saml-mozilla"],
description="An Mozilla-flavoured SAML MXID mapper for Synapse", description="An Mozilla-flavoured SAML MXID mapper for Synapse",
install_requires=[ install_requires=["attr>=0.3.1", "pysaml2>=4.5.0"],
"attr>=0.3.1",
"pysaml2>=4.5.0",
],
long_description=read_file(("README.md",)), long_description=read_file(("README.md",)),
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
classifiers=[ classifiers=[
'Development Status :: 4 - Beta', "Development Status :: 4 - Beta",
'License :: OSI Approved :: Apache Software License', "License :: OSI Approved :: Apache Software License",
'Programming Language :: Python :: 3', "Programming Language :: Python :: 3",
], ],
) )

@ -1,16 +1,15 @@
from typing import Tuple from typing import Tuple
from matrix_synapse_saml_mozilla import SamlMappingProvider from matrix_synapse_saml_mozilla import SamlMappingProvider
def create_mapping_provider() -> Tuple[SamlMappingProvider, dict]: def create_mapping_provider() -> Tuple[SamlMappingProvider, dict]:
# Default configuration # Default configuration
config_dict = { config_dict = {}
"mxid_source_attribute": "uid"
}
# Convert the config dictionary to a SamlMappingProvider.SamlConfig object # Convert the config dictionary to a SamlMappingProvider.SamlConfig object
config = SamlMappingProvider.parse_config(config_dict) config = SamlMappingProvider.parse_config(config_dict)
# Create a new instance of the provider with the specified config # Create a new instance of the provider with the specified config
# Return the config dict as well for other test methods to use # Return the config dict as well for other test methods to use
return SamlMappingProvider(config), config_dict return SamlMappingProvider(config, None), config_dict

@ -14,103 +14,58 @@
# limitations under the License. # limitations under the License.
import logging import logging
import re
import time
import unittest import unittest
from typing import Optional
from . import create_mapping_provider from synapse.api.errors import RedirectException
logging.basicConfig()
from matrix_synapse_saml_mozilla._sessions import username_mapping_sessions
def _make_test_saml_response( from . import create_mapping_provider
provider_config: dict,
source_attribute_value: str,
display_name: Optional[str] = None
):
"""Create a fake object based off of saml2.response.AuthnResponse
Args:
provider_config: The config dictionary used when creating the provider object
source_attribute_value: The desired value that the mapping provider will
pull out of the response object to turn into a Matrix UserID.
display_name: The desired displayname that the mapping provider will pull
out of the response object to turn into a Matrix user displayname.
Returns:
An object masquerading as a saml2.response.AuthnResponse object
"""
class FakeResponse(object): logging.basicConfig()
def __init__(self):
self.ava = {
provider_config["mxid_source_attribute"]: [source_attribute_value],
}
if display_name: class FakeResponse:
self.ava["displayName"] = [display_name] def __init__(self, source_uid, display_name):
self.ava = {
"uid": [source_uid],
}
return FakeResponse() if display_name:
self.ava["displayName"] = [display_name]
class SamlUserAttributeTestCase(unittest.TestCase): class SamlUserAttributeTestCase(unittest.TestCase):
def test_redirect(self):
def _attribute_test( """Creates a dummy response, feeds it to the provider and checks that it
self, redirects to the username picker.
input_uid: str,
input_displayname: Optional[str],
output_localpart: str,
output_displayname: Optional[str],
):
"""Creates a dummy response, feeds it to the provider and checks the output
Args:
input_uid: The value of the mxid_source_attribute that the provider will
base the generated localpart off of.
input_displayname: The saml auth response displayName value that the
provider will generate a Matrix user displayname from.
output_localpart: The expected mxid localpart.
output_displayname: The expected matrix displayname.
""" """
provider, config = create_mapping_provider() provider, config = create_mapping_provider()
response = _make_test_saml_response(config, input_uid, input_displayname) response = FakeResponse(123435, "Jonny")
attribute_dict = provider.saml_response_to_user_attributes(response) # we expect this to redirect to the username picker
self.assertEqual(attribute_dict["mxid_localpart"], output_localpart) with self.assertRaises(RedirectException) as cm:
self.assertEqual(attribute_dict["displayname"], output_displayname) provider.saml_response_to_user_attributes(response, 0, "http://client/")
self.assertEqual(cm.exception.location, b"/_matrix/saml2/pick_username/")
def test_normal_user(self):
self._attribute_test("john*doe2000#@example.com", None, "john.doe2000", None) cookieheader = cm.exception.cookies[0]
regex = re.compile(b"^username_mapping_session=([a-zA-Z]+);")
def test_normal_user_displayname(self): m = regex.search(cookieheader)
self._attribute_test( if not m:
"john*doe2000#@example.com", "Jonny", "john.doe2000", "Jonny" self.fail("cookie header %s does not match %s" % (cookieheader, regex))
)
session_id = m.group(1).decode("ascii")
def test_multiple_adjacent_symbols(self): self.assertIn(
self._attribute_test("bob%^$&#!bobby@example.com", None, "bob.bobby", None) session_id, username_mapping_sessions, "session id not found in map"
def test_username_does_not_end_with_dot(self):
"""This is allowed in mxid syntax, but is not aesthetically pleasing"""
self._attribute_test("bob.bobby$@example.com", None, "bob.bobby", None)
def test_username_no_email(self):
self._attribute_test("bob.bobby", None, "bob.bobby", None)
def test_username_starting_with_underscore(self):
self._attribute_test(
"_twilight (sparkle)@somewhere.com", None, "twilight.sparkle", None
) )
session = username_mapping_sessions[session_id]
def test_existing_user(self): self.assertEqual(session.remote_user_id, 123435)
provider, config = create_mapping_provider() self.assertEqual(session.displayname, "Jonny")
response = _make_test_saml_response(config, "wibble%@wobble.com", None) self.assertEqual(session.client_redirect_url, "http://client/")
attribute_dict = provider.saml_response_to_user_attributes(response) # the expiry time should be about 15 minutes away
expected_expiry = (time.time() + 15 * 60) * 1000
# Simulate a failure on the first attempt self.assertGreaterEqual(session.expiry_time_ms, expected_expiry - 1000)
attribute_dict = provider.saml_response_to_user_attributes(response, failures=1) self.assertLessEqual(session.expiry_time_ms, expected_expiry + 1000)
self.assertEqual(attribute_dict["mxid_localpart"], "wibble1")
self.assertEqual(attribute_dict["displayname"], None)

@ -1,25 +1,33 @@
[tox] [tox]
envlist = packaging, pep8 envlist = packaging, lint, tests
[testenv] [testenv]
setenv = setenv =
PYTHONDONTWRITEBYTECODE = no_byte_code PYTHONDONTWRITEBYTECODE = no_byte_code
PYTHONPATH = .
[testenv:tests] [testenv:tests]
deps =
git+git://github.com/matrix-org/synapse@rav/mozilla_username_hacks#egg=matrix-synapse
commands = commands =
python -m unittest discover python -m unittest discover
[testenv:packaging] [testenv:packaging]
skip_install = True
deps = deps =
check-manifest check-manifest
commands = commands =
check-manifest check-manifest
[testenv:pep8] [testenv:lint]
skip_install = True skip_install = True
basepython = python3 basepython = python3
deps = deps =
flake8 flake8
# We pin so that our tests don't start failing on new releases of black.
black==19.10b0
isort
commands = commands =
flake8 saml_mapping_provider.py tests/* python -m black --check --diff .
flake8 matrix_synapse_saml_mozilla tests
isort -c -df -sp setup.cfg -rc matrix_synapse_saml_mozilla tests

Loading…
Cancel
Save