# -*- coding: utf-8 -*-
""" OneLogin_Saml2_Utils class
MIT License
Auxiliary class of Python Toolkit.
"""
import base64
from copy import deepcopy
from datetime import datetime
import calendar
from hashlib import sha1, sha256, sha384, sha512
from isodate import parse_duration as duration_parser
from lxml import etree
from os.path import basename, dirname, join
import re
from sys import stderr
from tempfile import NamedTemporaryFile
from textwrap import wrap
from urllib import quote_plus
from urlparse import urlsplit, urlunsplit
from uuid import uuid4
from xml.dom.minidom import Document, Element
from defusedxml.minidom import parseString
from functools import wraps
import zlib
import dm.xmlsec.binding as xmlsec
from dm.xmlsec.binding.tmpl import EncData, Signature
from onelogin.saml2.constants import OneLogin_Saml2_Constants
from onelogin.saml2.errors import OneLogin_Saml2_Error, OneLogin_Saml2_ValidationError
from onelogin.saml2.xmlparser import tostring, fromstring
if not globals().get('xmlsec_setup', False):
xmlsec.initialize()
globals()['xmlsec_setup'] = True
[docs]def return_false_on_exception(func):
"""
Decorator. When applied to a function, it will, by default, suppress any exceptions
raised by that function and return False. It may be overridden by passing a
"raise_exceptions" keyword argument when calling the wrapped function.
"""
@wraps(func)
def exceptfalse(*args, **kwargs):
if not kwargs.pop('raise_exceptions', False):
try:
return func(*args, **kwargs)
except Exception:
return False
else:
return func(*args, **kwargs)
return exceptfalse
[docs]def print_xmlsec_errors(filename, line, func, error_object, error_subject, reason, msg):
"""
Auxiliary method. It overrides the default xmlsec debug message.
"""
info = []
if error_object != "unknown":
info.append("obj=" + error_object)
if error_subject != "unknown":
info.append("subject=" + error_subject)
if msg.strip():
info.append("msg=" + msg)
if reason != 1:
info.append("errno=%d" % reason)
if info:
print("%s:%d(%s)" % (filename, line, func), " ".join(info))
[docs]class OneLogin_Saml2_Utils(object):
"""
Auxiliary class that contains several utility methods to parse time,
urls, add sign, encrypt, decrypt, sign validation, handle xml ...
"""
RESPONSE_SIGNATURE_XPATH = '/samlp:Response/ds:Signature'
ASSERTION_SIGNATURE_XPATH = '/samlp:Response/saml:Assertion/ds:Signature'
TIME_FORMAT = "%Y-%m-%dT%H:%M:%SZ"
TIME_FORMAT_2 = "%Y-%m-%dT%H:%M:%S.%fZ"
TIME_FORMAT_WITH_FRAGMENT = re.compile(r'^(\d{4,4}-\d{2,2}-\d{2,2}T\d{2,2}:\d{2,2}:\d{2,2})(\.\d*)?Z?$')
[docs] @staticmethod
def decode_base64_and_inflate(value):
"""
base64 decodes and then inflates according to RFC1951
:param value: a deflated and encoded string
:type value: string
:returns: the string after decoding and inflating
:rtype: string
"""
decoded = base64.b64decode(value)
# We try to inflate
try:
result = zlib.decompress(decoded, -15)
except Exception:
result = decoded
return result.decode('utf-8')
[docs] @staticmethod
def deflate_and_base64_encode(value):
"""
Deflates and then base64 encodes a string
:param value: The string to deflate and encode
:type value: string
:returns: The deflated and encoded string
:rtype: string
"""
return base64.b64encode(zlib.compress(value.encode('utf-8'))[2:-4])
[docs] @staticmethod
def validate_xml(xml, schema, debug=False):
"""
Validates a xml against a schema
:param xml: The xml that will be validated
:type: string|DomDocument
:param schema: The schema
:type: string
:param debug: If debug is active, the parse-errors will be showed
:type: bool
:returns: Error code or the DomDocument of the xml
:rtype: string
"""
assert isinstance(xml, basestring) or isinstance(xml, Document) or isinstance(xml, etree._Element)
assert isinstance(schema, basestring)
if isinstance(xml, Document):
xml = xml.toxml()
elif isinstance(xml, etree._Element):
xml = tostring(xml, encoding='unicode')
# Switch to lxml for schema validation
try:
dom = fromstring(xml.encode('utf-8'), forbid_dtd=True)
except Exception:
return 'unloaded_xml'
schema_file = join(dirname(__file__), 'schemas', schema)
f_schema = open(schema_file, 'r')
schema_doc = etree.parse(f_schema)
f_schema.close()
xmlschema = etree.XMLSchema(schema_doc)
if not xmlschema.validate(dom):
if debug:
stderr.write('Errors validating the metadata')
stderr.write(':\n\n')
for error in xmlschema.error_log:
stderr.write('%s\n' % error.message)
return 'invalid_xml'
return parseString(tostring(dom, encoding='unicode').encode('utf-8'), forbid_dtd=True, forbid_entities=True, forbid_external=True)
[docs] @staticmethod
def element_text(node):
# Double check, the LXML Parser already removes comments
etree.strip_tags(node, etree.Comment)
return node.text
[docs] @staticmethod
def redirect(url, parameters={}, request_data={}):
"""
Executes a redirection to the provided url (or return the target url).
:param url: The target url
:type: string
:param parameters: Extra parameters to be passed as part of the url
:type: dict
:param request_data: The request as a dict
:type: dict
:returns: Url
:rtype: string
"""
assert isinstance(url, basestring)
assert isinstance(parameters, dict)
if url.startswith('/'):
url = '%s%s' % (OneLogin_Saml2_Utils.get_self_url_host(request_data), url)
# Verify that the URL is to a http or https site.
if re.search('^https?://', url) is None:
raise OneLogin_Saml2_Error(
'Redirect to invalid URL: ' + url,
OneLogin_Saml2_Error.REDIRECT_INVALID_URL
)
# Add encoded parameters
if url.find('?') < 0:
param_prefix = '?'
else:
param_prefix = '&'
for name, value in parameters.items():
if value is None:
param = quote_plus(name)
elif isinstance(value, list):
param = ''
for val in value:
param += quote_plus(name) + '[]=' + quote_plus(val) + '&'
if len(param) > 0:
param = param[0:-1]
else:
param = quote_plus(name) + '=' + quote_plus(value)
if param:
url += param_prefix + param
param_prefix = '&'
return url
[docs] @staticmethod
def get_self_url_host(request_data):
"""
Returns the protocol + the current host + the port (if different than
common ports).
:param request_data: The request as a dict
:type: dict
:return: Url
:rtype: string
"""
current_host = OneLogin_Saml2_Utils.get_self_host(request_data)
port = ''
if OneLogin_Saml2_Utils.is_https(request_data):
protocol = 'https'
else:
protocol = 'http'
if 'server_port' in request_data and request_data['server_port'] is not None:
port_number = str(request_data['server_port'])
port = ':' + port_number
if protocol == 'http' and port_number == '80':
port = ''
elif protocol == 'https' and port_number == '443':
port = ''
return '%s://%s%s' % (protocol, current_host, port)
[docs] @staticmethod
def get_self_host(request_data):
"""
Returns the current host.
:param request_data: The request as a dict
:type: dict
:return: The current host
:rtype: string
"""
if 'http_host' in request_data:
current_host = request_data['http_host']
elif 'server_name' in request_data:
current_host = request_data['server_name']
else:
raise Exception('No hostname defined')
if ':' in current_host:
current_host_data = current_host.split(':')
possible_port = current_host_data[-1]
try:
possible_port = float(possible_port)
current_host = current_host_data[0]
except ValueError:
current_host = ':'.join(current_host_data)
return current_host
[docs] @staticmethod
def is_https(request_data):
"""
Checks if https or http.
:param request_data: The request as a dict
:type: dict
:return: False if https is not active
:rtype: boolean
"""
is_https = 'https' in request_data and request_data['https'] != 'off'
is_https = is_https or ('server_port' in request_data and str(request_data['server_port']) == '443')
return is_https
[docs] @staticmethod
def get_self_url_no_query(request_data):
"""
Returns the URL of the current host + current view.
:param request_data: The request as a dict
:type: dict
:return: The url of current host + current view
:rtype: string
"""
self_url_host = OneLogin_Saml2_Utils.get_self_url_host(request_data)
script_name = request_data['script_name']
if script_name:
if script_name[0] != '/':
script_name = '/' + script_name
else:
script_name = ''
self_url_no_query = self_url_host + script_name
if 'path_info' in request_data:
self_url_no_query += request_data['path_info']
return self_url_no_query
[docs] @staticmethod
def get_self_routed_url_no_query(request_data):
"""
Returns the routed URL of the current host + current view.
:param request_data: The request as a dict
:type: dict
:return: The url of current host + current view
:rtype: string
"""
self_url_host = OneLogin_Saml2_Utils.get_self_url_host(request_data)
route = ''
if 'request_uri' in request_data.keys() and request_data['request_uri']:
route = request_data['request_uri']
if 'query_string' in request_data.keys() and request_data['query_string']:
route = route.replace(request_data['query_string'], '')
return self_url_host + route
[docs] @staticmethod
def get_self_url(request_data):
"""
Returns the URL of the current host + current view + query.
:param request_data: The request as a dict
:type: dict
:return: The url of current host + current view + query
:rtype: string
"""
self_url_host = OneLogin_Saml2_Utils.get_self_url_host(request_data)
request_uri = ''
if 'request_uri' in request_data:
request_uri = request_data['request_uri']
if not request_uri.startswith('/'):
match = re.search('^https?://[^/]*(/.*)', request_uri)
if match is not None:
request_uri = match.groups()[0]
return self_url_host + request_uri
[docs] @staticmethod
def generate_unique_id():
"""
Generates an unique string (used for example as ID for assertions).
:return: A unique string
:rtype: string
"""
return 'ONELOGIN_%s' % sha1(uuid4().hex).hexdigest()
[docs] @staticmethod
def parse_time_to_SAML(time):
r"""
Converts a UNIX timestamp to SAML2 timestamp on the form
yyyy-mm-ddThh:mm:ss(\.s+)?Z.
:param time: The time we should convert (DateTime).
:type: string
:return: SAML2 timestamp.
:rtype: string
"""
data = datetime.utcfromtimestamp(float(time))
return data.strftime(OneLogin_Saml2_Utils.TIME_FORMAT)
[docs] @staticmethod
def parse_SAML_to_time(timestr):
r"""
Converts a SAML2 timestamp on the form yyyy-mm-ddThh:mm:ss(\.s+)?Z
to a UNIX timestamp. The sub-second part is ignored.
:param time: The time we should convert (SAML Timestamp).
:type: string
:return: Converted to a unix timestamp.
:rtype: int
"""
try:
data = datetime.strptime(timestr, OneLogin_Saml2_Utils.TIME_FORMAT)
except ValueError:
try:
data = datetime.strptime(timestr, OneLogin_Saml2_Utils.TIME_FORMAT_2)
except ValueError:
elem = OneLogin_Saml2_Utils.TIME_FORMAT_WITH_FRAGMENT.match(timestr)
if not elem:
raise Exception("time data %s does not match format %s" % (timestr, r'yyyy-mm-ddThh:mm:ss(\.s+)?Z'))
data = datetime.strptime(elem.groups()[0] + "Z", OneLogin_Saml2_Utils.TIME_FORMAT)
return calendar.timegm(data.utctimetuple())
[docs] @staticmethod
def now():
"""
:return: unix timestamp of actual time.
:rtype: int
"""
return calendar.timegm(datetime.utcnow().utctimetuple())
[docs] @staticmethod
def parse_duration(duration, timestamp=None):
"""
Interprets a ISO8601 duration value relative to a given timestamp.
:param duration: The duration, as a string.
:type: string
:param timestamp: The unix timestamp we should apply the duration to.
Optional, default to the current time.
:type: string
:return: The new timestamp, after the duration is applied.
:rtype: int
"""
assert isinstance(duration, basestring)
assert timestamp is None or isinstance(timestamp, int)
timedelta = duration_parser(duration)
if timestamp is None:
data = datetime.utcnow() + timedelta
else:
data = datetime.utcfromtimestamp(timestamp) + timedelta
return calendar.timegm(data.utctimetuple())
[docs] @staticmethod
def get_expire_time(cache_duration=None, valid_until=None):
"""
Compares 2 dates and returns the earliest.
:param cache_duration: The duration, as a string.
:type: string
:param valid_until: The valid until date, as a string or as a timestamp
:type: string
:return: The expiration time.
:rtype: int
"""
expire_time = None
if cache_duration is not None:
expire_time = OneLogin_Saml2_Utils.parse_duration(cache_duration)
if valid_until is not None:
if isinstance(valid_until, int):
valid_until_time = valid_until
else:
valid_until_time = OneLogin_Saml2_Utils.parse_SAML_to_time(valid_until)
if expire_time is None or expire_time > valid_until_time:
expire_time = valid_until_time
if expire_time is not None:
return '%d' % expire_time
return None
[docs] @staticmethod
def query(dom, query, context=None, tagid=None):
"""
Extracts nodes that match the query from the Element
:param dom: The root of the lxml objet
:type: Element
:param query: Xpath Expresion
:type: string
:param context: Context Node
:type: DOMElement
:param tagid: Tag ID
:type: string
:returns: The queried nodes
:rtype: list
"""
if context is None:
source = dom
else:
source = context
if tagid is None:
return source.xpath(query, namespaces=OneLogin_Saml2_Constants.NSMAP)
else:
return source.xpath(query, tagid=tagid, namespaces=OneLogin_Saml2_Constants.NSMAP)
[docs] @staticmethod
def delete_local_session(callback=None):
"""
Deletes the local session.
"""
if callback is not None:
callback()
[docs] @staticmethod
def calculate_x509_fingerprint(x509_cert, alg='sha1'):
"""
Calculates the fingerprint of a formatted x509cert.
:param x509_cert: x509 cert formatted
:type: string
:param alg: The algorithm to build the fingerprint
:type: string
:returns: fingerprint
:rtype: string
"""
assert isinstance(x509_cert, basestring)
lines = x509_cert.split('\n')
data = ''
inData = False
for line in lines:
# Remove '\r' from end of line if present.
line = line.rstrip()
if not inData:
if line == '-----BEGIN CERTIFICATE-----':
inData = True
elif line == '-----BEGIN PUBLIC KEY-----' or line == '-----BEGIN RSA PRIVATE KEY-----':
# This isn't an X509 certificate.
return None
else:
if line == '-----END CERTIFICATE-----':
break
# Append the current line to the certificate data.
data += line
if not data:
return None
decoded_data = base64.b64decode(data)
if alg == 'sha512':
fingerprint = sha512(decoded_data)
elif alg == 'sha384':
fingerprint = sha384(decoded_data)
elif alg == 'sha256':
fingerprint = sha256(decoded_data)
else:
fingerprint = sha1(decoded_data)
return fingerprint.hexdigest().lower()
[docs] @staticmethod
def generate_name_id(value, sp_nq, sp_format=None, cert=None, debug=False, nq=None):
"""
Generates a nameID.
:param value: fingerprint
:type: string
:param sp_nq: SP Name Qualifier
:type: string
:param sp_format: SP Format
:type: string
:param cert: IdP Public Cert to encrypt the nameID
:type: string
:param debug: Activate the xmlsec debug
:type: bool
:param nq: IDP Name Qualifier
:type: string
:returns: DOMElement | XMLSec nameID
:rtype: string
"""
doc = Document()
name_id_container = doc.createElementNS(OneLogin_Saml2_Constants.NS_SAML, 'container')
name_id_container.setAttribute("xmlns:saml", OneLogin_Saml2_Constants.NS_SAML)
name_id = doc.createElement('saml:NameID')
if sp_nq is not None:
name_id.setAttribute('SPNameQualifier', sp_nq)
if nq is not None:
name_id.setAttribute('NameQualifier', nq)
if sp_format is not None:
name_id.setAttribute('Format', sp_format)
name_id.appendChild(doc.createTextNode(value))
name_id_container.appendChild(name_id)
if cert is not None:
xml = name_id_container.toxml()
elem = fromstring(xml, forbid_dtd=True)
error_callback_method = None
if debug:
error_callback_method = print_xmlsec_errors
xmlsec.set_error_callback(error_callback_method)
# Load the public cert
mngr = xmlsec.KeysMngr()
file_cert = OneLogin_Saml2_Utils.write_temp_file(cert)
key_data = xmlsec.Key.load(file_cert.name, xmlsec.KeyDataFormatCertPem, None)
key_data.name = basename(file_cert.name)
mngr.addKey(key_data)
file_cert.close()
# Prepare for encryption
enc_data = EncData(xmlsec.TransformAes128Cbc, type=xmlsec.TypeEncElement)
enc_data.ensureCipherValue()
key_info = enc_data.ensureKeyInfo()
# enc_key = key_info.addEncryptedKey(xmlsec.TransformRsaPkcs1)
enc_key = key_info.addEncryptedKey(xmlsec.TransformRsaOaep)
enc_key.ensureCipherValue()
# Encrypt!
enc_ctx = xmlsec.EncCtx(mngr)
enc_ctx.encKey = xmlsec.Key.generate(xmlsec.KeyDataAes, 128, xmlsec.KeyDataTypeSession)
edata = enc_ctx.encryptXml(enc_data, elem[0])
newdoc = parseString(tostring(edata, encoding='unicode').encode('utf-8'), forbid_dtd=True, forbid_entities=True, forbid_external=True)
if newdoc.hasChildNodes():
child = newdoc.firstChild
child.removeAttribute('xmlns')
child.removeAttribute('xmlns:saml')
child.setAttribute('xmlns:xenc', OneLogin_Saml2_Constants.NS_XENC)
child.setAttribute('xmlns:dsig', OneLogin_Saml2_Constants.NS_DS)
nodes = newdoc.getElementsByTagName("*")
for node in nodes:
if node.tagName == 'ns0:KeyInfo':
node.tagName = 'dsig:KeyInfo'
node.removeAttribute('xmlns:ns0')
node.setAttribute('xmlns:dsig', OneLogin_Saml2_Constants.NS_DS)
else:
node.tagName = 'xenc:' + node.tagName
encrypted_id = newdoc.createElement('saml:EncryptedID')
encrypted_data = newdoc.replaceChild(encrypted_id, newdoc.firstChild)
encrypted_id.appendChild(encrypted_data)
return newdoc.saveXML(encrypted_id)
else:
return doc.saveXML(name_id)
[docs] @staticmethod
def get_status(dom):
"""
Gets Status from a Response.
:param dom: The Response as XML
:type: Document
:returns: The Status, an array with the code and a message.
:rtype: dict
"""
status = {}
status_entry = OneLogin_Saml2_Utils.query(dom, '/samlp:Response/samlp:Status')
if len(status_entry) != 1:
raise OneLogin_Saml2_ValidationError(
'Missing Status on response',
OneLogin_Saml2_ValidationError.MISSING_STATUS
)
code_entry = OneLogin_Saml2_Utils.query(dom, '/samlp:Response/samlp:Status/samlp:StatusCode', status_entry[0])
if len(code_entry) != 1:
raise OneLogin_Saml2_ValidationError(
'Missing Status Code on response',
OneLogin_Saml2_ValidationError.MISSING_STATUS_CODE
)
code = code_entry[0].values()[0]
status['code'] = code
status['msg'] = ''
message_entry = OneLogin_Saml2_Utils.query(dom, '/samlp:Response/samlp:Status/samlp:StatusMessage', status_entry[0])
if len(message_entry) == 0:
subcode_entry = OneLogin_Saml2_Utils.query(dom, '/samlp:Response/samlp:Status/samlp:StatusCode/samlp:StatusCode', status_entry[0])
if len(subcode_entry) == 1:
status['msg'] = subcode_entry[0].values()[0]
elif len(message_entry) == 1:
status['msg'] = OneLogin_Saml2_Utils.element_text(message_entry[0])
return status
[docs] @staticmethod
def decrypt_element(encrypted_data, key, debug=False, inplace=False):
"""
Decrypts an encrypted element.
:param encrypted_data: The encrypted data.
:type: lxml.etree.Element | DOMElement | basestring
:param key: The key.
:type: string
:param debug: Activate the xmlsec debug
:type: bool
:param inplace: update passed data with decrypted result
:type: bool
:returns: The decrypted element.
:rtype: lxml.etree.Element
"""
if isinstance(encrypted_data, Element):
encrypted_data = fromstring(str(encrypted_data.toxml()), forbid_dtd=True)
elif isinstance(encrypted_data, basestring):
encrypted_data = fromstring(str(encrypted_data), forbid_dtd=True)
elif not inplace and isinstance(encrypted_data, etree._Element):
encrypted_data = deepcopy(encrypted_data)
error_callback_method = None
if debug:
error_callback_method = print_xmlsec_errors
xmlsec.set_error_callback(error_callback_method)
mngr = xmlsec.KeysMngr()
key = xmlsec.Key.loadMemory(key, xmlsec.KeyDataFormatPem, None)
mngr.addKey(key)
enc_ctx = xmlsec.EncCtx(mngr)
return enc_ctx.decrypt(encrypted_data)
[docs] @staticmethod
def write_temp_file(content):
"""
Writes some content into a temporary file and returns it.
:param content: The file content
:type: string
:returns: The temporary file
:rtype: file-like object
"""
f_temp = NamedTemporaryFile(delete=True)
f_temp.file.write(content)
f_temp.file.flush()
return f_temp
[docs] @staticmethod
def add_sign(xml, key, cert, debug=False, sign_algorithm=OneLogin_Saml2_Constants.RSA_SHA256, digest_algorithm=OneLogin_Saml2_Constants.SHA256):
"""
Adds signature key and senders certificate to an element (Message or
Assertion).
:param xml: The element we should sign
:type: string | Document
:param key: The private key
:type: string
:param cert: The public
:type: string
:param debug: Activate the xmlsec debug
:type: bool
:param sign_algorithm: Signature algorithm method
:type sign_algorithm: string
:param digest_algorithm: Digest algorithm method
:type digest_algorithm: string
:returns: Signed XML
:rtype: string
"""
if xml is None or xml == '':
raise Exception('Empty string supplied as input')
elif isinstance(xml, etree._Element):
elem = xml
elif isinstance(xml, Document):
xml = xml.toxml()
elem = fromstring(xml.encode('utf-8'), forbid_dtd=True)
elif isinstance(xml, Element):
xml.setAttributeNS(
unicode(OneLogin_Saml2_Constants.NS_SAMLP),
'xmlns:samlp',
unicode(OneLogin_Saml2_Constants.NS_SAMLP)
)
xml.setAttributeNS(
unicode(OneLogin_Saml2_Constants.NS_SAML),
'xmlns:saml',
unicode(OneLogin_Saml2_Constants.NS_SAML)
)
xml = xml.toxml()
elem = fromstring(xml.encode('utf-8'), forbid_dtd=True)
elif isinstance(xml, basestring):
elem = fromstring(xml.encode('utf-8'), forbid_dtd=True)
else:
raise Exception('Error parsing xml string')
error_callback_method = None
if debug:
error_callback_method = print_xmlsec_errors
xmlsec.set_error_callback(error_callback_method)
sign_algorithm_transform_map = {
OneLogin_Saml2_Constants.DSA_SHA1: xmlsec.TransformDsaSha1,
OneLogin_Saml2_Constants.RSA_SHA1: xmlsec.TransformRsaSha1,
OneLogin_Saml2_Constants.RSA_SHA256: xmlsec.TransformRsaSha256,
OneLogin_Saml2_Constants.RSA_SHA384: xmlsec.TransformRsaSha384,
OneLogin_Saml2_Constants.RSA_SHA512: xmlsec.TransformRsaSha512
}
sign_algorithm_transform = sign_algorithm_transform_map.get(sign_algorithm, xmlsec.TransformRsaSha1)
signature = Signature(xmlsec.TransformExclC14N, sign_algorithm_transform, nsPrefix='ds')
issuer = OneLogin_Saml2_Utils.query(elem, '//saml:Issuer')
if len(issuer) > 0:
issuer = issuer[0]
issuer.addnext(signature)
elem_to_sign = issuer.getparent()
else:
entity_descriptor = OneLogin_Saml2_Utils.query(elem, '//md:EntityDescriptor')
if len(entity_descriptor) > 0:
elem.insert(0, signature)
else:
elem[0].insert(0, signature)
elem_to_sign = elem
elem_id = elem_to_sign.get('ID', None)
if elem_id is not None:
if elem_id:
elem_id = '#' + elem_id
else:
generated_id = generated_id = OneLogin_Saml2_Utils.generate_unique_id()
elem_id = '#' + generated_id
elem_to_sign.attrib['ID'] = generated_id
xmlsec.addIDs(elem_to_sign, ["ID"])
digest_algorithm_transform_map = {
OneLogin_Saml2_Constants.SHA1: xmlsec.TransformSha1,
OneLogin_Saml2_Constants.SHA256: xmlsec.TransformSha256,
OneLogin_Saml2_Constants.SHA384: xmlsec.TransformSha384,
OneLogin_Saml2_Constants.SHA512: xmlsec.TransformSha512
}
digest_algorithm_transform = digest_algorithm_transform_map.get(digest_algorithm, xmlsec.TransformSha1)
ref = signature.addReference(digest_algorithm_transform)
if elem_id:
ref.attrib['URI'] = elem_id
ref.addTransform(xmlsec.TransformEnveloped)
ref.addTransform(xmlsec.TransformExclC14N)
key_info = signature.ensureKeyInfo()
key_info.addX509Data()
dsig_ctx = xmlsec.DSigCtx()
sign_key = xmlsec.Key.loadMemory(key, xmlsec.KeyDataFormatPem, None)
file_cert = OneLogin_Saml2_Utils.write_temp_file(cert)
sign_key.loadCert(file_cert.name, xmlsec.KeyDataFormatCertPem)
file_cert.close()
dsig_ctx.signKey = sign_key
dsig_ctx.sign(signature)
return tostring(elem, encoding='unicode').encode('utf-8')
[docs] @staticmethod
@return_false_on_exception
def validate_sign(xml, cert=None, fingerprint=None, fingerprintalg='sha1', validatecert=False, debug=False, xpath=None, multicerts=None):
"""
Validates a signature (Message or Assertion).
:param xml: The element we should validate
:type: string | Document
:param cert: The pubic cert
:type: string
:param fingerprint: The fingerprint of the public cert
:type: string
:param fingerprintalg: The algorithm used to build the fingerprint
:type: string
:param validatecert: If true, will verify the signature and if the cert is valid.
:type: bool
:param debug: Activate the xmlsec debug
:type: bool
:param xpath: The xpath of the signed element
:type: string
:param multicerts: Multiple public certs
:type: list
:param raise_exceptions: Whether to return false on failure or raise an exception
:type raise_exceptions: Boolean
"""
if xml is None or xml == '':
raise Exception('Empty string supplied as input')
elif isinstance(xml, etree._Element):
elem = xml
elif isinstance(xml, Document):
xml = xml.toxml()
elem = fromstring(str(xml), forbid_dtd=True)
elif isinstance(xml, Element):
xml.setAttributeNS(
unicode(OneLogin_Saml2_Constants.NS_SAMLP),
'xmlns:samlp',
unicode(OneLogin_Saml2_Constants.NS_SAMLP)
)
xml.setAttributeNS(
unicode(OneLogin_Saml2_Constants.NS_SAML),
'xmlns:saml',
unicode(OneLogin_Saml2_Constants.NS_SAML)
)
xml = xml.toxml()
elem = fromstring(str(xml), forbid_dtd=True)
elif isinstance(xml, basestring):
elem = fromstring(str(xml), forbid_dtd=True)
else:
raise Exception('Error parsing xml string')
error_callback_method = None
if debug:
error_callback_method = print_xmlsec_errors
xmlsec.set_error_callback(error_callback_method)
xmlsec.addIDs(elem, ["ID"])
if xpath:
signature_nodes = OneLogin_Saml2_Utils.query(elem, xpath)
else:
signature_nodes = OneLogin_Saml2_Utils.query(elem, OneLogin_Saml2_Utils.RESPONSE_SIGNATURE_XPATH)
if len(signature_nodes) == 0:
signature_nodes = OneLogin_Saml2_Utils.query(elem, OneLogin_Saml2_Utils.ASSERTION_SIGNATURE_XPATH)
if len(signature_nodes) == 1:
signature_node = signature_nodes[0]
if not multicerts:
return OneLogin_Saml2_Utils.validate_node_sign(signature_node, elem, cert, fingerprint, fingerprintalg, validatecert, debug, raise_exceptions=True)
else:
# If multiple certs are provided, I may ignore cert and
# fingerprint provided by the method and just check the
# certs multicerts
fingerprint = fingerprintalg = None
for cert in multicerts:
if OneLogin_Saml2_Utils.validate_node_sign(signature_node, elem, cert, fingerprint, fingerprintalg, validatecert, False, raise_exceptions=False):
return True
raise OneLogin_Saml2_ValidationError('Signature validation failed. SAML Response rejected.')
else:
raise OneLogin_Saml2_ValidationError('Expected exactly one signature node; got {}.'.format(len(signature_nodes)), OneLogin_Saml2_ValidationError.WRONG_NUMBER_OF_SIGNATURES)
[docs] @staticmethod
@return_false_on_exception
def validate_node_sign(signature_node, elem, cert=None, fingerprint=None, fingerprintalg='sha1', validatecert=False, debug=False):
"""
Validates a signature node.
:param signature_node: The signature node
:type: Node
:param xml: The element we should validate
:type: Document
:param cert: The public cert
:type: string
:param fingerprint: The fingerprint of the public cert
:type: string
:param fingerprintalg: The algorithm used to build the fingerprint
:type: string
:param validatecert: If true, will verify the signature and if the cert is valid.
:type: bool
:param debug: Activate the xmlsec debug
:type: bool
:param raise_exceptions: Whether to return false on failure or raise an exception
:type raise_exceptions: Boolean
"""
error_callback_method = None
if debug:
error_callback_method = print_xmlsec_errors
xmlsec.set_error_callback(error_callback_method)
xmlsec.addIDs(elem, ["ID"])
if (cert is None or cert == '') and fingerprint:
x509_certificate_nodes = OneLogin_Saml2_Utils.query(signature_node, '//ds:Signature/ds:KeyInfo/ds:X509Data/ds:X509Certificate')
if len(x509_certificate_nodes) > 0:
x509_certificate_node = x509_certificate_nodes[0]
x509_cert_value = OneLogin_Saml2_Utils.element_text(x509_certificate_node)
x509_cert_value_formatted = OneLogin_Saml2_Utils.format_cert(x509_cert_value)
x509_fingerprint_value = OneLogin_Saml2_Utils.calculate_x509_fingerprint(x509_cert_value_formatted, fingerprintalg)
if fingerprint == x509_fingerprint_value:
cert = x509_cert_value_formatted
# Check if Reference URI is empty
# reference_elem = OneLogin_Saml2_Utils.query(signature_node, '//ds:Reference')
# if len(reference_elem) > 0:
# if reference_elem[0].get('URI') == '':
# reference_elem[0].set('URI', '#%s' % signature_node.getparent().get('ID'))
if cert is None or cert == '':
raise OneLogin_Saml2_Error(
'Could not validate node signature: No certificate provided.',
OneLogin_Saml2_Error.CERT_NOT_FOUND
)
file_cert = OneLogin_Saml2_Utils.write_temp_file(cert)
if validatecert:
mngr = xmlsec.KeysMngr()
mngr.loadCert(file_cert.name, xmlsec.KeyDataFormatCertPem, xmlsec.KeyDataTypeTrusted)
dsig_ctx = xmlsec.DSigCtx(mngr)
else:
dsig_ctx = xmlsec.DSigCtx()
dsig_ctx.signKey = xmlsec.Key.load(file_cert.name, xmlsec.KeyDataFormatCertPem, None)
file_cert.close()
dsig_ctx.setEnabledKeyData([xmlsec.KeyDataX509])
try:
dsig_ctx.verify(signature_node)
except Exception as err:
raise OneLogin_Saml2_ValidationError(
'Signature validation failed. SAML Response rejected. %s',
OneLogin_Saml2_ValidationError.INVALID_SIGNATURE,
err.__str__()
)
return True
[docs] @staticmethod
@return_false_on_exception
def validate_binary_sign(signed_query, signature, cert=None, algorithm=OneLogin_Saml2_Constants.RSA_SHA1, debug=False):
"""
Validates signed binary data (Used to validate GET Signature).
:param signed_query: The element we should validate
:type: string
:param signature: The signature that will be validate
:type: string
:param cert: The public cert
:type: string
:param algorithm: Signature algorithm
:type: string
:param debug: Activate the xmlsec debug
:type: bool
:param raise_exceptions: Whether to return false on failure or raise an exception
:type raise_exceptions: Boolean
"""
error_callback_method = None
if debug:
error_callback_method = print_xmlsec_errors
xmlsec.set_error_callback(error_callback_method)
dsig_ctx = xmlsec.DSigCtx()
file_cert = OneLogin_Saml2_Utils.write_temp_file(cert)
dsig_ctx.signKey = xmlsec.Key.load(file_cert.name, xmlsec.KeyDataFormatCertPem, None)
file_cert.close()
# Sign the metadata with our private key.
sign_algorithm_transform_map = {
OneLogin_Saml2_Constants.DSA_SHA1: xmlsec.TransformDsaSha1,
OneLogin_Saml2_Constants.RSA_SHA1: xmlsec.TransformRsaSha1,
OneLogin_Saml2_Constants.RSA_SHA256: xmlsec.TransformRsaSha256,
OneLogin_Saml2_Constants.RSA_SHA384: xmlsec.TransformRsaSha384,
OneLogin_Saml2_Constants.RSA_SHA512: xmlsec.TransformRsaSha512
}
sign_algorithm_transform = sign_algorithm_transform_map.get(algorithm, xmlsec.TransformRsaSha1)
dsig_ctx.verifyBinary(signed_query, sign_algorithm_transform, signature)
return True
[docs] @staticmethod
def get_encoded_parameter(get_data, name, default=None, lowercase_urlencoding=False):
"""Return a URL encoded get parameter value
Prefer to extract the original encoded value directly from query_string since URL
encoding is not canonical. The encoding used by ADFS 3.0 is not compatible with
python's quote_plus (ADFS produces lower case hex numbers and quote_plus produces
upper case hex numbers)
"""
if name not in get_data:
return OneLogin_Saml2_Utils.case_sensitive_urlencode(default, lowercase_urlencoding)
if 'query_string' in get_data:
return OneLogin_Saml2_Utils.extract_raw_query_parameter(get_data['query_string'], name)
return OneLogin_Saml2_Utils.case_sensitive_urlencode(get_data[name], lowercase_urlencoding)
[docs] @staticmethod
def case_sensitive_urlencode(to_encode, lowercase=False):
encoded = quote_plus(to_encode)
return re.sub(r"%[A-F0-9]{2}", lambda m: m.group(0).lower(), encoded) if lowercase else encoded
[docs] @staticmethod
def normalize_url(url):
"""
Returns normalized URL for comparison.
This method converts the netloc to lowercase, as it should be case-insensitive (per RFC 4343, RFC 7617)
If standardization fails, the original URL is returned
Python documentation indicates that URL split also normalizes query strings if empty query fields are present
:param url: URL
:type url: String
:returns: A normalized URL, or the given URL string if parsing fails
:rtype: String
"""
try:
scheme, netloc, path, query, fragment = urlsplit(url)
normalized_url = urlunsplit((scheme.lower(), netloc.lower(), path, query, fragment))
return normalized_url
except Exception:
return url