# -*- coding: utf-8 -*-
# Copyright (c) 2010-2018 OneLogin, Inc.
# MIT License
from datetime import datetime
import json
import re
from os.path import dirname, exists, join, sep
from xml.dom.minidom import Document
from saml2.constants import OneLogin_Saml2_Constants
from saml2.errors import OneLogin_Saml2_Error
from saml2.metadata import OneLogin_Saml2_Metadata
from saml2.utils import OneLogin_Saml2_Utils
# Regex from Django Software Foundation and individual contributors.
# Released under a BSD 3-Clause License
url_regex = re.compile(
    r'^(?:[a-z0-9\.\-]*)://'  # scheme is validated separately
    r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|'  # domain...
    r'localhost|'  # localhost...
    r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}|'  # ...or ipv4
    r'\[?[A-F0-9]*:[A-F0-9:]+\]?)'  # ...or ipv6
    r'(?::\d+)?'  # optional port
    r'(?:/?|[/?]\S+)$', re.IGNORECASE)
url_schemes = ['http', 'https', 'ftp', 'ftps']
[docs]def validate_url(url):
    scheme = url.split('://')[0].lower()
    if scheme not in url_schemes:
        return False
    if not bool(url_regex.search(url)):
        return False
    return True
 
[docs]class OneLogin_Saml2_Settings:
    def __init__(self, settings=None, custom_base_path=None):
        """
        Initializes the settings:
        - Sets the paths of the different folders
        - Loads settings info from settings file or array/object provided
        :param settings: SAML Toolkit Settings
        :type settings: dict|object
        """
        self.__paths = {}
        self.__strict = False
        self.__debug = False
        self.__sp = {}
        self.__idp = {}
        self.__contacts = {}
        self.__organization = {}
        self.__errors = []
        self.__load_paths(base_path=custom_base_path)
        self.__update_paths(settings)
        if settings is None:
            if not self.__load_settings_from_file():
                raise OneLogin_Saml2_Error(
                    'Invalid file settings: %s',
                    OneLogin_Saml2_Error.SETTINGS_INVALID,
                    ','.join(self.__errors)
                )
            self.__add_default_values()
        elif isinstance(settings, dict):
            if not self.__load_settings_from_dict(settings):
                raise OneLogin_Saml2_Error(
                    'Invalid dict settings: %s',
                    OneLogin_Saml2_Error.SETTINGS_INVALID,
                    ','.join(self.__errors)
                )
        else:
            raise Exception('Unsupported settings object')
        self.format_idp_cert()
    def __load_paths(self, base_path=None):
        """
        Sets the paths of the different folders
        """
        if base_path is None:
            base_path = dirname(dirname(dirname(__file__)))
        base_path += sep
        self.__paths = {
            'base': base_path,
            'cert': base_path + 'certs' + sep,
            'lib': base_path + 'lib' + sep,
            'extlib': base_path + 'extlib' + sep,
        }
    def __update_paths(self, settings):
        """
        Set custom paths if necessary
        """
        if not isinstance(settings, dict):
            return
        if 'custom_base_path' in settings:
            base_path = settings['custom_base_path']
            base_path = join(dirname(__file__), base_path)
            self.__load_paths(base_path)
[docs]    def get_base_path(self):
        """
        Returns base path
        :return: The base toolkit folder path
        :rtype: string
        """
        return self.__paths['base']
 
[docs]    def get_cert_path(self):
        """
        Returns cert path
        :return: The cert folder path
        :rtype: string
        """
        return self.__paths['cert']
 
[docs]    def get_lib_path(self):
        """
        Returns lib path
        :return: The library folder path
        :rtype: string
        """
        return self.__paths['lib']
 
[docs]    def get_ext_lib_path(self):
        """
        Returns external lib path
        :return: The external library folder path
        :rtype: string
        """
        return self.__paths['extlib']
 
[docs]    def get_schemas_path(self):
        """
        Returns schema path
        :return: The schema folder path
        :rtype: string
        """
        return self.__paths['lib'] + 'schemas/'
 
    def __load_settings_from_dict(self, settings):
        """
        Loads settings info from a settings Dict
        :param settings: SAML Toolkit Settings
        :type settings: dict
        :returns: True if the settings info is valid
        :rtype: boolean
        """
        errors = self.check_settings(settings)
        if len(errors) == 0:
            self.__errors = []
            self.__sp = settings['sp']
            self.__idp = settings['idp']
            if 'strict' in settings:
                self.__strict = settings['strict']
            if 'debug' in settings:
                self.__debug = settings['debug']
            if 'security' in settings:
                self.__security = settings['security']
            if 'contactPerson' in settings:
                self.__contacts = settings['contactPerson']
            if 'organization' in settings:
                self.__organization = settings['organization']
            self.__add_default_values()
            return True
        self.__errors = errors
        return False
    def __load_settings_from_file(self):
        """
        Loads settings info from the settings json file
        :returns: True if the settings info is valid
        :rtype: boolean
        """
        filename = self.get_base_path() + 'settings.json'
        if not exists(filename):
            raise OneLogin_Saml2_Error(
                'Settings file not found: %s',
                OneLogin_Saml2_Error.SETTINGS_FILE_NOT_FOUND,
                filename
            )
        # In the php toolkit instead of being a json file it is a php file and
        # it is directly included
        json_data = open(filename, 'r')
        settings = json.load(json_data)
        json_data.close()
        advanced_filename = self.get_base_path() + 'advanced_settings.json'
        if exists(advanced_filename):
            json_data = open(advanced_filename, 'r')
            settings.update(json.load(json_data))  # Merge settings
            json_data.close()
        return self.__load_settings_from_dict(settings)
    def __add_default_values(self):
        """
        Add default values if the settings info is not complete
        """
        if 'binding' not in self.__sp['assertionConsumerService']:
            self.__sp['assertionConsumerService']['binding'] = OneLogin_Saml2_Constants.BINDING_HTTP_POST
        if 'singleLogoutService' in self.__sp and 'binding' not in self.__sp['singleLogoutService']:
            self.__sp['singleLogoutService']['binding'] = OneLogin_Saml2_Constants.BINDING_HTTP_REDIRECT
        # Related to nameID
        if 'NameIDFormat' not in self.__sp:
            self.__sp['NameIDFormat'] = OneLogin_Saml2_Constants.NAMEID_PERSISTENT
        if 'nameIdEncrypted' not in self.__security:
            self.__security['nameIdEncrypted'] = False
        # Sign provided
        if 'authnRequestsSigned' not in self.__security:
            self.__security['authnRequestsSigned'] = False
        if 'logoutRequestSigned' not in self.__security:
            self.__security['logoutRequestSigned'] = False
        if 'logoutResponseSigned' not in self.__security:
            self.__security['logoutResponseSigned'] = False
        if 'signMetadata' not in self.__security:
            self.__security['signMetadata'] = False
        # Sign expected
        if 'wantMessagesSigned' not in self.__security:
            self.__security['wantMessagesSigned'] = False
        if 'wantAssertionsSigned' not in self.__security:
            self.__security['wantAssertionsSigned'] = False
        # Encrypt expected
        if 'wantAssertionsEncrypted' not in self.__security:
            self.__security['wantAssertionsEncrypted'] = False
        if 'wantNameIdEncrypted' not in self.__security:
            self.__security['wantNameIdEncrypted'] = False
        if 'x509cert' not in self.__idp:
            self.__idp['x509cert'] = ''
        if 'certFingerprint' not in self.__idp:
            self.__idp['certFingerprint'] = ''
[docs]    def check_settings(self, settings):
        """
        Checks the settings info.
        :param settings: Dict with settings data
        :type settings: dict
        :returns: Errors found on the settings data
        :rtype: list
        """
        assert isinstance(settings, dict)
        errors = []
        if not isinstance(settings, dict) or len(settings) == 0:
            errors.append('invalid_syntax')
            return errors
        if 'idp' not in settings or len(settings['idp']) == 0:
            errors.append('idp_not_found')
        else:
            idp = settings['idp']
            if 'entityId' not in idp or len(idp['entityId']) == 0:
                errors.append('idp_entityId_not_found')
            if ('singleSignOnService' not in idp or
                'url' not in idp['singleSignOnService'] or
                    len(idp['singleSignOnService']['url']) == 0):
                errors.append('idp_sso_not_found')
            elif not validate_url(idp['singleSignOnService']['url']):
                errors.append('idp_sso_url_invalid')
            if ('singleLogoutService' in idp and
                'url' in idp['singleLogoutService'] and
                len(idp['singleLogoutService']['url']) > 0 and
                    not validate_url(idp['singleLogoutService']['url'])):
                errors.append('idp_slo_url_invalid')
        if 'sp' not in settings or len(settings['sp']) == 0:
            errors.append('sp_not_found')
        else:
            sp = settings['sp']
            security = {}
            if 'security' in settings:
                security = settings['security']
            if 'entityId' not in sp or len(sp['entityId']) == 0:
                errors.append('sp_entityId_not_found')
            if ('assertionConsumerService' not in sp or
                'url' not in sp['assertionConsumerService'] or
                    len(sp['assertionConsumerService']['url']) == 0):
                errors.append('sp_acs_not_found')
            elif not validate_url(sp['assertionConsumerService']['url']):
                errors.append('sp_acs_url_invalid')
            if ('singleLogoutService' in sp and
                'url' in sp['singleLogoutService'] and
                len(sp['singleLogoutService']['url']) > 0 and
                    not validate_url(sp['singleLogoutService']['url'])):
                errors.append('sp_sls_url_invalid')
            if 'signMetadata' in security and isinstance(security['signMetadata'], dict):
                if ('keyFileName' not in security['signMetadata'] or
                        'certFileName' not in security['signMetadata']):
                    errors.append('sp_signMetadata_invalid')
            if ((('authnRequestsSigned' in security and security['authnRequestsSigned']) or
                 ('logoutRequestSigned' in security and security['logoutRequestSigned']) or
                 ('logoutResponseSigned' in security and security['logoutResponseSigned']) or
                 ('wantAssertionsEncrypted' in security and security['wantAssertionsEncrypted']) or
                 ('wantNameIdEncrypted' in security and security['wantNameIdEncrypted'])) and
                    not self.check_sp_certs()):
                errors.append('sp_cert_not_found_and_required')
            exists_X509 = ('idp' in settings and
                           'x509cert' in settings['idp'] and
                           len(settings['idp']['x509cert']) > 0)
            exists_fingerprint = ('idp' in settings and
                                  'certFingerprint' in settings['idp'] and
                                  len(settings['idp']['certFingerprint']) > 0)
            if ((('wantAssertionsSigned' in security and security['wantAssertionsSigned']) or
                 ('wantMessagesSigned' in security and security['wantMessagesSigned'])) and
                    not(exists_X509 or exists_fingerprint)):
                errors.append('idp_cert_or_fingerprint_not_found_and_required')
            if ('nameIdEncrypted' in security and security['nameIdEncrypted']) and not exists_X509:
                errors.append('idp_cert_not_found_and_required')
        if 'contactPerson' in settings:
            types = settings['contactPerson'].keys()
            valid_types = ['technical', 'support', 'administrative', 'billing', 'other']
            for t in types:
                if t not in valid_types:
                    errors.append('contact_type_invalid')
                    break
            for t in settings['contactPerson']:
                contact = settings['contactPerson'][t]
                if (('givenName' not in contact or len(contact['givenName']) == 0) or
                        ('emailAddress' not in contact or len(contact['emailAddress']) == 0)):
                    errors.append('contact_not_enought_data')
                    break
        if 'organization' in settings:
            for o in settings['organization']:
                organization = settings['organization'][o]
                if (('name' not in organization or len(organization['name']) == 0) or
                    ('displayname' not in organization or len(organization['displayname']) == 0) or
                        ('url' not in organization or len(organization['url']) == 0)):
                    errors.append('organization_not_enought_data')
                    break
        return errors
 
[docs]    def check_sp_certs(self):
        """
        Checks if the x509 certs of the SP exists and are valid.
        :returns: If the x509 certs of the SP exists and are valid
        :rtype: boolean
        """
        key = self.get_sp_key()
        cert = self.get_sp_cert()
        return key is not None and cert is not None
 
[docs]    def get_sp_key(self):
        """
        Returns the x509 private key of the SP.
        :returns: SP private key
        :rtype: string
        """
        key = None
        key_file = self.__paths['cert'] + 'sp.key'
        if exists(key_file):
            f = open(key_file, 'r')
            key = f.read()
            f.close()
        return key
 
[docs]    def get_sp_cert(self):
        """
        Returns the x509 public cert of the SP.
        :returns: SP public cert
        :rtype: string
        """
        cert = None
        cert_file = self.__paths['cert'] + 'sp.crt'
        if exists(cert_file):
            f = open(cert_file, 'r')
            cert = f.read()
            f.close()
        return cert
 
[docs]    def get_idp_data(self):
        """
        Gets the IdP data.
        :returns: IdP info
        :rtype: dict
        """
        return self.__idp
 
[docs]    def get_sp_data(self):
        """
        Gets the SP data.
        :returns: SP info
        :rtype: dict
        """
        return self.__sp
 
[docs]    def get_security_data(self):
        """
        Gets security data.
        :returns: Security info
        :rtype: dict
        """
        return self.__security
 
[docs]    def get_organization(self):
        """
        Gets organization data.
        :returns: Organization info
        :rtype: dict
        """
        return self.__organization
 
[docs]    def get_errors(self):
        """
        Returns an array with the errors, the array is empty when the settings is ok.
        :returns: Errors
        :rtype: list
        """
        return self.__errors
 
[docs]    def set_strict(self, value):
        """
        Activates or deactivates the strict mode.
        :param xml: Strict parameter
        :type xml: boolean
        """
        assert isinstance(value, bool)
        self.__strict = value
 
[docs]    def is_strict(self):
        """
        Returns if the 'strict' mode is active.
        :returns: Strict parameter
        :rtype: boolean
        """
        return self.__strict
 
[docs]    def is_debug_active(self):
        """
        Returns if the debug is active.
        :returns: Debug parameter
        :rtype: boolean
        """
        return self.__debug