# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Optional

import attr
import time


SESSION_COOKIE_NAME = b"username_mapping_session"

logger = logging.getLogger(__name__)


@attr.s
class UsernameMappingSession:
    """Data we track about SAML2 sessions"""

    # user ID on the SAML server
    remote_user_id = attr.ib(type=str)

    # displayname, per the SAML attributes
    displayname = attr.ib(type=Optional[str])

    # where to redirect the client back to
    client_redirect_url = attr.ib(type=str)

    # expiry time for the session, in milliseconds
    expiry_time_ms = attr.ib(type=int)


# a map from session id to session data
username_mapping_sessions = {}  # type: dict[str, UsernameMappingSession]


def expire_old_sessions(gettime=time.time):
    """Delete any sessions which have passed their expiry_time"""
    to_expire = []
    now = int(gettime() * 1000)

    for session_id, session in username_mapping_sessions.items():
        if session.expiry_time_ms <= now:
            to_expire.append(session_id)

    for session_id in to_expire:
        logger.info("Expiring mapping session %s", session_id)
        del username_mapping_sessions[session_id]


def get_mapping_session(session_id: str) -> Optional[UsernameMappingSession]:
    """Look up the given session id, first expiring any old sessions"""
    expire_old_sessions()
    return username_mapping_sessions.get(session_id, None)