back to Reference (Gold) summary
Reference (Gold): python-rsa
Pytest Summary for test tests
status | count |
---|---|
passed | 86 |
failed | 1 |
total | 87 |
collected | 87 |
Failed pytests:
test_mypy.py::MypyRunnerTest::test_run_mypy
test_mypy.py::MypyRunnerTest::test_run_mypy
self =def test_run_mypy(self): proj_root = pathlib.Path(__file__).parent.parent args = [ "--incremental", "--ignore-missing-imports", f"--python-version={sys.version_info.major}.{sys.version_info.minor}", ] + [str(proj_root / dirname) for dirname in test_modules] result = mypy.api.run(args) stdout, stderr, status = result messages = [] if stderr: messages.append(stderr) if stdout: messages.append(stdout) if status: messages.append("Mypy failed with status %d" % status) if messages and not all("Success" in message for message in messages): > self.fail("\n".join(["Mypy errors:"] + messages)) E AssertionError: Mypy errors: E setup.cfg: [mypy]: python_version: Python 3.7 is not supported (must be 3.8 or higher) E E rsa/key.py:68: error: Missing return statement [empty-body] E rsa/key.py:80: error: Missing return statement [empty-body] E rsa/key.py:91: error: Missing return statement [empty-body] E rsa/key.py:98: error: Missing return statement [empty-body] E Found 4 errors in 1 file (checked 28 source files) E E Mypy failed with status 1 tests/test_mypy.py:31: AssertionError
Patch diff
diff --git a/rsa/asn1.py b/rsa/asn1.py
index 02e5a82..4cc4dd3 100644
--- a/rsa/asn1.py
+++ b/rsa/asn1.py
@@ -1,19 +1,41 @@
+# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
"""ASN.1 definitions.
Not all ASN.1-handling code use these definitions, but when it does, they should be here.
"""
+
from pyasn1.type import univ, namedtype, tag
class PubKeyHeader(univ.Sequence):
- componentType = namedtype.NamedTypes(namedtype.NamedType('oid', univ.
- ObjectIdentifier()), namedtype.NamedType('parameters', univ.Null()))
+ componentType = namedtype.NamedTypes(
+ namedtype.NamedType("oid", univ.ObjectIdentifier()),
+ namedtype.NamedType("parameters", univ.Null()),
+ )
class OpenSSLPubKey(univ.Sequence):
- componentType = namedtype.NamedTypes(namedtype.NamedType('header',
- PubKeyHeader()), namedtype.NamedType('key', univ.OctetString().
- subtype(implicitTag=tag.Tag(tagClass=0, tagFormat=0, tagId=3))))
+ componentType = namedtype.NamedTypes(
+ namedtype.NamedType("header", PubKeyHeader()),
+ # This little hack (the implicit tag) allows us to get a Bit String as Octet String
+ namedtype.NamedType(
+ "key",
+ univ.OctetString().subtype(implicitTag=tag.Tag(tagClass=0, tagFormat=0, tagId=3)),
+ ),
+ )
class AsnPubKey(univ.Sequence):
@@ -23,5 +45,8 @@ class AsnPubKey(univ.Sequence):
modulus INTEGER, -- n
publicExponent INTEGER, -- e
"""
- componentType = namedtype.NamedTypes(namedtype.NamedType('modulus',
- univ.Integer()), namedtype.NamedType('publicExponent', univ.Integer()))
+
+ componentType = namedtype.NamedTypes(
+ namedtype.NamedType("modulus", univ.Integer()),
+ namedtype.NamedType("publicExponent", univ.Integer()),
+ )
diff --git a/rsa/cli.py b/rsa/cli.py
index 4449a1f..4db3f0b 100644
--- a/rsa/cli.py
+++ b/rsa/cli.py
@@ -1,160 +1,318 @@
+# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
"""Commandline scripts.
These scripts are called by the executables defined in setup.py.
"""
+
import abc
import sys
import typing
import optparse
+
import rsa
import rsa.key
import rsa.pkcs1
+
HASH_METHODS = sorted(rsa.pkcs1.HASH_METHODS.keys())
Indexable = typing.Union[typing.Tuple, typing.List[str]]
-def keygen() ->None:
+def keygen() -> None:
"""Key generator."""
- pass
+
+ # Parse the CLI options
+ parser = optparse.OptionParser(
+ usage="usage: %prog [options] keysize",
+ description='Generates a new RSA key pair of "keysize" bits.',
+ )
+
+ parser.add_option(
+ "--pubout",
+ type="string",
+ help="Output filename for the public key. The public key is "
+ "not saved if this option is not present. You can use "
+ "pyrsa-priv2pub to create the public key file later.",
+ )
+
+ parser.add_option(
+ "-o",
+ "--out",
+ type="string",
+ help="Output filename for the private key. The key is "
+ "written to stdout if this option is not present.",
+ )
+
+ parser.add_option(
+ "--form",
+ help="key format of the private and public keys - default PEM",
+ choices=("PEM", "DER"),
+ default="PEM",
+ )
+
+ (cli, cli_args) = parser.parse_args(sys.argv[1:])
+
+ if len(cli_args) != 1:
+ parser.print_help()
+ raise SystemExit(1)
+
+ try:
+ keysize = int(cli_args[0])
+ except ValueError as ex:
+ parser.print_help()
+ print("Not a valid number: %s" % cli_args[0], file=sys.stderr)
+ raise SystemExit(1) from ex
+
+ print("Generating %i-bit key" % keysize, file=sys.stderr)
+ (pub_key, priv_key) = rsa.newkeys(keysize)
+
+ # Save public key
+ if cli.pubout:
+ print("Writing public key to %s" % cli.pubout, file=sys.stderr)
+ data = pub_key.save_pkcs1(format=cli.form)
+ with open(cli.pubout, "wb") as outfile:
+ outfile.write(data)
+
+ # Save private key
+ data = priv_key.save_pkcs1(format=cli.form)
+
+ if cli.out:
+ print("Writing private key to %s" % cli.out, file=sys.stderr)
+ with open(cli.out, "wb") as outfile:
+ outfile.write(data)
+ else:
+ print("Writing private key to stdout", file=sys.stderr)
+ sys.stdout.buffer.write(data)
class CryptoOperation(metaclass=abc.ABCMeta):
"""CLI callable that operates with input, output, and a key."""
- keyname = 'public'
- usage = 'usage: %%prog [options] %(keyname)s_key'
- description = ''
- operation = 'decrypt'
- operation_past = 'decrypted'
- operation_progressive = 'decrypting'
- input_help = (
- 'Name of the file to %(operation)s. Reads from stdin if not specified.'
- )
+
+ keyname = "public" # or 'private'
+ usage = "usage: %%prog [options] %(keyname)s_key"
+ description = ""
+ operation = "decrypt"
+ operation_past = "decrypted"
+ operation_progressive = "decrypting"
+ input_help = "Name of the file to %(operation)s. Reads from stdin if " "not specified."
output_help = (
- 'Name of the file to write the %(operation_past)s file to. Written to stdout if this option is not present.'
- )
+ "Name of the file to write the %(operation_past)s file "
+ "to. Written to stdout if this option is not present."
+ )
expected_cli_args = 1
has_output = True
- key_class = rsa.PublicKey
- def __init__(self) ->None:
+ key_class = rsa.PublicKey # type: typing.Type[rsa.key.AbstractKey]
+
+ def __init__(self) -> None:
self.usage = self.usage % self.__class__.__dict__
self.input_help = self.input_help % self.__class__.__dict__
self.output_help = self.output_help % self.__class__.__dict__
@abc.abstractmethod
- def perform_operation(self, indata: bytes, key: rsa.key.AbstractKey,
- cli_args: Indexable) ->typing.Any:
+ def perform_operation(
+ self, indata: bytes, key: rsa.key.AbstractKey, cli_args: Indexable
+ ) -> typing.Any:
"""Performs the program's operation.
Implement in a subclass.
:returns: the data to write to the output.
"""
- pass
- def __call__(self) ->None:
+ def __call__(self) -> None:
"""Runs the program."""
- cli, cli_args = self.parse_cli()
+
+ (cli, cli_args) = self.parse_cli()
+
key = self.read_key(cli_args[0], cli.keyform)
+
indata = self.read_infile(cli.input)
+
print(self.operation_progressive.title(), file=sys.stderr)
outdata = self.perform_operation(indata, key, cli_args)
+
if self.has_output:
self.write_outfile(outdata, cli.output)
- def parse_cli(self) ->typing.Tuple[optparse.Values, typing.List[str]]:
+ def parse_cli(self) -> typing.Tuple[optparse.Values, typing.List[str]]:
"""Parse the CLI options
:returns: (cli_opts, cli_args)
"""
- pass
- def read_key(self, filename: str, keyform: str) ->rsa.key.AbstractKey:
+ parser = optparse.OptionParser(usage=self.usage, description=self.description)
+
+ parser.add_option("-i", "--input", type="string", help=self.input_help)
+
+ if self.has_output:
+ parser.add_option("-o", "--output", type="string", help=self.output_help)
+
+ parser.add_option(
+ "--keyform",
+ help="Key format of the %s key - default PEM" % self.keyname,
+ choices=("PEM", "DER"),
+ default="PEM",
+ )
+
+ (cli, cli_args) = parser.parse_args(sys.argv[1:])
+
+ if len(cli_args) != self.expected_cli_args:
+ parser.print_help()
+ raise SystemExit(1)
+
+ return cli, cli_args
+
+ def read_key(self, filename: str, keyform: str) -> rsa.key.AbstractKey:
"""Reads a public or private key."""
- pass
- def read_infile(self, inname: str) ->bytes:
+ print("Reading %s key from %s" % (self.keyname, filename), file=sys.stderr)
+ with open(filename, "rb") as keyfile:
+ keydata = keyfile.read()
+
+ return self.key_class.load_pkcs1(keydata, keyform)
+
+ def read_infile(self, inname: str) -> bytes:
"""Read the input file"""
- pass
- def write_outfile(self, outdata: bytes, outname: str) ->None:
+ if inname:
+ print("Reading input from %s" % inname, file=sys.stderr)
+ with open(inname, "rb") as infile:
+ return infile.read()
+
+ print("Reading input from stdin", file=sys.stderr)
+ return sys.stdin.buffer.read()
+
+ def write_outfile(self, outdata: bytes, outname: str) -> None:
"""Write the output file"""
- pass
+
+ if outname:
+ print("Writing output to %s" % outname, file=sys.stderr)
+ with open(outname, "wb") as outfile:
+ outfile.write(outdata)
+ else:
+ print("Writing output to stdout", file=sys.stderr)
+ sys.stdout.buffer.write(outdata)
class EncryptOperation(CryptoOperation):
"""Encrypts a file."""
- keyname = 'public'
+
+ keyname = "public"
description = (
- 'Encrypts a file. The file must be shorter than the key length in order to be encrypted.'
- )
- operation = 'encrypt'
- operation_past = 'encrypted'
- operation_progressive = 'encrypting'
+ "Encrypts a file. The file must be shorter than the key " "length in order to be encrypted."
+ )
+ operation = "encrypt"
+ operation_past = "encrypted"
+ operation_progressive = "encrypting"
- def perform_operation(self, indata: bytes, pub_key: rsa.key.AbstractKey,
- cli_args: Indexable=()) ->bytes:
+ def perform_operation(
+ self, indata: bytes, pub_key: rsa.key.AbstractKey, cli_args: Indexable = ()
+ ) -> bytes:
"""Encrypts files."""
- pass
+ assert isinstance(pub_key, rsa.key.PublicKey)
+ return rsa.encrypt(indata, pub_key)
class DecryptOperation(CryptoOperation):
"""Decrypts a file."""
- keyname = 'private'
+
+ keyname = "private"
description = (
- 'Decrypts a file. The original file must be shorter than the key length in order to have been encrypted.'
- )
- operation = 'decrypt'
- operation_past = 'decrypted'
- operation_progressive = 'decrypting'
+ "Decrypts a file. The original file must be shorter than "
+ "the key length in order to have been encrypted."
+ )
+ operation = "decrypt"
+ operation_past = "decrypted"
+ operation_progressive = "decrypting"
key_class = rsa.PrivateKey
- def perform_operation(self, indata: bytes, priv_key: rsa.key.
- AbstractKey, cli_args: Indexable=()) ->bytes:
+ def perform_operation(
+ self, indata: bytes, priv_key: rsa.key.AbstractKey, cli_args: Indexable = ()
+ ) -> bytes:
"""Decrypts files."""
- pass
+ assert isinstance(priv_key, rsa.key.PrivateKey)
+ return rsa.decrypt(indata, priv_key)
class SignOperation(CryptoOperation):
"""Signs a file."""
- keyname = 'private'
- usage = 'usage: %%prog [options] private_key hash_method'
+
+ keyname = "private"
+ usage = "usage: %%prog [options] private_key hash_method"
description = (
- 'Signs a file, outputs the signature. Choose the hash method from %s' %
- ', '.join(HASH_METHODS))
- operation = 'sign'
- operation_past = 'signature'
- operation_progressive = 'Signing'
+ "Signs a file, outputs the signature. Choose the hash "
+ "method from %s" % ", ".join(HASH_METHODS)
+ )
+ operation = "sign"
+ operation_past = "signature"
+ operation_progressive = "Signing"
key_class = rsa.PrivateKey
expected_cli_args = 2
+
output_help = (
- 'Name of the file to write the signature to. Written to stdout if this option is not present.'
- )
+ "Name of the file to write the signature to. Written "
+ "to stdout if this option is not present."
+ )
- def perform_operation(self, indata: bytes, priv_key: rsa.key.
- AbstractKey, cli_args: Indexable) ->bytes:
+ def perform_operation(
+ self, indata: bytes, priv_key: rsa.key.AbstractKey, cli_args: Indexable
+ ) -> bytes:
"""Signs files."""
- pass
+ assert isinstance(priv_key, rsa.key.PrivateKey)
+
+ hash_method = cli_args[1]
+ if hash_method not in HASH_METHODS:
+ raise SystemExit("Invalid hash method, choose one of %s" % ", ".join(HASH_METHODS))
+
+ return rsa.sign(indata, priv_key, hash_method)
class VerifyOperation(CryptoOperation):
"""Verify a signature."""
- keyname = 'public'
- usage = 'usage: %%prog [options] public_key signature_file'
+
+ keyname = "public"
+ usage = "usage: %%prog [options] public_key signature_file"
description = (
- 'Verifies a signature, exits with status 0 upon success, prints an error message and exits with status 1 upon error.'
- )
- operation = 'verify'
- operation_past = 'verified'
- operation_progressive = 'Verifying'
+ "Verifies a signature, exits with status 0 upon success, "
+ "prints an error message and exits with status 1 upon error."
+ )
+ operation = "verify"
+ operation_past = "verified"
+ operation_progressive = "Verifying"
key_class = rsa.PublicKey
expected_cli_args = 2
has_output = False
- def perform_operation(self, indata: bytes, pub_key: rsa.key.AbstractKey,
- cli_args: Indexable) ->None:
+ def perform_operation(
+ self, indata: bytes, pub_key: rsa.key.AbstractKey, cli_args: Indexable
+ ) -> None:
"""Verifies files."""
- pass
+ assert isinstance(pub_key, rsa.key.PublicKey)
+
+ signature_file = cli_args[1]
+
+ with open(signature_file, "rb") as sigfile:
+ signature = sigfile.read()
+
+ try:
+ rsa.verify(indata, signature, pub_key)
+ except rsa.VerificationError as ex:
+ raise SystemExit("Verification failed.") from ex
+
+ print("Verification OK", file=sys.stderr)
encrypt = EncryptOperation()
diff --git a/rsa/common.py b/rsa/common.py
index 2f4bc71..ca732e5 100644
--- a/rsa/common.py
+++ b/rsa/common.py
@@ -1,18 +1,31 @@
+# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
"""Common functionality shared by several modules."""
+
import typing
class NotRelativePrimeError(ValueError):
-
- def __init__(self, a: int, b: int, d: int, msg: str='') ->None:
- super().__init__(msg or
- '%d and %d are not relatively prime, divider=%i' % (a, b, d))
+ def __init__(self, a: int, b: int, d: int, msg: str = "") -> None:
+ super().__init__(msg or "%d and %d are not relatively prime, divider=%i" % (a, b, d))
self.a = a
self.b = b
self.d = d
-def bit_size(num: int) ->int:
+def bit_size(num: int) -> int:
"""
Number of bits needed to represent a integer excluding any prefix
0 bits.
@@ -33,10 +46,14 @@ def bit_size(num: int) ->int:
:returns:
Returns the number of bits in the integer.
"""
- pass
+
+ try:
+ return num.bit_length()
+ except AttributeError as ex:
+ raise TypeError("bit_size(num) only supports integers, not %r" % type(num)) from ex
-def byte_size(number: int) ->int:
+def byte_size(number: int) -> int:
"""
Returns the number of bytes required to hold a specific long number.
@@ -56,10 +73,12 @@ def byte_size(number: int) ->int:
:returns:
The number of bytes required to hold a specific long number.
"""
- pass
+ if number == 0:
+ return 1
+ return ceil_div(bit_size(number), 8)
-def ceil_div(num: int, div: int) ->int:
+def ceil_div(num: int, div: int) -> int:
"""
Returns the ceiling function of a division between `num` and `div`.
@@ -77,15 +96,37 @@ def ceil_div(num: int, div: int) ->int:
:return: Rounded up result of the division between the parameters.
"""
- pass
+ quanta, mod = divmod(num, div)
+ if mod:
+ quanta += 1
+ return quanta
-def extended_gcd(a: int, b: int) ->typing.Tuple[int, int, int]:
+def extended_gcd(a: int, b: int) -> typing.Tuple[int, int, int]:
"""Returns a tuple (r, i, j) such that r = gcd(a, b) = ia + jb"""
- pass
-
-
-def inverse(x: int, n: int) ->int:
+ # r = gcd(a,b) i = multiplicitive inverse of a mod b
+ # or j = multiplicitive inverse of b mod a
+ # Neg return values for i or j are made positive mod b or a respectively
+ # Iterateive Version is faster and uses much less stack space
+ x = 0
+ y = 1
+ lx = 1
+ ly = 0
+ oa = a # Remember original a/b to remove
+ ob = b # negative values from return results
+ while b != 0:
+ q = a // b
+ (a, b) = (b, a % b)
+ (x, lx) = ((lx - (q * x)), x)
+ (y, ly) = ((ly - (q * y)), y)
+ if lx < 0:
+ lx += ob # If neg wrap modulo original b
+ if ly < 0:
+ ly += oa # If neg wrap modulo original a
+ return a, lx, ly # Return only positive values
+
+
+def inverse(x: int, n: int) -> int:
"""Returns the inverse of x % n under multiplication, a.k.a x^-1 (mod n)
>>> inverse(7, 4)
@@ -93,11 +134,16 @@ def inverse(x: int, n: int) ->int:
>>> (inverse(143, 4) * 143) % 4
1
"""
- pass
+ (divider, inv, _) = extended_gcd(x, n)
+
+ if divider != 1:
+ raise NotRelativePrimeError(x, n, divider)
+
+ return inv
-def crt(a_values: typing.Iterable[int], modulo_values: typing.Iterable[int]
- ) ->int:
+
+def crt(a_values: typing.Iterable[int], modulo_values: typing.Iterable[int]) -> int:
"""Chinese Remainder Theorem.
Calculates x such that x = a[i] (mod m[i]) for each i.
@@ -116,9 +162,23 @@ def crt(a_values: typing.Iterable[int], modulo_values: typing.Iterable[int]
>>> crt([2, 3, 0], [7, 11, 15])
135
"""
- pass
+ m = 1
+ x = 0
+
+ for modulo in modulo_values:
+ m *= modulo
+
+ for (m_i, a_i) in zip(modulo_values, a_values):
+ M_i = m // m_i
+ inv = inverse(M_i, m_i)
+
+ x = (x + a_i * M_i * inv) % m
-if __name__ == '__main__':
+ return x
+
+
+if __name__ == "__main__":
import doctest
+
doctest.testmod()
diff --git a/rsa/core.py b/rsa/core.py
index e7dba89..84ed3f8 100644
--- a/rsa/core.py
+++ b/rsa/core.py
@@ -1,3 +1,17 @@
+# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
"""Core mathematical operations.
This is the actual core RSA implementation, which is only defined
@@ -5,11 +19,35 @@ mathematically on integers.
"""
-def encrypt_int(message: int, ekey: int, n: int) ->int:
+def assert_int(var: int, name: str) -> None:
+ if isinstance(var, int):
+ return
+
+ raise TypeError("%s should be an integer, not %s" % (name, var.__class__))
+
+
+def encrypt_int(message: int, ekey: int, n: int) -> int:
"""Encrypts a message using encryption key 'ekey', working modulo n"""
- pass
+ assert_int(message, "message")
+ assert_int(ekey, "ekey")
+ assert_int(n, "n")
+
+ if message < 0:
+ raise ValueError("Only non-negative numbers are supported")
-def decrypt_int(cyphertext: int, dkey: int, n: int) ->int:
+ if message > n:
+ raise OverflowError("The message %i is too long for n=%i" % (message, n))
+
+ return pow(message, ekey, n)
+
+
+def decrypt_int(cyphertext: int, dkey: int, n: int) -> int:
"""Decrypts a cypher text using the decryption key 'dkey', working modulo n"""
- pass
+
+ assert_int(cyphertext, "cyphertext")
+ assert_int(dkey, "dkey")
+ assert_int(n, "n")
+
+ message = pow(cyphertext, dkey, n)
+ return message
diff --git a/rsa/key.py b/rsa/key.py
index c42592d..f800644 100644
--- a/rsa/key.py
+++ b/rsa/key.py
@@ -1,3 +1,17 @@
+# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
"""RSA key generation code.
Create new keys with the newkeys() function. It will give you a PublicKey and a
@@ -16,30 +30,42 @@ of pyasn1.
or unauthenticated source.
"""
+
import threading
import typing
import warnings
+
import rsa.prime
import rsa.pem
import rsa.common
import rsa.randnum
import rsa.core
+
+
DEFAULT_EXPONENT = 65537
-T = typing.TypeVar('T', bound='AbstractKey')
+
+
+T = typing.TypeVar("T", bound="AbstractKey")
class AbstractKey:
"""Abstract superclass for private and public keys."""
- __slots__ = 'n', 'e', 'blindfac', 'blindfac_inverse', 'mutex'
- def __init__(self, n: int, e: int) ->None:
+ __slots__ = ("n", "e", "blindfac", "blindfac_inverse", "mutex")
+
+ def __init__(self, n: int, e: int) -> None:
self.n = n
self.e = e
+
+ # These will be computed properly on the first call to blind().
self.blindfac = self.blindfac_inverse = -1
+
+ # Used to protect updates to the blinding factor in multi-threaded
+ # environments.
self.mutex = threading.Lock()
@classmethod
- def _load_pkcs1_pem(cls: typing.Type[T], keyfile: bytes) ->T:
+ def _load_pkcs1_pem(cls: typing.Type[T], keyfile: bytes) -> T:
"""Loads a key in PKCS#1 PEM format, implement in a subclass.
:param keyfile: contents of a PEM-encoded file that contains
@@ -49,10 +75,9 @@ class AbstractKey:
:return: the loaded key
:rtype: AbstractKey
"""
- pass
@classmethod
- def _load_pkcs1_der(cls: typing.Type[T], keyfile: bytes) ->T:
+ def _load_pkcs1_der(cls: typing.Type[T], keyfile: bytes) -> T:
"""Loads a key in PKCS#1 PEM format, implement in a subclass.
:param keyfile: contents of a DER-encoded file that contains
@@ -62,26 +87,23 @@ class AbstractKey:
:return: the loaded key
:rtype: AbstractKey
"""
- pass
- def _save_pkcs1_pem(self) ->bytes:
+ def _save_pkcs1_pem(self) -> bytes:
"""Saves the key in PKCS#1 PEM format, implement in a subclass.
:returns: the PEM-encoded key.
:rtype: bytes
"""
- pass
- def _save_pkcs1_der(self) ->bytes:
+ def _save_pkcs1_der(self) -> bytes:
"""Saves the key in PKCS#1 DER format, implement in a subclass.
:returns: the DER-encoded key.
:rtype: bytes
"""
- pass
@classmethod
- def load_pkcs1(cls: typing.Type[T], keyfile: bytes, format: str='PEM') ->T:
+ def load_pkcs1(cls: typing.Type[T], keyfile: bytes, format: str = "PEM") -> T:
"""Loads a key in PKCS#1 DER or PEM format.
:param keyfile: contents of a DER- or PEM-encoded file that contains
@@ -93,15 +115,30 @@ class AbstractKey:
:return: the loaded key
:rtype: AbstractKey
"""
- pass
+
+ methods = {
+ "PEM": cls._load_pkcs1_pem,
+ "DER": cls._load_pkcs1_der,
+ }
+
+ method = cls._assert_format_exists(format, methods)
+ return method(keyfile)
@staticmethod
- def _assert_format_exists(file_format: str, methods: typing.Mapping[str,
- typing.Callable]) ->typing.Callable:
+ def _assert_format_exists(
+ file_format: str, methods: typing.Mapping[str, typing.Callable]
+ ) -> typing.Callable:
"""Checks whether the given file format exists in 'methods'."""
- pass
- def save_pkcs1(self, format: str='PEM') ->bytes:
+ try:
+ return methods[file_format]
+ except KeyError as ex:
+ formats = ", ".join(sorted(methods.keys()))
+ raise ValueError(
+ "Unsupported format: %r, try one of %s" % (file_format, formats)
+ ) from ex
+
+ def save_pkcs1(self, format: str = "PEM") -> bytes:
"""Saves the key in PKCS#1 DER or PEM format.
:param format: the format to save; 'PEM' or 'DER'
@@ -109,9 +146,16 @@ class AbstractKey:
:returns: the DER- or PEM-encoded key.
:rtype: bytes
"""
- pass
- def blind(self, message: int) ->typing.Tuple[int, int]:
+ methods = {
+ "PEM": self._save_pkcs1_pem,
+ "DER": self._save_pkcs1_der,
+ }
+
+ method = self._assert_format_exists(format, methods)
+ return method()
+
+ def blind(self, message: int) -> typing.Tuple[int, int]:
"""Performs blinding on the message.
:param message: the message, as integer, to blind.
@@ -122,9 +166,11 @@ class AbstractKey:
See https://en.wikipedia.org/wiki/Blinding_%28cryptography%29
"""
- pass
+ blindfac, blindfac_inverse = self._update_blinding_factor()
+ blinded = (message * pow(blindfac, self.e, self.n)) % self.n
+ return blinded, blindfac_inverse
- def unblind(self, blinded: int, blindfac_inverse: int) ->int:
+ def unblind(self, blinded: int, blindfac_inverse: int) -> int:
"""Performs blinding on the message using random number 'blindfac_inverse'.
:param blinded: the blinded message, as integer, to unblind.
@@ -135,9 +181,16 @@ class AbstractKey:
See https://en.wikipedia.org/wiki/Blinding_%28cryptography%29
"""
- pass
+ return (blindfac_inverse * blinded) % self.n
+
+ def _initial_blinding_factor(self) -> int:
+ for _ in range(1000):
+ blind_r = rsa.randnum.randint(self.n - 1)
+ if rsa.prime.are_relatively_prime(self.n, blind_r):
+ return blind_r
+ raise RuntimeError("unable to find blinding factor")
- def _update_blinding_factor(self) ->typing.Tuple[int, int]:
+ def _update_blinding_factor(self) -> typing.Tuple[int, int]:
"""Update blinding factors.
Computing a blinding factor is expensive, so instead this function
@@ -148,7 +201,18 @@ class AbstractKey:
:return: the new blinding factor and its inverse.
"""
- pass
+
+ with self.mutex:
+ if self.blindfac < 0:
+ # Compute initial blinding factor, which is rather slow to do.
+ self.blindfac = self._initial_blinding_factor()
+ self.blindfac_inverse = rsa.common.inverse(self.blindfac, self.n)
+ else:
+ # Reuse previous blinding factor.
+ self.blindfac = pow(self.blindfac, 2, self.n)
+ self.blindfac_inverse = pow(self.blindfac_inverse, 2, self.n)
+
+ return self.blindfac, self.blindfac_inverse
class PublicKey(AbstractKey):
@@ -174,38 +238,41 @@ class PublicKey(AbstractKey):
3
"""
+
__slots__ = ()
- def __getitem__(self, key: str) ->int:
+ def __getitem__(self, key: str) -> int:
return getattr(self, key)
- def __repr__(self) ->str:
- return 'PublicKey(%i, %i)' % (self.n, self.e)
+ def __repr__(self) -> str:
+ return "PublicKey(%i, %i)" % (self.n, self.e)
- def __getstate__(self) ->typing.Tuple[int, int]:
+ def __getstate__(self) -> typing.Tuple[int, int]:
"""Returns the key as tuple for pickling."""
return self.n, self.e
- def __setstate__(self, state: typing.Tuple[int, int]) ->None:
+ def __setstate__(self, state: typing.Tuple[int, int]) -> None:
"""Sets the key from tuple."""
self.n, self.e = state
AbstractKey.__init__(self, self.n, self.e)
- def __eq__(self, other: typing.Any) ->bool:
+ def __eq__(self, other: typing.Any) -> bool:
if other is None:
return False
+
if not isinstance(other, PublicKey):
return False
+
return self.n == other.n and self.e == other.e
- def __ne__(self, other: typing.Any) ->bool:
- return not self == other
+ def __ne__(self, other: typing.Any) -> bool:
+ return not (self == other)
- def __hash__(self) ->int:
+ def __hash__(self) -> int:
return hash((self.n, self.e))
@classmethod
- def _load_pkcs1_der(cls, keyfile: bytes) ->'PublicKey':
+ def _load_pkcs1_der(cls, keyfile: bytes) -> "PublicKey":
"""Loads a key in PKCS#1 DER format.
:param keyfile: contents of a DER-encoded file that contains the public
@@ -224,18 +291,32 @@ class PublicKey(AbstractKey):
PublicKey(2367317549, 65537)
"""
- pass
- def _save_pkcs1_der(self) ->bytes:
+ from pyasn1.codec.der import decoder
+ from rsa.asn1 import AsnPubKey
+
+ (priv, _) = decoder.decode(keyfile, asn1Spec=AsnPubKey())
+ return cls(n=int(priv["modulus"]), e=int(priv["publicExponent"]))
+
+ def _save_pkcs1_der(self) -> bytes:
"""Saves the public key in PKCS#1 DER format.
:returns: the DER-encoded public key.
:rtype: bytes
"""
- pass
+
+ from pyasn1.codec.der import encoder
+ from rsa.asn1 import AsnPubKey
+
+ # Create the ASN object
+ asn_key = AsnPubKey()
+ asn_key.setComponentByName("modulus", self.n)
+ asn_key.setComponentByName("publicExponent", self.e)
+
+ return encoder.encode(asn_key)
@classmethod
- def _load_pkcs1_pem(cls, keyfile: bytes) ->'PublicKey':
+ def _load_pkcs1_pem(cls, keyfile: bytes) -> "PublicKey":
"""Loads a PKCS#1 PEM-encoded public key file.
The contents of the file before the "-----BEGIN RSA PUBLIC KEY-----" and
@@ -245,18 +326,22 @@ class PublicKey(AbstractKey):
key.
:return: a PublicKey object
"""
- pass
- def _save_pkcs1_pem(self) ->bytes:
+ der = rsa.pem.load_pem(keyfile, "RSA PUBLIC KEY")
+ return cls._load_pkcs1_der(der)
+
+ def _save_pkcs1_pem(self) -> bytes:
"""Saves a PKCS#1 PEM-encoded public key file.
:return: contents of a PEM-encoded file that contains the public key.
:rtype: bytes
"""
- pass
+
+ der = self._save_pkcs1_der()
+ return rsa.pem.save_pem(der, "RSA PUBLIC KEY")
@classmethod
- def load_pkcs1_openssl_pem(cls, keyfile: bytes) ->'PublicKey':
+ def load_pkcs1_openssl_pem(cls, keyfile: bytes) -> "PublicKey":
"""Loads a PKCS#1.5 PEM-encoded public key file from OpenSSL.
These files can be recognised in that they start with BEGIN PUBLIC KEY
@@ -270,17 +355,29 @@ class PublicKey(AbstractKey):
:type keyfile: bytes
:return: a PublicKey object
"""
- pass
+
+ der = rsa.pem.load_pem(keyfile, "PUBLIC KEY")
+ return cls.load_pkcs1_openssl_der(der)
@classmethod
- def load_pkcs1_openssl_der(cls, keyfile: bytes) ->'PublicKey':
+ def load_pkcs1_openssl_der(cls, keyfile: bytes) -> "PublicKey":
"""Loads a PKCS#1 DER-encoded public key file from OpenSSL.
:param keyfile: contents of a DER-encoded file that contains the public
key, from OpenSSL.
:return: a PublicKey object
"""
- pass
+
+ from rsa.asn1 import OpenSSLPubKey
+ from pyasn1.codec.der import decoder
+ from pyasn1.type import univ
+
+ (keyinfo, _) = decoder.decode(keyfile, asn1Spec=OpenSSLPubKey())
+
+ if keyinfo["header"]["oid"] != univ.ObjectIdentifier("1.2.840.113549.1.1.1"):
+ raise TypeError("This is not a DER-encoded OpenSSL-compatible public key")
+
+ return cls._load_pkcs1_der(keyinfo["key"][1:])
class PrivateKey(AbstractKey):
@@ -306,54 +403,66 @@ class PrivateKey(AbstractKey):
50797
"""
- __slots__ = 'd', 'p', 'q', 'exp1', 'exp2', 'coef'
- def __init__(self, n: int, e: int, d: int, p: int, q: int) ->None:
+ __slots__ = ("d", "p", "q", "exp1", "exp2", "coef")
+
+ def __init__(self, n: int, e: int, d: int, p: int, q: int) -> None:
AbstractKey.__init__(self, n, e)
self.d = d
self.p = p
self.q = q
+
+ # Calculate exponents and coefficient.
self.exp1 = int(d % (p - 1))
self.exp2 = int(d % (q - 1))
self.coef = rsa.common.inverse(q, p)
- def __getitem__(self, key: str) ->int:
+ def __getitem__(self, key: str) -> int:
return getattr(self, key)
- def __repr__(self) ->str:
- return 'PrivateKey(%i, %i, %i, %i, %i)' % (self.n, self.e, self.d,
- self.p, self.q)
+ def __repr__(self) -> str:
+ return "PrivateKey(%i, %i, %i, %i, %i)" % (
+ self.n,
+ self.e,
+ self.d,
+ self.p,
+ self.q,
+ )
- def __getstate__(self) ->typing.Tuple[int, int, int, int, int, int, int,
- int]:
+ def __getstate__(self) -> typing.Tuple[int, int, int, int, int, int, int, int]:
"""Returns the key as tuple for pickling."""
- return (self.n, self.e, self.d, self.p, self.q, self.exp1, self.
- exp2, self.coef)
+ return self.n, self.e, self.d, self.p, self.q, self.exp1, self.exp2, self.coef
- def __setstate__(self, state: typing.Tuple[int, int, int, int, int, int,
- int, int]) ->None:
+ def __setstate__(self, state: typing.Tuple[int, int, int, int, int, int, int, int]) -> None:
"""Sets the key from tuple."""
- (self.n, self.e, self.d, self.p, self.q, self.exp1, self.exp2, self
- .coef) = state
+ self.n, self.e, self.d, self.p, self.q, self.exp1, self.exp2, self.coef = state
AbstractKey.__init__(self, self.n, self.e)
- def __eq__(self, other: typing.Any) ->bool:
+ def __eq__(self, other: typing.Any) -> bool:
if other is None:
return False
+
if not isinstance(other, PrivateKey):
return False
- return (self.n == other.n and self.e == other.e and self.d == other
- .d and self.p == other.p and self.q == other.q and self.exp1 ==
- other.exp1 and self.exp2 == other.exp2 and self.coef == other.coef)
- def __ne__(self, other: typing.Any) ->bool:
- return not self == other
+ return (
+ self.n == other.n
+ and self.e == other.e
+ and self.d == other.d
+ and self.p == other.p
+ and self.q == other.q
+ and self.exp1 == other.exp1
+ and self.exp2 == other.exp2
+ and self.coef == other.coef
+ )
+
+ def __ne__(self, other: typing.Any) -> bool:
+ return not (self == other)
- def __hash__(self) ->int:
- return hash((self.n, self.e, self.d, self.p, self.q, self.exp1,
- self.exp2, self.coef))
+ def __hash__(self) -> int:
+ return hash((self.n, self.e, self.d, self.p, self.q, self.exp1, self.exp2, self.coef))
- def blinded_decrypt(self, encrypted: int) ->int:
+ def blinded_decrypt(self, encrypted: int) -> int:
"""Decrypts the message using blinding to prevent side-channel attacks.
:param encrypted: the encrypted message
@@ -362,9 +471,22 @@ class PrivateKey(AbstractKey):
:returns: the decrypted message
:rtype: int
"""
- pass
- def blinded_encrypt(self, message: int) ->int:
+ # Blinding and un-blinding should be using the same factor
+ blinded, blindfac_inverse = self.blind(encrypted)
+
+ # Instead of using the core functionality, use the Chinese Remainder
+ # Theorem and be 2-4x faster. This the same as:
+ #
+ # decrypted = rsa.core.decrypt_int(blinded, self.d, self.n)
+ s1 = pow(blinded, self.exp1, self.p)
+ s2 = pow(blinded, self.exp2, self.q)
+ h = ((s1 - s2) * self.coef) % self.p
+ decrypted = s2 + self.q * h
+
+ return self.unblind(decrypted, blindfac_inverse)
+
+ def blinded_encrypt(self, message: int) -> int:
"""Encrypts the message using blinding to prevent side-channel attacks.
:param message: the message to encrypt
@@ -373,10 +495,13 @@ class PrivateKey(AbstractKey):
:returns: the encrypted message
:rtype: int
"""
- pass
+
+ blinded, blindfac_inverse = self.blind(message)
+ encrypted = rsa.core.encrypt_int(blinded, self.d, self.n)
+ return self.unblind(encrypted, blindfac_inverse)
@classmethod
- def _load_pkcs1_der(cls, keyfile: bytes) ->'PrivateKey':
+ def _load_pkcs1_der(cls, keyfile: bytes) -> "PrivateKey":
"""Loads a key in PKCS#1 DER format.
:param keyfile: contents of a DER-encoded file that contains the private
@@ -396,18 +521,83 @@ class PrivateKey(AbstractKey):
PrivateKey(3727264081, 65537, 3349121513, 65063, 57287)
"""
- pass
- def _save_pkcs1_der(self) ->bytes:
+ from pyasn1.codec.der import decoder
+
+ (priv, _) = decoder.decode(keyfile)
+
+ # ASN.1 contents of DER encoded private key:
+ #
+ # RSAPrivateKey ::= SEQUENCE {
+ # version Version,
+ # modulus INTEGER, -- n
+ # publicExponent INTEGER, -- e
+ # privateExponent INTEGER, -- d
+ # prime1 INTEGER, -- p
+ # prime2 INTEGER, -- q
+ # exponent1 INTEGER, -- d mod (p-1)
+ # exponent2 INTEGER, -- d mod (q-1)
+ # coefficient INTEGER, -- (inverse of q) mod p
+ # otherPrimeInfos OtherPrimeInfos OPTIONAL
+ # }
+
+ if priv[0] != 0:
+ raise ValueError("Unable to read this file, version %s != 0" % priv[0])
+
+ as_ints = map(int, priv[1:6])
+ key = cls(*as_ints)
+
+ exp1, exp2, coef = map(int, priv[6:9])
+
+ if (key.exp1, key.exp2, key.coef) != (exp1, exp2, coef):
+ warnings.warn(
+ "You have provided a malformed keyfile. Either the exponents "
+ "or the coefficient are incorrect. Using the correct values "
+ "instead.",
+ UserWarning,
+ )
+
+ return key
+
+ def _save_pkcs1_der(self) -> bytes:
"""Saves the private key in PKCS#1 DER format.
:returns: the DER-encoded private key.
:rtype: bytes
"""
- pass
+
+ from pyasn1.type import univ, namedtype
+ from pyasn1.codec.der import encoder
+
+ class AsnPrivKey(univ.Sequence):
+ componentType = namedtype.NamedTypes(
+ namedtype.NamedType("version", univ.Integer()),
+ namedtype.NamedType("modulus", univ.Integer()),
+ namedtype.NamedType("publicExponent", univ.Integer()),
+ namedtype.NamedType("privateExponent", univ.Integer()),
+ namedtype.NamedType("prime1", univ.Integer()),
+ namedtype.NamedType("prime2", univ.Integer()),
+ namedtype.NamedType("exponent1", univ.Integer()),
+ namedtype.NamedType("exponent2", univ.Integer()),
+ namedtype.NamedType("coefficient", univ.Integer()),
+ )
+
+ # Create the ASN object
+ asn_key = AsnPrivKey()
+ asn_key.setComponentByName("version", 0)
+ asn_key.setComponentByName("modulus", self.n)
+ asn_key.setComponentByName("publicExponent", self.e)
+ asn_key.setComponentByName("privateExponent", self.d)
+ asn_key.setComponentByName("prime1", self.p)
+ asn_key.setComponentByName("prime2", self.q)
+ asn_key.setComponentByName("exponent1", self.exp1)
+ asn_key.setComponentByName("exponent2", self.exp2)
+ asn_key.setComponentByName("coefficient", self.coef)
+
+ return encoder.encode(asn_key)
@classmethod
- def _load_pkcs1_pem(cls, keyfile: bytes) ->'PrivateKey':
+ def _load_pkcs1_pem(cls, keyfile: bytes) -> "PrivateKey":
"""Loads a PKCS#1 PEM-encoded private key file.
The contents of the file before the "-----BEGIN RSA PRIVATE KEY-----" and
@@ -418,19 +608,26 @@ class PrivateKey(AbstractKey):
:type keyfile: bytes
:return: a PrivateKey object
"""
- pass
- def _save_pkcs1_pem(self) ->bytes:
+ der = rsa.pem.load_pem(keyfile, b"RSA PRIVATE KEY")
+ return cls._load_pkcs1_der(der)
+
+ def _save_pkcs1_pem(self) -> bytes:
"""Saves a PKCS#1 PEM-encoded private key file.
:return: contents of a PEM-encoded file that contains the private key.
:rtype: bytes
"""
- pass
+ der = self._save_pkcs1_der()
+ return rsa.pem.save_pem(der, b"RSA PRIVATE KEY")
-def find_p_q(nbits: int, getprime_func: typing.Callable[[int], int]=rsa.
- prime.getprime, accurate: bool=True) ->typing.Tuple[int, int]:
+
+def find_p_q(
+ nbits: int,
+ getprime_func: typing.Callable[[int], int] = rsa.prime.getprime,
+ accurate: bool = True,
+) -> typing.Tuple[int, int]:
"""Returns a tuple of two different primes of nbits bits each.
The resulting p * q has exactly 2 * nbits bits, and the returned p and q
@@ -460,11 +657,53 @@ def find_p_q(nbits: int, getprime_func: typing.Callable[[int], int]=rsa.
True
"""
- pass
+ total_bits = nbits * 2
+
+ # Make sure that p and q aren't too close or the factoring programs can
+ # factor n.
+ shift = nbits // 16
+ pbits = nbits + shift
+ qbits = nbits - shift
+
+ # Choose the two initial primes
+ p = getprime_func(pbits)
+ q = getprime_func(qbits)
+
+ def is_acceptable(p: int, q: int) -> bool:
+ """Returns True iff p and q are acceptable:
+
+ - p and q differ
+ - (p * q) has the right nr of bits (when accurate=True)
+ """
+
+ if p == q:
+ return False
+
+ if not accurate:
+ return True
+
+ # Make sure we have just the right amount of bits
+ found_size = rsa.common.bit_size(p * q)
+ return total_bits == found_size
+
+ # Keep choosing other primes until they match our requirements.
+ change_p = False
+ while not is_acceptable(p, q):
+ # Change p on one iteration and q on the other
+ if change_p:
+ p = getprime_func(pbits)
+ else:
+ q = getprime_func(qbits)
+
+ change_p = not change_p
-def calculate_keys_custom_exponent(p: int, q: int, exponent: int
- ) ->typing.Tuple[int, int]:
+ # We want p > q as described on
+ # http://www.di-mgt.com.au/rsa_alg.html#crt
+ return max(p, q), min(p, q)
+
+
+def calculate_keys_custom_exponent(p: int, q: int, exponent: int) -> typing.Tuple[int, int]:
"""Calculates an encryption and a decryption key given p, q and an exponent,
and returns them as a tuple (e, d)
@@ -476,10 +715,29 @@ def calculate_keys_custom_exponent(p: int, q: int, exponent: int
:type exponent: int
"""
- pass
+ phi_n = (p - 1) * (q - 1)
+
+ try:
+ d = rsa.common.inverse(exponent, phi_n)
+ except rsa.common.NotRelativePrimeError as ex:
+ raise rsa.common.NotRelativePrimeError(
+ exponent,
+ phi_n,
+ ex.d,
+ msg="e (%d) and phi_n (%d) are not relatively prime (divider=%i)"
+ % (exponent, phi_n, ex.d),
+ ) from ex
+
+ if (exponent * d) % phi_n != 1:
+ raise ValueError(
+ "e (%d) and d (%d) are not mult. inv. modulo " "phi_n (%d)" % (exponent, d, phi_n)
+ )
+
+ return exponent, d
-def calculate_keys(p: int, q: int) ->typing.Tuple[int, int]:
+
+def calculate_keys(p: int, q: int) -> typing.Tuple[int, int]:
"""Calculates an encryption and a decryption key given p and q, and
returns them as a tuple (e, d)
@@ -488,12 +746,16 @@ def calculate_keys(p: int, q: int) ->typing.Tuple[int, int]:
:return: tuple (e, d) with the encryption and decryption exponents.
"""
- pass
+
+ return calculate_keys_custom_exponent(p, q, DEFAULT_EXPONENT)
-def gen_keys(nbits: int, getprime_func: typing.Callable[[int], int],
- accurate: bool=True, exponent: int=DEFAULT_EXPONENT) ->typing.Tuple[int,
- int, int, int]:
+def gen_keys(
+ nbits: int,
+ getprime_func: typing.Callable[[int], int],
+ accurate: bool = True,
+ exponent: int = DEFAULT_EXPONENT,
+) -> typing.Tuple[int, int, int, int]:
"""Generate RSA keys of nbits bits. Returns (p, q, e, d).
Note: this can take a long time, depending on the key size.
@@ -507,11 +769,26 @@ def gen_keys(nbits: int, getprime_func: typing.Callable[[int], int],
private key can be cracked. A very common choice for e is 65537.
:type exponent: int
"""
- pass
-
-def newkeys(nbits: int, accurate: bool=True, poolsize: int=1, exponent: int
- =DEFAULT_EXPONENT) ->typing.Tuple[PublicKey, PrivateKey]:
+ # Regenerate p and q values, until calculate_keys doesn't raise a
+ # ValueError.
+ while True:
+ (p, q) = find_p_q(nbits // 2, getprime_func, accurate)
+ try:
+ (e, d) = calculate_keys_custom_exponent(p, q, exponent=exponent)
+ break
+ except ValueError:
+ pass
+
+ return p, q, e, d
+
+
+def newkeys(
+ nbits: int,
+ accurate: bool = True,
+ poolsize: int = 1,
+ exponent: int = DEFAULT_EXPONENT,
+) -> typing.Tuple[PublicKey, PrivateKey]:
"""Generates public and private keys, and returns them as (pub, priv).
The public key is also known as the 'encryption key', and is a
@@ -536,20 +813,46 @@ def newkeys(nbits: int, accurate: bool=True, poolsize: int=1, exponent: int
Python 2.6 or newer.
"""
- pass
+ if nbits < 16:
+ raise ValueError("Key too small")
+
+ if poolsize < 1:
+ raise ValueError("Pool size (%i) should be >= 1" % poolsize)
+
+ # Determine which getprime function to use
+ if poolsize > 1:
+ from rsa import parallel
-__all__ = ['PublicKey', 'PrivateKey', 'newkeys']
-if __name__ == '__main__':
+ def getprime_func(nbits: int) -> int:
+ return parallel.getprime(nbits, poolsize=poolsize)
+
+ else:
+ getprime_func = rsa.prime.getprime
+
+ # Generate the key components
+ (p, q, e, d) = gen_keys(nbits, getprime_func, accurate=accurate, exponent=exponent)
+
+ # Create the key objects
+ n = p * q
+
+ return (PublicKey(n, e), PrivateKey(n, e, d, p, q))
+
+
+__all__ = ["PublicKey", "PrivateKey", "newkeys"]
+
+if __name__ == "__main__":
import doctest
+
try:
for count in range(100):
- failures, tests = doctest.testmod()
+ (failures, tests) = doctest.testmod()
if failures:
break
- if count % 10 == 0 and count or count == 1:
- print('%i times' % count)
+
+ if (count % 10 == 0 and count) or count == 1:
+ print("%i times" % count)
except KeyboardInterrupt:
- print('Aborted')
+ print("Aborted")
else:
- print('Doctests done')
+ print("Doctests done")
diff --git a/rsa/parallel.py b/rsa/parallel.py
index 0d3a4f8..5020edb 100644
--- a/rsa/parallel.py
+++ b/rsa/parallel.py
@@ -1,3 +1,17 @@
+# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
"""Functions for parallel computation on multiple cores.
Introduced in Python-RSA 3.1.
@@ -7,13 +21,25 @@ Introduced in Python-RSA 3.1.
Requires Python 2.6 or newer.
"""
+
import multiprocessing as mp
from multiprocessing.connection import Connection
+
import rsa.prime
import rsa.randnum
-def getprime(nbits: int, poolsize: int) ->int:
+def _find_prime(nbits: int, pipe: Connection) -> None:
+ while True:
+ integer = rsa.randnum.read_random_odd_int(nbits)
+
+ # Test for primeness
+ if rsa.prime.is_prime(integer):
+ pipe.send(integer)
+ return
+
+
+def getprime(nbits: int, poolsize: int) -> int:
"""Returns a prime number that can be stored in 'nbits' bits.
Works in multiple threads at the same time.
@@ -31,17 +57,40 @@ def getprime(nbits: int, poolsize: int) ->int:
True
"""
- pass
+ (pipe_recv, pipe_send) = mp.Pipe(duplex=False)
+
+ # Create processes
+ try:
+ procs = [mp.Process(target=_find_prime, args=(nbits, pipe_send)) for _ in range(poolsize)]
+ # Start processes
+ for p in procs:
+ p.start()
+
+ result = pipe_recv.recv()
+ finally:
+ pipe_recv.close()
+ pipe_send.close()
-__all__ = ['getprime']
-if __name__ == '__main__':
- print('Running doctests 1000x or until failure')
+ # Terminate processes
+ for p in procs:
+ p.terminate()
+
+ return result
+
+
+__all__ = ["getprime"]
+
+if __name__ == "__main__":
+ print("Running doctests 1000x or until failure")
import doctest
+
for count in range(100):
- failures, tests = doctest.testmod()
+ (failures, tests) = doctest.testmod()
if failures:
break
+
if count % 10 == 0 and count:
- print('%i times' % count)
- print('Doctests done')
+ print("%i times" % count)
+
+ print("Doctests done")
diff --git a/rsa/pem.py b/rsa/pem.py
index b2b919a..5d26e6e 100644
--- a/rsa/pem.py
+++ b/rsa/pem.py
@@ -1,23 +1,86 @@
+# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
"""Functions that load and write PEM-encoded files."""
+
import base64
import typing
+
+# Should either be ASCII strings or bytes.
FlexiText = typing.Union[str, bytes]
-def _markers(pem_marker: FlexiText) ->typing.Tuple[bytes, bytes]:
+def _markers(pem_marker: FlexiText) -> typing.Tuple[bytes, bytes]:
"""
Returns the start and end PEM markers, as bytes.
"""
- pass
+ if not isinstance(pem_marker, bytes):
+ pem_marker = pem_marker.encode("ascii")
-def _pem_lines(contents: bytes, pem_start: bytes, pem_end: bytes
- ) ->typing.Iterator[bytes]:
+ return (
+ b"-----BEGIN " + pem_marker + b"-----",
+ b"-----END " + pem_marker + b"-----",
+ )
+
+
+def _pem_lines(contents: bytes, pem_start: bytes, pem_end: bytes) -> typing.Iterator[bytes]:
"""Generator over PEM lines between pem_start and pem_end."""
- pass
+ in_pem_part = False
+ seen_pem_start = False
+
+ for line in contents.splitlines():
+ line = line.strip()
+
+ # Skip empty lines
+ if not line:
+ continue
+
+ # Handle start marker
+ if line == pem_start:
+ if in_pem_part:
+ raise ValueError('Seen start marker "%r" twice' % pem_start)
+
+ in_pem_part = True
+ seen_pem_start = True
+ continue
+
+ # Skip stuff before first marker
+ if not in_pem_part:
+ continue
+
+ # Handle end marker
+ if in_pem_part and line == pem_end:
+ in_pem_part = False
+ break
+
+ # Load fields
+ if b":" in line:
+ continue
-def load_pem(contents: FlexiText, pem_marker: FlexiText) ->bytes:
+ yield line
+
+ # Do some sanity checks
+ if not seen_pem_start:
+ raise ValueError('No PEM start marker "%r" found' % pem_start)
+
+ if in_pem_part:
+ raise ValueError('No PEM end marker "%r" found' % pem_end)
+
+
+def load_pem(contents: FlexiText, pem_marker: FlexiText) -> bytes:
"""Loads a PEM file.
:param contents: the contents of the file to interpret
@@ -31,10 +94,20 @@ def load_pem(contents: FlexiText, pem_marker: FlexiText) ->bytes:
marker cannot be found.
"""
- pass
+ # We want bytes, not text. If it's text, it can be converted to ASCII bytes.
+ if not isinstance(contents, bytes):
+ contents = contents.encode("ascii")
+
+ (pem_start, pem_end) = _markers(pem_marker)
+ pem_lines = [line for line in _pem_lines(contents, pem_start, pem_end)]
+
+ # Base64-decode the contents
+ pem = b"".join(pem_lines)
+ return base64.standard_b64decode(pem)
-def save_pem(contents: bytes, pem_marker: FlexiText) ->bytes:
+
+def save_pem(contents: bytes, pem_marker: FlexiText) -> bytes:
"""Saves a PEM file.
:param contents: the contents to encode in PEM format
@@ -45,4 +118,17 @@ def save_pem(contents: bytes, pem_marker: FlexiText) ->bytes:
:return: the base64-encoded content between the start and end markers, as bytes.
"""
- pass
+
+ (pem_start, pem_end) = _markers(pem_marker)
+
+ b64 = base64.standard_b64encode(contents).replace(b"\n", b"")
+ pem_lines = [pem_start]
+
+ for block_start in range(0, len(b64), 64):
+ block = b64[block_start : block_start + 64]
+ pem_lines.append(block)
+
+ pem_lines.append(pem_end)
+ pem_lines.append(b"")
+
+ return b"\n".join(pem_lines)
diff --git a/rsa/pkcs1.py b/rsa/pkcs1.py
index 5359be7..ec6998e 100644
--- a/rsa/pkcs1.py
+++ b/rsa/pkcs1.py
@@ -1,3 +1,17 @@
+# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
"""Functions for PKCS#1 version 1.5 encryption and signing
This module implements certain functionality from PKCS#1 version 1.5. For a
@@ -11,35 +25,58 @@ that are raised contain the Python traceback information, which can be used to
deduce where in the process the failure occurred. DO NOT PASS SUCH INFORMATION
to your users.
"""
+
import hashlib
import os
import sys
import typing
from hmac import compare_digest
+
from . import common, transform, core, key
+
if typing.TYPE_CHECKING:
HashType = hashlib._Hash
else:
HashType = typing.Any
-HASH_ASN1 = {'MD5':
- b'0 0\x0c\x06\x08*\x86H\x86\xf7\r\x02\x05\x05\x00\x04\x10', 'SHA-1':
- b'0!0\t\x06\x05+\x0e\x03\x02\x1a\x05\x00\x04\x14', 'SHA-224':
- b'0-0\r\x06\t`\x86H\x01e\x03\x04\x02\x04\x05\x00\x04\x1c', 'SHA-256':
- b'010\r\x06\t`\x86H\x01e\x03\x04\x02\x01\x05\x00\x04 ', 'SHA-384':
- b'0A0\r\x06\t`\x86H\x01e\x03\x04\x02\x02\x05\x00\x040', 'SHA-512':
- b'0Q0\r\x06\t`\x86H\x01e\x03\x04\x02\x03\x05\x00\x04@'}
-HASH_METHODS: typing.Dict[str, typing.Callable[[], HashType]] = {'MD5':
- hashlib.md5, 'SHA-1': hashlib.sha1, 'SHA-224': hashlib.sha224,
- 'SHA-256': hashlib.sha256, 'SHA-384': hashlib.sha384, 'SHA-512':
- hashlib.sha512}
+
+# ASN.1 codes that describe the hash algorithm used.
+HASH_ASN1 = {
+ "MD5": b"\x30\x20\x30\x0c\x06\x08\x2a\x86\x48\x86\xf7\x0d\x02\x05\x05\x00\x04\x10",
+ "SHA-1": b"\x30\x21\x30\x09\x06\x05\x2b\x0e\x03\x02\x1a\x05\x00\x04\x14",
+ "SHA-224": b"\x30\x2d\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x04\x05\x00\x04\x1c",
+ "SHA-256": b"\x30\x31\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x01\x05\x00\x04\x20",
+ "SHA-384": b"\x30\x41\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x02\x05\x00\x04\x30",
+ "SHA-512": b"\x30\x51\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x03\x05\x00\x04\x40",
+}
+
+HASH_METHODS: typing.Dict[str, typing.Callable[[], HashType]] = {
+ "MD5": hashlib.md5,
+ "SHA-1": hashlib.sha1,
+ "SHA-224": hashlib.sha224,
+ "SHA-256": hashlib.sha256,
+ "SHA-384": hashlib.sha384,
+ "SHA-512": hashlib.sha512,
+}
"""Hash methods supported by this library."""
+
+
if sys.version_info >= (3, 6):
- HASH_ASN1.update({'SHA3-256':
- b'010\r\x06\t`\x86H\x01e\x03\x04\x02\x08\x05\x00\x04 ', 'SHA3-384':
- b'0A0\r\x06\t`\x86H\x01e\x03\x04\x02\t\x05\x00\x040', 'SHA3-512':
- b'0Q0\r\x06\t`\x86H\x01e\x03\x04\x02\n\x05\x00\x04@'})
- HASH_METHODS.update({'SHA3-256': hashlib.sha3_256, 'SHA3-384': hashlib.
- sha3_384, 'SHA3-512': hashlib.sha3_512})
+ # Python 3.6 introduced SHA3 support.
+ HASH_ASN1.update(
+ {
+ "SHA3-256": b"\x30\x31\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x08\x05\x00\x04\x20",
+ "SHA3-384": b"\x30\x41\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x09\x05\x00\x04\x30",
+ "SHA3-512": b"\x30\x51\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x0a\x05\x00\x04\x40",
+ }
+ )
+
+ HASH_METHODS.update(
+ {
+ "SHA3-256": hashlib.sha3_256,
+ "SHA3-384": hashlib.sha3_384,
+ "SHA3-512": hashlib.sha3_512,
+ }
+ )
class CryptoError(Exception):
@@ -54,8 +91,8 @@ class VerificationError(CryptoError):
"""Raised when verification fails."""
-def _pad_for_encryption(message: bytes, target_length: int) ->bytes:
- """Pads the message for encryption, returning the padded message.
+def _pad_for_encryption(message: bytes, target_length: int) -> bytes:
+ r"""Pads the message for encryption, returning the padded message.
:return: 00 02 RANDOM_DATA 00 MESSAGE
@@ -63,16 +100,44 @@ def _pad_for_encryption(message: bytes, target_length: int) ->bytes:
>>> len(block)
16
>>> block[0:2]
- b'\\x00\\x02'
+ b'\x00\x02'
>>> block[-6:]
- b'\\x00hello'
+ b'\x00hello'
"""
- pass
+ max_msglength = target_length - 11
+ msglength = len(message)
+
+ if msglength > max_msglength:
+ raise OverflowError(
+ "%i bytes needed for message, but there is only"
+ " space for %i" % (msglength, max_msglength)
+ )
-def _pad_for_signing(message: bytes, target_length: int) ->bytes:
- """Pads the message for signing, returning the padded message.
+ # Get random padding
+ padding = b""
+ padding_length = target_length - msglength - 3
+
+ # We remove 0-bytes, so we'll end up with less padding than we've asked for,
+ # so keep adding data until we're at the correct length.
+ while len(padding) < padding_length:
+ needed_bytes = padding_length - len(padding)
+
+ # Always read at least 8 bytes more than we need, and trim off the rest
+ # after removing the 0-bytes. This increases the chance of getting
+ # enough bytes, especially when needed_bytes is small
+ new_padding = os.urandom(needed_bytes + 5)
+ new_padding = new_padding.replace(b"\x00", b"")
+ padding = padding + new_padding[:needed_bytes]
+
+ assert len(padding) == padding_length
+
+ return b"".join([b"\x00\x02", padding, b"\x00", message])
+
+
+def _pad_for_signing(message: bytes, target_length: int) -> bytes:
+ r"""Pads the message for signing, returning the padded message.
The padding is always a repetition of FF bytes.
@@ -82,17 +147,29 @@ def _pad_for_signing(message: bytes, target_length: int) ->bytes:
>>> len(block)
16
>>> block[0:2]
- b'\\x00\\x01'
+ b'\x00\x01'
>>> block[-6:]
- b'\\x00hello'
+ b'\x00hello'
>>> block[2:-6]
- b'\\xff\\xff\\xff\\xff\\xff\\xff\\xff\\xff'
+ b'\xff\xff\xff\xff\xff\xff\xff\xff'
"""
- pass
+ max_msglength = target_length - 11
+ msglength = len(message)
+
+ if msglength > max_msglength:
+ raise OverflowError(
+ "%i bytes needed for message, but there is only"
+ " space for %i" % (msglength, max_msglength)
+ )
-def encrypt(message: bytes, pub_key: key.PublicKey) ->bytes:
+ padding_length = target_length - msglength - 3
+
+ return b"".join([b"\x00\x01", padding_length * b"\xff", b"\x00", message])
+
+
+def encrypt(message: bytes, pub_key: key.PublicKey) -> bytes:
"""Encrypts the given message using PKCS#1 v1.5
:param message: the message to encrypt. Must be a byte string no longer than
@@ -113,11 +190,19 @@ def encrypt(message: bytes, pub_key: key.PublicKey) ->bytes:
True
"""
- pass
+ keylength = common.byte_size(pub_key.n)
+ padded = _pad_for_encryption(message, keylength)
+
+ payload = transform.bytes2int(padded)
+ encrypted = core.encrypt_int(payload, pub_key.e, pub_key.n)
+ block = transform.int2bytes(encrypted, keylength)
-def decrypt(crypto: bytes, priv_key: key.PrivateKey) ->bytes:
- """Decrypts the given message using PKCS#1 v1.5
+ return block
+
+
+def decrypt(crypto: bytes, priv_key: key.PrivateKey) -> bytes:
+ r"""Decrypts the given message using PKCS#1 v1.5
The decryption is considered 'failed' when the resulting cleartext doesn't
start with the bytes 00 02, or when the 00 byte between the padding and
@@ -141,9 +226,9 @@ def decrypt(crypto: bytes, priv_key: key.PrivateKey) ->bytes:
And with binary data:
- >>> crypto = encrypt(b'\\x00\\x00\\x00\\x00\\x01', pub_key)
+ >>> crypto = encrypt(b'\x00\x00\x00\x00\x01', pub_key)
>>> decrypt(crypto, priv_key)
- b'\\x00\\x00\\x00\\x00\\x01'
+ b'\x00\x00\x00\x00\x01'
Altering the encrypted information will *likely* cause a
:py:class:`rsa.pkcs1.DecryptionError`. If you want to be *sure*, use
@@ -166,11 +251,40 @@ def decrypt(crypto: bytes, priv_key: key.PrivateKey) ->bytes:
rsa.pkcs1.DecryptionError: Decryption failed
"""
- pass
+ blocksize = common.byte_size(priv_key.n)
+ encrypted = transform.bytes2int(crypto)
+ decrypted = priv_key.blinded_decrypt(encrypted)
+ cleartext = transform.int2bytes(decrypted, blocksize)
+
+ # Detect leading zeroes in the crypto. These are not reflected in the
+ # encrypted value (as leading zeroes do not influence the value of an
+ # integer). This fixes CVE-2020-13757.
+ if len(crypto) > blocksize:
+ # This is operating on public information, so doesn't need to be constant-time.
+ raise DecryptionError("Decryption failed")
-def sign_hash(hash_value: bytes, priv_key: key.PrivateKey, hash_method: str
- ) ->bytes:
+ # If we can't find the cleartext marker, decryption failed.
+ cleartext_marker_bad = not compare_digest(cleartext[:2], b"\x00\x02")
+
+ # Find the 00 separator between the padding and the message
+ sep_idx = cleartext.find(b"\x00", 2)
+
+ # sep_idx indicates the position of the `\x00` separator that separates the
+ # padding from the actual message. The padding should be at least 8 bytes
+ # long (see https://tools.ietf.org/html/rfc8017#section-7.2.2 step 3), which
+ # means the separator should be at least at index 10 (because of the
+ # `\x00\x02` marker that precedes it).
+ sep_idx_bad = sep_idx < 10
+
+ anything_bad = cleartext_marker_bad | sep_idx_bad
+ if anything_bad:
+ raise DecryptionError("Decryption failed")
+
+ return cleartext[sep_idx + 1 :]
+
+
+def sign_hash(hash_value: bytes, priv_key: key.PrivateKey, hash_method: str) -> bytes:
"""Signs a precomputed hash with the private key.
Hashes the message, then signs the hash with the given key. This is known
@@ -185,10 +299,25 @@ def sign_hash(hash_value: bytes, priv_key: key.PrivateKey, hash_method: str
requested hash.
"""
- pass
+ # Get the ASN1 code for this hash method
+ if hash_method not in HASH_ASN1:
+ raise ValueError("Invalid hash method: %s" % hash_method)
+ asn1code = HASH_ASN1[hash_method]
+
+ # Encrypt the hash with the private key
+ cleartext = asn1code + hash_value
+ keylength = common.byte_size(priv_key.n)
+ padded = _pad_for_signing(cleartext, keylength)
-def sign(message: bytes, priv_key: key.PrivateKey, hash_method: str) ->bytes:
+ payload = transform.bytes2int(padded)
+ encrypted = priv_key.blinded_encrypt(payload)
+ block = transform.int2bytes(encrypted, keylength)
+
+ return block
+
+
+def sign(message: bytes, priv_key: key.PrivateKey, hash_method: str) -> bytes:
"""Signs the message with the private key.
Hashes the message, then signs the hash with the given key. This is known
@@ -205,10 +334,12 @@ def sign(message: bytes, priv_key: key.PrivateKey, hash_method: str) ->bytes:
requested hash.
"""
- pass
+
+ msg_hash = compute_hash(message, hash_method)
+ return sign_hash(msg_hash, priv_key, hash_method)
-def verify(message: bytes, signature: bytes, pub_key: key.PublicKey) ->str:
+def verify(message: bytes, signature: bytes, pub_key: key.PublicKey) -> str:
"""Verifies that the signature matches the message.
The hash method is detected automatically from the signature.
@@ -222,10 +353,31 @@ def verify(message: bytes, signature: bytes, pub_key: key.PublicKey) ->str:
:returns: the name of the used hash.
"""
- pass
+ keylength = common.byte_size(pub_key.n)
+ encrypted = transform.bytes2int(signature)
+ decrypted = core.decrypt_int(encrypted, pub_key.e, pub_key.n)
+ clearsig = transform.int2bytes(decrypted, keylength)
+
+ # Get the hash method
+ method_name = _find_method_hash(clearsig)
+ message_hash = compute_hash(message, method_name)
+
+ # Reconstruct the expected padded hash
+ cleartext = HASH_ASN1[method_name] + message_hash
+ expected = _pad_for_signing(cleartext, keylength)
+
+ if len(signature) != keylength:
+ raise VerificationError("Verification failed")
-def find_signature_hash(signature: bytes, pub_key: key.PublicKey) ->str:
+ # Compare with the signed one
+ if expected != clearsig:
+ raise VerificationError("Verification failed")
+
+ return method_name
+
+
+def find_signature_hash(signature: bytes, pub_key: key.PublicKey) -> str:
"""Returns the hash name detected from the signature.
If you also want to verify the message, use :py:func:`rsa.verify()` instead.
@@ -235,22 +387,37 @@ def find_signature_hash(signature: bytes, pub_key: key.PublicKey) ->str:
:param pub_key: the :py:class:`rsa.PublicKey` of the person signing the message.
:returns: the name of the used hash.
"""
- pass
+ keylength = common.byte_size(pub_key.n)
+ encrypted = transform.bytes2int(signature)
+ decrypted = core.decrypt_int(encrypted, pub_key.e, pub_key.n)
+ clearsig = transform.int2bytes(decrypted, keylength)
+
+ return _find_method_hash(clearsig)
-def yield_fixedblocks(infile: typing.BinaryIO, blocksize: int
- ) ->typing.Iterator[bytes]:
+
+def yield_fixedblocks(infile: typing.BinaryIO, blocksize: int) -> typing.Iterator[bytes]:
"""Generator, yields each block of ``blocksize`` bytes in the input file.
:param infile: file to read and separate in blocks.
:param blocksize: block size in bytes.
:returns: a generator that yields the contents of each block
"""
- pass
+
+ while True:
+ block = infile.read(blocksize)
+
+ read_bytes = len(block)
+ if read_bytes == 0:
+ break
+
+ yield block
+
+ if read_bytes < blocksize:
+ break
-def compute_hash(message: typing.Union[bytes, typing.BinaryIO], method_name:
- str) ->bytes:
+def compute_hash(message: typing.Union[bytes, typing.BinaryIO], method_name: str) -> bytes:
"""Returns the message digest.
:param message: the signed message. Can be an 8-bit string or a file-like
@@ -260,28 +427,59 @@ def compute_hash(message: typing.Union[bytes, typing.BinaryIO], method_name:
:py:const:`rsa.pkcs1.HASH_METHODS`.
"""
- pass
+ if method_name not in HASH_METHODS:
+ raise ValueError("Invalid hash method: %s" % method_name)
+
+ method = HASH_METHODS[method_name]
+ hasher = method()
+
+ if isinstance(message, bytes):
+ hasher.update(message)
+ else:
+ assert hasattr(message, "read") and hasattr(message.read, "__call__")
+ # read as 1K blocks
+ for block in yield_fixedblocks(message, 1024):
+ hasher.update(block)
+
+ return hasher.digest()
-def _find_method_hash(clearsig: bytes) ->str:
+
+def _find_method_hash(clearsig: bytes) -> str:
"""Finds the hash method.
:param clearsig: full padded ASN1 and hash.
:return: the used hash method.
:raise VerificationFailed: when the hash method cannot be found
"""
- pass
+ for (hashname, asn1code) in HASH_ASN1.items():
+ if asn1code in clearsig:
+ return hashname
+
+ raise VerificationError("Verification failed")
+
+
+__all__ = [
+ "encrypt",
+ "decrypt",
+ "sign",
+ "verify",
+ "DecryptionError",
+ "VerificationError",
+ "CryptoError",
+]
-__all__ = ['encrypt', 'decrypt', 'sign', 'verify', 'DecryptionError',
- 'VerificationError', 'CryptoError']
-if __name__ == '__main__':
- print('Running doctests 1000x or until failure')
+if __name__ == "__main__":
+ print("Running doctests 1000x or until failure")
import doctest
+
for count in range(1000):
- failures, tests = doctest.testmod()
+ (failures, tests) = doctest.testmod()
if failures:
break
+
if count % 100 == 0 and count:
- print('%i times' % count)
- print('Doctests done')
+ print("%i times" % count)
+
+ print("Doctests done")
diff --git a/rsa/pkcs1_v2.py b/rsa/pkcs1_v2.py
index e6d2e23..d68b907 100644
--- a/rsa/pkcs1_v2.py
+++ b/rsa/pkcs1_v2.py
@@ -1,12 +1,31 @@
+# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
"""Functions for PKCS#1 version 2 encryption and signing
This module implements certain functionality from PKCS#1 version 2. Main
documentation is RFC 2437: https://tools.ietf.org/html/rfc2437
"""
-from rsa import common, pkcs1, transform
+
+from rsa import (
+ common,
+ pkcs1,
+ transform,
+)
-def mgf1(seed: bytes, length: int, hasher: str='SHA-1') ->bytes:
+def mgf1(seed: bytes, length: int, hasher: str = "SHA-1") -> bytes:
"""
MGF1 is a Mask Generation Function based on a hash function.
@@ -27,17 +46,55 @@ def mgf1(seed: bytes, length: int, hasher: str='SHA-1') ->bytes:
:raise OverflowError: when `length` is too large for the specified `hasher`
:raise ValueError: when specified `hasher` is invalid
"""
- pass
+ try:
+ hash_length = pkcs1.HASH_METHODS[hasher]().digest_size
+ except KeyError as ex:
+ raise ValueError(
+ "Invalid `hasher` specified. Please select one of: {hash_list}".format(
+ hash_list=", ".join(sorted(pkcs1.HASH_METHODS.keys()))
+ )
+ ) from ex
+
+ # If l > 2^32(hLen), output "mask too long" and stop.
+ if length > (2 ** 32 * hash_length):
+ raise OverflowError(
+ "Desired length should be at most 2**32 times the hasher's output "
+ "length ({hash_length} for {hasher} function)".format(
+ hash_length=hash_length,
+ hasher=hasher,
+ )
+ )
+
+ # Looping `counter` from 0 to ceil(l / hLen)-1, build `output` based on the
+ # hashes formed by (`seed` + C), being `C` an octet string of length 4
+ # generated by converting `counter` with the primitive I2OSP
+ output = b"".join(
+ pkcs1.compute_hash(
+ seed + transform.int2bytes(counter, fill_size=4),
+ method_name=hasher,
+ )
+ for counter in range(common.ceil_div(length, hash_length) + 1)
+ )
+
+ # Output the leading `length` octets of `output` as the octet string mask.
+ return output[:length]
-__all__ = ['mgf1']
-if __name__ == '__main__':
- print('Running doctests 1000x or until failure')
+
+__all__ = [
+ "mgf1",
+]
+
+if __name__ == "__main__":
+ print("Running doctests 1000x or until failure")
import doctest
+
for count in range(1000):
- failures, tests = doctest.testmod()
+ (failures, tests) = doctest.testmod()
if failures:
break
+
if count % 100 == 0 and count:
- print('%i times' % count)
- print('Doctests done')
+ print("%i times" % count)
+
+ print("Doctests done")
diff --git a/rsa/prime.py b/rsa/prime.py
index 07ae2c5..ec486bc 100644
--- a/rsa/prime.py
+++ b/rsa/prime.py
@@ -1,23 +1,42 @@
+# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
"""Numerical functions related to primes.
Implementation based on the book Algorithm Design by Michael T. Goodrich and
Roberto Tamassia, 2002.
"""
+
import rsa.common
import rsa.randnum
-__all__ = ['getprime', 'are_relatively_prime']
+
+__all__ = ["getprime", "are_relatively_prime"]
-def gcd(p: int, q: int) ->int:
+def gcd(p: int, q: int) -> int:
"""Returns the greatest common divisor of p and q
>>> gcd(48, 180)
12
"""
- pass
+ while q != 0:
+ (p, q) = (q, p % q)
+ return p
-def get_primality_testing_rounds(number: int) ->int:
+
+def get_primality_testing_rounds(number: int) -> int:
"""Returns minimum number of rounds for Miller-Rabing primality testing,
based on number bitsize.
@@ -29,10 +48,21 @@ def get_primality_testing_rounds(number: int) ->int:
* p, q bitsize: 1536; rounds: 3
See: http://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.186-4.pdf
"""
- pass
+
+ # Calculate number bitsize.
+ bitsize = rsa.common.bit_size(number)
+ # Set number of rounds.
+ if bitsize >= 1536:
+ return 3
+ if bitsize >= 1024:
+ return 4
+ if bitsize >= 512:
+ return 7
+ # For smaller bitsizes, set arbitrary number of rounds.
+ return 10
-def miller_rabin_primality_testing(n: int, k: int) ->bool:
+def miller_rabin_primality_testing(n: int, k: int) -> bool:
"""Calculates whether n is composite (which is always correct) or prime
(which theoretically is incorrect with error probability 4**-k), by
applying Miller-Rabin primality testing.
@@ -47,10 +77,45 @@ def miller_rabin_primality_testing(n: int, k: int) ->bool:
:return: False if the number is composite, True if it's probably prime.
:rtype: bool
"""
- pass
+
+ # prevent potential infinite loop when d = 0
+ if n < 2:
+ return False
+
+ # Decompose (n - 1) to write it as (2 ** r) * d
+ # While d is even, divide it by 2 and increase the exponent.
+ d = n - 1
+ r = 0
+
+ while not (d & 1):
+ r += 1
+ d >>= 1
+
+ # Test k witnesses.
+ for _ in range(k):
+ # Generate random integer a, where 2 <= a <= (n - 2)
+ a = rsa.randnum.randint(n - 3) + 1
+
+ x = pow(a, d, n)
+ if x == 1 or x == n - 1:
+ continue
+
+ for _ in range(r - 1):
+ x = pow(x, 2, n)
+ if x == 1:
+ # n is composite.
+ return False
+ if x == n - 1:
+ # Exit inner loop and continue with next witness.
+ break
+ else:
+ # If loop doesn't break, n is composite.
+ return False
+
+ return True
-def is_prime(number: int) ->bool:
+def is_prime(number: int) -> bool:
"""Returns True if the number is prime, and False otherwise.
>>> is_prime(2)
@@ -60,10 +125,23 @@ def is_prime(number: int) ->bool:
>>> is_prime(41)
True
"""
- pass
+ # Check for small numbers.
+ if number < 10:
+ return number in {2, 3, 5, 7}
-def getprime(nbits: int) ->int:
+ # Check for even numbers.
+ if not (number & 1):
+ return False
+
+ # Calculate minimum number of rounds.
+ k = get_primality_testing_rounds(number)
+
+ # Run primality testing with (minimum + 1) rounds.
+ return miller_rabin_primality_testing(number, k + 1)
+
+
+def getprime(nbits: int) -> int:
"""Returns a prime number that can be stored in 'nbits' bits.
>>> p = getprime(128)
@@ -78,10 +156,20 @@ def getprime(nbits: int) ->int:
>>> common.bit_size(p) == 128
True
"""
- pass
+ assert nbits > 3 # the loop will hang on too small numbers
+
+ while True:
+ integer = rsa.randnum.read_random_odd_int(nbits)
+
+ # Test for primeness
+ if is_prime(integer):
+ return integer
-def are_relatively_prime(a: int, b: int) ->bool:
+ # Retry if not prime
+
+
+def are_relatively_prime(a: int, b: int) -> bool:
"""Returns True if a and b are relatively prime, and False if they
are not.
@@ -90,16 +178,21 @@ def are_relatively_prime(a: int, b: int) ->bool:
>>> are_relatively_prime(2, 4)
False
"""
- pass
+
+ d = gcd(a, b)
+ return d == 1
-if __name__ == '__main__':
- print('Running doctests 1000x or until failure')
+if __name__ == "__main__":
+ print("Running doctests 1000x or until failure")
import doctest
+
for count in range(1000):
- failures, tests = doctest.testmod()
+ (failures, tests) = doctest.testmod()
if failures:
break
+
if count % 100 == 0 and count:
- print('%i times' % count)
- print('Doctests done')
+ print("%i times" % count)
+
+ print("Doctests done")
diff --git a/rsa/randnum.py b/rsa/randnum.py
index 24066b9..c65facd 100644
--- a/rsa/randnum.py
+++ b/rsa/randnum.py
@@ -1,37 +1,95 @@
+# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
"""Functions for generating random numbers."""
+
+# Source inspired by code by Yesudeep Mangalapilly <yesudeep@gmail.com>
+
import os
import struct
+
from rsa import common, transform
-def read_random_bits(nbits: int) ->bytes:
+def read_random_bits(nbits: int) -> bytes:
"""Reads 'nbits' random bits.
If nbits isn't a whole number of bytes, an extra byte will be appended with
only the lower bits set.
"""
- pass
+ nbytes, rbits = divmod(nbits, 8)
+
+ # Get the random bytes
+ randomdata = os.urandom(nbytes)
-def read_random_int(nbits: int) ->int:
+ # Add the remaining random bits
+ if rbits > 0:
+ randomvalue = ord(os.urandom(1))
+ randomvalue >>= 8 - rbits
+ randomdata = struct.pack("B", randomvalue) + randomdata
+
+ return randomdata
+
+
+def read_random_int(nbits: int) -> int:
"""Reads a random integer of approximately nbits bits."""
- pass
+ randomdata = read_random_bits(nbits)
+ value = transform.bytes2int(randomdata)
+
+ # Ensure that the number is large enough to just fill out the required
+ # number of bits.
+ value |= 1 << (nbits - 1)
-def read_random_odd_int(nbits: int) ->int:
+ return value
+
+
+def read_random_odd_int(nbits: int) -> int:
"""Reads a random odd integer of approximately nbits bits.
>>> read_random_odd_int(512) & 1
1
"""
- pass
+
+ value = read_random_int(nbits)
+
+ # Make sure it's odd
+ return value | 1
-def randint(maxvalue: int) ->int:
+def randint(maxvalue: int) -> int:
"""Returns a random integer x with 1 <= x <= maxvalue
May take a very long time in specific situations. If maxvalue needs N bits
to store, the closer maxvalue is to (2 ** N) - 1, the faster this function
is.
"""
- pass
+
+ bit_size = common.bit_size(maxvalue)
+
+ tries = 0
+ while True:
+ value = read_random_int(bit_size)
+ if value <= maxvalue:
+ break
+
+ if tries % 10 == 0 and tries:
+ # After a lot of tries to get the right number of bits but still
+ # smaller than maxvalue, decrease the number of bits by 1. That'll
+ # dramatically increase the chances to get a large enough number.
+ bit_size -= 1
+ tries += 1
+
+ return value
diff --git a/rsa/transform.py b/rsa/transform.py
index 0601701..c609b65 100644
--- a/rsa/transform.py
+++ b/rsa/transform.py
@@ -1,25 +1,40 @@
+# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
"""Data transformation functions.
From bytes to a number, number to bytes, etc.
"""
+
import math
-def bytes2int(raw_bytes: bytes) ->int:
- """Converts a list of bytes or an 8-bit string to an integer.
+def bytes2int(raw_bytes: bytes) -> int:
+ r"""Converts a list of bytes or an 8-bit string to an integer.
When using unicode strings, encode it to some encoding like UTF8 first.
>>> (((128 * 256) + 64) * 256) + 15
8405007
- >>> bytes2int(b'\\x80@\\x0f')
+ >>> bytes2int(b'\x80@\x0f')
8405007
"""
- pass
+ return int.from_bytes(raw_bytes, "big", signed=False)
-def int2bytes(number: int, fill_size: int=0) ->bytes:
+def int2bytes(number: int, fill_size: int = 0) -> bytes:
"""
Convert an unsigned integer to bytes (big-endian)::
@@ -39,9 +54,19 @@ def int2bytes(number: int, fill_size: int=0) ->bytes:
argument to this function to be set to ``False`` otherwise, no
error will be raised.
"""
- pass
+ if number < 0:
+ raise ValueError("Number must be an unsigned integer: %d" % number)
+
+ bytes_required = max(1, math.ceil(number.bit_length() / 8))
+
+ if fill_size > 0:
+ return number.to_bytes(fill_size, "big")
-if __name__ == '__main__':
+ return number.to_bytes(bytes_required, "big")
+
+
+if __name__ == "__main__":
import doctest
+
doctest.testmod()
diff --git a/rsa/util.py b/rsa/util.py
index efc0a3b..087caf8 100644
--- a/rsa/util.py
+++ b/rsa/util.py
@@ -1,9 +1,97 @@
+# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
"""Utility functions."""
+
import sys
from optparse import OptionParser
+
import rsa.key
-def private_to_public() ->None:
+def private_to_public() -> None:
"""Reads a private key and outputs the corresponding public key."""
- pass
+
+ # Parse the CLI options
+ parser = OptionParser(
+ usage="usage: %prog [options]",
+ description="Reads a private key and outputs the "
+ "corresponding public key. Both private and public keys use "
+ "the format described in PKCS#1 v1.5",
+ )
+
+ parser.add_option(
+ "-i",
+ "--input",
+ dest="infilename",
+ type="string",
+ help="Input filename. Reads from stdin if not specified",
+ )
+ parser.add_option(
+ "-o",
+ "--output",
+ dest="outfilename",
+ type="string",
+ help="Output filename. Writes to stdout of not specified",
+ )
+
+ parser.add_option(
+ "--inform",
+ dest="inform",
+ help="key format of input - default PEM",
+ choices=("PEM", "DER"),
+ default="PEM",
+ )
+
+ parser.add_option(
+ "--outform",
+ dest="outform",
+ help="key format of output - default PEM",
+ choices=("PEM", "DER"),
+ default="PEM",
+ )
+
+ (cli, cli_args) = parser.parse_args(sys.argv)
+
+ # Read the input data
+ if cli.infilename:
+ print(
+ "Reading private key from %s in %s format" % (cli.infilename, cli.inform),
+ file=sys.stderr,
+ )
+ with open(cli.infilename, "rb") as infile:
+ in_data = infile.read()
+ else:
+ print("Reading private key from stdin in %s format" % cli.inform, file=sys.stderr)
+ in_data = sys.stdin.read().encode("ascii")
+
+ assert type(in_data) == bytes, type(in_data)
+
+ # Take the public fields and create a public key
+ priv_key = rsa.key.PrivateKey.load_pkcs1(in_data, cli.inform)
+ pub_key = rsa.key.PublicKey(priv_key.n, priv_key.e)
+
+ # Save to the output file
+ out_data = pub_key.save_pkcs1(cli.outform)
+
+ if cli.outfilename:
+ print(
+ "Writing public key to %s in %s format" % (cli.outfilename, cli.outform),
+ file=sys.stderr,
+ )
+ with open(cli.outfilename, "wb") as outfile:
+ outfile.write(out_data)
+ else:
+ print("Writing public key to stdout in %s format" % cli.outform, file=sys.stderr)
+ sys.stdout.write(out_data.decode("ascii"))