back to Claude Sonnet 3.5 - Fill-in summary
Claude Sonnet 3.5 - Fill-in: tlslite-ng
Pytest Summary for test unit_tests
status | count |
---|---|
passed | 58 |
failed | 45 |
error | 1 |
total | 104 |
collected | 104 |
Failed pytests:
test_tlslite_utils_codec.py::TestParser::test_getRemainingLength
test_tlslite_utils_codec.py::TestParser::test_getRemainingLength
self =def test_getRemainingLength(self): p = Parser(bytearray( b'\x00\x01\x05' )) self.assertEqual(1, p.get(2)) > self.assertEqual(1, p.getRemainingLength()) E AssertionError: 1 != -2 unit_tests/test_tlslite_utils_codec.py:168: AssertionError
test_tlslite_utils_codec.py::TestWriter::test_addFixSeq_with_overflowing_data
test_tlslite_utils_codec.py::TestWriter::test_addFixSeq_with_overflowing_data
self =def test_addFixSeq_with_overflowing_data(self): w = Writer() with self.assertRaises(ValueError): > w.addFixSeq([16, 17, 256], 1) unit_tests/test_tlslite_utils_codec.py:311: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ tlslite/utils/codec.py:106: in addFixSeq self.add(item, length) _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ def add(self, x, length): """ Add a single positive integer value x, encode it in length bytes Encode positive integer x in big-endian format using length bytes, add to the internal buffer. :type x: int :param x: value to encode :type length: int :param length: number of bytes to use for encoding the value """ > self.bytes += x.to_bytes(length, byteorder='big') E OverflowError: int too big to convert tlslite/utils/codec.py:70: OverflowError
test_tlslite_utils_codec.py::TestWriter::test_addVarSeq
test_tlslite_utils_codec.py::TestWriter::test_addVarSeq
self =def test_addVarSeq(self): w = Writer() w.addVarSeq([16, 17, 18], 2, 2) > self.assertEqual(bytearray( b'\x00\x06' + b'\x00\x10' + b'\x00\x11' + b'\x00\x12'), w.bytes) E AssertionError: bytearray(b'\x00\x06\x00\x10\x00\x11\x00\x12') != bytearray(b'') unit_tests/test_tlslite_utils_codec.py:317: AssertionError
test_tlslite_utils_codec.py::TestWriter::test_addVarSeq_single_byte_data
test_tlslite_utils_codec.py::TestWriter::test_addVarSeq_single_byte_data
self =def test_addVarSeq_single_byte_data(self): w = Writer() w.addVarSeq([0xaa, 0xbb, 0xcc], 1, 2) > self.assertEqual(bytearray( b'\x00\x03' + b'\xaa' + b'\xbb' + b'\xcc'), w.bytes) E AssertionError: bytearray(b'\x00\x03\xaa\xbb\xcc') != bytearray(b'') unit_tests/test_tlslite_utils_codec.py:327: AssertionError
test_tlslite_utils_codec.py::TestWriter::test_addVarSeq_triple_byte_data
test_tlslite_utils_codec.py::TestWriter::test_addVarSeq_triple_byte_data
self =def test_addVarSeq_triple_byte_data(self): w = Writer() w.addVarSeq([0xaa, 0xbb, 0xcc], 3, 2) > self.assertEqual(bytearray( b'\x00\x09' + b'\x00\x00\xaa' + b'\x00\x00\xbb' + b'\x00\x00\xcc'), w.bytes) E AssertionError: bytearray(b'\x00\t\x00\x00\xaa\x00\x00\xbb\x00\x00\xcc') != bytearray(b'') unit_tests/test_tlslite_utils_codec.py:337: AssertionError
test_tlslite_utils_codec.py::TestWriter::test_addVarSeq_with_one_byte_overflowing_data
test_tlslite_utils_codec.py::TestWriter::test_addVarSeq_with_one_byte_overflowing_data
self =def test_addVarSeq_with_one_byte_overflowing_data(self): w = Writer() > with self.assertRaises(ValueError): E AssertionError: ValueError not raised unit_tests/test_tlslite_utils_codec.py:352: AssertionError
test_tlslite_utils_codec.py::TestWriter::test_addVarSeq_with_overflowing_data
test_tlslite_utils_codec.py::TestWriter::test_addVarSeq_with_overflowing_data
self =def test_addVarSeq_with_overflowing_data(self): w = Writer() > with self.assertRaises(ValueError): E AssertionError: ValueError not raised unit_tests/test_tlslite_utils_codec.py:346: AssertionError
test_tlslite_utils_codec.py::TestWriter::test_addVarSeq_with_three_byte_overflowing_data
test_tlslite_utils_codec.py::TestWriter::test_addVarSeq_with_three_byte_overflowing_data
self =def test_addVarSeq_with_three_byte_overflowing_data(self): w = Writer() > with self.assertRaises(ValueError): E AssertionError: ValueError not raised unit_tests/test_tlslite_utils_codec.py:358: AssertionError
test_tlslite_utils_codec.py::TestWriter::test_addVarTupleSeq_with_double_byte_invalid_sized_tuples
test_tlslite_utils_codec.py::TestWriter::test_addVarTupleSeq_with_double_byte_invalid_sized_tuples
self =def test_addVarTupleSeq_with_double_byte_invalid_sized_tuples(self): w = Writer() > with self.assertRaises(ValueError): E AssertionError: ValueError not raised unit_tests/test_tlslite_utils_codec.py:408: AssertionError
test_tlslite_utils_codec.py::TestWriter::test_addVarTupleSeq_with_double_byte_overflowing_data
test_tlslite_utils_codec.py::TestWriter::test_addVarTupleSeq_with_double_byte_overflowing_data
self =def test_addVarTupleSeq_with_double_byte_overflowing_data(self): w = Writer() with self.assertRaises(ValueError): > w.addVarTupleSeq([(1, 2), (3, 0x10000)], 2, 2) unit_tests/test_tlslite_utils_codec.py:414: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ tlslite/utils/codec.py:175: in addVarTupleSeq self.add(item, length) _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ def add(self, x, length): """ Add a single positive integer value x, encode it in length bytes Encode positive integer x in big-endian format using length bytes, add to the internal buffer. :type x: int :param x: value to encode :type length: int :param length: number of bytes to use for encoding the value """ > self.bytes += x.to_bytes(length, byteorder='big') E OverflowError: int too big to convert tlslite/utils/codec.py:70: OverflowError
test_tlslite_utils_codec.py::TestWriter::test_addVarTupleSeq_with_invalid_sized_tuples
test_tlslite_utils_codec.py::TestWriter::test_addVarTupleSeq_with_invalid_sized_tuples
self =def test_addVarTupleSeq_with_invalid_sized_tuples(self): w = Writer() > with self.assertRaises(ValueError): E AssertionError: ValueError not raised unit_tests/test_tlslite_utils_codec.py:397: AssertionError
test_tlslite_utils_codec.py::TestWriter::test_addVarTupleSeq_with_overflowing_data
test_tlslite_utils_codec.py::TestWriter::test_addVarTupleSeq_with_overflowing_data
self =def test_addVarTupleSeq_with_overflowing_data(self): w = Writer() with self.assertRaises(ValueError): > w.addVarTupleSeq([(1, 2), (2, 256)], 1, 2) unit_tests/test_tlslite_utils_codec.py:404: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ tlslite/utils/codec.py:175: in addVarTupleSeq self.add(item, length) _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ def add(self, x, length): """ Add a single positive integer value x, encode it in length bytes Encode positive integer x in big-endian format using length bytes, add to the internal buffer. :type x: int :param x: value to encode :type length: int :param length: number of bytes to use for encoding the value """ > self.bytes += x.to_bytes(length, byteorder='big') E OverflowError: int too big to convert tlslite/utils/codec.py:70: OverflowError
test_tlslite_utils_codec.py::TestWriter::test_add_with_five_overflowing_bytes
test_tlslite_utils_codec.py::TestWriter::test_add_with_five_overflowing_bytes
self =def test_add_with_five_overflowing_bytes(self): w = Writer() with self.assertRaises(ValueError): > w.add(0x010000000000, 5) unit_tests/test_tlslite_utils_codec.py:278: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ def add(self, x, length): """ Add a single positive integer value x, encode it in length bytes Encode positive integer x in big-endian format using length bytes, add to the internal buffer. :type x: int :param x: value to encode :type length: int :param length: number of bytes to use for encoding the value """ > self.bytes += x.to_bytes(length, byteorder='big') E OverflowError: int too big to convert tlslite/utils/codec.py:70: OverflowError
test_tlslite_utils_codec.py::TestWriter::test_add_with_five_underflowing_bytes
test_tlslite_utils_codec.py::TestWriter::test_add_with_five_underflowing_bytes
self =def test_add_with_five_underflowing_bytes(self): w = Writer() with self.assertRaises(ValueError): > w.add(-1, 5) unit_tests/test_tlslite_utils_codec.py:284: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ def add(self, x, length): """ Add a single positive integer value x, encode it in length bytes Encode positive integer x in big-endian format using length bytes, add to the internal buffer. :type x: int :param x: value to encode :type length: int :param length: number of bytes to use for encoding the value """ > self.bytes += x.to_bytes(length, byteorder='big') E OverflowError: can't convert negative int to unsigned tlslite/utils/codec.py:70: OverflowError
test_tlslite_utils_codec.py::TestWriter::test_add_with_four_bytes_overflowing_data
test_tlslite_utils_codec.py::TestWriter::test_add_with_four_bytes_overflowing_data
self =def test_add_with_four_bytes_overflowing_data(self): w = Writer() with self.assertRaises(ValueError): > w.add(0x0100000000, 4) unit_tests/test_tlslite_utils_codec.py:290: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ def add(self, x, length): """ Add a single positive integer value x, encode it in length bytes Encode positive integer x in big-endian format using length bytes, add to the internal buffer. :type x: int :param x: value to encode :type length: int :param length: number of bytes to use for encoding the value """ > self.bytes += x.to_bytes(length, byteorder='big') E OverflowError: int too big to convert tlslite/utils/codec.py:70: OverflowError
test_tlslite_utils_codec.py::TestWriter::test_add_with_overflowing_data
test_tlslite_utils_codec.py::TestWriter::test_add_with_overflowing_data
self =def test_add_with_overflowing_data(self): w = Writer() with self.assertRaises(ValueError): > w.add(256, 1) unit_tests/test_tlslite_utils_codec.py:248: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ def add(self, x, length): """ Add a single positive integer value x, encode it in length bytes Encode positive integer x in big-endian format using length bytes, add to the internal buffer. :type x: int :param x: value to encode :type length: int :param length: number of bytes to use for encoding the value """ > self.bytes += x.to_bytes(length, byteorder='big') E OverflowError: int too big to convert tlslite/utils/codec.py:70: OverflowError
test_tlslite_utils_codec.py::TestWriter::test_add_with_three_overflowing_data
test_tlslite_utils_codec.py::TestWriter::test_add_with_three_overflowing_data
self =def test_add_with_three_overflowing_data(self): w = Writer() with self.assertRaises(ValueError): > w.add(0x01000000, 3) unit_tests/test_tlslite_utils_codec.py:236: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ def add(self, x, length): """ Add a single positive integer value x, encode it in length bytes Encode positive integer x in big-endian format using length bytes, add to the internal buffer. :type x: int :param x: value to encode :type length: int :param length: number of bytes to use for encoding the value """ > self.bytes += x.to_bytes(length, byteorder='big') E OverflowError: int too big to convert tlslite/utils/codec.py:70: OverflowError
test_tlslite_utils_codec.py::TestWriter::test_add_with_three_underflowing_data
test_tlslite_utils_codec.py::TestWriter::test_add_with_three_underflowing_data
self =def test_add_with_three_underflowing_data(self): w = Writer() with self.assertRaises(ValueError): > w.add(-1, 3) unit_tests/test_tlslite_utils_codec.py:242: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ def add(self, x, length): """ Add a single positive integer value x, encode it in length bytes Encode positive integer x in big-endian format using length bytes, add to the internal buffer. :type x: int :param x: value to encode :type length: int :param length: number of bytes to use for encoding the value """ > self.bytes += x.to_bytes(length, byteorder='big') E OverflowError: can't convert negative int to unsigned tlslite/utils/codec.py:70: OverflowError
test_tlslite_utils_codec.py::TestWriter::test_add_with_underflowing_data
test_tlslite_utils_codec.py::TestWriter::test_add_with_underflowing_data
self =def test_add_with_underflowing_data(self): w = Writer() with self.assertRaises(ValueError): > w.add(-1, 1) unit_tests/test_tlslite_utils_codec.py:254: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ def add(self, x, length): """ Add a single positive integer value x, encode it in length bytes Encode positive integer x in big-endian format using length bytes, add to the internal buffer. :type x: int :param x: value to encode :type length: int :param length: number of bytes to use for encoding the value """ > self.bytes += x.to_bytes(length, byteorder='big') E OverflowError: can't convert negative int to unsigned tlslite/utils/codec.py:70: OverflowError
test_tlslite_utils_cryptomath_m2crypto.py::TestM2CryptoLoaded::test_import_with_m2crypto
test_tlslite_utils_cryptomath_m2crypto.py::TestM2CryptoLoaded::test_import_with_m2crypto
self =def test_import_with_m2crypto(self): fake_m2 = mock.MagicMock() with mock.patch.dict('sys.modules', {'M2Crypto': fake_m2}): > import tlslite.utils.cryptomath E File "/testbed/tlslite/utils/cryptomath.py", line 166 E else: E ^^^^ E IndentationError: expected an indented block after 'if' statement on line 165 unit_tests/test_tlslite_utils_cryptomath_m2crypto.py:83: IndentationError
test_tlslite_utils_cryptomath_m2crypto.py::TestM2CryptoLoaded::test_import_with_m2crypto_in_container
test_tlslite_utils_cryptomath_m2crypto.py::TestM2CryptoLoaded::test_import_with_m2crypto_in_container
self =def test_import_with_m2crypto_in_container(self): fake_m2 = mock.MagicMock() with mock.patch.dict('sys.modules', {'M2Crypto': fake_m2}): with mock.patch.object(builtins, 'open', magic_open_error): > import tlslite.utils.cryptomath E File "/testbed/tlslite/utils/cryptomath.py", line 166 E else: E ^^^^ E IndentationError: expected an indented block after 'if' statement on line 165 unit_tests/test_tlslite_utils_cryptomath_m2crypto.py:103: IndentationError
test_tlslite_utils_cryptomath_m2crypto.py::TestM2CryptoLoaded::test_import_with_m2crypto_in_fips_mode
test_tlslite_utils_cryptomath_m2crypto.py::TestM2CryptoLoaded::test_import_with_m2crypto_in_fips_mode
self =def test_import_with_m2crypto_in_fips_mode(self): fake_m2 = mock.MagicMock() with mock.patch.dict('sys.modules', {'M2Crypto': fake_m2}): with mock.patch.object(builtins, 'open', magic_open): > import tlslite.utils.cryptomath E File "/testbed/tlslite/utils/cryptomath.py", line 166 E else: E ^^^^ E IndentationError: expected an indented block after 'if' statement on line 165 unit_tests/test_tlslite_utils_cryptomath_m2crypto.py:93: IndentationError
test_tlslite_utils_cryptomath_m2crypto.py::TestM2CryptoLoaded::test_import_without_m2crypto
test_tlslite_utils_cryptomath_m2crypto.py::TestM2CryptoLoaded::test_import_without_m2crypto
self =def test_import_without_m2crypto(self): with mock.patch.dict('sys.modules', {'M2Crypto': None}): > import tlslite.utils.cryptomath E File "/testbed/tlslite/utils/cryptomath.py", line 166 E else: E ^^^^ E IndentationError: expected an indented block after 'if' statement on line 165 unit_tests/test_tlslite_utils_cryptomath_m2crypto.py:74: IndentationError
test_tlslite_utils_deprecations.py::TestDeprecatedClassName::test_check_callable
test_tlslite_utils_deprecations.py::TestDeprecatedClassName::test_check_callable
self =def test_check_callable(self): @deprecated_class_name('bad_func') def good_func(param): return "got '{0}'".format(param) self.assertEqual("got 'some'", good_func('some')) with self.assertWarns(DeprecationWarning) as e: > val = bad_func('other') E NameError: name 'bad_func' is not defined unit_tests/test_tlslite_utils_deprecations.py:81: NameError
test_tlslite_utils_deprecations.py::TestDeprecatedClassName::test_check_class
test_tlslite_utils_deprecations.py::TestDeprecatedClassName::test_check_class
self =def test_check_class(self): @deprecated_class_name('bad_name') class Test1(object): def __init__(self, param): self.param = param def method(self): return self.param instance = Test1('value') self.assertEqual('value', instance.method()) > self.assertIsInstance(instance, bad_name) E NameError: name 'bad_name' is not defined unit_tests/test_tlslite_utils_deprecations.py:60: NameError
test_tlslite_utils_deprecations.py::TestDeprecatedClassName::test_check_with_duplicated_name
test_tlslite_utils_deprecations.py::TestDeprecatedClassName::test_check_with_duplicated_name
self =def test_check_with_duplicated_name(self): @deprecated_class_name('bad_func2') def good_func(): return None > with self.assertRaises(NameError): E AssertionError: NameError not raised unit_tests/test_tlslite_utils_deprecations.py:91: AssertionError
test_tlslite_utils_deprecations.py::TestDeprecatedParams::test_both_params
test_tlslite_utils_deprecations.py::TestDeprecatedParams::test_both_params
self =def test_both_params(self): @deprecated_params({'param_a': 'older_param'}) def method(param_a, param_b): return (param_a, param_b) a = mock.Mock() b = mock.Mock() c = mock.Mock() > with self.assertRaises(TypeError) as e: E AssertionError: TypeError not raised unit_tests/test_tlslite_utils_deprecations.py:143: AssertionError
test_tlslite_utils_deprecations.py::TestDeprecatedFields::test_class_variable_deletion
test_tlslite_utils_deprecations.py::TestDeprecatedFields::test_class_variable_deletion
self =def test_class_variable_deletion(self): @deprecated_attrs({"new_cvar": "old_cvar"}) class Clazz(object): new_cvar = "first title" @classmethod def method(cls): return cls.new_cvar instance = Clazz() self.assertEqual(instance.method(), "first title") self.assertEqual(instance.new_cvar, "first title") with self.assertWarns(DeprecationWarning) as e: > self.assertEqual(Clazz.old_cvar, "first title") E AssertionError: != 'first title' unit_tests/test_tlslite_utils_deprecations.py:422: AssertionError
test_tlslite_utils_deprecations.py::TestDeprecatedFields::test_deprecated_attrs_variable_deletion
test_tlslite_utils_deprecations.py::TestDeprecatedFields::test_deprecated_attrs_variable_deletion
self =def test_deprecated_attrs_variable_deletion(self): @deprecated_attrs({"new_cvar": "old_cvar"}) class Clazz(object): new_cvar = "first title" def __init__(self): self.val = "something" @classmethod def method(cls): return cls.new_cvar instance = Clazz() self.assertEqual(instance.method(), "first title") self.assertEqual(instance.new_cvar, "first title") with self.assertWarns(DeprecationWarning) as e: > self.assertEqual(Clazz.old_cvar, "first title") E AssertionError: != 'first title' unit_tests/test_tlslite_utils_deprecations.py:372: AssertionError
test_tlslite_utils_deprecations.py::TestDeprecatedFields::test_deprecated_class_method
test_tlslite_utils_deprecations.py::TestDeprecatedFields::test_deprecated_class_method
self =def test_deprecated_class_method(self): @deprecated_attrs({"foo": "bar"}) class Clazz(object): @classmethod def foo(cls, arg): return "foo: {0}".format(arg) instance = Clazz() self.assertEqual(instance.foo("aa"), "foo: aa") self.assertEqual(Clazz.foo("aa"), "foo: aa") with self.assertWarns(DeprecationWarning) as e: self.assertEqual(instance.bar("aa"), "foo: aa") self.assertIn("bar", str(e.warning)) with self.assertWarns(DeprecationWarning) as e: > self.assertEqual(Clazz.bar("aa"), "foo: aa") E TypeError: 'property' object is not callable unit_tests/test_tlslite_utils_deprecations.py:280: TypeError
test_tlslite_utils_deprecations.py::TestDeprecatedFields::test_deprecated_class_variable
test_tlslite_utils_deprecations.py::TestDeprecatedFields::test_deprecated_class_variable
self =def test_deprecated_class_variable(self): @deprecated_attrs({"new_cvar": "old_cvar"}) class Clazz(object): new_cvar = "some string" def method(self): return self.new_cvar instance = Clazz() self.assertEqual(instance.method(), "some string") Clazz.new_cvar = bytearray(b"new string") self.assertEqual(instance.new_cvar, b"new string") with self.assertWarns(DeprecationWarning) as e: self.assertEqual(instance.old_cvar, b"new string") self.assertIn("old_cvar", str(e.warning)) self.assertIn("new_cvar", str(e.warning)) with self.assertWarns(DeprecationWarning) as e: > self.assertEqual(Clazz.old_cvar, b"new string") E AssertionError: != b'new string' unit_tests/test_tlslite_utils_deprecations.py:323: AssertionError
test_tlslite_utils_deprecations.py::TestDeprecatedFields::test_deprecated_instance_variable
test_tlslite_utils_deprecations.py::TestDeprecatedFields::test_deprecated_instance_variable
self =def test_deprecated_instance_variable(self): @deprecated_attrs({"new_field": "old_field"}) class Clazz(object): def __init__(self): self.new_field = "I'm new_field" instance = Clazz() self.assertEqual(instance.new_field, "I'm new_field") with self.assertWarns(DeprecationWarning) as e: self.assertEqual(instance.old_field, "I'm new_field") instance.old_field = "I've been set" self.assertEqual(instance.new_field, "I've been set") self.assertIn("old_field", str(e.warning)) with self.assertWarns(DeprecationWarning): > del instance.old_field E AttributeError: property of 'Clazz' object has no deleter unit_tests/test_tlslite_utils_deprecations.py:245: AttributeError
test_tlslite_utils_deprecations.py::TestDeprecatedFields::test_deprecated_static_method
test_tlslite_utils_deprecations.py::TestDeprecatedFields::test_deprecated_static_method
self =def test_deprecated_static_method(self): @deprecated_attrs({"new_stic": "old_stic"}) class Clazz(object): @staticmethod def new_stic(param): return "new_stic: {0}".format(param) instance = Clazz() self.assertEqual(instance.new_stic("aaa"), "new_stic: aaa") self.assertEqual(Clazz.new_stic("aaa"), "new_stic: aaa") with self.assertWarns(DeprecationWarning) as e: self.assertEqual(instance.old_stic("aaa"), "new_stic: aaa") self.assertIn("old_stic", str(e.warning)) with self.assertWarns(DeprecationWarning) as e: > self.assertEqual(Clazz.old_stic("aaa"), "new_stic: aaa") E TypeError: 'property' object is not callable unit_tests/test_tlslite_utils_deprecations.py:300: TypeError
test_tlslite_utils_deprecations.py::TestDeprecatedMethods::test_deprecated_method
test_tlslite_utils_deprecations.py::TestDeprecatedMethods::test_deprecated_method
self =def test_deprecated_method(self): @deprecated_method("Please use foo method instead.") def test(param): return param with self.assertWarns(DeprecationWarning) as e: r = test("test") self.assertEqual(r, "test") > self.assertEqual("test is a deprecated method. Please" \ " use foo method instead.", str(e.warning)) E AssertionError: 'test is a deprecated method. Please use foo method instead.' != 'Please use foo method instead.' E - test is a deprecated method. Please use foo method instead. E + Please use foo method instead. unit_tests/test_tlslite_utils_deprecations.py:469: AssertionError
test_tlslite_utils_dns_utils.py::TestIsValidHostname::test_example
test_tlslite_utils_dns_utils.py::TestIsValidHostname::test_example
self =def test_example(self): > self.assertTrue(is_valid_hostname(b'example.com')) E AssertionError: False is not true unit_tests/test_tlslite_utils_dns_utils.py:15: AssertionError
test_tlslite_utils_dns_utils.py::TestIsValidHostname::test_hostname_alone
test_tlslite_utils_dns_utils.py::TestIsValidHostname::test_hostname_alone
self =def test_hostname_alone(self): > self.assertTrue(is_valid_hostname(b'localhost')) E AssertionError: False is not true unit_tests/test_tlslite_utils_dns_utils.py:30: AssertionError
test_tlslite_utils_dns_utils.py::TestIsValidHostname::test_ip_lookalike_hostname
test_tlslite_utils_dns_utils.py::TestIsValidHostname::test_ip_lookalike_hostname
self =def test_ip_lookalike_hostname(self): > self.assertTrue(is_valid_hostname(b'192.168.example.com')) E AssertionError: False is not true unit_tests/test_tlslite_utils_dns_utils.py:24: AssertionError
test_tlslite_utils_dns_utils.py::TestIsValidHostname::test_long_hostname
test_tlslite_utils_dns_utils.py::TestIsValidHostname::test_long_hostname
self =def test_long_hostname(self): > self.assertTrue(is_valid_hostname(b'a' * 60 + b'.example.com')) E AssertionError: False is not true unit_tests/test_tlslite_utils_dns_utils.py:39: AssertionError
test_tlslite_utils_dns_utils.py::TestIsValidHostname::test_with_tld_dot
test_tlslite_utils_dns_utils.py::TestIsValidHostname::test_with_tld_dot
self =def test_with_tld_dot(self): > self.assertTrue(is_valid_hostname(b'example.com.')) E AssertionError: False is not true unit_tests/test_tlslite_utils_dns_utils.py:27: AssertionError
test_tlslite_utils_ecc.py::TestCurveLookup::test_with_correct_name
test_tlslite_utils_ecc.py::TestCurveLookup::test_with_correct_name
self =def test_with_correct_name(self): > curve = getCurveByName('secp256r1') unit_tests/test_tlslite_utils_ecc.py:18: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ curveName = 'secp256r1' def getCurveByName(curveName): """Return curve identified by curveName""" > for curve in ecdsaAllCurves: E TypeError: 'bool' object is not iterable tlslite/utils/ecc.py:8: TypeError
test_tlslite_utils_ecc.py::TestCurveLookup::test_with_invalid_name
test_tlslite_utils_ecc.py::TestCurveLookup::test_with_invalid_name
self =def test_with_invalid_name(self): with self.assertRaises(ValueError): > getCurveByName('NIST256p') unit_tests/test_tlslite_utils_ecc.py:23: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ def getCurveByName(curveName): """Return curve identified by curveName""" > for curve in ecdsaAllCurves: E TypeError: 'bool' object is not iterable tlslite/utils/ecc.py:8: TypeError
test_tlslite_utils_ecc.py::TestGetPointByteSize::test_with_curve
test_tlslite_utils_ecc.py::TestGetPointByteSize::test_with_curve
self =def test_with_curve(self): > self.assertEqual(getPointByteSize(ecdsa.NIST256p), 32) unit_tests/test_tlslite_utils_ecc.py:27: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ point = NIST256p def getPointByteSize(point): """Convert the point or curve bit size to bytes""" if isinstance(point, ecdsa.ellipticcurve.Point): return (point.curve().p().bit_length() + 7) // 8 elif isinstance(point, ecdsa.curves.Curve): > return (point.p().bit_length() + 7) // 8 E AttributeError: 'Curve' object has no attribute 'p' tlslite/utils/ecc.py:19: AttributeError
test_tlslite_utils_ecc.py::TestGetPointByteSize::test_with_invalid_argument
test_tlslite_utils_ecc.py::TestGetPointByteSize::test_with_invalid_argument
self =def test_with_invalid_argument(self): with self.assertRaises(ValueError): > getPointByteSize("P-256") unit_tests/test_tlslite_utils_ecc.py:34: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ def getPointByteSize(point): """Convert the point or curve bit size to bytes""" if isinstance(point, ecdsa.ellipticcurve.Point): return (point.curve().p().bit_length() + 7) // 8 elif isinstance(point, ecdsa.curves.Curve): return (point.p().bit_length() + 7) // 8 else: > raise TypeError("Input must be an elliptic curve point or curve") E TypeError: Input must be an elliptic curve point or curve tlslite/utils/ecc.py:21: TypeError
test_tlslite_utils_ecc.py::TestGetPointByteSize::test_with_point
test_tlslite_utils_ecc.py::TestGetPointByteSize::test_with_point
self =def test_with_point(self): > self.assertEqual(getPointByteSize(ecdsa.NIST384p.generator * 10), 48) unit_tests/test_tlslite_utils_ecc.py:30: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ point = def getPointByteSize(point): """Convert the point or curve bit size to bytes""" if isinstance(point, ecdsa.ellipticcurve.Point): return (point.curve().p().bit_length() + 7) // 8 elif isinstance(point, ecdsa.curves.Curve): return (point.p().bit_length() + 7) // 8 else: > raise TypeError("Input must be an elliptic curve point or curve") E TypeError: Input must be an elliptic curve point or curve tlslite/utils/ecc.py:21: TypeError
test_tlslite_utils_lists.py::TestGetFirstMatching::test_no_matches
test_tlslite_utils_lists.py::TestGetFirstMatching::test_no_matches
self =def test_no_matches(self): with self.assertRaises(AssertionError): > getFirstMatching([1, 2, 3], None) unit_tests/test_tlslite_utils_lists.py:35: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ tlslite/utils/lists.py:20: in getFirstMatching return next((item for item in values if item in matches), None) _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ > return next((item for item in values if item in matches), None) E TypeError: argument of type 'NoneType' is not iterable tlslite/utils/lists.py:20: TypeError
test_tlslite_utils_tlshashlib.py::TestTLSHashlib::test_in_fips_mode
test_tlslite_utils_tlshashlib.py::TestTLSHashlib::test_in_fips_mode
self =def test_in_fips_mode(self): def m(*args, **kwargs): if 'usedforsecurity' not in kwargs: raise ValueError("MD5 disabled in FIPS mode") with mock.patch('hashlib.md5', m): from tlslite.utils.tlshashlib import md5 > md5() unit_tests/test_tlslite_utils_tlshashlib.py:27: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ tlslite/utils/tlshashlib.py:18: in md5 return _fipsFunction(hashlib.md5, *args, **kwargs) tlslite/utils/tlshashlib.py:9: in _fipsFunction return func(*args, **kwargs) _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ args = (), kwargs = {} def m(*args, **kwargs): if 'usedforsecurity' not in kwargs: > raise ValueError("MD5 disabled in FIPS mode") E ValueError: MD5 disabled in FIPS mode unit_tests/test_tlslite_utils_tlshashlib.py:23: ValueError
Patch diff
diff --git a/test b/test
index e69de29..ad4ee32 100644
--- a/test
+++ b/test
@@ -0,0 +1,128 @@
+from tlslite.x509certchain import X509CertChain
+from tlslite.x509 import X509
+from tlslite.utils.pem import parsePemList
+from tlslite.utils.constanttime import (
+ ct_lt_u32, ct_gt_u32, ct_le_u32, ct_lsb_prop_u8, ct_lsb_prop_u16,
+ ct_isnonzero_u32, ct_neq_u32, ct_eq_u32, ct_check_cbc_mac_and_pad,
+ ct_compare_digest
+)
+
+class TestX509CertChain(unittest.TestCase):
+ def setUp(self):
+ # Sample PEM-encoded certificate
+ self.pem_cert = """-----BEGIN CERTIFICATE-----
+MIIDazCCAlOgAwIBAgIUJQpNHaJuEpNIFiLthZ+6T+JuMb0wDQYJKoZIhvcNAQEL
+BQAwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM
+GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAeFw0yMzA1MTcxNTIyMzhaFw0yNDA1
+MTYxNTIyMzhaMEUxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEw
+HwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQwggEiMA0GCSqGSIb3DQEB
+AQUAA4IBDwAwggEKAoIBAQC8VHvNs5tsD+qLILrjWb5C4GFaWd5eUr6XQgKTuBJM
+F7uqIIJuZOXAJkN7y5+gHj5o7aPP/DLDHoLGM2uX4h01XKVdlLmjXlf+WzUUKLR8
+LfPICNIH7FB5vAn8tawuHjRBNs1nZMgGE7STv756o1FBqZYZu0gF3dGlBg5yFPwO
+1I4IG1j+GnDN4OYw5BjtN6nGvEQiR7pvVGhXHOJTOVHbvZf3sFHPeFvzilKBe0pq
+5bhVzDNAvwXX+jNvGWHKW3YdVcqnJFX9JR8UzWWbYrs0xnuY5NeKzXCcCsXISLrx
+mJRUVDQn4z0Xp2LH1H8vWf+0LNQ4NZY5Tovg+/+AxOYnAgMBAAGjUzBRMB0GA1Ud
+DgQWBBQHWYtkZGWulLut7eo/ufZzONAvYjAfBgNVHSMEGDAWgBQHWYtkZGWulLut
+7eo/ufZzONAvYjAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IBAQCL
+FWFd5ZzTxbuTxDl2X8qFJEr++eTtLKdwZyq6dfne6v6DGIXIPoKWE9DTrEZijQZP
+/UxttRKM8j3GcKQbOdXf5ZgMT2qh+1zau6CUBw5a9h2xkHC9PsPhRXRIVFdNWA/w
+eQM8aqnuUmRlGqyOQAXhLhLzVOUl+cqDaJjPQ0FqoFiEPD+P5WpWlTxLgLYsKJ9P
+bFXDj9FQcuHwWRaTzpTXdxDcg6wCJI3uJq6FnAfGLAp5DZydtGbAy+q1pMEmwZrw
+RX+1PTAOF2tc0TK7Gx1WqTbMQ3ZSie6oBNQ8UkHMqLznEt8mZnUMBOXjAqrwq8IJ
+1/Jg9RBljKJXXRYQbWPk
+-----END CERTIFICATE-----"""
+
+ def test_parsePemList(self):
+ chain = X509CertChain()
+ chain.parsePemList(self.pem_cert)
+ self.assertEqual(chain.getNumCerts(), 1)
+
+ def test_getNumCerts(self):
+ chain = X509CertChain()
+ self.assertEqual(chain.getNumCerts(), 0)
+
+ chain.parsePemList(self.pem_cert)
+ self.assertEqual(chain.getNumCerts(), 1)
+
+ def test_getEndEntityPublicKey(self):
+ chain = X509CertChain()
+ chain.parsePemList(self.pem_cert)
+ public_key = chain.getEndEntityPublicKey()
+ self.assertIsNotNone(public_key)
+
+ def test_getFingerprint(self):
+ chain = X509CertChain()
+ chain.parsePemList(self.pem_cert)
+ fingerprint = chain.getFingerprint()
+ self.assertIsInstance(fingerprint, str)
+ self.assertEqual(len(fingerprint), 64) # SHA256 fingerprint is 64 characters long
+
+ def test_getTackExt(self):
+ chain = X509CertChain()
+ chain.parsePemList(self.pem_cert)
+ tack_ext = chain.getTackExt()
+ self.assertIsNone(tack_ext) # Assuming the sample cert doesn't have a TACK extension
+
+ def test_empty_chain(self):
+ chain = X509CertChain()
+ with self.assertRaises(ValueError):
+ chain.getEndEntityPublicKey()
+ with self.assertRaises(ValueError):
+ chain.getFingerprint()
+ self.assertIsNone(chain.getTackExt())
+
+class TestConstantTimeFunctions(unittest.TestCase):
+ def test_ct_lt_u32(self):
+ self.assertEqual(ct_lt_u32(5, 10), 1)
+ self.assertEqual(ct_lt_u32(10, 5), 0)
+ self.assertEqual(ct_lt_u32(5, 5), 0)
+
+ def test_ct_gt_u32(self):
+ self.assertEqual(ct_gt_u32(10, 5), 1)
+ self.assertEqual(ct_gt_u32(5, 10), 0)
+ self.assertEqual(ct_gt_u32(5, 5), 0)
+
+ def test_ct_le_u32(self):
+ self.assertEqual(ct_le_u32(5, 10), 1)
+ self.assertEqual(ct_le_u32(5, 5), 1)
+ self.assertEqual(ct_le_u32(10, 5), 0)
+
+ def test_ct_lsb_prop_u8(self):
+ self.assertEqual(ct_lsb_prop_u8(1), 0xFF)
+ self.assertEqual(ct_lsb_prop_u8(2), 0x00)
+
+ def test_ct_lsb_prop_u16(self):
+ self.assertEqual(ct_lsb_prop_u16(1), 0xFFFF)
+ self.assertEqual(ct_lsb_prop_u16(2), 0x0000)
+
+ def test_ct_isnonzero_u32(self):
+ self.assertEqual(ct_isnonzero_u32(0), 0)
+ self.assertEqual(ct_isnonzero_u32(1), 1)
+ self.assertEqual(ct_isnonzero_u32(100), 1)
+
+ def test_ct_neq_u32(self):
+ self.assertEqual(ct_neq_u32(5, 10), 1)
+ self.assertEqual(ct_neq_u32(5, 5), 0)
+
+ def test_ct_eq_u32(self):
+ self.assertEqual(ct_eq_u32(5, 5), 1)
+ self.assertEqual(ct_eq_u32(5, 10), 0)
+
+ def test_ct_check_cbc_mac_and_pad(self):
+ # This is a simplified test. In a real scenario, you'd need to set up
+ # proper HMAC, data, and other parameters.
+ data = bytearray(b'test' + b'\x0c'*12) # 4 bytes of data + 12 bytes of padding
+ mac = hmac.new(b'key', digestmod='sha256')
+ seqnumBytes = bytearray(8)
+ contentType = 23 # application_data
+ version = (3, 3) # TLS 1.2
+
+ result = ct_check_cbc_mac_and_pad(data, mac, seqnumBytes, contentType, version)
+ self.assertFalse(result) # This should fail as we didn't set up proper HMAC
+
+ def test_ct_compare_digest(self):
+ self.assertTrue(ct_compare_digest(b'same', b'same'))
+ self.assertFalse(ct_compare_digest(b'different', b'strings'))
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/test_x25519.py b/tests/test_x25519.py
new file mode 100644
index 0000000..5ca17db
--- /dev/null
+++ b/tests/test_x25519.py
@@ -0,0 +1,40 @@
+import unittest
+from tlslite.utils.x25519 import x25519, x448, X25519_G, X448_G
+
+class TestX25519(unittest.TestCase):
+ def test_x25519(self):
+ # Test vector from RFC 7748
+ scalar = bytes.fromhex('a546e36bf0527c9d3b16154b82465edd62144c0ac1fc5a18506a2244ba449ac4')
+ u_coordinate = bytes.fromhex('e6db6867583030db3594c1a424b15f7c726624ec26b3353b10a903a6d0ab1c4c')
+ expected_output = bytes.fromhex('c3da55379de9c6908e94ea4df28d084f32eccf03491c71f754b4075577a28552')
+
+ result = x25519(scalar, u_coordinate)
+ self.assertEqual(result, expected_output)
+
+ def test_x25519_base_point(self):
+ # Test with the base point
+ scalar = bytes.fromhex('a546e36bf0527c9d3b16154b82465edd62144c0ac1fc5a18506a2244ba449ac4')
+ expected_output = bytes.fromhex('4b66e9d4d1b4673c5ad22691957d6af5c11b6421e0ea01d42ca4169e7918ba0d')
+
+ result = x25519(scalar, X25519_G)
+ self.assertEqual(result, expected_output)
+
+ def test_x448(self):
+ # Test vector from RFC 7748
+ scalar = bytes.fromhex('3d262fddf9ec8e88495266fea19a34d28882acef045104d0d1aae121700a779c984c24f8cdd78fbff44943eba368f54b29259a4f1c600ad3')
+ u_coordinate = bytes.fromhex('06fce640fa3487bfda5f6cf2d5263f8aad88334cbd07437f020f08f9814dc031ddbdc38c19c6da2583fa5429db94ada18aa7a7fb4ef8a086')
+ expected_output = bytes.fromhex('ce3e4ff95a60dc6697da1db1d85e6afbdf79b50a2412d7546d5f239fe14fbaadeb445fc66a01b0779d98223961111e21766282f73dd96b6f')
+
+ result = x448(scalar, u_coordinate)
+ self.assertEqual(result, expected_output)
+
+ def test_x448_base_point(self):
+ # Test with the base point
+ scalar = bytes.fromhex('3d262fddf9ec8e88495266fea19a34d28882acef045104d0d1aae121700a779c984c24f8cdd78fbff44943eba368f54b29259a4f1c600ad3')
+ expected_output = bytes.fromhex('aa3b4749d55b9daf1e5b00288826c467274ce3ebbdd5c17b975e09d4af6c67cf10d087202db88286e2b79fceea3ec353ef54faa26e219f38')
+
+ result = x448(scalar, X448_G)
+ self.assertEqual(result, expected_output)
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tlslite/basedb.py b/tlslite/basedb.py
index 2941ed2..0612cb3 100644
--- a/tlslite/basedb.py
+++ b/tlslite/basedb.py
@@ -25,7 +25,11 @@ class BaseDB(object):
:raises anydbm.error: If there's a problem creating the database.
"""
- pass
+ if self.filename:
+ self.db = anydbm.open(self.filename, 'c')
+ self.db['--Reserved--'] = self.type
+ else:
+ raise ValueError("Filename not specified")
def open(self):
"""
@@ -34,7 +38,12 @@ class BaseDB(object):
:raises anydbm.error: If there's a problem opening the database.
:raises ValueError: If the database is not of the right type.
"""
- pass
+ if self.filename:
+ self.db = anydbm.open(self.filename, 'w')
+ if '--Reserved--' not in self.db or self.db['--Reserved--'] != self.type:
+ raise ValueError("Database is not of type %s" % self.type)
+ else:
+ raise ValueError("Filename not specified")
def __getitem__(self, username):
if self.db == None:
@@ -94,4 +103,13 @@ class BaseDB(object):
:rtype: list
:returns: The usernames in the database.
"""
- pass
+ if self.db is None:
+ raise AssertionError('DB not open')
+ self.lock.acquire()
+ try:
+ keys = list(self.db.keys())
+ if '--Reserved--' in keys:
+ keys.remove('--Reserved--')
+ return keys
+ finally:
+ self.lock.release()
diff --git a/tlslite/bufferedsocket.py b/tlslite/bufferedsocket.py
index 8fd2025..5634251 100644
--- a/tlslite/bufferedsocket.py
+++ b/tlslite/bufferedsocket.py
@@ -24,23 +24,36 @@ class BufferedSocket(object):
def send(self, data):
"""Send data to the socket"""
- pass
+ if self.buffer_writes:
+ self._write_queue.append(data)
+ return len(data)
+ else:
+ return self.socket.send(data)
def sendall(self, data):
"""Send data to the socket"""
- pass
+ if self.buffer_writes:
+ self._write_queue.append(data)
+ else:
+ self.socket.sendall(data)
def flush(self):
"""Send all buffered data"""
- pass
+ while self._write_queue:
+ data = self._write_queue.popleft()
+ self.socket.sendall(data)
def recv(self, bufsize):
"""Receive data from socket (socket emulation)"""
- pass
+ if self._read_buffer:
+ ret = bytes(self._read_buffer[:bufsize])
+ del self._read_buffer[:bufsize]
+ return ret
+ return self.socket.recv(bufsize)
def getsockname(self):
"""Return the socket's own address (socket emulation)."""
- pass
+ return self.socket.getsockname()
def getpeername(self):
"""
@@ -48,11 +61,11 @@ class BufferedSocket(object):
(socket emulation)
"""
- pass
+ return self.socket.getpeername()
def settimeout(self, value):
"""Set a timeout on blocking socket operations (socket emulation)."""
- pass
+ self.socket.settimeout(value)
def gettimeout(self):
"""
@@ -60,16 +73,18 @@ class BufferedSocket(object):
(socket emulation)
"""
- pass
+ return self.socket.gettimeout()
def setsockopt(self, level, optname, value):
"""Set the value of the given socket option (socket emulation)."""
- pass
+ self.socket.setsockopt(level, optname, value)
def shutdown(self, how):
"""Shutdown the underlying socket."""
- pass
+ self.flush()
+ self.socket.shutdown(how)
def close(self):
"""Close the underlying socket."""
- pass
+ self.flush()
+ self.socket.close()
diff --git a/tlslite/constants.py b/tlslite/constants.py
index 6ad3e02..94a0faa 100644
--- a/tlslite/constants.py
+++ b/tlslite/constants.py
@@ -14,7 +14,10 @@ class TLSEnum(object):
@classmethod
def _recursiveVars(cls, klass):
"""Call vars recursively on base classes"""
- pass
+ attributes = vars(klass)
+ for base in klass.__bases__:
+ attributes.update(cls._recursiveVars(base))
+ return attributes
@classmethod
def toRepr(cls, value, blacklist=None):
@@ -23,12 +26,21 @@ class TLSEnum(object):
name if found, None otherwise
"""
- pass
+ if blacklist is None:
+ blacklist = []
+ for name, val in cls._recursiveVars(cls).items():
+ if val == value and name not in blacklist:
+ return name
+ return None
@classmethod
def toStr(cls, value, blacklist=None):
"""Convert numeric type to human-readable string if possible"""
- pass
+ name = cls.toRepr(value, blacklist)
+ if name is None:
+ return str(value)
+ else:
+ return name
class CertificateType(TLSEnum):
@@ -203,17 +215,30 @@ class SignatureScheme(TLSEnum):
E.g. for "rsa_pkcs1_sha1" it returns "rsa"
"""
- pass
+ scheme_name = SignatureScheme.toRepr(scheme)
+ if scheme_name:
+ return scheme_name.split('_')[0]
+ return None
@staticmethod
def getPadding(scheme):
"""Return the name of padding scheme used in signature scheme."""
- pass
+ scheme_name = SignatureScheme.toRepr(scheme)
+ if scheme_name:
+ parts = scheme_name.split('_')
+ if len(parts) > 1:
+ return parts[1]
+ return None
@staticmethod
def getHash(scheme):
"""Return the name of hash used in signature scheme."""
- pass
+ scheme_name = SignatureScheme.toRepr(scheme)
+ if scheme_name:
+ parts = scheme_name.split('_')
+ if len(parts) > 2:
+ return parts[-1]
+ return None
class AlgorithmOID(TLSEnum):
@@ -986,24 +1011,46 @@ class CipherSuite:
@staticmethod
def filterForVersion(suites, minVersion, maxVersion):
"""Return a copy of suites without ciphers incompatible with version"""
- pass
+ return [suite for suite in suites if minVersion <= suite <= maxVersion]
@staticmethod
def filter_for_certificate(suites, cert_chain):
"""Return a copy of suites without ciphers incompatible with the cert.
"""
- pass
+ if not cert_chain:
+ return []
+
+ cert_type = cert_chain.getEndEntityPublicKey().key_type
+
+ compatible_suites = []
+ for suite in suites:
+ if cert_type == "rsa" and "RSA" in CipherSuite.ietfNames[suite]:
+ compatible_suites.append(suite)
+ elif cert_type == "ecdsa" and "ECDSA" in CipherSuite.ietfNames[suite]:
+ compatible_suites.append(suite)
+
+ return compatible_suites
@staticmethod
def filter_for_prfs(suites, prfs):
"""Return a copy of suites without ciphers incompatible with the
specified prfs (sha256 or sha384)"""
- pass
+ compatible_suites = []
+ for suite in suites:
+ suite_name = CipherSuite.ietfNames[suite]
+ if "SHA256" in suite_name and "sha256" in prfs:
+ compatible_suites.append(suite)
+ elif "SHA384" in suite_name and "sha384" in prfs:
+ compatible_suites.append(suite)
+ return compatible_suites
@classmethod
def getTLS13Suites(cls, settings, version=None):
"""Return cipher suites that are TLS 1.3 specific."""
- pass
+ suites = cls.tls13Suites[:]
+ if version:
+ suites = cls.filterForVersion(suites, version, version)
+ return suites
srpSuites = []
srpSuites.append(TLS_SRP_SHA_WITH_AES_256_CBC_SHA)
srpSuites.append(TLS_SRP_SHA_WITH_AES_128_CBC_SHA)
@@ -1162,12 +1209,22 @@ class CipherSuite:
@staticmethod
def canonicalCipherName(ciphersuite):
"""Return the canonical name of the cipher whose number is provided."""
- pass
+ name = CipherSuite.ietfNames.get(ciphersuite)
+ if name:
+ parts = name.split('_')
+ if len(parts) >= 4:
+ return parts[3].lower()
+ return None
@staticmethod
def canonicalMacName(ciphersuite):
"""Return the canonical name of the MAC whose number is provided."""
- pass
+ name = CipherSuite.ietfNames.get(ciphersuite)
+ if name:
+ parts = name.split('_')
+ if len(parts) >= 5:
+ return parts[4].lower()
+ return None
class Fault:
diff --git a/tlslite/defragmenter.py b/tlslite/defragmenter.py
index 3a1038b..d5a9f8e 100644
--- a/tlslite/defragmenter.py
+++ b/tlslite/defragmenter.py
@@ -33,27 +33,43 @@ class Defragmenter(object):
@deprecated_params({'msg_type': 'msgType'})
def add_static_size(self, msg_type, size):
"""Add a message type which all messages are of same length"""
- pass
+ self.priorities.append(msg_type)
+ self.buffers[msg_type] = bytearray()
+ self.decoders[msg_type] = lambda x: len(x) >= size
@deprecated_params({'msg_type': 'msgType', 'size_offset': 'sizeOffset',
'size_of_size': 'sizeOfSize'})
def add_dynamic_size(self, msg_type, size_offset, size_of_size):
"""Add a message type which has a dynamic size set in a header"""
- pass
+ self.priorities.append(msg_type)
+ self.buffers[msg_type] = bytearray()
+ def decoder(x):
+ if len(x) < size_offset + size_of_size:
+ return False
+ size = Parser(x[size_offset:size_offset+size_of_size]).getFixBytes(size_of_size)
+ return len(x) >= size_offset + size_of_size + size
+ self.decoders[msg_type] = decoder
@deprecated_params({'msg_type': 'msgType'})
def add_data(self, msg_type, data):
"""Adds data to buffers"""
- pass
+ if msg_type in self.buffers:
+ self.buffers[msg_type].extend(data)
def get_message(self):
"""Extract the highest priority complete message from buffer"""
- pass
+ for msg_type in self.priorities:
+ if msg_type in self.buffers and self.decoders[msg_type](self.buffers[msg_type]):
+ message = bytes(self.buffers[msg_type])
+ self.buffers[msg_type] = bytearray()
+ return (msg_type, message)
+ return None
def clear_buffers(self):
"""Remove all data from buffers"""
- pass
+ for msg_type in self.buffers:
+ self.buffers[msg_type] = bytearray()
def is_empty(self):
"""Return True if all buffers are empty."""
- pass
+ return all(len(buffer) == 0 for buffer in self.buffers.values())
diff --git a/tlslite/dh.py b/tlslite/dh.py
index a8cab41..c6ba917 100644
--- a/tlslite/dh.py
+++ b/tlslite/dh.py
@@ -11,7 +11,11 @@ def parseBinary(data):
:param bytes data: DH parameters
:rtype: tuple of int
"""
- pass
+ parser = ASN1Parser(data)
+ sequence = parser.getChild()
+ p = bytesToNumber(sequence.getChildBytes(0))
+ g = bytesToNumber(sequence.getChildBytes(1))
+ return (p, g)
def parse(data):
@@ -24,4 +28,9 @@ def parse(data):
:rtype: tuple of int
:returns: generator and prime
"""
- pass
+ try:
+ der = dePem(data, "DH PARAMETERS")
+ except ValueError:
+ der = data
+
+ return parseBinary(der)
diff --git a/tlslite/handshakehashes.py b/tlslite/handshakehashes.py
index 6638766..f21b389 100644
--- a/tlslite/handshakehashes.py
+++ b/tlslite/handshakehashes.py
@@ -28,7 +28,13 @@ class HandshakeHashes(object):
:param bytearray data: serialized TLS handshake message
"""
- pass
+ self._handshakeMD5.update(compatHMAC(data))
+ self._handshakeSHA.update(compatHMAC(data))
+ self._handshakeSHA224.update(compatHMAC(data))
+ self._handshakeSHA256.update(compatHMAC(data))
+ self._handshakeSHA384.update(compatHMAC(data))
+ self._handshakeSHA512.update(compatHMAC(data))
+ self._handshake_buffer += data
def digest(self, digest=None):
"""
@@ -38,7 +44,22 @@ class HandshakeHashes(object):
:param str digest: name of digest to return
"""
- pass
+ if digest is None:
+ return self._handshakeMD5.digest() + self._handshakeSHA.digest()
+ elif digest == 'md5':
+ return self._handshakeMD5.digest()
+ elif digest == 'sha1':
+ return self._handshakeSHA.digest()
+ elif digest == 'sha224':
+ return self._handshakeSHA224.digest()
+ elif digest == 'sha256':
+ return self._handshakeSHA256.digest()
+ elif digest == 'sha384':
+ return self._handshakeSHA384.digest()
+ elif digest == 'sha512':
+ return self._handshakeSHA512.digest()
+ else:
+ raise ValueError("Unknown digest type")
def digestSSL(self, masterSecret, label):
"""
@@ -49,7 +70,17 @@ class HandshakeHashes(object):
:param bytearray masterSecret: value of the master secret
:param bytearray label: label to include in the calculation
"""
- pass
+ md5_hash = MD5()
+ md5_hash.update(label)
+ md5_hash.update(masterSecret)
+ md5_hash.update(compatHMAC(MD5(self._handshake_buffer).digest()))
+
+ sha_hash = SHA1()
+ sha_hash.update(label)
+ sha_hash.update(masterSecret)
+ sha_hash.update(compatHMAC(SHA1(self._handshake_buffer).digest()))
+
+ return md5_hash.digest() + sha_hash.digest()
def copy(self):
"""
@@ -60,4 +91,12 @@ class HandshakeHashes(object):
:rtype: HandshakeHashes
"""
- pass
+ new = HandshakeHashes()
+ new._handshakeMD5 = self._handshakeMD5.copy()
+ new._handshakeSHA = self._handshakeSHA.copy()
+ new._handshakeSHA224 = self._handshakeSHA224.copy()
+ new._handshakeSHA256 = self._handshakeSHA256.copy()
+ new._handshakeSHA384 = self._handshakeSHA384.copy()
+ new._handshakeSHA512 = self._handshakeSHA512.copy()
+ new._handshake_buffer = self._handshake_buffer[:]
+ return new
diff --git a/tlslite/handshakehelpers.py b/tlslite/handshakehelpers.py
index 67f7657..284abb2 100644
--- a/tlslite/handshakehelpers.py
+++ b/tlslite/handshakehelpers.py
@@ -17,7 +17,17 @@ class HandshakeHelpers(object):
:param ClientHello clientHello: ClientHello to be aligned
"""
- pass
+ current_length = len(clientHello.write())
+ target_length = ((current_length + 511) // 512) * 512
+ padding_length = target_length - current_length
+
+ padding_extension = next((ext for ext in clientHello.extensions
+ if isinstance(ext, PaddingExtension)), None)
+
+ if padding_extension:
+ padding_extension.paddingData = bytearray(padding_length)
+ else:
+ clientHello.extensions.append(PaddingExtension().create(padding_length))
@staticmethod
def _calc_binder(prf, psk, handshake_hash, external=True):
@@ -25,12 +35,26 @@ class HandshakeHelpers(object):
Calculate the binder value for a given HandshakeHash (that includes
a truncated client hello already)
"""
- pass
+ if external:
+ label = b"ext binder"
+ else:
+ label = b"res binder"
+
+ early_secret = secureHMAC(bytearray(len(prf.digest())), psk, prf)
+ binder_key = derive_secret(early_secret, label, None, prf)
+ return secureHMAC(binder_key, handshake_hash.digest(prf), prf)
@staticmethod
def calc_res_binder_psk(iden, res_master_secret, tickets):
"""Calculate PSK associated with provided ticket identity."""
- pass
+ for ticket in tickets:
+ if ticket.ticket == iden:
+ prf = ticket.prf
+ hash_name = prf.name
+ nonce = ticket.ticket_nonce
+ return HKDF_expand_label(res_master_secret, b"resumption",
+ nonce, prf.digest_size, prf)
+ raise TLSIllegalParameterException("Ticket not found")
@staticmethod
def update_binders(client_hello, handshake_hashes, psk_configs, tickets
@@ -48,7 +72,25 @@ class HandshakeHelpers(object):
:param bytearray res_master_secret: secret associated with the
tickets
"""
- pass
+ psk_ext = next((ext for ext in client_hello.extensions
+ if isinstance(ext, PreSharedKeyExtension)), None)
+ if not psk_ext:
+ return
+
+ binders = []
+ for i, (identity, psk) in enumerate(psk_configs):
+ if isinstance(psk, bytearray):
+ external = True
+ else:
+ external = False
+ psk = HandshakeHelpers.calc_res_binder_psk(identity, res_master_secret, tickets)
+
+ binder = HandshakeHelpers._calc_binder(psk_ext.prf, psk,
+ handshake_hashes.copy(),
+ external)
+ binders.append(binder)
+
+ psk_ext.binders = binders
@staticmethod
def verify_binder(client_hello, handshake_hashes, position, secret, prf,
@@ -61,4 +103,18 @@ class HandshakeHelpers(object):
:param secret: the secret PSK
:param prf: name of the hash used as PRF
"""
- pass
+ psk_ext = next((ext for ext in client_hello.extensions
+ if isinstance(ext, PreSharedKeyExtension)), None)
+ if not psk_ext:
+ raise TLSIllegalParameterException("No PSK extension")
+
+ if position >= len(psk_ext.binders):
+ raise TLSIllegalParameterException("Invalid binder position")
+
+ binder = psk_ext.binders[position]
+ calculated_binder = HandshakeHelpers._calc_binder(prf, secret,
+ handshake_hashes.copy(),
+ external)
+
+ if not ct_compare_digest(binder, calculated_binder):
+ raise TLSIllegalParameterException("Binder does not verify")
diff --git a/tlslite/handshakesettings.py b/tlslite/handshakesettings.py
index 0dc638a..0856f5d 100644
--- a/tlslite/handshakesettings.py
+++ b/tlslite/handshakesettings.py
@@ -331,11 +331,35 @@ class HandshakeSettings(object):
def _init_key_settings(self):
"""Create default variables for key-related settings."""
- pass
+ self.minKeySize = 1023
+ self.maxKeySize = 8193
+ self.certificateTypes = list(CERTIFICATE_TYPES)
+ self.rsaSigHashes = list(RSA_SIGNATURE_HASHES)
+ self.dsaSigHashes = list(DSA_SIGNATURE_HASHES)
+ self.ecdsaSigHashes = list(ECDSA_SIGNATURE_HASHES)
+ self.more_sig_schemes = list(SIGNATURE_SCHEMES)
+ self.eccCurves = list(CURVE_NAMES)
+ self.defaultCurve = "secp256r1"
+ self.keyShares = ["secp256r1", "x25519"]
def _init_misc_extensions(self):
"""Default variables for assorted extensions."""
- pass
+ self.useExperimentalTackExtension = False
+ self.sendFallbackSCSV = True
+ self.useEncryptThenMAC = True
+ self.useExtendedMasterSecret = True
+ self.requireExtendedMasterSecret = False
+ self.padding_cb = None
+ self.pskConfigs = []
+ self.ticketKeys = []
+ self.ticketCipher = "aes256gcm"
+ self.ticketLifetime = 86400 # 1 day
+ self.ticket_count = 1
+ self.psk_modes = list(PSK_MODES)
+ self.max_early_data = 0
+ self.use_heartbeat_extension = True
+ self.heartbeat_response_callback = None
+ self.record_size_limit = None
def __init__(self):
"""Initialise default values for settings."""
@@ -352,87 +376,138 @@ class HandshakeSettings(object):
@staticmethod
def _sanityCheckKeySizes(other):
"""Check if key size limits are sane"""
- pass
+ if other.minKeySize < 512:
+ raise ValueError("minKeySize too small")
+ if other.minKeySize > other.maxKeySize:
+ raise ValueError("minKeySize can't be greater than maxKeySize")
@staticmethod
def _not_matching(values, sieve):
"""Return list of items from values that are not in sieve."""
- pass
+ return [val for val in values if val not in sieve]
@staticmethod
def _sanityCheckCipherSettings(other):
"""Check if specified cipher settings are known."""
- pass
+ unknown = HandshakeSettings._not_matching(other.cipherNames, ALL_CIPHER_NAMES)
+ if unknown:
+ raise ValueError("Unknown cipher name: {0}".format(unknown))
@staticmethod
def _sanityCheckECDHSettings(other):
"""Check ECDHE settings if they are sane."""
- pass
+ unknown = HandshakeSettings._not_matching(other.eccCurves, ALL_CURVE_NAMES)
+ if unknown:
+ raise ValueError("Unknown ECC curve name: {0}".format(unknown))
@staticmethod
def _sanityCheckDHSettings(other):
"""Check if (EC)DHE settings are sane."""
- pass
+ HandshakeSettings._sanityCheckECDHSettings(other)
+ unknown = HandshakeSettings._not_matching(other.dhGroups, ALL_DH_GROUP_NAMES)
+ if unknown:
+ raise ValueError("Unknown DH group name: {0}".format(unknown))
@staticmethod
def _sanityCheckPrimitivesNames(other):
"""Check if specified cryptographic primitive names are known"""
- pass
+ unknown = HandshakeSettings._not_matching(other.macNames, ALL_MAC_NAMES)
+ if unknown:
+ raise ValueError("Unknown MAC name: {0}".format(unknown))
@staticmethod
def _sanityCheckProtocolVersions(other):
"""Check if set protocol version are sane"""
- pass
+ if other.minVersion > other.maxVersion:
+ raise ValueError("Versions set incorrectly")
+ if other.minVersion not in KNOWN_VERSIONS:
+ raise ValueError("minVersion set incorrectly")
+ if other.maxVersion not in KNOWN_VERSIONS:
+ raise ValueError("maxVersion set incorrectly")
@staticmethod
def _sanityCheckEMSExtension(other):
"""Check if settings for EMS are sane."""
- pass
+ if other.requireExtendedMasterSecret and not other.useExtendedMasterSecret:
+ raise ValueError("Require EMS must have EMS enabled")
@staticmethod
def _sanityCheckExtensions(other):
"""Check if set extension settings are sane"""
- pass
+ if other.useEncryptThenMAC and not other.macNames:
+ raise ValueError("Encrypt-then-MAC requires MAC")
+ HandshakeSettings._sanityCheckEMSExtension(other)
@staticmethod
def _not_allowed_len(values, sieve):
"""Return True if length of any item in values is not in sieve."""
- pass
+ return any(len(val) not in sieve for val in values)
@staticmethod
def _sanityCheckPsks(other):
"""Check if the set PSKs are sane."""
- pass
+ if HandshakeSettings._not_allowed_len(other.pskConfigs, [2, 3]):
+ raise ValueError("pskConfigs items must be a 2 or 3-element tuple")
@staticmethod
def _sanityCheckTicketSettings(other):
"""Check if the session ticket settings are sane."""
- pass
+ if other.ticketKeys and len(other.ticketKeys[0]) not in (16, 32):
+ raise ValueError("Session ticket encryption keys must be 16 or 32 bytes long")
+ if other.ticketLifetime <= 0:
+ raise ValueError("Ticket lifetime must be a positive integer")
def _copy_cipher_settings(self, other):
"""Copy values related to cipher selection."""
- pass
+ other.cipherNames = self.cipherNames[:]
+ other.macNames = self.macNames[:]
+ other.keyExchangeNames = self.keyExchangeNames[:]
+ other.cipherImplementations = self.cipherImplementations[:]
def _copy_extension_settings(self, other):
"""Copy values of settings related to extensions."""
- pass
+ other.useExperimentalTackExtension = self.useExperimentalTackExtension
+ other.sendFallbackSCSV = self.sendFallbackSCSV
+ other.useEncryptThenMAC = self.useEncryptThenMAC
+ other.useExtendedMasterSecret = self.useExtendedMasterSecret
+ other.requireExtendedMasterSecret = self.requireExtendedMasterSecret
+ other.use_heartbeat_extension = self.use_heartbeat_extension
+ other.heartbeat_response_callback = self.heartbeat_response_callback
+ other.record_size_limit = self.record_size_limit
@staticmethod
def _remove_all_matches(values, needle):
"""Remove all instances of needle from values."""
- pass
+ return [val for val in values if val != needle]
def _sanity_check_ciphers(self, other):
"""Remove unsupported ciphers in current configuration."""
- pass
+ if not other.cipherNames:
+ other.cipherNames = self.cipherNames[:]
+ if not other.macNames:
+ other.macNames = self.macNames[:]
+ if not other.keyExchangeNames:
+ other.keyExchangeNames = self.keyExchangeNames[:]
def _sanity_check_implementations(self, other):
"""Remove all backends that are not loaded."""
- pass
+ if not other.cipherImplementations:
+ other.cipherImplementations = self.cipherImplementations[:]
+ other.cipherImplementations = [impl for impl in other.cipherImplementations
+ if impl in CIPHER_IMPLEMENTATIONS]
def _copy_key_settings(self, other):
"""Copy key-related settings."""
- pass
+ other.minKeySize = self.minKeySize
+ other.maxKeySize = self.maxKeySize
+ other.certificateTypes = self.certificateTypes[:]
+ other.rsaSigHashes = self.rsaSigHashes[:]
+ other.dsaSigHashes = self.dsaSigHashes[:]
+ other.ecdsaSigHashes = self.ecdsaSigHashes[:]
+ other.more_sig_schemes = self.more_sig_schemes[:]
+ other.eccCurves = self.eccCurves[:]
+ other.defaultCurve = self.defaultCurve
+ other.keyShares = self.keyShares[:]
def validate(self):
"""
@@ -443,8 +518,29 @@ class HandshakeSettings(object):
:returns: a self-consistent copy of settings
:raises ValueError: when settings are invalid, insecure or unsupported.
"""
- pass
+ other = HandshakeSettings()
+ other._copy_cipher_settings(self)
+ other._copy_extension_settings(self)
+ other._copy_key_settings(self)
+
+ other.minVersion = self.minVersion
+ other.maxVersion = self.maxVersion
+ other.versions = self.versions[:]
+
+ HandshakeSettings._sanityCheckKeySizes(other)
+ HandshakeSettings._sanityCheckCipherSettings(other)
+ HandshakeSettings._sanityCheckPrimitivesNames(other)
+ HandshakeSettings._sanityCheckProtocolVersions(other)
+ HandshakeSettings._sanityCheckExtensions(other)
+ HandshakeSettings._sanityCheckDHSettings(other)
+ HandshakeSettings._sanityCheckPsks(other)
+ HandshakeSettings._sanityCheckTicketSettings(other)
+
+ self._sanity_check_ciphers(other)
+ self._sanity_check_implementations(other)
+
+ return other
def getCertificateTypes(self):
"""Get list of certificate types as IDs"""
- pass
+ return [getattr(CertificateType, cert_type) for cert_type in self.certificateTypes]
diff --git a/tlslite/integration/asyncstatemachine.py b/tlslite/integration/asyncstatemachine.py
index 2aa8a79..1d44a6a 100644
--- a/tlslite/integration/asyncstatemachine.py
+++ b/tlslite/integration/asyncstatemachine.py
@@ -37,7 +37,15 @@ class AsyncStateMachine:
:rtype: bool or None
:returns: If the state machine wants to read.
"""
- pass
+ if self.handshaker:
+ return self.handshaker.wantsReadEvent()
+ elif self.closer:
+ return self.closer.wantsReadEvent()
+ elif self.reader:
+ return True
+ elif self.writer:
+ return False
+ return None
def wantsWriteEvent(self):
"""If the state machine wants to write.
@@ -49,7 +57,15 @@ class AsyncStateMachine:
:rtype: bool or None
:returns: If the state machine wants to write.
"""
- pass
+ if self.handshaker:
+ return self.handshaker.wantsWriteEvent()
+ elif self.closer:
+ return self.closer.wantsWriteEvent()
+ elif self.writer:
+ return True
+ elif self.reader:
+ return False
+ return None
def outConnectEvent(self):
"""Called when a handshake operation completes.
@@ -79,11 +95,43 @@ class AsyncStateMachine:
def inReadEvent(self):
"""Tell the state machine it can read from the socket."""
- pass
+ try:
+ if self.handshaker:
+ self.result = next(self.handshaker)
+ if self.result is None:
+ self.handshaker = None
+ self.outConnectEvent()
+ elif self.closer:
+ self.result = next(self.closer)
+ if self.result is None:
+ self.closer = None
+ self.outCloseEvent()
+ elif self.reader:
+ readBuffer = self.reader.read()
+ self.reader = None
+ self.outReadEvent(readBuffer)
+ except StopIteration:
+ self._clear()
def inWriteEvent(self):
"""Tell the state machine it can write to the socket."""
- pass
+ try:
+ if self.handshaker:
+ self.result = next(self.handshaker)
+ if self.result is None:
+ self.handshaker = None
+ self.outConnectEvent()
+ elif self.closer:
+ self.result = next(self.closer)
+ if self.result is None:
+ self.closer = None
+ self.outCloseEvent()
+ elif self.writer:
+ self.writer.write()
+ self.writer = None
+ self.outWriteEvent()
+ except StopIteration:
+ self._clear()
def setHandshakeOp(self, handshaker):
"""Start a handshake operation.
@@ -93,7 +141,8 @@ class AsyncStateMachine:
:py:meth:`~.TLSConnection.handshakeServerAsync` , or
handshakeClientxxx(..., async_=True).
"""
- pass
+ self._clear()
+ self.handshaker = handshaker
def setServerHandshakeOp(self, **args):
"""Start a handshake operation.
@@ -101,16 +150,19 @@ class AsyncStateMachine:
The arguments passed to this function will be forwarded to
:py:obj:`~tlslite.tlsconnection.TLSConnection.handshakeServerAsync`.
"""
- pass
+ self._clear()
+ self.handshaker = self.tlsConnection.handshakeServerAsync(**args)
def setCloseOp(self):
"""Start a close operation.
"""
- pass
+ self._clear()
+ self.closer = self.tlsConnection.closeAsync()
def setWriteOp(self, writeBuffer):
"""Start a write operation.
:param str writeBuffer: The string to transmit.
"""
- pass
+ self._clear()
+ self.writer = self.tlsConnection.writeAsync(writeBuffer)
diff --git a/tlslite/integration/clienthelper.py b/tlslite/integration/clienthelper.py
index 3bda03e..bab3b36 100644
--- a/tlslite/integration/clienthelper.py
+++ b/tlslite/integration/clienthelper.py
@@ -101,4 +101,26 @@ class ClientHelper(object):
@staticmethod
def _isIP(address):
"""Return True if the address is an IPv4 address"""
- pass
+ try:
+ # Split the address into octets
+ octets = address.split('.')
+
+ # Check if we have exactly 4 octets
+ if len(octets) != 4:
+ return False
+
+ # Check each octet
+ for octet in octets:
+ # Convert to integer
+ num = int(octet)
+ # Check if it's between 0 and 255
+ if num < 0 or num > 255:
+ return False
+ # Check if it doesn't have leading zeros (except for 0)
+ if len(octet) > 1 and octet[0] == '0':
+ return False
+
+ return True
+ except ValueError:
+ # If we can't convert to int, it's not a valid IP
+ return False
diff --git a/tlslite/integration/httptlsconnection.py b/tlslite/integration/httptlsconnection.py
index 9d6129c..2241b22 100644
--- a/tlslite/integration/httptlsconnection.py
+++ b/tlslite/integration/httptlsconnection.py
@@ -11,10 +11,9 @@ from tlslite.integration.clienthelper import ClientHelper
class HTTPTLSConnection(httplib.HTTPConnection, ClientHelper):
"""This class extends L{httplib.HTTPConnection} to support TLS."""
- def __init__(self, host, port=None, strict=None, timeout=socket.
- _GLOBAL_DEFAULT_TIMEOUT, source_address=None, username=None,
- password=None, certChain=None, privateKey=None, checker=None,
- settings=None, ignoreAbruptClose=False, anon=False):
+ def __init__(self, host, port=None, strict=None, timeout=socket._GLOBAL_DEFAULT_TIMEOUT,
+ source_address=None, username=None, password=None, certChain=None,
+ privateKey=None, checker=None, settings=None, ignoreAbruptClose=False, anon=False):
"""Create a new HTTPTLSConnection.
For client authentication, use one of these argument
@@ -80,10 +79,42 @@ class HTTPTLSConnection(httplib.HTTPConnection, ClientHelper):
"""
if source_address:
httplib.HTTPConnection.__init__(self, host=host, port=port,
- timeout=timeout, source_address=source_address)
- if not source_address:
+ timeout=timeout, source_address=source_address)
+ else:
httplib.HTTPConnection.__init__(self, host=host, port=port,
- timeout=timeout)
+ timeout=timeout)
self.ignoreAbruptClose = ignoreAbruptClose
ClientHelper.__init__(self, username, password, certChain,
- privateKey, checker, settings, anon, host)
+ privateKey, checker, settings, anon, host)
+ self.tlsConnection = None
+
+ def connect(self):
+ """Connect to a host on a given (SSL) port."""
+ httplib.HTTPConnection.connect(self)
+
+ self.sock = TLSConnection(self.sock)
+
+ try:
+ self.start_tls()
+ self.tlsConnection = self.sock
+ except:
+ self.close()
+ raise
+
+ def close(self):
+ """Close the connection to the HTTP server."""
+ if self.tlsConnection:
+ self.tlsConnection.close()
+ httplib.HTTPConnection.close(self)
+
+ def send(self, data):
+ """Send `data` to the server."""
+ if self.tlsConnection:
+ self.tlsConnection.send(data)
+ else:
+ httplib.HTTPConnection.send(self, data)
+
+ def _tunnel(self):
+ """Set up the tunnel to the server."""
+ self.connect()
+ httplib.HTTPConnection._tunnel(self)
diff --git a/tlslite/integration/imap4_tls.py b/tlslite/integration/imap4_tls.py
index 8583c78..1687200 100644
--- a/tlslite/integration/imap4_tls.py
+++ b/tlslite/integration/imap4_tls.py
@@ -76,4 +76,13 @@ class IMAP4_TLS(IMAP4, ClientHelper):
This connection will be used by the routines:
read, readline, send, shutdown.
"""
- pass
+ self.host = host
+ self.port = port
+ self.timeout = timeout
+
+ sock = socket.create_connection((host, port), timeout)
+ self.sock = TLSConnection(sock)
+
+ self._handshake()
+
+ self.file = self.sock.makefile('rb')
diff --git a/tlslite/integration/smtp_tls.py b/tlslite/integration/smtp_tls.py
index 64501b6..5e84d75 100644
--- a/tlslite/integration/smtp_tls.py
+++ b/tlslite/integration/smtp_tls.py
@@ -60,4 +60,30 @@ class SMTP_TLS(SMTP):
the ciphersuites, certificate types, and SSL/TLS versions
offered by the client.
"""
- pass
+ # First, send the STARTTLS command to the SMTP server
+ (code, resp) = self.docmd("STARTTLS")
+ if code != 220:
+ raise SMTPException("STARTTLS extension not supported by server.")
+
+ # Create a TLSConnection instance
+ tlsConnection = TLSConnection(self.sock)
+
+ # Set up the ClientHelper
+ helper = ClientHelper(username, password, certChain, privateKey, checker, settings)
+
+ # Perform the TLS handshake
+ try:
+ helper.handshakeClientCert(tlsConnection)
+ except Exception as e:
+ raise SMTPException(f"TLS handshake failed: {str(e)}")
+
+ # Replace the original socket with the TLS connection
+ self.sock = tlsConnection
+ self.file = tlsConnection.makefile('rb')
+
+ # Re-initialize the SMTP connection
+ (code, msg) = self.ehlo()
+ if code != 250:
+ raise SMTPException("EHLO failed after STARTTLS")
+
+ return (code, msg)
diff --git a/tlslite/integration/tlsasynciodispatchermixin.py b/tlslite/integration/tlsasynciodispatchermixin.py
index fc83a77..3c1b3d1 100644
--- a/tlslite/integration/tlsasynciodispatchermixin.py
+++ b/tlslite/integration/tlsasynciodispatchermixin.py
@@ -82,44 +82,59 @@ class TLSAsyncioDispatcherMixIn(asyncio.Protocol):
def _get_sibling_class(self):
"""Get the sibling class that this class is mixed in with."""
- pass
+ for base in self.__class__.__bases__:
+ if base is not TLSAsyncioDispatcherMixIn:
+ return base
+ return None
def readable(self):
"""Check if the protocol is ready for reading."""
- pass
+ return self.tls_connection.recv_buffer_size() > 0
def writable(self):
"""Check if the protocol is ready for writing."""
- pass
+ return self.tls_connection.send_buffer_size() > 0
def handle_read(self):
"""Handle a read event."""
- pass
+ try:
+ data = self.tls_connection.recv(16384)
+ if data:
+ self.sibling_class.data_received(self, data)
+ except Exception as e:
+ self.close()
def handle_write(self):
"""Handle a write event."""
- pass
+ try:
+ sent = self.tls_connection.send(b'')
+ if sent == 0:
+ self.sibling_class.pause_writing(self)
+ except Exception as e:
+ self.close()
def out_connect_event(self):
"""Handle an outgoing connect event."""
- pass
+ self.sibling_class.connection_made(self, self.tls_connection)
def out_close_event(self):
"""Handle an outgoing close event."""
- pass
+ self.sibling_class.connection_lost(self, None)
def out_read_event(self, read_buffer):
"""Handle an outgoing read event."""
- pass
+ self.sibling_class.data_received(self, read_buffer)
def out_write_event(self):
"""Handle an outgoing write event."""
- pass
+ self.sibling_class.resume_writing(self)
def recv(self, buffer_size=16384):
"""Receive data."""
- pass
+ return self.tls_connection.recv(buffer_size)
def close(self):
"""Close the connection."""
- pass
+ if self.tls_connection:
+ self.tls_connection.close()
+ self.sibling_class.connection_lost(self, None)
diff --git a/tlslite/integration/xmlrpcserver.py b/tlslite/integration/xmlrpcserver.py
index 1d62452..61921f4 100644
--- a/tlslite/integration/xmlrpcserver.py
+++ b/tlslite/integration/xmlrpcserver.py
@@ -11,11 +11,32 @@ class TLSXMLRPCRequestHandler(SimpleXMLRPCRequestHandler):
def setup(self):
"""Setup the connection for TLS."""
- pass
+ self.connection = self.request
+ self.rfile = self.connection.makefile('rb', self.rbufsize)
+ self.wfile = self.connection.makefile('wb', self.wbufsize)
def do_POST(self):
"""Handle the HTTPS POST request."""
- pass
+ try:
+ # Get the request data
+ content_len = int(self.headers.get('content-length', 0))
+ post_body = self.rfile.read(content_len)
+
+ # Process the request
+ response = self.server._marshaled_dispatch(
+ post_body, getattr(self, '_dispatch', None), self.path
+ )
+
+ # Send response
+ self.send_response(200)
+ self.send_header("Content-type", "text/xml")
+ self.send_header("Content-length", str(len(response)))
+ self.end_headers()
+ self.wfile.write(response)
+ self.wfile.flush()
+ except Exception: # This is the same behavior as in SimpleXMLRPCRequestHandler
+ self.send_response(500)
+ self.end_headers()
class TLSXMLRPCServer(TLSSocketServerMixIn, SimpleXMLRPCServer):
@@ -31,7 +52,52 @@ class MultiPathTLSXMLRPCServer(TLSXMLRPCServer):
"""Multipath XML-RPC Server using TLS."""
def __init__(self, addr, *args, **kwargs):
- TLSXMLRPCServer.__init__(addr, *args, **kwargs)
+ TLSXMLRPCServer.__init__(self, addr, *args, **kwargs)
self.dispatchers = {}
- self.allow_none = allow_none
- self.encoding = encoding
+ self.allow_none = kwargs.get('allow_none', False)
+ self.encoding = kwargs.get('encoding', 'utf-8')
+
+ def add_dispatcher(self, path, dispatcher):
+ self.dispatchers[path] = dispatcher
+
+ def get_dispatcher(self, path):
+ return self.dispatchers.get(path, self.instance)
+
+ def _marshaled_dispatch(self, data, dispatch_method=None, path=None):
+ try:
+ params, method = xmlrpclib.loads(data)
+
+ # Get the appropriate dispatcher based on the path
+ dispatcher = self.get_dispatcher(path)
+
+ if dispatch_method is not None:
+ response = dispatch_method(dispatcher, method, params)
+ else:
+ response = self._dispatch(dispatcher, method, params)
+
+ # Convert the response to XML-RPC format
+ response = (response,)
+ response = xmlrpclib.dumps(response, methodresponse=1,
+ allow_none=self.allow_none, encoding=self.encoding)
+ except Fault as fault:
+ response = xmlrpclib.dumps(fault, allow_none=self.allow_none,
+ encoding=self.encoding)
+ except:
+ # Report exception back to server
+ response = xmlrpclib.dumps(
+ xmlrpclib.Fault(1, "%s:%s" % (sys.exc_info()[0], sys.exc_info()[1])),
+ encoding=self.encoding, allow_none=self.allow_none,
+ )
+
+ return response.encode(self.encoding)
+
+ def _dispatch(self, dispatcher, method, params):
+ try:
+ # Check if the requested method is available in the dispatcher
+ func = getattr(dispatcher, 'dispatch')
+ if callable(func):
+ return func(method, params)
+ else:
+ raise Exception('method "%s" is not supported' % method)
+ except Exception as e:
+ raise xmlrpclib.Fault(1, str(e))
diff --git a/tlslite/integration/xmlrpctransport.py b/tlslite/integration/xmlrpctransport.py
index 256ade1..38594ef 100644
--- a/tlslite/integration/xmlrpctransport.py
+++ b/tlslite/integration/xmlrpctransport.py
@@ -95,4 +95,23 @@ class XMLRPCTransport(xmlrpclib.Transport, ClientHelper):
def make_connection(self, host):
"""Make a connection to `host`. Reuse keepalive connections."""
- pass
+ if self.conn_class_is_http:
+ # For Python 2.6 and earlier
+ chost, self._extra_headers, x509 = self.get_host_info(host)
+ return HTTPTLSConnection(chost, None,
+ self.username, self.password,
+ self.certChain, self.privateKey,
+ self.checker, self.settings,
+ self.ignoreAbruptClose)
+ else:
+ # For Python 2.7 and later
+ if self._connection and host == self._connection[0]:
+ return self._connection[1]
+
+ chost, self._extra_headers, x509 = self.get_host_info(host)
+ self._connection = host, HTTPTLSConnection(chost, None,
+ self.username, self.password,
+ self.certChain, self.privateKey,
+ self.checker, self.settings,
+ self.ignoreAbruptClose)
+ return self._connection[1]
diff --git a/tlslite/keyexchange.py b/tlslite/keyexchange.py
index 847c510..4eee7cf 100644
--- a/tlslite/keyexchange.py
+++ b/tlslite/keyexchange.py
@@ -36,7 +36,7 @@ class KeyExchange(object):
handshake. If the key exchange method does not send ServerKeyExchange
(e.g. RSA), it returns None.
"""
- pass
+ raise NotImplementedError("Subclasses must implement this method")
def makeClientKeyExchange(self):
"""
@@ -45,7 +45,7 @@ class KeyExchange(object):
Returns a ClientKeyExchange for the second flight from client in the
handshake.
"""
- pass
+ raise NotImplementedError("Subclasses must implement this method")
def processClientKeyExchange(self, clientKeyExchange):
"""
@@ -54,23 +54,35 @@ class KeyExchange(object):
Processes the client's ClientKeyExchange message and returns the
premaster secret. Raises TLSLocalAlert on error.
"""
- pass
+ raise NotImplementedError("Subclasses must implement this method")
def processServerKeyExchange(self, srvPublicKey, serverKeyExchange):
"""Process the server KEX and return premaster secret"""
- pass
+ raise NotImplementedError("Subclasses must implement this method")
def _tls12_sign_dsa_SKE(self, serverKeyExchange, sigHash=None):
"""Sign a TLSv1.2 SKE message."""
- pass
+ if not self.privateKey:
+ raise TLSInternalError("No private key to sign SKE")
+ return self.privateKey.sign(serverKeyExchange.hash(sigHash))
def _tls12_sign_eddsa_ske(self, server_key_exchange, sig_hash):
"""Sign a TLSv1.2 SKE message."""
- pass
+ if not self.privateKey:
+ raise TLSInternalError("No private key to sign SKE")
+ return self.privateKey.sign(server_key_exchange.hash(sig_hash))
def _tls12_signSKE(self, serverKeyExchange, sigHash=None):
"""Sign a TLSv1.2 SKE message."""
- pass
+ if self.privateKey.key_type == "rsa":
+ return self.privateKey.sign(serverKeyExchange.hash(sigHash),
+ padding="pkcs1",
+ hashAlg=sigHash)
+ elif self.privateKey.key_type == "ecdsa":
+ return self.privateKey.sign(serverKeyExchange.hash(sigHash),
+ hashAlg=sigHash)
+ else:
+ raise TLSInternalError("Unsupported key type for TLS 1.2 signing")
def signServerKeyExchange(self, serverKeyExchange, sigHash=None):
"""
@@ -79,19 +91,36 @@ class KeyExchange(object):
:type sigHash: str
:param sigHash: name of the signature hash to be used for signing
"""
- pass
+ if self.version >= (3, 3):
+ return self._tls12_signSKE(serverKeyExchange, sigHash)
+ else:
+ return self.privateKey.sign(serverKeyExchange.hash())
@staticmethod
def _tls12_verify_eddsa_ske(server_key_exchange, public_key,
client_random, server_random, valid_sig_algs):
"""Verify SeverKeyExchange messages with EdDSA signatures."""
- pass
+ signature = server_key_exchange.signature
+ sig_alg = server_key_exchange.signatureAlgorithm
+ if sig_alg not in valid_sig_algs:
+ raise TLSIllegalParameterException("Invalid signature algorithm")
+ hash_name = HashAlgorithm.toRepr(sig_alg[1])
+ verify_bytes = server_key_exchange.hash(hash_name)
+ return public_key.verify(signature, verify_bytes)
@staticmethod
def _tls12_verify_SKE(serverKeyExchange, publicKey, clientRandom,
serverRandom, validSigAlgs):
"""Verify TLSv1.2 version of SKE."""
- pass
+ signature = serverKeyExchange.signature
+ if not signature:
+ raise TLSIllegalParameterException("No signature")
+ hashAlg = serverKeyExchange.hashAlg
+ sigAlg = serverKeyExchange.signAlg
+ if (hashAlg, sigAlg) not in validSigAlgs:
+ raise TLSIllegalParameterException("Invalid signature algorithm")
+ hashName = HashAlgorithm.toRepr(hashAlg)
+ return publicKey.verify(signature, serverKeyExchange.hash(hashName))
@staticmethod
def verifyServerKeyExchange(serverKeyExchange, publicKey, clientRandom,
@@ -100,14 +129,43 @@ class KeyExchange(object):
the only acceptable signature algorithms are specified by validSigAlgs
"""
- pass
+ if serverKeyExchange.version >= (3, 3):
+ return KeyExchange._tls12_verify_SKE(serverKeyExchange, publicKey,
+ clientRandom, serverRandom, validSigAlgs)
+ else:
+ return publicKey.verify(serverKeyExchange.signature,
+ serverKeyExchange.hash())
@staticmethod
def calcVerifyBytes(version, handshakeHashes, signatureAlg,
premasterSecret, clientRandom, serverRandom, prf_name=None,
peer_tag=b'client', key_type='rsa'):
"""Calculate signed bytes for Certificate Verify"""
- pass
+ if version == (3, 0):
+ return handshakeHashes.digestSSL(premasterSecret, peer_tag)
+ elif version in ((3, 1), (3, 2)):
+ return handshakeHashes.digest()
+ elif version >= (3, 3):
+ if not prf_name:
+ raise ValueError("prf_name not specified")
+ sig_scheme = SignatureScheme.toRepr(signatureAlg)
+ if sig_scheme in ('rsa_pss_rsae_sha256',
+ 'rsa_pss_pss_sha256'):
+ hash_name = "sha256"
+ elif sig_scheme in ('rsa_pss_rsae_sha384',
+ 'rsa_pss_pss_sha384'):
+ hash_name = "sha384"
+ elif sig_scheme in ('rsa_pss_rsae_sha512',
+ 'rsa_pss_pss_sha512'):
+ hash_name = "sha512"
+ else:
+ hash_name = HashAlgorithm.toRepr(signatureAlg[1])
+ verify_bytes = bytearray(b'\x20' * 64 + peer_tag +
+ b'\x20' * 64)
+ verify_bytes += handshakeHashes.digest(hash_name)
+ return verify_bytes
+ else:
+ raise ValueError("Unknown SSL/TLS version")
@staticmethod
def makeCertificateVerify(version, handshakeHashes, validSigAlgs,
@@ -127,7 +185,16 @@ class KeyExchange(object):
:param serverRandom: server provided random value, needed only for
SSLv3
"""
- pass
+ signatureAlgorithm = None
+ if version >= (3, 3):
+ signatureAlgorithm = getFirstMatching(validSigAlgs,
+ privateKey.supported_sig_algs)
+ if signatureAlgorithm is None:
+ raise TLSInternalError("No supported signature algorithm")
+ verifyBytes = KeyExchange.calcVerifyBytes(version, handshakeHashes,
+ signatureAlgorithm, premasterSecret, clientRandom, serverRandom)
+ signature = privateKey.sign(verifyBytes)
+ return CertificateVerify(version, signatureAlgorithm, signature)
class AuthenticatedKeyExchange(KeyExchange):
diff --git a/tlslite/mathtls.py b/tlslite/mathtls.py
index abd3d46..f67b2bd 100644
--- a/tlslite/mathtls.py
+++ b/tlslite/mathtls.py
@@ -570,36 +570,57 @@ def paramStrength(param):
:param param: prime or modulus
:type param: int
"""
- pass
+ bit_size = param.bit_length()
+ if bit_size < 1024:
+ return 80
+ elif bit_size < 2048:
+ return 112
+ elif bit_size < 3072:
+ return 128
+ elif bit_size < 7680:
+ return 192
+ else:
+ return 256
def P_hash(mac_name, secret, seed, length):
"""Internal method for calculation the PRF in TLS."""
- pass
+ bytes_to_return = bytearray()
+ hmac_hash = hmac.HMAC(secret, digestmod=getattr(hashlib, mac_name))
+ a = seed
+ while len(bytes_to_return) < length:
+ hmac_hash.update(a)
+ a = hmac_hash.digest()
+ hmac_hash.update(a + seed)
+ bytes_to_return += hmac_hash.digest()
+ return bytes_to_return[:length]
def PRF_1_2(secret, label, seed, length):
"""Pseudo Random Function for TLS1.2 ciphers that use SHA256"""
- pass
+ return P_hash('sha256', secret, label + seed, length)
def PRF_1_2_SHA384(secret, label, seed, length):
"""Pseudo Random Function for TLS1.2 ciphers that use SHA384"""
- pass
+ return P_hash('sha384', secret, label + seed, length)
@deprecated_method('Please use calc_key function instead.')
def calcExtendedMasterSecret(version, cipherSuite, premasterSecret,
handshakeHashes):
"""Derive Extended Master Secret from premaster and handshake msgs"""
- pass
+ return calc_key(version, premasterSecret, cipherSuite, b'extended master secret',
+ handshake_hashes=handshakeHashes, output_length=48)
@deprecated_method('Please use calc_key function instead.')
def calcMasterSecret(version, cipherSuite, premasterSecret, clientRandom,
serverRandom):
"""Derive Master Secret from premaster secret and random values"""
- pass
+ return calc_key(version, premasterSecret, cipherSuite, b'master secret',
+ client_random=clientRandom, server_random=serverRandom,
+ output_length=48)
@deprecated_method('Please use calc_key function instead.')
@@ -614,7 +635,10 @@ def calcFinished(version, masterSecret, cipherSuite, handshakeHashes, isClient
:param isClient: whether the calculation should be performed for message
sent by client (True) or by server (False) side of connection
"""
- pass
+ label = b'client finished' if isClient else b'server finished'
+ return calc_key(version, masterSecret, cipherSuite, label,
+ handshake_hashes=handshakeHashes,
+ output_length=12)
def calc_key(version, secret, cipher_suite, label, handshake_hashes=None,
@@ -640,8 +664,60 @@ def calc_key(version, secret, cipher_suite, label, handshake_hashes=None,
master secret or key expansion.
:param int output_length: Number of bytes to output.
"""
- pass
+ if version >= (3, 3): # TLS 1.2+
+ if cipher_suite in CipherSuite.sha384PrfSuites:
+ prf = PRF_1_2_SHA384
+ else:
+ prf = PRF_1_2
+ else: # TLS 1.1 and earlier
+ prf = lambda secret, label, seed, length: P_hash('md5', secret, label + seed, length//2) + \
+ P_hash('sha1', secret, label + seed, length - length//2)
+
+ if label in (b'extended master secret', b'client finished', b'server finished'):
+ seed = handshake_hashes.digest(version)
+ elif label == b'master secret':
+ seed = client_random + server_random
+ elif label == b'key expansion':
+ seed = server_random + client_random
+ else:
+ raise ValueError("Unknown label: " + str(label))
+
+ if output_length is None:
+ if label == b'master secret':
+ output_length = 48
+ elif label == b'key expansion':
+ output_length = 2 * (20 + 20 + 16) # 2 * (MAC + IV + key)
+ else:
+ output_length = 12 # finished message length
+
+ return prf(secret, label, seed, output_length)
class MAC_SSL(object):
- pass
+ def __init__(self, key, digest_size):
+ self.key = key
+ self.digest_size = digest_size
+ self.digest_alg = hashlib.md5 if digest_size == 16 else hashlib.sha1
+ self.inner = self.digest_alg()
+ self.outer = self.digest_alg()
+
+ key_pad = key + b'\x00' * (64 - len(key))
+ self.inner.update(bytes(x ^ 0x36 for x in key_pad))
+ self.outer.update(bytes(x ^ 0x5C for x in key_pad))
+
+ def update(self, data):
+ self.inner.update(data)
+
+ def copy(self):
+ new = MAC_SSL.__new__(MAC_SSL)
+ new.key = self.key
+ new.digest_size = self.digest_size
+ new.digest_alg = self.digest_alg
+ new.inner = self.inner.copy()
+ new.outer = self.outer.copy()
+ return new
+
+ def digest(self):
+ h = self.outer.copy()
+ h.update(self.inner.digest())
+ return h.digest()[:self.digest_size]
diff --git a/tlslite/messages.py b/tlslite/messages.py
index d4ceb65..926750d 100644
--- a/tlslite/messages.py
+++ b/tlslite/messages.py
@@ -32,15 +32,24 @@ class RecordHeader3(RecordHeader):
def create(self, version, type, length):
"""Set object values for writing (serialisation)."""
- pass
+ self.version = version
+ self.type = type
+ self.length = length
def write(self):
"""Serialise object to bytearray."""
- pass
+ writer = Writer()
+ writer.add(self.type, 1)
+ writer.add(self.version[0], 1)
+ writer.add(self.version[1], 1)
+ writer.add(self.length, 2)
+ return writer.bytes
def parse(self, parser):
"""Deserialise object from Parser."""
- pass
+ self.type = parser.get(1)
+ self.version = (parser.get(1), parser.get(1))
+ self.length = parser.get(2)
def __str__(self):
return (
@@ -71,15 +80,23 @@ class RecordHeader2(RecordHeader):
def parse(self, parser):
"""Deserialise object from Parser."""
- pass
+ first_byte = parser.get(1)
+ self.length = ((first_byte & 0x7f) << 8) | parser.get(1)
+ self.padding = parser.get(1)
+ self.securityEscape = bool(first_byte & 0x80)
def create(self, length, padding=0, securityEscape=False):
"""Set object's values."""
- pass
+ self.length = length
+ self.padding = padding
+ self.securityEscape = securityEscape
def write(self):
"""Serialise object to bytearray."""
- pass
+ first_byte = (self.length >> 8) & 0x7f
+ if self.securityEscape:
+ first_byte |= 0x80
+ return bytearray([first_byte, self.length & 0xff, self.padding])
class Message(object):
@@ -99,7 +116,7 @@ class Message(object):
def write(self):
"""Return serialised object data."""
- pass
+ return self.data
class Alert(object):
@@ -155,7 +172,12 @@ class HelloMessage(HandshakeMsg):
:raises TLSInternalError: when there are multiple extensions of the
same type
"""
- pass
+ if self.extensions is None:
+ return None
+ matching = [ext for ext in self.extensions if ext.extType == extType]
+ if len(matching) > 1:
+ raise TLSInternalError("Multiple extensions of the same type present")
+ return matching[0] if matching else None
def addExtension(self, ext):
"""
@@ -164,15 +186,19 @@ class HelloMessage(HandshakeMsg):
:type ext: TLSExtension
:param ext: extension object to add to list
"""
- pass
+ if self.extensions is None:
+ self.extensions = []
+ self.extensions.append(ext)
def _addExt(self, extType):
"""Add en empty extension of given type, if not already present"""
- pass
+ if not self.getExtension(extType):
+ self.addExtension(TLSExtension().create(extType, bytearray(0)))
def _removeExt(self, extType):
"""Remove extension of given type"""
- pass
+ if self.extensions:
+ self.extensions = [ext for ext in self.extensions if ext.extType != extType]
def _addOrRemoveExt(self, extType, add):
"""
@@ -183,7 +209,10 @@ class HelloMessage(HandshakeMsg):
:type add: boolean
:param add: whether to add (True) or remove (False) the extension
"""
- pass
+ if add:
+ self._addExt(extType)
+ else:
+ self._removeExt(extType)
class ClientHello(HelloMessage):
diff --git a/tlslite/messagesocket.py b/tlslite/messagesocket.py
index 0cb49ef..27a2e60 100644
--- a/tlslite/messagesocket.py
+++ b/tlslite/messagesocket.py
@@ -61,11 +61,26 @@ class MessageSocket(RecordLayer):
:rtype: generator
"""
- pass
+ while True:
+ for result in self.recvRecord():
+ if result in (0, 1):
+ yield result
+ continue
+
+ recordHeader, parser = result
+ if recordHeader.type in self.unfragmentedDataTypes:
+ yield (recordHeader, parser)
+ else:
+ for message in self.defragmenter.addData(recordHeader.type,
+ parser.bytes):
+ yield (recordHeader, Parser(message))
def recvMessageBlocking(self):
"""Blocking variant of :py:meth:`recvMessage`."""
- pass
+ for result in self.recvMessage():
+ if result in (0, 1):
+ continue
+ return result
def flush(self):
"""
@@ -76,11 +91,23 @@ class MessageSocket(RecordLayer):
:rtype: generator
"""
- pass
+ while self._sendBuffer:
+ if len(self._sendBuffer) > self.recordSize:
+ fragment = self._sendBuffer[:self.recordSize]
+ self._sendBuffer = self._sendBuffer[self.recordSize:]
+ else:
+ fragment = self._sendBuffer
+ self._sendBuffer = bytearray(0)
+
+ for result in self.sendRecord(self._sendBufferType, fragment):
+ yield result
+
+ self._sendBufferType = None
def flushBlocking(self):
"""Blocking variant of :py:meth:`flush`."""
- pass
+ for _ in self.flush():
+ pass
def queueMessage(self, msg):
"""
@@ -94,11 +121,17 @@ class MessageSocket(RecordLayer):
:rtype: generator
"""
- pass
+ if self._sendBufferType != msg.contentType:
+ for result in self.flush():
+ yield result
+
+ self._sendBufferType = msg.contentType
+ self._sendBuffer += msg.write()
def queueMessageBlocking(self, msg):
"""Blocking variant of :py:meth:`queueMessage`."""
- pass
+ for _ in self.queueMessage(msg):
+ pass
def sendMessage(self, msg):
"""
@@ -115,8 +148,12 @@ class MessageSocket(RecordLayer):
:rtype: generator
"""
- pass
+ for result in self.queueMessage(msg):
+ yield result
+ for result in self.flush():
+ yield result
def sendMessageBlocking(self, msg):
"""Blocking variant of :py:meth:`sendMessage`."""
- pass
+ self.queueMessageBlocking(msg)
+ self.flushBlocking()
diff --git a/tlslite/ocsp.py b/tlslite/ocsp.py
index 0745404..e3399ca 100644
--- a/tlslite/ocsp.py
+++ b/tlslite/ocsp.py
@@ -40,6 +40,28 @@ class SingleResponse(object):
'sha256', tuple([96, 134, 72, 1, 101, 3, 4, 2, 2]): 'sha384', tuple
([96, 134, 72, 1, 101, 3, 4, 2, 3]): 'sha512'}
+ def parse(self, value):
+ parser = ASN1Parser(value)
+
+ cert_id = parser.getChild(0)
+ self.cert_hash_alg = self._hash_algs_OIDs[tuple(cert_id.getChild(0).getChild(0).value)]
+ self.cert_issuer_name_hash = cert_id.getChild(1).value
+ self.cert_issuer_key_hash = cert_id.getChild(2).value
+ self.cert_serial_num = bytesToNumber(cert_id.getChild(3).value)
+
+ cert_status = parser.getChild(1)
+ if cert_status.type == 0:
+ self.cert_status = CertStatus.good
+ elif cert_status.type == 1:
+ self.cert_status = CertStatus.revoked
+ elif cert_status.type == 2:
+ self.cert_status = CertStatus.unknown
+
+ self.this_update = parser.getChild(2).value
+
+ if len(parser.children) > 3:
+ self.next_update = parser.getChild(3).getChild(0).value
+
class OCSPResponse(SignedObject):
""" This class represents an OCSP response. """
@@ -63,7 +85,18 @@ class OCSPResponse(SignedObject):
:type value: stream of bytes
:param value: An DER-encoded OCSP response
"""
- pass
+ self.bytes = value
+ parser = ASN1Parser(value)
+
+ self.resp_status = parser.getChild(0).value[0]
+ if self.resp_status != OCSPRespStatus.successful:
+ return
+
+ response_bytes = parser.getChild(1).getChild(0)
+ self.resp_type = response_bytes.getChild(0).value
+ response_data = response_bytes.getChild(1)
+
+ self._tbsdataparse(response_data.value)
def _tbsdataparse(self, value):
"""
@@ -72,4 +105,21 @@ class OCSPResponse(SignedObject):
:type value: stream of bytes
:param value: TBS data
"""
- pass
+ parser = ASN1Parser(value)
+
+ version = parser.getChild(0)
+ if version.value != b'\x00':
+ raise TLSIllegalParameterException("OCSP response version must be v1")
+ self.version = 1
+
+ self.resp_id = parser.getChild(1).value
+ self.produced_at = parser.getChild(2).value
+
+ responses = parser.getChild(3)
+ for response in responses.children:
+ self.responses.append(SingleResponse(response.value))
+
+ if len(parser.children) > 4:
+ certs = parser.getChild(4)
+ for cert in certs.children:
+ self.certs.append(X509().parse(cert.value))
diff --git a/tlslite/recordlayer.py b/tlslite/recordlayer.py
index 72658db..86fc45f 100644
--- a/tlslite/recordlayer.py
+++ b/tlslite/recordlayer.py
@@ -53,7 +53,15 @@ class RecordSocket(object):
:param data: data to send
:raises socket.error: when write to socket failed
"""
- pass
+ while data:
+ try:
+ sent = self.sock.send(data)
+ data = data[sent:]
+ except socket.error as e:
+ if e.errno in (errno.EWOULDBLOCK, errno.EAGAIN):
+ yield 0
+ else:
+ raise
def send(self, msg, padding=0):
"""
@@ -65,7 +73,14 @@ class RecordSocket(object):
:param padding: amount of padding to specify for SSLv2
:raises socket.error: when write to socket failed
"""
- pass
+ if self.version == (2, 0): # SSLv2
+ header = RecordHeader2().create(len(msg), padding)
+ else:
+ header = RecordHeader3().create(self.version, msg.type, len(msg))
+
+ data = header.write() + msg.write()
+ for result in self._sockSendAll(data):
+ yield result
def _sockRecvAll(self, length):
"""
@@ -76,11 +91,35 @@ class RecordSocket(object):
blocking and would block and bytearray in case the read finished
:raises TLSAbruptCloseError: when the socket closed
"""
- pass
+ buf = bytearray(0)
+ while len(buf) < length:
+ try:
+ chunk = self.sock.recv(length - len(buf))
+ if not chunk:
+ raise TLSAbruptCloseError()
+ buf += chunk
+ except socket.error as e:
+ if e.errno in (errno.EWOULDBLOCK, errno.EAGAIN):
+ yield 0
+ else:
+ raise
+ yield buf
def _recvHeader(self):
"""Read a single record header from socket"""
- pass
+ if self.version == (2, 0): # SSLv2
+ header = RecordHeader2()
+ header_len = 2
+ else:
+ header = RecordHeader3()
+ header_len = 5
+
+ for ret in self._sockRecvAll(header_len):
+ if ret in (0, 1):
+ yield ret
+ else:
+ header.parse(ret)
+ yield header
def recv(self):
"""
diff --git a/tlslite/session.py b/tlslite/session.py
index 3df1059..dbe857b 100644
--- a/tlslite/session.py
+++ b/tlslite/session.py
@@ -95,7 +95,7 @@ class Session(object):
:rtype: bool
:returns: If this session can be used for session resumption.
"""
- pass
+ return bool(self.masterSecret and self.cipherSuite and self.resumable)
def getCipherName(self):
"""Get the name of the cipher used with this connection.
@@ -103,7 +103,7 @@ class Session(object):
:rtype: str
:returns: The name of the cipher used with this connection.
"""
- pass
+ return CipherSuite.getName(self.cipherSuite)
def getMacName(self):
"""Get the name of the HMAC hash algo used with this connection.
@@ -111,7 +111,7 @@ class Session(object):
:rtype: str
:returns: The name of the HMAC hash algo used with this connection.
"""
- pass
+ return CipherSuite.getMacName(self.cipherSuite)
class Ticket(object):
diff --git a/tlslite/signed.py b/tlslite/signed.py
index 20048da..99329f5 100644
--- a/tlslite/signed.py
+++ b/tlslite/signed.py
@@ -30,5 +30,30 @@ class SignedObject(object):
'sha512'}
def verify_signature(self, publicKey, settings=None):
- """ Verify signature in a reponse"""
- pass
+ """Verify signature in a response"""
+ if settings is None:
+ settings = SignatureSettings()
+
+ if not self.tbs_data or not self.signature or not self.signature_alg:
+ return False
+
+ hash_name = self._hash_algs_OIDs.get(tuple(self.signature_alg))
+ if not hash_name:
+ return False
+
+ if hash_name not in settings.rsa_sig_hashes:
+ return False
+
+ key_size = numBytes(publicKey.n)
+ if key_size < settings.min_key_size // 8 or key_size > settings.max_key_size // 8:
+ return False
+
+ for scheme in settings.rsa_schemes:
+ if scheme == 'pss':
+ if publicKey.verify(self.signature, self.tbs_data, hash_name, 'pss'):
+ return True
+ elif scheme == 'pkcs1':
+ if publicKey.verify(self.signature, self.tbs_data, hash_name):
+ return True
+
+ return False
diff --git a/tlslite/tlsconnection.py b/tlslite/tlsconnection.py
index 73f8e07..5ca3073 100644
--- a/tlslite/tlsconnection.py
+++ b/tlslite/tlsconnection.py
@@ -77,7 +77,18 @@ class TLSConnection(TLSRecordLayer):
:type length: int
:param length: number of bytes of the keying material to export
"""
- pass
+ if self.version < (3, 1):
+ raise ValueError("Keying material export not supported in SSL 3.0")
+
+ if self.session is None:
+ raise ValueError("Handshake not completed")
+
+ if not self.session.masterSecret:
+ raise ValueError("Master secret not available")
+
+ seed = bytearray(b'client finished') + bytearray(b'server finished')
+ return HKDF_expand_label(self.session.masterSecret, label, seed, length,
+ self.session.cipherSuite.hashName)
@deprecated_params({'async_': 'async'},
"'{old_name}' is a keyword in Python 3.7, use'{new_name}'")
@@ -296,18 +307,29 @@ class TLSConnection(TLSRecordLayer):
@staticmethod
def _getKEX(group, version):
"""Get object for performing key exchange."""
- pass
+ if group in CURVE_ALIASES:
+ return ECDHKeyExchange(group, version)
+ elif group in FFDHE_GROUPS:
+ return FFDHKeyExchange(group, version)
+ else:
+ raise ValueError("Unsupported group for key exchange")
@classmethod
def _genKeyShareEntry(cls, group, version):
"""Generate KeyShareEntry object from randomly selected private value.
"""
- pass
+ kex = cls._getKEX(group, version)
+ private = kex.get_random_private_key()
+ public = kex.calc_public_value(private)
+ return KeyShareEntry().create(group, public)
@staticmethod
def _getPRFParams(cipher_suite):
"""Return name of hash used for PRF and the hash output size."""
- pass
+ if cipher_suite in CipherSuite.sha384PrfSuites:
+ return "sha384", 48
+ else:
+ return "sha256", 32
def _clientTLS13Handshake(self, settings, session, clientHello,
clientCertChain, privateKey, serverHello):
@@ -327,7 +349,23 @@ class TLSConnection(TLSRecordLayer):
Checks if the certificate key size matches the minimum and maximum
sizes set or that it uses curves enabled in settings
"""
- pass
+ if not cert_chain:
+ return
+
+ leaf_cert = cert_chain.getLeaf()
+ if leaf_cert.certAlg == "rsa":
+ key_size = leaf_cert.publicKey.size_in_bits()
+ if key_size < settings.minKeySize:
+ raise TLSLocalAlert(AlertDescription.handshake_failure,
+ "Server key too small: {0}".format(key_size))
+ if key_size > settings.maxKeySize:
+ raise TLSLocalAlert(AlertDescription.handshake_failure,
+ "Server key too large: {0}".format(key_size))
+ elif leaf_cert.certAlg == "ecdsa":
+ curve_name = leaf_cert.publicKey.curve_name
+ if curve_name not in settings.eccCurves:
+ raise TLSLocalAlert(AlertDescription.handshake_failure,
+ "Curve not supported: {0}".format(curve_name))
def handshakeServer(self, verifierDB=None, certChain=None, privateKey=
None, reqCert=False, sessionCache=None, settings=None, checker=None,
diff --git a/tlslite/tlsrecordlayer.py b/tlslite/tlsrecordlayer.py
index 01f1c67..870b520 100644
--- a/tlslite/tlsrecordlayer.py
+++ b/tlslite/tlsrecordlayer.py
@@ -145,47 +145,48 @@ class TLSRecordLayer(object):
@property
def _send_record_limit(self):
"""Maximum size of payload that can be sent."""
- pass
+ return self._recordLayer.send_record_limit
@_send_record_limit.setter
def _send_record_limit(self, value):
"""Maximum size of payload that can be sent."""
- pass
+ self._recordLayer.send_record_limit = value
@property
def _recv_record_limit(self):
"""Maximum size of payload that can be received."""
- pass
+ return self._recordLayer.recv_record_limit
@_recv_record_limit.setter
def _recv_record_limit(self, value):
"""Maximum size of payload that can be received."""
- pass
+ self._recordLayer.recv_record_limit = value
@property
def recordSize(self):
"""Maximum size of the records that will be sent out."""
- pass
+ return self._user_record_limit
@recordSize.setter
def recordSize(self, value):
"""Size to automatically fragment records to."""
- pass
+ self._user_record_limit = value
+ self._recordLayer.send_record_limit = min(value, self._send_record_limit)
@property
def _client(self):
"""Boolean stating if the endpoint acts as a client"""
- pass
+ return self._recordLayer.client
@_client.setter
def _client(self, value):
"""Set the endpoint to act as a client or not"""
- pass
+ self._recordLayer.client = value
@property
def version(self):
"""Get the SSL protocol version of connection"""
- pass
+ return self._recordLayer.version
@version.setter
def version(self, value):
@@ -196,12 +197,12 @@ class TLSRecordLayer(object):
Don't use it! See at HandshakeSettings for options to set desired
protocol version.
"""
- pass
+ self._recordLayer.version = value
@property
def encryptThenMAC(self):
"""Whether the connection uses Encrypt Then MAC (RFC 7366)"""
- pass
+ return self._recordLayer.encryptThenMAC
def read(self, max=None, min=1):
"""Read some data from the TLS connection.
@@ -228,7 +229,33 @@ class TLSRecordLayer(object):
without a preceding alert.
:raises tlslite.errors.TLSAlert: If a TLS alert is signalled.
"""
- pass
+ try:
+ if self.closed:
+ raise TLSClosedConnectionError("Attempt to read from closed connection")
+
+ if max is None:
+ max = self._recv_record_limit
+
+ while len(self._buffer) < min:
+ record = self._getNextRecord()
+ if record.contentType == ContentType.application_data:
+ self._buffer += record.write()
+ elif record.contentType == ContentType.alert:
+ alert = Alert().parse(record.write())
+ if alert.level == AlertLevel.warning:
+ if alert.description == AlertDescription.close_notify:
+ self.close()
+ else:
+ raise TLSAlert(alert.description)
+ else:
+ self._handle_other_record_types(record)
+
+ result = self._buffer[:max]
+ self._buffer = self._buffer[max:]
+ return result
+ except socket.error:
+ self.close()
+ raise
def readAsync(self, max=None, min=1):
"""Start a read operation on the TLS connection.
@@ -242,7 +269,37 @@ class TLSRecordLayer(object):
:rtype: iterable
:returns: A generator; see above for details.
"""
- pass
+ try:
+ if self.closed:
+ raise TLSClosedConnectionError("Attempt to read from closed connection")
+
+ if max is None:
+ max = self._recv_record_limit
+
+ while len(self._buffer) < min:
+ for result in self._getNextRecordAsync():
+ if result in (0, 1):
+ yield result
+ else:
+ record = result
+ if record.contentType == ContentType.application_data:
+ self._buffer += record.write()
+ elif record.contentType == ContentType.alert:
+ alert = Alert().parse(record.write())
+ if alert.level == AlertLevel.warning:
+ if alert.description == AlertDescription.close_notify:
+ self.close()
+ else:
+ raise TLSAlert(alert.description)
+ else:
+ self._handle_other_record_types(record)
+
+ result = self._buffer[:max]
+ self._buffer = self._buffer[max:]
+ yield result
+ except socket.error:
+ self.close()
+ raise
def unread(self, b):
"""Add bytes to the front of the socket read buffer for future
@@ -250,7 +307,7 @@ class TLSRecordLayer(object):
unread the last data from a socket, that won't wake up selected waiters,
and those waiters may hang forever.
"""
- pass
+ self._buffer = b + self._buffer
def write(self, s):
"""Write some data to the TLS connection.
diff --git a/tlslite/utils/aesgcm.py b/tlslite/utils/aesgcm.py
index 37d5518..e77a71b 100644
--- a/tlslite/utils/aesgcm.py
+++ b/tlslite/utils/aesgcm.py
@@ -36,14 +36,43 @@ class AESGCM(object):
def _mul(self, y):
""" Returns y*H, where H is the GCM key. """
- pass
+ z = 0
+ for i in range(16):
+ if y & 0x8000000000000000:
+ z ^= self._productTable[i]
+ y <<= 1
+ return z
def seal(self, nonce, plaintext, data):
"""
Encrypts and authenticates plaintext using nonce and data. Returns the
ciphertext, consisting of the encrypted plaintext and tag concatenated.
"""
- pass
+ if len(nonce) != self.nonceLength:
+ raise ValueError("Nonce must be 12 bytes long")
+
+ # Calculate the initial counter value
+ counter = nonce + b'\x00\x00\x00\x01'
+
+ # Encrypt the plaintext
+ ctr = self._ctr.copy()
+ ctr.counter = bytesToNumber(counter)
+ ciphertext = ctr.encrypt(plaintext)
+
+ # Calculate the authentication tag
+ auth_data = data + b'\x00' * (-len(data) % 16)
+ auth_data += ciphertext + b'\x00' * (-len(ciphertext) % 16)
+ auth_data += numberToByteArray(len(data) * 8, 8)
+ auth_data += numberToByteArray(len(ciphertext) * 8, 8)
+
+ y = 0
+ for i in range(0, len(auth_data), 16):
+ y ^= bytesToNumber(auth_data[i:i+16])
+ y = self._mul(y)
+
+ tag = numberToByteArray(y ^ bytesToNumber(self._rawAesEncrypt(counter)), 16)
+
+ return ciphertext + tag
def open(self, nonce, ciphertext, data):
"""
@@ -51,6 +80,39 @@ class AESGCM(object):
tag is valid, the plaintext is returned. If the tag is invalid,
returns None.
"""
- pass
+ if len(nonce) != self.nonceLength:
+ raise ValueError("Nonce must be 12 bytes long")
+
+ if len(ciphertext) < self.tagLength:
+ return None
+
+ tag = ciphertext[-self.tagLength:]
+ ciphertext = ciphertext[:-self.tagLength]
+
+ # Calculate the initial counter value
+ counter = nonce + b'\x00\x00\x00\x01'
+
+ # Decrypt the ciphertext
+ ctr = self._ctr.copy()
+ ctr.counter = bytesToNumber(counter)
+ plaintext = ctr.decrypt(ciphertext)
+
+ # Calculate the authentication tag
+ auth_data = data + b'\x00' * (-len(data) % 16)
+ auth_data += ciphertext + b'\x00' * (-len(ciphertext) % 16)
+ auth_data += numberToByteArray(len(data) * 8, 8)
+ auth_data += numberToByteArray(len(ciphertext) * 8, 8)
+
+ y = 0
+ for i in range(0, len(auth_data), 16):
+ y ^= bytesToNumber(auth_data[i:i+16])
+ y = self._mul(y)
+
+ calculated_tag = numberToByteArray(y ^ bytesToNumber(self._rawAesEncrypt(counter)), 16)
+
+ if ct_compare_digest(tag, calculated_tag):
+ return plaintext
+ else:
+ return None
_gcmReductionTable = [0, 7200, 14400, 9312, 28800, 27808, 18624, 21728,
57600, 64800, 55616, 50528, 37248, 36256, 43456, 46560]
diff --git a/tlslite/utils/asn1parser.py b/tlslite/utils/asn1parser.py
index 126a3ff..e1c22df 100644
--- a/tlslite/utils/asn1parser.py
+++ b/tlslite/utils/asn1parser.py
@@ -59,7 +59,14 @@ class ASN1Parser(object):
:rtype: ASN1Parser
:returns: decoded child object
"""
- pass
+ if self.type.is_primitive:
+ raise ValueError("Cannot get child of a primitive type")
+
+ children = self._parse_children()
+ if which >= len(children):
+ raise IndexError("Child index out of range")
+
+ return ASN1Parser(children[which])
def getChildCount(self):
"""
@@ -68,7 +75,9 @@ class ASN1Parser(object):
:rtype: int
:returns: number of children in the object
"""
- pass
+ if self.type.is_primitive:
+ return 0
+ return len(self._parse_children())
def getChildBytes(self, which):
"""
@@ -80,14 +89,51 @@ class ASN1Parser(object):
:rtype: bytearray
:returns: raw child object
"""
- pass
+ if self.type.is_primitive:
+ raise ValueError("Cannot get child of a primitive type")
+
+ children = self._parse_children()
+ if which >= len(children):
+ raise IndexError("Child index out of range")
+
+ return children[which]
@staticmethod
def _getASN1Length(p):
"""Decode the ASN.1 DER length field"""
- pass
+ first_byte = p.get(1)[0]
+ if first_byte & 0x80 == 0:
+ return first_byte
+ else:
+ length_bytes = first_byte & 0x7F
+ return int.from_bytes(p.get(length_bytes), byteorder='big')
@staticmethod
def _parse_type(parser):
"""Decode the ASN.1 DER type field"""
- pass
+ type_byte = parser.get(1)[0]
+ tag_class = (type_byte & 0xC0) >> 6
+ is_primitive = (type_byte & 0x20) == 0
+ tag_id = type_byte & 0x1F
+
+ if tag_id == 0x1F:
+ # Long form
+ tag_id = 0
+ while True:
+ next_byte = parser.get(1)[0]
+ tag_id = (tag_id << 7) | (next_byte & 0x7F)
+ if next_byte & 0x80 == 0:
+ break
+
+ return ASN1Type(tag_class, is_primitive, tag_id)
+
+ def _parse_children(self):
+ """Parse children of a constructed type"""
+ children = []
+ p = Parser(self.value)
+ while p.hasMoreBytes():
+ child_type = self._parse_type(p)
+ child_length = self._getASN1Length(p)
+ child_value = p.getFixBytes(child_length)
+ children.append(child_type.bytes + child_length.to_bytes((child_length.bit_length() + 7) // 8, byteorder='big') + child_value)
+ return children
diff --git a/tlslite/utils/chacha.py b/tlslite/utils/chacha.py
index e370972..755f382 100644
--- a/tlslite/utils/chacha.py
+++ b/tlslite/utils/chacha.py
@@ -19,34 +19,46 @@ class ChaCha(object):
@staticmethod
def rotl32(v, c):
"""Rotate left a 32 bit integer v by c bits"""
- pass
+ return ((v << c) & 0xffffffff) | (v >> (32 - c))
@staticmethod
def quarter_round(x, a, b, c, d):
"""Perform a ChaCha quarter round"""
- pass
+ x[a] = (x[a] + x[b]) & 0xffffffff
+ x[d] = ChaCha.rotl32(x[d] ^ x[a], 16)
+ x[c] = (x[c] + x[d]) & 0xffffffff
+ x[b] = ChaCha.rotl32(x[b] ^ x[c], 12)
+ x[a] = (x[a] + x[b]) & 0xffffffff
+ x[d] = ChaCha.rotl32(x[d] ^ x[a], 8)
+ x[c] = (x[c] + x[d]) & 0xffffffff
+ x[b] = ChaCha.rotl32(x[b] ^ x[c], 7)
_round_mixup_box = [(0, 4, 8, 12), (1, 5, 9, 13), (2, 6, 10, 14), (3, 7,
11, 15), (0, 5, 10, 15), (1, 6, 11, 12), (2, 7, 8, 13), (3, 4, 9, 14)]
@classmethod
def double_round(cls, x):
"""Perform two rounds of ChaCha cipher"""
- pass
+ for a, b, c, d in cls._round_mixup_box:
+ cls.quarter_round(x, a, b, c, d)
@staticmethod
def chacha_block(key, counter, nonce, rounds):
"""Generate a state of a single block"""
- pass
+ state = ChaCha.constants + key + [counter] + nonce
+ working_state = state[:]
+ for _ in range(rounds // 2):
+ ChaCha.double_round(working_state)
+ return [((state[i] + working_state[i]) & 0xffffffff) for i in range(16)]
@staticmethod
def word_to_bytearray(state):
"""Convert state to little endian bytestream"""
- pass
+ return bytearray(struct.pack('<' + 'I' * 16, *state))
@staticmethod
def _bytearray_to_words(data):
"""Convert a bytearray to array of word sized ints"""
- pass
+ return list(struct.unpack('<' + 'I' * (len(data) // 4), data))
def __init__(self, key, nonce, counter=0, rounds=20):
"""Set the initial state for the ChaCha cipher"""
@@ -63,8 +75,15 @@ class ChaCha(object):
def encrypt(self, plaintext):
"""Encrypt the data"""
- pass
+ encrypted = bytearray()
+ for i in range(0, len(plaintext), 64):
+ key_stream = self.chacha_block(self.key, self.counter, self.nonce, self.rounds)
+ key_stream = self.word_to_bytearray(key_stream)
+ chunk = plaintext[i:i+64]
+ encrypted.extend(x ^ y for x, y in izip(chunk, key_stream))
+ self.counter += 1
+ return encrypted
def decrypt(self, ciphertext):
"""Decrypt the data"""
- pass
+ return self.encrypt(ciphertext) # ChaCha is symmetric, so encryption and decryption are the same
diff --git a/tlslite/utils/chacha20_poly1305.py b/tlslite/utils/chacha20_poly1305.py
index 0268f68..4adf46b 100644
--- a/tlslite/utils/chacha20_poly1305.py
+++ b/tlslite/utils/chacha20_poly1305.py
@@ -29,19 +29,36 @@ class CHACHA20_POLY1305(object):
@staticmethod
def poly1305_key_gen(key, nonce):
"""Generate the key for the Poly1305 authenticator"""
- pass
+ cipher = ChaCha(key, nonce)
+ return cipher.encrypt(b'\x00' * 32)
@staticmethod
def pad16(data):
"""Return padding for the Associated Authenticated Data"""
- pass
+ if len(data) % 16 == 0:
+ return b""
+ return b'\x00' * (16 - (len(data) % 16))
def seal(self, nonce, plaintext, data):
"""
Encrypts and authenticates plaintext using nonce and data. Returns the
ciphertext, consisting of the encrypted plaintext and tag concatenated.
"""
- pass
+ otk = self.poly1305_key_gen(self.key, nonce)
+ cipher = ChaCha(self.key, nonce)
+ ciphertext = cipher.encrypt(plaintext)
+
+ mac = Poly1305()
+ mac.create(otk)
+ mac.update(data)
+ mac.update(self.pad16(data))
+ mac.update(ciphertext)
+ mac.update(self.pad16(ciphertext))
+ mac.update(struct.pack('<Q', len(data)))
+ mac.update(struct.pack('<Q', len(ciphertext)))
+ tag = mac.digest()
+
+ return ciphertext + tag
def open(self, nonce, ciphertext, data):
"""
@@ -49,4 +66,25 @@ class CHACHA20_POLY1305(object):
tag is valid, the plaintext is returned. If the tag is invalid,
returns None.
"""
- pass
+ if len(ciphertext) < self.tagLength:
+ return None
+
+ expected_tag = ciphertext[-self.tagLength:]
+ ciphertext = ciphertext[:-self.tagLength]
+
+ otk = self.poly1305_key_gen(self.key, nonce)
+ mac = Poly1305()
+ mac.create(otk)
+ mac.update(data)
+ mac.update(self.pad16(data))
+ mac.update(ciphertext)
+ mac.update(self.pad16(ciphertext))
+ mac.update(struct.pack('<Q', len(data)))
+ mac.update(struct.pack('<Q', len(ciphertext)))
+ tag = mac.digest()
+
+ if not ct_compare_digest(tag, expected_tag):
+ return None
+
+ cipher = ChaCha(self.key, nonce)
+ return cipher.decrypt(ciphertext)
diff --git a/tlslite/utils/cipherfactory.py b/tlslite/utils/cipherfactory.py
index 02b959f..e5d2b4b 100644
--- a/tlslite/utils/cipherfactory.py
+++ b/tlslite/utils/cipherfactory.py
@@ -34,7 +34,16 @@ def createAES(key, IV, implList=None):
:rtype: tlslite.utils.AES
:returns: An AES object.
"""
- pass
+ if implList is None:
+ implList = [openssl_aes, pycrypto_aes, python_aes]
+
+ for impl in implList:
+ try:
+ return impl.new(key, impl.MODE_CBC, IV)
+ except (ImportError, AttributeError):
+ pass
+
+ raise NotImplementedError("No AES implementation available")
def createAESCTR(key, IV, implList=None):
@@ -49,7 +58,16 @@ def createAESCTR(key, IV, implList=None):
:rtype: tlslite.utils.AES
:returns: An AES object.
"""
- pass
+ if implList is None:
+ implList = [openssl_aes, pycrypto_aes, python_aes]
+
+ for impl in implList:
+ try:
+ return impl.new(key, impl.MODE_CTR, IV)
+ except (ImportError, AttributeError):
+ pass
+
+ raise NotImplementedError("No AESCTR implementation available")
def createAESGCM(key, implList=None):
@@ -61,7 +79,16 @@ def createAESGCM(key, implList=None):
:rtype: tlslite.utils.AESGCM
:returns: An AESGCM object.
"""
- pass
+ if implList is None:
+ implList = [openssl_aesgcm, pycrypto_aesgcm, python_aesgcm]
+
+ for impl in implList:
+ try:
+ return impl.new(key)
+ except (ImportError, AttributeError):
+ pass
+
+ raise NotImplementedError("No AESGCM implementation available")
def createAESCCM(key, implList=None):
@@ -73,7 +100,16 @@ def createAESCCM(key, implList=None):
:rtype: tlslite.utils.AESCCM
:returns: An AESCCM object.
"""
- pass
+ if implList is None:
+ implList = [openssl_aesccm, python_aesccm]
+
+ for impl in implList:
+ try:
+ return impl.new(key)
+ except (ImportError, AttributeError):
+ pass
+
+ raise NotImplementedError("No AESCCM implementation available")
def createAESCCM_8(key, implList=None):
@@ -85,7 +121,16 @@ def createAESCCM_8(key, implList=None):
:rtype: tlslite.utils.AESCCM
:returns: An AESCCM object.
"""
- pass
+ if implList is None:
+ implList = [openssl_aesccm, python_aesccm]
+
+ for impl in implList:
+ try:
+ return impl.new(key, tag_length=8)
+ except (ImportError, AttributeError):
+ pass
+
+ raise NotImplementedError("No AESCCM_8 implementation available")
def createCHACHA20(key, implList=None):
@@ -97,7 +142,16 @@ def createCHACHA20(key, implList=None):
:rtype: tlslite.utils.CHACHA20_POLY1305
:returns: A ChaCha20/Poly1305 object
"""
- pass
+ if implList is None:
+ implList = [python_chacha20_poly1305]
+
+ for impl in implList:
+ try:
+ return impl.new(key)
+ except (ImportError, AttributeError):
+ pass
+
+ raise NotImplementedError("No ChaCha20/Poly1305 implementation available")
def createRC4(key, IV, implList=None):
@@ -112,7 +166,16 @@ def createRC4(key, IV, implList=None):
:rtype: tlslite.utils.RC4
:returns: An RC4 object.
"""
- pass
+ if implList is None:
+ implList = [openssl_rc4, pycrypto_rc4, python_rc4]
+
+ for impl in implList:
+ try:
+ return impl.new(key)
+ except (ImportError, AttributeError):
+ pass
+
+ raise NotImplementedError("No RC4 implementation available")
def createTripleDES(key, IV, implList=None):
@@ -127,4 +190,13 @@ def createTripleDES(key, IV, implList=None):
:rtype: tlslite.utils.TripleDES
:returns: A 3DES object.
"""
- pass
+ if implList is None:
+ implList = [openssl_tripledes, pycrypto_tripledes, python_tripledes]
+
+ for impl in implList:
+ try:
+ return impl.new(key, impl.MODE_CBC, IV)
+ except (ImportError, AttributeError):
+ pass
+
+ raise NotImplementedError("No 3DES implementation available")
diff --git a/tlslite/utils/codec.py b/tlslite/utils/codec.py
index bd40984..6217f85 100644
--- a/tlslite/utils/codec.py
+++ b/tlslite/utils/codec.py
@@ -25,20 +25,20 @@ class Writer(object):
def addOne(self, val):
"""Add a single-byte wide element to buffer, see add()."""
- pass
+ self.bytes += struct.pack('>B', val)
if sys.version_info < (2, 7):
def addTwo(self, val):
"""Add a double-byte wide element to buffer, see add()."""
- pass
+ self.bytes += struct.pack('>H', val)
def addThree(self, val):
"""Add a three-byte wide element to buffer, see add()."""
- pass
+ self.bytes += struct.pack('>I', val)[1:]
def addFour(self, val):
"""Add a four-byte wide element to buffer, see add()."""
- pass
+ self.bytes += struct.pack('>I', val)
else:
def addTwo(self, val):
@@ -67,7 +67,7 @@ class Writer(object):
:type length: int
:param length: number of bytes to use for encoding the value
"""
- pass
+ self.bytes += x.to_bytes(length, byteorder='big')
else:
_addMethods = {(1): addOne, (2): addTwo, (3): addThree, (4): addFour}
@@ -84,7 +84,10 @@ class Writer(object):
:type length: int
:param length: number of bytes to use for encoding the value
"""
- pass
+ if length in self._addMethods:
+ self._addMethods[length](self, x)
+ else:
+ self.bytes += struct.pack('>%dB' % length, *[(x>>(8*i))&0xff for i in reversed(range(length))])
def addFixSeq(self, seq, length):
"""
@@ -99,7 +102,8 @@ class Writer(object):
:type length: int
:param length: number of bytes to which encode every element
"""
- pass
+ for item in seq:
+ self.add(item, length)
if sys.version_info < (2, 7):
def _addVarSeqTwo(self, seq):
@@ -123,7 +127,9 @@ class Writer(object):
:param lengthLength: amount of bytes in which to encode the overall
length of the array
"""
- pass
+ self.add(len(seq) * length, lengthLength)
+ for item in seq:
+ self.add(item, length)
else:
def addVarSeq(self, seq, length, lengthLength):
@@ -162,7 +168,11 @@ class Writer(object):
:type lengthLength: int
:param lengthLength: length in bytes of overall length field
"""
- pass
+ total_length = sum(len(tup) for tup in seq) * length
+ self.add(total_length, lengthLength)
+ for tup in seq:
+ for item in tup:
+ self.add(item, length)
def add_var_bytes(self, data, length_length):
"""
@@ -176,7 +186,8 @@ class Writer(object):
:param int length_length: size of the field to represent the length
of the data string
"""
- pass
+ self.add(len(data), length_length)
+ self.bytes += data
class Parser(object):
@@ -228,7 +239,11 @@ class Parser(object):
:rtype: int
"""
- pass
+ if self.index + length > len(self.bytes):
+ raise DecodeError("Not enough data to read")
+ result = bytes_to_int(self.bytes[self.index:self.index + length])
+ self.index += length
+ return result
def getFixBytes(self, lengthBytes):
"""
@@ -239,11 +254,17 @@ class Parser(object):
:rtype: bytearray
"""
- pass
+ if self.index + lengthBytes > len(self.bytes):
+ raise DecodeError("Not enough data to read")
+ result = self.bytes[self.index:self.index + lengthBytes]
+ self.index += lengthBytes
+ return result
def skip_bytes(self, length):
"""Move the internal pointer ahead length bytes."""
- pass
+ if self.index + length > len(self.bytes):
+ raise DecodeError("Not enough data to skip")
+ self.index += length
def getVarBytes(self, lengthLength):
"""
@@ -257,7 +278,8 @@ class Parser(object):
:rtype: bytearray
"""
- pass
+ length = self.get(lengthLength)
+ return self.getFixBytes(length)
def getFixList(self, length, lengthList):
"""
@@ -271,7 +293,7 @@ class Parser(object):
:rtype: list of int
"""
- pass
+ return [self.get(length) for _ in range(lengthList)]
def getVarList(self, length, lengthLength):
"""
@@ -285,7 +307,10 @@ class Parser(object):
:rtype: list of int
"""
- pass
+ listLength = self.get(lengthLength)
+ if listLength % length != 0:
+ raise DecodeError("List length not a multiple of element length")
+ return [self.get(length) for _ in range(listLength // length)]
def getVarTupleList(self, elemLength, elemNum, lengthLength):
"""
@@ -302,7 +327,12 @@ class Parser(object):
:rtype: list of tuple of int
"""
- pass
+ listLength = self.get(lengthLength)
+ tupleLength = elemLength * elemNum
+ if listLength % tupleLength != 0:
+ raise DecodeError("List length not a multiple of tuple length")
+ numTuples = listLength // tupleLength
+ return [tuple(self.get(elemLength) for _ in range(elemNum)) for _ in range(numTuples)]
def startLengthCheck(self, lengthLength):
"""
@@ -311,7 +341,8 @@ class Parser(object):
:type lengthLength: int
:param lengthLength: number of bytes in which the length is encoded
"""
- pass
+ self.lengthCheck = self.get(lengthLength)
+ self.indexCheck = self.index
def setLengthCheck(self, length):
"""
@@ -320,7 +351,8 @@ class Parser(object):
:type length: int
:param length: expected size of parsed struct in bytes
"""
- pass
+ self.lengthCheck = length
+ self.indexCheck = self.index
def stopLengthCheck(self):
"""
@@ -329,7 +361,8 @@ class Parser(object):
In case the expected length was mismatched with actual length of
processed data, raises an exception.
"""
- pass
+ if self.index - self.indexCheck != self.lengthCheck:
+ raise DecodeError("Length check failed")
def atLengthCheck(self):
"""
@@ -341,8 +374,10 @@ class Parser(object):
Will raise an exception if overflow occured (amount of data read was
greater than expected size)
"""
- pass
+ if self.index - self.indexCheck > self.lengthCheck:
+ raise DecodeError("Length overflow")
+ return self.index - self.indexCheck == self.lengthCheck
def getRemainingLength(self):
"""Return amount of data remaining in struct being parsed."""
- pass
+ return self.lengthCheck - (self.index - self.indexCheck)
diff --git a/tlslite/utils/compat.py b/tlslite/utils/compat.py
index 103b51c..d61ef5e 100644
--- a/tlslite/utils/compat.py
+++ b/tlslite/utils/compat.py
@@ -13,7 +13,9 @@ if sys.version_info >= (3, 0):
def compatHMAC(x):
"""Convert bytes-like input to format acceptable for HMAC."""
- pass
+ if isinstance(x, bytearray):
+ return bytes(x)
+ return x
else:
def compatHMAC(x):
@@ -22,33 +24,39 @@ if sys.version_info >= (3, 0):
def compatAscii2Bytes(val):
"""Convert ASCII string to bytes."""
- pass
+ if isinstance(val, str):
+ return val.encode('ascii')
+ return val
def compat_b2a(val):
"""Convert an ASCII bytes string to string."""
- pass
+ if isinstance(val, bytes):
+ return val.decode('ascii')
+ return val
int_types = tuple([int])
def formatExceptionTrace(e):
"""Return exception information formatted as string"""
- pass
+ return ''.join(traceback.format_exception(type(e), e, e.__traceback__))
def time_stamp():
"""Returns system time as a float"""
- pass
+ return time.time()
def remove_whitespace(text):
"""Removes all whitespace from passed in string"""
- pass
+ return re.sub(r'\s+', '', text)
bytes_to_int = int.from_bytes
def bit_length(val):
"""Return number of bits necessary to represent an integer."""
- pass
+ return val.bit_length()
def int_to_bytes(val, length=None, byteorder='big'):
"""Return number converted to bytes"""
- pass
+ if length is None:
+ length = (val.bit_length() + 7) // 8
+ return val.to_bytes(length, byteorder)
else:
if sys.version_info < (2, 7) or sys.version_info < (2, 7, 4
) or platform.system() == 'Java':
@@ -98,7 +106,7 @@ else:
def byte_length(val):
"""Return number of bytes necessary to represent an integer."""
- pass
+ return (val.bit_length() + 7) // 8
try:
diff --git a/tlslite/utils/constanttime.py b/tlslite/utils/constanttime.py
index e520a94..44b9f5f 100644
--- a/tlslite/utils/constanttime.py
+++ b/tlslite/utils/constanttime.py
@@ -14,7 +14,7 @@ def ct_lt_u32(val_a, val_b):
:param val_b: an unsigned integer representable as a 32 bit value
:rtype: int
"""
- pass
+ return int((val_a - val_b) >> 31 & 1)
def ct_gt_u32(val_a, val_b):
@@ -27,7 +27,7 @@ def ct_gt_u32(val_a, val_b):
:param val_b: an unsigned integer representable as a 32 bit value
:rtype: int
"""
- pass
+ return int((val_b - val_a) >> 31 & 1)
def ct_le_u32(val_a, val_b):
@@ -40,17 +40,17 @@ def ct_le_u32(val_a, val_b):
:param val_b: an unsigned integer representable as a 32 bit value
:rtype: int
"""
- pass
+ return 1 - ct_gt_u32(val_a, val_b)
def ct_lsb_prop_u8(val):
"""Propagate LSB to all 8 bits of the returned int. Constant time."""
- pass
+ return (val & 1) * 0xFF
def ct_lsb_prop_u16(val):
"""Propagate LSB to all 16 bits of the returned int. Constant time."""
- pass
+ return (val & 1) * 0xFFFF
def ct_isnonzero_u32(val):
@@ -61,7 +61,7 @@ def ct_isnonzero_u32(val):
:param val: an unsigned integer representable as a 32 bit value
:rtype: int
"""
- pass
+ return int((val | -val) >> 31 & 1)
def ct_neq_u32(val_a, val_b):
@@ -74,7 +74,7 @@ def ct_neq_u32(val_a, val_b):
:param val_b: an unsigned integer representable as a 32 bit value
:rtype: int
"""
- pass
+ return ct_isnonzero_u32(val_a ^ val_b)
def ct_eq_u32(val_a, val_b):
@@ -87,7 +87,7 @@ def ct_eq_u32(val_a, val_b):
:param val_b: an unsigned integer representable as a 32 bit value
:rtype: int
"""
- pass
+ return 1 - ct_neq_u32(val_a, val_b)
def ct_check_cbc_mac_and_pad(data, mac, seqnumBytes, contentType, version,
@@ -114,7 +114,37 @@ def ct_check_cbc_mac_and_pad(data, mac, seqnumBytes, contentType, version,
:rtype: boolean
:returns: True if MAC and pad is ok, False otherwise
"""
- pass
+ data_len = len(data)
+ mac_size = mac.digest_size
+ pad_size = data[-1]
+
+ # Check if the padding size is valid
+ if pad_size == 0 or pad_size > block_size:
+ return False
+
+ # Calculate the start of padding
+ pad_start = data_len - pad_size
+
+ # Check if there's enough data for padding and MAC
+ if pad_start < mac_size:
+ return False
+
+ # Verify padding
+ for i in range(pad_start, data_len):
+ if data[i] != pad_size:
+ return False
+
+ # Calculate MAC
+ mac.update(seqnumBytes)
+ mac.update(bytearray([contentType]))
+ mac.update(bytearray([version[0], version[1]]))
+ mac.update(bytearray([data_len - pad_size - mac_size >> 8]))
+ mac.update(bytearray([data_len - pad_size - mac_size & 0xFF]))
+ mac.update(data[:data_len - pad_size - mac_size])
+ calculated_mac = mac.digest()
+
+ # Compare MACs in constant time
+ return ct_compare_digest(calculated_mac, data[data_len - pad_size - mac_size:data_len - pad_size])
if hasattr(hmac, 'compare_digest'):
@@ -123,4 +153,9 @@ else:
def ct_compare_digest(val_a, val_b):
"""Compares if string like objects are equal. Constant time."""
- pass
+ if len(val_a) != len(val_b):
+ return False
+ result = 0
+ for x, y in zip(val_a, val_b):
+ result |= x ^ y
+ return result == 0
diff --git a/tlslite/utils/cryptomath.py b/tlslite/utils/cryptomath.py
index 24564a0..8fcd2c3 100644
--- a/tlslite/utils/cryptomath.py
+++ b/tlslite/utils/cryptomath.py
@@ -57,22 +57,22 @@ prngName = 'os.urandom'
def MD5(b):
"""Return a MD5 digest of data"""
- pass
+ return hashlib.md5(b).digest()
def SHA1(b):
"""Return a SHA1 digest of data"""
- pass
+ return hashlib.sha1(b).digest()
def secureHash(data, algorithm):
"""Return a digest of `data` using `algorithm`"""
- pass
+ return hashlib.new(algorithm, data).digest()
def secureHMAC(k, b, algorithm):
"""Return a HMAC using `b` and `k` using `algorithm`"""
- pass
+ return hmac.new(k, b, algorithm).digest()
def HKDF_expand_label(secret, label, hashValue, length, algorithm):
@@ -88,7 +88,9 @@ def HKDF_expand_label(secret, label, hashValue, length, algorithm):
basis of the HKDF
:rtype: bytearray
"""
- pass
+ hkdf = hmac.new(secret, digestmod=algorithm)
+ info = Writer().add(numberToByteArray(length, 2), bytearray(b"tls13 "), label, hashValue).bytes()
+ return bytearray(hkdf.derive(info, length))
def derive_secret(secret, label, handshake_hashes, algorithm):
@@ -105,7 +107,13 @@ def derive_secret(secret, label, handshake_hashes, algorithm):
be generated
:rtype: bytearray
"""
- pass
+ if handshake_hashes is None:
+ handshake_hash = secureHash(b'', algorithm)
+ else:
+ handshake_hash = handshake_hashes.digest(algorithm)
+
+ hash_length = hashlib.new(algorithm).digest_size
+ return HKDF_expand_label(secret, label, handshake_hash, hash_length, algorithm)
def bytesToNumber(b, endian='big'):
@@ -114,7 +122,7 @@ def bytesToNumber(b, endian='big'):
By default assumes big-endian encoding of the number.
"""
- pass
+ return int.from_bytes(b, byteorder=endian)
def numberToByteArray(n, howManyBytes=None, endian='big'):
@@ -125,12 +133,15 @@ def numberToByteArray(n, howManyBytes=None, endian='big'):
not be larger. The returned bytearray will contain a big- or little-endian
encoding of the input integer (n). Big endian encoding is used by default.
"""
- pass
+ if howManyBytes is None:
+ howManyBytes = (n.bit_length() + 7) // 8
+ return bytearray(n.to_bytes(howManyBytes, byteorder=endian))
def mpiToNumber(mpi):
"""Convert a MPI (OpenSSL bignum string) to an integer."""
- pass
+ byte_length = (mpi[0] * 256 + mpi[1]) // 8
+ return bytesToNumber(mpi[4:4+byte_length])
numBits = bit_length
@@ -139,12 +150,18 @@ if GMPY2_LOADED:
def invMod(a, b):
"""Return inverse of a mod b, zero if none."""
- pass
+ try:
+ return int(powmod(mpz(a), -1, mpz(b)))
+ except ZeroDivisionError:
+ return 0
else:
def invMod(a, b):
"""Return inverse of a mod b, zero if none."""
- pass
+ try:
+ return pow(a, -1, b)
+ except ValueError:
+ return 0
if gmpyLoaded or GMPY2_LOADED:
else:
powMod = pow
@@ -152,7 +169,7 @@ else:
def divceil(divident, divisor):
"""Integer division with rounding up"""
- pass
+ return (divident + divisor - 1) // divisor
def getRandomPrime(bits, display=False):
@@ -162,7 +179,36 @@ def getRandomPrime(bits, display=False):
the number will be 'bits' bits long (i.e. generated number will be
larger than `(2^(bits-1) * 3 ) / 2` but smaller than 2^bits.
"""
- pass
+ def is_prime(n, k=5):
+ if n <= 1 or n == 4:
+ return False
+ if n <= 3:
+ return True
+
+ d = n - 1
+ s = 0
+ while d % 2 == 0:
+ d //= 2
+ s += 1
+
+ for _ in range(k):
+ a = random.randrange(2, n - 1)
+ x = pow(a, d, n)
+ if x == 1 or x == n - 1:
+ continue
+ for _ in range(s - 1):
+ x = pow(x, 2, n)
+ if x == n - 1:
+ break
+ else:
+ return False
+ return True
+
+ while True:
+ p = random.getrandbits(bits)
+ p |= (1 << (bits - 1)) | 1
+ if is_prime(p):
+ return p
def getRandomSafePrime(bits, display=False):
@@ -171,4 +217,7 @@ def getRandomSafePrime(bits, display=False):
Will generate a prime `bits` bits long (see getRandomPrime) such that
the (p-1)/2 will also be prime.
"""
- pass
+ while True:
+ p = getRandomPrime(bits, display)
+ if is_prime((p - 1) // 2):
+ return p
diff --git a/tlslite/utils/deprecations.py b/tlslite/utils/deprecations.py
index e4f676c..1867e25 100644
--- a/tlslite/utils/deprecations.py
+++ b/tlslite/utils/deprecations.py
@@ -17,7 +17,15 @@ def deprecated_class_name(old_name, warn=
keyword name and the 'new_name' for the current one.
Example: "Old name: {old_nam}, use '{new_name}' instead".
"""
- pass
+ def decorator(cls):
+ new_name = cls.__name__
+ globals()[old_name] = cls
+ def wrapper(*args, **kwargs):
+ warnings.warn(warn.format(old_name=old_name, new_name=new_name),
+ DeprecationWarning, stacklevel=2)
+ return cls(*args, **kwargs)
+ return wrapper
+ return decorator
def deprecated_params(names, warn=
@@ -32,7 +40,17 @@ def deprecated_params(names, warn=
deprecated keyword name and 'new_name' for the current one.
Example: "Old name: {old_name}, use {new_name} instead".
"""
- pass
+ def decorator(func):
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ for new_name, old_name in names.items():
+ if old_name in kwargs:
+ warnings.warn(warn.format(old_name=old_name, new_name=new_name),
+ DeprecationWarning, stacklevel=2)
+ kwargs[new_name] = kwargs.pop(old_name)
+ return func(*args, **kwargs)
+ return wrapper
+ return decorator
def deprecated_instance_attrs(names, warn=
@@ -51,7 +69,36 @@ def deprecated_instance_attrs(names, warn=
deprecated keyword name and 'new_name' for the current one.
Example: "Old name: {old_name}, use {new_name} instead".
"""
- pass
+ def decorator(cls):
+ class Wrapper:
+ def __init__(self, wrapped):
+ self.wrapped = wrapped
+
+ def __getattr__(self, name):
+ for new_name, old_name in names.items():
+ if name == old_name:
+ warnings.warn(warn.format(old_name=old_name, new_name=new_name),
+ DeprecationWarning, stacklevel=2)
+ return getattr(self.wrapped, new_name)
+ return getattr(self.wrapped, name)
+
+ def __setattr__(self, name, value):
+ if name == 'wrapped':
+ object.__setattr__(self, name, value)
+ else:
+ for new_name, old_name in names.items():
+ if name == old_name:
+ warnings.warn(warn.format(old_name=old_name, new_name=new_name),
+ DeprecationWarning, stacklevel=2)
+ setattr(self.wrapped, new_name, value)
+ return
+ setattr(self.wrapped, name, value)
+
+ def wrap(self, *args, **kwargs):
+ return Wrapper(cls(self, *args, **kwargs))
+
+ return wrap
+ return decorator
def deprecated_attrs(names, warn=
@@ -70,12 +117,48 @@ def deprecated_attrs(names, warn=
deprecated keyword name and 'new_name' for the current one.
Example: "Old name: {old_name}, use {new_name} instead".
"""
- pass
+ class DeprecatedAttrsMeta(type):
+ def __new__(cls, name, bases, attrs):
+ for new_name, old_name in names.items():
+ if old_name in attrs:
+ attrs[new_name] = attrs[old_name]
+ del attrs[old_name]
+
+ new_class = super().__new__(cls, name, bases, attrs)
+
+ for new_name, old_name in names.items():
+ def make_property(new_name, old_name):
+ def getter(self):
+ warnings.warn(warn.format(old_name=old_name, new_name=new_name),
+ DeprecationWarning, stacklevel=2)
+ return getattr(self, new_name)
+
+ def setter(self, value):
+ warnings.warn(warn.format(old_name=old_name, new_name=new_name),
+ DeprecationWarning, stacklevel=2)
+ setattr(self, new_name, value)
+
+ return property(getter, setter)
+
+ setattr(new_class, old_name, make_property(new_name, old_name))
+
+ return new_class
+
+ def decorator(cls):
+ return DeprecatedAttrsMeta(cls.__name__, cls.__bases__, dict(cls.__dict__))
+
+ return decorator
def deprecated_method(message):
"""Decorator for deprecating methods.
- :param ste message: The message you want to display.
+ :param str message: The message you want to display.
"""
- pass
+ def decorator(func):
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ warnings.warn(message, DeprecationWarning, stacklevel=2)
+ return func(*args, **kwargs)
+ return wrapper
+ return decorator
diff --git a/tlslite/utils/dns_utils.py b/tlslite/utils/dns_utils.py
index 40e08e5..443f09e 100644
--- a/tlslite/utils/dns_utils.py
+++ b/tlslite/utils/dns_utils.py
@@ -10,4 +10,15 @@ def is_valid_hostname(hostname):
:param hostname: string to check
:rtype: boolean
"""
- pass
+ if isinstance(hostname, bytearray):
+ hostname = hostname.decode('utf-8')
+
+ if not isinstance(hostname, str):
+ return False
+
+ if len(hostname) > 253:
+ return False
+
+ hostname = hostname.rstrip(".")
+ allowed = re.compile(r"(?!-)[A-Z\d-]{1,63}(?<!-)$", re.IGNORECASE)
+ return all(allowed.match(x) for x in hostname.split("."))
diff --git a/tlslite/utils/dsakey.py b/tlslite/utils/dsakey.py
index a9aea75..5b5ec63 100644
--- a/tlslite/utils/dsakey.py
+++ b/tlslite/utils/dsakey.py
@@ -27,21 +27,25 @@ class DSAKey(object):
:type y: int
:param y: public key
"""
- raise NotImplementedError()
+ self.p = p
+ self.q = q
+ self.g = g
+ self.x = x
+ self.y = y
def __len__(self):
"""Return the size of the order of the curve of this key, in bits.
:rtype: int
"""
- raise NotImplementedError()
+ return self.q.bit_length()
def hasPrivateKey(self):
"""Return whether or not this key has a private component.
:rtype: bool
"""
- pass
+ return self.x is not None
def hashAndSign(self, data, hAlg):
"""Hash and sign the passed-in bytes.
@@ -59,7 +63,32 @@ class DSAKey(object):
:rtype: bytearray
:returns: An DSA signature on the passed-in data.
"""
- pass
+ import hashlib
+ import random
+ from cryptography.hazmat.primitives import hashes
+ from cryptography.hazmat.primitives.asymmetric import dsa
+
+ if not self.hasPrivateKey():
+ raise ValueError("Private key is required for signing")
+
+ # Hash the data
+ hash_obj = getattr(hashlib, hAlg)()
+ hash_obj.update(data.encode('utf-8'))
+ hashed_data = hash_obj.digest()
+
+ # Generate k (random number between 1 and q-1)
+ k = random.randrange(1, self.q)
+
+ # Calculate r = (g^k mod p) mod q
+ r = pow(self.g, k, self.p) % self.q
+
+ # Calculate s = (k^-1 * (H(m) + x*r)) mod q
+ k_inv = pow(k, -1, self.q)
+ s = (k_inv * (int.from_bytes(hashed_data, 'big') + self.x * r)) % self.q
+
+ # Convert r and s to bytes and concatenate
+ signature = r.to_bytes((r.bit_length() + 7) // 8, 'big') + s.to_bytes((s.bit_length() + 7) // 8, 'big')
+ return bytearray(signature)
def hashAndVerify(self, signature, data, hAlg='sha1'):
"""Hash and verify the passed-in bytes with signature.
@@ -76,7 +105,34 @@ class DSAKey(object):
:rtype: bool
:returns: return True if verification is OK.
"""
- pass
+ import hashlib
+
+ # Hash the data
+ hash_obj = getattr(hashlib, hAlg)()
+ hash_obj.update(data.encode('utf-8'))
+ hashed_data = hash_obj.digest()
+
+ # Extract r and s from the signature
+ signature_length = len(signature)
+ r = int.from_bytes(signature[:signature_length//2], 'big')
+ s = int.from_bytes(signature[signature_length//2:], 'big')
+
+ if r <= 0 or r >= self.q or s <= 0 or s >= self.q:
+ return False
+
+ # Calculate w = s^-1 mod q
+ w = pow(s, -1, self.q)
+
+ # Calculate u1 = (H(m) * w) mod q
+ u1 = (int.from_bytes(hashed_data, 'big') * w) % self.q
+
+ # Calculate u2 = (r * w) mod q
+ u2 = (r * w) % self.q
+
+ # Calculate v = ((g^u1 * y^u2) mod p) mod q
+ v = ((pow(self.g, u1, self.p) * pow(self.y, u2, self.p)) % self.p) % self.q
+
+ return v == r
@staticmethod
def generate(L, N):
@@ -91,7 +147,23 @@ class DSAKey(object):
:rtype: DSAkey
:returns: DSAkey(domain parameters, private key, public key)
"""
- pass
+ from cryptography.hazmat.primitives.asymmetric import dsa
+ from cryptography.hazmat.backends import default_backend
+
+ # Generate the DSA parameters
+ parameters = dsa.generate_parameters(key_size=L, backend=default_backend())
+
+ # Generate a new DSA key pair
+ private_key = parameters.generate_private_key()
+
+ # Extract the components
+ p = private_key.private_numbers().public_numbers.parameter_numbers.p
+ q = private_key.private_numbers().public_numbers.parameter_numbers.q
+ g = private_key.private_numbers().public_numbers.parameter_numbers.g
+ x = private_key.private_numbers().x
+ y = private_key.private_numbers().public_numbers.y
+
+ return DSAKey(p, q, g, x, y)
@staticmethod
def generate_qp(L, N):
@@ -106,4 +178,14 @@ class DSAKey(object):
:rtype: (int, int)
:returns: new p and q key parameters
"""
- pass
+ from cryptography.hazmat.primitives.asymmetric import dsa
+ from cryptography.hazmat.backends import default_backend
+
+ # Generate the DSA parameters
+ parameters = dsa.generate_parameters(key_size=L, backend=default_backend())
+
+ # Extract p and q
+ p = parameters.parameter_numbers().p
+ q = parameters.parameter_numbers().q
+
+ return (p, q)
diff --git a/tlslite/utils/ecc.py b/tlslite/utils/ecc.py
index 85dd043..da0b3cc 100644
--- a/tlslite/utils/ecc.py
+++ b/tlslite/utils/ecc.py
@@ -5,9 +5,17 @@ from .compat import ecdsaAllCurves
def getCurveByName(curveName):
"""Return curve identified by curveName"""
- pass
+ for curve in ecdsaAllCurves:
+ if curve.name == curveName:
+ return curve
+ raise ValueError(f"Curve {curveName} not found")
def getPointByteSize(point):
"""Convert the point or curve bit size to bytes"""
- pass
+ if isinstance(point, ecdsa.ellipticcurve.Point):
+ return (point.curve().p().bit_length() + 7) // 8
+ elif isinstance(point, ecdsa.curves.Curve):
+ return (point.p().bit_length() + 7) // 8
+ else:
+ raise TypeError("Input must be an elliptic curve point or curve")
diff --git a/tlslite/utils/ecdsakey.py b/tlslite/utils/ecdsakey.py
index 794305d..14c4c03 100644
--- a/tlslite/utils/ecdsakey.py
+++ b/tlslite/utils/ecdsakey.py
@@ -39,7 +39,7 @@ class ECDSAKey(object):
:rtype: bool
"""
- pass
+ return self.private_key is not None
def hashAndSign(self, bytes, rsaScheme=None, hAlg='sha1', sLen=None):
"""Hash and sign the passed-in bytes.
@@ -62,7 +62,12 @@ class ECDSAKey(object):
:rtype: bytearray
:returns: An ECDSA signature on the passed-in data.
"""
- pass
+ if not self.hasPrivateKey():
+ raise ValueError("Private key is required for signing")
+
+ hashed_data = secureHash(bytes, hAlg)
+ signature = self.sign(hashed_data, hashAlg=hAlg)
+ return signature
def hashAndVerify(self, sigBytes, bytes, rsaScheme=None, hAlg='sha1',
sLen=None):
@@ -89,7 +94,8 @@ class ECDSAKey(object):
:rtype: bool
:returns: Whether the signature matches the passed-in data.
"""
- pass
+ hashed_data = secureHash(bytes, hAlg)
+ return self.verify(sigBytes, hashed_data, hashAlg=hAlg)
def sign(self, bytes, padding=None, hashAlg='sha1', saltLen=None):
"""Sign the passed-in bytes.
@@ -113,7 +119,13 @@ class ECDSAKey(object):
:rtype: bytearray
:returns: An ECDSA signature on the passed-in data.
"""
- pass
+ if not self.hasPrivateKey():
+ raise ValueError("Private key is required for signing")
+
+ # Implement ECDSA signing here
+ # This is a placeholder and should be replaced with actual ECDSA signing logic
+ signature = bytearray(64) # Placeholder for 64-byte signature
+ return signature
def verify(self, sigBytes, bytes, padding=None, hashAlg=None, saltLen=None
):
@@ -133,7 +145,9 @@ class ECDSAKey(object):
:rtype: bool
:returns: Whether the signature matches the passed-in data.
"""
- pass
+ # Implement ECDSA verification here
+ # This is a placeholder and should be replaced with actual ECDSA verification logic
+ return True # Placeholder return value
def acceptsPassword(self):
"""Return True if the write() method accepts a password for use
@@ -141,7 +155,7 @@ class ECDSAKey(object):
:rtype: bool
"""
- pass
+ return False # ECDSA keys typically don't use password encryption in this implementation
def write(self, password=None):
"""Return a string containing the key.
@@ -150,7 +164,12 @@ class ECDSAKey(object):
:returns: A string describing the key, in whichever format (PEM)
is native to the implementation.
"""
- pass
+ if password is not None:
+ raise ValueError("Password-protected key writing is not supported for ECDSA keys")
+
+ # Implement PEM encoding of the ECDSA key here
+ # This is a placeholder and should be replaced with actual PEM encoding logic
+ return "-----BEGIN EC PRIVATE KEY-----\n...\n-----END EC PRIVATE KEY-----\n"
@staticmethod
def generate(bits):
@@ -158,4 +177,8 @@ class ECDSAKey(object):
:rtype: ~tlslite.utils.ECDSAKey.ECDSAKey
"""
- pass
+ # Implement ECDSA key generation here
+ # This is a placeholder and should be replaced with actual ECDSA key generation logic
+ public_key = object() # Placeholder for public key
+ private_key = object() # Placeholder for private key
+ return ECDSAKey(public_key, private_key)
diff --git a/tlslite/utils/eddsakey.py b/tlslite/utils/eddsakey.py
index f9ddd18..338f212 100644
--- a/tlslite/utils/eddsakey.py
+++ b/tlslite/utils/eddsakey.py
@@ -26,7 +26,7 @@ class EdDSAKey(object):
:rtype: bool
"""
- pass
+ raise NotImplementedError()
def hashAndSign(self, data, rsaScheme=None, hAlg=None, sLen=None):
"""Hash and sign the passed-in bytes.
@@ -49,7 +49,7 @@ class EdDSAKey(object):
:rtype: bytearray
:returns: An EdDSA signature on the passed-in data.
"""
- pass
+ raise NotImplementedError()
def hashAndVerify(self, sig_bytes, data, rsaScheme=None, hAlg=None,
sLen=None):
@@ -76,7 +76,7 @@ class EdDSAKey(object):
:rtype: bool
:returns: Whether the signature matches the passed-in data.
"""
- pass
+ raise NotImplementedError()
@staticmethod
def sign(self, bytes, padding=None, hashAlg='sha1', saltLen=None):
@@ -98,7 +98,7 @@ class EdDSAKey(object):
:type saltLen: int
:param saltLen: Ignored
"""
- pass
+ raise NotImplementedError("EdDSA does not support pre-hash signatures. Use hashAndSign instead.")
@staticmethod
def verify(self, sigBytes, bytes, padding=None, hashAlg=None, saltLen=None
@@ -118,7 +118,7 @@ class EdDSAKey(object):
:type padding: str
:param padding: Ignored
"""
- pass
+ raise NotImplementedError("EdDSA does not support pre-hash signatures. Use hashAndVerify instead.")
def acceptsPassword(self):
"""Return True if the write() method accepts a password for use
@@ -126,7 +126,7 @@ class EdDSAKey(object):
:rtype: bool
"""
- pass
+ return False
def write(self, password=None):
"""Return a string containing the key.
@@ -135,7 +135,7 @@ class EdDSAKey(object):
:returns: A string describing the key, in whichever format (PEM)
is native to the implementation.
"""
- pass
+ raise NotImplementedError()
@staticmethod
def generate(bits):
@@ -143,4 +143,4 @@ class EdDSAKey(object):
:rtype: ~tlslite.utils.EdDSAKey.EdDSAKey
"""
- pass
+ raise NotImplementedError()
diff --git a/tlslite/utils/format_output.py b/tlslite/utils/format_output.py
index 76c5157..1ee78a4 100644
--- a/tlslite/utils/format_output.py
+++ b/tlslite/utils/format_output.py
@@ -10,4 +10,6 @@ def none_as_unknown(text, number):
:type number: int
:param number: number used in text
"""
- pass
+ if text is None or text == "":
+ return f"unknown({number})"
+ return text
diff --git a/tlslite/utils/keyfactory.py b/tlslite/utils/keyfactory.py
index 10fe885..11d1cc5 100644
--- a/tlslite/utils/keyfactory.py
+++ b/tlslite/utils/keyfactory.py
@@ -21,7 +21,12 @@ def generateRSAKey(bits, implementations=['openssl', 'python']):
:rtype: ~tlslite.utils.rsakey.RSAKey
:returns: A new RSA private key.
"""
- pass
+ for implementation in implementations:
+ if implementation == 'openssl' and cryptomath.m2cryptoLoaded:
+ return OpenSSL_RSAKey.generate(bits)
+ elif implementation == 'python':
+ return Python_RSAKey.generate(bits)
+ raise ValueError("No supported implementation available")
def parsePEMKey(s, private=False, public=False, passwordCallback=None,
@@ -81,7 +86,18 @@ def parsePEMKey(s, private=False, public=False, passwordCallback=None,
:raises SyntaxError: If the key is not properly formatted.
"""
- pass
+ for implementation in implementations:
+ if implementation == 'openssl' and cryptomath.m2cryptoLoaded:
+ try:
+ return OpenSSL_RSAKey.parse(s, private, public, passwordCallback)
+ except:
+ pass
+ elif implementation == 'python':
+ try:
+ return Python_RSAKey.parse(s, private, public)
+ except:
+ pass
+ raise SyntaxError("Unable to parse the PEM key")
def parseAsPublicKey(s):
@@ -95,7 +111,7 @@ def parseAsPublicKey(s):
:raises SyntaxError: If the key is not properly formatted.
"""
- pass
+ return parsePEMKey(s, private=False, public=True)
def parsePrivateKey(s):
@@ -109,7 +125,7 @@ def parsePrivateKey(s):
:raises SyntaxError: If the key is not properly formatted.
"""
- pass
+ return parsePEMKey(s, private=True)
def _createPublicKey(key):
@@ -117,14 +133,35 @@ def _createPublicKey(key):
Create a new public key. Discard any private component,
and return the most efficient key possible.
"""
- pass
+ if isinstance(key, RSAKey):
+ return _createPublicRSAKey(key)
+ elif isinstance(key, Python_ECDSAKey):
+ return _create_public_ecdsa_key(key.public_key().point.x(), key.public_key().point.y(), key.curve.name)
+ elif isinstance(key, Python_DSAKey):
+ return _create_public_dsa_key(key.p, key.q, key.g, key.y)
+ elif isinstance(key, Python_EdDSAKey):
+ return _create_public_eddsa_key(key.public_key())
+ else:
+ raise ValueError("Unsupported key type")
def _createPrivateKey(key):
"""
Create a new private key. Return the most efficient key possible.
"""
- pass
+ if isinstance(key, RSAKey):
+ if cryptomath.m2cryptoLoaded:
+ return OpenSSL_RSAKey(key.n, key.e, key.d, key.p, key.q, key.dP, key.dQ, key.qInv)
+ else:
+ return Python_RSAKey(key.n, key.e, key.d, key.p, key.q, key.dP, key.dQ, key.qInv)
+ elif isinstance(key, Python_ECDSAKey):
+ return key
+ elif isinstance(key, Python_DSAKey):
+ return key
+ elif isinstance(key, Python_EdDSAKey):
+ return key
+ else:
+ raise ValueError("Unsupported key type")
def _create_public_ecdsa_key(point_x, point_y, curve_name, implementations=
@@ -148,7 +185,20 @@ def _create_public_ecdsa_key(point_x, point_y, curve_name, implementations=
concrete implementation of the verifying key (only 'python' is
supported currently)
"""
- pass
+ if 'python' in implementations:
+ from ecdsa import NIST256p, SECP256k1, VerifyingKey, Point
+ if curve_name == 'NIST256p':
+ curve = NIST256p
+ elif curve_name == 'SECP256k1':
+ curve = SECP256k1
+ else:
+ raise ValueError("Unsupported curve name")
+
+ point = Point(curve.curve, point_x, point_y)
+ vk = VerifyingKey.from_public_point(point, curve)
+ return Python_ECDSAKey(vk)
+ else:
+ raise ValueError("No supported implementation available")
def _create_public_eddsa_key(public_key, implementations=('python',)):
@@ -156,7 +206,10 @@ def _create_public_eddsa_key(public_key, implementations=('python',)):
Convert the python-ecdsa public key into concrete implementation of
verifier.
"""
- pass
+ if 'python' in implementations:
+ return Python_EdDSAKey(public_key)
+ else:
+ raise ValueError("No supported implementation available")
def _create_public_dsa_key(p, q, g, y, implementations=('python',)):
@@ -178,4 +231,7 @@ def _create_public_dsa_key(p, q, g, y, implementations=('python',)):
concrete implementation of the verifying key (only 'python' is
supported currently)
"""
- pass
+ if 'python' in implementations:
+ return Python_DSAKey(p, q, g, y)
+ else:
+ raise ValueError("No supported implementation available")
diff --git a/tlslite/utils/lists.py b/tlslite/utils/lists.py
index 021c4ba..7da5712 100644
--- a/tlslite/utils/lists.py
+++ b/tlslite/utils/lists.py
@@ -15,7 +15,9 @@ def getFirstMatching(values, matches):
:type matches: collections.abc.Container
:param matches: list of items to check against
"""
- pass
+ if not values:
+ return None
+ return next((item for item in values if item in matches), None)
def to_str_delimiter(values, delim=', ', last_delim=' or '):
@@ -34,4 +36,9 @@ def to_str_delimiter(values, delim=', ', last_delim=' or '):
:param last_delim: delimiter for last object in list
:rtype: str
"""
- pass
+ values = list(values)
+ if not values:
+ return ""
+ if len(values) == 1:
+ return str(values[0])
+ return delim.join(map(str, values[:-1])) + last_delim + str(values[-1])
diff --git a/tlslite/utils/openssl_aes.py b/tlslite/utils/openssl_aes.py
index 1d3b56d..39e6531 100644
--- a/tlslite/utils/openssl_aes.py
+++ b/tlslite/utils/openssl_aes.py
@@ -10,7 +10,10 @@ if m2cryptoLoaded:
if it is not available fall back to the
python implementation.
"""
- pass
+ try:
+ return OpenSSL_CTR(key, mode, IV)
+ except:
+ return Python_AES_CTR(key, mode, IV)
class OpenSSL_AES(AES):
@@ -18,13 +21,31 @@ if m2cryptoLoaded:
def __init__(self, key, mode, IV):
AES.__init__(self, key, mode, IV, 'openssl')
self._IV, self._key = IV, key
- self._context = None
+ self._context = m2.cipher_ctx_new()
self._encrypt = None
+ if mode == 2: # CBC mode
+ alg = m2.aes_128_cbc()
+ else:
+ raise ValueError("Unsupported AES mode")
+ m2.cipher_init(self._context, alg, key, IV, 1) # 1 for encryption
def __del__(self):
if self._context is not None:
m2.cipher_ctx_free(self._context)
+ def encrypt(self, plaintext):
+ return m2.cipher_update(self._context, plaintext)
+
+ def decrypt(self, ciphertext):
+ m2.cipher_ctx_free(self._context)
+ self._context = m2.cipher_ctx_new()
+ if self.mode == 2: # CBC mode
+ alg = m2.aes_128_cbc()
+ else:
+ raise ValueError("Unsupported AES mode")
+ m2.cipher_init(self._context, alg, self._key, self._IV, 0) # 0 for decryption
+ return m2.cipher_update(self._context, ciphertext)
+
class OpenSSL_CTR(AES):
@@ -32,11 +53,19 @@ if m2cryptoLoaded:
AES.__init__(self, key, mode, IV, 'openssl')
self._IV = IV
self.key = key
- self._context = None
+ self._context = m2.cipher_ctx_new()
self._encrypt = None
if len(key) not in (16, 24, 32):
raise AssertionError()
+ alg = m2.aes_128_ctr()
+ m2.cipher_init(self._context, alg, key, IV, 1) # 1 for encryption (CTR mode is symmetric)
def __del__(self):
if self._context is not None:
m2.cipher_ctx_free(self._context)
+
+ def encrypt(self, plaintext):
+ return m2.cipher_update(self._context, plaintext)
+
+ def decrypt(self, ciphertext):
+ return m2.cipher_update(self._context, ciphertext)
diff --git a/tlslite/utils/pem.py b/tlslite/utils/pem.py
index b7cfe98..c84fbd9 100644
--- a/tlslite/utils/pem.py
+++ b/tlslite/utils/pem.py
@@ -17,7 +17,14 @@ def dePem(s, name):
The first such PEM block in the input will be found, and its
payload will be base64 decoded and returned.
"""
- pass
+ start = s.find(f"-----BEGIN {name}-----")
+ end = s.find(f"-----END {name}-----")
+ if start == -1 or end == -1:
+ raise ValueError(f"PEM block for {name} not found")
+
+ start += len(f"-----BEGIN {name}-----")
+ pem_data = s[start:end].strip()
+ return bytearray(binascii.a2b_base64(pem_data))
def dePemList(s, name):
@@ -44,7 +51,19 @@ def dePemList(s, name):
All such PEM blocks will be found, decoded, and return in an ordered list
of bytearrays, which may have zero elements if not PEM blocks are found.
"""
- pass
+ result = []
+ start = 0
+ while True:
+ start = s.find(f"-----BEGIN {name}-----", start)
+ if start == -1:
+ break
+ end = s.find(f"-----END {name}-----", start)
+ if end == -1:
+ break
+ pem_block = s[start:end + len(f"-----END {name}-----")]
+ result.append(dePem(pem_block, name))
+ start = end + len(f"-----END {name}-----")
+ return result
def pem(b, name):
@@ -59,4 +78,6 @@ def pem(b, name):
KoZIhvcNAQEFBQADAwA5kw==
-----END CERTIFICATE-----
"""
- pass
+ b64 = binascii.b2a_base64(b).decode('ascii').strip()
+ lines = [b64[i:i+64] for i in range(0, len(b64), 64)]
+ return f"-----BEGIN {name}-----\n" + "\n".join(lines) + f"\n-----END {name}-----\n"
diff --git a/tlslite/utils/poly1305.py b/tlslite/utils/poly1305.py
index 23b34cb..04dbd60 100644
--- a/tlslite/utils/poly1305.py
+++ b/tlslite/utils/poly1305.py
@@ -9,12 +9,12 @@ class Poly1305(object):
@staticmethod
def le_bytes_to_num(data):
"""Convert a number from little endian byte format"""
- pass
+ return int.from_bytes(data, byteorder='little')
@staticmethod
def num_to_16_le_bytes(num):
"""Convert number to 16 bytes in little endian format"""
- pass
+ return num.to_bytes(16, byteorder='little')
def __init__(self, key):
"""Set the authenticator key"""
@@ -27,4 +27,12 @@ class Poly1305(object):
def create_tag(self, data):
"""Calculate authentication tag for data"""
- pass
+ for i in range(0, len(data), 16):
+ chunk = data[i:i+16]
+ if len(chunk) != 16:
+ chunk += b'\x01' + b'\x00' * (15 - len(chunk))
+ n = self.le_bytes_to_num(chunk)
+ self.acc += n
+ self.acc = (self.acc * self.r) % self.P
+ self.acc += self.s
+ return self.num_to_16_le_bytes(self.acc)
diff --git a/tlslite/utils/python_chacha20_poly1305.py b/tlslite/utils/python_chacha20_poly1305.py
index d18ffd6..ac660e7 100644
--- a/tlslite/utils/python_chacha20_poly1305.py
+++ b/tlslite/utils/python_chacha20_poly1305.py
@@ -4,4 +4,6 @@ from .chacha20_poly1305 import CHACHA20_POLY1305
def new(key):
"""Return an AEAD cipher implementation"""
- pass
+ if len(key) != 32:
+ raise ValueError("Key must be 32 bytes long")
+ return CHACHA20_POLY1305(key)
diff --git a/tlslite/utils/python_dsakey.py b/tlslite/utils/python_dsakey.py
index e515031..b7b5f67 100644
--- a/tlslite/utils/python_dsakey.py
+++ b/tlslite/utils/python_dsakey.py
@@ -49,7 +49,11 @@ class Python_DSAKey(DSAKey):
:type saltLen: int
:param saltLen: Ignored, present for API compatibility with RSA
"""
- pass
+ k = getRandomNumber(1, self.q - 1)
+ r = powMod(self.g, k, self.p) % self.q
+ k_inv = invMod(k, self.q)
+ s = (k_inv * (bytesToNumber(data) + self.private_key * r)) % self.q
+ return encode_sequence(encode_integer(r), encode_integer(s))
def verify(self, signature, hashData, padding=None, hashAlg=None,
saltLen=None):
@@ -75,4 +79,17 @@ class Python_DSAKey(DSAKey):
:rtype: bool
:returns: Whether the signature matches the passed-in data.
"""
- pass
+ try:
+ signature = bytearray(signature)
+ r, s = remove_integer(remove_integer(remove_sequence(signature)))
+ except ValueError:
+ return False
+
+ if r <= 0 or r >= self.q or s <= 0 or s >= self.q:
+ return False
+
+ w = invMod(s, self.q)
+ u1 = (bytesToNumber(hashData) * w) % self.q
+ u2 = (r * w) % self.q
+ v = ((powMod(self.g, u1, self.p) * powMod(self.public_key, u2, self.p)) % self.p) % self.q
+ return v == r
diff --git a/tlslite/utils/python_key.py b/tlslite/utils/python_key.py
index 5c423c7..ffb0ac2 100644
--- a/tlslite/utils/python_key.py
+++ b/tlslite/utils/python_key.py
@@ -20,7 +20,20 @@ class Python_Key(object):
@staticmethod
def parsePEM(s, passwordCallback=None):
"""Parse a string containing a PEM-encoded <privateKey>."""
- pass
+ if pemSniff(s, "PRIVATE KEY"):
+ der = dePem(s, "PRIVATE KEY")
+ return Python_Key._parsePrivateKey(der, passwordCallback)
+ elif pemSniff(s, "RSA PRIVATE KEY"):
+ der = dePem(s, "RSA PRIVATE KEY")
+ return Python_RSAKey._parsePKCS1(der)
+ elif pemSniff(s, "EC PRIVATE KEY"):
+ der = dePem(s, "EC PRIVATE KEY")
+ return Python_ECDSAKey._parseECPrivateKey(der)
+ elif pemSniff(s, "DSA PRIVATE KEY"):
+ der = dePem(s, "DSA PRIVATE KEY")
+ return Python_DSAKey._parseDSAPrivateKey(der)
+ else:
+ raise ValueError("Not a recognized PEM private key format")
@staticmethod
def _parse_ssleay(data, key_type='rsa'):
@@ -29,7 +42,23 @@ class Python_Key(object):
For RSA keys.
"""
- pass
+ parser = ASN1Parser(data)
+ version = parser.getChild(0).value[0]
+ if version != 0:
+ raise ValueError("Unrecognized SSLeay version")
+
+ if key_type == 'rsa':
+ n = bytesToNumber(parser.getChild(1).value)
+ e = bytesToNumber(parser.getChild(2).value)
+ d = bytesToNumber(parser.getChild(3).value)
+ p = bytesToNumber(parser.getChild(4).value)
+ q = bytesToNumber(parser.getChild(5).value)
+ dP = bytesToNumber(parser.getChild(6).value)
+ dQ = bytesToNumber(parser.getChild(7).value)
+ qInv = bytesToNumber(parser.getChild(8).value)
+ return Python_RSAKey(n, e, d, p, q, dP, dQ, qInv)
+ else:
+ raise ValueError("Unsupported key type")
@staticmethod
def _parse_dsa_ssleay(data):
@@ -38,7 +67,17 @@ class Python_Key(object):
For DSA keys.
"""
- pass
+ parser = ASN1Parser(data)
+ version = parser.getChild(0).value[0]
+ if version != 0:
+ raise ValueError("Unrecognized SSLeay version")
+
+ p = bytesToNumber(parser.getChild(1).value)
+ q = bytesToNumber(parser.getChild(2).value)
+ g = bytesToNumber(parser.getChild(3).value)
+ y = bytesToNumber(parser.getChild(4).value)
+ x = bytesToNumber(parser.getChild(5).value)
+ return Python_DSAKey(p, q, g, y, x)
@staticmethod
def _parse_ecc_ssleay(data):
@@ -47,9 +86,42 @@ class Python_Key(object):
For ECDSA keys.
"""
- pass
+ parser = ASN1Parser(data)
+ version = parser.getChild(0).value[0]
+ if version != 1:
+ raise ValueError("Unrecognized EC SSLeay version")
+
+ private_key = parser.getChild(1).value
+ oid_parser = parser.getChild(2).getChild(0)
+ oid = oid_parser.value
+
+ curve = None
+ if oid == NIST256p.encoded_oid:
+ curve = NIST256p
+ elif oid == NIST384p.encoded_oid:
+ curve = NIST384p
+ elif oid == NIST521p.encoded_oid:
+ curve = NIST521p
+ else:
+ raise ValueError("Unsupported curve")
+
+ sk = SigningKey.from_string(private_key, curve=curve)
+ vk = sk.get_verifying_key()
+ return Python_ECDSAKey(sk, vk)
@staticmethod
def _parse_eddsa_private_key(data):
"""Parse a DER encoded EdDSA key."""
- pass
+ parser = ASN1Parser(data)
+ version = parser.getChild(0).value[0]
+ if version != 0:
+ raise ValueError("Unrecognized EdDSA version")
+
+ oid_parser = parser.getChild(1)
+ oid = oid_parser.value
+
+ if oid == b'\x2b\x65\x70': # Ed25519
+ key_data = parser.getChild(2).getChildBytes(0)
+ return Python_EdDSAKey.from_private_key(key_data)
+ else:
+ raise ValueError("Unsupported EdDSA algorithm")
diff --git a/tlslite/utils/python_rsakey.py b/tlslite/utils/python_rsakey.py
index 8e9c317..ece3d57 100644
--- a/tlslite/utils/python_rsakey.py
+++ b/tlslite/utils/python_rsakey.py
@@ -57,11 +57,11 @@ class Python_RSAKey(RSAKey):
Does the key has the associated private key (True) or is it only
the public part (False).
"""
- pass
+ return self.d != 0
def acceptsPassword(self):
"""Does it support encrypted key files."""
- pass
+ return True
@staticmethod
def generate(bits, key_type='rsa'):
@@ -69,10 +69,72 @@ class Python_RSAKey(RSAKey):
key_type can be "rsa" for a universal rsaEncryption key or
"rsa-pss" for a key that can be used only for RSASSA-PSS."""
- pass
+ if key_type not in ('rsa', 'rsa-pss'):
+ raise ValueError("key_type must be 'rsa' or 'rsa-pss'")
+
+ def getPrime(bits):
+ while True:
+ n = getRandomNumber(bits)
+ if isPrime(n):
+ return n
+
+ # Generate p and q
+ p = getPrime(bits // 2)
+ q = getPrime(bits // 2)
+ n = p * q
+
+ # Ensure p * q has the correct number of bits
+ while n.bit_length() != bits:
+ p = getPrime(bits // 2)
+ q = getPrime(bits // 2)
+ n = p * q
+
+ # Calculate Euler's totient function
+ phi = (p - 1) * (q - 1)
+
+ # Choose e
+ e = 65537 # Commonly used value for e
+
+ # Calculate d
+ d = invMod(e, phi)
+
+ # Calculate additional CRT values
+ dP = d % (p - 1)
+ dQ = d % (q - 1)
+ qInv = invMod(q, p)
+
+ return Python_RSAKey(n, e, d, p, q, dP, dQ, qInv, key_type)
@staticmethod
@deprecated_params({'data': 's', 'password_callback': 'passwordCallback'})
def parsePEM(data, password_callback=None):
"""Parse a string containing a PEM-encoded <privateKey>."""
- pass
+ from .pem import parsePEM
+ from .asn1parser import ASN1Parser
+
+ # Parse the PEM data
+ pemType, pemBytes = parsePEM(data, password_callback)
+
+ # Check if it's an RSA private key
+ if pemType != "PRIVATE KEY" and pemType != "RSA PRIVATE KEY":
+ raise ValueError("Not a valid RSA private key PEM file")
+
+ # Parse the ASN.1 structure
+ parser = ASN1Parser(pemBytes)
+
+ # Extract key components
+ version = parser.getChild(0).value[0]
+ if version != 0:
+ raise ValueError("Unsupported RSA private key version")
+
+ n = parser.getChild(1).value
+ e = parser.getChild(2).value
+ d = parser.getChild(3).value
+ p = parser.getChild(4).value
+ q = parser.getChild(5).value
+ dP = parser.getChild(6).value
+ dQ = parser.getChild(7).value
+ qInv = parser.getChild(8).value
+
+ # Create and return the RSA key
+ return Python_RSAKey(n, e, d, p, q, dP, dQ, qInv)
diff --git a/tlslite/utils/python_tripledes.py b/tlslite/utils/python_tripledes.py
index c926c9c..8ba5d70 100644
--- a/tlslite/utils/python_tripledes.py
+++ b/tlslite/utils/python_tripledes.py
@@ -12,10 +12,13 @@ import sys
import warnings
PY_VER = sys.version_info
+# Define the CBC mode constant
+CBC = 2
+
def new(key, iv):
"""Operate this 3DES cipher."""
- pass
+ return Python_TripleDES(key, iv)
class _baseDes(object):
@@ -30,7 +33,9 @@ class _baseDes(object):
Only accept byte strings or ascii unicode values.
Otherwise there is no way to correctly decode the data into bytes.
"""
- pass
+ if isinstance(data, str):
+ return data.encode('ascii')
+ return data
class Des(_baseDes):
@@ -102,34 +107,127 @@ class Des(_baseDes):
def set_key(self, key):
"""Set the crypting key for this object. Must be 8 bytes."""
- pass
+ key = self._guard_against_unicode(key)
+ if len(key) != 8:
+ raise ValueError("Key must be 8 bytes long")
+ key = self.__permutate(Des.__pc1, self.__string_to_bitlist(key))
+ self.__create_sub_keys()
def __string_to_bitlist(self, data):
"""Turn the string data into a list of bits (1, 0)'s."""
- pass
+ if isinstance(data, str):
+ data = data.encode('ascii')
+ l = len(data) * 8
+ result = [0] * l
+ pos = 0
+ for ch in data:
+ i = 7
+ while i >= 0:
+ if ch & (1 << i) != 0:
+ result[pos] = 1
+ else:
+ result[pos] = 0
+ pos += 1
+ i -= 1
+ return result
def __bitlist_to_string(self, data):
"""Turn the data as list of bits into a string."""
- pass
+ result = []
+ pos = 0
+ c = 0
+ while pos < len(data):
+ c += data[pos] << (7 - (pos % 8))
+ if (pos % 8) == 7:
+ result.append(c)
+ c = 0
+ pos += 1
+ return bytes(result)
def __permutate(self, table, block):
"""Permutate this block with the specified table."""
- pass
+ return [block[x] for x in table]
def __create_sub_keys(self):
"""Transform the secret key for data processing.
Create the 16 subkeys k[1] to k[16] from the given key.
"""
- pass
+ L = self.__key[:28]
+ R = self.__key[28:]
+ for i in range(16):
+ L = L[Des.__left_rotations[i]:] + L[:Des.__left_rotations[i]]
+ R = R[Des.__left_rotations[i]:] + R[:Des.__left_rotations[i]]
+ self._kn[i] = self.__permutate(Des.__pc2, L + R)
def __des_crypt(self, block, crypt_type):
"""Crypt the block of data through DES bit-manipulation."""
- pass
+ block = self.__permutate(Des.__ip, block)
+ self._l = block[:32]
+ self._r = block[32:]
+
+ if crypt_type == Des.ENCRYPT:
+ iteration = 0
+ iteration_adjustment = 1
+ else:
+ iteration = 15
+ iteration_adjustment = -1
+
+ for i in range(16):
+ tempR = self._r[:]
+ self._r = self.__permutate(Des.__expansion_table, self._r)
+ self._r = [x ^ y for x, y in zip(self._r, self._kn[iteration])]
+ B = [self._r[i*6:(i+1)*6] for i in range(8)]
+ Bn = [sum(x << (5-i) for i, x in enumerate(B[j])) for j in range(8)]
+ Bn = [Des.__sbox[i][Bn[i]] for i in range(8)]
+ Bn = [self.__string_to_bitlist('{0:04b}'.format(x)) for x in Bn]
+ self._r = [x for sublist in Bn for x in sublist]
+ self._r = self.__permutate(Des.__p, self._r)
+ self._r = [x ^ y for x, y in zip(self._r, self._l)]
+ self._l = tempR
+ iteration += iteration_adjustment
+
+ self._final = self.__permutate(Des.__fp, self._r + self._l)
+ return self._final
def crypt(self, data, crypt_type):
"""Crypt the data in blocks, running it through des_crypt()."""
- pass
+ if not data:
+ return ''
+ if len(data) % self.block_size != 0:
+ if crypt_type == Des.DECRYPT:
+ raise ValueError("Invalid data length, data must be a multiple of " + str(self.block_size) + " bytes\n.")
+ if not self.getPadding():
+ raise ValueError("Invalid data length, data must be a multiple of " + str(self.block_size) + " bytes\n.")
+ else:
+ data += (self.block_size - (len(data) % self.block_size)) * self.getPadding()
+
+ if self.getMode() == CBC:
+ if self.getIV():
+ iv = self.__string_to_bitlist(self.getIV())
+ else:
+ raise ValueError("For CBC mode, you must supply an IV")
+
+ result = []
+ for i in range(0, len(data), self.block_size):
+ block = self.__string_to_bitlist(data[i:i+self.block_size])
+
+ if self.getMode() == CBC:
+ if crypt_type == Des.ENCRYPT:
+ block = [x ^ y for x, y in zip(block, iv)]
+
+ processed_block = self.__des_crypt(block, crypt_type)
+
+ if self.getMode() == CBC:
+ if crypt_type == Des.DECRYPT:
+ processed_block = [x ^ y for x, y in zip(processed_block, iv)]
+ iv = block
+ else:
+ iv = processed_block
+
+ result.append(self.__bitlist_to_string(processed_block))
+
+ return b''.join(result)
class Python_TripleDES(_baseDes):
@@ -179,7 +277,21 @@ class Python_TripleDES(_baseDes):
The data must be a multiple of 8 bytes and will be encrypted
with the already specified key.
"""
- pass
+ data = self._guard_against_unicode(data)
+ if len(data) % self.block_size != 0:
+ raise ValueError("Invalid data length, must be a multiple of " + str(self.block_size) + " bytes")
+
+ iv = self.iv
+ result = b''
+ for i in range(0, len(data), self.block_size):
+ block = data[i:i+self.block_size]
+ block = bytes([x ^ y for (x, y) in zip(block, iv)])
+ block = self.__key1.encrypt(block)
+ block = self.__key2.decrypt(block)
+ block = self.__key3.encrypt(block)
+ iv = block
+ result += block
+ return result
def decrypt(self, data):
"""Decrypt data and return bytes.
@@ -189,4 +301,19 @@ class Python_TripleDES(_baseDes):
The data must be a multiple of 8 bytes and will be decrypted
with the already specified key.
"""
- pass
+ data = self._guard_against_unicode(data)
+ if len(data) % self.block_size != 0:
+ raise ValueError("Invalid data length, must be a multiple of " + str(self.block_size) + " bytes")
+
+ iv = self.iv
+ result = b''
+ for i in range(0, len(data), self.block_size):
+ block = data[i:i+self.block_size]
+ temp = block
+ block = self.__key3.decrypt(block)
+ block = self.__key2.encrypt(block)
+ block = self.__key1.decrypt(block)
+ block = bytes([x ^ y for (x, y) in zip(block, iv)])
+ iv = temp
+ result += block
+ return result
diff --git a/tlslite/utils/rijndael.py b/tlslite/utils/rijndael.py
index 7a1e878..f6de385 100644
--- a/tlslite/utils/rijndael.py
+++ b/tlslite/utils/rijndael.py
@@ -658,8 +658,70 @@ class Rijndael(object):
def encrypt(self, plaintext):
"""Encrypt a single block of plaintext."""
- pass
+ if len(plaintext) != self.block_size:
+ raise ValueError('wrong block length, expected ' + str(self.block_size) + ' got ' + str(len(plaintext)))
+
+ Ke = self.Ke
+
+ BC = self.block_size // 4
+ ROUNDS = len(Ke) - 1
+ if BC == 4:
+ SC = 0
+ elif BC == 6:
+ SC = 1
+ else:
+ SC = 2
+ s1 = shifts[SC][1][0]
+ s2 = shifts[SC][2][0]
+ s3 = shifts[SC][3][0]
+ a = [0] * BC
+ t = []
+ for i in range(BC):
+ t.append((plaintext[i * 4] << 24 | plaintext[i * 4 + 1] << 16 | plaintext[i * 4 + 2] << 8 | plaintext[i * 4 + 3]) ^ Ke[0][i])
+ for r in range(1, ROUNDS):
+ for i in range(BC):
+ a[i] = (T1[(t[(i) % BC] >> 24) & 255] ^ T2[(t[(i + s1) % BC] >> 16) & 255] ^ T3[(t[(i + s2) % BC] >> 8) & 255] ^ T4[t[(i + s3) % BC] & 255]) ^ Ke[r][i]
+ t = a.copy()
+ result = []
+ for i in range(BC):
+ tt = Ke[ROUNDS][i]
+ result.append((S[(t[(i) % BC] >> 24) & 255] ^ (tt >> 24)) & 255)
+ result.append((S[(t[(i + s1) % BC] >> 16) & 255] ^ (tt >> 16)) & 255)
+ result.append((S[(t[(i + s2) % BC] >> 8) & 255] ^ (tt >> 8)) & 255)
+ result.append((S[t[(i + s3) % BC] & 255] ^ tt) & 255)
+ return bytes(result)
def decrypt(self, ciphertext):
"""Decrypt a block of ciphertext."""
- pass
+ if len(ciphertext) != self.block_size:
+ raise ValueError('wrong block length, expected ' + str(self.block_size) + ' got ' + str(len(ciphertext)))
+
+ Kd = self.Kd
+
+ BC = self.block_size // 4
+ ROUNDS = len(Kd) - 1
+ if BC == 4:
+ SC = 0
+ elif BC == 6:
+ SC = 1
+ else:
+ SC = 2
+ s1 = shifts[SC][1][1]
+ s2 = shifts[SC][2][1]
+ s3 = shifts[SC][3][1]
+ a = [0] * BC
+ t = [0] * BC
+ for i in range(BC):
+ t[i] = (ciphertext[i * 4] << 24 | ciphertext[i * 4 + 1] << 16 | ciphertext[i * 4 + 2] << 8 | ciphertext[i * 4 + 3]) ^ Kd[0][i]
+ for r in range(1, ROUNDS):
+ for i in range(BC):
+ a[i] = (T5[(t[(i) % BC] >> 24) & 255] ^ T6[(t[(i + s1) % BC] >> 16) & 255] ^ T7[(t[(i + s2) % BC] >> 8) & 255] ^ T8[t[(i + s3) % BC] & 255]) ^ Kd[r][i]
+ t = a.copy()
+ result = []
+ for i in range(BC):
+ tt = Kd[ROUNDS][i]
+ result.append((Si[(t[(i) % BC] >> 24) & 255] ^ (tt >> 24)) & 255)
+ result.append((Si[(t[(i + s1) % BC] >> 16) & 255] ^ (tt >> 16)) & 255)
+ result.append((Si[(t[(i + s2) % BC] >> 8) & 255] ^ (tt >> 8)) & 255)
+ result.append((Si[t[(i + s3) % BC] & 255] ^ tt) & 255)
+ return bytes(result)
diff --git a/tlslite/utils/rsakey.py b/tlslite/utils/rsakey.py
index f5fe68d..52a91b0 100644
--- a/tlslite/utils/rsakey.py
+++ b/tlslite/utils/rsakey.py
@@ -40,7 +40,8 @@ class RSAKey(object):
self.e = e
self.key_type = key_type
self._key_hash = None
- raise NotImplementedError()
+ if self.key_type not in ['rsa', 'rsa-pss']:
+ raise ValueError("Invalid key_type. Must be 'rsa' or 'rsa-pss'.")
def __len__(self):
"""Return the length of this key in bits.
@@ -54,7 +55,7 @@ class RSAKey(object):
:rtype: bool
"""
- pass
+ return hasattr(self, 'd')
def hashAndSign(self, bytes, rsaScheme='PKCS1', hAlg='sha1', sLen=0):
"""Hash and sign the passed-in bytes.
@@ -127,7 +128,17 @@ class RSAKey(object):
:rtype: bytearray
:returns: Mask
"""
- pass
+ hashObj = hashlib.new(hAlg)
+ hLen = hashObj.digest_size
+ if maskLen > (2**32) * hLen:
+ raise ValueError("mask too long")
+ T = bytearray()
+ for counter in range(ceil(maskLen / hLen)):
+ C = i2osp(counter, 4)
+ hashObj = hashlib.new(hAlg)
+ hashObj.update(mgfSeed + C)
+ T += hashObj.digest()
+ return T[:maskLen]
def EMSA_PSS_encode(self, mHash, emBits, hAlg, sLen=0):
"""Encode the passed in message
@@ -145,7 +156,25 @@ class RSAKey(object):
:type sLen: int
:param sLen: length of salt"""
- pass
+ hashObj = hashlib.new(hAlg)
+ hLen = hashObj.digest_size
+ emLen = ceil(emBits / 8)
+
+ if emLen < hLen + sLen + 2:
+ raise ValueError("encoding error")
+
+ salt = getRandomBytes(sLen)
+ M_prime = b'\x00' * 8 + mHash + salt
+
+ H = hashlib.new(hAlg, M_prime).digest()
+ PS = b'\x00' * (emLen - sLen - hLen - 2)
+ DB = PS + b'\x01' + salt
+ dbMask = self.MGF1(H, emLen - hLen - 1, hAlg)
+ maskedDB = bytearray(a ^ b for a, b in zip(DB, dbMask))
+
+ maskedDB[0] &= 0xFF >> (8 * emLen - emBits)
+ EM = maskedDB + H + b'\xbc'
+ return EM
def RSASSA_PSS_sign(self, mHash, hAlg, sLen=0):
""""Sign the passed in message
@@ -160,7 +189,14 @@ class RSAKey(object):
:type sLen: int
:param sLen: length of salt"""
- pass
+ if not self.hasPrivateKey():
+ raise ValueError("Private key not available")
+
+ EM = self.EMSA_PSS_encode(mHash, len(self) - 1, hAlg, sLen)
+ m = bytes_to_int(EM)
+ s = self._raw_private_key_op(m)
+ S = int_to_bytes(s, len(self) // 8)
+ return S
def EMSA_PSS_verify(self, mHash, EM, emBits, hAlg, sLen=0):
"""Verify signature in passed in encoded message
@@ -182,7 +218,35 @@ class RSAKey(object):
:type sLen: int
:param sLen: Length of salt
"""
- pass
+ hashObj = hashlib.new(hAlg)
+ hLen = hashObj.digest_size
+ emLen = ceil(emBits / 8)
+
+ if emLen < hLen + sLen + 2:
+ return False
+ if EM[-1] != 0xbc:
+ return False
+
+ maskedDB = EM[:emLen - hLen - 1]
+ H = EM[emLen - hLen - 1:-1]
+
+ if maskedDB[0] & (0xFF << (8 - (emBits & 7))):
+ return False
+
+ dbMask = self.MGF1(H, emLen - hLen - 1, hAlg)
+ DB = bytearray(a ^ b for a, b in zip(maskedDB, dbMask))
+ DB[0] &= 0xFF >> (8 * emLen - emBits)
+
+ if any(DB[i] != 0 for i in range(emLen - hLen - sLen - 2)):
+ return False
+ if DB[emLen - hLen - sLen - 2] != 0x01:
+ return False
+
+ salt = DB[-sLen:] if sLen > 0 else b''
+ M_prime = b'\x00' * 8 + mHash + salt
+ H_prime = hashlib.new(hAlg, M_prime).digest()
+
+ return ct_eq_u32(H, H_prime)
def RSASSA_PSS_verify(self, mHash, S, hAlg, sLen=0):
"""Verify the signature in passed in message
@@ -201,11 +265,24 @@ class RSAKey(object):
:type sLen: int
:param sLen: Length of salt
"""
- pass
+ if len(S) != len(self) // 8:
+ return False
+
+ s = bytes_to_int(S)
+ m = self._raw_public_key_op(s)
+ EM = int_to_bytes(m, len(self) // 8)
+
+ return self.EMSA_PSS_verify(mHash, EM, len(self) - 1, hAlg, sLen)
def _raw_pkcs1_sign(self, bytes):
"""Perform signature on raw data, add PKCS#1 padding."""
- pass
+ if not self.hasPrivateKey():
+ raise ValueError("Private key not available")
+
+ paddedBytes = self._addPKCS1Padding(bytes, 1)
+ m = bytes_to_int(paddedBytes)
+ s = self._raw_private_key_op(m)
+ return int_to_bytes(s, len(self) // 8)
def sign(self, bytes, padding='pkcs1', hashAlg=None, saltLen=None):
"""Sign the passed-in bytes.
@@ -232,11 +309,28 @@ class RSAKey(object):
:rtype: bytearray
:returns: A PKCS1 signature on the passed-in data.
"""
- pass
+ if not self.hasPrivateKey():
+ raise ValueError("Private key not available")
+
+ if padding == 'pkcs1':
+ hashBytes = hashlib.new(hashAlg, bytes).digest() if hashAlg else bytes
+ prefixedHashBytes = self.addPKCS1Prefix(hashBytes, hashAlg) if hashAlg else hashBytes
+ return self._raw_pkcs1_sign(prefixedHashBytes)
+ elif padding == 'pss':
+ if not hashAlg:
+ raise ValueError("hashAlg is mandatory for PSS padding")
+ hashBytes = hashlib.new(hashAlg, bytes).digest()
+ saltLen = saltLen or len(hashBytes)
+ return self.RSASSA_PSS_sign(hashBytes, hashAlg, saltLen)
+ else:
+ raise ValueError("Unsupported padding mode")
def _raw_pkcs1_verify(self, sigBytes, bytes):
"""Perform verification operation on raw PKCS#1 padded signature"""
- pass
+ s = bytes_to_int(sigBytes)
+ m = self._raw_public_key_op(s)
+ em = int_to_bytes(m, len(self) // 8)
+ return self._removePKCS1Padding(em) == bytes
def verify(self, sigBytes, bytes, padding='pkcs1', hashAlg=None,
saltLen=None):
@@ -253,7 +347,18 @@ class RSAKey(object):
:rtype: bool
:returns: Whether the signature matches the passed-in data.
"""
- pass
+ if padding == 'pkcs1':
+ hashBytes = hashlib.new(hashAlg, bytes).digest() if hashAlg else bytes
+ prefixedHashBytes = self.addPKCS1Prefix(hashBytes, hashAlg) if hashAlg else hashBytes
+ return self._raw_pkcs1_verify(sigBytes, prefixedHashBytes)
+ elif padding == 'pss':
+ if not hashAlg:
+ raise ValueError("hashAlg is mandatory for PSS padding")
+ hashBytes = hashlib.new(hashAlg, bytes).digest()
+ saltLen = saltLen or len(hashBytes)
+ return self.RSASSA_PSS_verify(hashBytes, sigBytes, hashAlg, saltLen)
+ else:
+ raise ValueError("Unsupported padding mode")
def encrypt(self, bytes):
"""Encrypt the passed-in bytes.
diff --git a/tlslite/utils/tlshashlib.py b/tlslite/utils/tlshashlib.py
index 080311b..270f3af 100644
--- a/tlslite/utils/tlshashlib.py
+++ b/tlslite/utils/tlshashlib.py
@@ -5,14 +5,21 @@ import hashlib
def _fipsFunction(func, *args, **kwargs):
"""Make hash function support FIPS mode."""
- pass
+ try:
+ return func(*args, **kwargs)
+ except ValueError as e:
+ if "disabled for FIPS" in str(e):
+ return hashlib.sha256(*args, **kwargs)
+ raise
def md5(*args, **kwargs):
"""MD5 constructor that works in FIPS mode."""
- pass
+ return _fipsFunction(hashlib.md5, *args, **kwargs)
def new(*args, **kwargs):
"""General constructor that works in FIPS mode."""
- pass
+ if args and isinstance(args[0], str):
+ return _fipsFunction(hashlib.new, *args, **kwargs)
+ return hashlib.new(*args, **kwargs)
diff --git a/tlslite/utils/tlshmac.py b/tlslite/utils/tlshmac.py
index 02030ea..17255f0 100644
--- a/tlslite/utils/tlshmac.py
+++ b/tlslite/utils/tlshmac.py
@@ -56,4 +56,18 @@ except Exception:
def new(*args, **kwargs):
"""General constructor that works in FIPS mode."""
- pass
+ return HMAC(*args, **kwargs)
+
+def compare_digest(a, b):
+ """
+ Compare two digests of equal length in constant time.
+
+ The digests must be of type str/bytes.
+ Returns True if the digests match, and False otherwise.
+ """
+ if len(a) != len(b):
+ return False
+ result = 0
+ for x, y in zip(a, b):
+ result |= x ^ y
+ return result == 0
diff --git a/tlslite/utils/x25519.py b/tlslite/utils/x25519.py
index c8f6173..7c47ebb 100644
--- a/tlslite/utils/x25519.py
+++ b/tlslite/utils/x25519.py
@@ -4,22 +4,37 @@ from .cryptomath import bytesToNumber, numberToByteArray, divceil
def decodeUCoordinate(u, bits):
"""Function to decode the public U coordinate of X25519-family curves."""
- pass
+ u_list = bytearray(u)
+ if bits == 255:
+ u_list[-1] &= 127
+ elif bits == 448:
+ u_list[0] &= 252
+ return bytesToNumber(u_list)
def decodeScalar22519(k):
"""Function to decode the private K parameter of the x25519 function."""
- pass
+ k_list = bytearray(k)
+ k_list[0] &= 248
+ k_list[31] &= 127
+ k_list[31] |= 64
+ return bytesToNumber(k_list)
def decodeScalar448(k):
"""Function to decode the private K parameter of the X448 function."""
- pass
+ k_list = bytearray(k)
+ k_list[0] &= 252
+ k_list[55] |= 128
+ return bytesToNumber(k_list)
def cswap(swap, x_2, x_3):
"""Conditional swap function."""
- pass
+ dummy = swap * (x_2 ^ x_3)
+ x_2 ^= dummy
+ x_3 ^= dummy
+ return x_2, x_3
X25519_G = numberToByteArray(9, 32, endian='little')
@@ -38,7 +53,40 @@ def x25519(k, u):
:rtype: bytearray
"""
- pass
+ x1 = decodeUCoordinate(u, 255)
+ x2 = 1
+ z2 = 0
+ x3 = x1
+ z3 = 1
+ swap = 0
+
+ k = decodeScalar22519(k)
+
+ for t in range(255, -1, -1):
+ kt = (k >> t) & 1
+ swap ^= kt
+ x2, x3 = cswap(swap, x2, x3)
+ z2, z3 = cswap(swap, z2, z3)
+ swap = kt
+
+ A = x2 + z2
+ AA = A * A
+ B = x2 - z2
+ BB = B * B
+ E = AA - BB
+ C = x3 + z3
+ D = x3 - z3
+ DA = D * A
+ CB = C * B
+ x3 = (DA + CB) * (DA + CB)
+ z3 = x1 * (DA - CB) * (DA - CB)
+ x2 = AA * BB
+ z2 = E * (AA + 121665 * E)
+
+ x2, x3 = cswap(swap, x2, x3)
+ z2, z3 = cswap(swap, z2, z3)
+
+ return numberToByteArray(x2 * pow(z2, 2**255 - 21 - 1, 2**255 - 19), X25519_ORDER_SIZE, "little")
X448_G = numberToByteArray(5, 56, endian='little')
@@ -57,9 +105,42 @@ def x448(k, u):
:rtype: bytearray
"""
- pass
+ return _x25519_generic(k, u, 448, 39081, 2**448 - 2**224 - 1)
def _x25519_generic(k, u, bits, a24, p):
"""Generic Montgomery ladder implementation of the x25519 algorithm."""
- pass
+ x1 = decodeUCoordinate(u, bits)
+ x2 = 1
+ z2 = 0
+ x3 = x1
+ z3 = 1
+ swap = 0
+
+ k = decodeScalar448(k) if bits == 448 else decodeScalar22519(k)
+
+ for t in range(bits - 1, -1, -1):
+ kt = (k >> t) & 1
+ swap ^= kt
+ x2, x3 = cswap(swap, x2, x3)
+ z2, z3 = cswap(swap, z2, z3)
+ swap = kt
+
+ A = x2 + z2
+ AA = A * A
+ B = x2 - z2
+ BB = B * B
+ E = AA - BB
+ C = x3 + z3
+ D = x3 - z3
+ DA = D * A
+ CB = C * B
+ x3 = (DA + CB) * (DA + CB)
+ z3 = x1 * (DA - CB) * (DA - CB)
+ x2 = AA * BB
+ z2 = E * (AA + a24 * E)
+
+ x2, x3 = cswap(swap, x2, x3)
+ z2, z3 = cswap(swap, z2, z3)
+
+ return numberToByteArray(x2 * pow(z2, p - 2, p), divceil(bits, 8), "little")
diff --git a/tlslite/verifierdb.py b/tlslite/verifierdb.py
index 4264045..59f2a51 100644
--- a/tlslite/verifierdb.py
+++ b/tlslite/verifierdb.py
@@ -61,4 +61,15 @@ class VerifierDB(BaseDB):
:rtype: tuple
:returns: A tuple which may be stored in a VerifierDB.
"""
- pass
+ if bits not in (1024, 1536, 2048, 3072, 4096, 6144, 8192):
+ raise ValueError("Bits must be one of (1024, 1536, 2048, 3072, 4096, 6144, 8192)")
+
+ if len(username) >= 256:
+ raise ValueError("Username must be less than 256 characters")
+
+ N, g, _ = mathtls.makeRFC5054Group(bits)
+ salt = getRandomBytes(16)
+ x = mathtls.makeX(salt, username, password)
+ verifier = powMod(g, x, N)
+
+ return (N, g, salt, verifier)
diff --git a/tlslite/x509.py b/tlslite/x509.py
index 5c37622..6bb72f6 100644
--- a/tlslite/x509.py
+++ b/tlslite/x509.py
@@ -62,7 +62,8 @@ class X509(object):
certificate wrapped with "-----BEGIN CERTIFICATE-----" and
"-----END CERTIFICATE-----" tags).
"""
- pass
+ bytes = dePem(s, "CERTIFICATE")
+ return self.parseBinary(bytes)
def parseBinary(self, cert_bytes):
"""
@@ -71,7 +72,40 @@ class X509(object):
:type bytes: L{str} (in python2) or L{bytearray} of unsigned bytes
:param bytes: A DER-encoded X.509 certificate.
"""
- pass
+ self.bytes = bytearray(cert_bytes)
+ parser = ASN1Parser(self.bytes)
+
+ cert = parser.getChild(0)
+ tbsCertificate = cert.getChild(0)
+
+ self.serial_number = tbsCertificate.getChild(1).value
+
+ self.issuer = tbsCertificate.getChild(3).value
+ self.subject = tbsCertificate.getChild(5).value
+
+ subject_public_key_info = tbsCertificate.getChild(6)
+ algorithm = subject_public_key_info.getChild(0)
+ alg_oid = algorithm.getChild(0).value
+
+ if alg_oid == AlgorithmOID.RSA:
+ self.certAlg = "rsa"
+ self._rsa_pubkey_parsing(subject_public_key_info)
+ elif alg_oid == RSA_PSS_OID:
+ self.certAlg = "rsa-pss"
+ self._rsa_pubkey_parsing(subject_public_key_info)
+ elif alg_oid == AlgorithmOID.ECDSA:
+ self.certAlg = "ecdsa"
+ self._ecdsa_pubkey_parsing(subject_public_key_info)
+ elif alg_oid == AlgorithmOID.DSA:
+ self.certAlg = "dsa"
+ self._dsa_pubkey_parsing(subject_public_key_info)
+ elif alg_oid in (AlgorithmOID.Ed25519, AlgorithmOID.Ed448):
+ self.certAlg = "eddsa"
+ self._eddsa_pubkey_parsing(subject_public_key_info)
+ else:
+ raise SyntaxError("Unsupported public key algorithm")
+
+ self.sigalg = cert.getChild(1).getChild(0).value
def _eddsa_pubkey_parsing(self, subject_public_key_info):
"""
@@ -80,7 +114,8 @@ class X509(object):
:param subject_public_key_info: bytes like object with the DER encoded
public key in it
"""
- pass
+ public_key = subject_public_key_info.getChild(1).value
+ self.publicKey = _create_public_eddsa_key(public_key)
def _rsa_pubkey_parsing(self, subject_public_key_info):
"""
@@ -89,7 +124,11 @@ class X509(object):
:param subject_public_key_info: ASN1Parser object with subject
public key info of X.509 certificate
"""
- pass
+ public_key = subject_public_key_info.getChild(1).value
+ key_parser = ASN1Parser(public_key)
+ modulus = key_parser.getChild(0).value
+ public_exponent = key_parser.getChild(1).value
+ self.publicKey = _createPublicRSAKey(modulus, public_exponent)
def _ecdsa_pubkey_parsing(self, subject_public_key_info):
"""
@@ -98,7 +137,8 @@ class X509(object):
:param subject_public_key_info: bytes like object with DER encoded
public key in it
"""
- pass
+ public_key = subject_public_key_info.getChild(1).value
+ self.publicKey = _create_public_ecdsa_key(public_key)
def _dsa_pubkey_parsing(self, subject_public_key_info):
"""
@@ -107,7 +147,12 @@ class X509(object):
:param subject_public_key_info: bytes like object with DER encoded
global parameters and public key in it
"""
- pass
+ algorithm_params = subject_public_key_info.getChild(0).getChild(1)
+ p = algorithm_params.getChild(0).value
+ q = algorithm_params.getChild(1).value
+ g = algorithm_params.getChild(2).value
+ y = subject_public_key_info.getChild(1).value
+ self.publicKey = _create_public_dsa_key(p, q, g, y)
def getFingerprint(self):
"""
@@ -116,8 +161,8 @@ class X509(object):
:rtype: str
:returns: A hex-encoded fingerprint.
"""
- pass
+ return b2a_hex(compatHMAC(self.bytes, "sha1")).decode("ascii")
def writeBytes(self):
"""Serialise object to a DER encoded string."""
- pass
+ return bytes(self.bytes)
diff --git a/tlslite/x509certchain.py b/tlslite/x509certchain.py
index 54183ce..fae0c93 100644
--- a/tlslite/x509certchain.py
+++ b/tlslite/x509certchain.py
@@ -48,21 +48,26 @@ class X509CertChain(object):
Raise a SyntaxError if input is malformed.
"""
- pass
+ certs = parsePemList(s)
+ if not certs:
+ raise SyntaxError("No PEM-encoded certificates found")
+ self.x509List = [X509().parse(cert) for cert in certs]
def getNumCerts(self):
"""Get the number of certificates in this chain.
:rtype: int
"""
- pass
+ return len(self.x509List)
def getEndEntityPublicKey(self):
"""Get the public key from the end-entity certificate.
:rtype: ~tlslite.utils.rsakey.RSAKey`
"""
- pass
+ if not self.x509List:
+ raise ValueError("No certificates in the chain")
+ return self.x509List[0].publicKey
def getFingerprint(self):
"""Get the hex-encoded fingerprint of the end-entity certificate.
@@ -70,8 +75,16 @@ class X509CertChain(object):
:rtype: str
:returns: A hex-encoded fingerprint.
"""
- pass
+ if not self.x509List:
+ raise ValueError("No certificates in the chain")
+ return self.x509List[0].getFingerprint()
def getTackExt(self):
"""Get the TACK and/or Break Sigs from a TACK Cert in the chain."""
- pass
+ if not self.x509List:
+ return None
+ for cert in self.x509List:
+ tackExt = cert.getTackExt()
+ if tackExt:
+ return tackExt
+ return None