# -*- coding: utf-8 -*-
# Copyright (c) 2010-2018 OneLogin, Inc.
# MIT License
from base64 import b64decode
from datetime import datetime
from lxml import etree
from os.path import basename
from urllib import urlencode
from urlparse import parse_qs
from xml.dom.minidom import Document, parseString
import dm.xmlsec.binding as xmlsec
from saml2.constants import OneLogin_Saml2_Constants
from saml2.utils import OneLogin_Saml2_Utils
[docs]class OneLogin_Saml2_Logout_Request:
    def __init__(self, settings,
request=None,name_id=None, session_index=None):
        """
        Constructs the Logout Request object.
        Arguments are:
            * (OneLogin_Saml2_Settings)   settings. Setting data
        """
        self.__settings = settings
        sp_data = self.__settings.get_sp_data()
        idp_data = self.__settings.get_idp_data()
        security = self.__settings.get_security_data()
        uid = OneLogin_Saml2_Utils.generate_unique_id()
        name_id_value = OneLogin_Saml2_Utils.generate_unique_id()
        issue_instant = OneLogin_Saml2_Utils.parse_time_to_SAML(int(datetime.now().strftime("%s")))
        key = None
        if 'nameIdEncrypted' in security and security['nameIdEncrypted']:
            key = idp_data['x509cert']
        name_id = OneLogin_Saml2_Utils.generate_name_id(
            name_id_value,
            sp_data['entityId'],
            sp_data['NameIDFormat'],
            key
        )
        logout_request = """<samlp:LogoutRequest
    xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol"
    xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion"
    ID="%(id)s"
    Version="2.0"
    IssueInstant="%(issue_instant)s"
    Destination="%(single_logout_url)s">
    <saml:Issuer>%(entity_id)s</saml:Issuer>
    %(name_id)s
</samlp:LogoutRequest>""" % {
            'id': uid,
            'issue_instant': issue_instant,
            'single_logout_url': idp_data['singleLogoutService']['url'],
            'entity_id': sp_data['entityId'],
            'name_id': name_id,
        }
        self.__logout_request = logout_request
[docs]    def get_request(self):
        """
        Returns the Logout Request defated, base64encoded
        :return: Deflated base64 encoded Logout Request
        :rtype: str object
        """
        return OneLogin_Saml2_Utils.deflate_and_base64_encode(self.__logout_request)
 
    @staticmethod
[docs]    def get_id(request):
        """
        Returns the ID of the Logout Request
        :param request: Logout Request Message
        :type request: string|DOMDocument
        :return: string ID
        :rtype: str object
        """
        if isinstance(request, Document):
            dom = request
        else:
            dom = parseString(request)
        return dom.documentElement.getAttribute('ID')
 
    @staticmethod
[docs]    def get_name_id_data(request, key=None):
        """
        Gets the NameID Data of the the Logout Request
        :param request: Logout Request Message
        :type request: string|DOMDocument
        :param key: The SP key
        :type key: string
        :return: Name ID Data (Value, Format, NameQualifier, SPNameQualifier)
        :rtype: dict
        """
        if isinstance(request, Document):
            request = request.toxml()
        dom = etree.fromstring(request)
        name_id = None
        encrypted_entries = OneLogin_Saml2_Utils.query(dom, '/samlp:LogoutRequest/saml:EncryptedID')
        if len(encrypted_entries) == 1:
            if key is None:
                raise Exception('Key is required in order to decrypt the NameID')
            elem = parseString(etree.tostring(encrypted_entries[0]))
            encrypted_data_nodes = elem.documentElement.getElementsByTagName('xenc:EncryptedData')
            encrypted_data = encrypted_data_nodes[0]
            xmlsec.initialize()
            # Load the key into the xmlsec context
            file_key = OneLogin_Saml2_Utils.write_temp_file(key)  # FIXME avoid writing a file
            enc_key = xmlsec.Key.load(file_key.name, xmlsec.KeyDataFormatPem, None)
            enc_key.name = basename(file_key.name)
            file_key.close()
            enc_ctx = xmlsec.EncCtx()
            enc_ctx.encKey = enc_key
            name_id = OneLogin_Saml2_Utils.decrypt_element(encrypted_data, enc_ctx)
        else:
            entries = OneLogin_Saml2_Utils.query(dom, '/samlp:LogoutRequest/saml:NameID')
            if len(entries) == 1:
                name_id = entries[0]
        if name_id is None:
            raise Exception('Not NameID found in the Logout Request')
        name_id_data = {
            'Value': name_id.text
        }
        for attr in ['Format', 'SPNameQualifier', 'NameQualifier']:
            if attr in name_id.attrib.keys():
                name_id_data[attr] = name_id.attrib[attr]
        return name_id_data
 
    @staticmethod
[docs]    def get_name_id(request, key=None):
        """
        Gets the NameID of the Logout Request Message
        :param request: Logout Request Message
        :type request: string|DOMDocument
        :param key: The SP key
        :type key: string
        :return: Name ID Value
        :rtype: string
        """
        name_id = OneLogin_Saml2_Logout_Request.get_name_id_data(request, key)
        return name_id['Value']
 
    @staticmethod
[docs]    def get_issuer(request):
        """
        Gets the Issuer of the Logout Request Message
        :param request: Logout Request Message
        :type request: string|DOMDocument
        :return: The Issuer
        :rtype: string
        """
        if isinstance(request, Document):
            request = request.toxml()
        dom = etree.fromstring(request)
        issuer = None
        issuer_nodes = OneLogin_Saml2_Utils.query(dom, '/samlp:LogoutRequest/saml:Issuer')
        if len(issuer_nodes) == 1:
            issuer = issuer_nodes[0].text
        return issuer
 
    @staticmethod
[docs]    def get_session_indexes(request):
        """
        Gets the SessionIndexes from the Logout Request
        :param request: Logout Request Message
        :type request: string|DOMDocument
        :return: The SessionIndex value
        :rtype: list
        """
        if isinstance(request, Document):
            request = request.toxml()
        dom = etree.fromstring(request)
        session_indexes = []
        session_index_nodes = OneLogin_Saml2_Utils.query(dom, '/samlp:LogoutRequest/samlp:SessionIndex')
        for session_index_node in session_index_nodes:
            session_indexes.append(session_index_node.text)
        return session_indexes
 
    @staticmethod
[docs]    def is_valid(settings, request, get_data, debug=False):
        """
        Checks if the Logout Request recieved is valid
        :param settings: Settings
        :type settings: OneLogin_Saml2_Settings
        :param request: Logout Request Message
        :type request: string|DOMDocument
        :return: If the Logout Request is or not valid
        :rtype: boolean
        """
        try:
            if isinstance(request, Document):
                dom = request
            else:
                dom = parseString(request)
            idp_data = settings.get_idp_data()
            idp_entity_id = idp_data['entityId']
            if settings.is_strict():
                res = OneLogin_Saml2_Utils.validate_xml(dom, 'saml-schema-protocol-2.0.xsd', debug)
                if not isinstance(res, Document):
                    raise Exception('Invalid SAML Logout Request. Not match the saml-schema-protocol-2.0.xsd')
                security = settings.get_security_data()
                current_url = OneLogin_Saml2_Utils.get_self_url_no_query(get_data)
                # Check NotOnOrAfter
                if dom.documentElement.hasAttribute('NotOnOrAfter'):
                    na = OneLogin_Saml2_Utils.parse_SAML_to_time(dom.documentElement.getAttribute('NotOnOrAfter'))
                    if na <= datetime.now():
                        raise Exception('Timing issues (please check your clock settings)')
                # Check destination
                if dom.documentElement.hasAttribute('Destination'):
                    destination = dom.documentElement.getAttribute('Destination')
                    if destination is not None:
                        if current_url not in destination:
                            raise Exception('The LogoutRequest was received at $currentURL instead of $destination')
                # Check issuer
                issuer = OneLogin_Saml2_Logout_Request.get_issuer(dom)
                if issuer is None or issuer != idp_entity_id:
                    raise Exception('Invalid issuer in the Logout Request')
                if security['wantMessagesSigned']:
                    if 'Signature' not in get_data:
                        raise Exception('The Message of the Logout Request is not signed and the SP require it')
            if 'Signature' in get_data:
                if 'SigAlg' not in get_data:
                    sign_alg = OneLogin_Saml2_Constants.RSA_SHA1
                else:
                    sign_alg = get_data['SigAlg']
                if sign_alg != OneLogin_Saml2_Constants.RSA_SHA1:
                    raise Exception('Invalid signAlg in the recieved Logout Request')
                signed_query = 'SAMLRequest=%s' % urlencode(get_data['SAMLRequest'])
                if 'RelayState' in get_data:
                    signed_query = '%s&RelayState=%s' % (signed_query, urlencode(get_data['RelayState']))
                signed_query = '%s&SigAlg=%s' % (signed_query, urlencode(sign_alg))
                if 'x509cert' not in idp_data or idp_data['x509cert'] is None:
                    raise Exception('In order to validate the sign on the Logout Request, the x509cert of the IdP is required')
                cert = idp_data['x509cert']
                xmlsec.initialize()
                objkey = xmlsec.Key.load(cert, xmlsec.KeyDataFormatPem, None)  # FIXME is this right?
                if not objkey.verifySignature(signed_query, b64decode(get_data['Signature'])):
                    raise Exception('Signature validation failed. Logout Request rejected')
            return True
        except Exception as e:
            debug = settings.is_debug_active()
            if debug:
                print(e.strerror)
            return False