diff --git a/tlslite/handshakehashes.py b/tlslite/handshakehashes.py
index 137b20b..e4b3fcc 100644
--- a/tlslite/handshakehashes.py
+++ b/tlslite/handshakehashes.py
@@ -27,7 +27,13 @@ class HandshakeHashes(object):
:param bytearray data: serialized TLS handshake message
"""
- pass
+ self._handshakeMD5.update(compat26Str(data))
+ self._handshakeSHA.update(compat26Str(data))
+ self._handshakeSHA224.update(compat26Str(data))
+ self._handshakeSHA256.update(compat26Str(data))
+ self._handshakeSHA384.update(compat26Str(data))
+ self._handshakeSHA512.update(compat26Str(data))
+ self._handshake_buffer += data
def digest(self, digest=None):
"""
@@ -37,7 +43,22 @@ class HandshakeHashes(object):
:param str digest: name of digest to return
"""
- pass
+ if digest is None:
+ return self._handshakeMD5.digest() + self._handshakeSHA.digest()
+ elif digest == 'md5':
+ return self._handshakeMD5.digest()
+ elif digest == 'sha1':
+ return self._handshakeSHA.digest()
+ elif digest == 'sha224':
+ return self._handshakeSHA224.digest()
+ elif digest == 'sha256':
+ return self._handshakeSHA256.digest()
+ elif digest == 'sha384':
+ return self._handshakeSHA384.digest()
+ elif digest == 'sha512':
+ return self._handshakeSHA512.digest()
+ else:
+ raise ValueError("Unknown digest type")
def digestSSL(self, masterSecret, label):
"""
@@ -48,7 +69,11 @@ class HandshakeHashes(object):
:param bytearray masterSecret: value of the master secret
:param bytearray label: label to include in the calculation
"""
- pass
+ inner1 = label + masterSecret + bytearray([0x36] * 48)
+ inner2 = label + masterSecret + bytearray([0x5c] * 48)
+ md5_inner = MD5(inner1 + self._handshakeMD5.digest())
+ sha_inner = SHA1(inner2 + self._handshakeSHA.digest())
+ return md5_inner + sha_inner
def copy(self):
"""
@@ -59,4 +84,12 @@ class HandshakeHashes(object):
:rtype: HandshakeHashes
"""
- pass
\ No newline at end of file
+ new = HandshakeHashes()
+ new._handshakeMD5 = self._handshakeMD5.copy()
+ new._handshakeSHA = self._handshakeSHA.copy()
+ new._handshakeSHA224 = self._handshakeSHA224.copy()
+ new._handshakeSHA256 = self._handshakeSHA256.copy()
+ new._handshakeSHA384 = self._handshakeSHA384.copy()
+ new._handshakeSHA512 = self._handshakeSHA512.copy()
+ new._handshake_buffer = self._handshake_buffer[:]
+ return new
\ No newline at end of file
diff --git a/tlslite/handshakesettings.py b/tlslite/handshakesettings.py
index 2f8e245..5877645 100644
--- a/tlslite/handshakesettings.py
+++ b/tlslite/handshakesettings.py
@@ -316,11 +316,37 @@ class HandshakeSettings(object):
def _init_key_settings(self):
"""Create default variables for key-related settings."""
- pass
+ self.minKeySize = 1023
+ self.maxKeySize = 8193
+ self.rsaSigHashes = list(RSA_SIGNATURE_HASHES)
+ self.dsaSigHashes = list(DSA_SIGNATURE_HASHES)
+ self.ecdsaSigHashes = list(ECDSA_SIGNATURE_HASHES)
+ self.more_sig_schemes = list(SIGNATURE_SCHEMES)
+ self.rsa_schemes = list(RSA_SCHEMES)
+ self.eccCurves = list(CURVE_NAMES)
+ self.dhGroups = list(ALL_DH_GROUP_NAMES)
+ self.defaultCurve = "secp256r1"
+ self.keyShares = ["secp256r1", "x25519"]
+ self.certificateTypes = list(CERTIFICATE_TYPES)
def _init_misc_extensions(self):
"""Default variables for assorted extensions."""
- pass
+ self.useExtendedMasterSecret = True
+ self.requireExtendedMasterSecret = False
+ self.useEncryptThenMAC = True
+ self.useExperimentalTackExtension = False
+ self.sendFallbackSCSV = False
+ self.use_heartbeat_extension = True
+ self.heartbeat_response_callback = None
+ self.record_size_limit = None
+ self.padding_cb = None
+ self.pskConfigs = []
+ self.psk_modes = list(PSK_MODES)
+ self.max_early_data = 0
+ self.ticketKeys = []
+ self.ticketCipher = "aes256gcm"
+ self.ticketLifetime = 24 * 60 * 60 # 1 day
+ self.ticket_count = 1
def __init__(self):
"""Initialise default values for settings."""
@@ -337,86 +363,327 @@ class HandshakeSettings(object):
@staticmethod
def _sanityCheckKeySizes(other):
"""Check if key size limits are sane"""
- pass
+ if other.minKeySize < 512:
+ raise ValueError("minKeySize too small")
+ if other.minKeySize > 16384:
+ raise ValueError("minKeySize too large")
+ if other.maxKeySize < 512:
+ raise ValueError("maxKeySize too small")
+ if other.maxKeySize > 16384:
+ raise ValueError("maxKeySize too large")
+ if other.maxKeySize < other.minKeySize:
+ raise ValueError("maxKeySize smaller than minKeySize")
@staticmethod
def _not_matching(values, sieve):
"""Return list of items from values that are not in sieve."""
- pass
+ return [val for val in values if val not in sieve]
@staticmethod
def _sanityCheckCipherSettings(other):
"""Check if specified cipher settings are known."""
- pass
+ not_matching = HandshakeSettings._not_matching(other.cipherNames,
+ ALL_CIPHER_NAMES)
+ if not_matching:
+ raise ValueError("Unknown cipher name: {0}".format(not_matching))
+
+ not_matching = HandshakeSettings._not_matching(other.macNames,
+ ALL_MAC_NAMES)
+ if not_matching:
+ raise ValueError("Unknown MAC name: {0}".format(not_matching))
+
+ not_matching = HandshakeSettings._not_matching(
+ other.cipherImplementations, CIPHER_IMPLEMENTATIONS)
+ if not_matching:
+ raise ValueError("Unknown cipher implementation: {0}"
+ .format(not_matching))
@staticmethod
def _sanityCheckECDHSettings(other):
"""Check ECDHE settings if they are sane."""
- pass
+ not_matching = HandshakeSettings._not_matching(other.eccCurves,
+ ALL_CURVE_NAMES)
+ if not_matching:
+ raise ValueError("Unknown ECC curve name: {0}".format(not_matching))
+
+ if other.defaultCurve not in ALL_CURVE_NAMES:
+ raise ValueError("Unknown default ECC curve name: {0}"
+ .format(other.defaultCurve))
@staticmethod
def _sanityCheckDHSettings(other):
"""Check if (EC)DHE settings are sane."""
- pass
+ not_matching = HandshakeSettings._not_matching(other.dhGroups,
+ ALL_DH_GROUP_NAMES)
+ if not_matching:
+ raise ValueError("Unknown DH group name: {0}".format(not_matching))
@staticmethod
def _sanityCheckPrimitivesNames(other):
"""Check if specified cryptographic primitive names are known"""
- pass
+ not_matching = HandshakeSettings._not_matching(other.rsaSigHashes,
+ ALL_RSA_SIGNATURE_HASHES)
+ if not_matching:
+ raise ValueError("Unknown RSA signature hash: {0}"
+ .format(not_matching))
+
+ not_matching = HandshakeSettings._not_matching(other.dsaSigHashes,
+ DSA_SIGNATURE_HASHES)
+ if not_matching:
+ raise ValueError("Unknown DSA signature hash: {0}"
+ .format(not_matching))
+
+ not_matching = HandshakeSettings._not_matching(other.ecdsaSigHashes,
+ ECDSA_SIGNATURE_HASHES)
+ if not_matching:
+ raise ValueError("Unknown ECDSA signature hash: {0}"
+ .format(not_matching))
+
+ not_matching = HandshakeSettings._not_matching(other.more_sig_schemes,
+ SIGNATURE_SCHEMES)
+ if not_matching:
+ raise ValueError("Unknown signature scheme: {0}"
+ .format(not_matching))
+
+ not_matching = HandshakeSettings._not_matching(other.rsa_schemes,
+ RSA_SCHEMES)
+ if not_matching:
+ raise ValueError("Unknown RSA scheme: {0}"
+ .format(not_matching))
@staticmethod
def _sanityCheckProtocolVersions(other):
"""Check if set protocol version are sane"""
- pass
+ if other.minVersion > other.maxVersion:
+ raise ValueError("Versions set incorrectly")
+ if other.minVersion not in KNOWN_VERSIONS:
+ raise ValueError("Unknown minimum protocol version")
+ if other.maxVersion not in KNOWN_VERSIONS:
+ raise ValueError("Unknown maximum protocol version")
@staticmethod
def _sanityCheckEMSExtension(other):
"""Check if settings for EMS are sane."""
- pass
+ if other.requireExtendedMasterSecret and \
+ not other.useExtendedMasterSecret:
+ raise ValueError("Require Extended Master Secret but don't use it")
@staticmethod
def _sanityCheckExtensions(other):
"""Check if set extension settings are sane"""
- pass
+ if other.record_size_limit is not None:
+ if not 64 <= other.record_size_limit <= 2**14+1:
+ raise ValueError("Record size limit must be between 64 and 16385")
@staticmethod
def _not_allowed_len(values, sieve):
"""Return True if length of any item in values is not in sieve."""
- pass
+ return any(len(i) not in sieve for i in values)
@staticmethod
def _sanityCheckPsks(other):
"""Check if the set PSKs are sane."""
- pass
+ if other.psk_modes:
+ not_matching = HandshakeSettings._not_matching(other.psk_modes,
+ PSK_MODES)
+ if not_matching:
+ raise ValueError("Unknown PSK mode: {0}".format(not_matching))
+
+ if not all(isinstance(i, (bytes, bytearray)) for i in other.pskConfigs):
+ raise ValueError("PSK identity must be a bytes-like object")
@staticmethod
def _sanityCheckTicketSettings(other):
"""Check if the session ticket settings are sane."""
- pass
+ if other.ticketCipher not in TICKET_CIPHERS:
+ raise ValueError("Unknown ticket cipher")
+
+ if other.ticketLifetime < 0:
+ raise ValueError("Ticket lifetime must be a positive integer")
+
+ if other.ticket_count < 0:
+ raise ValueError("Ticket count must be a positive integer")
def _copy_cipher_settings(self, other):
"""Copy values related to cipher selection."""
- pass
+ other.cipherNames = self.cipherNames
+ other.macNames = self.macNames
+ other.keyExchangeNames = self.keyExchangeNames
+ other.cipherImplementations = self.cipherImplementations
+ other.minVersion = self.minVersion
+ other.maxVersion = self.maxVersion
+ other.versions = self.versions
def _copy_extension_settings(self, other):
"""Copy values of settings related to extensions."""
- pass
+ other.useExtendedMasterSecret = self.useExtendedMasterSecret
+ other.requireExtendedMasterSecret = self.requireExtendedMasterSecret
+ other.useEncryptThenMAC = self.useEncryptThenMAC
+ other.useExperimentalTackExtension = self.useExperimentalTackExtension
+ other.sendFallbackSCSV = self.sendFallbackSCSV
+ other.use_heartbeat_extension = self.use_heartbeat_extension
+ other.heartbeat_response_callback = self.heartbeat_response_callback
+ other.record_size_limit = self.record_size_limit
+ other.padding_cb = self.padding_cb
+ other.pskConfigs = self.pskConfigs
+ other.psk_modes = self.psk_modes
+ other.max_early_data = self.max_early_data
+ other.ticketKeys = self.ticketKeys
+ other.ticketCipher = self.ticketCipher
+ other.ticketLifetime = self.ticketLifetime
+ other.ticket_count = self.ticket_count
@staticmethod
def _remove_all_matches(values, needle):
"""Remove all instances of needle from values."""
- pass
+ while needle in values:
+ values.remove(needle)
+
+ def getCertificateTypes(self):
+ """Get list of certificate types as IDs."""
+ ret = []
+ if "x509" in self.certificateTypes:
+ ret.append(CertificateType.x509)
+ return ret
+
+ def validate(self):
+ """
+ Validate the settings, filter out unsupported ciphersuites and return
+ a copy of object.
+
+ This method checks if the settings are consistent and if they can be
+ used for a connection. It checks if all selected algorithms are
+ supported.
+
+ :rtype: HandshakeSettings
+ :returns: a copy of self with all settings validated
+ :raises ValueError: when settings are invalid, inconsistent or unsupported
+ """
+ other = HandshakeSettings()
+
+ # Copy values
+ other.minKeySize = self.minKeySize
+ other.maxKeySize = self.maxKeySize
+ other.cipherNames = self.cipherNames
+ other.macNames = self.macNames
+ other.keyExchangeNames = self.keyExchangeNames
+ other.cipherImplementations = self.cipherImplementations
+ other.rsaSigHashes = self.rsaSigHashes
+ other.dsaSigHashes = self.dsaSigHashes
+ other.ecdsaSigHashes = self.ecdsaSigHashes
+ other.more_sig_schemes = self.more_sig_schemes
+ other.rsa_schemes = self.rsa_schemes
+ other.eccCurves = self.eccCurves
+ other.dhGroups = self.dhGroups
+ other.defaultCurve = self.defaultCurve
+ other.keyShares = self.keyShares
+ other.certificateTypes = self.certificateTypes
+ other.minVersion = self.minVersion
+ other.maxVersion = self.maxVersion
+ other.versions = self.versions
+ other.useExtendedMasterSecret = self.useExtendedMasterSecret
+ other.requireExtendedMasterSecret = self.requireExtendedMasterSecret
+ other.useEncryptThenMAC = self.useEncryptThenMAC
+ other.useExperimentalTackExtension = self.useExperimentalTackExtension
+ other.sendFallbackSCSV = self.sendFallbackSCSV
+ other.use_heartbeat_extension = self.use_heartbeat_extension
+ other.heartbeat_response_callback = self.heartbeat_response_callback
+ other.record_size_limit = self.record_size_limit
+ other.padding_cb = self.padding_cb
+ other.pskConfigs = self.pskConfigs
+ other.psk_modes = self.psk_modes
+ other.max_early_data = self.max_early_data
+ other.ticketKeys = self.ticketKeys
+ other.ticketCipher = self.ticketCipher
+ other.ticketLifetime = self.ticketLifetime
+ other.ticket_count = self.ticket_count
+
+ # Perform sanity checks
+ self._sanityCheckKeySizes(other)
+ self._sanityCheckCipherSettings(other)
+ self._sanityCheckECDHSettings(other)
+ self._sanityCheckDHSettings(other)
+ self._sanityCheckPrimitivesNames(other)
+ self._sanityCheckProtocolVersions(other)
+ self._sanityCheckEMSExtension(other)
+ self._sanityCheckExtensions(other)
+ self._sanityCheckPsks(other)
+ self._sanityCheckTicketSettings(other)
+
+ return other
def _sanity_check_ciphers(self, other):
"""Remove unsupported ciphers in current configuration."""
- pass
+ if not other.cipherNames:
+ raise ValueError("No cipher names specified")
+
+ if not other.macNames:
+ raise ValueError("No MAC names specified")
+ if not other.keyExchangeNames:
+ raise ValueError("No key exchange algorithms specified")
+
+ if not other.cipherImplementations:
+ raise ValueError("No cipher implementations specified")
+
+ # Remove ciphers that are not supported
+ for cipher in other.cipherNames[:]:
+ if cipher not in ALL_CIPHER_NAMES:
+ other.cipherNames.remove(cipher)
+
+ # Remove MACs that are not supported
+ for mac in other.macNames[:]:
+ if mac not in ALL_MAC_NAMES:
+ other.macNames.remove(mac)
+
+ # Remove key exchange algorithms that are not supported
+ for kex in other.keyExchangeNames[:]:
+ if kex not in KEY_EXCHANGE_NAMES:
+ other.keyExchangeNames.remove(kex)
+
+ # Remove cipher implementations that are not supported
+ for impl in other.cipherImplementations[:]:
+ if impl not in CIPHER_IMPLEMENTATIONS:
+ other.cipherImplementations.remove(impl)
+
+ # Check if any ciphers remain
+ if not other.cipherNames:
+ raise ValueError("No supported cipher names")
+
+ # Check if any MACs remain
+ if not other.macNames:
+ raise ValueError("No supported MAC names")
+
+ # Check if any key exchange algorithms remain
+ if not other.keyExchangeNames:
+ raise ValueError("No supported key exchange algorithms")
+
+ # Check if any cipher implementations remain
+ if not other.cipherImplementations:
+ raise ValueError("No supported cipher implementations")
def _sanity_check_implementations(self, other):
"""Remove all backends that are not loaded."""
- pass
+ if not cryptomath.m2cryptoLoaded:
+ self._remove_all_matches(other.cipherImplementations, "openssl")
+ if not cryptomath.pycryptoLoaded:
+ self._remove_all_matches(other.cipherImplementations, "pycrypto")
+ if not other.cipherImplementations:
+ raise ValueError("No supported cipher implementations")
def _copy_key_settings(self, other):
"""Copy key-related settings."""
+ other.minKeySize = self.minKeySize
+ other.maxKeySize = self.maxKeySize
+ other.rsaSigHashes = self.rsaSigHashes
+ other.dsaSigHashes = self.dsaSigHashes
+ other.ecdsaSigHashes = self.ecdsaSigHashes
+ other.more_sig_schemes = self.more_sig_schemes
+ other.rsa_schemes = self.rsa_schemes
+ other.eccCurves = self.eccCurves
+ other.dhGroups = self.dhGroups
+ other.defaultCurve = self.defaultCurve
+ other.keyShares = self.keyShares
+ other.certificateTypes = self.certificateTypes
pass
def validate(self):
diff --git a/tlslite/mathtls.py b/tlslite/mathtls.py
index 8e2820a..4692332 100644
--- a/tlslite/mathtls.py
+++ b/tlslite/mathtls.py
@@ -1,10 +1,43 @@
"""Miscellaneous helper functions."""
+import struct
from .utils.compat import *
from .utils.cryptomath import *
from .constants import CipherSuite
from .utils import tlshashlib as hashlib
from .utils import tlshmac as hmac
from .utils.deprecations import deprecated_method
+
+def createMAC_SSL(mac_key, seq_num, content_type, data, version):
+ """Create a SSL/early TLS MAC."""
+ mac = hmac.HMAC(mac_key, digestmod=hashlib.md5)
+ mac.update(bytearray(struct.pack(">Q", seq_num)))
+ mac.update(bytearray([content_type]))
+ mac.update(bytearray(struct.pack(">H", len(data))))
+ mac.update(data)
+ return mac.digest()
+
+def createHMAC(mac_key, seq_num, content_type, data, version, algorithm):
+ """Create a TLS 1.2 HMAC."""
+ mac = hmac.HMAC(mac_key, digestmod=getattr(hashlib, algorithm))
+ mac.update(bytearray(struct.pack(">Q", seq_num)))
+ mac.update(bytearray([content_type]))
+ mac.update(bytearray(struct.pack(">H", version[0])))
+ mac.update(bytearray(struct.pack(">H", version[1])))
+ mac.update(bytearray(struct.pack(">H", len(data))))
+ mac.update(data)
+ return mac.digest()
+
+def makeX(salt, username, password):
+ """Create an X value for SRP."""
+ return bytesToNumber(secureHash(salt + username + password, "sha1"))
+
+def makeU(N, A, B):
+ """Create a U value for SRP."""
+ return bytesToNumber(secureHash(numberToByteArray(A) + numberToByteArray(B), "sha1"))
+
+def makeK(N, g):
+ """Create a K value for SRP."""
+ return bytesToNumber(secureHash(numberToByteArray(N) + numberToByteArray(g), "sha1"))
FFDHE_PARAMETERS = {}
'\nListing of all well known FFDH parameters.\n\nPlease note that this dictionary includes all groups that are well-known\n(i.e. named), irrespective if their use is recommended or not.\n\nYou should use RFC7919_GROUPS for well-known secure groups.\n'
RFC2409_GROUP1 = (2, int(remove_whitespace('\n FFFFFFFF FFFFFFFF C90FDAA2 2168C234 C4C6628B 80DC1CD1\n 29024E08 8A67CC74 020BBEA6 3B139B22 514A0879 8E3404DD\n EF9519B3 CD3A431B 302B0A6D F25F1437 4FE1356D 6D51C245\n E485B576 625E7EC6 F44C42E9 A63A3620 FFFFFFFF FFFFFFFF'), 16))
diff --git a/tlslite/messages.py b/tlslite/messages.py
index 7f0b0cd..3ea54f0 100644
--- a/tlslite/messages.py
+++ b/tlslite/messages.py
@@ -238,7 +238,12 @@ class ClientHello(HelloMessage):
.. deprecated:: 0.5
use extensions field to get the extension for inspection
"""
- pass
+ if self.extensions is None:
+ return None
+ for ext in self.extensions:
+ if isinstance(ext, ClientCertTypeExtension):
+ return ext.cert_types
+ return None
@certificate_types.setter
def certificate_types(self, val):
@@ -253,7 +258,16 @@ class ClientHello(HelloMessage):
:param val: list of supported certificate types by client encoded as
single byte integers
"""
- pass
+ if self.extensions is None:
+ self.extensions = []
+
+ for ext in self.extensions:
+ if isinstance(ext, ClientCertTypeExtension):
+ ext.cert_types = val
+ return
+
+ ext = ClientCertTypeExtension().create(val)
+ self.extensions.append(ext)
@property
def srp_username(self):
@@ -263,7 +277,12 @@ class ClientHello(HelloMessage):
.. deprecated:: 0.5
use extensions field to get the extension for inspection
"""
- pass
+ if self.extensions is None:
+ return None
+ for ext in self.extensions:
+ if isinstance(ext, SRPExtension):
+ return ext.srp_username
+ return None
@srp_username.setter
def srp_username(self, name):
@@ -273,7 +292,16 @@ class ClientHello(HelloMessage):
:type name: bytearray
:param name: UTF-8 encoded username
"""
- pass
+ if self.extensions is None:
+ self.extensions = []
+
+ for ext in self.extensions:
+ if isinstance(ext, SRPExtension):
+ ext.srp_username = name
+ return
+
+ ext = SRPExtension().create(name)
+ self.extensions.append(ext)
@property
def tack(self):
@@ -285,7 +313,12 @@ class ClientHello(HelloMessage):
:rtype: boolean
"""
- pass
+ if self.extensions is None:
+ return False
+ for ext in self.extensions:
+ if isinstance(ext, TACKExtension):
+ return True
+ return False
@tack.setter
def tack(self, present):
@@ -296,7 +329,18 @@ class ClientHello(HelloMessage):
:param present: True will create extension while False will remove
extension from client hello
"""
- pass
+ if self.extensions is None:
+ self.extensions = []
+
+ for ext in self.extensions:
+ if isinstance(ext, TACKExtension):
+ if not present:
+ self.extensions.remove(ext)
+ return
+
+ if present:
+ ext = TACKExtension().create()
+ self.extensions.append(ext)
@property
def supports_npn(self):
@@ -308,7 +352,12 @@ class ClientHello(HelloMessage):
:rtype: boolean
"""
- pass
+ if self.extensions is None:
+ return False
+ for ext in self.extensions:
+ if isinstance(ext, NPNExtension):
+ return True
+ return False
@supports_npn.setter
def supports_npn(self, present):
@@ -319,7 +368,18 @@ class ClientHello(HelloMessage):
:param present: selects whatever to create or remove the extension
from list of supported ones
"""
- pass
+ if self.extensions is None:
+ self.extensions = []
+
+ for ext in self.extensions:
+ if isinstance(ext, NPNExtension):
+ if not present:
+ self.extensions.remove(ext)
+ return
+
+ if present:
+ ext = NPNExtension().create([])
+ self.extensions.append(ext)
@property
def server_name(self):
@@ -331,7 +391,13 @@ class ClientHello(HelloMessage):
:rtype: bytearray
"""
- pass
+ if self.extensions is None:
+ return None
+ for ext in self.extensions:
+ if isinstance(ext, SNIExtension):
+ if ext.host_names:
+ return ext.host_names[0]
+ return None
@server_name.setter
def server_name(self, hostname):
@@ -341,7 +407,16 @@ class ClientHello(HelloMessage):
:type hostname: bytearray
:param hostname: name of the host_name to set
"""
- pass
+ if self.extensions is None:
+ self.extensions = []
+
+ for ext in self.extensions:
+ if isinstance(ext, SNIExtension):
+ ext.host_names = [hostname]
+ return
+
+ ext = SNIExtension().create([hostname])
+ self.extensions.append(ext)
def create(self, version, random, session_id, cipher_suites, certificate_types=None, srpUsername=None, tack=False, supports_npn=None, serverName=None, extensions=None):
"""
@@ -386,19 +461,90 @@ class ClientHello(HelloMessage):
:type extensions: list of :py:class:`~.extensions.TLSExtension`
:param extensions: list of extensions to advertise
"""
- pass
+ self.client_version = version
+ self.random = random
+ self.session_id = session_id
+ self.cipher_suites = cipher_suites
+ self.compression_methods = [0] # only null compression
+ self.extensions = extensions
+
+ if certificate_types is not None:
+ self.certificate_types = certificate_types
+ if srpUsername is not None:
+ self.srp_username = srpUsername
+ if tack:
+ self.tack = tack
+ if supports_npn is not None:
+ self.supports_npn = supports_npn
+ if serverName is not None:
+ self.server_name = serverName
def parse(self, p):
"""Deserialise object from on the wire data."""
- pass
+ if self.ssl2:
+ self.client_version = (p.get(1), p.get(1))
+ cipher_suites_length = p.get(2)
+ session_id_length = p.get(2)
+ challenge_length = p.get(2)
+ self.cipher_suites = []
+ for i in range(cipher_suites_length // 3):
+ self.cipher_suites.append(p.get(3))
+ self.session_id = p.getFixBytes(session_id_length)
+ self.random = p.getFixBytes(challenge_length)
+ self.compression_methods = [0] # SSL2 has no compression
+ return None
+
+ self.client_version = (p.get(1), p.get(1))
+ self.random = p.getFixBytes(32)
+ session_id_length = p.get(1)
+ self.session_id = p.getFixBytes(session_id_length)
+ cipher_suites_length = p.get(2)
+ self.cipher_suites = []
+ for i in range(cipher_suites_length // 2):
+ self.cipher_suites.append(p.get(2))
+ compression_methods_length = p.get(1)
+ self.compression_methods = []
+ for i in range(compression_methods_length):
+ self.compression_methods.append(p.get(1))
+
+ if p.getRemainingLength() > 0:
+ self.extensions = []
+ extensions_length = p.get(2)
+ while p.getRemainingLength() > 0:
+ ext = TLSExtension().parse(p)
+ self.extensions.append(ext)
def _writeSSL2(self):
"""Serialise SSLv2 object to on the wire data."""
- pass
+ w = Writer()
+ w.add(self.client_version[0], 1)
+ w.add(self.client_version[1], 1)
+ w.add(len(self.cipher_suites) * 3, 2)
+ w.add(len(self.session_id), 2)
+ w.add(len(self.random), 2)
+ for cipher_suite in self.cipher_suites:
+ w.addFixSeq(bytearray([0x00, cipher_suite >> 8, cipher_suite & 0xFF]), 3)
+ w.addFixSeq(self.session_id, len(self.session_id))
+ w.addFixSeq(self.random, len(self.random))
+ return w.bytes
def _write(self):
"""Serialise SSLv3 or TLS object to on the wire data."""
- pass
+ w = Writer()
+ w.add(self.client_version[0], 1)
+ w.add(self.client_version[1], 1)
+ w.addFixSeq(self.random, 32)
+ w.addVarSeq(self.session_id, 1, 1)
+ w.addVarSeq(Writer.array_to_bytes(self.cipher_suites, 2), 2, 2)
+ w.addVarSeq(Writer.array_to_bytes(self.compression_methods, 1), 1, 1)
+
+ if self.extensions is not None:
+ w2 = Writer()
+ for ext in self.extensions:
+ w2.bytes += ext.write()
+ w.add(len(w2.bytes), 2)
+ w.bytes += w2.bytes
+ return w.bytes
def psk_truncate(self):
"""Return a truncated encoding of message without binders.
@@ -411,11 +557,30 @@ class ClientHello(HelloMessage):
:rtype: bytearray
"""
- pass
+ if not self.extensions:
+ return self.write()
+
+ psk_ext = None
+ for ext in self.extensions:
+ if isinstance(ext, PreSharedKeyExtension):
+ psk_ext = ext
+ break
+
+ if not psk_ext:
+ return self.write()
+
+ # remove the binders from the extension
+ old_binders = psk_ext.binders
+ psk_ext.binders = None
+ ret = self.write()
+ psk_ext.binders = old_binders
+ return ret
def write(self):
"""Serialise object to on the wire data."""
- pass
+ if self.ssl2:
+ return self._writeSSL2()
+ return self._write()
class HelloRequest(HandshakeMsg):
"""
diff --git a/tlslite/utils/compat.py b/tlslite/utils/compat.py
index 4c66fab..728509b 100644
--- a/tlslite/utils/compat.py
+++ b/tlslite/utils/compat.py
@@ -8,95 +8,142 @@ import binascii
import traceback
import time
import ecdsa
+from binascii import a2b_hex, b2a_hex, a2b_base64, b2a_base64
+
+def compat26Str(x):
+ """Convert bytes or str to str"""
+ if isinstance(x, str):
+ return x
+ elif isinstance(x, bytes):
+ return x.decode('ascii')
+ else:
+ return str(x)
+
+def compatLong(x):
+ """Convert number to long"""
+ if sys.version_info >= (3, 0):
+ return int(x)
+ else:
+ return long(x)
if sys.version_info >= (3, 0):
if sys.version_info < (3, 4):
def compatHMAC(x):
"""Convert bytes-like input to format acceptable for HMAC."""
- pass
+ if isinstance(x, str):
+ return x.encode('ascii')
+ return x
else:
def compatHMAC(x):
"""Convert bytes-like input to format acceptable for HMAC."""
- pass
+ if isinstance(x, str):
+ return x.encode('ascii')
+ return x
def compatAscii2Bytes(val):
"""Convert ASCII string to bytes."""
- pass
+ if isinstance(val, str):
+ return val.encode('ascii')
+ return val
def compat_b2a(val):
"""Convert an ASCII bytes string to string."""
- pass
+ if isinstance(val, bytes):
+ return val.decode('ascii')
+ return val
int_types = tuple([int])
def formatExceptionTrace(e):
"""Return exception information formatted as string"""
- pass
+ return str(traceback.format_exception(type(e), e, e.__traceback__))
def time_stamp():
"""Returns system time as a float"""
- pass
+ return time.time()
def remove_whitespace(text):
"""Removes all whitespace from passed in string"""
- pass
+ return re.sub(r'\s+', '', text)
bytes_to_int = int.from_bytes
def bit_length(val):
"""Return number of bits necessary to represent an integer."""
- pass
+ return val.bit_length()
def int_to_bytes(val, length=None, byteorder='big'):
"""Return number converted to bytes"""
- pass
+ if length is None:
+ length = byte_length(val)
+ return val.to_bytes(length, byteorder=byteorder)
else:
if sys.version_info < (2, 7) or sys.version_info < (2, 7, 4) or platform.system() == 'Java':
def remove_whitespace(text):
"""Removes all whitespace from passed in string"""
- pass
+ return re.sub(r'\s+', '', text)
def bit_length(val):
"""Return number of bits necessary to represent an integer."""
- pass
+ if val == 0:
+ return 0
+ return len(bin(val)[2:])
else:
def remove_whitespace(text):
"""Removes all whitespace from passed in string"""
- pass
+ return re.sub(r'\s+', '', text)
def bit_length(val):
"""Return number of bits necessary to represent an integer."""
- pass
+ return val.bit_length()
def compatAscii2Bytes(val):
"""Convert ASCII string to bytes."""
- pass
+ if isinstance(val, str):
+ return val.encode('ascii')
+ return val
def compat_b2a(val):
"""Convert an ASCII bytes string to string."""
- pass
+ if isinstance(val, bytes):
+ return val.decode('ascii')
+ return val
int_types = (int, long)
def formatExceptionTrace(e):
"""Return exception information formatted as string"""
- pass
+ return str(traceback.format_exc())
def time_stamp():
"""Returns system time as a float"""
- pass
+ return time.time()
def bytes_to_int(val, byteorder):
"""Convert bytes to an int."""
- pass
+ if byteorder == 'big':
+ return int(b2a_hex(val), 16)
+ else:
+ return int(b2a_hex(val[::-1]), 16)
def int_to_bytes(val, length=None, byteorder='big'):
"""Return number converted to bytes"""
- pass
+ if length is None:
+ length = byte_length(val)
+ hex_str = '%x' % val
+ if len(hex_str) % 2:
+ hex_str = '0' + hex_str
+ result = a2b_hex(hex_str)
+ if len(result) < length:
+ result = b'\x00' * (length - len(result)) + result
+ if byteorder == 'little':
+ result = result[::-1]
+ return result
def byte_length(val):
"""Return number of bytes necessary to represent an integer."""
- pass
+ length = val.bit_length()
+ return (length + 7) // 8
try:
getattr(ecdsa, 'NIST192p')
except AttributeError:
diff --git a/tlslite/utils/cryptomath.py b/tlslite/utils/cryptomath.py
index abcdbce..37e88d9 100644
--- a/tlslite/utils/cryptomath.py
+++ b/tlslite/utils/cryptomath.py
@@ -56,19 +56,75 @@ prngName = 'os.urandom'
def MD5(b):
"""Return a MD5 digest of data"""
- pass
+ return hashlib.md5(compat26Str(b)).digest()
def SHA1(b):
"""Return a SHA1 digest of data"""
- pass
+ return hashlib.sha1(compat26Str(b)).digest()
+
+def HMAC_MD5(k, b):
+ """Return HMAC using MD5"""
+ return secureHMAC(k, b, "md5")
+
+def HMAC_SHA1(k, b):
+ """Return HMAC using SHA1"""
+ return secureHMAC(k, b, "sha1")
+
+def HMAC_SHA256(k, b):
+ """Return HMAC using SHA256"""
+ return secureHMAC(k, b, "sha256")
+
+def HMAC_SHA384(k, b):
+ """Return HMAC using SHA384"""
+ return secureHMAC(k, b, "sha384")
def secureHash(data, algorithm):
"""Return a digest of `data` using `algorithm`"""
- pass
+ hashInstance = hashlib.new(algorithm)
+ hashInstance.update(compat26Str(data))
+ return hashInstance.digest()
def secureHMAC(k, b, algorithm):
"""Return a HMAC using `b` and `k` using `algorithm`"""
- pass
+ k = compatHMAC(k)
+ b = compatHMAC(b)
+ return hmac.new(k, b, getattr(hashlib, algorithm)).digest()
+
+def getRandomBytes(howMany):
+ """Return a specified number of random bytes."""
+ return os.urandom(howMany)
+
+def getRandomNumber(low, high):
+ """Return a random number in the range [low, high]."""
+ if low >= high:
+ raise ValueError("Low must be lower than high")
+ howManyBits = len(bin(high - low)[2:])
+ howManyBytes = (howManyBits + 7) // 8
+ while True:
+ bytes = getRandomBytes(howManyBytes)
+ n = bytesToNumber(bytes)
+ if n >= low and n <= high:
+ return n
+
+def HKDF_expand(secret, info, length, algorithm):
+ """
+ HKDF-Expand function from RFC 5869.
+
+ :param bytearray secret: the key from which to derive the keying material
+ :param bytearray info: context specific information
+ :param int length: number of bytes to produce
+ :param str algorithm: name of the secure hash algorithm used as the
+ basis of the HKDF
+ :rtype: bytearray
+ """
+ hash_size = getattr(hashlib, algorithm)().digest_size
+ N = (length + hash_size - 1) // hash_size
+ T = bytearray()
+ output = bytearray()
+ for i in range(N):
+ T = secureHMAC(secret, T + info + bytearray([i + 1]), algorithm)
+ output += T
+ return output[:length]
def HKDF_expand_label(secret, label, hashValue, length, algorithm):
"""
@@ -83,7 +139,11 @@ def HKDF_expand_label(secret, label, hashValue, length, algorithm):
basis of the HKDF
:rtype: bytearray
"""
- pass
+ hkdf_label = Writer()
+ hkdf_label.add(length, 2)
+ hkdf_label.addVarSeq(b"tls13 " + label, 1, 1)
+ hkdf_label.addVarSeq(hashValue, 1, 1)
+ return HKDF_expand(secret, hkdf_label.bytes(), length, algorithm)
def derive_secret(secret, label, handshake_hashes, algorithm):
"""
@@ -99,7 +159,13 @@ def derive_secret(secret, label, handshake_hashes, algorithm):
be generated
:rtype: bytearray
"""
- pass
+ if handshake_hashes is None:
+ hash_value = secureHash(b"", algorithm)
+ else:
+ hash_value = handshake_hashes.digest(algorithm)
+ return HKDF_expand_label(secret, label, hash_value,
+ getattr(hashlib, algorithm).digest_size,
+ algorithm)
def bytesToNumber(b, endian='big'):
"""
@@ -107,7 +173,7 @@ def bytesToNumber(b, endian='big'):
By default assumes big-endian encoding of the number.
"""
- pass
+ return bytes_to_int(b, endian)
def numberToByteArray(n, howManyBytes=None, endian='big'):
"""
@@ -117,30 +183,106 @@ def numberToByteArray(n, howManyBytes=None, endian='big'):
not be larger. The returned bytearray will contain a big- or little-endian
encoding of the input integer (n). Big endian encoding is used by default.
"""
- pass
+ return bytearray(int_to_bytes(n, howManyBytes, endian))
def mpiToNumber(mpi):
"""Convert a MPI (OpenSSL bignum string) to an integer."""
- pass
+ byte_array = bytearray(mpi)
+ byte_len = (byte_array[0] * 256 + byte_array[1] + 7) // 8
+ return bytesToNumber(byte_array[2:2 + byte_len])
+
+def numberToMPI(n):
+ """Convert an integer to a MPI (OpenSSL bignum string)."""
+ b = numberToByteArray(n)
+ ext = 0
+ if len(b) == 0:
+ b = bytearray([0])
+ if b[0] & 0x80:
+ ext = 1
+ length = len(b) + ext
+ b2 = bytearray(2 + length)
+ b2[0] = (length >> 8) & 0xFF
+ b2[1] = length & 0xFF
+ for i in range(len(b)):
+ b2[2+ext+i] = b[i]
+ return bytes(b2)
numBits = bit_length
numBytes = byte_length
if GMPY2_LOADED:
def invMod(a, b):
"""Return inverse of a mod b, zero if none."""
- pass
+ try:
+ return int(powmod(mpz(a), -1, mpz(b)))
+ except (ValueError, ZeroDivisionError):
+ return 0
else:
def invMod(a, b):
"""Return inverse of a mod b, zero if none."""
- pass
+ s = 0
+ t = 1
+ r = b
+ old_s = 1
+ old_t = 0
+ old_r = a
+ while r != 0:
+ quotient = old_r // r
+ old_r, r = r, old_r - quotient * r
+ old_s, s = s, old_s - quotient * s
+ old_t, t = t, old_t - quotient * t
+ if old_r != 1:
+ return 0
+ while old_s < 0:
+ old_s += b
+ return old_s
if gmpyLoaded or GMPY2_LOADED:
+ powMod = powmod
else:
powMod = pow
def divceil(divident, divisor):
"""Integer division with rounding up"""
- pass
+ return (divident + divisor - 1) // divisor
+
+def isPrime(n, iterations=8):
+ """Returns True if n is prime with high probability"""
+ if n < 2:
+ return False
+ if n == 2:
+ return True
+ if n & 1 == 0:
+ return False
+
+ # Write n-1 as d * 2^s by factoring powers of 2 from n-1
+ s = 0
+ d = n - 1
+ while d & 1 == 0:
+ s += 1
+ d >>= 1
+
+ # Try to divide n by a few small primes
+ for i in [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97]:
+ if n % i == 0:
+ return n == i
+
+ # Do iterations of Miller-Rabin testing
+ for i in range(iterations):
+ a = bytes_to_int(os.urandom(numBytes(n)))
+ if a == 0 or a >= n:
+ a = 1
+ a = powMod(a, d, n)
+ if a == 1:
+ continue
+ for r in range(s):
+ if a == n - 1:
+ break
+ a = powMod(a, 2, n)
+ if a == 1:
+ return False
+ else:
+ return False
+ return True
def getRandomPrime(bits, display=False):
"""
@@ -149,7 +291,14 @@ def getRandomPrime(bits, display=False):
the number will be 'bits' bits long (i.e. generated number will be
larger than `(2^(bits-1) * 3 ) / 2` but smaller than 2^bits.
"""
- pass
+ while True:
+ n = bytes_to_int(os.urandom(bits // 8 + 1))
+ n |= 2 ** (bits - 1) # Set high bit
+ n &= ~(1 << (bits - 1)) - 1 # Clear low bits
+ if display:
+ print(".", end=' ')
+ if isPrime(n, iterations=30):
+ return n
def getRandomSafePrime(bits, display=False):
"""Generate a random safe prime.
@@ -157,4 +306,8 @@ def getRandomSafePrime(bits, display=False):
Will generate a prime `bits` bits long (see getRandomPrime) such that
the (p-1)/2 will also be prime.
"""
- pass
\ No newline at end of file
+ while True:
+ q = getRandomPrime(bits - 1, display)
+ p = 2 * q + 1
+ if isPrime(p, iterations=30):
+ return p
\ No newline at end of file
diff --git a/tlslite/utils/deprecations.py b/tlslite/utils/deprecations.py
index 787ea1d..cb9830e 100644
--- a/tlslite/utils/deprecations.py
+++ b/tlslite/utils/deprecations.py
@@ -15,7 +15,15 @@ def deprecated_class_name(old_name, warn="Class name '{old_name}' is deprecated,
keyword name and the 'new_name' for the current one.
Example: "Old name: {old_nam}, use '{new_name}' instead".
"""
- pass
+ def decorator(cls):
+ new_name = cls.__name__
+ def wrapper(*args, **kwargs):
+ warnings.warn(warn.format(old_name=old_name, new_name=new_name),
+ DeprecationWarning, stacklevel=2)
+ return cls(*args, **kwargs)
+ globals()[old_name] = wrapper
+ return cls
+ return decorator
def deprecated_params(names, warn="Param name '{old_name}' is deprecated, please use '{new_name}'"):
"""Decorator to translate obsolete names and warn about their use.
@@ -28,7 +36,17 @@ def deprecated_params(names, warn="Param name '{old_name}' is deprecated, please
deprecated keyword name and 'new_name' for the current one.
Example: "Old name: {old_name}, use {new_name} instead".
"""
- pass
+ def decorator(func):
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ for new_name, old_name in names.items():
+ if old_name in kwargs:
+ warnings.warn(warn.format(old_name=old_name, new_name=new_name),
+ DeprecationWarning, stacklevel=2)
+ kwargs[new_name] = kwargs.pop(old_name)
+ return func(*args, **kwargs)
+ return wrapper
+ return decorator
def deprecated_instance_attrs(names, warn="Attribute '{old_name}' is deprecated, please use '{new_name}'"):
"""Decorator to deprecate class instance attributes.
@@ -45,7 +63,34 @@ def deprecated_instance_attrs(names, warn="Attribute '{old_name}' is deprecated,
deprecated keyword name and 'new_name' for the current one.
Example: "Old name: {old_name}, use {new_name} instead".
"""
- pass
+ def decorator(cls):
+ old_getattr = cls.__getattr__ if hasattr(cls, '__getattr__') else None
+ old_setattr = cls.__setattr__ if hasattr(cls, '__setattr__') else None
+
+ def __getattr__(self, name):
+ for new_name, old_name in names.items():
+ if name == old_name:
+ warnings.warn(warn.format(old_name=old_name, new_name=new_name),
+ DeprecationWarning, stacklevel=2)
+ return getattr(self, new_name)
+ if old_getattr:
+ return old_getattr(self, name)
+ raise AttributeError(name)
+
+ def __setattr__(self, name, value):
+ for new_name, old_name in names.items():
+ if name == old_name:
+ warnings.warn(warn.format(old_name=old_name, new_name=new_name),
+ DeprecationWarning, stacklevel=2)
+ return setattr(self, new_name, value)
+ if old_setattr:
+ return old_setattr(self, name, value)
+ return object.__setattr__(self, name, value)
+
+ cls.__getattr__ = __getattr__
+ cls.__setattr__ = __setattr__
+ return cls
+ return decorator
def deprecated_attrs(names, warn="Attribute '{old_name}' is deprecated, please use '{new_name}'"):
"""Decorator to deprecate all specified attributes in class.
@@ -62,11 +107,28 @@ def deprecated_attrs(names, warn="Attribute '{old_name}' is deprecated, please u
deprecated keyword name and 'new_name' for the current one.
Example: "Old name: {old_name}, use {new_name} instead".
"""
- pass
+ class DeprecatedAttrMetaclass(type):
+ def __new__(cls, name, bases, attrs):
+ for new_name, old_name in names.items():
+ if old_name in attrs:
+ warnings.warn(warn.format(old_name=old_name, new_name=new_name),
+ DeprecationWarning, stacklevel=2)
+ attrs[new_name] = attrs.pop(old_name)
+ return super(DeprecatedAttrMetaclass, cls).__new__(cls, name, bases, attrs)
+
+ def decorator(cls):
+ return DeprecatedAttrMetaclass(cls.__name__, cls.__bases__, dict(cls.__dict__))
+ return decorator
def deprecated_method(message):
"""Decorator for deprecating methods.
:param ste message: The message you want to display.
"""
- pass
\ No newline at end of file
+ def decorator(func):
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ warnings.warn(message, DeprecationWarning, stacklevel=2)
+ return func(*args, **kwargs)
+ return wrapper
+ return decorator
\ No newline at end of file
diff --git a/tlslite/utils/keyfactory.py b/tlslite/utils/keyfactory.py
index 47b2fba..e815d44 100644
--- a/tlslite/utils/keyfactory.py
+++ b/tlslite/utils/keyfactory.py
@@ -20,7 +20,12 @@ def generateRSAKey(bits, implementations=['openssl', 'python']):
:rtype: ~tlslite.utils.rsakey.RSAKey
:returns: A new RSA private key.
"""
- pass
+ for implementation in implementations:
+ if implementation == 'openssl' and cryptomath.m2cryptoLoaded:
+ return OpenSSL_RSAKey.generate(bits)
+ elif implementation == 'python':
+ return Python_RSAKey.generate(bits)
+ raise ValueError("No acceptable implementations")
def parsePEMKey(s, private=False, public=False, passwordCallback=None, implementations=['openssl', 'python']):
"""Parse a PEM-format key.
@@ -78,7 +83,26 @@ def parsePEMKey(s, private=False, public=False, passwordCallback=None, implement
:raises SyntaxError: If the key is not properly formatted.
"""
- pass
+ for implementation in implementations:
+ if implementation == 'openssl' and cryptomath.m2cryptoLoaded:
+ key = OpenSSL_RSAKey.parse(s, passwordCallback)
+ break
+ elif implementation == 'python':
+ key = Python_RSAKey.parse(s)
+ break
+ else:
+ raise ValueError("No acceptable implementations")
+
+ if private and not key.hasPrivateKey():
+ raise SyntaxError("Not a private key")
+
+ if public:
+ return _createPublicKey(key)
+ else:
+ if private:
+ return _createPrivateKey(key)
+ else:
+ return key
def parseAsPublicKey(s):
"""Parse a PEM-formatted public key.
@@ -91,7 +115,7 @@ def parseAsPublicKey(s):
:raises SyntaxError: If the key is not properly formatted.
"""
- pass
+ return parsePEMKey(s, public=True)
def parsePrivateKey(s):
"""Parse a PEM-formatted private key.
@@ -104,20 +128,31 @@ def parsePrivateKey(s):
:raises SyntaxError: If the key is not properly formatted.
"""
- pass
+ return parsePEMKey(s, private=True)
def _createPublicKey(key):
"""
Create a new public key. Discard any private component,
and return the most efficient key possible.
"""
- pass
+ if not isinstance(key, RSAKey):
+ raise ValueError("Unsupported key type")
+ return key.publicKey()
def _createPrivateKey(key):
"""
Create a new private key. Return the most efficient key possible.
"""
- pass
+ if not isinstance(key, RSAKey):
+ raise ValueError("Unsupported key type")
+ return key
+
+def _createPublicRSAKey(n, e):
+ """Create a new public RSA key from modulus and exponent."""
+ key = Python_RSAKey()
+ key.n = n
+ key.e = e
+ return _createPublicKey(key)
def _create_public_ecdsa_key(point_x, point_y, curve_name, implementations=('python',)):
"""
@@ -139,14 +174,18 @@ def _create_public_ecdsa_key(point_x, point_y, curve_name, implementations=('pyt
concrete implementation of the verifying key (only 'python' is
supported currently)
"""
- pass
+ if 'python' not in implementations:
+ raise ValueError("No acceptable implementations")
+ return Python_ECDSAKey(point_x, point_y, curve_name)
def _create_public_eddsa_key(public_key, implementations=('python',)):
"""
Convert the python-ecdsa public key into concrete implementation of
verifier.
"""
- pass
+ if 'python' not in implementations:
+ raise ValueError("No acceptable implementations")
+ return Python_EdDSAKey(public_key)
def _create_public_dsa_key(p, q, g, y, implementations=('python',)):
"""
@@ -167,4 +206,6 @@ def _create_public_dsa_key(p, q, g, y, implementations=('python',)):
concrete implementation of the verifying key (only 'python' is
supported currently)
"""
- pass
\ No newline at end of file
+ if 'python' not in implementations:
+ raise ValueError("No acceptable implementations")
+ return Python_DSAKey(p, q, g, y)
\ No newline at end of file
diff --git a/tlslite/utils/openssl_aesccm.py b/tlslite/utils/openssl_aesccm.py
index 8c708cb..ed36242 100644
--- a/tlslite/utils/openssl_aesccm.py
+++ b/tlslite/utils/openssl_aesccm.py
@@ -3,10 +3,9 @@ from tlslite.utils.cryptomath import m2cryptoLoaded
from tlslite.utils.aesccm import AESCCM
from tlslite.utils import openssl_aes
if m2cryptoLoaded:
+ class OPENSSL_AESCCM(AESCCM):
-class OPENSSL_AESCCM(AESCCM):
-
- def __init__(self, key, implementation, rawAesEncrypt, tagLength):
- super(OPENSSL_AESCCM, self).__init__(key, implementation, rawAesEncrypt, tagLength)
- self._ctr = openssl_aes.new(key, 6, bytearray(b'\x00' * 16))
- self._cbc = openssl_aes.new(key, 2, bytearray(b'\x00' * 16))
\ No newline at end of file
+ def __init__(self, key, implementation, rawAesEncrypt, tagLength):
+ super(OPENSSL_AESCCM, self).__init__(key, implementation, rawAesEncrypt, tagLength)
+ self._ctr = openssl_aes.new(key, 6, bytearray(b'\x00' * 16))
+ self._cbc = openssl_aes.new(key, 2, bytearray(b'\x00' * 16))
\ No newline at end of file
diff --git a/tlslite/utils/openssl_aesgcm.py b/tlslite/utils/openssl_aesgcm.py
index 5283cf6..4f61aa2 100644
--- a/tlslite/utils/openssl_aesgcm.py
+++ b/tlslite/utils/openssl_aesgcm.py
@@ -4,9 +4,8 @@ from tlslite.utils.aesgcm import AESGCM
from tlslite.utils import openssl_aes
from tlslite.utils.rijndael import Rijndael
if m2cryptoLoaded:
+ class OPENSSL_AESGCM(AESGCM):
-class OPENSSL_AESGCM(AESGCM):
-
- def __init__(self, key, implementation, rawAesEncrypt):
- super(OPENSSL_AESGCM, self).__init__(key, implementation, rawAesEncrypt)
- self._ctr = openssl_aes.new(key, 6, bytearray(b'\x00' * 16))
\ No newline at end of file
+ def __init__(self, key, implementation, rawAesEncrypt):
+ super(OPENSSL_AESGCM, self).__init__(key, implementation, rawAesEncrypt)
+ self._ctr = openssl_aes.new(key, 6, bytearray(b'\x00' * 16))
\ No newline at end of file
diff --git a/tlslite/utils/pem.py b/tlslite/utils/pem.py
index 79ec0cd..b0aff2b 100644
--- a/tlslite/utils/pem.py
+++ b/tlslite/utils/pem.py
@@ -16,7 +16,23 @@ def dePem(s, name):
The first such PEM block in the input will be found, and its
payload will be base64 decoded and returned.
"""
- pass
+ start = "-----BEGIN " + name + "-----"
+ end = "-----END " + name + "-----"
+ s = str(s)
+
+ # Find first PEM block
+ start_index = s.find(start)
+ if start_index == -1:
+ raise SyntaxError("Missing PEM prefix")
+ end_index = s.find(end, start_index + len(start))
+ if end_index == -1:
+ raise SyntaxError("Missing PEM postfix")
+
+ # Get payload
+ s = s[start_index + len(start):end_index]
+ s = ''.join(s.splitlines())
+ s = s.strip()
+ return bytearray(binascii.a2b_base64(s))
def dePemList(s, name):
"""Decode a sequence of PEM blocks into a list of bytearrays.
@@ -42,7 +58,24 @@ def dePemList(s, name):
All such PEM blocks will be found, decoded, and return in an ordered list
of bytearrays, which may have zero elements if not PEM blocks are found.
"""
- pass
+ bList = []
+ start = "-----BEGIN " + name + "-----"
+ end = "-----END " + name + "-----"
+ s = str(s)
+ while True:
+ start_index = s.find(start)
+ if start_index == -1:
+ break
+ end_index = s.find(end, start_index + len(start))
+ if end_index == -1:
+ break
+ # Get the payload
+ payload = s[start_index + len(start):end_index]
+ payload = ''.join(payload.splitlines())
+ payload = payload.strip()
+ bList.append(bytearray(binascii.a2b_base64(payload)))
+ s = s[end_index + len(end):]
+ return bList
def pem(b, name):
"""Encode a payload bytearray into a PEM string.
@@ -56,4 +89,33 @@ def pem(b, name):
KoZIhvcNAQEFBQADAwA5kw==
-----END CERTIFICATE-----
"""
- pass
\ No newline at end of file
+ s = binascii.b2a_base64(b).decode()
+ s = s.rstrip() # remove newline from b2a_base64
+ s = "-----BEGIN " + name + "-----\n" + \
+ s + "\n" + \
+ "-----END " + name + "-----\n"
+ return s
+
+def pemSniff(s, name):
+ """Check if string appears to be a PEM-encoded payload with the specified name.
+
+ :type s: str
+ :param s: The string to check.
+
+ :type name: str
+ :param name: The expected name of the PEM payload.
+
+ :rtype: bool
+ :returns: True if the string appears to be a PEM-encoded payload with the
+ specified name, False otherwise.
+ """
+ start = "-----BEGIN " + name + "-----"
+ end = "-----END " + name + "-----"
+ s = str(s)
+ start_index = s.find(start)
+ if start_index == -1:
+ return False
+ end_index = s.find(end, start_index + len(start))
+ if end_index == -1:
+ return False
+ return True
\ No newline at end of file
diff --git a/tlslite/utils/python_rsakey.py b/tlslite/utils/python_rsakey.py
index 5571d84..6269124 100644
--- a/tlslite/utils/python_rsakey.py
+++ b/tlslite/utils/python_rsakey.py
@@ -55,11 +55,11 @@ class Python_RSAKey(RSAKey):
Does the key has the associated private key (True) or is it only
the public part (False).
"""
- pass
+ return self.d != 0
def acceptsPassword(self):
"""Does it support encrypted key files."""
- pass
+ return False
@staticmethod
def generate(bits, key_type='rsa'):
@@ -67,10 +67,62 @@ class Python_RSAKey(RSAKey):
key_type can be "rsa" for a universal rsaEncryption key or
"rsa-pss" for a key that can be used only for RSASSA-PSS."""
- pass
+ key = Python_RSAKey(key_type=key_type)
+ p = getRandomPrime(bits // 2, False)
+ q = getRandomPrime(bits // 2, False)
+ n = p * q
+ t = lcm(p - 1, q - 1)
+ e = 65537
+ d = invMod(e, t)
+ key.n = n
+ key.e = e
+ key.d = d
+ key.p = p
+ key.q = q
+ key.dP = d % (p - 1)
+ key.dQ = d % (q - 1)
+ key.qInv = invMod(q, p)
+ return key
+
+ @staticmethod
+ def parse(s):
+ """Parse a string containing a PEM-encoded key."""
+ return Python_RSAKey.parsePEM(s)
@staticmethod
@deprecated_params({'data': 's', 'password_callback': 'passwordCallback'})
def parsePEM(data, password_callback=None):
"""Parse a string containing a PEM-encoded <privateKey>."""
- pass
\ No newline at end of file
+ if password_callback:
+ raise ValueError("This implementation does not support encrypted keys")
+
+ # Try to parse as private key first
+ try:
+ der = dePem(data, "RSA PRIVATE KEY")
+ key = Python_RSAKey()
+ p = Parser(der)
+ p.get(1) # skip version
+ key.n = p.getInt(1)
+ key.e = p.getInt(1)
+ key.d = p.getInt(1)
+ key.p = p.getInt(1)
+ key.q = p.getInt(1)
+ key.dP = p.getInt(1)
+ key.dQ = p.getInt(1)
+ key.qInv = p.getInt(1)
+ return key
+ except SyntaxError:
+ pass
+
+ # Try to parse as public key
+ try:
+ der = dePem(data, "RSA PUBLIC KEY")
+ key = Python_RSAKey()
+ p = Parser(der)
+ key.n = p.getInt(1)
+ key.e = p.getInt(1)
+ return key
+ except SyntaxError:
+ pass
+
+ raise SyntaxError("Not a valid PEM file")
\ No newline at end of file
diff --git a/unit_tests/test_tlslite_integration_tlsasynciodispatchermixin.py b/unit_tests/test_tlslite_integration_tlsasynciodispatchermixin.py
index 7c7d3eb..2839882 100644
--- a/unit_tests/test_tlslite_integration_tlsasynciodispatchermixin.py
+++ b/unit_tests/test_tlslite_integration_tlsasynciodispatchermixin.py
@@ -6,9 +6,9 @@ import sys
# which is not available in Python 2- asyncio is used
# in the implementation of TLSAsyncioDispatcherMixIn
try:
+ import asyncio
from tlslite.integration.tlsasynciodispatchermixin \
import TLSAsyncioDispatcherMixIn
- import asyncio
except ImportError:
pass