back to Reference (Gold) summary
Reference (Gold): dnspython
Pytest Summary for test tests
status | count |
---|---|
passed | 1308 |
skipped | 37 |
total | 1345 |
collected | 1345 |
Failed pytests:
Patch diff
diff --git a/dns/_asyncbackend.py b/dns/_asyncbackend.py
index eba0ba6..49f14fe 100644
--- a/dns/_asyncbackend.py
+++ b/dns/_asyncbackend.py
@@ -1,5 +1,10 @@
-class NullContext:
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# This is a nullcontext for both sync and async. 3.7 has a nullcontext,
+# but it is only for sync use.
+
+class NullContext:
def __init__(self, enter_result=None):
self.enter_result = enter_result
@@ -16,7 +21,22 @@ class NullContext:
pass
-class Socket:
+# These are declared here so backends can import them without creating
+# circular dependencies with dns.asyncbackend.
+
+
+class Socket: # pragma: no cover
+ async def close(self):
+ pass
+
+ async def getpeername(self):
+ raise NotImplementedError
+
+ async def getsockname(self):
+ raise NotImplementedError
+
+ async def getpeercert(self, timeout):
+ raise NotImplementedError
async def __aenter__(self):
return self
@@ -25,19 +45,55 @@ class Socket:
await self.close()
-class DatagramSocket(Socket):
-
+class DatagramSocket(Socket): # pragma: no cover
def __init__(self, family: int):
self.family = family
+ async def sendto(self, what, destination, timeout):
+ raise NotImplementedError
-class StreamSocket(Socket):
- pass
+ async def recvfrom(self, size, timeout):
+ raise NotImplementedError
-class NullTransport:
- pass
+class StreamSocket(Socket): # pragma: no cover
+ async def sendall(self, what, timeout):
+ raise NotImplementedError
+
+ async def recv(self, size, timeout):
+ raise NotImplementedError
-class Backend:
- pass
+class NullTransport:
+ async def connect_tcp(self, host, port, timeout, local_address):
+ raise NotImplementedError
+
+
+class Backend: # pragma: no cover
+ def name(self):
+ return "unknown"
+
+ async def make_socket(
+ self,
+ af,
+ socktype,
+ proto=0,
+ source=None,
+ destination=None,
+ timeout=None,
+ ssl_context=None,
+ server_hostname=None,
+ ):
+ raise NotImplementedError
+
+ def datagram_connection_required(self):
+ return False
+
+ async def sleep(self, interval):
+ raise NotImplementedError
+
+ def get_transport_class(self):
+ raise NotImplementedError
+
+ async def wait_for(self, awaitable, timeout):
+ raise NotImplementedError
diff --git a/dns/_asyncio_backend.py b/dns/_asyncio_backend.py
index 6b61355..9d9ed36 100644
--- a/dns/_asyncio_backend.py
+++ b/dns/_asyncio_backend.py
@@ -1,49 +1,140 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
"""asyncio library query support"""
+
import asyncio
import socket
import sys
+
import dns._asyncbackend
import dns._features
import dns.exception
import dns.inet
-_is_win32 = sys.platform == 'win32'
+_is_win32 = sys.platform == "win32"
+
+
+def _get_running_loop():
+ try:
+ return asyncio.get_running_loop()
+ except AttributeError: # pragma: no cover
+ return asyncio.get_event_loop()
-class _DatagramProtocol:
+class _DatagramProtocol:
def __init__(self):
self.transport = None
self.recvfrom = None
+ def connection_made(self, transport):
+ self.transport = transport
-class DatagramSocket(dns._asyncbackend.DatagramSocket):
+ def datagram_received(self, data, addr):
+ if self.recvfrom and not self.recvfrom.done():
+ self.recvfrom.set_result((data, addr))
+
+ def error_received(self, exc): # pragma: no cover
+ if self.recvfrom and not self.recvfrom.done():
+ self.recvfrom.set_exception(exc)
+
+ def connection_lost(self, exc):
+ if self.recvfrom and not self.recvfrom.done():
+ if exc is None:
+ # EOF we triggered. Is there a better way to do this?
+ try:
+ raise EOFError
+ except EOFError as e:
+ self.recvfrom.set_exception(e)
+ else:
+ self.recvfrom.set_exception(exc)
+ def close(self):
+ self.transport.close()
+
+
+async def _maybe_wait_for(awaitable, timeout):
+ if timeout is not None:
+ try:
+ return await asyncio.wait_for(awaitable, timeout)
+ except asyncio.TimeoutError:
+ raise dns.exception.Timeout(timeout=timeout)
+ else:
+ return await awaitable
+
+
+class DatagramSocket(dns._asyncbackend.DatagramSocket):
def __init__(self, family, transport, protocol):
super().__init__(family)
self.transport = transport
self.protocol = protocol
+ async def sendto(self, what, destination, timeout): # pragma: no cover
+ # no timeout for asyncio sendto
+ self.transport.sendto(what, destination)
+ return len(what)
+
+ async def recvfrom(self, size, timeout):
+ # ignore size as there's no way I know to tell protocol about it
+ done = _get_running_loop().create_future()
+ try:
+ assert self.protocol.recvfrom is None
+ self.protocol.recvfrom = done
+ await _maybe_wait_for(done, timeout)
+ return done.result()
+ finally:
+ self.protocol.recvfrom = None
+
+ async def close(self):
+ self.protocol.close()
+
+ async def getpeername(self):
+ return self.transport.get_extra_info("peername")
+
+ async def getsockname(self):
+ return self.transport.get_extra_info("sockname")
+
+ async def getpeercert(self, timeout):
+ raise NotImplementedError
-class StreamSocket(dns._asyncbackend.StreamSocket):
+class StreamSocket(dns._asyncbackend.StreamSocket):
def __init__(self, af, reader, writer):
self.family = af
self.reader = reader
self.writer = writer
+ async def sendall(self, what, timeout):
+ self.writer.write(what)
+ return await _maybe_wait_for(self.writer.drain(), timeout)
+
+ async def recv(self, size, timeout):
+ return await _maybe_wait_for(self.reader.read(size), timeout)
+
+ async def close(self):
+ self.writer.close()
+
+ async def getpeername(self):
+ return self.writer.get_extra_info("peername")
+
+ async def getsockname(self):
+ return self.writer.get_extra_info("sockname")
+
+ async def getpeercert(self, timeout):
+ return self.writer.get_extra_info("peercert")
-if dns._features.have('doh'):
+
+if dns._features.have("doh"):
import anyio
import httpcore
import httpcore._backends.anyio
import httpx
+
_CoreAsyncNetworkBackend = httpcore.AsyncNetworkBackend
_CoreAnyIOStream = httpcore._backends.anyio.AnyIOStream
- from dns.query import _compute_times, _expiration_for_this_attempt, _remaining
+ from dns.query import _compute_times, _expiration_for_this_attempt, _remaining
class _NetworkBackend(_CoreAsyncNetworkBackend):
-
def __init__(self, resolver, local_port, bootstrap_address, family):
super().__init__()
self._local_port = local_port
@@ -52,23 +143,133 @@ if dns._features.have('doh'):
self._family = family
if local_port != 0:
raise NotImplementedError(
- 'the asyncio transport for HTTPX cannot set the local port'
- )
+ "the asyncio transport for HTTPX cannot set the local port"
+ )
+ async def connect_tcp(
+ self, host, port, timeout, local_address, socket_options=None
+ ): # pylint: disable=signature-differs
+ addresses = []
+ _, expiration = _compute_times(timeout)
+ if dns.inet.is_address(host):
+ addresses.append(host)
+ elif self._bootstrap_address is not None:
+ addresses.append(self._bootstrap_address)
+ else:
+ timeout = _remaining(expiration)
+ family = self._family
+ if local_address:
+ family = dns.inet.af_for_address(local_address)
+ answers = await self._resolver.resolve_name(
+ host, family=family, lifetime=timeout
+ )
+ addresses = answers.addresses()
+ for address in addresses:
+ try:
+ attempt_expiration = _expiration_for_this_attempt(2.0, expiration)
+ timeout = _remaining(attempt_expiration)
+ with anyio.fail_after(timeout):
+ stream = await anyio.connect_tcp(
+ remote_host=address,
+ remote_port=port,
+ local_host=local_address,
+ )
+ return _CoreAnyIOStream(stream)
+ except Exception:
+ pass
+ raise httpcore.ConnectError
- class _HTTPTransport(httpx.AsyncHTTPTransport):
+ async def connect_unix_socket(
+ self, path, timeout, socket_options=None
+ ): # pylint: disable=signature-differs
+ raise NotImplementedError
+
+ async def sleep(self, seconds): # pylint: disable=signature-differs
+ await anyio.sleep(seconds)
- def __init__(self, *args, local_port=0, bootstrap_address=None,
- resolver=None, family=socket.AF_UNSPEC, **kwargs):
+ class _HTTPTransport(httpx.AsyncHTTPTransport):
+ def __init__(
+ self,
+ *args,
+ local_port=0,
+ bootstrap_address=None,
+ resolver=None,
+ family=socket.AF_UNSPEC,
+ **kwargs,
+ ):
if resolver is None:
+ # pylint: disable=import-outside-toplevel,redefined-outer-name
import dns.asyncresolver
+
resolver = dns.asyncresolver.Resolver()
super().__init__(*args, **kwargs)
- self._pool._network_backend = _NetworkBackend(resolver,
- local_port, bootstrap_address, family)
+ self._pool._network_backend = _NetworkBackend(
+ resolver, local_port, bootstrap_address, family
+ )
+
else:
- _HTTPTransport = dns._asyncbackend.NullTransport
+ _HTTPTransport = dns._asyncbackend.NullTransport # type: ignore
class Backend(dns._asyncbackend.Backend):
- pass
+ def name(self):
+ return "asyncio"
+
+ async def make_socket(
+ self,
+ af,
+ socktype,
+ proto=0,
+ source=None,
+ destination=None,
+ timeout=None,
+ ssl_context=None,
+ server_hostname=None,
+ ):
+ loop = _get_running_loop()
+ if socktype == socket.SOCK_DGRAM:
+ if _is_win32 and source is None:
+ # Win32 wants explicit binding before recvfrom(). This is the
+ # proper fix for [#637].
+ source = (dns.inet.any_for_af(af), 0)
+ transport, protocol = await loop.create_datagram_endpoint(
+ _DatagramProtocol,
+ source,
+ family=af,
+ proto=proto,
+ remote_addr=destination,
+ )
+ return DatagramSocket(af, transport, protocol)
+ elif socktype == socket.SOCK_STREAM:
+ if destination is None:
+ # This shouldn't happen, but we check to make code analysis software
+ # happier.
+ raise ValueError("destination required for stream sockets")
+ (r, w) = await _maybe_wait_for(
+ asyncio.open_connection(
+ destination[0],
+ destination[1],
+ ssl=ssl_context,
+ family=af,
+ proto=proto,
+ local_addr=source,
+ server_hostname=server_hostname,
+ ),
+ timeout,
+ )
+ return StreamSocket(af, r, w)
+ raise NotImplementedError(
+ "unsupported socket " + f"type {socktype}"
+ ) # pragma: no cover
+
+ async def sleep(self, interval):
+ await asyncio.sleep(interval)
+
+ def datagram_connection_required(self):
+ return False
+
+ def get_transport_class(self):
+ return _HTTPTransport
+
+ async def wait_for(self, awaitable, timeout):
+ return await _maybe_wait_for(awaitable, timeout)
diff --git a/dns/_ddr.py b/dns/_ddr.py
index f03d10e..bf5c11e 100644
--- a/dns/_ddr.py
+++ b/dns/_ddr.py
@@ -1,17 +1,29 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+#
+# Support for Discovery of Designated Resolvers
+
import socket
import time
from urllib.parse import urlparse
+
import dns.asyncbackend
import dns.inet
import dns.name
import dns.nameserver
import dns.query
import dns.rdtypes.svcbbase
-_local_resolver_name = dns.name.from_text('_dns.resolver.arpa')
+# The special name of the local resolver when using DDR
+_local_resolver_name = dns.name.from_text("_dns.resolver.arpa")
-class _SVCBInfo:
+#
+# Processing is split up into I/O independent and I/O dependent parts to
+# make supporting sync and async versions easy.
+#
+
+
+class _SVCBInfo:
def __init__(self, bootstrap_address, port, hostname, nameservers):
self.bootstrap_address = bootstrap_address
self.port = port
@@ -20,16 +32,123 @@ class _SVCBInfo:
def ddr_check_certificate(self, cert):
"""Verify that the _SVCBInfo's address is in the cert's subjectAltName (SAN)"""
- pass
+ for name, value in cert["subjectAltName"]:
+ if name == "IP Address" and value == self.bootstrap_address:
+ return True
+ return False
+
+ def make_tls_context(self):
+ ssl = dns.query.ssl
+ ctx = ssl.create_default_context()
+ ctx.minimum_version = ssl.TLSVersion.TLSv1_2
+ return ctx
+
+ def ddr_tls_check_sync(self, lifetime):
+ ctx = self.make_tls_context()
+ expiration = time.time() + lifetime
+ with socket.create_connection(
+ (self.bootstrap_address, self.port), lifetime
+ ) as s:
+ with ctx.wrap_socket(s, server_hostname=self.hostname) as ts:
+ ts.settimeout(dns.query._remaining(expiration))
+ ts.do_handshake()
+ cert = ts.getpeercert()
+ return self.ddr_check_certificate(cert)
+
+ async def ddr_tls_check_async(self, lifetime, backend=None):
+ if backend is None:
+ backend = dns.asyncbackend.get_default_backend()
+ ctx = self.make_tls_context()
+ expiration = time.time() + lifetime
+ async with await backend.make_socket(
+ dns.inet.af_for_address(self.bootstrap_address),
+ socket.SOCK_STREAM,
+ 0,
+ None,
+ (self.bootstrap_address, self.port),
+ lifetime,
+ ctx,
+ self.hostname,
+ ) as ts:
+ cert = await ts.getpeercert(dns.query._remaining(expiration))
+ return self.ddr_check_certificate(cert)
+
+
+def _extract_nameservers_from_svcb(answer):
+ bootstrap_address = answer.nameserver
+ if not dns.inet.is_address(bootstrap_address):
+ return []
+ infos = []
+ for rr in answer.rrset.processing_order():
+ nameservers = []
+ param = rr.params.get(dns.rdtypes.svcbbase.ParamKey.ALPN)
+ if param is None:
+ continue
+ alpns = set(param.ids)
+ host = rr.target.to_text(omit_final_dot=True)
+ port = None
+ param = rr.params.get(dns.rdtypes.svcbbase.ParamKey.PORT)
+ if param is not None:
+ port = param.port
+ # For now we ignore address hints and address resolution and always use the
+ # bootstrap address
+ if b"h2" in alpns:
+ param = rr.params.get(dns.rdtypes.svcbbase.ParamKey.DOHPATH)
+ if param is None or not param.value.endswith(b"{?dns}"):
+ continue
+ path = param.value[:-6].decode()
+ if not path.startswith("/"):
+ path = "/" + path
+ if port is None:
+ port = 443
+ url = f"https://{host}:{port}{path}"
+ # check the URL
+ try:
+ urlparse(url)
+ nameservers.append(dns.nameserver.DoHNameserver(url, bootstrap_address))
+ except Exception:
+ # continue processing other ALPN types
+ pass
+ if b"dot" in alpns:
+ if port is None:
+ port = 853
+ nameservers.append(
+ dns.nameserver.DoTNameserver(bootstrap_address, port, host)
+ )
+ if b"doq" in alpns:
+ if port is None:
+ port = 853
+ nameservers.append(
+ dns.nameserver.DoQNameserver(bootstrap_address, port, True, host)
+ )
+ if len(nameservers) > 0:
+ infos.append(_SVCBInfo(bootstrap_address, port, host, nameservers))
+ return infos
def _get_nameservers_sync(answer, lifetime):
"""Return a list of TLS-validated resolver nameservers extracted from an SVCB
answer."""
- pass
+ nameservers = []
+ infos = _extract_nameservers_from_svcb(answer)
+ for info in infos:
+ try:
+ if info.ddr_tls_check_sync(lifetime):
+ nameservers.extend(info.nameservers)
+ except Exception:
+ pass
+ return nameservers
async def _get_nameservers_async(answer, lifetime):
"""Return a list of TLS-validated resolver nameservers extracted from an SVCB
answer."""
- pass
+ nameservers = []
+ infos = _extract_nameservers_from_svcb(answer)
+ for info in infos:
+ try:
+ if await info.ddr_tls_check_async(lifetime):
+ nameservers.extend(info.nameservers)
+ except Exception:
+ pass
+ return nameservers
diff --git a/dns/_features.py b/dns/_features.py
index bb537c8..03ccaa7 100644
--- a/dns/_features.py
+++ b/dns/_features.py
@@ -1,23 +1,50 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
import importlib.metadata
import itertools
import string
from typing import Dict, List, Tuple
-def _version_check(requirement: str) ->bool:
+def _tuple_from_text(version: str) -> Tuple:
+ text_parts = version.split(".")
+ int_parts = []
+ for text_part in text_parts:
+ digit_prefix = "".join(
+ itertools.takewhile(lambda x: x in string.digits, text_part)
+ )
+ try:
+ int_parts.append(int(digit_prefix))
+ except Exception:
+ break
+ return tuple(int_parts)
+
+
+def _version_check(
+ requirement: str,
+) -> bool:
"""Is the requirement fulfilled?
The requirement must be of the form
package>=version
"""
- pass
+ package, minimum = requirement.split(">=")
+ try:
+ version = importlib.metadata.version(package)
+ except Exception:
+ return False
+ t_version = _tuple_from_text(version)
+ t_minimum = _tuple_from_text(minimum)
+ if t_version < t_minimum:
+ return False
+ return True
_cache: Dict[str, bool] = {}
-def have(feature: str) ->bool:
+def have(feature: str) -> bool:
"""Is *feature* available?
This tests if all optional packages needed for the
@@ -27,19 +54,39 @@ def have(feature: str) ->bool:
and ``False`` if it is not or if metadata is
missing.
"""
- pass
+ value = _cache.get(feature)
+ if value is not None:
+ return value
+ requirements = _requirements.get(feature)
+ if requirements is None:
+ # we make a cache entry here for consistency not performance
+ _cache[feature] = False
+ return False
+ ok = True
+ for requirement in requirements:
+ if not _version_check(requirement):
+ ok = False
+ break
+ _cache[feature] = ok
+ return ok
-def force(feature: str, enabled: bool) ->None:
+def force(feature: str, enabled: bool) -> None:
"""Force the status of *feature* to be *enabled*.
This method is provided as a workaround for any cases
where importlib.metadata is ineffective, or for testing.
"""
- pass
+ _cache[feature] = enabled
-_requirements: Dict[str, List[str]] = {'dnssec': ['cryptography>=41'],
- 'doh': ['httpcore>=1.0.0', 'httpx>=0.26.0', 'h2>=4.1.0'], 'doq': [
- 'aioquic>=0.9.25'], 'idna': ['idna>=3.6'], 'trio': ['trio>=0.23'],
- 'wmi': ['wmi>=1.5.1']}
+_requirements: Dict[str, List[str]] = {
+ ### BEGIN generated requirements
+ "dnssec": ["cryptography>=41"],
+ "doh": ["httpcore>=1.0.0", "httpx>=0.26.0", "h2>=4.1.0"],
+ "doq": ["aioquic>=0.9.25"],
+ "idna": ["idna>=3.6"],
+ "trio": ["trio>=0.23"],
+ "wmi": ["wmi>=1.5.1"],
+ ### END generated requirements
+}
diff --git a/dns/_immutable_ctx.py b/dns/_immutable_ctx.py
index f0ff108..ae7a33b 100644
--- a/dns/_immutable_ctx.py
+++ b/dns/_immutable_ctx.py
@@ -1,10 +1,22 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# This implementation of the immutable decorator requires python >=
+# 3.7, and is significantly more storage efficient when making classes
+# with slots immutable. It's also faster.
+
import contextvars
import inspect
-_in__init__ = contextvars.ContextVar('_immutable_in__init__', default=False)
+
+_in__init__ = contextvars.ContextVar("_immutable_in__init__", default=False)
class _Immutable:
"""Immutable mixin class"""
+
+ # We set slots to the empty list to say "we don't have any attributes".
+ # We do this so that if we're mixed in with a class with __slots__, we
+ # don't cause a __dict__ to be added which would waste space.
+
__slots__ = ()
def __setattr__(self, name, value):
@@ -18,3 +30,47 @@ class _Immutable:
raise TypeError("object doesn't support attribute assignment")
else:
super().__delattr__(name)
+
+
+def _immutable_init(f):
+ def nf(*args, **kwargs):
+ previous = _in__init__.set(args[0])
+ try:
+ # call the actual __init__
+ f(*args, **kwargs)
+ finally:
+ _in__init__.reset(previous)
+
+ nf.__signature__ = inspect.signature(f)
+ return nf
+
+
+def immutable(cls):
+ if _Immutable in cls.__mro__:
+ # Some ancestor already has the mixin, so just make sure we keep
+ # following the __init__ protocol.
+ cls.__init__ = _immutable_init(cls.__init__)
+ if hasattr(cls, "__setstate__"):
+ cls.__setstate__ = _immutable_init(cls.__setstate__)
+ ncls = cls
+ else:
+ # Mixin the Immutable class and follow the __init__ protocol.
+ class ncls(_Immutable, cls):
+ # We have to do the __slots__ declaration here too!
+ __slots__ = ()
+
+ @_immutable_init
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ if hasattr(cls, "__setstate__"):
+
+ @_immutable_init
+ def __setstate__(self, *args, **kwargs):
+ super().__setstate__(*args, **kwargs)
+
+ # make ncls have the same name and module as cls
+ ncls.__name__ = cls.__name__
+ ncls.__qualname__ = cls.__qualname__
+ ncls.__module__ = cls.__module__
+ return ncls
diff --git a/dns/_trio_backend.py b/dns/_trio_backend.py
index 52268ee..398e327 100644
--- a/dns/_trio_backend.py
+++ b/dns/_trio_backend.py
@@ -1,42 +1,115 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
"""trio async I/O library query support"""
+
import socket
+
import trio
-import trio.socket
+import trio.socket # type: ignore
+
import dns._asyncbackend
import dns._features
import dns.exception
import dns.inet
-if not dns._features.have('trio'):
- raise ImportError('trio not found or too old')
+
+if not dns._features.have("trio"):
+ raise ImportError("trio not found or too old")
+
+
+def _maybe_timeout(timeout):
+ if timeout is not None:
+ return trio.move_on_after(timeout)
+ else:
+ return dns._asyncbackend.NullContext()
+
+
+# for brevity
_lltuple = dns.inet.low_level_address_tuple
+# pylint: disable=redefined-outer-name
-class DatagramSocket(dns._asyncbackend.DatagramSocket):
+class DatagramSocket(dns._asyncbackend.DatagramSocket):
def __init__(self, socket):
super().__init__(socket.family)
self.socket = socket
+ async def sendto(self, what, destination, timeout):
+ with _maybe_timeout(timeout):
+ return await self.socket.sendto(what, destination)
+ raise dns.exception.Timeout(
+ timeout=timeout
+ ) # pragma: no cover lgtm[py/unreachable-statement]
-class StreamSocket(dns._asyncbackend.StreamSocket):
+ async def recvfrom(self, size, timeout):
+ with _maybe_timeout(timeout):
+ return await self.socket.recvfrom(size)
+ raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement]
+
+ async def close(self):
+ self.socket.close()
+ async def getpeername(self):
+ return self.socket.getpeername()
+
+ async def getsockname(self):
+ return self.socket.getsockname()
+
+ async def getpeercert(self, timeout):
+ raise NotImplementedError
+
+
+class StreamSocket(dns._asyncbackend.StreamSocket):
def __init__(self, family, stream, tls=False):
self.family = family
self.stream = stream
self.tls = tls
-
-if dns._features.have('doh'):
+ async def sendall(self, what, timeout):
+ with _maybe_timeout(timeout):
+ return await self.stream.send_all(what)
+ raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement]
+
+ async def recv(self, size, timeout):
+ with _maybe_timeout(timeout):
+ return await self.stream.receive_some(size)
+ raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement]
+
+ async def close(self):
+ await self.stream.aclose()
+
+ async def getpeername(self):
+ if self.tls:
+ return self.stream.transport_stream.socket.getpeername()
+ else:
+ return self.stream.socket.getpeername()
+
+ async def getsockname(self):
+ if self.tls:
+ return self.stream.transport_stream.socket.getsockname()
+ else:
+ return self.stream.socket.getsockname()
+
+ async def getpeercert(self, timeout):
+ if self.tls:
+ with _maybe_timeout(timeout):
+ await self.stream.do_handshake()
+ return self.stream.getpeercert()
+ else:
+ raise NotImplementedError
+
+
+if dns._features.have("doh"):
import httpcore
import httpcore._backends.trio
import httpx
+
_CoreAsyncNetworkBackend = httpcore.AsyncNetworkBackend
_CoreTrioStream = httpcore._backends.trio.TrioStream
- from dns.query import _compute_times, _expiration_for_this_attempt, _remaining
+ from dns.query import _compute_times, _expiration_for_this_attempt, _remaining
class _NetworkBackend(_CoreAsyncNetworkBackend):
-
def __init__(self, resolver, local_port, bootstrap_address, family):
super().__init__()
self._local_port = local_port
@@ -44,20 +117,134 @@ if dns._features.have('doh'):
self._bootstrap_address = bootstrap_address
self._family = family
+ async def connect_tcp(
+ self, host, port, timeout, local_address, socket_options=None
+ ): # pylint: disable=signature-differs
+ addresses = []
+ _, expiration = _compute_times(timeout)
+ if dns.inet.is_address(host):
+ addresses.append(host)
+ elif self._bootstrap_address is not None:
+ addresses.append(self._bootstrap_address)
+ else:
+ timeout = _remaining(expiration)
+ family = self._family
+ if local_address:
+ family = dns.inet.af_for_address(local_address)
+ answers = await self._resolver.resolve_name(
+ host, family=family, lifetime=timeout
+ )
+ addresses = answers.addresses()
+ for address in addresses:
+ try:
+ af = dns.inet.af_for_address(address)
+ if local_address is not None or self._local_port != 0:
+ source = (local_address, self._local_port)
+ else:
+ source = None
+ destination = (address, port)
+ attempt_expiration = _expiration_for_this_attempt(2.0, expiration)
+ timeout = _remaining(attempt_expiration)
+ sock = await Backend().make_socket(
+ af, socket.SOCK_STREAM, 0, source, destination, timeout
+ )
+ return _CoreTrioStream(sock.stream)
+ except Exception:
+ continue
+ raise httpcore.ConnectError
+
+ async def connect_unix_socket(
+ self, path, timeout, socket_options=None
+ ): # pylint: disable=signature-differs
+ raise NotImplementedError
+
+ async def sleep(self, seconds): # pylint: disable=signature-differs
+ await trio.sleep(seconds)
class _HTTPTransport(httpx.AsyncHTTPTransport):
-
- def __init__(self, *args, local_port=0, bootstrap_address=None,
- resolver=None, family=socket.AF_UNSPEC, **kwargs):
+ def __init__(
+ self,
+ *args,
+ local_port=0,
+ bootstrap_address=None,
+ resolver=None,
+ family=socket.AF_UNSPEC,
+ **kwargs,
+ ):
if resolver is None:
+ # pylint: disable=import-outside-toplevel,redefined-outer-name
import dns.asyncresolver
+
resolver = dns.asyncresolver.Resolver()
super().__init__(*args, **kwargs)
- self._pool._network_backend = _NetworkBackend(resolver,
- local_port, bootstrap_address, family)
+ self._pool._network_backend = _NetworkBackend(
+ resolver, local_port, bootstrap_address, family
+ )
+
else:
- _HTTPTransport = dns._asyncbackend.NullTransport
+ _HTTPTransport = dns._asyncbackend.NullTransport # type: ignore
class Backend(dns._asyncbackend.Backend):
- pass
+ def name(self):
+ return "trio"
+
+ async def make_socket(
+ self,
+ af,
+ socktype,
+ proto=0,
+ source=None,
+ destination=None,
+ timeout=None,
+ ssl_context=None,
+ server_hostname=None,
+ ):
+ s = trio.socket.socket(af, socktype, proto)
+ stream = None
+ try:
+ if source:
+ await s.bind(_lltuple(source, af))
+ if socktype == socket.SOCK_STREAM:
+ connected = False
+ with _maybe_timeout(timeout):
+ await s.connect(_lltuple(destination, af))
+ connected = True
+ if not connected:
+ raise dns.exception.Timeout(
+ timeout=timeout
+ ) # lgtm[py/unreachable-statement]
+ except Exception: # pragma: no cover
+ s.close()
+ raise
+ if socktype == socket.SOCK_DGRAM:
+ return DatagramSocket(s)
+ elif socktype == socket.SOCK_STREAM:
+ stream = trio.SocketStream(s)
+ tls = False
+ if ssl_context:
+ tls = True
+ try:
+ stream = trio.SSLStream(
+ stream, ssl_context, server_hostname=server_hostname
+ )
+ except Exception: # pragma: no cover
+ await stream.aclose()
+ raise
+ return StreamSocket(af, stream, tls)
+ raise NotImplementedError(
+ "unsupported socket " + f"type {socktype}"
+ ) # pragma: no cover
+
+ async def sleep(self, interval):
+ await trio.sleep(interval)
+
+ def get_transport_class(self):
+ return _HTTPTransport
+
+ async def wait_for(self, awaitable, timeout):
+ with _maybe_timeout(timeout):
+ return await awaitable
+ raise dns.exception.Timeout(
+ timeout=timeout
+ ) # pragma: no cover lgtm[py/unreachable-statement]
diff --git a/dns/asyncbackend.py b/dns/asyncbackend.py
index 3e2691b..0ec58b0 100644
--- a/dns/asyncbackend.py
+++ b/dns/asyncbackend.py
@@ -1,8 +1,24 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
from typing import Dict
+
import dns.exception
-from dns._asyncbackend import Backend, DatagramSocket, Socket, StreamSocket
+
+# pylint: disable=unused-import
+from dns._asyncbackend import ( # noqa: F401 lgtm[py/unused-import]
+ Backend,
+ DatagramSocket,
+ Socket,
+ StreamSocket,
+)
+
+# pylint: enable=unused-import
+
_default_backend = None
+
_backends: Dict[str, Backend] = {}
+
+# Allow sniffio import to be disabled for testing purposes
_no_sniffio = False
@@ -10,7 +26,7 @@ class AsyncLibraryNotFoundError(dns.exception.DNSException):
pass
-def get_backend(name: str) ->Backend:
+def get_backend(name: str) -> Backend:
"""Get the specified asynchronous backend.
*name*, a ``str``, the name of the backend. Currently the "trio"
@@ -18,25 +34,60 @@ def get_backend(name: str) ->Backend:
Raises NotImplementedError if an unknown backend name is specified.
"""
- pass
+ # pylint: disable=import-outside-toplevel,redefined-outer-name
+ backend = _backends.get(name)
+ if backend:
+ return backend
+ if name == "trio":
+ import dns._trio_backend
+ backend = dns._trio_backend.Backend()
+ elif name == "asyncio":
+ import dns._asyncio_backend
-def sniff() ->str:
+ backend = dns._asyncio_backend.Backend()
+ else:
+ raise NotImplementedError(f"unimplemented async backend {name}")
+ _backends[name] = backend
+ return backend
+
+
+def sniff() -> str:
"""Attempt to determine the in-use asynchronous I/O library by using
the ``sniffio`` module if it is available.
Returns the name of the library, or raises AsyncLibraryNotFoundError
if the library cannot be determined.
"""
- pass
+ # pylint: disable=import-outside-toplevel
+ try:
+ if _no_sniffio:
+ raise ImportError
+ import sniffio
+
+ try:
+ return sniffio.current_async_library()
+ except sniffio.AsyncLibraryNotFoundError:
+ raise AsyncLibraryNotFoundError("sniffio cannot determine async library")
+ except ImportError:
+ import asyncio
+ try:
+ asyncio.get_running_loop()
+ return "asyncio"
+ except RuntimeError:
+ raise AsyncLibraryNotFoundError("no async library detected")
-def get_default_backend() ->Backend:
+
+def get_default_backend() -> Backend:
"""Get the default backend, initializing it if necessary."""
- pass
+ if _default_backend:
+ return _default_backend
+ return set_default_backend(sniff())
-def set_default_backend(name: str) ->Backend:
+
+def set_default_backend(name: str) -> Backend:
"""Set the default backend.
It's not normally necessary to call this method, as
@@ -45,4 +96,6 @@ def set_default_backend(name: str) ->Backend:
in testing situations, this function allows the backend to be set
explicitly.
"""
- pass
+ global _default_backend
+ _default_backend = get_backend(name)
+ return _default_backend
diff --git a/dns/asyncquery.py b/dns/asyncquery.py
index 5fbfbb5..4d9ab9a 100644
--- a/dns/asyncquery.py
+++ b/dns/asyncquery.py
@@ -1,10 +1,29 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2017 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
"""Talk to a DNS server."""
+
import base64
import contextlib
import socket
import struct
import time
from typing import Any, Dict, Optional, Tuple, Union
+
import dns.asyncbackend
import dns.exception
import dns.inet
@@ -16,15 +35,57 @@ import dns.rdataclass
import dns.rdatatype
import dns.transaction
from dns._asyncbackend import NullContext
-from dns.query import BadResponse, NoDOH, NoDOQ, UDPMode, _compute_times, _make_dot_ssl_context, _matches_destination, _remaining, have_doh, ssl
+from dns.query import (
+ BadResponse,
+ NoDOH,
+ NoDOQ,
+ UDPMode,
+ _compute_times,
+ _make_dot_ssl_context,
+ _matches_destination,
+ _remaining,
+ have_doh,
+ ssl,
+)
+
if have_doh:
import httpx
+
+# for brevity
_lltuple = dns.inet.low_level_address_tuple
-async def send_udp(sock: dns.asyncbackend.DatagramSocket, what: Union[dns.
- message.Message, bytes], destination: Any, expiration: Optional[float]=None
- ) ->Tuple[int, float]:
+def _source_tuple(af, address, port):
+ # Make a high level source tuple, or return None if address and port
+ # are both None
+ if address or port:
+ if address is None:
+ if af == socket.AF_INET:
+ address = "0.0.0.0"
+ elif af == socket.AF_INET6:
+ address = "::"
+ else:
+ raise NotImplementedError(f"unknown address family {af}")
+ return (address, port)
+ else:
+ return None
+
+
+def _timeout(expiration, now=None):
+ if expiration is not None:
+ if not now:
+ now = time.time()
+ return max(expiration - now, 0)
+ else:
+ return None
+
+
+async def send_udp(
+ sock: dns.asyncbackend.DatagramSocket,
+ what: Union[dns.message.Message, bytes],
+ destination: Any,
+ expiration: Optional[float] = None,
+) -> Tuple[int, float]:
"""Send a DNS message to the specified UDP socket.
*sock*, a ``dns.asyncbackend.DatagramSocket``.
@@ -41,16 +102,27 @@ async def send_udp(sock: dns.asyncbackend.DatagramSocket, what: Union[dns.
Returns an ``(int, float)`` tuple of bytes sent and the sent time.
"""
- pass
-
-async def receive_udp(sock: dns.asyncbackend.DatagramSocket, destination:
- Optional[Any]=None, expiration: Optional[float]=None, ignore_unexpected:
- bool=False, one_rr_per_rrset: bool=False, keyring: Optional[Dict[dns.
- name.Name, dns.tsig.Key]]=None, request_mac: Optional[bytes]=b'',
- ignore_trailing: bool=False, raise_on_truncation: bool=False,
- ignore_errors: bool=False, query: Optional[dns.message.Message]=None
- ) ->Any:
+ if isinstance(what, dns.message.Message):
+ what = what.to_wire()
+ sent_time = time.time()
+ n = await sock.sendto(what, destination, _timeout(expiration, sent_time))
+ return (n, sent_time)
+
+
+async def receive_udp(
+ sock: dns.asyncbackend.DatagramSocket,
+ destination: Optional[Any] = None,
+ expiration: Optional[float] = None,
+ ignore_unexpected: bool = False,
+ one_rr_per_rrset: bool = False,
+ keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]] = None,
+ request_mac: Optional[bytes] = b"",
+ ignore_trailing: bool = False,
+ raise_on_truncation: bool = False,
+ ignore_errors: bool = False,
+ query: Optional[dns.message.Message] = None,
+) -> Any:
"""Read a DNS message from a UDP socket.
*sock*, a ``dns.asyncbackend.DatagramSocket``.
@@ -61,16 +133,59 @@ async def receive_udp(sock: dns.asyncbackend.DatagramSocket, destination:
Returns a ``(dns.message.Message, float, tuple)`` tuple of the received message, the
received time, and the address where the message arrived from.
"""
- pass
-
-async def udp(q: dns.message.Message, where: str, timeout: Optional[float]=
- None, port: int=53, source: Optional[str]=None, source_port: int=0,
- ignore_unexpected: bool=False, one_rr_per_rrset: bool=False,
- ignore_trailing: bool=False, raise_on_truncation: bool=False, sock:
- Optional[dns.asyncbackend.DatagramSocket]=None, backend: Optional[dns.
- asyncbackend.Backend]=None, ignore_errors: bool=False
- ) ->dns.message.Message:
+ wire = b""
+ while True:
+ (wire, from_address) = await sock.recvfrom(65535, _timeout(expiration))
+ if not _matches_destination(
+ sock.family, from_address, destination, ignore_unexpected
+ ):
+ continue
+ received_time = time.time()
+ try:
+ r = dns.message.from_wire(
+ wire,
+ keyring=keyring,
+ request_mac=request_mac,
+ one_rr_per_rrset=one_rr_per_rrset,
+ ignore_trailing=ignore_trailing,
+ raise_on_truncation=raise_on_truncation,
+ )
+ except dns.message.Truncated as e:
+ # See the comment in query.py for details.
+ if (
+ ignore_errors
+ and query is not None
+ and not query.is_response(e.message())
+ ):
+ continue
+ else:
+ raise
+ except Exception:
+ if ignore_errors:
+ continue
+ else:
+ raise
+ if ignore_errors and query is not None and not query.is_response(r):
+ continue
+ return (r, received_time, from_address)
+
+
+async def udp(
+ q: dns.message.Message,
+ where: str,
+ timeout: Optional[float] = None,
+ port: int = 53,
+ source: Optional[str] = None,
+ source_port: int = 0,
+ ignore_unexpected: bool = False,
+ one_rr_per_rrset: bool = False,
+ ignore_trailing: bool = False,
+ raise_on_truncation: bool = False,
+ sock: Optional[dns.asyncbackend.DatagramSocket] = None,
+ backend: Optional[dns.asyncbackend.Backend] = None,
+ ignore_errors: bool = False,
+) -> dns.message.Message:
"""Return the response obtained after sending a query via UDP.
*sock*, a ``dns.asyncbackend.DatagramSocket``, or ``None``,
@@ -84,16 +199,59 @@ async def udp(q: dns.message.Message, where: str, timeout: Optional[float]=
See :py:func:`dns.query.udp()` for the documentation of the other
parameters, exceptions, and return type of this method.
"""
- pass
-
-
-async def udp_with_fallback(q: dns.message.Message, where: str, timeout:
- Optional[float]=None, port: int=53, source: Optional[str]=None,
- source_port: int=0, ignore_unexpected: bool=False, one_rr_per_rrset:
- bool=False, ignore_trailing: bool=False, udp_sock: Optional[dns.
- asyncbackend.DatagramSocket]=None, tcp_sock: Optional[dns.asyncbackend.
- StreamSocket]=None, backend: Optional[dns.asyncbackend.Backend]=None,
- ignore_errors: bool=False) ->Tuple[dns.message.Message, bool]:
+ wire = q.to_wire()
+ (begin_time, expiration) = _compute_times(timeout)
+ af = dns.inet.af_for_address(where)
+ destination = _lltuple((where, port), af)
+ if sock:
+ cm: contextlib.AbstractAsyncContextManager = NullContext(sock)
+ else:
+ if not backend:
+ backend = dns.asyncbackend.get_default_backend()
+ stuple = _source_tuple(af, source, source_port)
+ if backend.datagram_connection_required():
+ dtuple = (where, port)
+ else:
+ dtuple = None
+ cm = await backend.make_socket(af, socket.SOCK_DGRAM, 0, stuple, dtuple)
+ async with cm as s:
+ await send_udp(s, wire, destination, expiration)
+ (r, received_time, _) = await receive_udp(
+ s,
+ destination,
+ expiration,
+ ignore_unexpected,
+ one_rr_per_rrset,
+ q.keyring,
+ q.mac,
+ ignore_trailing,
+ raise_on_truncation,
+ ignore_errors,
+ q,
+ )
+ r.time = received_time - begin_time
+ # We don't need to check q.is_response() if we are in ignore_errors mode
+ # as receive_udp() will have checked it.
+ if not (ignore_errors or q.is_response(r)):
+ raise BadResponse
+ return r
+
+
+async def udp_with_fallback(
+ q: dns.message.Message,
+ where: str,
+ timeout: Optional[float] = None,
+ port: int = 53,
+ source: Optional[str] = None,
+ source_port: int = 0,
+ ignore_unexpected: bool = False,
+ one_rr_per_rrset: bool = False,
+ ignore_trailing: bool = False,
+ udp_sock: Optional[dns.asyncbackend.DatagramSocket] = None,
+ tcp_sock: Optional[dns.asyncbackend.StreamSocket] = None,
+ backend: Optional[dns.asyncbackend.Backend] = None,
+ ignore_errors: bool = False,
+) -> Tuple[dns.message.Message, bool]:
"""Return the response to the query, trying UDP first and falling back
to TCP if UDP results in a truncated response.
@@ -114,12 +272,44 @@ async def udp_with_fallback(q: dns.message.Message, where: str, timeout:
of the other parameters, exceptions, and return type of this
method.
"""
- pass
-
-
-async def send_tcp(sock: dns.asyncbackend.StreamSocket, what: Union[dns.
- message.Message, bytes], expiration: Optional[float]=None) ->Tuple[int,
- float]:
+ try:
+ response = await udp(
+ q,
+ where,
+ timeout,
+ port,
+ source,
+ source_port,
+ ignore_unexpected,
+ one_rr_per_rrset,
+ ignore_trailing,
+ True,
+ udp_sock,
+ backend,
+ ignore_errors,
+ )
+ return (response, False)
+ except dns.message.Truncated:
+ response = await tcp(
+ q,
+ where,
+ timeout,
+ port,
+ source,
+ source_port,
+ one_rr_per_rrset,
+ ignore_trailing,
+ tcp_sock,
+ backend,
+ )
+ return (response, True)
+
+
+async def send_tcp(
+ sock: dns.asyncbackend.StreamSocket,
+ what: Union[dns.message.Message, bytes],
+ expiration: Optional[float] = None,
+) -> Tuple[int, float]:
"""Send a DNS message to the specified TCP socket.
*sock*, a ``dns.asyncbackend.StreamSocket``.
@@ -127,20 +317,41 @@ async def send_tcp(sock: dns.asyncbackend.StreamSocket, what: Union[dns.
See :py:func:`dns.query.send_tcp()` for the documentation of the other
parameters, exceptions, and return type of this method.
"""
- pass
+
+ if isinstance(what, dns.message.Message):
+ tcpmsg = what.to_wire(prepend_length=True)
+ else:
+ # copying the wire into tcpmsg is inefficient, but lets us
+ # avoid writev() or doing a short write that would get pushed
+ # onto the net
+ tcpmsg = len(what).to_bytes(2, "big") + what
+ sent_time = time.time()
+ await sock.sendall(tcpmsg, _timeout(expiration, sent_time))
+ return (len(tcpmsg), sent_time)
async def _read_exactly(sock, count, expiration):
"""Read the specified number of bytes from stream. Keep trying until we
either get the desired amount, or we hit EOF.
"""
- pass
-
-
-async def receive_tcp(sock: dns.asyncbackend.StreamSocket, expiration:
- Optional[float]=None, one_rr_per_rrset: bool=False, keyring: Optional[
- Dict[dns.name.Name, dns.tsig.Key]]=None, request_mac: Optional[bytes]=
- b'', ignore_trailing: bool=False) ->Tuple[dns.message.Message, float]:
+ s = b""
+ while count > 0:
+ n = await sock.recv(count, _timeout(expiration))
+ if n == b"":
+ raise EOFError
+ count = count - len(n)
+ s = s + n
+ return s
+
+
+async def receive_tcp(
+ sock: dns.asyncbackend.StreamSocket,
+ expiration: Optional[float] = None,
+ one_rr_per_rrset: bool = False,
+ keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]] = None,
+ request_mac: Optional[bytes] = b"",
+ ignore_trailing: bool = False,
+) -> Tuple[dns.message.Message, float]:
"""Read a DNS message from a TCP socket.
*sock*, a ``dns.asyncbackend.StreamSocket``.
@@ -148,14 +359,33 @@ async def receive_tcp(sock: dns.asyncbackend.StreamSocket, expiration:
See :py:func:`dns.query.receive_tcp()` for the documentation of the other
parameters, exceptions, and return type of this method.
"""
- pass
-
-async def tcp(q: dns.message.Message, where: str, timeout: Optional[float]=
- None, port: int=53, source: Optional[str]=None, source_port: int=0,
- one_rr_per_rrset: bool=False, ignore_trailing: bool=False, sock:
- Optional[dns.asyncbackend.StreamSocket]=None, backend: Optional[dns.
- asyncbackend.Backend]=None) ->dns.message.Message:
+ ldata = await _read_exactly(sock, 2, expiration)
+ (l,) = struct.unpack("!H", ldata)
+ wire = await _read_exactly(sock, l, expiration)
+ received_time = time.time()
+ r = dns.message.from_wire(
+ wire,
+ keyring=keyring,
+ request_mac=request_mac,
+ one_rr_per_rrset=one_rr_per_rrset,
+ ignore_trailing=ignore_trailing,
+ )
+ return (r, received_time)
+
+
+async def tcp(
+ q: dns.message.Message,
+ where: str,
+ timeout: Optional[float] = None,
+ port: int = 53,
+ source: Optional[str] = None,
+ source_port: int = 0,
+ one_rr_per_rrset: bool = False,
+ ignore_trailing: bool = False,
+ sock: Optional[dns.asyncbackend.StreamSocket] = None,
+ backend: Optional[dns.asyncbackend.Backend] = None,
+) -> dns.message.Message:
"""Return the response obtained after sending a query via TCP.
*sock*, a ``dns.asyncbacket.StreamSocket``, or ``None``, the
@@ -169,16 +399,52 @@ async def tcp(q: dns.message.Message, where: str, timeout: Optional[float]=
See :py:func:`dns.query.tcp()` for the documentation of the other
parameters, exceptions, and return type of this method.
"""
- pass
-
-async def tls(q: dns.message.Message, where: str, timeout: Optional[float]=
- None, port: int=853, source: Optional[str]=None, source_port: int=0,
- one_rr_per_rrset: bool=False, ignore_trailing: bool=False, sock:
- Optional[dns.asyncbackend.StreamSocket]=None, backend: Optional[dns.
- asyncbackend.Backend]=None, ssl_context: Optional[ssl.SSLContext]=None,
- server_hostname: Optional[str]=None, verify: Union[bool, str]=True
- ) ->dns.message.Message:
+ wire = q.to_wire()
+ (begin_time, expiration) = _compute_times(timeout)
+ if sock:
+ # Verify that the socket is connected, as if it's not connected,
+ # it's not writable, and the polling in send_tcp() will time out or
+ # hang forever.
+ await sock.getpeername()
+ cm: contextlib.AbstractAsyncContextManager = NullContext(sock)
+ else:
+ # These are simple (address, port) pairs, not family-dependent tuples
+ # you pass to low-level socket code.
+ af = dns.inet.af_for_address(where)
+ stuple = _source_tuple(af, source, source_port)
+ dtuple = (where, port)
+ if not backend:
+ backend = dns.asyncbackend.get_default_backend()
+ cm = await backend.make_socket(
+ af, socket.SOCK_STREAM, 0, stuple, dtuple, timeout
+ )
+ async with cm as s:
+ await send_tcp(s, wire, expiration)
+ (r, received_time) = await receive_tcp(
+ s, expiration, one_rr_per_rrset, q.keyring, q.mac, ignore_trailing
+ )
+ r.time = received_time - begin_time
+ if not q.is_response(r):
+ raise BadResponse
+ return r
+
+
+async def tls(
+ q: dns.message.Message,
+ where: str,
+ timeout: Optional[float] = None,
+ port: int = 853,
+ source: Optional[str] = None,
+ source_port: int = 0,
+ one_rr_per_rrset: bool = False,
+ ignore_trailing: bool = False,
+ sock: Optional[dns.asyncbackend.StreamSocket] = None,
+ backend: Optional[dns.asyncbackend.Backend] = None,
+ ssl_context: Optional[ssl.SSLContext] = None,
+ server_hostname: Optional[str] = None,
+ verify: Union[bool, str] = True,
+) -> dns.message.Message:
"""Return the response obtained after sending a query via TLS.
*sock*, an ``asyncbackend.StreamSocket``, or ``None``, the socket
@@ -194,16 +460,63 @@ async def tls(q: dns.message.Message, where: str, timeout: Optional[float]=
See :py:func:`dns.query.tls()` for the documentation of the other
parameters, exceptions, and return type of this method.
"""
- pass
-
-
-async def https(q: dns.message.Message, where: str, timeout: Optional[float
- ]=None, port: int=443, source: Optional[str]=None, source_port: int=0,
- one_rr_per_rrset: bool=False, ignore_trailing: bool=False, client:
- Optional['httpx.AsyncClient']=None, path: str='/dns-query', post: bool=
- True, verify: Union[bool, str]=True, bootstrap_address: Optional[str]=
- None, resolver: Optional['dns.asyncresolver.Resolver']=None, family:
- Optional[int]=socket.AF_UNSPEC) ->dns.message.Message:
+ (begin_time, expiration) = _compute_times(timeout)
+ if sock:
+ cm: contextlib.AbstractAsyncContextManager = NullContext(sock)
+ else:
+ if ssl_context is None:
+ ssl_context = _make_dot_ssl_context(server_hostname, verify)
+ af = dns.inet.af_for_address(where)
+ stuple = _source_tuple(af, source, source_port)
+ dtuple = (where, port)
+ if not backend:
+ backend = dns.asyncbackend.get_default_backend()
+ cm = await backend.make_socket(
+ af,
+ socket.SOCK_STREAM,
+ 0,
+ stuple,
+ dtuple,
+ timeout,
+ ssl_context,
+ server_hostname,
+ )
+ async with cm as s:
+ timeout = _timeout(expiration)
+ response = await tcp(
+ q,
+ where,
+ timeout,
+ port,
+ source,
+ source_port,
+ one_rr_per_rrset,
+ ignore_trailing,
+ s,
+ backend,
+ )
+ end_time = time.time()
+ response.time = end_time - begin_time
+ return response
+
+
+async def https(
+ q: dns.message.Message,
+ where: str,
+ timeout: Optional[float] = None,
+ port: int = 443,
+ source: Optional[str] = None,
+ source_port: int = 0, # pylint: disable=W0613
+ one_rr_per_rrset: bool = False,
+ ignore_trailing: bool = False,
+ client: Optional["httpx.AsyncClient"] = None,
+ path: str = "/dns-query",
+ post: bool = True,
+ verify: Union[bool, str] = True,
+ bootstrap_address: Optional[str] = None,
+ resolver: Optional["dns.asyncresolver.Resolver"] = None,
+ family: Optional[int] = socket.AF_UNSPEC,
+) -> dns.message.Message:
"""Return the response obtained after sending a query via DNS-over-HTTPS.
*client*, a ``httpx.AsyncClient``. If provided, the client to use for
@@ -215,14 +528,107 @@ async def https(q: dns.message.Message, where: str, timeout: Optional[float
See :py:func:`dns.query.https()` for the documentation of the other
parameters, exceptions, and return type of this method.
"""
- pass
-
-async def inbound_xfr(where: str, txn_manager: dns.transaction.
- TransactionManager, query: Optional[dns.message.Message]=None, port:
- int=53, timeout: Optional[float]=None, lifetime: Optional[float]=None,
- source: Optional[str]=None, source_port: int=0, udp_mode: UDPMode=
- UDPMode.NEVER, backend: Optional[dns.asyncbackend.Backend]=None) ->None:
+ if not have_doh:
+ raise NoDOH # pragma: no cover
+ if client and not isinstance(client, httpx.AsyncClient):
+ raise ValueError("session parameter must be an httpx.AsyncClient")
+
+ wire = q.to_wire()
+ try:
+ af = dns.inet.af_for_address(where)
+ except ValueError:
+ af = None
+ transport = None
+ headers = {"accept": "application/dns-message"}
+ if af is not None and dns.inet.is_address(where):
+ if af == socket.AF_INET:
+ url = "https://{}:{}{}".format(where, port, path)
+ elif af == socket.AF_INET6:
+ url = "https://[{}]:{}{}".format(where, port, path)
+ else:
+ url = where
+
+ backend = dns.asyncbackend.get_default_backend()
+
+ if source is None:
+ local_address = None
+ local_port = 0
+ else:
+ local_address = source
+ local_port = source_port
+ transport = backend.get_transport_class()(
+ local_address=local_address,
+ http1=True,
+ http2=True,
+ verify=verify,
+ local_port=local_port,
+ bootstrap_address=bootstrap_address,
+ resolver=resolver,
+ family=family,
+ )
+
+ if client:
+ cm: contextlib.AbstractAsyncContextManager = NullContext(client)
+ else:
+ cm = httpx.AsyncClient(
+ http1=True, http2=True, verify=verify, transport=transport
+ )
+
+ async with cm as the_client:
+ # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH
+ # GET and POST examples
+ if post:
+ headers.update(
+ {
+ "content-type": "application/dns-message",
+ "content-length": str(len(wire)),
+ }
+ )
+ response = await backend.wait_for(
+ the_client.post(url, headers=headers, content=wire), timeout
+ )
+ else:
+ wire = base64.urlsafe_b64encode(wire).rstrip(b"=")
+ twire = wire.decode() # httpx does a repr() if we give it bytes
+ response = await backend.wait_for(
+ the_client.get(url, headers=headers, params={"dns": twire}), timeout
+ )
+
+ # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH
+ # status codes
+ if response.status_code < 200 or response.status_code > 299:
+ raise ValueError(
+ "{} responded with status code {}"
+ "\nResponse body: {!r}".format(
+ where, response.status_code, response.content
+ )
+ )
+ r = dns.message.from_wire(
+ response.content,
+ keyring=q.keyring,
+ request_mac=q.request_mac,
+ one_rr_per_rrset=one_rr_per_rrset,
+ ignore_trailing=ignore_trailing,
+ )
+ r.time = response.elapsed.total_seconds()
+ if not q.is_response(r):
+ raise BadResponse
+ return r
+
+
+async def inbound_xfr(
+ where: str,
+ txn_manager: dns.transaction.TransactionManager,
+ query: Optional[dns.message.Message] = None,
+ port: int = 53,
+ timeout: Optional[float] = None,
+ lifetime: Optional[float] = None,
+ source: Optional[str] = None,
+ source_port: int = 0,
+ udp_mode: UDPMode = UDPMode.NEVER,
+ backend: Optional[dns.asyncbackend.Backend] = None,
+) -> None:
"""Conduct an inbound transfer and apply it via a transaction from the
txn_manager.
@@ -232,15 +638,100 @@ async def inbound_xfr(where: str, txn_manager: dns.transaction.
See :py:func:`dns.query.inbound_xfr()` for the documentation of
the other parameters, exceptions, and return type of this method.
"""
- pass
-
-
-async def quic(q: dns.message.Message, where: str, timeout: Optional[float]
- =None, port: int=853, source: Optional[str]=None, source_port: int=0,
- one_rr_per_rrset: bool=False, ignore_trailing: bool=False, connection:
- Optional[dns.quic.AsyncQuicConnection]=None, verify: Union[bool, str]=
- True, backend: Optional[dns.asyncbackend.Backend]=None, server_hostname:
- Optional[str]=None) ->dns.message.Message:
+ if query is None:
+ (query, serial) = dns.xfr.make_query(txn_manager)
+ else:
+ serial = dns.xfr.extract_serial_from_query(query)
+ rdtype = query.question[0].rdtype
+ is_ixfr = rdtype == dns.rdatatype.IXFR
+ origin = txn_manager.from_wire_origin()
+ wire = query.to_wire()
+ af = dns.inet.af_for_address(where)
+ stuple = _source_tuple(af, source, source_port)
+ dtuple = (where, port)
+ (_, expiration) = _compute_times(lifetime)
+ retry = True
+ while retry:
+ retry = False
+ if is_ixfr and udp_mode != UDPMode.NEVER:
+ sock_type = socket.SOCK_DGRAM
+ is_udp = True
+ else:
+ sock_type = socket.SOCK_STREAM
+ is_udp = False
+ if not backend:
+ backend = dns.asyncbackend.get_default_backend()
+ s = await backend.make_socket(
+ af, sock_type, 0, stuple, dtuple, _timeout(expiration)
+ )
+ async with s:
+ if is_udp:
+ await s.sendto(wire, dtuple, _timeout(expiration))
+ else:
+ tcpmsg = struct.pack("!H", len(wire)) + wire
+ await s.sendall(tcpmsg, expiration)
+ with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound:
+ done = False
+ tsig_ctx = None
+ while not done:
+ (_, mexpiration) = _compute_times(timeout)
+ if mexpiration is None or (
+ expiration is not None and mexpiration > expiration
+ ):
+ mexpiration = expiration
+ if is_udp:
+ destination = _lltuple((where, port), af)
+ while True:
+ timeout = _timeout(mexpiration)
+ (rwire, from_address) = await s.recvfrom(65535, timeout)
+ if _matches_destination(
+ af, from_address, destination, True
+ ):
+ break
+ else:
+ ldata = await _read_exactly(s, 2, mexpiration)
+ (l,) = struct.unpack("!H", ldata)
+ rwire = await _read_exactly(s, l, mexpiration)
+ is_ixfr = rdtype == dns.rdatatype.IXFR
+ r = dns.message.from_wire(
+ rwire,
+ keyring=query.keyring,
+ request_mac=query.mac,
+ xfr=True,
+ origin=origin,
+ tsig_ctx=tsig_ctx,
+ multi=(not is_udp),
+ one_rr_per_rrset=is_ixfr,
+ )
+ try:
+ done = inbound.process_message(r)
+ except dns.xfr.UseTCP:
+ assert is_udp # should not happen if we used TCP!
+ if udp_mode == UDPMode.ONLY:
+ raise
+ done = True
+ retry = True
+ udp_mode = UDPMode.NEVER
+ continue
+ tsig_ctx = r.tsig_ctx
+ if not retry and query.keyring and not r.had_tsig:
+ raise dns.exception.FormError("missing TSIG")
+
+
+async def quic(
+ q: dns.message.Message,
+ where: str,
+ timeout: Optional[float] = None,
+ port: int = 853,
+ source: Optional[str] = None,
+ source_port: int = 0,
+ one_rr_per_rrset: bool = False,
+ ignore_trailing: bool = False,
+ connection: Optional[dns.quic.AsyncQuicConnection] = None,
+ verify: Union[bool, str] = True,
+ backend: Optional[dns.asyncbackend.Backend] = None,
+ server_hostname: Optional[str] = None,
+) -> dns.message.Message:
"""Return the response obtained after sending an asynchronous query via
DNS-over-QUIC.
@@ -250,4 +741,40 @@ async def quic(q: dns.message.Message, where: str, timeout: Optional[float]
See :py:func:`dns.query.quic()` for the documentation of the other
parameters, exceptions, and return type of this method.
"""
- pass
+
+ if not dns.quic.have_quic:
+ raise NoDOQ("DNS-over-QUIC is not available.") # pragma: no cover
+
+ q.id = 0
+ wire = q.to_wire()
+ the_connection: dns.quic.AsyncQuicConnection
+ if connection:
+ cfactory = dns.quic.null_factory
+ mfactory = dns.quic.null_factory
+ the_connection = connection
+ else:
+ (cfactory, mfactory) = dns.quic.factories_for_backend(backend)
+
+ async with cfactory() as context:
+ async with mfactory(
+ context, verify_mode=verify, server_name=server_hostname
+ ) as the_manager:
+ if not connection:
+ the_connection = the_manager.connect(where, port, source, source_port)
+ (start, expiration) = _compute_times(timeout)
+ stream = await the_connection.make_stream(timeout)
+ async with stream:
+ await stream.send(wire, True)
+ wire = await stream.receive(_remaining(expiration))
+ finish = time.time()
+ r = dns.message.from_wire(
+ wire,
+ keyring=q.keyring,
+ request_mac=q.request_mac,
+ one_rr_per_rrset=one_rr_per_rrset,
+ ignore_trailing=ignore_trailing,
+ )
+ r.time = max(finish - start, 0.0)
+ if not q.is_response(r):
+ raise BadResponse
+ return r
diff --git a/dns/asyncresolver.py b/dns/asyncresolver.py
index e587e04..8f5e062 100644
--- a/dns/asyncresolver.py
+++ b/dns/asyncresolver.py
@@ -1,7 +1,26 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2017 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
"""Asynchronous DNS stub resolver."""
+
import socket
import time
from typing import Any, Dict, List, Optional, Union
+
import dns._ddr
import dns.asyncbackend
import dns.asyncquery
@@ -10,8 +29,12 @@ import dns.name
import dns.query
import dns.rdataclass
import dns.rdatatype
-import dns.resolver
+import dns.resolver # lgtm[py/import-and-import-from]
+
+# import some resolver symbols for brevity
from dns.resolver import NXDOMAIN, NoAnswer, NoRootSOA, NotAbsolute
+
+# for indentation purposes below
_udp = dns.asyncquery.udp
_tcp = dns.asyncquery.tcp
@@ -19,13 +42,19 @@ _tcp = dns.asyncquery.tcp
class Resolver(dns.resolver.BaseResolver):
"""Asynchronous DNS stub resolver."""
- async def resolve(self, qname: Union[dns.name.Name, str], rdtype: Union
- [dns.rdatatype.RdataType, str]=dns.rdatatype.A, rdclass: Union[dns.
- rdataclass.RdataClass, str]=dns.rdataclass.IN, tcp: bool=False,
- source: Optional[str]=None, raise_on_no_answer: bool=True,
- source_port: int=0, lifetime: Optional[float]=None, search:
- Optional[bool]=None, backend: Optional[dns.asyncbackend.Backend]=None
- ) ->dns.resolver.Answer:
+ async def resolve(
+ self,
+ qname: Union[dns.name.Name, str],
+ rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.A,
+ rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN,
+ tcp: bool = False,
+ source: Optional[str] = None,
+ raise_on_no_answer: bool = True,
+ source_port: int = 0,
+ lifetime: Optional[float] = None,
+ search: Optional[bool] = None,
+ backend: Optional[dns.asyncbackend.Backend] = None,
+ ) -> dns.resolver.Answer:
"""Query nameservers asynchronously to find the answer to the question.
*backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``,
@@ -35,10 +64,52 @@ class Resolver(dns.resolver.BaseResolver):
documentation of the other parameters, exceptions, and return
type of this method.
"""
- pass
- async def resolve_address(self, ipaddr: str, *args: Any, **kwargs: Any
- ) ->dns.resolver.Answer:
+ resolution = dns.resolver._Resolution(
+ self, qname, rdtype, rdclass, tcp, raise_on_no_answer, search
+ )
+ if not backend:
+ backend = dns.asyncbackend.get_default_backend()
+ start = time.time()
+ while True:
+ (request, answer) = resolution.next_request()
+ # Note we need to say "if answer is not None" and not just
+ # "if answer" because answer implements __len__, and python
+ # will call that. We want to return if we have an answer
+ # object, including in cases where its length is 0.
+ if answer is not None:
+ # cache hit!
+ return answer
+ assert request is not None # needed for type checking
+ done = False
+ while not done:
+ (nameserver, tcp, backoff) = resolution.next_nameserver()
+ if backoff:
+ await backend.sleep(backoff)
+ timeout = self._compute_timeout(start, lifetime, resolution.errors)
+ try:
+ response = await nameserver.async_query(
+ request,
+ timeout=timeout,
+ source=source,
+ source_port=source_port,
+ max_size=tcp,
+ backend=backend,
+ )
+ except Exception as ex:
+ (_, done) = resolution.query_result(None, ex)
+ continue
+ (answer, done) = resolution.query_result(response, None)
+ # Note we need to say "if answer is not None" and not just
+ # "if answer" because answer implements __len__, and python
+ # will call that. We want to return if we have an answer
+ # object, including in cases where its length is 0.
+ if answer is not None:
+ return answer
+
+ async def resolve_address(
+ self, ipaddr: str, *args: Any, **kwargs: Any
+ ) -> dns.resolver.Answer:
"""Use an asynchronous resolver to run a reverse query for PTR
records.
@@ -53,10 +124,23 @@ class Resolver(dns.resolver.BaseResolver):
function.
"""
- pass
-
- async def resolve_name(self, name: Union[dns.name.Name, str], family:
- int=socket.AF_UNSPEC, **kwargs: Any) ->dns.resolver.HostAnswers:
+ # We make a modified kwargs for type checking happiness, as otherwise
+ # we get a legit warning about possibly having rdtype and rdclass
+ # in the kwargs more than once.
+ modified_kwargs: Dict[str, Any] = {}
+ modified_kwargs.update(kwargs)
+ modified_kwargs["rdtype"] = dns.rdatatype.PTR
+ modified_kwargs["rdclass"] = dns.rdataclass.IN
+ return await self.resolve(
+ dns.reversename.from_address(ipaddr), *args, **modified_kwargs
+ )
+
+ async def resolve_name(
+ self,
+ name: Union[dns.name.Name, str],
+ family: int = socket.AF_UNSPEC,
+ **kwargs: Any,
+ ) -> dns.resolver.HostAnswers:
"""Use an asynchronous resolver to query for address records.
This utilizes the resolve() method to perform A and/or AAAA lookups on
@@ -71,10 +155,56 @@ class Resolver(dns.resolver.BaseResolver):
except for rdtype and rdclass are also supported by this
function.
"""
- pass
-
- async def canonical_name(self, name: Union[dns.name.Name, str]
- ) ->dns.name.Name:
+ # We make a modified kwargs for type checking happiness, as otherwise
+ # we get a legit warning about possibly having rdtype and rdclass
+ # in the kwargs more than once.
+ modified_kwargs: Dict[str, Any] = {}
+ modified_kwargs.update(kwargs)
+ modified_kwargs.pop("rdtype", None)
+ modified_kwargs["rdclass"] = dns.rdataclass.IN
+
+ if family == socket.AF_INET:
+ v4 = await self.resolve(name, dns.rdatatype.A, **modified_kwargs)
+ return dns.resolver.HostAnswers.make(v4=v4)
+ elif family == socket.AF_INET6:
+ v6 = await self.resolve(name, dns.rdatatype.AAAA, **modified_kwargs)
+ return dns.resolver.HostAnswers.make(v6=v6)
+ elif family != socket.AF_UNSPEC:
+ raise NotImplementedError(f"unknown address family {family}")
+
+ raise_on_no_answer = modified_kwargs.pop("raise_on_no_answer", True)
+ lifetime = modified_kwargs.pop("lifetime", None)
+ start = time.time()
+ v6 = await self.resolve(
+ name,
+ dns.rdatatype.AAAA,
+ raise_on_no_answer=False,
+ lifetime=self._compute_timeout(start, lifetime),
+ **modified_kwargs,
+ )
+ # Note that setting name ensures we query the same name
+ # for A as we did for AAAA. (This is just in case search lists
+ # are active by default in the resolver configuration and
+ # we might be talking to a server that says NXDOMAIN when it
+ # wants to say NOERROR no data.
+ name = v6.qname
+ v4 = await self.resolve(
+ name,
+ dns.rdatatype.A,
+ raise_on_no_answer=False,
+ lifetime=self._compute_timeout(start, lifetime),
+ **modified_kwargs,
+ )
+ answers = dns.resolver.HostAnswers.make(
+ v6=v6, v4=v4, add_empty=not raise_on_no_answer
+ )
+ if not answers:
+ raise NoAnswer(response=v6.response)
+ return answers
+
+ # pylint: disable=redefined-outer-name
+
+ async def canonical_name(self, name: Union[dns.name.Name, str]) -> dns.name.Name:
"""Determine the canonical name of *name*.
The canonical name is the name the resolver uses for queries
@@ -88,9 +218,14 @@ class Resolver(dns.resolver.BaseResolver):
Returns a ``dns.name.Name``.
"""
- pass
-
- async def try_ddr(self, lifetime: float=5.0) ->None:
+ try:
+ answer = await self.resolve(name, raise_on_no_answer=False)
+ canonical_name = answer.canonical_name
+ except dns.resolver.NXDOMAIN as e:
+ canonical_name = e.canonical_name
+ return canonical_name
+
+ async def try_ddr(self, lifetime: float = 5.0) -> None:
"""Try to update the resolver's nameservers using Discovery of Designated
Resolvers (DDR). If successful, the resolver will subsequently use
DNS-over-HTTPS or DNS-over-TLS for future queries.
@@ -109,32 +244,53 @@ class Resolver(dns.resolver.BaseResolver):
the bootstrap nameserver is in the Subject Alternative Name field of the
TLS certficate.
"""
- pass
+ try:
+ expiration = time.time() + lifetime
+ answer = await self.resolve(
+ dns._ddr._local_resolver_name, "svcb", lifetime=lifetime
+ )
+ timeout = dns.query._remaining(expiration)
+ nameservers = await dns._ddr._get_nameservers_async(answer, timeout)
+ if len(nameservers) > 0:
+ self.nameservers = nameservers
+ except Exception:
+ pass
default_resolver = None
-def get_default_resolver() ->Resolver:
+def get_default_resolver() -> Resolver:
"""Get the default asynchronous resolver, initializing it if necessary."""
- pass
+ if default_resolver is None:
+ reset_default_resolver()
+ assert default_resolver is not None
+ return default_resolver
-def reset_default_resolver() ->None:
+def reset_default_resolver() -> None:
"""Re-initialize default asynchronous resolver.
Note that the resolver configuration (i.e. /etc/resolv.conf on UNIX
systems) will be re-read immediately.
"""
- pass
-
-async def resolve(qname: Union[dns.name.Name, str], rdtype: Union[dns.
- rdatatype.RdataType, str]=dns.rdatatype.A, rdclass: Union[dns.
- rdataclass.RdataClass, str]=dns.rdataclass.IN, tcp: bool=False, source:
- Optional[str]=None, raise_on_no_answer: bool=True, source_port: int=0,
- lifetime: Optional[float]=None, search: Optional[bool]=None, backend:
- Optional[dns.asyncbackend.Backend]=None) ->dns.resolver.Answer:
+ global default_resolver
+ default_resolver = Resolver()
+
+
+async def resolve(
+ qname: Union[dns.name.Name, str],
+ rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.A,
+ rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN,
+ tcp: bool = False,
+ source: Optional[str] = None,
+ raise_on_no_answer: bool = True,
+ source_port: int = 0,
+ lifetime: Optional[float] = None,
+ search: Optional[bool] = None,
+ backend: Optional[dns.asyncbackend.Backend] = None,
+) -> dns.resolver.Answer:
"""Query nameservers asynchronously to find the answer to the question.
This is a convenience function that uses the default resolver
@@ -143,63 +299,107 @@ async def resolve(qname: Union[dns.name.Name, str], rdtype: Union[dns.
See :py:func:`dns.asyncresolver.Resolver.resolve` for more
information on the parameters.
"""
- pass
-
-async def resolve_address(ipaddr: str, *args: Any, **kwargs: Any
- ) ->dns.resolver.Answer:
+ return await get_default_resolver().resolve(
+ qname,
+ rdtype,
+ rdclass,
+ tcp,
+ source,
+ raise_on_no_answer,
+ source_port,
+ lifetime,
+ search,
+ backend,
+ )
+
+
+async def resolve_address(
+ ipaddr: str, *args: Any, **kwargs: Any
+) -> dns.resolver.Answer:
"""Use a resolver to run a reverse query for PTR records.
See :py:func:`dns.asyncresolver.Resolver.resolve_address` for more
information on the parameters.
"""
- pass
+ return await get_default_resolver().resolve_address(ipaddr, *args, **kwargs)
-async def resolve_name(name: Union[dns.name.Name, str], family: int=socket.
- AF_UNSPEC, **kwargs: Any) ->dns.resolver.HostAnswers:
+
+async def resolve_name(
+ name: Union[dns.name.Name, str], family: int = socket.AF_UNSPEC, **kwargs: Any
+) -> dns.resolver.HostAnswers:
"""Use a resolver to asynchronously query for address records.
See :py:func:`dns.asyncresolver.Resolver.resolve_name` for more
information on the parameters.
"""
- pass
+
+ return await get_default_resolver().resolve_name(name, family, **kwargs)
-async def canonical_name(name: Union[dns.name.Name, str]) ->dns.name.Name:
+async def canonical_name(name: Union[dns.name.Name, str]) -> dns.name.Name:
"""Determine the canonical name of *name*.
See :py:func:`dns.resolver.Resolver.canonical_name` for more
information on the parameters and possible exceptions.
"""
- pass
+ return await get_default_resolver().canonical_name(name)
-async def try_ddr(timeout: float=5.0) ->None:
+
+async def try_ddr(timeout: float = 5.0) -> None:
"""Try to update the default resolver's nameservers using Discovery of Designated
Resolvers (DDR). If successful, the resolver will subsequently use
DNS-over-HTTPS or DNS-over-TLS for future queries.
See :py:func:`dns.resolver.Resolver.try_ddr` for more information.
"""
- pass
+ return await get_default_resolver().try_ddr(timeout)
-async def zone_for_name(name: Union[dns.name.Name, str], rdclass: dns.
- rdataclass.RdataClass=dns.rdataclass.IN, tcp: bool=False, resolver:
- Optional[Resolver]=None, backend: Optional[dns.asyncbackend.Backend]=None
- ) ->dns.name.Name:
+async def zone_for_name(
+ name: Union[dns.name.Name, str],
+ rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN,
+ tcp: bool = False,
+ resolver: Optional[Resolver] = None,
+ backend: Optional[dns.asyncbackend.Backend] = None,
+) -> dns.name.Name:
"""Find the name of the zone which contains the specified name.
See :py:func:`dns.resolver.Resolver.zone_for_name` for more
information on the parameters and possible exceptions.
"""
- pass
-
-async def make_resolver_at(where: Union[dns.name.Name, str], port: int=53,
- family: int=socket.AF_UNSPEC, resolver: Optional[Resolver]=None
- ) ->Resolver:
+ if isinstance(name, str):
+ name = dns.name.from_text(name, dns.name.root)
+ if resolver is None:
+ resolver = get_default_resolver()
+ if not name.is_absolute():
+ raise NotAbsolute(name)
+ while True:
+ try:
+ answer = await resolver.resolve(
+ name, dns.rdatatype.SOA, rdclass, tcp, backend=backend
+ )
+ assert answer.rrset is not None
+ if answer.rrset.name == name:
+ return name
+ # otherwise we were CNAMEd or DNAMEd and need to look higher
+ except (NXDOMAIN, NoAnswer):
+ pass
+ try:
+ name = name.parent()
+ except dns.name.NoParent: # pragma: no cover
+ raise NoRootSOA
+
+
+async def make_resolver_at(
+ where: Union[dns.name.Name, str],
+ port: int = 53,
+ family: int = socket.AF_UNSPEC,
+ resolver: Optional[Resolver] = None,
+) -> Resolver:
"""Make a stub resolver using the specified destination as the full resolver.
*where*, a ``dns.name.Name`` or ``str`` the domain name or IP address of the
@@ -217,17 +417,36 @@ async def make_resolver_at(where: Union[dns.name.Name, str], port: int=53,
Returns a ``dns.resolver.Resolver`` or raises an exception.
"""
- pass
-
-
-async def resolve_at(where: Union[dns.name.Name, str], qname: Union[dns.
- name.Name, str], rdtype: Union[dns.rdatatype.RdataType, str]=dns.
- rdatatype.A, rdclass: Union[dns.rdataclass.RdataClass, str]=dns.
- rdataclass.IN, tcp: bool=False, source: Optional[str]=None,
- raise_on_no_answer: bool=True, source_port: int=0, lifetime: Optional[
- float]=None, search: Optional[bool]=None, backend: Optional[dns.
- asyncbackend.Backend]=None, port: int=53, family: int=socket.AF_UNSPEC,
- resolver: Optional[Resolver]=None) ->dns.resolver.Answer:
+ if resolver is None:
+ resolver = get_default_resolver()
+ nameservers: List[Union[str, dns.nameserver.Nameserver]] = []
+ if isinstance(where, str) and dns.inet.is_address(where):
+ nameservers.append(dns.nameserver.Do53Nameserver(where, port))
+ else:
+ answers = await resolver.resolve_name(where, family)
+ for address in answers.addresses():
+ nameservers.append(dns.nameserver.Do53Nameserver(address, port))
+ res = dns.asyncresolver.Resolver(configure=False)
+ res.nameservers = nameservers
+ return res
+
+
+async def resolve_at(
+ where: Union[dns.name.Name, str],
+ qname: Union[dns.name.Name, str],
+ rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.A,
+ rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN,
+ tcp: bool = False,
+ source: Optional[str] = None,
+ raise_on_no_answer: bool = True,
+ source_port: int = 0,
+ lifetime: Optional[float] = None,
+ search: Optional[bool] = None,
+ backend: Optional[dns.asyncbackend.Backend] = None,
+ port: int = 53,
+ family: int = socket.AF_UNSPEC,
+ resolver: Optional[Resolver] = None,
+) -> dns.resolver.Answer:
"""Query nameservers to find the answer to the question.
This is a convenience function that calls ``dns.asyncresolver.make_resolver_at()``
@@ -241,4 +460,16 @@ async def resolve_at(where: Union[dns.name.Name, str], qname: Union[dns.
``dns.asyncresolver.make_resolver_at()`` and then use that resolver for the queries
instead of calling ``resolve_at()`` multiple times.
"""
- pass
+ res = await make_resolver_at(where, port, family, resolver)
+ return await res.resolve(
+ qname,
+ rdtype,
+ rdclass,
+ tcp,
+ source,
+ raise_on_no_answer,
+ source_port,
+ lifetime,
+ search,
+ backend,
+ )
diff --git a/dns/dnssec.py b/dns/dnssec.py
index 2787e8a..e49c3b7 100644
--- a/dns/dnssec.py
+++ b/dns/dnssec.py
@@ -1,4 +1,23 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2017 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
"""Common DNSSEC-related functions and constants."""
+
+
import base64
import contextlib
import functools
@@ -7,6 +26,7 @@ import struct
import time
from datetime import datetime
from typing import Callable, Dict, List, Optional, Set, Tuple, Union, cast
+
import dns._features
import dns.exception
import dns.name
@@ -19,7 +39,12 @@ import dns.rrset
import dns.transaction
import dns.zone
from dns.dnssectypes import Algorithm, DSDigest, NSEC3Hash
-from dns.exception import AlgorithmKeyMismatch, DeniedByPolicy, UnsupportedAlgorithm, ValidationFailure
+from dns.exception import ( # pylint: disable=W0611
+ AlgorithmKeyMismatch,
+ DeniedByPolicy,
+ UnsupportedAlgorithm,
+ ValidationFailure,
+)
from dns.rdtypes.ANY.CDNSKEY import CDNSKEY
from dns.rdtypes.ANY.CDS import CDS
from dns.rdtypes.ANY.DNSKEY import DNSKEY
@@ -28,78 +53,142 @@ from dns.rdtypes.ANY.NSEC import NSEC, Bitmap
from dns.rdtypes.ANY.NSEC3PARAM import NSEC3PARAM
from dns.rdtypes.ANY.RRSIG import RRSIG, sigtime_to_posixtime
from dns.rdtypes.dnskeybase import Flag
-PublicKey = Union['GenericPublicKey', 'rsa.RSAPublicKey',
- 'ec.EllipticCurvePublicKey', 'ed25519.Ed25519PublicKey',
- 'ed448.Ed448PublicKey']
-PrivateKey = Union['GenericPrivateKey', 'rsa.RSAPrivateKey',
- 'ec.EllipticCurvePrivateKey', 'ed25519.Ed25519PrivateKey',
- 'ed448.Ed448PrivateKey']
+
+PublicKey = Union[
+ "GenericPublicKey",
+ "rsa.RSAPublicKey",
+ "ec.EllipticCurvePublicKey",
+ "ed25519.Ed25519PublicKey",
+ "ed448.Ed448PublicKey",
+]
+
+PrivateKey = Union[
+ "GenericPrivateKey",
+ "rsa.RSAPrivateKey",
+ "ec.EllipticCurvePrivateKey",
+ "ed25519.Ed25519PrivateKey",
+ "ed448.Ed448PrivateKey",
+]
+
RRsetSigner = Callable[[dns.transaction.Transaction, dns.rrset.RRset], None]
-def algorithm_from_text(text: str) ->Algorithm:
+def algorithm_from_text(text: str) -> Algorithm:
"""Convert text into a DNSSEC algorithm value.
*text*, a ``str``, the text to convert to into an algorithm value.
Returns an ``int``.
"""
- pass
+ return Algorithm.from_text(text)
-def algorithm_to_text(value: Union[Algorithm, int]) ->str:
+
+def algorithm_to_text(value: Union[Algorithm, int]) -> str:
"""Convert a DNSSEC algorithm value to text
*value*, a ``dns.dnssec.Algorithm``.
Returns a ``str``, the name of a DNSSEC algorithm.
"""
- pass
-
-def to_timestamp(value: Union[datetime, str, float, int]) ->int:
- """Convert various format to a timestamp"""
- pass
+ return Algorithm.to_text(value)
-def key_id(key: Union[DNSKEY, CDNSKEY]) ->int:
+def to_timestamp(value: Union[datetime, str, float, int]) -> int:
+ """Convert various format to a timestamp"""
+ if isinstance(value, datetime):
+ return int(value.timestamp())
+ elif isinstance(value, str):
+ return sigtime_to_posixtime(value)
+ elif isinstance(value, float):
+ return int(value)
+ elif isinstance(value, int):
+ return value
+ else:
+ raise TypeError("Unsupported timestamp type")
+
+
+def key_id(key: Union[DNSKEY, CDNSKEY]) -> int:
"""Return the key id (a 16-bit number) for the specified key.
*key*, a ``dns.rdtypes.ANY.DNSKEY.DNSKEY``
Returns an ``int`` between 0 and 65535
"""
- pass
+ rdata = key.to_wire()
+ if key.algorithm == Algorithm.RSAMD5:
+ return (rdata[-3] << 8) + rdata[-2]
+ else:
+ total = 0
+ for i in range(len(rdata) // 2):
+ total += (rdata[2 * i] << 8) + rdata[2 * i + 1]
+ if len(rdata) % 2 != 0:
+ total += rdata[len(rdata) - 1] << 8
+ total += (total >> 16) & 0xFFFF
+ return total & 0xFFFF
-class Policy:
+class Policy:
def __init__(self):
pass
+ def ok_to_sign(self, _: DNSKEY) -> bool: # pragma: no cover
+ return False
-class SimpleDeny(Policy):
+ def ok_to_validate(self, _: DNSKEY) -> bool: # pragma: no cover
+ return False
+
+ def ok_to_create_ds(self, _: DSDigest) -> bool: # pragma: no cover
+ return False
+
+ def ok_to_validate_ds(self, _: DSDigest) -> bool: # pragma: no cover
+ return False
- def __init__(self, deny_sign, deny_validate, deny_create_ds,
- deny_validate_ds):
+
+class SimpleDeny(Policy):
+ def __init__(self, deny_sign, deny_validate, deny_create_ds, deny_validate_ds):
super().__init__()
self._deny_sign = deny_sign
self._deny_validate = deny_validate
self._deny_create_ds = deny_create_ds
self._deny_validate_ds = deny_validate_ds
+ def ok_to_sign(self, key: DNSKEY) -> bool:
+ return key.algorithm not in self._deny_sign
+
+ def ok_to_validate(self, key: DNSKEY) -> bool:
+ return key.algorithm not in self._deny_validate
+
+ def ok_to_create_ds(self, algorithm: DSDigest) -> bool:
+ return algorithm not in self._deny_create_ds
+
+ def ok_to_validate_ds(self, algorithm: DSDigest) -> bool:
+ return algorithm not in self._deny_validate_ds
+
+
+rfc_8624_policy = SimpleDeny(
+ {Algorithm.RSAMD5, Algorithm.DSA, Algorithm.DSANSEC3SHA1, Algorithm.ECCGOST},
+ {Algorithm.RSAMD5, Algorithm.DSA, Algorithm.DSANSEC3SHA1},
+ {DSDigest.NULL, DSDigest.SHA1, DSDigest.GOST},
+ {DSDigest.NULL},
+)
-rfc_8624_policy = SimpleDeny({Algorithm.RSAMD5, Algorithm.DSA, Algorithm.
- DSANSEC3SHA1, Algorithm.ECCGOST}, {Algorithm.RSAMD5, Algorithm.DSA,
- Algorithm.DSANSEC3SHA1}, {DSDigest.NULL, DSDigest.SHA1, DSDigest.GOST},
- {DSDigest.NULL})
allow_all_policy = SimpleDeny(set(), set(), set(), set())
+
+
default_policy = rfc_8624_policy
-def make_ds(name: Union[dns.name.Name, str], key: dns.rdata.Rdata,
- algorithm: Union[DSDigest, str], origin: Optional[dns.name.Name]=None,
- policy: Optional[Policy]=None, validating: bool=False) ->DS:
+def make_ds(
+ name: Union[dns.name.Name, str],
+ key: dns.rdata.Rdata,
+ algorithm: Union[DSDigest, str],
+ origin: Optional[dns.name.Name] = None,
+ policy: Optional[Policy] = None,
+ validating: bool = False,
+) -> DS:
"""Create a DS record for a DNSSEC key.
*name*, a ``dns.name.Name`` or ``str``, the owner name of the DS record.
@@ -128,12 +217,52 @@ def make_ds(name: Union[dns.name.Name, str], key: dns.rdata.Rdata,
Returns a ``dns.rdtypes.ANY.DS.DS``
"""
- pass
-
-def make_cds(name: Union[dns.name.Name, str], key: dns.rdata.Rdata,
- algorithm: Union[DSDigest, str], origin: Optional[dns.name.Name]=None
- ) ->CDS:
+ if policy is None:
+ policy = default_policy
+ try:
+ if isinstance(algorithm, str):
+ algorithm = DSDigest[algorithm.upper()]
+ except Exception:
+ raise UnsupportedAlgorithm('unsupported algorithm "%s"' % algorithm)
+ if validating:
+ check = policy.ok_to_validate_ds
+ else:
+ check = policy.ok_to_create_ds
+ if not check(algorithm):
+ raise DeniedByPolicy
+ if not isinstance(key, (DNSKEY, CDNSKEY)):
+ raise ValueError("key is not a DNSKEY/CDNSKEY")
+ if algorithm == DSDigest.SHA1:
+ dshash = hashlib.sha1()
+ elif algorithm == DSDigest.SHA256:
+ dshash = hashlib.sha256()
+ elif algorithm == DSDigest.SHA384:
+ dshash = hashlib.sha384()
+ else:
+ raise UnsupportedAlgorithm('unsupported algorithm "%s"' % algorithm)
+
+ if isinstance(name, str):
+ name = dns.name.from_text(name, origin)
+ wire = name.canonicalize().to_wire()
+ assert wire is not None
+ dshash.update(wire)
+ dshash.update(key.to_wire(origin=origin))
+ digest = dshash.digest()
+
+ dsrdata = struct.pack("!HBB", key_id(key), key.algorithm, algorithm) + digest
+ ds = dns.rdata.from_wire(
+ dns.rdataclass.IN, dns.rdatatype.DS, dsrdata, 0, len(dsrdata)
+ )
+ return cast(DS, ds)
+
+
+def make_cds(
+ name: Union[dns.name.Name, str],
+ key: dns.rdata.Rdata,
+ algorithm: Union[DSDigest, str],
+ origin: Optional[dns.name.Name] = None,
+) -> CDS:
"""Create a CDS record for a DNSSEC key.
*name*, a ``dns.name.Name`` or ``str``, the owner name of the DS record.
@@ -152,13 +281,64 @@ def make_cds(name: Union[dns.name.Name, str], key: dns.rdata.Rdata,
Returns a ``dns.rdtypes.ANY.DS.CDS``
"""
- pass
-
-def _validate_rrsig(rrset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.
- rdataset.Rdataset]], rrsig: RRSIG, keys: Dict[dns.name.Name, Union[dns.
- node.Node, dns.rdataset.Rdataset]], origin: Optional[dns.name.Name]=
- None, now: Optional[float]=None, policy: Optional[Policy]=None) ->None:
+ ds = make_ds(name, key, algorithm, origin)
+ return CDS(
+ rdclass=ds.rdclass,
+ rdtype=dns.rdatatype.CDS,
+ key_tag=ds.key_tag,
+ algorithm=ds.algorithm,
+ digest_type=ds.digest_type,
+ digest=ds.digest,
+ )
+
+
+def _find_candidate_keys(
+ keys: Dict[dns.name.Name, Union[dns.rdataset.Rdataset, dns.node.Node]], rrsig: RRSIG
+) -> Optional[List[DNSKEY]]:
+ value = keys.get(rrsig.signer)
+ if isinstance(value, dns.node.Node):
+ rdataset = value.get_rdataset(dns.rdataclass.IN, dns.rdatatype.DNSKEY)
+ else:
+ rdataset = value
+ if rdataset is None:
+ return None
+ return [
+ cast(DNSKEY, rd)
+ for rd in rdataset
+ if rd.algorithm == rrsig.algorithm
+ and key_id(rd) == rrsig.key_tag
+ and (rd.flags & Flag.ZONE) == Flag.ZONE # RFC 4034 2.1.1
+ and rd.protocol == 3 # RFC 4034 2.1.2
+ ]
+
+
+def _get_rrname_rdataset(
+ rrset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdataset.Rdataset]],
+) -> Tuple[dns.name.Name, dns.rdataset.Rdataset]:
+ if isinstance(rrset, tuple):
+ return rrset[0], rrset[1]
+ else:
+ return rrset.name, rrset
+
+
+def _validate_signature(sig: bytes, data: bytes, key: DNSKEY) -> None:
+ public_cls = get_algorithm_cls_from_dnskey(key).public_cls
+ try:
+ public_key = public_cls.from_dnskey(key)
+ except ValueError:
+ raise ValidationFailure("invalid public key")
+ public_key.verify(sig, data)
+
+
+def _validate_rrsig(
+ rrset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdataset.Rdataset]],
+ rrsig: RRSIG,
+ keys: Dict[dns.name.Name, Union[dns.node.Node, dns.rdataset.Rdataset]],
+ origin: Optional[dns.name.Name] = None,
+ now: Optional[float] = None,
+ policy: Optional[Policy] = None,
+) -> None:
"""Validate an RRset against a single signature rdata, throwing an
exception if validation is not successful.
@@ -190,14 +370,44 @@ def _validate_rrsig(rrset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.
Raises ``UnsupportedAlgorithm`` if the algorithm is recognized by
dnspython but not implemented.
"""
- pass
-
-def _validate(rrset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.
- rdataset.Rdataset]], rrsigset: Union[dns.rrset.RRset, Tuple[dns.name.
- Name, dns.rdataset.Rdataset]], keys: Dict[dns.name.Name, Union[dns.node
- .Node, dns.rdataset.Rdataset]], origin: Optional[dns.name.Name]=None,
- now: Optional[float]=None, policy: Optional[Policy]=None) ->None:
+ if policy is None:
+ policy = default_policy
+
+ candidate_keys = _find_candidate_keys(keys, rrsig)
+ if candidate_keys is None:
+ raise ValidationFailure("unknown key")
+
+ if now is None:
+ now = time.time()
+ if rrsig.expiration < now:
+ raise ValidationFailure("expired")
+ if rrsig.inception > now:
+ raise ValidationFailure("not yet valid")
+
+ data = _make_rrsig_signature_data(rrset, rrsig, origin)
+
+ for candidate_key in candidate_keys:
+ if not policy.ok_to_validate(candidate_key):
+ continue
+ try:
+ _validate_signature(rrsig.signature, data, candidate_key)
+ return
+ except (InvalidSignature, ValidationFailure):
+ # this happens on an individual validation failure
+ continue
+ # nothing verified -- raise failure:
+ raise ValidationFailure("verify failure")
+
+
+def _validate(
+ rrset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdataset.Rdataset]],
+ rrsigset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdataset.Rdataset]],
+ keys: Dict[dns.name.Name, Union[dns.node.Node, dns.rdataset.Rdataset]],
+ origin: Optional[dns.name.Name] = None,
+ now: Optional[float] = None,
+ policy: Optional[Policy] = None,
+) -> None:
"""Validate an RRset against a signature RRset, throwing an exception
if none of the signatures validate.
@@ -228,15 +438,53 @@ def _validate(rrset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.
the public key is invalid, the algorithm is unknown, the verification
fails, etc.
"""
- pass
-
-def _sign(rrset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdataset.
- Rdataset]], private_key: PrivateKey, signer: dns.name.Name, dnskey:
- DNSKEY, inception: Optional[Union[datetime, str, int, float]]=None,
- expiration: Optional[Union[datetime, str, int, float]]=None, lifetime:
- Optional[int]=None, verify: bool=False, policy: Optional[Policy]=None,
- origin: Optional[dns.name.Name]=None) ->RRSIG:
+ if policy is None:
+ policy = default_policy
+
+ if isinstance(origin, str):
+ origin = dns.name.from_text(origin, dns.name.root)
+
+ if isinstance(rrset, tuple):
+ rrname = rrset[0]
+ else:
+ rrname = rrset.name
+
+ if isinstance(rrsigset, tuple):
+ rrsigname = rrsigset[0]
+ rrsigrdataset = rrsigset[1]
+ else:
+ rrsigname = rrsigset.name
+ rrsigrdataset = rrsigset
+
+ rrname = rrname.choose_relativity(origin)
+ rrsigname = rrsigname.choose_relativity(origin)
+ if rrname != rrsigname:
+ raise ValidationFailure("owner names do not match")
+
+ for rrsig in rrsigrdataset:
+ if not isinstance(rrsig, RRSIG):
+ raise ValidationFailure("expected an RRSIG")
+ try:
+ _validate_rrsig(rrset, rrsig, keys, origin, now, policy)
+ return
+ except (ValidationFailure, UnsupportedAlgorithm):
+ pass
+ raise ValidationFailure("no RRSIGs validated")
+
+
+def _sign(
+ rrset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdataset.Rdataset]],
+ private_key: PrivateKey,
+ signer: dns.name.Name,
+ dnskey: DNSKEY,
+ inception: Optional[Union[datetime, str, int, float]] = None,
+ expiration: Optional[Union[datetime, str, int, float]] = None,
+ lifetime: Optional[int] = None,
+ verify: bool = False,
+ policy: Optional[Policy] = None,
+ origin: Optional[dns.name.Name] = None,
+) -> RRSIG:
"""Sign RRset using private key.
*rrset*, the RRset to validate. This can be a
@@ -277,12 +525,80 @@ def _sign(rrset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdataset.
Raises ``DeniedByPolicy`` if the signature is denied by policy.
"""
- pass
-
-def _make_rrsig_signature_data(rrset: Union[dns.rrset.RRset, Tuple[dns.name
- .Name, dns.rdataset.Rdataset]], rrsig: RRSIG, origin: Optional[dns.name
- .Name]=None) ->bytes:
+ if policy is None:
+ policy = default_policy
+ if not policy.ok_to_sign(dnskey):
+ raise DeniedByPolicy
+
+ if isinstance(rrset, tuple):
+ rdclass = rrset[1].rdclass
+ rdtype = rrset[1].rdtype
+ rrname = rrset[0]
+ original_ttl = rrset[1].ttl
+ else:
+ rdclass = rrset.rdclass
+ rdtype = rrset.rdtype
+ rrname = rrset.name
+ original_ttl = rrset.ttl
+
+ if inception is not None:
+ rrsig_inception = to_timestamp(inception)
+ else:
+ rrsig_inception = int(time.time())
+
+ if expiration is not None:
+ rrsig_expiration = to_timestamp(expiration)
+ elif lifetime is not None:
+ rrsig_expiration = rrsig_inception + lifetime
+ else:
+ raise ValueError("expiration or lifetime must be specified")
+
+ # Derelativize now because we need a correct labels length for the
+ # rrsig_template.
+ if origin is not None:
+ rrname = rrname.derelativize(origin)
+ labels = len(rrname) - 1
+
+ # Adjust labels appropriately for wildcards.
+ if rrname.is_wild():
+ labels -= 1
+
+ rrsig_template = RRSIG(
+ rdclass=rdclass,
+ rdtype=dns.rdatatype.RRSIG,
+ type_covered=rdtype,
+ algorithm=dnskey.algorithm,
+ labels=labels,
+ original_ttl=original_ttl,
+ expiration=rrsig_expiration,
+ inception=rrsig_inception,
+ key_tag=key_id(dnskey),
+ signer=signer,
+ signature=b"",
+ )
+
+ data = dns.dnssec._make_rrsig_signature_data(rrset, rrsig_template, origin)
+
+ if isinstance(private_key, GenericPrivateKey):
+ signing_key = private_key
+ else:
+ try:
+ private_cls = get_algorithm_cls_from_dnskey(dnskey)
+ signing_key = private_cls(key=private_key)
+ except UnsupportedAlgorithm:
+ raise TypeError("Unsupported key algorithm")
+
+ signature = signing_key.sign(data, verify)
+
+ return cast(RRSIG, rrsig_template.replace(signature=signature))
+
+
+def _make_rrsig_signature_data(
+ rrset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdataset.Rdataset]],
+ rrsig: RRSIG,
+ origin: Optional[dns.name.Name] = None,
+) -> bytes:
"""Create signature rdata.
*rrset*, the RRset to sign/validate. This can be a
@@ -298,11 +614,57 @@ def _make_rrsig_signature_data(rrset: Union[dns.rrset.RRset, Tuple[dns.name
Raises ``UnsupportedAlgorithm`` if the algorithm is recognized by
dnspython but not implemented.
"""
- pass
-
-def _make_dnskey(public_key: PublicKey, algorithm: Union[int, str], flags:
- int=Flag.ZONE, protocol: int=3) ->DNSKEY:
+ if isinstance(origin, str):
+ origin = dns.name.from_text(origin, dns.name.root)
+
+ signer = rrsig.signer
+ if not signer.is_absolute():
+ if origin is None:
+ raise ValidationFailure("relative RR name without an origin specified")
+ signer = signer.derelativize(origin)
+
+ # For convenience, allow the rrset to be specified as a (name,
+ # rdataset) tuple as well as a proper rrset
+ rrname, rdataset = _get_rrname_rdataset(rrset)
+
+ data = b""
+ data += rrsig.to_wire(origin=signer)[:18]
+ data += rrsig.signer.to_digestable(signer)
+
+ # Derelativize the name before considering labels.
+ if not rrname.is_absolute():
+ if origin is None:
+ raise ValidationFailure("relative RR name without an origin specified")
+ rrname = rrname.derelativize(origin)
+
+ name_len = len(rrname)
+ if rrname.is_wild() and rrsig.labels != name_len - 2:
+ raise ValidationFailure("wild owner name has wrong label length")
+ if name_len - 1 < rrsig.labels:
+ raise ValidationFailure("owner name longer than RRSIG labels")
+ elif rrsig.labels < name_len - 1:
+ suffix = rrname.split(rrsig.labels + 1)[1]
+ rrname = dns.name.from_text("*", suffix)
+ rrnamebuf = rrname.to_digestable()
+ rrfixed = struct.pack("!HHI", rdataset.rdtype, rdataset.rdclass, rrsig.original_ttl)
+ rdatas = [rdata.to_digestable(origin) for rdata in rdataset]
+ for rdata in sorted(rdatas):
+ data += rrnamebuf
+ data += rrfixed
+ rrlen = struct.pack("!H", len(rdata))
+ data += rrlen
+ data += rdata
+
+ return data
+
+
+def _make_dnskey(
+ public_key: PublicKey,
+ algorithm: Union[int, str],
+ flags: int = Flag.ZONE,
+ protocol: int = 3,
+) -> DNSKEY:
"""Convert a public key to DNSKEY Rdata
*public_key*, a ``PublicKey`` (``GenericPublicKey`` or
@@ -321,11 +683,22 @@ def _make_dnskey(public_key: PublicKey, algorithm: Union[int, str], flags:
Return DNSKEY ``Rdata``.
"""
- pass
+
+ algorithm = Algorithm.make(algorithm)
+
+ if isinstance(public_key, GenericPublicKey):
+ return public_key.to_dnskey(flags=flags, protocol=protocol)
+ else:
+ public_cls = get_algorithm_cls(algorithm).public_cls
+ return public_cls(key=public_key).to_dnskey(flags=flags, protocol=protocol)
-def _make_cdnskey(public_key: PublicKey, algorithm: Union[int, str], flags:
- int=Flag.ZONE, protocol: int=3) ->CDNSKEY:
+def _make_cdnskey(
+ public_key: PublicKey,
+ algorithm: Union[int, str],
+ flags: int = Flag.ZONE,
+ protocol: int = 3,
+) -> CDNSKEY:
"""Convert a public key to CDNSKEY Rdata
*public_key*, the public key to convert, a
@@ -345,11 +718,25 @@ def _make_cdnskey(public_key: PublicKey, algorithm: Union[int, str], flags:
Return CDNSKEY ``Rdata``.
"""
- pass
+ dnskey = _make_dnskey(public_key, algorithm, flags, protocol)
-def nsec3_hash(domain: Union[dns.name.Name, str], salt: Optional[Union[str,
- bytes]], iterations: int, algorithm: Union[int, str]) ->str:
+ return CDNSKEY(
+ rdclass=dnskey.rdclass,
+ rdtype=dns.rdatatype.CDNSKEY,
+ flags=dnskey.flags,
+ protocol=dnskey.protocol,
+ algorithm=dnskey.algorithm,
+ key=dnskey.key,
+ )
+
+
+def nsec3_hash(
+ domain: Union[dns.name.Name, str],
+ salt: Optional[Union[str, bytes]],
+ iterations: int,
+ algorithm: Union[int, str],
+) -> str:
"""
Calculate the NSEC3 hash, according to
https://tools.ietf.org/html/rfc5155#section-5
@@ -366,12 +753,50 @@ def nsec3_hash(domain: Union[dns.name.Name, str], salt: Optional[Union[str,
Returns a ``str``, the encoded NSEC3 hash.
"""
- pass
-
-def make_ds_rdataset(rrset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns
- .rdataset.Rdataset]], algorithms: Set[Union[DSDigest, str]], origin:
- Optional[dns.name.Name]=None) ->dns.rdataset.Rdataset:
+ b32_conversion = str.maketrans(
+ "ABCDEFGHIJKLMNOPQRSTUVWXYZ234567", "0123456789ABCDEFGHIJKLMNOPQRSTUV"
+ )
+
+ try:
+ if isinstance(algorithm, str):
+ algorithm = NSEC3Hash[algorithm.upper()]
+ except Exception:
+ raise ValueError("Wrong hash algorithm (only SHA1 is supported)")
+
+ if algorithm != NSEC3Hash.SHA1:
+ raise ValueError("Wrong hash algorithm (only SHA1 is supported)")
+
+ if salt is None:
+ salt_encoded = b""
+ elif isinstance(salt, str):
+ if len(salt) % 2 == 0:
+ salt_encoded = bytes.fromhex(salt)
+ else:
+ raise ValueError("Invalid salt length")
+ else:
+ salt_encoded = salt
+
+ if not isinstance(domain, dns.name.Name):
+ domain = dns.name.from_text(domain)
+ domain_encoded = domain.canonicalize().to_wire()
+ assert domain_encoded is not None
+
+ digest = hashlib.sha1(domain_encoded + salt_encoded).digest()
+ for _ in range(iterations):
+ digest = hashlib.sha1(digest + salt_encoded).digest()
+
+ output = base64.b32encode(digest).decode("utf-8")
+ output = output.translate(b32_conversion)
+
+ return output
+
+
+def make_ds_rdataset(
+ rrset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdataset.Rdataset]],
+ algorithms: Set[Union[DSDigest, str]],
+ origin: Optional[dns.name.Name] = None,
+) -> dns.rdataset.Rdataset:
"""Create a DS record from DNSKEY/CDNSKEY/CDS.
*rrset*, the RRset to create DS Rdataset for. This can be a
@@ -391,11 +816,43 @@ def make_ds_rdataset(rrset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns
Returns a ``dns.rdataset.Rdataset``
"""
- pass
-
-def cds_rdataset_to_ds_rdataset(rdataset: dns.rdataset.Rdataset
- ) ->dns.rdataset.Rdataset:
+ rrname, rdataset = _get_rrname_rdataset(rrset)
+
+ if rdataset.rdtype not in (
+ dns.rdatatype.DNSKEY,
+ dns.rdatatype.CDNSKEY,
+ dns.rdatatype.CDS,
+ ):
+ raise ValueError("rrset not a DNSKEY/CDNSKEY/CDS")
+
+ _algorithms = set()
+ for algorithm in algorithms:
+ try:
+ if isinstance(algorithm, str):
+ algorithm = DSDigest[algorithm.upper()]
+ except Exception:
+ raise UnsupportedAlgorithm('unsupported algorithm "%s"' % algorithm)
+ _algorithms.add(algorithm)
+
+ if rdataset.rdtype == dns.rdatatype.CDS:
+ res = []
+ for rdata in cds_rdataset_to_ds_rdataset(rdataset):
+ if rdata.digest_type in _algorithms:
+ res.append(rdata)
+ if len(res) == 0:
+ raise ValueError("no acceptable CDS rdata found")
+ return dns.rdataset.from_rdata_list(rdataset.ttl, res)
+
+ res = []
+ for algorithm in _algorithms:
+ res.extend(dnskey_rdataset_to_cds_rdataset(rrname, rdataset, algorithm, origin))
+ return dns.rdataset.from_rdata_list(rdataset.ttl, res)
+
+
+def cds_rdataset_to_ds_rdataset(
+ rdataset: dns.rdataset.Rdataset,
+) -> dns.rdataset.Rdataset:
"""Create a CDS record from DS.
*rdataset*, a ``dns.rdataset.Rdataset``, to create DS Rdataset for.
@@ -404,12 +861,30 @@ def cds_rdataset_to_ds_rdataset(rdataset: dns.rdataset.Rdataset
Returns a ``dns.rdataset.Rdataset``
"""
- pass
-
-def dnskey_rdataset_to_cds_rdataset(name: Union[dns.name.Name, str],
- rdataset: dns.rdataset.Rdataset, algorithm: Union[DSDigest, str],
- origin: Optional[dns.name.Name]=None) ->dns.rdataset.Rdataset:
+ if rdataset.rdtype != dns.rdatatype.CDS:
+ raise ValueError("rdataset not a CDS")
+ res = []
+ for rdata in rdataset:
+ res.append(
+ CDS(
+ rdclass=rdata.rdclass,
+ rdtype=dns.rdatatype.DS,
+ key_tag=rdata.key_tag,
+ algorithm=rdata.algorithm,
+ digest_type=rdata.digest_type,
+ digest=rdata.digest,
+ )
+ )
+ return dns.rdataset.from_rdata_list(rdataset.ttl, res)
+
+
+def dnskey_rdataset_to_cds_rdataset(
+ name: Union[dns.name.Name, str],
+ rdataset: dns.rdataset.Rdataset,
+ algorithm: Union[DSDigest, str],
+ origin: Optional[dns.name.Name] = None,
+) -> dns.rdataset.Rdataset:
"""Create a CDS record from DNSKEY/CDNSKEY.
*name*, a ``dns.name.Name`` or ``str``, the owner name of the CDS record.
@@ -428,37 +903,95 @@ def dnskey_rdataset_to_cds_rdataset(name: Union[dns.name.Name, str],
Returns a ``dns.rdataset.Rdataset``
"""
- pass
+
+ if rdataset.rdtype not in (dns.rdatatype.DNSKEY, dns.rdatatype.CDNSKEY):
+ raise ValueError("rdataset not a DNSKEY/CDNSKEY")
+ res = []
+ for rdata in rdataset:
+ res.append(make_cds(name, rdata, algorithm, origin))
+ return dns.rdataset.from_rdata_list(rdataset.ttl, res)
-def dnskey_rdataset_to_cdnskey_rdataset(rdataset: dns.rdataset.Rdataset
- ) ->dns.rdataset.Rdataset:
+def dnskey_rdataset_to_cdnskey_rdataset(
+ rdataset: dns.rdataset.Rdataset,
+) -> dns.rdataset.Rdataset:
"""Create a CDNSKEY record from DNSKEY.
*rdataset*, a ``dns.rdataset.Rdataset``, to create CDNSKEY Rdataset for.
Returns a ``dns.rdataset.Rdataset``
"""
- pass
-
-def default_rrset_signer(txn: dns.transaction.Transaction, rrset: dns.rrset
- .RRset, signer: dns.name.Name, ksks: List[Tuple[PrivateKey, DNSKEY]],
- zsks: List[Tuple[PrivateKey, DNSKEY]], inception: Optional[Union[
- datetime, str, int, float]]=None, expiration: Optional[Union[datetime,
- str, int, float]]=None, lifetime: Optional[int]=None, policy: Optional[
- Policy]=None, origin: Optional[dns.name.Name]=None) ->None:
+ if rdataset.rdtype != dns.rdatatype.DNSKEY:
+ raise ValueError("rdataset not a DNSKEY")
+ res = []
+ for rdata in rdataset:
+ res.append(
+ CDNSKEY(
+ rdclass=rdataset.rdclass,
+ rdtype=rdataset.rdtype,
+ flags=rdata.flags,
+ protocol=rdata.protocol,
+ algorithm=rdata.algorithm,
+ key=rdata.key,
+ )
+ )
+ return dns.rdataset.from_rdata_list(rdataset.ttl, res)
+
+
+def default_rrset_signer(
+ txn: dns.transaction.Transaction,
+ rrset: dns.rrset.RRset,
+ signer: dns.name.Name,
+ ksks: List[Tuple[PrivateKey, DNSKEY]],
+ zsks: List[Tuple[PrivateKey, DNSKEY]],
+ inception: Optional[Union[datetime, str, int, float]] = None,
+ expiration: Optional[Union[datetime, str, int, float]] = None,
+ lifetime: Optional[int] = None,
+ policy: Optional[Policy] = None,
+ origin: Optional[dns.name.Name] = None,
+) -> None:
"""Default RRset signer"""
- pass
-
-def sign_zone(zone: dns.zone.Zone, txn: Optional[dns.transaction.
- Transaction]=None, keys: Optional[List[Tuple[PrivateKey, DNSKEY]]]=None,
- add_dnskey: bool=True, dnskey_ttl: Optional[int]=None, inception:
- Optional[Union[datetime, str, int, float]]=None, expiration: Optional[
- Union[datetime, str, int, float]]=None, lifetime: Optional[int]=None,
- nsec3: Optional[NSEC3PARAM]=None, rrset_signer: Optional[RRsetSigner]=
- None, policy: Optional[Policy]=None) ->None:
+ if rrset.rdtype in set(
+ [
+ dns.rdatatype.RdataType.DNSKEY,
+ dns.rdatatype.RdataType.CDS,
+ dns.rdatatype.RdataType.CDNSKEY,
+ ]
+ ):
+ keys = ksks
+ else:
+ keys = zsks
+
+ for private_key, dnskey in keys:
+ rrsig = dns.dnssec.sign(
+ rrset=rrset,
+ private_key=private_key,
+ dnskey=dnskey,
+ inception=inception,
+ expiration=expiration,
+ lifetime=lifetime,
+ signer=signer,
+ policy=policy,
+ origin=origin,
+ )
+ txn.add(rrset.name, rrset.ttl, rrsig)
+
+
+def sign_zone(
+ zone: dns.zone.Zone,
+ txn: Optional[dns.transaction.Transaction] = None,
+ keys: Optional[List[Tuple[PrivateKey, DNSKEY]]] = None,
+ add_dnskey: bool = True,
+ dnskey_ttl: Optional[int] = None,
+ inception: Optional[Union[datetime, str, int, float]] = None,
+ expiration: Optional[Union[datetime, str, int, float]] = None,
+ lifetime: Optional[int] = None,
+ nsec3: Optional[NSEC3PARAM] = None,
+ rrset_signer: Optional[RRsetSigner] = None,
+ policy: Optional[Policy] = None,
+) -> None:
"""Sign zone.
*zone*, a ``dns.zone.Zone``, the zone to sign.
@@ -499,37 +1032,176 @@ def sign_zone(zone: dns.zone.Zone, txn: Optional[dns.transaction.
Returns ``None``.
"""
- pass
-
-def _sign_zone_nsec(zone: dns.zone.Zone, txn: dns.transaction.Transaction,
- rrset_signer: Optional[RRsetSigner]=None) ->None:
+ ksks = []
+ zsks = []
+
+ # if we have both KSKs and ZSKs, split by SEP flag. if not, sign all
+ # records with all keys
+ if keys:
+ for key in keys:
+ if key[1].flags & Flag.SEP:
+ ksks.append(key)
+ else:
+ zsks.append(key)
+ if not ksks:
+ ksks = keys
+ if not zsks:
+ zsks = keys
+ else:
+ keys = []
+
+ if txn:
+ cm: contextlib.AbstractContextManager = contextlib.nullcontext(txn)
+ else:
+ cm = zone.writer()
+
+ with cm as _txn:
+ if add_dnskey:
+ if dnskey_ttl is None:
+ dnskey = _txn.get(zone.origin, dns.rdatatype.DNSKEY)
+ if dnskey:
+ dnskey_ttl = dnskey.ttl
+ else:
+ soa = _txn.get(zone.origin, dns.rdatatype.SOA)
+ dnskey_ttl = soa.ttl
+ for _, dnskey in keys:
+ _txn.add(zone.origin, dnskey_ttl, dnskey)
+
+ if nsec3:
+ raise NotImplementedError("Signing with NSEC3 not yet implemented")
+ else:
+ _rrset_signer = rrset_signer or functools.partial(
+ default_rrset_signer,
+ signer=zone.origin,
+ ksks=ksks,
+ zsks=zsks,
+ inception=inception,
+ expiration=expiration,
+ lifetime=lifetime,
+ policy=policy,
+ origin=zone.origin,
+ )
+ return _sign_zone_nsec(zone, _txn, _rrset_signer)
+
+
+def _sign_zone_nsec(
+ zone: dns.zone.Zone,
+ txn: dns.transaction.Transaction,
+ rrset_signer: Optional[RRsetSigner] = None,
+) -> None:
"""NSEC zone signer"""
- pass
-
-if dns._features.have('dnssec'):
+ def _txn_add_nsec(
+ txn: dns.transaction.Transaction,
+ name: dns.name.Name,
+ next_secure: Optional[dns.name.Name],
+ rdclass: dns.rdataclass.RdataClass,
+ ttl: int,
+ rrset_signer: Optional[RRsetSigner] = None,
+ ) -> None:
+ """NSEC zone signer helper"""
+ mandatory_types = set(
+ [dns.rdatatype.RdataType.RRSIG, dns.rdatatype.RdataType.NSEC]
+ )
+ node = txn.get_node(name)
+ if node and next_secure:
+ types = (
+ set([rdataset.rdtype for rdataset in node.rdatasets]) | mandatory_types
+ )
+ windows = Bitmap.from_rdtypes(list(types))
+ rrset = dns.rrset.from_rdata(
+ name,
+ ttl,
+ NSEC(
+ rdclass=rdclass,
+ rdtype=dns.rdatatype.RdataType.NSEC,
+ next=next_secure,
+ windows=windows,
+ ),
+ )
+ txn.add(rrset)
+ if rrset_signer:
+ rrset_signer(txn, rrset)
+
+ rrsig_ttl = zone.get_soa().minimum
+ delegation = None
+ last_secure = None
+
+ for name in sorted(txn.iterate_names()):
+ if delegation and name.is_subdomain(delegation):
+ # names below delegations are not secure
+ continue
+ elif txn.get(name, dns.rdatatype.NS) and name != zone.origin:
+ # inside delegation
+ delegation = name
+ else:
+ # outside delegation
+ delegation = None
+
+ if rrset_signer:
+ node = txn.get_node(name)
+ if node:
+ for rdataset in node.rdatasets:
+ if rdataset.rdtype == dns.rdatatype.RRSIG:
+ # do not sign RRSIGs
+ continue
+ elif delegation and rdataset.rdtype != dns.rdatatype.DS:
+ # do not sign delegations except DS records
+ continue
+ else:
+ rrset = dns.rrset.from_rdata(name, rdataset.ttl, *rdataset)
+ rrset_signer(txn, rrset)
+
+ # We need "is not None" as the empty name is False because its length is 0.
+ if last_secure is not None:
+ _txn_add_nsec(txn, last_secure, name, zone.rdclass, rrsig_ttl, rrset_signer)
+ last_secure = name
+
+ if last_secure:
+ _txn_add_nsec(
+ txn, last_secure, zone.origin, zone.rdclass, rrsig_ttl, rrset_signer
+ )
+
+
+def _need_pyca(*args, **kwargs):
+ raise ImportError(
+ "DNSSEC validation requires python cryptography"
+ ) # pragma: no cover
+
+
+if dns._features.have("dnssec"):
from cryptography.exceptions import InvalidSignature
- from cryptography.hazmat.primitives.asymmetric import dsa
- from cryptography.hazmat.primitives.asymmetric import ec
- from cryptography.hazmat.primitives.asymmetric import ed448
- from cryptography.hazmat.primitives.asymmetric import rsa
- from cryptography.hazmat.primitives.asymmetric import ed25519
- from dns.dnssecalgs import get_algorithm_cls, get_algorithm_cls_from_dnskey
+ from cryptography.hazmat.primitives.asymmetric import dsa # pylint: disable=W0611
+ from cryptography.hazmat.primitives.asymmetric import ec # pylint: disable=W0611
+ from cryptography.hazmat.primitives.asymmetric import ed448 # pylint: disable=W0611
+ from cryptography.hazmat.primitives.asymmetric import rsa # pylint: disable=W0611
+ from cryptography.hazmat.primitives.asymmetric import ( # pylint: disable=W0611
+ ed25519,
+ )
+
+ from dns.dnssecalgs import ( # pylint: disable=C0412
+ get_algorithm_cls,
+ get_algorithm_cls_from_dnskey,
+ )
from dns.dnssecalgs.base import GenericPrivateKey, GenericPublicKey
- validate = _validate
- validate_rrsig = _validate_rrsig
+
+ validate = _validate # type: ignore
+ validate_rrsig = _validate_rrsig # type: ignore
sign = _sign
make_dnskey = _make_dnskey
make_cdnskey = _make_cdnskey
_have_pyca = True
-else:
+else: # pragma: no cover
validate = _need_pyca
validate_rrsig = _need_pyca
sign = _need_pyca
make_dnskey = _need_pyca
make_cdnskey = _need_pyca
_have_pyca = False
+
+### BEGIN generated Algorithm constants
+
RSAMD5 = Algorithm.RSAMD5
DH = Algorithm.DH
DSA = Algorithm.DSA
@@ -547,3 +1219,5 @@ ED448 = Algorithm.ED448
INDIRECT = Algorithm.INDIRECT
PRIVATEDNS = Algorithm.PRIVATEDNS
PRIVATEOID = Algorithm.PRIVATEOID
+
+### END generated Algorithm constants
diff --git a/dns/dnssecalgs/base.py b/dns/dnssecalgs/base.py
index 9fc70e6..e990575 100644
--- a/dns/dnssecalgs/base.py
+++ b/dns/dnssecalgs/base.py
@@ -1,5 +1,6 @@
-from abc import ABC, abstractmethod
+from abc import ABC, abstractmethod # pylint: disable=no-name-in-module
from typing import Any, Optional, Type
+
import dns.rdataclass
import dns.rdatatype
from dns.dnssectypes import Algorithm
@@ -12,68 +13,72 @@ class GenericPublicKey(ABC):
algorithm: Algorithm
@abstractmethod
- def __init__(self, key: Any) ->None:
+ def __init__(self, key: Any) -> None:
pass
@abstractmethod
- def verify(self, signature: bytes, data: bytes) ->None:
+ def verify(self, signature: bytes, data: bytes) -> None:
"""Verify signed DNSSEC data"""
- pass
@abstractmethod
- def encode_key_bytes(self) ->bytes:
+ def encode_key_bytes(self) -> bytes:
"""Encode key as bytes for DNSKEY"""
- pass
- def to_dnskey(self, flags: int=Flag.ZONE, protocol: int=3) ->DNSKEY:
+ @classmethod
+ def _ensure_algorithm_key_combination(cls, key: DNSKEY) -> None:
+ if key.algorithm != cls.algorithm:
+ raise AlgorithmKeyMismatch
+
+ def to_dnskey(self, flags: int = Flag.ZONE, protocol: int = 3) -> DNSKEY:
"""Return public key as DNSKEY"""
- pass
+ return DNSKEY(
+ rdclass=dns.rdataclass.IN,
+ rdtype=dns.rdatatype.DNSKEY,
+ flags=flags,
+ protocol=protocol,
+ algorithm=self.algorithm,
+ key=self.encode_key_bytes(),
+ )
@classmethod
@abstractmethod
- def from_dnskey(cls, key: DNSKEY) ->'GenericPublicKey':
+ def from_dnskey(cls, key: DNSKEY) -> "GenericPublicKey":
"""Create public key from DNSKEY"""
- pass
@classmethod
@abstractmethod
- def from_pem(cls, public_pem: bytes) ->'GenericPublicKey':
+ def from_pem(cls, public_pem: bytes) -> "GenericPublicKey":
"""Create public key from PEM-encoded SubjectPublicKeyInfo as specified
in RFC 5280"""
- pass
@abstractmethod
- def to_pem(self) ->bytes:
+ def to_pem(self) -> bytes:
"""Return public-key as PEM-encoded SubjectPublicKeyInfo as specified
in RFC 5280"""
- pass
class GenericPrivateKey(ABC):
public_cls: Type[GenericPublicKey]
@abstractmethod
- def __init__(self, key: Any) ->None:
+ def __init__(self, key: Any) -> None:
pass
@abstractmethod
- def sign(self, data: bytes, verify: bool=False) ->bytes:
+ def sign(self, data: bytes, verify: bool = False) -> bytes:
"""Sign DNSSEC data"""
- pass
@abstractmethod
- def public_key(self) ->'GenericPublicKey':
+ def public_key(self) -> "GenericPublicKey":
"""Return public key instance"""
- pass
@classmethod
@abstractmethod
- def from_pem(cls, private_pem: bytes, password: Optional[bytes]=None
- ) ->'GenericPrivateKey':
+ def from_pem(
+ cls, private_pem: bytes, password: Optional[bytes] = None
+ ) -> "GenericPrivateKey":
"""Create private key from PEM-encoded PKCS#8"""
- pass
@abstractmethod
- def to_pem(self, password: Optional[bytes]=None) ->bytes:
+ def to_pem(self, password: Optional[bytes] = None) -> bytes:
"""Return private key as PEM-encoded PKCS#8"""
- pass
diff --git a/dns/dnssecalgs/cryptography.py b/dns/dnssecalgs/cryptography.py
index cdc4553..5a31a81 100644
--- a/dns/dnssecalgs/cryptography.py
+++ b/dns/dnssecalgs/cryptography.py
@@ -1,5 +1,7 @@
from typing import Any, Optional, Type
+
from cryptography.hazmat.primitives import serialization
+
from dns.dnssecalgs.base import GenericPrivateKey, GenericPublicKey
from dns.exception import AlgorithmKeyMismatch
@@ -8,22 +10,59 @@ class CryptographyPublicKey(GenericPublicKey):
key: Any = None
key_cls: Any = None
- def __init__(self, key: Any) ->None:
+ def __init__(self, key: Any) -> None: # pylint: disable=super-init-not-called
if self.key_cls is None:
- raise TypeError('Undefined private key class')
- if not isinstance(key, self.key_cls):
+ raise TypeError("Undefined private key class")
+ if not isinstance( # pylint: disable=isinstance-second-argument-not-valid-type
+ key, self.key_cls
+ ):
raise AlgorithmKeyMismatch
self.key = key
+ @classmethod
+ def from_pem(cls, public_pem: bytes) -> "GenericPublicKey":
+ key = serialization.load_pem_public_key(public_pem)
+ return cls(key=key)
+
+ def to_pem(self) -> bytes:
+ return self.key.public_bytes(
+ encoding=serialization.Encoding.PEM,
+ format=serialization.PublicFormat.SubjectPublicKeyInfo,
+ )
+
class CryptographyPrivateKey(GenericPrivateKey):
key: Any = None
key_cls: Any = None
public_cls: Type[CryptographyPublicKey]
- def __init__(self, key: Any) ->None:
+ def __init__(self, key: Any) -> None: # pylint: disable=super-init-not-called
if self.key_cls is None:
- raise TypeError('Undefined private key class')
- if not isinstance(key, self.key_cls):
+ raise TypeError("Undefined private key class")
+ if not isinstance( # pylint: disable=isinstance-second-argument-not-valid-type
+ key, self.key_cls
+ ):
raise AlgorithmKeyMismatch
self.key = key
+
+ def public_key(self) -> "CryptographyPublicKey":
+ return self.public_cls(key=self.key.public_key())
+
+ @classmethod
+ def from_pem(
+ cls, private_pem: bytes, password: Optional[bytes] = None
+ ) -> "GenericPrivateKey":
+ key = serialization.load_pem_private_key(private_pem, password=password)
+ return cls(key=key)
+
+ def to_pem(self, password: Optional[bytes] = None) -> bytes:
+ encryption_algorithm: serialization.KeySerializationEncryption
+ if password:
+ encryption_algorithm = serialization.BestAvailableEncryption(password)
+ else:
+ encryption_algorithm = serialization.NoEncryption()
+ return self.key.private_bytes(
+ encoding=serialization.Encoding.PEM,
+ format=serialization.PrivateFormat.PKCS8,
+ encryption_algorithm=encryption_algorithm,
+ )
diff --git a/dns/dnssecalgs/dsa.py b/dns/dnssecalgs/dsa.py
index d09a487..0fe4690 100644
--- a/dns/dnssecalgs/dsa.py
+++ b/dns/dnssecalgs/dsa.py
@@ -1,7 +1,9 @@
import struct
+
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import dsa, utils
+
from dns.dnssecalgs.cryptography import CryptographyPrivateKey, CryptographyPublicKey
from dns.dnssectypes import Algorithm
from dns.rdtypes.ANY.DNSKEY import DNSKEY
@@ -13,9 +15,52 @@ class PublicDSA(CryptographyPublicKey):
algorithm = Algorithm.DSA
chosen_hash = hashes.SHA1()
- def encode_key_bytes(self) ->bytes:
+ def verify(self, signature: bytes, data: bytes) -> None:
+ sig_r = signature[1:21]
+ sig_s = signature[21:]
+ sig = utils.encode_dss_signature(
+ int.from_bytes(sig_r, "big"), int.from_bytes(sig_s, "big")
+ )
+ self.key.verify(sig, data, self.chosen_hash)
+
+ def encode_key_bytes(self) -> bytes:
"""Encode a public key per RFC 2536, section 2."""
- pass
+ pn = self.key.public_numbers()
+ dsa_t = (self.key.key_size // 8 - 64) // 8
+ if dsa_t > 8:
+ raise ValueError("unsupported DSA key size")
+ octets = 64 + dsa_t * 8
+ res = struct.pack("!B", dsa_t)
+ res += pn.parameter_numbers.q.to_bytes(20, "big")
+ res += pn.parameter_numbers.p.to_bytes(octets, "big")
+ res += pn.parameter_numbers.g.to_bytes(octets, "big")
+ res += pn.y.to_bytes(octets, "big")
+ return res
+
+ @classmethod
+ def from_dnskey(cls, key: DNSKEY) -> "PublicDSA":
+ cls._ensure_algorithm_key_combination(key)
+ keyptr = key.key
+ (t,) = struct.unpack("!B", keyptr[0:1])
+ keyptr = keyptr[1:]
+ octets = 64 + t * 8
+ dsa_q = keyptr[0:20]
+ keyptr = keyptr[20:]
+ dsa_p = keyptr[0:octets]
+ keyptr = keyptr[octets:]
+ dsa_g = keyptr[0:octets]
+ keyptr = keyptr[octets:]
+ dsa_y = keyptr[0:octets]
+ return cls(
+ key=dsa.DSAPublicNumbers( # type: ignore
+ int.from_bytes(dsa_y, "big"),
+ dsa.DSAParameterNumbers(
+ int.from_bytes(dsa_p, "big"),
+ int.from_bytes(dsa_q, "big"),
+ int.from_bytes(dsa_g, "big"),
+ ),
+ ).public_key(default_backend()),
+ )
class PrivateDSA(CryptographyPrivateKey):
@@ -23,9 +68,29 @@ class PrivateDSA(CryptographyPrivateKey):
key_cls = dsa.DSAPrivateKey
public_cls = PublicDSA
- def sign(self, data: bytes, verify: bool=False) ->bytes:
+ def sign(self, data: bytes, verify: bool = False) -> bytes:
"""Sign using a private key per RFC 2536, section 3."""
- pass
+ public_dsa_key = self.key.public_key()
+ if public_dsa_key.key_size > 1024:
+ raise ValueError("DSA key size overflow")
+ der_signature = self.key.sign(data, self.public_cls.chosen_hash)
+ dsa_r, dsa_s = utils.decode_dss_signature(der_signature)
+ dsa_t = (public_dsa_key.key_size // 8 - 64) // 8
+ octets = 20
+ signature = (
+ struct.pack("!B", dsa_t)
+ + int.to_bytes(dsa_r, length=octets, byteorder="big")
+ + int.to_bytes(dsa_s, length=octets, byteorder="big")
+ )
+ if verify:
+ self.public_key().verify(signature, data)
+ return signature
+
+ @classmethod
+ def generate(cls, key_size: int) -> "PrivateDSA":
+ return cls(
+ key=dsa.generate_private_key(key_size=key_size),
+ )
class PublicDSANSEC3SHA1(PublicDSA):
diff --git a/dns/dnssecalgs/ecdsa.py b/dns/dnssecalgs/ecdsa.py
index f482742..a31d79f 100644
--- a/dns/dnssecalgs/ecdsa.py
+++ b/dns/dnssecalgs/ecdsa.py
@@ -1,6 +1,7 @@
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import ec, utils
+
from dns.dnssecalgs.cryptography import CryptographyPrivateKey, CryptographyPublicKey
from dns.dnssectypes import Algorithm
from dns.rdtypes.ANY.DNSKEY import DNSKEY
@@ -14,9 +15,31 @@ class PublicECDSA(CryptographyPublicKey):
curve: ec.EllipticCurve
octets: int
- def encode_key_bytes(self) ->bytes:
+ def verify(self, signature: bytes, data: bytes) -> None:
+ sig_r = signature[0 : self.octets]
+ sig_s = signature[self.octets :]
+ sig = utils.encode_dss_signature(
+ int.from_bytes(sig_r, "big"), int.from_bytes(sig_s, "big")
+ )
+ self.key.verify(sig, data, ec.ECDSA(self.chosen_hash))
+
+ def encode_key_bytes(self) -> bytes:
"""Encode a public key per RFC 6605, section 4."""
- pass
+ pn = self.key.public_numbers()
+ return pn.x.to_bytes(self.octets, "big") + pn.y.to_bytes(self.octets, "big")
+
+ @classmethod
+ def from_dnskey(cls, key: DNSKEY) -> "PublicECDSA":
+ cls._ensure_algorithm_key_combination(key)
+ ecdsa_x = key.key[0 : cls.octets]
+ ecdsa_y = key.key[cls.octets : cls.octets * 2]
+ return cls(
+ key=ec.EllipticCurvePublicNumbers(
+ curve=cls.curve,
+ x=int.from_bytes(ecdsa_x, "big"),
+ y=int.from_bytes(ecdsa_y, "big"),
+ ).public_key(default_backend()),
+ )
class PrivateECDSA(CryptographyPrivateKey):
@@ -24,9 +47,24 @@ class PrivateECDSA(CryptographyPrivateKey):
key_cls = ec.EllipticCurvePrivateKey
public_cls = PublicECDSA
- def sign(self, data: bytes, verify: bool=False) ->bytes:
+ def sign(self, data: bytes, verify: bool = False) -> bytes:
"""Sign using a private key per RFC 6605, section 4."""
- pass
+ der_signature = self.key.sign(data, ec.ECDSA(self.public_cls.chosen_hash))
+ dsa_r, dsa_s = utils.decode_dss_signature(der_signature)
+ signature = int.to_bytes(
+ dsa_r, length=self.public_cls.octets, byteorder="big"
+ ) + int.to_bytes(dsa_s, length=self.public_cls.octets, byteorder="big")
+ if verify:
+ self.public_key().verify(signature, data)
+ return signature
+
+ @classmethod
+ def generate(cls) -> "PrivateECDSA":
+ return cls(
+ key=ec.generate_private_key(
+ curve=cls.public_cls.curve, backend=default_backend()
+ ),
+ )
class PublicECDSAP256SHA256(PublicECDSA):
diff --git a/dns/dnssecalgs/eddsa.py b/dns/dnssecalgs/eddsa.py
index 7705e31..7050534 100644
--- a/dns/dnssecalgs/eddsa.py
+++ b/dns/dnssecalgs/eddsa.py
@@ -1,24 +1,44 @@
from typing import Type
+
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import ed448, ed25519
+
from dns.dnssecalgs.cryptography import CryptographyPrivateKey, CryptographyPublicKey
from dns.dnssectypes import Algorithm
from dns.rdtypes.ANY.DNSKEY import DNSKEY
class PublicEDDSA(CryptographyPublicKey):
+ def verify(self, signature: bytes, data: bytes) -> None:
+ self.key.verify(signature, data)
- def encode_key_bytes(self) ->bytes:
+ def encode_key_bytes(self) -> bytes:
"""Encode a public key per RFC 8080, section 3."""
- pass
+ return self.key.public_bytes(
+ encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw
+ )
+
+ @classmethod
+ def from_dnskey(cls, key: DNSKEY) -> "PublicEDDSA":
+ cls._ensure_algorithm_key_combination(key)
+ return cls(
+ key=cls.key_cls.from_public_bytes(key.key),
+ )
class PrivateEDDSA(CryptographyPrivateKey):
public_cls: Type[PublicEDDSA]
- def sign(self, data: bytes, verify: bool=False) ->bytes:
+ def sign(self, data: bytes, verify: bool = False) -> bytes:
"""Sign using a private key per RFC 8080, section 4."""
- pass
+ signature = self.key.sign(data)
+ if verify:
+ self.public_key().verify(signature, data)
+ return signature
+
+ @classmethod
+ def generate(cls) -> "PrivateEDDSA":
+ return cls(key=cls.key_cls.generate())
class PublicED25519(PublicEDDSA):
diff --git a/dns/dnssecalgs/rsa.py b/dns/dnssecalgs/rsa.py
index 91f1eaf..e95dcf1 100644
--- a/dns/dnssecalgs/rsa.py
+++ b/dns/dnssecalgs/rsa.py
@@ -1,8 +1,10 @@
import math
import struct
+
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import padding, rsa
+
from dns.dnssecalgs.cryptography import CryptographyPrivateKey, CryptographyPublicKey
from dns.dnssectypes import Algorithm
from dns.rdtypes.ANY.DNSKEY import DNSKEY
@@ -14,9 +16,38 @@ class PublicRSA(CryptographyPublicKey):
algorithm: Algorithm
chosen_hash: hashes.HashAlgorithm
- def encode_key_bytes(self) ->bytes:
+ def verify(self, signature: bytes, data: bytes) -> None:
+ self.key.verify(signature, data, padding.PKCS1v15(), self.chosen_hash)
+
+ def encode_key_bytes(self) -> bytes:
"""Encode a public key per RFC 3110, section 2."""
- pass
+ pn = self.key.public_numbers()
+ _exp_len = math.ceil(int.bit_length(pn.e) / 8)
+ exp = int.to_bytes(pn.e, length=_exp_len, byteorder="big")
+ if _exp_len > 255:
+ exp_header = b"\0" + struct.pack("!H", _exp_len)
+ else:
+ exp_header = struct.pack("!B", _exp_len)
+ if pn.n.bit_length() < 512 or pn.n.bit_length() > 4096:
+ raise ValueError("unsupported RSA key length")
+ return exp_header + exp + pn.n.to_bytes((pn.n.bit_length() + 7) // 8, "big")
+
+ @classmethod
+ def from_dnskey(cls, key: DNSKEY) -> "PublicRSA":
+ cls._ensure_algorithm_key_combination(key)
+ keyptr = key.key
+ (bytes_,) = struct.unpack("!B", keyptr[0:1])
+ keyptr = keyptr[1:]
+ if bytes_ == 0:
+ (bytes_,) = struct.unpack("!H", keyptr[0:2])
+ keyptr = keyptr[2:]
+ rsa_e = keyptr[0:bytes_]
+ rsa_n = keyptr[bytes_:]
+ return cls(
+ key=rsa.RSAPublicNumbers(
+ int.from_bytes(rsa_e, "big"), int.from_bytes(rsa_n, "big")
+ ).public_key(default_backend())
+ )
class PrivateRSA(CryptographyPrivateKey):
@@ -25,9 +56,22 @@ class PrivateRSA(CryptographyPrivateKey):
public_cls = PublicRSA
default_public_exponent = 65537
- def sign(self, data: bytes, verify: bool=False) ->bytes:
+ def sign(self, data: bytes, verify: bool = False) -> bytes:
"""Sign using a private key per RFC 3110, section 3."""
- pass
+ signature = self.key.sign(data, padding.PKCS1v15(), self.public_cls.chosen_hash)
+ if verify:
+ self.public_key().verify(signature, data)
+ return signature
+
+ @classmethod
+ def generate(cls, key_size: int) -> "PrivateRSA":
+ return cls(
+ key=rsa.generate_private_key(
+ public_exponent=cls.default_public_exponent,
+ key_size=key_size,
+ backend=default_backend(),
+ )
+ )
class PublicRSAMD5(PublicRSA):
diff --git a/dns/dnssectypes.py b/dns/dnssectypes.py
index 1d320c7..02131e0 100644
--- a/dns/dnssectypes.py
+++ b/dns/dnssectypes.py
@@ -1,4 +1,25 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2017 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
"""Common DNSSEC-related types."""
+
+# This is a separate file to avoid import circularity between dns.dnssec and
+# the implementations of the DS and DNSKEY types.
+
import dns.enum
@@ -21,16 +42,30 @@ class Algorithm(dns.enum.IntEnum):
PRIVATEDNS = 253
PRIVATEOID = 254
+ @classmethod
+ def _maximum(cls):
+ return 255
+
class DSDigest(dns.enum.IntEnum):
"""DNSSEC Delegation Signer Digest Algorithm"""
+
NULL = 0
SHA1 = 1
SHA256 = 2
GOST = 3
SHA384 = 4
+ @classmethod
+ def _maximum(cls):
+ return 255
+
class NSEC3Hash(dns.enum.IntEnum):
"""NSEC3 hash algorithm"""
+
SHA1 = 1
+
+ @classmethod
+ def _maximum(cls):
+ return 255
diff --git a/dns/e164.py b/dns/e164.py
index 94218a3..453736d 100644
--- a/dns/e164.py
+++ b/dns/e164.py
@@ -1,13 +1,35 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2006-2017 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
"""DNS E.164 helpers."""
+
from typing import Iterable, Optional, Union
+
import dns.exception
import dns.name
import dns.resolver
-public_enum_domain = dns.name.from_text('e164.arpa.')
+
+#: The public E.164 domain.
+public_enum_domain = dns.name.from_text("e164.arpa.")
-def from_e164(text: str, origin: Optional[dns.name.Name]=public_enum_domain
- ) ->dns.name.Name:
+def from_e164(
+ text: str, origin: Optional[dns.name.Name] = public_enum_domain
+) -> dns.name.Name:
"""Convert an E.164 number in textual form into a Name object whose
value is the ENUM domain name for that number.
@@ -21,11 +43,17 @@ def from_e164(text: str, origin: Optional[dns.name.Name]=public_enum_domain
Returns a ``dns.name.Name``.
"""
- pass
+ parts = [d for d in text if d.isdigit()]
+ parts.reverse()
+ return dns.name.from_text(".".join(parts), origin=origin)
-def to_e164(name: dns.name.Name, origin: Optional[dns.name.Name]=
- public_enum_domain, want_plus_prefix: bool=True) ->str:
+
+def to_e164(
+ name: dns.name.Name,
+ origin: Optional[dns.name.Name] = public_enum_domain,
+ want_plus_prefix: bool = True,
+) -> str:
"""Convert an ENUM domain name into an E.164 number.
Note that dnspython does not have any information about preferred
@@ -45,11 +73,23 @@ def to_e164(name: dns.name.Name, origin: Optional[dns.name.Name]=
Returns a ``str``.
"""
- pass
-
-
-def query(number: str, domains: Iterable[Union[dns.name.Name, str]],
- resolver: Optional[dns.resolver.Resolver]=None) ->dns.resolver.Answer:
+ if origin is not None:
+ name = name.relativize(origin)
+ dlabels = [d for d in name.labels if d.isdigit() and len(d) == 1]
+ if len(dlabels) != len(name.labels):
+ raise dns.exception.SyntaxError("non-digit labels in ENUM domain name")
+ dlabels.reverse()
+ text = b"".join(dlabels)
+ if want_plus_prefix:
+ text = b"+" + text
+ return text.decode()
+
+
+def query(
+ number: str,
+ domains: Iterable[Union[dns.name.Name, str]],
+ resolver: Optional[dns.resolver.Resolver] = None,
+) -> dns.resolver.Answer:
"""Look for NAPTR RRs for the specified number in the specified domains.
e.g. lookup('16505551212', ['e164.dnspython.org.', 'e164.arpa.'])
@@ -61,4 +101,16 @@ def query(number: str, domains: Iterable[Union[dns.name.Name, str]],
*resolver*, a ``dns.resolver.Resolver``, is the resolver to use. If
``None``, the default resolver is used.
"""
- pass
+
+ if resolver is None:
+ resolver = dns.resolver.get_default_resolver()
+ e_nx = dns.resolver.NXDOMAIN()
+ for domain in domains:
+ if isinstance(domain, str):
+ domain = dns.name.from_text(domain)
+ qname = dns.e164.from_e164(number, domain)
+ try:
+ return resolver.resolve(qname, "NAPTR")
+ except dns.resolver.NXDOMAIN as e:
+ e_nx += e
+ raise e_nx
diff --git a/dns/edns.py b/dns/edns.py
index ac75090..776e5ee 100644
--- a/dns/edns.py
+++ b/dns/edns.py
@@ -1,9 +1,28 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2009-2017 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
"""EDNS Options"""
+
import binascii
import math
import socket
import struct
from typing import Any, Dict, Optional, Union
+
import dns.enum
import dns.inet
import dns.rdata
@@ -11,18 +30,33 @@ import dns.wire
class OptionType(dns.enum.IntEnum):
+ #: NSID
NSID = 3
+ #: DAU
DAU = 5
+ #: DHU
DHU = 6
+ #: N3U
N3U = 7
+ #: ECS (client-subnet)
ECS = 8
+ #: EXPIRE
EXPIRE = 9
+ #: COOKIE
COOKIE = 10
+ #: KEEPALIVE
KEEPALIVE = 11
+ #: PADDING
PADDING = 12
+ #: CHAIN
CHAIN = 13
+ #: EDE (extended-dns-error)
EDE = 15
+ @classmethod
+ def _maximum(cls):
+ return 65535
+
class Option:
"""Base class for all EDNS option types."""
@@ -34,17 +68,19 @@ class Option:
"""
self.otype = OptionType.make(otype)
- def to_wire(self, file: Optional[Any]=None) ->Optional[bytes]:
+ def to_wire(self, file: Optional[Any] = None) -> Optional[bytes]:
"""Convert an option to wire format.
Returns a ``bytes`` or ``None``.
"""
- pass
+ raise NotImplementedError # pragma: no cover
+
+ def to_text(self) -> str:
+ raise NotImplementedError # pragma: no cover
@classmethod
- def from_wire_parser(cls, otype: OptionType, parser: 'dns.wire.Parser'
- ) ->'Option':
+ def from_wire_parser(cls, otype: OptionType, parser: "dns.wire.Parser") -> "Option":
"""Build an EDNS option object from wire format.
*otype*, a ``dns.edns.OptionType``, is the option type.
@@ -54,14 +90,20 @@ class Option:
Returns a ``dns.edns.Option``.
"""
- pass
+ raise NotImplementedError # pragma: no cover
def _cmp(self, other):
"""Compare an EDNS option with another option of the same type.
Returns < 0 if < *other*, 0 if == *other*, and > 0 if > *other*.
"""
- pass
+ wire = self.to_wire()
+ owire = other.to_wire()
+ if wire == owire:
+ return 0
+ if wire > owire:
+ return 1
+ return -1
def __eq__(self, other):
if not isinstance(other, Option):
@@ -101,7 +143,7 @@ class Option:
return self.to_text()
-class GenericOption(Option):
+class GenericOption(Option): # lgtm[py/missing-equals]
"""Generic Option Class
This class is used for EDNS option types for which we have no better
@@ -112,12 +154,27 @@ class GenericOption(Option):
super().__init__(otype)
self.data = dns.rdata.Rdata._as_bytes(data, True)
+ def to_wire(self, file: Optional[Any] = None) -> Optional[bytes]:
+ if file:
+ file.write(self.data)
+ return None
+ else:
+ return self.data
+
+ def to_text(self) -> str:
+ return "Generic %d" % self.otype
+
+ @classmethod
+ def from_wire_parser(
+ cls, otype: Union[OptionType, str], parser: "dns.wire.Parser"
+ ) -> Option:
+ return cls(otype, parser.get_remaining())
+
-class ECSOption(Option):
+class ECSOption(Option): # lgtm[py/missing-equals]
"""EDNS Client Subnet (ECS, RFC7871)"""
- def __init__(self, address: str, srclen: Optional[int]=None, scopelen:
- int=0):
+ def __init__(self, address: str, srclen: Optional[int] = None, scopelen: int = 0):
"""*address*, a ``str``, is the client address information.
*srclen*, an ``int``, the source prefix length, which is the
@@ -127,8 +184,10 @@ class ECSOption(Option):
*scopelen*, an ``int``, the scope prefix length. This value
must be 0 in queries, and should be set in responses.
"""
+
super().__init__(OptionType.ECS)
af = dns.inet.af_for_address(address)
+
if af == socket.AF_INET6:
self.family = 2
if srclen is None:
@@ -143,22 +202,30 @@ class ECSOption(Option):
address = dns.rdata.Rdata._as_ipv4_address(address)
srclen = dns.rdata.Rdata._as_int(srclen, 0, 32)
scopelen = dns.rdata.Rdata._as_int(scopelen, 0, 32)
- else:
- raise ValueError('Bad address family')
+ else: # pragma: no cover (this will never happen)
+ raise ValueError("Bad address family")
+
assert srclen is not None
self.address = address
self.srclen = srclen
self.scopelen = scopelen
+
addrdata = dns.inet.inet_pton(af, address)
nbytes = int(math.ceil(srclen / 8.0))
+
+ # Truncate to srclen and pad to the end of the last octet needed
+ # See RFC section 6
self.addrdata = addrdata[:nbytes]
nbits = srclen % 8
if nbits != 0:
- last = struct.pack('B', ord(self.addrdata[-1:]) & 255 << 8 - nbits)
+ last = struct.pack("B", ord(self.addrdata[-1:]) & (0xFF << (8 - nbits)))
self.addrdata = self.addrdata[:-1] + last
+ def to_text(self) -> str:
+ return "ECS {}/{} scope/{}".format(self.address, self.srclen, self.scopelen)
+
@staticmethod
- def from_text(text: str) ->Option:
+ def from_text(text: str) -> Option:
"""Convert a string into a `dns.edns.ECSOption`
*text*, a `str`, the text form of the option.
@@ -181,7 +248,66 @@ class ECSOption(Option):
>>> # it understands results from `dns.edns.ECSOption.to_text()`
>>> dns.edns.ECSOption.from_text('ECS 1.2.3.4/24/32')
"""
- pass
+ optional_prefix = "ECS"
+ tokens = text.split()
+ ecs_text = None
+ if len(tokens) == 1:
+ ecs_text = tokens[0]
+ elif len(tokens) == 2:
+ if tokens[0] != optional_prefix:
+ raise ValueError('could not parse ECS from "{}"'.format(text))
+ ecs_text = tokens[1]
+ else:
+ raise ValueError('could not parse ECS from "{}"'.format(text))
+ n_slashes = ecs_text.count("/")
+ if n_slashes == 1:
+ address, tsrclen = ecs_text.split("/")
+ tscope = "0"
+ elif n_slashes == 2:
+ address, tsrclen, tscope = ecs_text.split("/")
+ else:
+ raise ValueError('could not parse ECS from "{}"'.format(text))
+ try:
+ scope = int(tscope)
+ except ValueError:
+ raise ValueError(
+ "invalid scope " + '"{}": scope must be an integer'.format(tscope)
+ )
+ try:
+ srclen = int(tsrclen)
+ except ValueError:
+ raise ValueError(
+ "invalid srclen " + '"{}": srclen must be an integer'.format(tsrclen)
+ )
+ return ECSOption(address, srclen, scope)
+
+ def to_wire(self, file: Optional[Any] = None) -> Optional[bytes]:
+ value = (
+ struct.pack("!HBB", self.family, self.srclen, self.scopelen) + self.addrdata
+ )
+ if file:
+ file.write(value)
+ return None
+ else:
+ return value
+
+ @classmethod
+ def from_wire_parser(
+ cls, otype: Union[OptionType, str], parser: "dns.wire.Parser"
+ ) -> Option:
+ family, src, scope = parser.get_struct("!HBB")
+ addrlen = int(math.ceil(src / 8.0))
+ prefix = parser.get_bytes(addrlen)
+ if family == 1:
+ pad = 4 - addrlen
+ addr = dns.ipv4.inet_ntoa(prefix + b"\x00" * pad)
+ elif family == 2:
+ pad = 16 - addrlen
+ addr = dns.ipv6.inet_ntoa(prefix + b"\x00" * pad)
+ else:
+ raise ValueError("unsupported family")
+
+ return cls(addr, src, scope)
class EDECode(dns.enum.IntEnum):
@@ -211,47 +337,122 @@ class EDECode(dns.enum.IntEnum):
NETWORK_ERROR = 23
INVALID_DATA = 24
+ @classmethod
+ def _maximum(cls):
+ return 65535
+
-class EDEOption(Option):
+class EDEOption(Option): # lgtm[py/missing-equals]
"""Extended DNS Error (EDE, RFC8914)"""
- _preserve_case = {'DNSKEY', 'DS', 'DNSSEC', 'RRSIGs', 'NSEC', 'NXDOMAIN'}
- def __init__(self, code: Union[EDECode, str], text: Optional[str]=None):
+ _preserve_case = {"DNSKEY", "DS", "DNSSEC", "RRSIGs", "NSEC", "NXDOMAIN"}
+
+ def __init__(self, code: Union[EDECode, str], text: Optional[str] = None):
"""*code*, a ``dns.edns.EDECode`` or ``str``, the info code of the
extended error.
*text*, a ``str`` or ``None``, specifying additional information about
the error.
"""
+
super().__init__(OptionType.EDE)
+
self.code = EDECode.make(code)
if text is not None and not isinstance(text, str):
- raise ValueError('text must be string or None')
+ raise ValueError("text must be string or None")
self.text = text
+ def to_text(self) -> str:
+ output = f"EDE {self.code}"
+ if self.code in EDECode:
+ desc = EDECode.to_text(self.code)
+ desc = " ".join(
+ word if word in self._preserve_case else word.title()
+ for word in desc.split("_")
+ )
+ output += f" ({desc})"
+ if self.text is not None:
+ output += f": {self.text}"
+ return output
+
+ def to_wire(self, file: Optional[Any] = None) -> Optional[bytes]:
+ value = struct.pack("!H", self.code)
+ if self.text is not None:
+ value += self.text.encode("utf8")
+
+ if file:
+ file.write(value)
+ return None
+ else:
+ return value
+
+ @classmethod
+ def from_wire_parser(
+ cls, otype: Union[OptionType, str], parser: "dns.wire.Parser"
+ ) -> Option:
+ code = EDECode.make(parser.get_uint16())
+ text = parser.get_remaining()
+
+ if text:
+ if text[-1] == 0: # text MAY be null-terminated
+ text = text[:-1]
+ btext = text.decode("utf8")
+ else:
+ btext = None
+
+ return cls(code, btext)
-class NSIDOption(Option):
+class NSIDOption(Option):
def __init__(self, nsid: bytes):
super().__init__(OptionType.NSID)
self.nsid = nsid
+ def to_wire(self, file: Any = None) -> Optional[bytes]:
+ if file:
+ file.write(self.nsid)
+ return None
+ else:
+ return self.nsid
+
+ def to_text(self) -> str:
+ if all(c >= 0x20 and c <= 0x7E for c in self.nsid):
+ # All ASCII printable, so it's probably a string.
+ value = self.nsid.decode()
+ else:
+ value = binascii.hexlify(self.nsid).decode()
+ return f"NSID {value}"
+
+ @classmethod
+ def from_wire_parser(
+ cls, otype: Union[OptionType, str], parser: dns.wire.Parser
+ ) -> Option:
+ return cls(parser.get_remaining())
-_type_to_class: Dict[OptionType, Any] = {OptionType.ECS: ECSOption,
- OptionType.EDE: EDEOption, OptionType.NSID: NSIDOption}
+_type_to_class: Dict[OptionType, Any] = {
+ OptionType.ECS: ECSOption,
+ OptionType.EDE: EDEOption,
+ OptionType.NSID: NSIDOption,
+}
-def get_option_class(otype: OptionType) ->Any:
+
+def get_option_class(otype: OptionType) -> Any:
"""Return the class for the specified option type.
The GenericOption class is used if a more specific class is not
known.
"""
- pass
+
+ cls = _type_to_class.get(otype)
+ if cls is None:
+ cls = GenericOption
+ return cls
-def option_from_wire_parser(otype: Union[OptionType, str], parser:
- 'dns.wire.Parser') ->Option:
+def option_from_wire_parser(
+ otype: Union[OptionType, str], parser: "dns.wire.Parser"
+) -> Option:
"""Build an EDNS option object from wire format.
*otype*, an ``int``, is the option type.
@@ -261,11 +462,14 @@ def option_from_wire_parser(otype: Union[OptionType, str], parser:
Returns an instance of a subclass of ``dns.edns.Option``.
"""
- pass
+ otype = OptionType.make(otype)
+ cls = get_option_class(otype)
+ return cls.from_wire_parser(otype, parser)
-def option_from_wire(otype: Union[OptionType, str], wire: bytes, current:
- int, olen: int) ->Option:
+def option_from_wire(
+ otype: Union[OptionType, str], wire: bytes, current: int, olen: int
+) -> Option:
"""Build an EDNS option object from wire format.
*otype*, an ``int``, is the option type.
@@ -279,19 +483,24 @@ def option_from_wire(otype: Union[OptionType, str], wire: bytes, current:
Returns an instance of a subclass of ``dns.edns.Option``.
"""
- pass
+ parser = dns.wire.Parser(wire, current)
+ with parser.restrict_to(olen):
+ return option_from_wire_parser(otype, parser)
-def register_type(implementation: Any, otype: OptionType) ->None:
+def register_type(implementation: Any, otype: OptionType) -> None:
"""Register the implementation of an option type.
*implementation*, a ``class``, is a subclass of ``dns.edns.Option``.
*otype*, an ``int``, is the option type.
"""
- pass
+
+ _type_to_class[otype] = implementation
+### BEGIN generated OptionType constants
+
NSID = OptionType.NSID
DAU = OptionType.DAU
DHU = OptionType.DHU
@@ -303,3 +512,5 @@ KEEPALIVE = OptionType.KEEPALIVE
PADDING = OptionType.PADDING
CHAIN = OptionType.CHAIN
EDE = OptionType.EDE
+
+### END generated OptionType constants
diff --git a/dns/entropy.py b/dns/entropy.py
index 7e11b03..4dcdc62 100644
--- a/dns/entropy.py
+++ b/dns/entropy.py
@@ -1,3 +1,20 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2009-2017 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
import hashlib
import os
import random
@@ -7,15 +24,19 @@ from typing import Any, Optional
class EntropyPool:
+ # This is an entropy pool for Python implementations that do not
+ # have a working SystemRandom. I'm not sure there are any, but
+ # leaving this code doesn't hurt anything as the library code
+ # is used if present.
- def __init__(self, seed: Optional[bytes]=None):
+ def __init__(self, seed: Optional[bytes] = None):
self.pool_index = 0
self.digest: Optional[bytearray] = None
self.next_byte = 0
self.lock = threading.Lock()
self.hash = hashlib.sha1()
self.hash_len = 20
- self.pool = bytearray(b'\x00' * self.hash_len)
+ self.pool = bytearray(b"\0" * self.hash_len)
if seed is not None:
self._stir(seed)
self.seeded = True
@@ -24,10 +45,86 @@ class EntropyPool:
self.seeded = False
self.seed_pid = 0
+ def _stir(self, entropy: bytes) -> None:
+ for c in entropy:
+ if self.pool_index == self.hash_len:
+ self.pool_index = 0
+ b = c & 0xFF
+ self.pool[self.pool_index] ^= b
+ self.pool_index += 1
+
+ def stir(self, entropy: bytes) -> None:
+ with self.lock:
+ self._stir(entropy)
+
+ def _maybe_seed(self) -> None:
+ if not self.seeded or self.seed_pid != os.getpid():
+ try:
+ seed = os.urandom(16)
+ except Exception: # pragma: no cover
+ try:
+ with open("/dev/urandom", "rb", 0) as r:
+ seed = r.read(16)
+ except Exception:
+ seed = str(time.time()).encode()
+ self.seeded = True
+ self.seed_pid = os.getpid()
+ self.digest = None
+ seed = bytearray(seed)
+ self._stir(seed)
+
+ def random_8(self) -> int:
+ with self.lock:
+ self._maybe_seed()
+ if self.digest is None or self.next_byte == self.hash_len:
+ self.hash.update(bytes(self.pool))
+ self.digest = bytearray(self.hash.digest())
+ self._stir(self.digest)
+ self.next_byte = 0
+ value = self.digest[self.next_byte]
+ self.next_byte += 1
+ return value
+
+ def random_16(self) -> int:
+ return self.random_8() * 256 + self.random_8()
+
+ def random_32(self) -> int:
+ return self.random_16() * 65536 + self.random_16()
+
+ def random_between(self, first: int, last: int) -> int:
+ size = last - first + 1
+ if size > 4294967296:
+ raise ValueError("too big")
+ if size > 65536:
+ rand = self.random_32
+ max = 4294967295
+ elif size > 256:
+ rand = self.random_16
+ max = 65535
+ else:
+ rand = self.random_8
+ max = 255
+ return first + size * rand() // (max + 1)
+
pool = EntropyPool()
+
system_random: Optional[Any]
try:
system_random = random.SystemRandom()
-except Exception:
+except Exception: # pragma: no cover
system_random = None
+
+
+def random_16() -> int:
+ if system_random is not None:
+ return system_random.randrange(0, 65536)
+ else:
+ return pool.random_16()
+
+
+def between(first: int, last: int) -> int:
+ if system_random is not None:
+ return system_random.randrange(first, last + 1)
+ else:
+ return pool.random_between(first, last)
diff --git a/dns/enum.py b/dns/enum.py
index c6d69c6..71461f1 100644
--- a/dns/enum.py
+++ b/dns/enum.py
@@ -1,12 +1,78 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2017 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
import enum
from typing import Type, TypeVar, Union
-TIntEnum = TypeVar('TIntEnum', bound='IntEnum')
+
+TIntEnum = TypeVar("TIntEnum", bound="IntEnum")
class IntEnum(enum.IntEnum):
+ @classmethod
+ def _missing_(cls, value):
+ cls._check_value(value)
+ val = int.__new__(cls, value)
+ val._name_ = cls._extra_to_text(value, None) or f"{cls._prefix()}{value}"
+ val._value_ = value
+ return val
+
+ @classmethod
+ def _check_value(cls, value):
+ max = cls._maximum()
+ if not isinstance(value, int):
+ raise TypeError
+ if value < 0 or value > max:
+ name = cls._short_name()
+ raise ValueError(f"{name} must be an int between >= 0 and <= {max}")
@classmethod
- def make(cls: Type[TIntEnum], value: Union[int, str]) ->TIntEnum:
+ def from_text(cls: Type[TIntEnum], text: str) -> TIntEnum:
+ text = text.upper()
+ try:
+ return cls[text]
+ except KeyError:
+ pass
+ value = cls._extra_from_text(text)
+ if value:
+ return value
+ prefix = cls._prefix()
+ if text.startswith(prefix) and text[len(prefix) :].isdigit():
+ value = int(text[len(prefix) :])
+ cls._check_value(value)
+ try:
+ return cls(value)
+ except ValueError:
+ return value
+ raise cls._unknown_exception_class()
+
+ @classmethod
+ def to_text(cls: Type[TIntEnum], value: int) -> str:
+ cls._check_value(value)
+ try:
+ text = cls(value).name
+ except ValueError:
+ text = None
+ text = cls._extra_to_text(value, text)
+ if text is None:
+ text = f"{cls._prefix()}{value}"
+ return text
+
+ @classmethod
+ def make(cls: Type[TIntEnum], value: Union[int, str]) -> TIntEnum:
"""Convert text or a value into an enumerated type, if possible.
*value*, the ``int`` or ``str`` to convert.
@@ -19,4 +85,32 @@ class IntEnum(enum.IntEnum):
Returns an enumeration from the calling class corresponding to the
value, if one is defined, or an ``int`` otherwise.
"""
- pass
+
+ if isinstance(value, str):
+ return cls.from_text(value)
+ cls._check_value(value)
+ return cls(value)
+
+ @classmethod
+ def _maximum(cls):
+ raise NotImplementedError # pragma: no cover
+
+ @classmethod
+ def _short_name(cls):
+ return cls.__name__.lower()
+
+ @classmethod
+ def _prefix(cls):
+ return ""
+
+ @classmethod
+ def _extra_from_text(cls, text): # pylint: disable=W0613
+ return None
+
+ @classmethod
+ def _extra_to_text(cls, value, current_text): # pylint: disable=W0613
+ return current_text
+
+ @classmethod
+ def _unknown_exception_class(cls):
+ return ValueError
diff --git a/dns/exception.py b/dns/exception.py
index 4f53e7b..6982373 100644
--- a/dns/exception.py
+++ b/dns/exception.py
@@ -1,8 +1,27 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2017 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
"""Common DNS Exceptions.
Dnspython modules may also define their own exceptions, which will
always be subclasses of ``DNSException``.
"""
+
+
from typing import Optional, Set
@@ -28,18 +47,21 @@ class DNSException(Exception):
In the simplest case it is enough to override the ``supp_kwargs``
and ``fmt`` class variables to get nice parametrized messages.
"""
- msg: Optional[str] = None
- supp_kwargs: Set[str] = set()
- fmt: Optional[str] = None
+
+ msg: Optional[str] = None # non-parametrized message
+ supp_kwargs: Set[str] = set() # accepted parameters for _fmt_kwargs (sanity check)
+ fmt: Optional[str] = None # message parametrized with results from _fmt_kwargs
def __init__(self, *args, **kwargs):
self._check_params(*args, **kwargs)
if kwargs:
- self.kwargs = self._check_kwargs(**kwargs)
+ # This call to a virtual method from __init__ is ok in our usage
+ self.kwargs = self._check_kwargs(**kwargs) # lgtm[py/init-calls-subclass]
self.msg = str(self)
else:
- self.kwargs = dict()
+ self.kwargs = dict() # defined but empty for old mode exceptions
if self.msg is None:
+ # doc string is better implicit message than empty string
self.msg = self.__doc__
if args:
super().__init__(*args)
@@ -50,7 +72,17 @@ class DNSException(Exception):
"""Old exceptions supported only args and not kwargs.
For sanity we do not allow to mix old and new behavior."""
- pass
+ if args or kwargs:
+ assert bool(args) != bool(
+ kwargs
+ ), "keyword arguments are mutually exclusive with positional args"
+
+ def _check_kwargs(self, **kwargs):
+ if kwargs:
+ assert (
+ set(kwargs.keys()) == self.supp_kwargs
+ ), "following set of keyword args is required: %s" % (self.supp_kwargs)
+ return kwargs
def _fmt_kwargs(self, **kwargs):
"""Format kwargs before printing them.
@@ -58,13 +90,25 @@ class DNSException(Exception):
Resulting dictionary has to have keys necessary for str.format call
on fmt class variable.
"""
- pass
+ fmtargs = {}
+ for kw, data in kwargs.items():
+ if isinstance(data, (list, set)):
+ # convert list of <someobj> to list of str(<someobj>)
+ fmtargs[kw] = list(map(str, data))
+ if len(fmtargs[kw]) == 1:
+ # remove list brackets [] from single-item lists
+ fmtargs[kw] = fmtargs[kw].pop()
+ else:
+ fmtargs[kw] = data
+ return fmtargs
def __str__(self):
if self.kwargs and self.fmt:
+ # provide custom message constructed from keyword arguments
fmtargs = self._fmt_kwargs(**self.kwargs)
return self.fmt.format(**fmtargs)
else:
+ # print *args directly in the same way as old DNSException
return super().__str__()
@@ -86,9 +130,12 @@ class TooBig(DNSException):
class Timeout(DNSException):
"""The DNS operation timed out."""
- supp_kwargs = {'timeout'}
- fmt = 'The DNS operation timed out after {timeout:.3f} seconds'
+ supp_kwargs = {"timeout"}
+ fmt = "The DNS operation timed out after {timeout:.3f} seconds"
+
+ # We do this as otherwise mypy complains about unexpected keyword argument
+ # idna_exception
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -110,7 +157,6 @@ class DeniedByPolicy(DNSException):
class ExceptionWrapper:
-
def __init__(self, exception_class):
self.exception_class = exception_class
@@ -118,7 +164,6 @@ class ExceptionWrapper:
return self
def __exit__(self, exc_type, exc_val, exc_tb):
- if exc_type is not None and not isinstance(exc_val, self.
- exception_class):
+ if exc_type is not None and not isinstance(exc_val, self.exception_class):
raise self.exception_class(str(exc_val)) from exc_val
return False
diff --git a/dns/flags.py b/dns/flags.py
index f682e9b..4c60be1 100644
--- a/dns/flags.py
+++ b/dns/flags.py
@@ -1,57 +1,110 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2001-2017 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
"""DNS Message Flags."""
+
import enum
from typing import Any
+# Standard DNS flags
+
class Flag(enum.IntFlag):
- QR = 32768
- AA = 1024
- TC = 512
- RD = 256
- RA = 128
- AD = 32
- CD = 16
+ #: Query Response
+ QR = 0x8000
+ #: Authoritative Answer
+ AA = 0x0400
+ #: Truncated Response
+ TC = 0x0200
+ #: Recursion Desired
+ RD = 0x0100
+ #: Recursion Available
+ RA = 0x0080
+ #: Authentic Data
+ AD = 0x0020
+ #: Checking Disabled
+ CD = 0x0010
+
+
+# EDNS flags
class EDNSFlag(enum.IntFlag):
- DO = 32768
+ #: DNSSEC answer OK
+ DO = 0x8000
+
+
+def _from_text(text: str, enum_class: Any) -> int:
+ flags = 0
+ tokens = text.split()
+ for t in tokens:
+ flags |= enum_class[t.upper()]
+ return flags
+
+
+def _to_text(flags: int, enum_class: Any) -> str:
+ text_flags = []
+ for k, v in enum_class.__members__.items():
+ if flags & v != 0:
+ text_flags.append(k)
+ return " ".join(text_flags)
-def from_text(text: str) ->int:
+def from_text(text: str) -> int:
"""Convert a space-separated list of flag text values into a flags
value.
Returns an ``int``
"""
- pass
+ return _from_text(text, Flag)
-def to_text(flags: int) ->str:
+
+def to_text(flags: int) -> str:
"""Convert a flags value into a space-separated list of flag text
values.
Returns a ``str``.
"""
- pass
+
+ return _to_text(flags, Flag)
-def edns_from_text(text: str) ->int:
+def edns_from_text(text: str) -> int:
"""Convert a space-separated list of EDNS flag text values into a EDNS
flags value.
Returns an ``int``
"""
- pass
+
+ return _from_text(text, EDNSFlag)
-def edns_to_text(flags: int) ->str:
+def edns_to_text(flags: int) -> str:
"""Convert an EDNS flags value into a space-separated list of EDNS flag
text values.
Returns a ``str``.
"""
- pass
+ return _to_text(flags, EDNSFlag)
+
+
+### BEGIN generated Flag constants
QR = Flag.QR
AA = Flag.AA
@@ -60,4 +113,11 @@ RD = Flag.RD
RA = Flag.RA
AD = Flag.AD
CD = Flag.CD
+
+### END generated Flag constants
+
+### BEGIN generated EDNSFlag constants
+
DO = EDNSFlag.DO
+
+### END generated EDNSFlag constants
diff --git a/dns/grange.py b/dns/grange.py
index af2c200..3a52278 100644
--- a/dns/grange.py
+++ b/dns/grange.py
@@ -1,9 +1,28 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2012-2017 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
"""DNS GENERATE range conversion."""
+
from typing import Tuple
+
import dns
-def from_text(text: str) ->Tuple[int, int, int]:
+def from_text(text: str) -> Tuple[int, int, int]:
"""Convert the text form of a range in a ``$GENERATE`` statement to an
integer.
@@ -11,4 +30,43 @@ def from_text(text: str) ->Tuple[int, int, int]:
Returns a tuple of three ``int`` values ``(start, stop, step)``.
"""
- pass
+
+ start = -1
+ stop = -1
+ step = 1
+ cur = ""
+ state = 0
+ # state 0 1 2
+ # x - y / z
+
+ if text and text[0] == "-":
+ raise dns.exception.SyntaxError("Start cannot be a negative number")
+
+ for c in text:
+ if c == "-" and state == 0:
+ start = int(cur)
+ cur = ""
+ state = 1
+ elif c == "/":
+ stop = int(cur)
+ cur = ""
+ state = 2
+ elif c.isdigit():
+ cur += c
+ else:
+ raise dns.exception.SyntaxError("Could not parse %s" % (c))
+
+ if state == 0:
+ raise dns.exception.SyntaxError("no stop value specified")
+ elif state == 1:
+ stop = int(cur)
+ else:
+ assert state == 2
+ step = int(cur)
+
+ assert step >= 1
+ assert start >= 0
+ if start > stop:
+ raise dns.exception.SyntaxError("start must be <= stop")
+
+ return (start, stop, step)
diff --git a/dns/immutable.py b/dns/immutable.py
index 1170831..36b0362 100644
--- a/dns/immutable.py
+++ b/dns/immutable.py
@@ -1,13 +1,19 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
import collections.abc
from typing import Any, Callable
+
from dns._immutable_ctx import immutable
@immutable
-class Dict(collections.abc.Mapping):
-
- def __init__(self, dictionary: Any, no_copy: bool=False, map_factory:
- Callable[[], collections.abc.MutableMapping]=dict):
+class Dict(collections.abc.Mapping): # lgtm[py/missing-equals]
+ def __init__(
+ self,
+ dictionary: Any,
+ no_copy: bool = False,
+ map_factory: Callable[[], collections.abc.MutableMapping] = dict,
+ ):
"""Make an immutable dictionary from the specified dictionary.
If *no_copy* is `True`, then *dictionary* will be wrapped instead
@@ -24,12 +30,13 @@ class Dict(collections.abc.Mapping):
def __getitem__(self, key):
return self._odict.__getitem__(key)
- def __hash__(self):
+ def __hash__(self): # pylint: disable=invalid-hash-returned
if self._hash is None:
h = 0
for key in sorted(self._odict.keys()):
h ^= hash(key)
- object.__setattr__(self, '_hash', h)
+ object.__setattr__(self, "_hash", h)
+ # this does return an int, but pylint doesn't figure that out
return self._hash
def __len__(self):
@@ -39,8 +46,23 @@ class Dict(collections.abc.Mapping):
return iter(self._odict)
-def constify(o: Any) ->Any:
+def constify(o: Any) -> Any:
"""
Convert mutable types to immutable types.
"""
- pass
+ if isinstance(o, bytearray):
+ return bytes(o)
+ if isinstance(o, tuple):
+ try:
+ hash(o)
+ return o
+ except Exception:
+ return tuple(constify(elt) for elt in o)
+ if isinstance(o, list):
+ return tuple(constify(elt) for elt in o)
+ if isinstance(o, dict):
+ cdict = dict()
+ for k, v in o.items():
+ cdict[k] = constify(v)
+ return Dict(cdict, True)
+ return o
diff --git a/dns/inet.py b/dns/inet.py
index 8c49f86..4a03f99 100644
--- a/dns/inet.py
+++ b/dns/inet.py
@@ -1,13 +1,36 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2017 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
"""Generic Internet address helper functions."""
+
import socket
from typing import Any, Optional, Tuple
+
import dns.ipv4
import dns.ipv6
+
+# We assume that AF_INET and AF_INET6 are always defined. We keep
+# these here for the benefit of any old code (unlikely though that
+# is!).
AF_INET = socket.AF_INET
AF_INET6 = socket.AF_INET6
-def inet_pton(family: int, text: str) ->bytes:
+def inet_pton(family: int, text: str) -> bytes:
"""Convert the textual form of a network address into its binary form.
*family* is an ``int``, the address family.
@@ -19,10 +42,16 @@ def inet_pton(family: int, text: str) ->bytes:
Returns a ``bytes``.
"""
- pass
+
+ if family == AF_INET:
+ return dns.ipv4.inet_aton(text)
+ elif family == AF_INET6:
+ return dns.ipv6.inet_aton(text, True)
+ else:
+ raise NotImplementedError
-def inet_ntop(family: int, address: bytes) ->str:
+def inet_ntop(family: int, address: bytes) -> str:
"""Convert the binary form of a network address into its textual form.
*family* is an ``int``, the address family.
@@ -34,10 +63,16 @@ def inet_ntop(family: int, address: bytes) ->str:
Returns a ``str``.
"""
- pass
+ if family == AF_INET:
+ return dns.ipv4.inet_ntoa(address)
+ elif family == AF_INET6:
+ return dns.ipv6.inet_ntoa(address)
+ else:
+ raise NotImplementedError
-def af_for_address(text: str) ->int:
+
+def af_for_address(text: str) -> int:
"""Determine the address family of a textual-form network address.
*text*, a ``str``, the textual address.
@@ -47,10 +82,19 @@ def af_for_address(text: str) ->int:
Returns an ``int``.
"""
- pass
+
+ try:
+ dns.ipv4.inet_aton(text)
+ return AF_INET
+ except Exception:
+ try:
+ dns.ipv6.inet_aton(text, True)
+ return AF_INET6
+ except Exception:
+ raise ValueError
-def is_multicast(text: str) ->bool:
+def is_multicast(text: str) -> bool:
"""Is the textual-form network address a multicast address?
*text*, a ``str``, the textual address.
@@ -60,21 +104,40 @@ def is_multicast(text: str) ->bool:
Returns a ``bool``.
"""
- pass
+ try:
+ first = dns.ipv4.inet_aton(text)[0]
+ return first >= 224 and first <= 239
+ except Exception:
+ try:
+ first = dns.ipv6.inet_aton(text, True)[0]
+ return first == 255
+ except Exception:
+ raise ValueError
-def is_address(text: str) ->bool:
+
+def is_address(text: str) -> bool:
"""Is the specified string an IPv4 or IPv6 address?
*text*, a ``str``, the textual address.
Returns a ``bool``.
"""
- pass
+
+ try:
+ dns.ipv4.inet_aton(text)
+ return True
+ except Exception:
+ try:
+ dns.ipv6.inet_aton(text, True)
+ return True
+ except Exception:
+ return False
-def low_level_address_tuple(high_tuple: Tuple[str, int], af: Optional[int]=None
- ) ->Any:
+def low_level_address_tuple(
+ high_tuple: Tuple[str, int], af: Optional[int] = None
+) -> Any:
"""Given a "high-level" address tuple, i.e.
an (address, port) return the appropriate "low-level" address tuple
suitable for use in socket calls.
@@ -83,15 +146,41 @@ def low_level_address_tuple(high_tuple: Tuple[str, int], af: Optional[int]=None
address in the high-level tuple is valid and has that af. If af
is ``None``, then af_for_address will be called.
"""
- pass
+ address, port = high_tuple
+ if af is None:
+ af = af_for_address(address)
+ if af == AF_INET:
+ return (address, port)
+ elif af == AF_INET6:
+ i = address.find("%")
+ if i < 0:
+ # no scope, shortcut!
+ return (address, port, 0, 0)
+ # try to avoid getaddrinfo()
+ addrpart = address[:i]
+ scope = address[i + 1 :]
+ if scope.isdigit():
+ return (addrpart, port, 0, int(scope))
+ try:
+ return (addrpart, port, 0, socket.if_nametoindex(scope))
+ except AttributeError: # pragma: no cover (we can't really test this)
+ ai_flags = socket.AI_NUMERICHOST
+ ((*_, tup), *_) = socket.getaddrinfo(address, port, flags=ai_flags)
+ return tup
+ else:
+ raise NotImplementedError(f"unknown address family {af}")
def any_for_af(af):
"""Return the 'any' address for the specified address family."""
- pass
+ if af == socket.AF_INET:
+ return "0.0.0.0"
+ elif af == socket.AF_INET6:
+ return "::"
+ raise NotImplementedError(f"unknown address family {af}")
-def canonicalize(text: str) ->str:
+def canonicalize(text: str) -> str:
"""Verify that *address* is a valid text form IPv4 or IPv6 address and return its
canonical text form. IPv6 addresses with scopes are rejected.
@@ -99,4 +188,10 @@ def canonicalize(text: str) ->str:
Raises ``ValueError`` if the text is not valid.
"""
- pass
+ try:
+ return dns.ipv6.canonicalize(text)
+ except Exception:
+ try:
+ return dns.ipv4.canonicalize(text)
+ except Exception:
+ raise ValueError
diff --git a/dns/ipv4.py b/dns/ipv4.py
index 00864bd..65ee69c 100644
--- a/dns/ipv4.py
+++ b/dns/ipv4.py
@@ -1,30 +1,70 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2017 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
"""IPv4 helper functions."""
+
import struct
from typing import Union
+
import dns.exception
-def inet_ntoa(address: bytes) ->str:
+def inet_ntoa(address: bytes) -> str:
"""Convert an IPv4 address in binary form to text form.
*address*, a ``bytes``, the IPv4 address in binary form.
Returns a ``str``.
"""
- pass
+
+ if len(address) != 4:
+ raise dns.exception.SyntaxError
+ return "%u.%u.%u.%u" % (address[0], address[1], address[2], address[3])
-def inet_aton(text: Union[str, bytes]) ->bytes:
+def inet_aton(text: Union[str, bytes]) -> bytes:
"""Convert an IPv4 address in text form to binary form.
*text*, a ``str`` or ``bytes``, the IPv4 address in textual form.
Returns a ``bytes``.
"""
- pass
+
+ if not isinstance(text, bytes):
+ btext = text.encode()
+ else:
+ btext = text
+ parts = btext.split(b".")
+ if len(parts) != 4:
+ raise dns.exception.SyntaxError
+ for part in parts:
+ if not part.isdigit():
+ raise dns.exception.SyntaxError
+ if len(part) > 1 and part[0] == ord("0"):
+ # No leading zeros
+ raise dns.exception.SyntaxError
+ try:
+ b = [int(part) for part in parts]
+ return struct.pack("BBBB", *b)
+ except Exception:
+ raise dns.exception.SyntaxError
-def canonicalize(text: Union[str, bytes]) ->str:
+def canonicalize(text: Union[str, bytes]) -> str:
"""Verify that *address* is a valid text form IPv4 address and return its
canonical text form.
@@ -32,4 +72,6 @@ def canonicalize(text: Union[str, bytes]) ->str:
Raises ``dns.exception.SyntaxError`` if the text is not valid.
"""
- pass
+ # Note that inet_aton() only accepts canonial form, but we still run through
+ # inet_ntoa() to ensure the output is a str.
+ return dns.ipv4.inet_ntoa(dns.ipv4.inet_aton(text))
diff --git a/dns/ipv6.py b/dns/ipv6.py
index 94ddeaa..44a1063 100644
--- a/dns/ipv6.py
+++ b/dns/ipv6.py
@@ -1,13 +1,33 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2017 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
"""IPv6 helper functions."""
+
import binascii
import re
from typing import List, Union
+
import dns.exception
import dns.ipv4
-_leading_zero = re.compile('0+([0-9a-f]+)')
+_leading_zero = re.compile(r"0+([0-9a-f]+)")
-def inet_ntoa(address: bytes) ->str:
+
+def inet_ntoa(address: bytes) -> str:
"""Convert an IPv6 address in binary form to text form.
*address*, a ``bytes``, the IPv6 address in binary form.
@@ -15,15 +35,73 @@ def inet_ntoa(address: bytes) ->str:
Raises ``ValueError`` if the address isn't 16 bytes long.
Returns a ``str``.
"""
- pass
-
-
-_v4_ending = re.compile(b'(.*):(\\d+\\.\\d+\\.\\d+\\.\\d+)$')
-_colon_colon_start = re.compile(b'::.*')
-_colon_colon_end = re.compile(b'.*::$')
-
-def inet_aton(text: Union[str, bytes], ignore_scope: bool=False) ->bytes:
+ if len(address) != 16:
+ raise ValueError("IPv6 addresses are 16 bytes long")
+ hex = binascii.hexlify(address)
+ chunks = []
+ i = 0
+ l = len(hex)
+ while i < l:
+ chunk = hex[i : i + 4].decode()
+ # strip leading zeros. we do this with an re instead of
+ # with lstrip() because lstrip() didn't support chars until
+ # python 2.2.2
+ m = _leading_zero.match(chunk)
+ if m is not None:
+ chunk = m.group(1)
+ chunks.append(chunk)
+ i += 4
+ #
+ # Compress the longest subsequence of 0-value chunks to ::
+ #
+ best_start = 0
+ best_len = 0
+ start = -1
+ last_was_zero = False
+ for i in range(8):
+ if chunks[i] != "0":
+ if last_was_zero:
+ end = i
+ current_len = end - start
+ if current_len > best_len:
+ best_start = start
+ best_len = current_len
+ last_was_zero = False
+ elif not last_was_zero:
+ start = i
+ last_was_zero = True
+ if last_was_zero:
+ end = 8
+ current_len = end - start
+ if current_len > best_len:
+ best_start = start
+ best_len = current_len
+ if best_len > 1:
+ if best_start == 0 and (best_len == 6 or best_len == 5 and chunks[5] == "ffff"):
+ # We have an embedded IPv4 address
+ if best_len == 6:
+ prefix = "::"
+ else:
+ prefix = "::ffff:"
+ thex = prefix + dns.ipv4.inet_ntoa(address[12:])
+ else:
+ thex = (
+ ":".join(chunks[:best_start])
+ + "::"
+ + ":".join(chunks[best_start + best_len :])
+ )
+ else:
+ thex = ":".join(chunks)
+ return thex
+
+
+_v4_ending = re.compile(rb"(.*):(\d+\.\d+\.\d+\.\d+)$")
+_colon_colon_start = re.compile(rb"::.*")
+_colon_colon_end = re.compile(rb".*::$")
+
+
+def inet_aton(text: Union[str, bytes], ignore_scope: bool = False) -> bytes:
"""Convert an IPv6 address in text form to binary form.
*text*, a ``str`` or ``bytes``, the IPv6 address in textual form.
@@ -33,23 +111,104 @@ def inet_aton(text: Union[str, bytes], ignore_scope: bool=False) ->bytes:
Returns a ``bytes``.
"""
- pass
-
-
-_mapped_prefix = b'\x00' * 10 + b'\xff\xff'
-
-def is_mapped(address: bytes) ->bool:
+ #
+ # Our aim here is not something fast; we just want something that works.
+ #
+ if not isinstance(text, bytes):
+ btext = text.encode()
+ else:
+ btext = text
+
+ if ignore_scope:
+ parts = btext.split(b"%")
+ l = len(parts)
+ if l == 2:
+ btext = parts[0]
+ elif l > 2:
+ raise dns.exception.SyntaxError
+
+ if btext == b"":
+ raise dns.exception.SyntaxError
+ elif btext.endswith(b":") and not btext.endswith(b"::"):
+ raise dns.exception.SyntaxError
+ elif btext.startswith(b":") and not btext.startswith(b"::"):
+ raise dns.exception.SyntaxError
+ elif btext == b"::":
+ btext = b"0::"
+ #
+ # Get rid of the icky dot-quad syntax if we have it.
+ #
+ m = _v4_ending.match(btext)
+ if m is not None:
+ b = dns.ipv4.inet_aton(m.group(2))
+ btext = (
+ "{}:{:02x}{:02x}:{:02x}{:02x}".format(
+ m.group(1).decode(), b[0], b[1], b[2], b[3]
+ )
+ ).encode()
+ #
+ # Try to turn '::<whatever>' into ':<whatever>'; if no match try to
+ # turn '<whatever>::' into '<whatever>:'
+ #
+ m = _colon_colon_start.match(btext)
+ if m is not None:
+ btext = btext[1:]
+ else:
+ m = _colon_colon_end.match(btext)
+ if m is not None:
+ btext = btext[:-1]
+ #
+ # Now canonicalize into 8 chunks of 4 hex digits each
+ #
+ chunks = btext.split(b":")
+ l = len(chunks)
+ if l > 8:
+ raise dns.exception.SyntaxError
+ seen_empty = False
+ canonical: List[bytes] = []
+ for c in chunks:
+ if c == b"":
+ if seen_empty:
+ raise dns.exception.SyntaxError
+ seen_empty = True
+ for _ in range(0, 8 - l + 1):
+ canonical.append(b"0000")
+ else:
+ lc = len(c)
+ if lc > 4:
+ raise dns.exception.SyntaxError
+ if lc != 4:
+ c = (b"0" * (4 - lc)) + c
+ canonical.append(c)
+ if l < 8 and not seen_empty:
+ raise dns.exception.SyntaxError
+ btext = b"".join(canonical)
+
+ #
+ # Finally we can go to binary.
+ #
+ try:
+ return binascii.unhexlify(btext)
+ except (binascii.Error, TypeError):
+ raise dns.exception.SyntaxError
+
+
+_mapped_prefix = b"\x00" * 10 + b"\xff\xff"
+
+
+def is_mapped(address: bytes) -> bool:
"""Is the specified address a mapped IPv4 address?
*address*, a ``bytes`` is an IPv6 address in binary form.
Returns a ``bool``.
"""
- pass
+
+ return address.startswith(_mapped_prefix)
-def canonicalize(text: Union[str, bytes]) ->str:
+def canonicalize(text: Union[str, bytes]) -> str:
"""Verify that *address* is a valid text form IPv6 address and return its
canonical text form. Addresses with scopes are rejected.
@@ -57,4 +216,4 @@ def canonicalize(text: Union[str, bytes]) ->str:
Raises ``dns.exception.SyntaxError`` if the text is not valid.
"""
- pass
+ return dns.ipv6.inet_ntoa(dns.ipv6.inet_aton(text))
diff --git a/dns/message.py b/dns/message.py
index 657451c..44cacbd 100644
--- a/dns/message.py
+++ b/dns/message.py
@@ -1,8 +1,27 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2001-2017 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
"""DNS Messages"""
+
import contextlib
import io
import time
from typing import Any, Dict, List, Optional, Tuple, Union
+
import dns.edns
import dns.entropy
import dns.enum
@@ -52,8 +71,11 @@ class UnknownTSIGKey(dns.exception.DNSException):
class Truncated(dns.exception.DNSException):
"""The truncated flag is set."""
- supp_kwargs = {'message'}
+ supp_kwargs = {"message"}
+
+ # We do this as otherwise mypy complains about unexpected keyword argument
+ # idna_exception
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -62,7 +84,7 @@ class Truncated(dns.exception.DNSException):
Returns a ``dns.message.Message``.
"""
- pass
+ return self.kwargs["message"]
class NotQueryResponse(dns.exception.DNSException):
@@ -83,14 +105,18 @@ class NoPreviousName(dns.exception.SyntaxError):
class MessageSection(dns.enum.IntEnum):
"""Message sections"""
+
QUESTION = 0
ANSWER = 1
AUTHORITY = 2
ADDITIONAL = 3
+ @classmethod
+ def _maximum(cls):
+ return 3
-class MessageError:
+class MessageError:
def __init__(self, exception: Exception, offset: int):
self.exception = exception
self.offset = offset
@@ -98,18 +124,25 @@ class MessageError:
DEFAULT_EDNS_PAYLOAD = 1232
MAX_CHAIN = 16
-IndexKeyType = Tuple[int, dns.name.Name, dns.rdataclass.RdataClass, dns.
- rdatatype.RdataType, Optional[dns.rdatatype.RdataType], Optional[dns.
- rdataclass.RdataClass]]
+
+IndexKeyType = Tuple[
+ int,
+ dns.name.Name,
+ dns.rdataclass.RdataClass,
+ dns.rdatatype.RdataType,
+ Optional[dns.rdatatype.RdataType],
+ Optional[dns.rdataclass.RdataClass],
+]
IndexType = Dict[IndexKeyType, dns.rrset.RRset]
SectionType = Union[int, str, List[dns.rrset.RRset]]
class Message:
"""A DNS message."""
+
_section_enum = MessageSection
- def __init__(self, id: Optional[int]=None):
+ def __init__(self, id: Optional[int] = None):
if id is None:
self.id = dns.entropy.random_16()
else:
@@ -121,7 +154,7 @@ class Message:
self.pad = 0
self.keyring: Any = None
self.tsig: Optional[dns.rrset.RRset] = None
- self.request_mac = b''
+ self.request_mac = b""
self.xfr = False
self.origin: Optional[dns.name.Name] = None
self.tsig_ctx: Optional[Any] = None
@@ -130,33 +163,53 @@ class Message:
self.time = 0.0
@property
- def question(self) ->List[dns.rrset.RRset]:
+ def question(self) -> List[dns.rrset.RRset]:
"""The question section."""
- pass
+ return self.sections[0]
+
+ @question.setter
+ def question(self, v):
+ self.sections[0] = v
@property
- def answer(self) ->List[dns.rrset.RRset]:
+ def answer(self) -> List[dns.rrset.RRset]:
"""The answer section."""
- pass
+ return self.sections[1]
+
+ @answer.setter
+ def answer(self, v):
+ self.sections[1] = v
@property
- def authority(self) ->List[dns.rrset.RRset]:
+ def authority(self) -> List[dns.rrset.RRset]:
"""The authority section."""
- pass
+ return self.sections[2]
+
+ @authority.setter
+ def authority(self, v):
+ self.sections[2] = v
@property
- def additional(self) ->List[dns.rrset.RRset]:
+ def additional(self) -> List[dns.rrset.RRset]:
"""The additional data section."""
- pass
+ return self.sections[3]
+
+ @additional.setter
+ def additional(self, v):
+ self.sections[3] = v
def __repr__(self):
- return '<DNS message, ID ' + repr(self.id) + '>'
+ return "<DNS message, ID " + repr(self.id) + ">"
def __str__(self):
return self.to_text()
- def to_text(self, origin: Optional[dns.name.Name]=None, relativize:
- bool=True, **kw: Dict[str, Any]) ->str:
+ def to_text(
+ self,
+ origin: Optional[dns.name.Name] = None,
+ relativize: bool = True,
+ **kw: Dict[str, Any],
+ ) -> str:
"""Convert the message to text.
The *origin*, *relativize*, and any other keyword
@@ -164,7 +217,30 @@ class Message:
Returns a ``str``.
"""
- pass
+
+ s = io.StringIO()
+ s.write("id %d\n" % self.id)
+ s.write("opcode %s\n" % dns.opcode.to_text(self.opcode()))
+ s.write("rcode %s\n" % dns.rcode.to_text(self.rcode()))
+ s.write("flags %s\n" % dns.flags.to_text(self.flags))
+ if self.edns >= 0:
+ s.write("edns %s\n" % self.edns)
+ if self.ednsflags != 0:
+ s.write("eflags %s\n" % dns.flags.edns_to_text(self.ednsflags))
+ s.write("payload %d\n" % self.payload)
+ for opt in self.options:
+ s.write("option %s\n" % opt.to_text())
+ for name, which in self._section_enum.__members__.items():
+ s.write(f";{name}\n")
+ for rrset in self.section_from_number(which):
+ s.write(rrset.to_text(origin, relativize, **kw))
+ s.write("\n")
+ #
+ # We strip off the final \n so the caller can print the result without
+ # doing weird things to get around eccentricities in Python print
+ # formatting
+ #
+ return s.getvalue()[:-1]
def __eq__(self, other):
"""Two messages are equal if they have the same content in the
@@ -172,6 +248,7 @@ class Message:
Returns a ``bool``.
"""
+
if not isinstance(other, Message):
return False
if self.id != other.id:
@@ -191,15 +268,45 @@ class Message:
def __ne__(self, other):
return not self.__eq__(other)
- def is_response(self, other: 'Message') ->bool:
+ def is_response(self, other: "Message") -> bool:
"""Is *other*, also a ``dns.message.Message``, a response to this
message?
Returns a ``bool``.
"""
- pass
- def section_number(self, section: List[dns.rrset.RRset]) ->int:
+ if (
+ other.flags & dns.flags.QR == 0
+ or self.id != other.id
+ or dns.opcode.from_flags(self.flags) != dns.opcode.from_flags(other.flags)
+ ):
+ return False
+ if other.rcode() in {
+ dns.rcode.FORMERR,
+ dns.rcode.SERVFAIL,
+ dns.rcode.NOTIMP,
+ dns.rcode.REFUSED,
+ }:
+ # We don't check the question section in these cases if
+ # the other question section is empty, even though they
+ # still really ought to have a question section.
+ if len(other.question) == 0:
+ return True
+ if dns.opcode.is_update(self.flags):
+ # This is assuming the "sender doesn't include anything
+ # from the update", but we don't care to check the other
+ # case, which is that all the sections are returned and
+ # identical.
+ return True
+ for n in self.question:
+ if n not in other.question:
+ return False
+ for n in other.question:
+ if n not in self.question:
+ return False
+ return True
+
+ def section_number(self, section: List[dns.rrset.RRset]) -> int:
"""Return the "section number" of the specified section for use
in indexing.
@@ -209,9 +316,13 @@ class Message:
Returns an ``int``.
"""
- pass
- def section_from_number(self, number: int) ->List[dns.rrset.RRset]:
+ for i, our_section in enumerate(self.sections):
+ if section is our_section:
+ return self._section_enum(i)
+ raise ValueError("unknown section")
+
+ def section_from_number(self, number: int) -> List[dns.rrset.RRset]:
"""Return the section list associated with the specified section
number.
@@ -222,14 +333,22 @@ class Message:
Returns a ``list``.
"""
- pass
-
- def find_rrset(self, section: SectionType, name: dns.name.Name, rdclass:
- dns.rdataclass.RdataClass, rdtype: dns.rdatatype.RdataType, covers:
- dns.rdatatype.RdataType=dns.rdatatype.NONE, deleting: Optional[dns.
- rdataclass.RdataClass]=None, create: bool=False, force_unique: bool
- =False, idna_codec: Optional[dns.name.IDNACodec]=None
- ) ->dns.rrset.RRset:
+
+ section = self._section_enum.make(number)
+ return self.sections[section]
+
+ def find_rrset(
+ self,
+ section: SectionType,
+ name: dns.name.Name,
+ rdclass: dns.rdataclass.RdataClass,
+ rdtype: dns.rdatatype.RdataType,
+ covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
+ deleting: Optional[dns.rdataclass.RdataClass] = None,
+ create: bool = False,
+ force_unique: bool = False,
+ idna_codec: Optional[dns.name.IDNACodec] = None,
+ ) -> dns.rrset.RRset:
"""Find the RRset with the given attributes in the specified section.
*section*, an ``int`` section number, a ``str`` section name, or one of
@@ -269,14 +388,52 @@ class Message:
Returns a ``dns.rrset.RRset object``.
"""
- pass
-
- def get_rrset(self, section: SectionType, name: dns.name.Name, rdclass:
- dns.rdataclass.RdataClass, rdtype: dns.rdatatype.RdataType, covers:
- dns.rdatatype.RdataType=dns.rdatatype.NONE, deleting: Optional[dns.
- rdataclass.RdataClass]=None, create: bool=False, force_unique: bool
- =False, idna_codec: Optional[dns.name.IDNACodec]=None) ->Optional[dns
- .rrset.RRset]:
+
+ if isinstance(section, int):
+ section_number = section
+ section = self.section_from_number(section_number)
+ elif isinstance(section, str):
+ section_number = self._section_enum.from_text(section)
+ section = self.section_from_number(section_number)
+ else:
+ section_number = self.section_number(section)
+ if isinstance(name, str):
+ name = dns.name.from_text(name, idna_codec=idna_codec)
+ rdtype = dns.rdatatype.RdataType.make(rdtype)
+ rdclass = dns.rdataclass.RdataClass.make(rdclass)
+ covers = dns.rdatatype.RdataType.make(covers)
+ if deleting is not None:
+ deleting = dns.rdataclass.RdataClass.make(deleting)
+ key = (section_number, name, rdclass, rdtype, covers, deleting)
+ if not force_unique:
+ if self.index is not None:
+ rrset = self.index.get(key)
+ if rrset is not None:
+ return rrset
+ else:
+ for rrset in section:
+ if rrset.full_match(name, rdclass, rdtype, covers, deleting):
+ return rrset
+ if not create:
+ raise KeyError
+ rrset = dns.rrset.RRset(name, rdclass, rdtype, covers, deleting)
+ section.append(rrset)
+ if self.index is not None:
+ self.index[key] = rrset
+ return rrset
+
+ def get_rrset(
+ self,
+ section: SectionType,
+ name: dns.name.Name,
+ rdclass: dns.rdataclass.RdataClass,
+ rdtype: dns.rdatatype.RdataType,
+ covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
+ deleting: Optional[dns.rdataclass.RdataClass] = None,
+ create: bool = False,
+ force_unique: bool = False,
+ idna_codec: Optional[dns.name.IDNACodec] = None,
+ ) -> Optional[dns.rrset.RRset]:
"""Get the RRset with the given attributes in the specified section.
If the RRset is not found, None is returned.
@@ -315,9 +472,24 @@ class Message:
Returns a ``dns.rrset.RRset object`` or ``None``.
"""
- pass
- def section_count(self, section: SectionType) ->int:
+ try:
+ rrset = self.find_rrset(
+ section,
+ name,
+ rdclass,
+ rdtype,
+ covers,
+ deleting,
+ create,
+ force_unique,
+ idna_codec,
+ )
+ except KeyError:
+ rrset = None
+ return rrset
+
+ def section_count(self, section: SectionType) -> int:
"""Returns the number of records in the specified section.
*section*, an ``int`` section number, a ``str`` section name, or one of
@@ -328,20 +500,65 @@ class Message:
my_message.section_count(dns.message.ANSWER)
my_message.section_count("ANSWER")
"""
- pass
- def _compute_opt_reserve(self) ->int:
+ if isinstance(section, int):
+ section_number = section
+ section = self.section_from_number(section_number)
+ elif isinstance(section, str):
+ section_number = self._section_enum.from_text(section)
+ section = self.section_from_number(section_number)
+ else:
+ section_number = self.section_number(section)
+ count = sum(max(1, len(rrs)) for rrs in section)
+ if section_number == MessageSection.ADDITIONAL:
+ if self.opt is not None:
+ count += 1
+ if self.tsig is not None:
+ count += 1
+ return count
+
+ def _compute_opt_reserve(self) -> int:
"""Compute the size required for the OPT RR, padding excluded"""
- pass
-
- def _compute_tsig_reserve(self) ->int:
+ if not self.opt:
+ return 0
+ # 1 byte for the root name, 10 for the standard RR fields
+ size = 11
+ # This would be more efficient if options had a size() method, but we won't
+ # worry about that for now. We also don't worry if there is an existing padding
+ # option, as it is unlikely and probably harmless, as the worst case is that we
+ # may add another, and this seems to be legal.
+ for option in self.opt[0].options:
+ wire = option.to_wire()
+ # We add 4 here to account for the option type and length
+ size += len(wire) + 4
+ if self.pad:
+ # Padding will be added, so again add the option type and length.
+ size += 4
+ return size
+
+ def _compute_tsig_reserve(self) -> int:
"""Compute the size required for the TSIG RR"""
- pass
-
- def to_wire(self, origin: Optional[dns.name.Name]=None, max_size: int=0,
- multi: bool=False, tsig_ctx: Optional[Any]=None, prepend_length:
- bool=False, prefer_truncation: bool=False, **kw: Dict[str, Any]
- ) ->bytes:
+ # This would be more efficient if TSIGs had a size method, but we won't
+ # worry about for now. Also, we can't really cope with the potential
+ # compressibility of the TSIG owner name, so we estimate with the uncompressed
+ # size. We will disable compression when TSIG and padding are both is active
+ # so that the padding comes out right.
+ if not self.tsig:
+ return 0
+ f = io.BytesIO()
+ self.tsig.to_wire(f)
+ return len(f.getvalue())
+
+ def to_wire(
+ self,
+ origin: Optional[dns.name.Name] = None,
+ max_size: int = 0,
+ multi: bool = False,
+ tsig_ctx: Optional[Any] = None,
+ prepend_length: bool = False,
+ prefer_truncation: bool = False,
+ **kw: Dict[str, Any],
+ ) -> bytes:
"""Return a string containing the message in DNS compressed wire
format.
@@ -375,12 +592,90 @@ class Message:
Returns a ``bytes``.
"""
- pass
- def use_tsig(self, keyring: Any, keyname: Optional[Union[dns.name.Name,
- str]]=None, fudge: int=300, original_id: Optional[int]=None,
- tsig_error: int=0, other_data: bytes=b'', algorithm: Union[dns.name
- .Name, str]=dns.tsig.default_algorithm) ->None:
+ if origin is None and self.origin is not None:
+ origin = self.origin
+ if max_size == 0:
+ if self.request_payload != 0:
+ max_size = self.request_payload
+ else:
+ max_size = 65535
+ if max_size < 512:
+ max_size = 512
+ elif max_size > 65535:
+ max_size = 65535
+ r = dns.renderer.Renderer(self.id, self.flags, max_size, origin)
+ opt_reserve = self._compute_opt_reserve()
+ r.reserve(opt_reserve)
+ tsig_reserve = self._compute_tsig_reserve()
+ r.reserve(tsig_reserve)
+ try:
+ for rrset in self.question:
+ r.add_question(rrset.name, rrset.rdtype, rrset.rdclass)
+ for rrset in self.answer:
+ r.add_rrset(dns.renderer.ANSWER, rrset, **kw)
+ for rrset in self.authority:
+ r.add_rrset(dns.renderer.AUTHORITY, rrset, **kw)
+ for rrset in self.additional:
+ r.add_rrset(dns.renderer.ADDITIONAL, rrset, **kw)
+ except dns.exception.TooBig:
+ if prefer_truncation:
+ if r.section < dns.renderer.ADDITIONAL:
+ r.flags |= dns.flags.TC
+ else:
+ raise
+ r.release_reserved()
+ if self.opt is not None:
+ r.add_opt(self.opt, self.pad, opt_reserve, tsig_reserve)
+ r.write_header()
+ if self.tsig is not None:
+ (new_tsig, ctx) = dns.tsig.sign(
+ r.get_wire(),
+ self.keyring,
+ self.tsig[0],
+ int(time.time()),
+ self.request_mac,
+ tsig_ctx,
+ multi,
+ )
+ self.tsig.clear()
+ self.tsig.add(new_tsig)
+ r.add_rrset(dns.renderer.ADDITIONAL, self.tsig)
+ r.write_header()
+ if multi:
+ self.tsig_ctx = ctx
+ wire = r.get_wire()
+ if prepend_length:
+ wire = len(wire).to_bytes(2, "big") + wire
+ return wire
+
+ @staticmethod
+ def _make_tsig(
+ keyname, algorithm, time_signed, fudge, mac, original_id, error, other
+ ):
+ tsig = dns.rdtypes.ANY.TSIG.TSIG(
+ dns.rdataclass.ANY,
+ dns.rdatatype.TSIG,
+ algorithm,
+ time_signed,
+ fudge,
+ mac,
+ original_id,
+ error,
+ other,
+ )
+ return dns.rrset.from_rdata(keyname, 0, tsig)
+
+ def use_tsig(
+ self,
+ keyring: Any,
+ keyname: Optional[Union[dns.name.Name, str]] = None,
+ fudge: int = 300,
+ original_id: Optional[int] = None,
+ tsig_error: int = 0,
+ other_data: bytes = b"",
+ algorithm: Union[dns.name.Name, str] = dns.tsig.default_algorithm,
+ ) -> None:
"""When sending, a TSIG signature using the specified key
should be added.
@@ -417,12 +712,80 @@ class Message:
*algorithm*, a ``dns.name.Name`` or ``str``, the TSIG algorithm to use. This is
only used if *keyring* is a ``dict``, and the key entry is a ``bytes``.
"""
- pass
- def use_edns(self, edns: Optional[Union[int, bool]]=0, ednsflags: int=0,
- payload: int=DEFAULT_EDNS_PAYLOAD, request_payload: Optional[int]=
- None, options: Optional[List[dns.edns.Option]]=None, pad: int=0
- ) ->None:
+ if isinstance(keyring, dns.tsig.Key):
+ key = keyring
+ keyname = key.name
+ elif callable(keyring):
+ key = keyring(self, keyname)
+ else:
+ if isinstance(keyname, str):
+ keyname = dns.name.from_text(keyname)
+ if keyname is None:
+ keyname = next(iter(keyring))
+ key = keyring[keyname]
+ if isinstance(key, bytes):
+ key = dns.tsig.Key(keyname, key, algorithm)
+ self.keyring = key
+ if original_id is None:
+ original_id = self.id
+ self.tsig = self._make_tsig(
+ keyname,
+ self.keyring.algorithm,
+ 0,
+ fudge,
+ b"\x00" * dns.tsig.mac_sizes[self.keyring.algorithm],
+ original_id,
+ tsig_error,
+ other_data,
+ )
+
+ @property
+ def keyname(self) -> Optional[dns.name.Name]:
+ if self.tsig:
+ return self.tsig.name
+ else:
+ return None
+
+ @property
+ def keyalgorithm(self) -> Optional[dns.name.Name]:
+ if self.tsig:
+ return self.tsig[0].algorithm
+ else:
+ return None
+
+ @property
+ def mac(self) -> Optional[bytes]:
+ if self.tsig:
+ return self.tsig[0].mac
+ else:
+ return None
+
+ @property
+ def tsig_error(self) -> Optional[int]:
+ if self.tsig:
+ return self.tsig[0].error
+ else:
+ return None
+
+ @property
+ def had_tsig(self) -> bool:
+ return bool(self.tsig)
+
+ @staticmethod
+ def _make_opt(flags=0, payload=DEFAULT_EDNS_PAYLOAD, options=None):
+ opt = dns.rdtypes.ANY.OPT.OPT(payload, dns.rdatatype.OPT, options or ())
+ return dns.rrset.from_rdata(dns.name.root, int(flags), opt)
+
+ def use_edns(
+ self,
+ edns: Optional[Union[int, bool]] = 0,
+ ednsflags: int = 0,
+ payload: int = DEFAULT_EDNS_PAYLOAD,
+ request_payload: Optional[int] = None,
+ options: Optional[List[dns.edns.Option]] = None,
+ pad: int = 0,
+ ) -> None:
"""Configure EDNS behavior.
*edns*, an ``int``, is the EDNS level to use. Specifying ``None``, ``False``,
@@ -445,9 +808,64 @@ class Message:
padding is non-zero, an EDNS PADDING option will always be added to the
message.
"""
- pass
- def want_dnssec(self, wanted: bool=True) ->None:
+ if edns is None or edns is False:
+ edns = -1
+ elif edns is True:
+ edns = 0
+ if edns < 0:
+ self.opt = None
+ self.request_payload = 0
+ else:
+ # make sure the EDNS version in ednsflags agrees with edns
+ ednsflags &= 0xFF00FFFF
+ ednsflags |= edns << 16
+ if options is None:
+ options = []
+ self.opt = self._make_opt(ednsflags, payload, options)
+ if request_payload is None:
+ request_payload = payload
+ self.request_payload = request_payload
+ if pad < 0:
+ raise ValueError("pad must be non-negative")
+ self.pad = pad
+
+ @property
+ def edns(self) -> int:
+ if self.opt:
+ return (self.ednsflags & 0xFF0000) >> 16
+ else:
+ return -1
+
+ @property
+ def ednsflags(self) -> int:
+ if self.opt:
+ return self.opt.ttl
+ else:
+ return 0
+
+ @ednsflags.setter
+ def ednsflags(self, v):
+ if self.opt:
+ self.opt.ttl = v
+ elif v:
+ self.opt = self._make_opt(v)
+
+ @property
+ def payload(self) -> int:
+ if self.opt:
+ return self.opt[0].payload
+ else:
+ return 0
+
+ @property
+ def options(self) -> Tuple:
+ if self.opt:
+ return self.opt[0].options
+ else:
+ return ()
+
+ def want_dnssec(self, wanted: bool = True) -> None:
"""Enable or disable 'DNSSEC desired' flag in requests.
*wanted*, a ``bool``. If ``True``, then DNSSEC data is
@@ -455,35 +873,72 @@ class Message:
the DO bit is set. If ``False``, the DO bit is cleared if
EDNS is enabled.
"""
- pass
- def rcode(self) ->dns.rcode.Rcode:
+ if wanted:
+ self.ednsflags |= dns.flags.DO
+ elif self.opt:
+ self.ednsflags &= ~int(dns.flags.DO)
+
+ def rcode(self) -> dns.rcode.Rcode:
"""Return the rcode.
Returns a ``dns.rcode.Rcode``.
"""
- pass
+ return dns.rcode.from_flags(int(self.flags), int(self.ednsflags))
- def set_rcode(self, rcode: dns.rcode.Rcode) ->None:
+ def set_rcode(self, rcode: dns.rcode.Rcode) -> None:
"""Set the rcode.
*rcode*, a ``dns.rcode.Rcode``, is the rcode to set.
"""
- pass
+ (value, evalue) = dns.rcode.to_flags(rcode)
+ self.flags &= 0xFFF0
+ self.flags |= value
+ self.ednsflags &= 0x00FFFFFF
+ self.ednsflags |= evalue
- def opcode(self) ->dns.opcode.Opcode:
+ def opcode(self) -> dns.opcode.Opcode:
"""Return the opcode.
Returns a ``dns.opcode.Opcode``.
"""
- pass
+ return dns.opcode.from_flags(int(self.flags))
- def set_opcode(self, opcode: dns.opcode.Opcode) ->None:
+ def set_opcode(self, opcode: dns.opcode.Opcode) -> None:
"""Set the opcode.
*opcode*, a ``dns.opcode.Opcode``, is the opcode to set.
"""
- pass
+ self.flags &= 0x87FF
+ self.flags |= dns.opcode.to_flags(opcode)
+
+ def _get_one_rr_per_rrset(self, value):
+ # What the caller picked is fine.
+ return value
+
+ # pylint: disable=unused-argument
+
+ def _parse_rr_header(self, section, name, rdclass, rdtype):
+ return (rdclass, rdtype, None, False)
+
+ # pylint: enable=unused-argument
+
+ def _parse_special_rr_header(self, section, count, position, name, rdclass, rdtype):
+ if rdtype == dns.rdatatype.OPT:
+ if (
+ section != MessageSection.ADDITIONAL
+ or self.opt
+ or name != dns.name.root
+ ):
+ raise BadEDNS
+ elif rdtype == dns.rdatatype.TSIG:
+ if (
+ section != MessageSection.ADDITIONAL
+ or rdclass != dns.rdatatype.ANY
+ or position != count - 1
+ ):
+ raise BadTSIG
+ return (rdclass, rdtype, None, False)
class ChainingResult:
@@ -505,8 +960,13 @@ class ChainingResult:
get to the canonical name.
"""
- def __init__(self, canonical_name: dns.name.Name, answer: Optional[dns.
- rrset.RRset], minimum_ttl: int, cnames: List[dns.rrset.RRset]):
+ def __init__(
+ self,
+ canonical_name: dns.name.Name,
+ answer: Optional[dns.rrset.RRset],
+ minimum_ttl: int,
+ cnames: List[dns.rrset.RRset],
+ ):
self.canonical_name = canonical_name
self.answer = answer
self.minimum_ttl = minimum_ttl
@@ -514,8 +974,7 @@ class ChainingResult:
class QueryMessage(Message):
-
- def resolve_chaining(self) ->ChainingResult:
+ def resolve_chaining(self) -> ChainingResult:
"""Follow the CNAME chain in the response to determine the answer
RRset.
@@ -531,9 +990,66 @@ class QueryMessage(Message):
Returns a ChainingResult object.
"""
- pass
-
- def canonical_name(self) ->dns.name.Name:
+ if self.flags & dns.flags.QR == 0:
+ raise NotQueryResponse
+ if len(self.question) != 1:
+ raise dns.exception.FormError
+ question = self.question[0]
+ qname = question.name
+ min_ttl = dns.ttl.MAX_TTL
+ answer = None
+ count = 0
+ cnames = []
+ while count < MAX_CHAIN:
+ try:
+ answer = self.find_rrset(
+ self.answer, qname, question.rdclass, question.rdtype
+ )
+ min_ttl = min(min_ttl, answer.ttl)
+ break
+ except KeyError:
+ if question.rdtype != dns.rdatatype.CNAME:
+ try:
+ crrset = self.find_rrset(
+ self.answer, qname, question.rdclass, dns.rdatatype.CNAME
+ )
+ cnames.append(crrset)
+ min_ttl = min(min_ttl, crrset.ttl)
+ for rd in crrset:
+ qname = rd.target
+ break
+ count += 1
+ continue
+ except KeyError:
+ # Exit the chaining loop
+ break
+ else:
+ # Exit the chaining loop
+ break
+ if count >= MAX_CHAIN:
+ raise ChainTooLong
+ if self.rcode() == dns.rcode.NXDOMAIN and answer is not None:
+ raise AnswerForNXDOMAIN
+ if answer is None:
+ # Further minimize the TTL with NCACHE.
+ auname = qname
+ while True:
+ # Look for an SOA RR whose owner name is a superdomain
+ # of qname.
+ try:
+ srrset = self.find_rrset(
+ self.authority, auname, question.rdclass, dns.rdatatype.SOA
+ )
+ min_ttl = min(min_ttl, srrset.ttl, srrset[0].minimum)
+ break
+ except KeyError:
+ try:
+ auname = auname.parent()
+ except dns.name.NoParent:
+ break
+ return ChainingResult(qname, answer, min_ttl, cnames)
+
+ def canonical_name(self) -> dns.name.Name:
"""Return the canonical name of the first name in the question
section.
@@ -547,7 +1063,26 @@ class QueryMessage(Message):
Raises ``dns.exception.FormError`` if the question count is not 1.
"""
- pass
+ return self.resolve_chaining().canonical_name
+
+
+def _maybe_import_update():
+ # We avoid circular imports by doing this here. We do it in another
+ # function as doing it in _message_factory_from_opcode() makes "dns"
+ # a local symbol, and the first line fails :)
+
+ # pylint: disable=redefined-outer-name,import-outside-toplevel,unused-import
+ import dns.update # noqa: F401
+
+
+def _message_factory_from_opcode(opcode):
+ if opcode == dns.opcode.QUERY:
+ return QueryMessage
+ elif opcode == dns.opcode.UPDATE:
+ _maybe_import_update()
+ return dns.update.UpdateMessage
+ else:
+ return Message
class _WireReader:
@@ -567,9 +1102,17 @@ class _WireReader:
raising them.
"""
- def __init__(self, wire, initialize_message, question_only=False,
- one_rr_per_rrset=False, ignore_trailing=False, keyring=None, multi=
- False, continue_on_error=False):
+ def __init__(
+ self,
+ wire,
+ initialize_message,
+ question_only=False,
+ one_rr_per_rrset=False,
+ ignore_trailing=False,
+ keyring=None,
+ multi=False,
+ continue_on_error=False,
+ ):
self.parser = dns.wire.Parser(wire)
self.message = None
self.initialize_message = initialize_message
@@ -585,7 +1128,20 @@ class _WireReader:
"""Read the next *qcount* records from the wire data and add them to
the question section.
"""
- pass
+ assert self.message is not None
+ section = self.message.sections[section_number]
+ for _ in range(qcount):
+ qname = self.parser.get_name(self.message.origin)
+ (rdtype, rdclass) = self.parser.get_struct("!HH")
+ (rdclass, rdtype, _, _) = self.message._parse_rr_header(
+ section_number, qname, rdclass, rdtype
+ )
+ self.message.find_rrset(
+ section, qname, rdclass, rdtype, create=True, force_unique=True
+ )
+
+ def _add_error(self, e):
+ self.errors.append(MessageError(e, self.parser.current))
def _get_section(self, section_number, count):
"""Read the next I{count} records from the wire data and add them to
@@ -594,20 +1150,144 @@ class _WireReader:
section_number: the section of the message to which to add records
count: the number of records to read
"""
- pass
+ assert self.message is not None
+ section = self.message.sections[section_number]
+ force_unique = self.one_rr_per_rrset
+ for i in range(count):
+ rr_start = self.parser.current
+ absolute_name = self.parser.get_name()
+ if self.message.origin is not None:
+ name = absolute_name.relativize(self.message.origin)
+ else:
+ name = absolute_name
+ (rdtype, rdclass, ttl, rdlen) = self.parser.get_struct("!HHIH")
+ if rdtype in (dns.rdatatype.OPT, dns.rdatatype.TSIG):
+ (
+ rdclass,
+ rdtype,
+ deleting,
+ empty,
+ ) = self.message._parse_special_rr_header(
+ section_number, count, i, name, rdclass, rdtype
+ )
+ else:
+ (rdclass, rdtype, deleting, empty) = self.message._parse_rr_header(
+ section_number, name, rdclass, rdtype
+ )
+ rdata_start = self.parser.current
+ try:
+ if empty:
+ if rdlen > 0:
+ raise dns.exception.FormError
+ rd = None
+ covers = dns.rdatatype.NONE
+ else:
+ with self.parser.restrict_to(rdlen):
+ rd = dns.rdata.from_wire_parser(
+ rdclass, rdtype, self.parser, self.message.origin
+ )
+ covers = rd.covers()
+ if self.message.xfr and rdtype == dns.rdatatype.SOA:
+ force_unique = True
+ if rdtype == dns.rdatatype.OPT:
+ self.message.opt = dns.rrset.from_rdata(name, ttl, rd)
+ elif rdtype == dns.rdatatype.TSIG:
+ if self.keyring is None:
+ raise UnknownTSIGKey("got signed message without keyring")
+ if isinstance(self.keyring, dict):
+ key = self.keyring.get(absolute_name)
+ if isinstance(key, bytes):
+ key = dns.tsig.Key(absolute_name, key, rd.algorithm)
+ elif callable(self.keyring):
+ key = self.keyring(self.message, absolute_name)
+ else:
+ key = self.keyring
+ if key is None:
+ raise UnknownTSIGKey("key '%s' unknown" % name)
+ self.message.keyring = key
+ self.message.tsig_ctx = dns.tsig.validate(
+ self.parser.wire,
+ key,
+ absolute_name,
+ rd,
+ int(time.time()),
+ self.message.request_mac,
+ rr_start,
+ self.message.tsig_ctx,
+ self.multi,
+ )
+ self.message.tsig = dns.rrset.from_rdata(absolute_name, 0, rd)
+ else:
+ rrset = self.message.find_rrset(
+ section,
+ name,
+ rdclass,
+ rdtype,
+ covers,
+ deleting,
+ True,
+ force_unique,
+ )
+ if rd is not None:
+ if ttl > 0x7FFFFFFF:
+ ttl = 0
+ rrset.add(rd, ttl)
+ except Exception as e:
+ if self.continue_on_error:
+ self._add_error(e)
+ self.parser.seek(rdata_start + rdlen)
+ else:
+ raise
def read(self):
"""Read a wire format DNS message and build a dns.message.Message
object."""
- pass
-
-def from_wire(wire: bytes, keyring: Optional[Any]=None, request_mac:
- Optional[bytes]=b'', xfr: bool=False, origin: Optional[dns.name.Name]=
- None, tsig_ctx: Optional[Union[dns.tsig.HMACTSig, dns.tsig.GSSTSig]]=
- None, multi: bool=False, question_only: bool=False, one_rr_per_rrset:
- bool=False, ignore_trailing: bool=False, raise_on_truncation: bool=
- False, continue_on_error: bool=False) ->Message:
+ if self.parser.remaining() < 12:
+ raise ShortHeader
+ (id, flags, qcount, ancount, aucount, adcount) = self.parser.get_struct(
+ "!HHHHHH"
+ )
+ factory = _message_factory_from_opcode(dns.opcode.from_flags(flags))
+ self.message = factory(id=id)
+ self.message.flags = dns.flags.Flag(flags)
+ self.initialize_message(self.message)
+ self.one_rr_per_rrset = self.message._get_one_rr_per_rrset(
+ self.one_rr_per_rrset
+ )
+ try:
+ self._get_question(MessageSection.QUESTION, qcount)
+ if self.question_only:
+ return self.message
+ self._get_section(MessageSection.ANSWER, ancount)
+ self._get_section(MessageSection.AUTHORITY, aucount)
+ self._get_section(MessageSection.ADDITIONAL, adcount)
+ if not self.ignore_trailing and self.parser.remaining() != 0:
+ raise TrailingJunk
+ if self.multi and self.message.tsig_ctx and not self.message.had_tsig:
+ self.message.tsig_ctx.update(self.parser.wire)
+ except Exception as e:
+ if self.continue_on_error:
+ self._add_error(e)
+ else:
+ raise
+ return self.message
+
+
+def from_wire(
+ wire: bytes,
+ keyring: Optional[Any] = None,
+ request_mac: Optional[bytes] = b"",
+ xfr: bool = False,
+ origin: Optional[dns.name.Name] = None,
+ tsig_ctx: Optional[Union[dns.tsig.HMACTSig, dns.tsig.GSSTSig]] = None,
+ multi: bool = False,
+ question_only: bool = False,
+ one_rr_per_rrset: bool = False,
+ ignore_trailing: bool = False,
+ raise_on_truncation: bool = False,
+ continue_on_error: bool = False,
+) -> Message:
"""Convert a DNS wire format message into a message object.
*keyring*, a ``dns.tsig.Key`` or ``dict``, the key or keyring to use if the message
@@ -662,7 +1342,46 @@ def from_wire(wire: bytes, keyring: Optional[Any]=None, request_mac:
Returns a ``dns.message.Message``.
"""
- pass
+
+ # We permit None for request_mac solely for backwards compatibility
+ if request_mac is None:
+ request_mac = b""
+
+ def initialize_message(message):
+ message.request_mac = request_mac
+ message.xfr = xfr
+ message.origin = origin
+ message.tsig_ctx = tsig_ctx
+
+ reader = _WireReader(
+ wire,
+ initialize_message,
+ question_only,
+ one_rr_per_rrset,
+ ignore_trailing,
+ keyring,
+ multi,
+ continue_on_error,
+ )
+ try:
+ m = reader.read()
+ except dns.exception.FormError:
+ if (
+ reader.message
+ and (reader.message.flags & dns.flags.TC)
+ and raise_on_truncation
+ ):
+ raise Truncated(message=reader.message)
+ else:
+ raise
+ # Reading a truncated message might not have any errors, so we
+ # have to do this check here too.
+ if m.flags & dns.flags.TC and raise_on_truncation:
+ raise Truncated(message=m)
+ if continue_on_error:
+ m.errors = reader.errors
+
+ return m
class _TextReader:
@@ -678,8 +1397,15 @@ class _TextReader:
relativize_to: the origin to relativize to.
"""
- def __init__(self, text, idna_codec, one_rr_per_rrset=False, origin=
- None, relativize=True, relativize_to=None):
+ def __init__(
+ self,
+ text,
+ idna_codec,
+ one_rr_per_rrset=False,
+ origin=None,
+ relativize=True,
+ relativize_to=None,
+ ):
self.message = None
self.tok = dns.tokenizer.Tokenizer(text, idna_codec=idna_codec)
self.last_name = None
@@ -697,28 +1423,221 @@ class _TextReader:
def _header_line(self, _):
"""Process one line from the text format header section."""
- pass
+
+ token = self.tok.get()
+ what = token.value
+ if what == "id":
+ self.id = self.tok.get_int()
+ elif what == "flags":
+ while True:
+ token = self.tok.get()
+ if not token.is_identifier():
+ self.tok.unget(token)
+ break
+ self.flags = self.flags | dns.flags.from_text(token.value)
+ elif what == "edns":
+ self.edns = self.tok.get_int()
+ self.ednsflags = self.ednsflags | (self.edns << 16)
+ elif what == "eflags":
+ if self.edns < 0:
+ self.edns = 0
+ while True:
+ token = self.tok.get()
+ if not token.is_identifier():
+ self.tok.unget(token)
+ break
+ self.ednsflags = self.ednsflags | dns.flags.edns_from_text(token.value)
+ elif what == "payload":
+ self.payload = self.tok.get_int()
+ if self.edns < 0:
+ self.edns = 0
+ elif what == "opcode":
+ text = self.tok.get_string()
+ self.opcode = dns.opcode.from_text(text)
+ self.flags = self.flags | dns.opcode.to_flags(self.opcode)
+ elif what == "rcode":
+ text = self.tok.get_string()
+ self.rcode = dns.rcode.from_text(text)
+ else:
+ raise UnknownHeaderField
+ self.tok.get_eol()
def _question_line(self, section_number):
"""Process one line from the text format question section."""
- pass
+
+ section = self.message.sections[section_number]
+ token = self.tok.get(want_leading=True)
+ if not token.is_whitespace():
+ self.last_name = self.tok.as_name(
+ token, self.message.origin, self.relativize, self.relativize_to
+ )
+ name = self.last_name
+ if name is None:
+ raise NoPreviousName
+ token = self.tok.get()
+ if not token.is_identifier():
+ raise dns.exception.SyntaxError
+ # Class
+ try:
+ rdclass = dns.rdataclass.from_text(token.value)
+ token = self.tok.get()
+ if not token.is_identifier():
+ raise dns.exception.SyntaxError
+ except dns.exception.SyntaxError:
+ raise dns.exception.SyntaxError
+ except Exception:
+ rdclass = dns.rdataclass.IN
+ # Type
+ rdtype = dns.rdatatype.from_text(token.value)
+ (rdclass, rdtype, _, _) = self.message._parse_rr_header(
+ section_number, name, rdclass, rdtype
+ )
+ self.message.find_rrset(
+ section, name, rdclass, rdtype, create=True, force_unique=True
+ )
+ self.tok.get_eol()
def _rr_line(self, section_number):
"""Process one line from the text format answer, authority, or
additional data sections.
"""
- pass
+
+ section = self.message.sections[section_number]
+ # Name
+ token = self.tok.get(want_leading=True)
+ if not token.is_whitespace():
+ self.last_name = self.tok.as_name(
+ token, self.message.origin, self.relativize, self.relativize_to
+ )
+ name = self.last_name
+ if name is None:
+ raise NoPreviousName
+ token = self.tok.get()
+ if not token.is_identifier():
+ raise dns.exception.SyntaxError
+ # TTL
+ try:
+ ttl = int(token.value, 0)
+ token = self.tok.get()
+ if not token.is_identifier():
+ raise dns.exception.SyntaxError
+ except dns.exception.SyntaxError:
+ raise dns.exception.SyntaxError
+ except Exception:
+ ttl = 0
+ # Class
+ try:
+ rdclass = dns.rdataclass.from_text(token.value)
+ token = self.tok.get()
+ if not token.is_identifier():
+ raise dns.exception.SyntaxError
+ except dns.exception.SyntaxError:
+ raise dns.exception.SyntaxError
+ except Exception:
+ rdclass = dns.rdataclass.IN
+ # Type
+ rdtype = dns.rdatatype.from_text(token.value)
+ (rdclass, rdtype, deleting, empty) = self.message._parse_rr_header(
+ section_number, name, rdclass, rdtype
+ )
+ token = self.tok.get()
+ if empty and not token.is_eol_or_eof():
+ raise dns.exception.SyntaxError
+ if not empty and token.is_eol_or_eof():
+ raise dns.exception.UnexpectedEnd
+ if not token.is_eol_or_eof():
+ self.tok.unget(token)
+ rd = dns.rdata.from_text(
+ rdclass,
+ rdtype,
+ self.tok,
+ self.message.origin,
+ self.relativize,
+ self.relativize_to,
+ )
+ covers = rd.covers()
+ else:
+ rd = None
+ covers = dns.rdatatype.NONE
+ rrset = self.message.find_rrset(
+ section,
+ name,
+ rdclass,
+ rdtype,
+ covers,
+ deleting,
+ True,
+ self.one_rr_per_rrset,
+ )
+ if rd is not None:
+ rrset.add(rd, ttl)
+
+ def _make_message(self):
+ factory = _message_factory_from_opcode(self.opcode)
+ message = factory(id=self.id)
+ message.flags = self.flags
+ if self.edns >= 0:
+ message.use_edns(self.edns, self.ednsflags, self.payload)
+ if self.rcode:
+ message.set_rcode(self.rcode)
+ if self.origin:
+ message.origin = self.origin
+ return message
def read(self):
"""Read a text format DNS message and build a dns.message.Message
object."""
- pass
-
-def from_text(text: str, idna_codec: Optional[dns.name.IDNACodec]=None,
- one_rr_per_rrset: bool=False, origin: Optional[dns.name.Name]=None,
- relativize: bool=True, relativize_to: Optional[dns.name.Name]=None
- ) ->Message:
+ line_method = self._header_line
+ section_number = None
+ while 1:
+ token = self.tok.get(True, True)
+ if token.is_eol_or_eof():
+ break
+ if token.is_comment():
+ u = token.value.upper()
+ if u == "HEADER":
+ line_method = self._header_line
+
+ if self.message:
+ message = self.message
+ else:
+ # If we don't have a message, create one with the current
+ # opcode, so that we know which section names to parse.
+ message = self._make_message()
+ try:
+ section_number = message._section_enum.from_text(u)
+ # We found a section name. If we don't have a message,
+ # use the one we just created.
+ if not self.message:
+ self.message = message
+ self.one_rr_per_rrset = message._get_one_rr_per_rrset(
+ self.one_rr_per_rrset
+ )
+ if section_number == MessageSection.QUESTION:
+ line_method = self._question_line
+ else:
+ line_method = self._rr_line
+ except Exception:
+ # It's just a comment.
+ pass
+ self.tok.get_eol()
+ continue
+ self.tok.unget(token)
+ line_method(section_number)
+ if not self.message:
+ self.message = self._make_message()
+ return self.message
+
+
+def from_text(
+ text: str,
+ idna_codec: Optional[dns.name.IDNACodec] = None,
+ one_rr_per_rrset: bool = False,
+ origin: Optional[dns.name.Name] = None,
+ relativize: bool = True,
+ relativize_to: Optional[dns.name.Name] = None,
+) -> Message:
"""Convert the text format message into a message object.
The reader stops after reading the first blank line in the input to
@@ -748,11 +1667,22 @@ def from_text(text: str, idna_codec: Optional[dns.name.IDNACodec]=None,
Returns a ``dns.message.Message object``
"""
- pass
+ # 'text' can also be a file, but we don't publish that fact
+ # since it's an implementation detail. The official file
+ # interface is from_file().
+
+ reader = _TextReader(
+ text, idna_codec, one_rr_per_rrset, origin, relativize, relativize_to
+ )
+ return reader.read()
-def from_file(f: Any, idna_codec: Optional[dns.name.IDNACodec]=None,
- one_rr_per_rrset: bool=False) ->Message:
+
+def from_file(
+ f: Any,
+ idna_codec: Optional[dns.name.IDNACodec] = None,
+ one_rr_per_rrset: bool = False,
+) -> Message:
"""Read the next text format message from the specified file.
Message blocks are separated by a single blank line.
@@ -773,17 +1703,31 @@ def from_file(f: Any, idna_codec: Optional[dns.name.IDNACodec]=None,
Returns a ``dns.message.Message object``
"""
- pass
-
-def make_query(qname: Union[dns.name.Name, str], rdtype: Union[dns.
- rdatatype.RdataType, str], rdclass: Union[dns.rdataclass.RdataClass,
- str]=dns.rdataclass.IN, use_edns: Optional[Union[int, bool]]=None,
- want_dnssec: bool=False, ednsflags: Optional[int]=None, payload:
- Optional[int]=None, request_payload: Optional[int]=None, options:
- Optional[List[dns.edns.Option]]=None, idna_codec: Optional[dns.name.
- IDNACodec]=None, id: Optional[int]=None, flags: int=dns.flags.RD, pad:
- int=0) ->QueryMessage:
+ if isinstance(f, str):
+ cm: contextlib.AbstractContextManager = open(f)
+ else:
+ cm = contextlib.nullcontext(f)
+ with cm as f:
+ return from_text(f, idna_codec, one_rr_per_rrset)
+ assert False # for mypy lgtm[py/unreachable-statement]
+
+
+def make_query(
+ qname: Union[dns.name.Name, str],
+ rdtype: Union[dns.rdatatype.RdataType, str],
+ rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN,
+ use_edns: Optional[Union[int, bool]] = None,
+ want_dnssec: bool = False,
+ ednsflags: Optional[int] = None,
+ payload: Optional[int] = None,
+ request_payload: Optional[int] = None,
+ options: Optional[List[dns.edns.Option]] = None,
+ idna_codec: Optional[dns.name.IDNACodec] = None,
+ id: Optional[int] = None,
+ flags: int = dns.flags.RD,
+ pad: int = 0,
+) -> QueryMessage:
"""Make a query message.
The query name, type, and class may all be specified either
@@ -838,12 +1782,43 @@ def make_query(qname: Union[dns.name.Name, str], rdtype: Union[dns.
Returns a ``dns.message.QueryMessage``
"""
- pass
-
-def make_response(query: Message, recursion_available: bool=False,
- our_payload: int=8192, fudge: int=300, tsig_error: int=0, pad: Optional
- [int]=None) ->Message:
+ if isinstance(qname, str):
+ qname = dns.name.from_text(qname, idna_codec=idna_codec)
+ rdtype = dns.rdatatype.RdataType.make(rdtype)
+ rdclass = dns.rdataclass.RdataClass.make(rdclass)
+ m = QueryMessage(id=id)
+ m.flags = dns.flags.Flag(flags)
+ m.find_rrset(m.question, qname, rdclass, rdtype, create=True, force_unique=True)
+ # only pass keywords on to use_edns if they have been set to a
+ # non-None value. Setting a field will turn EDNS on if it hasn't
+ # been configured.
+ kwargs: Dict[str, Any] = {}
+ if ednsflags is not None:
+ kwargs["ednsflags"] = ednsflags
+ if payload is not None:
+ kwargs["payload"] = payload
+ if request_payload is not None:
+ kwargs["request_payload"] = request_payload
+ if options is not None:
+ kwargs["options"] = options
+ if kwargs and use_edns is None:
+ use_edns = 0
+ kwargs["edns"] = use_edns
+ kwargs["pad"] = pad
+ m.use_edns(**kwargs)
+ m.want_dnssec(want_dnssec)
+ return m
+
+
+def make_response(
+ query: Message,
+ recursion_available: bool = False,
+ our_payload: int = 8192,
+ fudge: int = 300,
+ tsig_error: int = 0,
+ pad: Optional[int] = None,
+) -> Message:
"""Make a message which is a response for the specified query.
The message returned is really a response skeleton; it has all of the infrastructure
required of a response, but none of the content.
@@ -871,10 +1846,43 @@ def make_response(query: Message, recursion_available: bool=False,
query. For example, if query is a ``dns.update.UpdateMessage``, response will be
too.
"""
- pass
+ if query.flags & dns.flags.QR:
+ raise dns.exception.FormError("specified query message is not a query")
+ factory = _message_factory_from_opcode(query.opcode())
+ response = factory(id=query.id)
+ response.flags = dns.flags.QR | (query.flags & dns.flags.RD)
+ if recursion_available:
+ response.flags |= dns.flags.RA
+ response.set_opcode(query.opcode())
+ response.question = list(query.question)
+ if query.edns >= 0:
+ if pad is None:
+ # Set response padding per RFC 8467
+ pad = 0
+ for option in query.options:
+ if option.otype == dns.edns.OptionType.PADDING:
+ pad = 468
+ response.use_edns(0, 0, our_payload, query.payload, pad=pad)
+ if query.had_tsig:
+ response.use_tsig(
+ query.keyring,
+ query.keyname,
+ fudge,
+ None,
+ tsig_error,
+ b"",
+ query.keyalgorithm,
+ )
+ response.request_mac = query.mac
+ return response
+
+
+### BEGIN generated MessageSection constants
QUESTION = MessageSection.QUESTION
ANSWER = MessageSection.ANSWER
AUTHORITY = MessageSection.AUTHORITY
ADDITIONAL = MessageSection.ADDITIONAL
+
+### END generated MessageSection constants
diff --git a/dns/name.py b/dns/name.py
index b9a153e..22ccb39 100644
--- a/dns/name.py
+++ b/dns/name.py
@@ -1,32 +1,72 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2001-2017 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
"""DNS Names.
"""
+
import copy
-import encodings.idna
+import encodings.idna # type: ignore
import functools
import struct
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union
+
import dns._features
import dns.enum
import dns.exception
import dns.immutable
import dns.wire
-if dns._features.have('idna'):
- import idna
+
+if dns._features.have("idna"):
+ import idna # type: ignore
+
have_idna_2008 = True
-else:
+else: # pragma: no cover
have_idna_2008 = False
-CompressType = Dict['Name', int]
+
+CompressType = Dict["Name", int]
class NameRelation(dns.enum.IntEnum):
"""Name relation result from fullcompare()."""
+
+ # This is an IntEnum for backwards compatibility in case anyone
+ # has hardwired the constants.
+
+ #: The compared names have no relationship to each other.
NONE = 0
+ #: the first name is a superdomain of the second.
SUPERDOMAIN = 1
+ #: The first name is a subdomain of the second.
SUBDOMAIN = 2
+ #: The compared names are equal.
EQUAL = 3
+ #: The compared names have a common ancestor.
COMMONANCESTOR = 4
+ @classmethod
+ def _maximum(cls):
+ return cls.COMMONANCESTOR
+
+ @classmethod
+ def _short_name(cls):
+ return cls.__name__
+
+# Backwards compatibility
NAMERELN_NONE = NameRelation.NONE
NAMERELN_SUPERDOMAIN = NameRelation.SUPERDOMAIN
NAMERELN_SUBDOMAIN = NameRelation.SUBDOMAIN
@@ -80,9 +120,12 @@ class NoIDNA2008(dns.exception.DNSException):
class IDNAException(dns.exception.DNSException):
"""IDNA processing raised an exception."""
- supp_kwargs = {'idna_exception'}
- fmt = 'IDNA processing exception: {idna_exception}'
+ supp_kwargs = {"idna_exception"}
+ fmt = "IDNA processing exception: {idna_exception}"
+
+ # We do this as otherwise mypy complains about unexpected keyword argument
+ # idna_exception
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -95,11 +138,33 @@ _escaped = b'"().;\\@$'
_escaped_text = '"().;\\@$'
-def _escapify(label: Union[bytes, str]) ->str:
+def _escapify(label: Union[bytes, str]) -> str:
"""Escape the characters in label which need it.
@returns: the escaped string
@rtype: string"""
- pass
+ if isinstance(label, bytes):
+ # Ordinary DNS label mode. Escape special characters and values
+ # < 0x20 or > 0x7f.
+ text = ""
+ for c in label:
+ if c in _escaped:
+ text += "\\" + chr(c)
+ elif c > 0x20 and c < 0x7F:
+ text += chr(c)
+ else:
+ text += "\\%03d" % c
+ return text
+
+ # Unicode label mode. Escape only special characters and values < 0x20
+ text = ""
+ for uc in label:
+ if uc in _escaped_text:
+ text += "\\" + uc
+ elif uc <= "\x20":
+ text += "\\%03d" % ord(uc)
+ else:
+ text += uc
+ return text
class IDNACodec:
@@ -108,34 +173,70 @@ class IDNACodec:
def __init__(self):
pass
+ def is_idna(self, label: bytes) -> bool:
+ return label.lower().startswith(b"xn--")
+
+ def encode(self, label: str) -> bytes:
+ raise NotImplementedError # pragma: no cover
+
+ def decode(self, label: bytes) -> str:
+ # We do not apply any IDNA policy on decode.
+ if self.is_idna(label):
+ try:
+ slabel = label[4:].decode("punycode")
+ return _escapify(slabel)
+ except Exception as e:
+ raise IDNAException(idna_exception=e)
+ else:
+ return _escapify(label)
+
class IDNA2003Codec(IDNACodec):
"""IDNA 2003 encoder/decoder."""
- def __init__(self, strict_decode: bool=False):
+ def __init__(self, strict_decode: bool = False):
"""Initialize the IDNA 2003 encoder/decoder.
*strict_decode* is a ``bool``. If `True`, then IDNA2003 checking
is done when decoding. This can cause failures if the name
was encoded with IDNA2008. The default is `False`.
"""
+
super().__init__()
self.strict_decode = strict_decode
- def encode(self, label: str) ->bytes:
+ def encode(self, label: str) -> bytes:
"""Encode *label*."""
- pass
- def decode(self, label: bytes) ->str:
+ if label == "":
+ return b""
+ try:
+ return encodings.idna.ToASCII(label)
+ except UnicodeError:
+ raise LabelTooLong
+
+ def decode(self, label: bytes) -> str:
"""Decode *label*."""
- pass
+ if not self.strict_decode:
+ return super().decode(label)
+ if label == b"":
+ return ""
+ try:
+ return _escapify(encodings.idna.ToUnicode(label))
+ except Exception as e:
+ raise IDNAException(idna_exception=e)
class IDNA2008Codec(IDNACodec):
"""IDNA 2008 encoder/decoder."""
- def __init__(self, uts_46: bool=False, transitional: bool=False,
- allow_pure_ascii: bool=False, strict_decode: bool=False):
+ def __init__(
+ self,
+ uts_46: bool = False,
+ transitional: bool = False,
+ allow_pure_ascii: bool = False,
+ strict_decode: bool = False,
+ ):
"""Initialize the IDNA 2008 encoder/decoder.
*uts_46* is a ``bool``. If True, apply Unicode IDNA
@@ -164,6 +265,41 @@ class IDNA2008Codec(IDNACodec):
self.allow_pure_ascii = allow_pure_ascii
self.strict_decode = strict_decode
+ def encode(self, label: str) -> bytes:
+ if label == "":
+ return b""
+ if self.allow_pure_ascii and is_all_ascii(label):
+ encoded = label.encode("ascii")
+ if len(encoded) > 63:
+ raise LabelTooLong
+ return encoded
+ if not have_idna_2008:
+ raise NoIDNA2008
+ try:
+ if self.uts_46:
+ label = idna.uts46_remap(label, False, self.transitional)
+ return idna.alabel(label)
+ except idna.IDNAError as e:
+ if e.args[0] == "Label too long":
+ raise LabelTooLong
+ else:
+ raise IDNAException(idna_exception=e)
+
+ def decode(self, label: bytes) -> str:
+ if not self.strict_decode:
+ return super().decode(label)
+ if label == b"":
+ return ""
+ if not have_idna_2008:
+ raise NoIDNA2008
+ try:
+ ulabel = idna.ulabel(label)
+ if self.uts_46:
+ ulabel = idna.uts46_remap(ulabel, False, self.transitional)
+ return _escapify(ulabel)
+ except (idna.IDNAError, UnicodeError) as e:
+ raise IDNAException(idna_exception=e)
+
IDNA_2003_Practical = IDNA2003Codec(False)
IDNA_2003_Strict = IDNA2003Codec(True)
@@ -175,7 +311,7 @@ IDNA_2008_Transitional = IDNA2008Codec(True, True, False, False)
IDNA_2008 = IDNA_2008_Practical
-def _validate_labels(labels: Tuple[bytes, ...]) ->None:
+def _validate_labels(labels: Tuple[bytes, ...]) -> None:
"""Check for empty labels in the middle of a label sequence,
labels that are too long, and for too many labels.
@@ -186,15 +322,36 @@ def _validate_labels(labels: Tuple[bytes, ...]) ->None:
sequence
"""
- pass
-
-def _maybe_convert_to_binary(label: Union[bytes, str]) ->bytes:
+ l = len(labels)
+ total = 0
+ i = -1
+ j = 0
+ for label in labels:
+ ll = len(label)
+ total += ll + 1
+ if ll > 63:
+ raise LabelTooLong
+ if i < 0 and label == b"":
+ i = j
+ j += 1
+ if total > 255:
+ raise NameTooLong
+ if i >= 0 and i != l - 1:
+ raise EmptyLabel
+
+
+def _maybe_convert_to_binary(label: Union[bytes, str]) -> bytes:
"""If label is ``str``, convert it to ``bytes``. If it is already
``bytes`` just return it.
"""
- pass
+
+ if isinstance(label, bytes):
+ return label
+ if isinstance(label, str):
+ return label.encode()
+ raise ValueError # pragma: no cover
@dns.immutable.immutable
@@ -205,10 +362,12 @@ class Name:
labels. Each label is a ``bytes`` in DNS wire format. Instances
of the class are immutable.
"""
- __slots__ = ['labels']
+
+ __slots__ = ["labels"]
def __init__(self, labels: Iterable[Union[bytes, str]]):
"""*labels* is any iterable whose values are ``str`` or ``bytes``."""
+
blabels = [_maybe_convert_to_binary(x) for x in labels]
self.labels = tuple(blabels)
_validate_labels(self.labels)
@@ -220,38 +379,42 @@ class Name:
return Name(copy.deepcopy(self.labels, memo))
def __getstate__(self):
- return {'labels': self.labels}
+ # Names can be pickled
+ return {"labels": self.labels}
def __setstate__(self, state):
- super().__setattr__('labels', state['labels'])
+ super().__setattr__("labels", state["labels"])
_validate_labels(self.labels)
- def is_absolute(self) ->bool:
+ def is_absolute(self) -> bool:
"""Is the most significant label of this name the root label?
Returns a ``bool``.
"""
- pass
- def is_wild(self) ->bool:
+ return len(self.labels) > 0 and self.labels[-1] == b""
+
+ def is_wild(self) -> bool:
"""Is this name wild? (I.e. Is the least significant label '*'?)
Returns a ``bool``.
"""
- pass
- def __hash__(self) ->int:
+ return len(self.labels) > 0 and self.labels[0] == b"*"
+
+ def __hash__(self) -> int:
"""Return a case-insensitive hash of the name.
Returns an ``int``.
"""
+
h = 0
for label in self.labels:
for c in label.lower():
h += (h << 3) + c
return h
- def fullcompare(self, other: 'Name') ->Tuple[NameRelation, int, int]:
+ def fullcompare(self, other: "Name") -> Tuple[NameRelation, int, int]:
"""Compare two names, returning a 3-tuple
``(relation, order, nlabels)``.
@@ -282,9 +445,52 @@ class Name:
example1. example2 none > 0 0
============= ============= =========== ===== =======
"""
- pass
- def is_subdomain(self, other: 'Name') ->bool:
+ sabs = self.is_absolute()
+ oabs = other.is_absolute()
+ if sabs != oabs:
+ if sabs:
+ return (NameRelation.NONE, 1, 0)
+ else:
+ return (NameRelation.NONE, -1, 0)
+ l1 = len(self.labels)
+ l2 = len(other.labels)
+ ldiff = l1 - l2
+ if ldiff < 0:
+ l = l1
+ else:
+ l = l2
+
+ order = 0
+ nlabels = 0
+ namereln = NameRelation.NONE
+ while l > 0:
+ l -= 1
+ l1 -= 1
+ l2 -= 1
+ label1 = self.labels[l1].lower()
+ label2 = other.labels[l2].lower()
+ if label1 < label2:
+ order = -1
+ if nlabels > 0:
+ namereln = NameRelation.COMMONANCESTOR
+ return (namereln, order, nlabels)
+ elif label1 > label2:
+ order = 1
+ if nlabels > 0:
+ namereln = NameRelation.COMMONANCESTOR
+ return (namereln, order, nlabels)
+ nlabels += 1
+ order = ldiff
+ if ldiff < 0:
+ namereln = NameRelation.SUPERDOMAIN
+ elif ldiff > 0:
+ namereln = NameRelation.SUBDOMAIN
+ else:
+ namereln = NameRelation.EQUAL
+ return (namereln, order, nlabels)
+
+ def is_subdomain(self, other: "Name") -> bool:
"""Is self a subdomain of other?
Note that the notion of subdomain includes equality, e.g.
@@ -292,9 +498,13 @@ class Name:
Returns a ``bool``.
"""
- pass
- def is_superdomain(self, other: 'Name') ->bool:
+ (nr, _, _) = self.fullcompare(other)
+ if nr == NameRelation.SUBDOMAIN or nr == NameRelation.EQUAL:
+ return True
+ return False
+
+ def is_superdomain(self, other: "Name") -> bool:
"""Is self a superdomain of other?
Note that the notion of superdomain includes equality, e.g.
@@ -302,13 +512,18 @@ class Name:
Returns a ``bool``.
"""
- pass
- def canonicalize(self) ->'Name':
+ (nr, _, _) = self.fullcompare(other)
+ if nr == NameRelation.SUPERDOMAIN or nr == NameRelation.EQUAL:
+ return True
+ return False
+
+ def canonicalize(self) -> "Name":
"""Return a name which is equal to the current name, but is in
DNSSEC canonical form.
"""
- pass
+
+ return Name([x.lower() for x in self.labels])
def __eq__(self, other):
if isinstance(other, Name):
@@ -347,12 +562,12 @@ class Name:
return NotImplemented
def __repr__(self):
- return '<DNS name ' + self.__str__() + '>'
+ return "<DNS name " + self.__str__() + ">"
def __str__(self):
return self.to_text(False)
- def to_text(self, omit_final_dot: bool=False) ->str:
+ def to_text(self, omit_final_dot: bool = False) -> str:
"""Convert name to DNS text format.
*omit_final_dot* is a ``bool``. If True, don't emit the final
@@ -361,10 +576,21 @@ class Name:
Returns a ``str``.
"""
- pass
- def to_unicode(self, omit_final_dot: bool=False, idna_codec: Optional[
- IDNACodec]=None) ->str:
+ if len(self.labels) == 0:
+ return "@"
+ if len(self.labels) == 1 and self.labels[0] == b"":
+ return "."
+ if omit_final_dot and self.is_absolute():
+ l = self.labels[:-1]
+ else:
+ l = self.labels
+ s = ".".join(map(_escapify, l))
+ return s
+
+ def to_unicode(
+ self, omit_final_dot: bool = False, idna_codec: Optional[IDNACodec] = None
+ ) -> str:
"""Convert name to Unicode text format.
IDN ACE labels are converted to Unicode.
@@ -381,9 +607,20 @@ class Name:
Returns a ``str``.
"""
- pass
- def to_digestable(self, origin: Optional['Name']=None) ->bytes:
+ if len(self.labels) == 0:
+ return "@"
+ if len(self.labels) == 1 and self.labels[0] == b"":
+ return "."
+ if omit_final_dot and self.is_absolute():
+ l = self.labels[:-1]
+ else:
+ l = self.labels
+ if idna_codec is None:
+ idna_codec = IDNA_2003_Practical
+ return ".".join([idna_codec.decode(x) for x in l])
+
+ def to_digestable(self, origin: Optional["Name"] = None) -> bytes:
"""Convert name to a format suitable for digesting in hashes.
The name is canonicalized and converted to uncompressed wire
@@ -399,11 +636,18 @@ class Name:
Returns a ``bytes``.
"""
- pass
- def to_wire(self, file: Optional[Any]=None, compress: Optional[
- CompressType]=None, origin: Optional['Name']=None, canonicalize:
- bool=False) ->Optional[bytes]:
+ digest = self.to_wire(origin=origin, canonicalize=True)
+ assert digest is not None
+ return digest
+
+ def to_wire(
+ self,
+ file: Optional[Any] = None,
+ compress: Optional[CompressType] = None,
+ origin: Optional["Name"] = None,
+ canonicalize: bool = False,
+ ) -> Optional[bytes]:
"""Convert name to wire format, possibly compressing it.
*file* is the file where the name is emitted (typically an
@@ -429,13 +673,67 @@ class Name:
Returns a ``bytes`` or ``None``.
"""
- pass
- def __len__(self) ->int:
+ if file is None:
+ out = bytearray()
+ for label in self.labels:
+ out.append(len(label))
+ if canonicalize:
+ out += label.lower()
+ else:
+ out += label
+ if not self.is_absolute():
+ if origin is None or not origin.is_absolute():
+ raise NeedAbsoluteNameOrOrigin
+ for label in origin.labels:
+ out.append(len(label))
+ if canonicalize:
+ out += label.lower()
+ else:
+ out += label
+ return bytes(out)
+
+ labels: Iterable[bytes]
+ if not self.is_absolute():
+ if origin is None or not origin.is_absolute():
+ raise NeedAbsoluteNameOrOrigin
+ labels = list(self.labels)
+ labels.extend(list(origin.labels))
+ else:
+ labels = self.labels
+ i = 0
+ for label in labels:
+ n = Name(labels[i:])
+ i += 1
+ if compress is not None:
+ pos = compress.get(n)
+ else:
+ pos = None
+ if pos is not None:
+ value = 0xC000 + pos
+ s = struct.pack("!H", value)
+ file.write(s)
+ break
+ else:
+ if compress is not None and len(n) > 1:
+ pos = file.tell()
+ if pos <= 0x3FFF:
+ compress[n] = pos
+ l = len(label)
+ file.write(struct.pack("!B", l))
+ if l > 0:
+ if canonicalize:
+ file.write(label.lower())
+ else:
+ file.write(label)
+ return None
+
+ def __len__(self) -> int:
"""The length of the name (in labels).
Returns an ``int``.
"""
+
return len(self.labels)
def __getitem__(self, index):
@@ -447,7 +745,7 @@ class Name:
def __sub__(self, other):
return self.relativize(other)
- def split(self, depth: int) ->Tuple['Name', 'Name']:
+ def split(self, depth: int) -> Tuple["Name", "Name"]:
"""Split a name into a prefix and suffix names at the specified depth.
*depth* is an ``int`` specifying the number of labels in the suffix
@@ -457,9 +755,17 @@ class Name:
Returns the tuple ``(prefix, suffix)``.
"""
- pass
- def concatenate(self, other: 'Name') ->'Name':
+ l = len(self.labels)
+ if depth == 0:
+ return (self, dns.name.empty)
+ elif depth == l:
+ return (dns.name.empty, self)
+ elif depth < 0 or depth > l:
+ raise ValueError("depth must be >= 0 and <= the length of the name")
+ return (Name(self[:-depth]), Name(self[-depth:]))
+
+ def concatenate(self, other: "Name") -> "Name":
"""Return a new name which is the concatenation of self and other.
Raises ``dns.name.AbsoluteConcatenation`` if the name is
@@ -467,9 +773,14 @@ class Name:
Returns a ``dns.name.Name``.
"""
- pass
- def relativize(self, origin: 'Name') ->'Name':
+ if self.is_absolute() and len(other) > 0:
+ raise AbsoluteConcatenation
+ labels = list(self.labels)
+ labels.extend(list(other.labels))
+ return Name(labels)
+
+ def relativize(self, origin: "Name") -> "Name":
"""If the name is a subdomain of *origin*, return a new name which is
the name relative to origin. Otherwise return the name.
@@ -479,9 +790,13 @@ class Name:
Returns a ``dns.name.Name``.
"""
- pass
- def derelativize(self, origin: 'Name') ->'Name':
+ if origin is not None and self.is_subdomain(origin):
+ return Name(self[: -len(origin)])
+ else:
+ return self
+
+ def derelativize(self, origin: "Name") -> "Name":
"""If the name is a relative name, return a new name which is the
concatenation of the name and origin. Otherwise return the name.
@@ -491,10 +806,15 @@ class Name:
Returns a ``dns.name.Name``.
"""
- pass
- def choose_relativity(self, origin: Optional['Name']=None, relativize:
- bool=True) ->'Name':
+ if not self.is_absolute():
+ return self.concatenate(origin)
+ else:
+ return self
+
+ def choose_relativity(
+ self, origin: Optional["Name"] = None, relativize: bool = True
+ ) -> "Name":
"""Return a name with the relativity desired by the caller.
If *origin* is ``None``, then the name is returned.
@@ -504,9 +824,16 @@ class Name:
Returns a ``dns.name.Name``.
"""
- pass
- def parent(self) ->'Name':
+ if origin:
+ if relativize:
+ return self.relativize(origin)
+ else:
+ return self.derelativize(origin)
+ else:
+ return self
+
+ def parent(self) -> "Name":
"""Return the parent of the name.
For example, the parent of ``www.dnspython.org.`` is ``dnspython.org``.
@@ -516,9 +843,12 @@ class Name:
Returns a ``dns.name.Name``.
"""
- pass
- def predecessor(self, origin: 'Name', prefix_ok: bool=True) ->'Name':
+ if self == root or self == empty:
+ raise NoParent
+ return Name(self.labels[1:])
+
+ def predecessor(self, origin: "Name", prefix_ok: bool = True) -> "Name":
"""Return the maximal predecessor of *name* in the DNSSEC ordering in the zone
whose origin is *origin*, or return the longest name under *origin* if the
name is origin (i.e. wrap around to the longest name, which may still be
@@ -532,9 +862,11 @@ class Name:
defaults to ``True``. Normally it is good to allow this, but if computing
a maximal predecessor at a zone cut point then ``False`` must be specified.
"""
- pass
+ return _handle_relativity_and_call(
+ _absolute_predecessor, self, origin, prefix_ok
+ )
- def successor(self, origin: 'Name', prefix_ok: bool=True) ->'Name':
+ def successor(self, origin: "Name", prefix_ok: bool = True) -> "Name":
"""Return the minimal successor of *name* in the DNSSEC ordering in the zone
whose origin is *origin*, or return *origin* if the successor cannot be
computed due to name length limitations.
@@ -550,15 +882,19 @@ class Name:
defaults to ``True``. Normally it is good to allow this, but if computing
a minimal successor at a zone cut point then ``False`` must be specified.
"""
- pass
+ return _handle_relativity_and_call(_absolute_successor, self, origin, prefix_ok)
-root = Name([b''])
+#: The root name, '.'
+root = Name([b""])
+
+#: The empty name.
empty = Name([])
-def from_unicode(text: str, origin: Optional[Name]=root, idna_codec:
- Optional[IDNACodec]=None) ->Name:
+def from_unicode(
+ text: str, origin: Optional[Name] = root, idna_codec: Optional[IDNACodec] = None
+) -> Name:
"""Convert unicode text into a Name object.
Labels are encoded in IDN ACE form according to rules specified by
@@ -575,11 +911,76 @@ def from_unicode(text: str, origin: Optional[Name]=root, idna_codec:
Returns a ``dns.name.Name``.
"""
- pass
+ if not isinstance(text, str):
+ raise ValueError("input to from_unicode() must be a unicode string")
+ if not (origin is None or isinstance(origin, Name)):
+ raise ValueError("origin must be a Name or None")
+ labels = []
+ label = ""
+ escaping = False
+ edigits = 0
+ total = 0
+ if idna_codec is None:
+ idna_codec = IDNA_2003
+ if text == "@":
+ text = ""
+ if text:
+ if text in [".", "\u3002", "\uff0e", "\uff61"]:
+ return Name([b""]) # no Unicode "u" on this constant!
+ for c in text:
+ if escaping:
+ if edigits == 0:
+ if c.isdigit():
+ total = int(c)
+ edigits += 1
+ else:
+ label += c
+ escaping = False
+ else:
+ if not c.isdigit():
+ raise BadEscape
+ total *= 10
+ total += int(c)
+ edigits += 1
+ if edigits == 3:
+ escaping = False
+ label += chr(total)
+ elif c in [".", "\u3002", "\uff0e", "\uff61"]:
+ if len(label) == 0:
+ raise EmptyLabel
+ labels.append(idna_codec.encode(label))
+ label = ""
+ elif c == "\\":
+ escaping = True
+ edigits = 0
+ total = 0
+ else:
+ label += c
+ if escaping:
+ raise BadEscape
+ if len(label) > 0:
+ labels.append(idna_codec.encode(label))
+ else:
+ labels.append(b"")
+
+ if (len(labels) == 0 or labels[-1] != b"") and origin is not None:
+ labels.extend(list(origin.labels))
+ return Name(labels)
-def from_text(text: Union[bytes, str], origin: Optional[Name]=root,
- idna_codec: Optional[IDNACodec]=None) ->Name:
+
+def is_all_ascii(text: str) -> bool:
+ for c in text:
+ if ord(c) > 0x7F:
+ return False
+ return True
+
+
+def from_text(
+ text: Union[bytes, str],
+ origin: Optional[Name] = root,
+ idna_codec: Optional[IDNACodec] = None,
+) -> Name:
"""Convert text into a Name object.
*text*, a ``bytes`` or ``str``, is the text to convert into a name.
@@ -593,10 +994,79 @@ def from_text(text: Union[bytes, str], origin: Optional[Name]=root,
Returns a ``dns.name.Name``.
"""
- pass
+ if isinstance(text, str):
+ if not is_all_ascii(text):
+ # Some codepoint in the input text is > 127, so IDNA applies.
+ return from_unicode(text, origin, idna_codec)
+ # The input is all ASCII, so treat this like an ordinary non-IDNA
+ # domain name. Note that "all ASCII" is about the input text,
+ # not the codepoints in the domain name. E.g. if text has value
+ #
+ # r'\150\151\152\153\154\155\156\157\158\159'
+ #
+ # then it's still "all ASCII" even though the domain name has
+ # codepoints > 127.
+ text = text.encode("ascii")
+ if not isinstance(text, bytes):
+ raise ValueError("input to from_text() must be a string")
+ if not (origin is None or isinstance(origin, Name)):
+ raise ValueError("origin must be a Name or None")
+ labels = []
+ label = b""
+ escaping = False
+ edigits = 0
+ total = 0
+ if text == b"@":
+ text = b""
+ if text:
+ if text == b".":
+ return Name([b""])
+ for c in text:
+ byte_ = struct.pack("!B", c)
+ if escaping:
+ if edigits == 0:
+ if byte_.isdigit():
+ total = int(byte_)
+ edigits += 1
+ else:
+ label += byte_
+ escaping = False
+ else:
+ if not byte_.isdigit():
+ raise BadEscape
+ total *= 10
+ total += int(byte_)
+ edigits += 1
+ if edigits == 3:
+ escaping = False
+ label += struct.pack("!B", total)
+ elif byte_ == b".":
+ if len(label) == 0:
+ raise EmptyLabel
+ labels.append(label)
+ label = b""
+ elif byte_ == b"\\":
+ escaping = True
+ edigits = 0
+ total = 0
+ else:
+ label += byte_
+ if escaping:
+ raise BadEscape
+ if len(label) > 0:
+ labels.append(label)
+ else:
+ labels.append(b"")
+ if (len(labels) == 0 or labels[-1] != b"") and origin is not None:
+ labels.extend(list(origin.labels))
+ return Name(labels)
+
+
+# we need 'dns.wire.Parser' quoted as dns.name and dns.wire depend on each other.
-def from_wire_parser(parser: 'dns.wire.Parser') ->Name:
+
+def from_wire_parser(parser: "dns.wire.Parser") -> Name:
"""Convert possibly compressed wire format into a Name.
*parser* is a dns.wire.Parser.
@@ -608,10 +1078,28 @@ def from_wire_parser(parser: 'dns.wire.Parser') ->Name:
Returns a ``dns.name.Name``
"""
- pass
-
-def from_wire(message: bytes, current: int) ->Tuple[Name, int]:
+ labels = []
+ biggest_pointer = parser.current
+ with parser.restore_furthest():
+ count = parser.get_uint8()
+ while count != 0:
+ if count < 64:
+ labels.append(parser.get_bytes(count))
+ elif count >= 192:
+ current = (count & 0x3F) * 256 + parser.get_uint8()
+ if current >= biggest_pointer:
+ raise BadPointer
+ biggest_pointer = current
+ parser.seek(current)
+ else:
+ raise BadLabelType
+ count = parser.get_uint8()
+ labels.append(b"")
+ return Name(labels)
+
+
+def from_wire(message: bytes, current: int) -> Tuple[Name, int]:
"""Convert possibly compressed wire format into a Name.
*message* is a ``bytes`` containing an entire DNS message in DNS
@@ -629,13 +1117,167 @@ def from_wire(message: bytes, current: int) ->Tuple[Name, int]:
that was read and the number of bytes of the wire format message
which were consumed reading it.
"""
- pass
+
+ if not isinstance(message, bytes):
+ raise ValueError("input to from_wire() must be a byte string")
+ parser = dns.wire.Parser(message, current)
+ name = from_wire_parser(parser)
+ return (name, parser.current - current)
-_MINIMAL_OCTET = b'\x00'
+# RFC 4471 Support
+
+_MINIMAL_OCTET = b"\x00"
_MINIMAL_OCTET_VALUE = ord(_MINIMAL_OCTET)
_SUCCESSOR_PREFIX = Name([_MINIMAL_OCTET])
-_MAXIMAL_OCTET = b'\xff'
+_MAXIMAL_OCTET = b"\xff"
_MAXIMAL_OCTET_VALUE = ord(_MAXIMAL_OCTET)
-_AT_SIGN_VALUE = ord('@')
-_LEFT_SQUARE_BRACKET_VALUE = ord('[')
+_AT_SIGN_VALUE = ord("@")
+_LEFT_SQUARE_BRACKET_VALUE = ord("[")
+
+
+def _wire_length(labels):
+ return functools.reduce(lambda v, x: v + len(x) + 1, labels, 0)
+
+
+def _pad_to_max_name(name):
+ needed = 255 - _wire_length(name.labels)
+ new_labels = []
+ while needed > 64:
+ new_labels.append(_MAXIMAL_OCTET * 63)
+ needed -= 64
+ if needed >= 2:
+ new_labels.append(_MAXIMAL_OCTET * (needed - 1))
+ # Note we're already maximal in the needed == 1 case as while we'd like
+ # to add one more byte as a new label, we can't, as adding a new non-empty
+ # label requires at least 2 bytes.
+ new_labels = list(reversed(new_labels))
+ new_labels.extend(name.labels)
+ return Name(new_labels)
+
+
+def _pad_to_max_label(label, suffix_labels):
+ length = len(label)
+ # We have to subtract one here to account for the length byte of label.
+ remaining = 255 - _wire_length(suffix_labels) - length - 1
+ if remaining <= 0:
+ # Shouldn't happen!
+ return label
+ needed = min(63 - length, remaining)
+ return label + _MAXIMAL_OCTET * needed
+
+
+def _absolute_predecessor(name: Name, origin: Name, prefix_ok: bool) -> Name:
+ # This is the RFC 4471 predecessor algorithm using the "absolute method" of section
+ # 3.1.1.
+ #
+ # Our caller must ensure that the name and origin are absolute, and that name is a
+ # subdomain of origin.
+ if name == origin:
+ return _pad_to_max_name(name)
+ least_significant_label = name[0]
+ if least_significant_label == _MINIMAL_OCTET:
+ return name.parent()
+ least_octet = least_significant_label[-1]
+ suffix_labels = name.labels[1:]
+ if least_octet == _MINIMAL_OCTET_VALUE:
+ new_labels = [least_significant_label[:-1]]
+ else:
+ octets = bytearray(least_significant_label)
+ octet = octets[-1]
+ if octet == _LEFT_SQUARE_BRACKET_VALUE:
+ octet = _AT_SIGN_VALUE
+ else:
+ octet -= 1
+ octets[-1] = octet
+ least_significant_label = bytes(octets)
+ new_labels = [_pad_to_max_label(least_significant_label, suffix_labels)]
+ new_labels.extend(suffix_labels)
+ name = Name(new_labels)
+ if prefix_ok:
+ return _pad_to_max_name(name)
+ else:
+ return name
+
+
+def _absolute_successor(name: Name, origin: Name, prefix_ok: bool) -> Name:
+ # This is the RFC 4471 successor algorithm using the "absolute method" of section
+ # 3.1.2.
+ #
+ # Our caller must ensure that the name and origin are absolute, and that name is a
+ # subdomain of origin.
+ if prefix_ok:
+ # Try prefixing \000 as new label
+ try:
+ return _SUCCESSOR_PREFIX.concatenate(name)
+ except NameTooLong:
+ pass
+ while name != origin:
+ # Try extending the least significant label.
+ least_significant_label = name[0]
+ if len(least_significant_label) < 63:
+ # We may be able to extend the least label with a minimal additional byte.
+ # This is only "may" because we could have a maximal length name even though
+ # the least significant label isn't maximally long.
+ new_labels = [least_significant_label + _MINIMAL_OCTET]
+ new_labels.extend(name.labels[1:])
+ try:
+ return dns.name.Name(new_labels)
+ except dns.name.NameTooLong:
+ pass
+ # We can't extend the label either, so we'll try to increment the least
+ # signficant non-maximal byte in it.
+ octets = bytearray(least_significant_label)
+ # We do this reversed iteration with an explicit indexing variable because
+ # if we find something to increment, we're going to want to truncate everything
+ # to the right of it.
+ for i in range(len(octets) - 1, -1, -1):
+ octet = octets[i]
+ if octet == _MAXIMAL_OCTET_VALUE:
+ # We can't increment this, so keep looking.
+ continue
+ # Finally, something we can increment. We have to apply a special rule for
+ # incrementing "@", sending it to "[", because RFC 4034 6.1 says that when
+ # comparing names, uppercase letters compare as if they were their
+ # lower-case equivalents. If we increment "@" to "A", then it would compare
+ # as "a", which is after "[", "\", "]", "^", "_", and "`", so we would have
+ # skipped the most minimal successor, namely "[".
+ if octet == _AT_SIGN_VALUE:
+ octet = _LEFT_SQUARE_BRACKET_VALUE
+ else:
+ octet += 1
+ octets[i] = octet
+ # We can now truncate all of the maximal values we skipped (if any)
+ new_labels = [bytes(octets[: i + 1])]
+ new_labels.extend(name.labels[1:])
+ # We haven't changed the length of the name, so the Name constructor will
+ # always work.
+ return Name(new_labels)
+ # We couldn't increment, so chop off the least significant label and try
+ # again.
+ name = name.parent()
+
+ # We couldn't increment at all, so return the origin, as wrapping around is the
+ # DNSSEC way.
+ return origin
+
+
+def _handle_relativity_and_call(
+ function: Callable[[Name, Name, bool], Name],
+ name: Name,
+ origin: Name,
+ prefix_ok: bool,
+) -> Name:
+ # Make "name" absolute if needed, ensure that the origin is absolute,
+ # call function(), and then relativize the result if needed.
+ if not origin.is_absolute():
+ raise NeedAbsoluteNameOrOrigin
+ relative = not name.is_absolute()
+ if relative:
+ name = name.derelativize(origin)
+ elif not name.is_subdomain(origin):
+ raise NeedSubdomainOfOrigin
+ result_name = function(name, origin, prefix_ok)
+ if relative:
+ result_name = result_name.relativize(origin)
+ return result_name
diff --git a/dns/namedict.py b/dns/namedict.py
index e2fcfff..ca8b197 100644
--- a/dns/namedict.py
+++ b/dns/namedict.py
@@ -1,5 +1,35 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2017 Nominum, Inc.
+# Copyright (C) 2016 Coresec Systems AB
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND CORESEC SYSTEMS AB DISCLAIMS ALL
+# WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED
+# WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL CORESEC
+# SYSTEMS AB BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR
+# CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS
+# OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT,
+# NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION
+# WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
"""DNS name dictionary"""
-from collections.abc import MutableMapping
+
+# pylint seems to be confused about this one!
+from collections.abc import MutableMapping # pylint: disable=no-name-in-module
+
import dns.name
@@ -9,21 +39,31 @@ class NameDict(MutableMapping):
In addition to being like a regular Python dictionary, this
dictionary can also get the deepest match for a given key.
"""
- __slots__ = ['max_depth', 'max_depth_items', '__store']
+
+ __slots__ = ["max_depth", "max_depth_items", "__store"]
def __init__(self, *args, **kwargs):
super().__init__()
self.__store = dict()
+ #: the maximum depth of the keys that have ever been added
self.max_depth = 0
+ #: the number of items of maximum depth
self.max_depth_items = 0
self.update(dict(*args, **kwargs))
+ def __update_max_depth(self, key):
+ if len(key) == self.max_depth:
+ self.max_depth_items = self.max_depth_items + 1
+ elif len(key) > self.max_depth:
+ self.max_depth = len(key)
+ self.max_depth_items = 1
+
def __getitem__(self, key):
return self.__store[key]
def __setitem__(self, key, value):
if not isinstance(key, dns.name.Name):
- raise ValueError('NameDict key must be a name')
+ raise ValueError("NameDict key must be a name")
self.__store[key] = value
self.__update_max_depth(key)
@@ -42,6 +82,9 @@ class NameDict(MutableMapping):
def __len__(self):
return len(self.__store)
+ def has_key(self, key):
+ return key in self.__store
+
def get_deepest_match(self, name):
"""Find the deepest match to *name* in the dictionary.
@@ -54,4 +97,13 @@ class NameDict(MutableMapping):
Returns a ``(key, value)`` where *key* is the deepest
``dns.name.Name``, and *value* is the value associated with *key*.
"""
- pass
+
+ depth = len(name)
+ if depth > self.max_depth:
+ depth = self.max_depth
+ for i in range(-depth, 0):
+ n = dns.name.Name(name[i:])
+ if n in self:
+ return (n, self[n])
+ v = self[dns.name.empty]
+ return (dns.name.empty, v)
diff --git a/dns/nameserver.py b/dns/nameserver.py
index 3c1878b..5dbb4e8 100644
--- a/dns/nameserver.py
+++ b/dns/nameserver.py
@@ -1,5 +1,6 @@
from typing import Optional, Union
from urllib.parse import urlparse
+
import dns.asyncbackend
import dns.asyncquery
import dns.inet
@@ -8,59 +9,351 @@ import dns.query
class Nameserver:
-
def __init__(self):
pass
def __str__(self):
raise NotImplementedError
+ def kind(self) -> str:
+ raise NotImplementedError
-class AddressAndPortNameserver(Nameserver):
+ def is_always_max_size(self) -> bool:
+ raise NotImplementedError
+
+ def answer_nameserver(self) -> str:
+ raise NotImplementedError
+
+ def answer_port(self) -> int:
+ raise NotImplementedError
+
+ def query(
+ self,
+ request: dns.message.QueryMessage,
+ timeout: float,
+ source: Optional[str],
+ source_port: int,
+ max_size: bool,
+ one_rr_per_rrset: bool = False,
+ ignore_trailing: bool = False,
+ ) -> dns.message.Message:
+ raise NotImplementedError
+
+ async def async_query(
+ self,
+ request: dns.message.QueryMessage,
+ timeout: float,
+ source: Optional[str],
+ source_port: int,
+ max_size: bool,
+ backend: dns.asyncbackend.Backend,
+ one_rr_per_rrset: bool = False,
+ ignore_trailing: bool = False,
+ ) -> dns.message.Message:
+ raise NotImplementedError
+
+class AddressAndPortNameserver(Nameserver):
def __init__(self, address: str, port: int):
super().__init__()
self.address = address
self.port = port
+ def kind(self) -> str:
+ raise NotImplementedError
+
+ def is_always_max_size(self) -> bool:
+ return False
+
def __str__(self):
ns_kind = self.kind()
- return f'{ns_kind}:{self.address}@{self.port}'
+ return f"{ns_kind}:{self.address}@{self.port}"
+ def answer_nameserver(self) -> str:
+ return self.address
+
+ def answer_port(self) -> int:
+ return self.port
-class Do53Nameserver(AddressAndPortNameserver):
- def __init__(self, address: str, port: int=53):
+class Do53Nameserver(AddressAndPortNameserver):
+ def __init__(self, address: str, port: int = 53):
super().__init__(address, port)
+ def kind(self):
+ return "Do53"
-class DoHNameserver(Nameserver):
+ def query(
+ self,
+ request: dns.message.QueryMessage,
+ timeout: float,
+ source: Optional[str],
+ source_port: int,
+ max_size: bool,
+ one_rr_per_rrset: bool = False,
+ ignore_trailing: bool = False,
+ ) -> dns.message.Message:
+ if max_size:
+ response = dns.query.tcp(
+ request,
+ self.address,
+ timeout=timeout,
+ port=self.port,
+ source=source,
+ source_port=source_port,
+ one_rr_per_rrset=one_rr_per_rrset,
+ ignore_trailing=ignore_trailing,
+ )
+ else:
+ response = dns.query.udp(
+ request,
+ self.address,
+ timeout=timeout,
+ port=self.port,
+ source=source,
+ source_port=source_port,
+ raise_on_truncation=True,
+ one_rr_per_rrset=one_rr_per_rrset,
+ ignore_trailing=ignore_trailing,
+ ignore_errors=True,
+ ignore_unexpected=True,
+ )
+ return response
+
+ async def async_query(
+ self,
+ request: dns.message.QueryMessage,
+ timeout: float,
+ source: Optional[str],
+ source_port: int,
+ max_size: bool,
+ backend: dns.asyncbackend.Backend,
+ one_rr_per_rrset: bool = False,
+ ignore_trailing: bool = False,
+ ) -> dns.message.Message:
+ if max_size:
+ response = await dns.asyncquery.tcp(
+ request,
+ self.address,
+ timeout=timeout,
+ port=self.port,
+ source=source,
+ source_port=source_port,
+ backend=backend,
+ one_rr_per_rrset=one_rr_per_rrset,
+ ignore_trailing=ignore_trailing,
+ )
+ else:
+ response = await dns.asyncquery.udp(
+ request,
+ self.address,
+ timeout=timeout,
+ port=self.port,
+ source=source,
+ source_port=source_port,
+ raise_on_truncation=True,
+ backend=backend,
+ one_rr_per_rrset=one_rr_per_rrset,
+ ignore_trailing=ignore_trailing,
+ ignore_errors=True,
+ ignore_unexpected=True,
+ )
+ return response
- def __init__(self, url: str, bootstrap_address: Optional[str]=None,
- verify: Union[bool, str]=True, want_get: bool=False):
+
+class DoHNameserver(Nameserver):
+ def __init__(
+ self,
+ url: str,
+ bootstrap_address: Optional[str] = None,
+ verify: Union[bool, str] = True,
+ want_get: bool = False,
+ ):
super().__init__()
self.url = url
self.bootstrap_address = bootstrap_address
self.verify = verify
self.want_get = want_get
+ def kind(self):
+ return "DoH"
+
+ def is_always_max_size(self) -> bool:
+ return True
+
def __str__(self):
return self.url
+ def answer_nameserver(self) -> str:
+ return self.url
-class DoTNameserver(AddressAndPortNameserver):
+ def answer_port(self) -> int:
+ port = urlparse(self.url).port
+ if port is None:
+ port = 443
+ return port
+
+ def query(
+ self,
+ request: dns.message.QueryMessage,
+ timeout: float,
+ source: Optional[str],
+ source_port: int,
+ max_size: bool = False,
+ one_rr_per_rrset: bool = False,
+ ignore_trailing: bool = False,
+ ) -> dns.message.Message:
+ return dns.query.https(
+ request,
+ self.url,
+ timeout=timeout,
+ source=source,
+ source_port=source_port,
+ bootstrap_address=self.bootstrap_address,
+ one_rr_per_rrset=one_rr_per_rrset,
+ ignore_trailing=ignore_trailing,
+ verify=self.verify,
+ post=(not self.want_get),
+ )
- def __init__(self, address: str, port: int=853, hostname: Optional[str]
- =None, verify: Union[bool, str]=True):
+ async def async_query(
+ self,
+ request: dns.message.QueryMessage,
+ timeout: float,
+ source: Optional[str],
+ source_port: int,
+ max_size: bool,
+ backend: dns.asyncbackend.Backend,
+ one_rr_per_rrset: bool = False,
+ ignore_trailing: bool = False,
+ ) -> dns.message.Message:
+ return await dns.asyncquery.https(
+ request,
+ self.url,
+ timeout=timeout,
+ source=source,
+ source_port=source_port,
+ bootstrap_address=self.bootstrap_address,
+ one_rr_per_rrset=one_rr_per_rrset,
+ ignore_trailing=ignore_trailing,
+ verify=self.verify,
+ post=(not self.want_get),
+ )
+
+
+class DoTNameserver(AddressAndPortNameserver):
+ def __init__(
+ self,
+ address: str,
+ port: int = 853,
+ hostname: Optional[str] = None,
+ verify: Union[bool, str] = True,
+ ):
super().__init__(address, port)
self.hostname = hostname
self.verify = verify
+ def kind(self):
+ return "DoT"
-class DoQNameserver(AddressAndPortNameserver):
+ def query(
+ self,
+ request: dns.message.QueryMessage,
+ timeout: float,
+ source: Optional[str],
+ source_port: int,
+ max_size: bool = False,
+ one_rr_per_rrset: bool = False,
+ ignore_trailing: bool = False,
+ ) -> dns.message.Message:
+ return dns.query.tls(
+ request,
+ self.address,
+ port=self.port,
+ timeout=timeout,
+ one_rr_per_rrset=one_rr_per_rrset,
+ ignore_trailing=ignore_trailing,
+ server_hostname=self.hostname,
+ verify=self.verify,
+ )
+
+ async def async_query(
+ self,
+ request: dns.message.QueryMessage,
+ timeout: float,
+ source: Optional[str],
+ source_port: int,
+ max_size: bool,
+ backend: dns.asyncbackend.Backend,
+ one_rr_per_rrset: bool = False,
+ ignore_trailing: bool = False,
+ ) -> dns.message.Message:
+ return await dns.asyncquery.tls(
+ request,
+ self.address,
+ port=self.port,
+ timeout=timeout,
+ one_rr_per_rrset=one_rr_per_rrset,
+ ignore_trailing=ignore_trailing,
+ server_hostname=self.hostname,
+ verify=self.verify,
+ )
- def __init__(self, address: str, port: int=853, verify: Union[bool, str
- ]=True, server_hostname: Optional[str]=None):
+
+class DoQNameserver(AddressAndPortNameserver):
+ def __init__(
+ self,
+ address: str,
+ port: int = 853,
+ verify: Union[bool, str] = True,
+ server_hostname: Optional[str] = None,
+ ):
super().__init__(address, port)
self.verify = verify
self.server_hostname = server_hostname
+
+ def kind(self):
+ return "DoQ"
+
+ def query(
+ self,
+ request: dns.message.QueryMessage,
+ timeout: float,
+ source: Optional[str],
+ source_port: int,
+ max_size: bool = False,
+ one_rr_per_rrset: bool = False,
+ ignore_trailing: bool = False,
+ ) -> dns.message.Message:
+ return dns.query.quic(
+ request,
+ self.address,
+ port=self.port,
+ timeout=timeout,
+ one_rr_per_rrset=one_rr_per_rrset,
+ ignore_trailing=ignore_trailing,
+ verify=self.verify,
+ server_hostname=self.server_hostname,
+ )
+
+ async def async_query(
+ self,
+ request: dns.message.QueryMessage,
+ timeout: float,
+ source: Optional[str],
+ source_port: int,
+ max_size: bool,
+ backend: dns.asyncbackend.Backend,
+ one_rr_per_rrset: bool = False,
+ ignore_trailing: bool = False,
+ ) -> dns.message.Message:
+ return await dns.asyncquery.quic(
+ request,
+ self.address,
+ port=self.port,
+ timeout=timeout,
+ one_rr_per_rrset=one_rr_per_rrset,
+ ignore_trailing=ignore_trailing,
+ verify=self.verify,
+ server_hostname=self.server_hostname,
+ )
diff --git a/dns/node.py b/dns/node.py
index 802f226..de85a82 100644
--- a/dns/node.py
+++ b/dns/node.py
@@ -1,7 +1,26 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2001-2017 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
"""DNS nodes. A node is a set of rdatasets."""
+
import enum
import io
from typing import Any, Dict, Optional
+
import dns.immutable
import dns.name
import dns.rdataclass
@@ -9,17 +28,46 @@ import dns.rdataset
import dns.rdatatype
import dns.renderer
import dns.rrset
-_cname_types = {dns.rdatatype.CNAME}
-_neutral_types = {dns.rdatatype.NSEC, dns.rdatatype.NSEC3, dns.rdatatype.KEY}
+
+_cname_types = {
+ dns.rdatatype.CNAME,
+}
+
+# "neutral" types can coexist with a CNAME and thus are not "other data"
+_neutral_types = {
+ dns.rdatatype.NSEC, # RFC 4035 section 2.5
+ dns.rdatatype.NSEC3, # This is not likely to happen, but not impossible!
+ dns.rdatatype.KEY, # RFC 4035 section 2.5, RFC 3007
+}
+
+
+def _matches_type_or_its_signature(rdtypes, rdtype, covers):
+ return rdtype in rdtypes or (rdtype == dns.rdatatype.RRSIG and covers in rdtypes)
@enum.unique
class NodeKind(enum.Enum):
"""Rdatasets in nodes"""
- REGULAR = 0
+
+ REGULAR = 0 # a.k.a "other data"
NEUTRAL = 1
CNAME = 2
+ @classmethod
+ def classify(
+ cls, rdtype: dns.rdatatype.RdataType, covers: dns.rdatatype.RdataType
+ ) -> "NodeKind":
+ if _matches_type_or_its_signature(_cname_types, rdtype, covers):
+ return NodeKind.CNAME
+ elif _matches_type_or_its_signature(_neutral_types, rdtype, covers):
+ return NodeKind.NEUTRAL
+ else:
+ return NodeKind.REGULAR
+
+ @classmethod
+ def classify_rdataset(cls, rdataset: dns.rdataset.Rdataset) -> "NodeKind":
+ return cls.classify(rdataset.rdtype, rdataset.covers)
+
class Node:
"""A Node is a set of rdatasets.
@@ -36,12 +84,14 @@ class Node:
an MX rdataset and add a CNAME rdataset, the MX rdataset will be
deleted.
"""
- __slots__ = ['rdatasets']
+
+ __slots__ = ["rdatasets"]
def __init__(self):
+ # the set of rdatasets, represented as a list.
self.rdatasets = []
- def to_text(self, name: dns.name.Name, **kw: Dict[str, Any]) ->str:
+ def to_text(self, name: dns.name.Name, **kw: Dict[str, Any]) -> str:
"""Convert a node to text format.
Each rdataset at the node is printed. Any keyword arguments
@@ -53,12 +103,21 @@ class Node:
Returns a ``str``.
"""
- pass
+
+ s = io.StringIO()
+ for rds in self.rdatasets:
+ if len(rds) > 0:
+ s.write(rds.to_text(name, **kw)) # type: ignore[arg-type]
+ s.write("\n")
+ return s.getvalue()[:-1]
def __repr__(self):
- return '<DNS node ' + str(id(self)) + '>'
+ return "<DNS node " + str(id(self)) + ">"
def __eq__(self, other):
+ #
+ # This is inefficient. Good thing we don't need to do it much.
+ #
for rd in self.rdatasets:
if rd not in other.rdatasets:
return False
@@ -85,11 +144,32 @@ class Node:
RRSIGs are deleted. If the rdataset being appended has
``NodeKind.REGULAR`` then CNAME and RRSIG(CNAME) are deleted.
"""
- pass
-
- def find_rdataset(self, rdclass: dns.rdataclass.RdataClass, rdtype: dns
- .rdatatype.RdataType, covers: dns.rdatatype.RdataType=dns.rdatatype
- .NONE, create: bool=False) ->dns.rdataset.Rdataset:
+ # Make having just one rdataset at the node fast.
+ if len(self.rdatasets) > 0:
+ kind = NodeKind.classify_rdataset(rdataset)
+ if kind == NodeKind.CNAME:
+ self.rdatasets = [
+ rds
+ for rds in self.rdatasets
+ if NodeKind.classify_rdataset(rds) != NodeKind.REGULAR
+ ]
+ elif kind == NodeKind.REGULAR:
+ self.rdatasets = [
+ rds
+ for rds in self.rdatasets
+ if NodeKind.classify_rdataset(rds) != NodeKind.CNAME
+ ]
+ # Otherwise the rdataset is NodeKind.NEUTRAL and we do not need to
+ # edit self.rdatasets.
+ self.rdatasets.append(rdataset)
+
+ def find_rdataset(
+ self,
+ rdclass: dns.rdataclass.RdataClass,
+ rdtype: dns.rdatatype.RdataType,
+ covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
+ create: bool = False,
+ ) -> dns.rdataset.Rdataset:
"""Find an rdataset matching the specified properties in the
current node.
@@ -114,11 +194,23 @@ class Node:
Returns a ``dns.rdataset.Rdataset``.
"""
- pass
- def get_rdataset(self, rdclass: dns.rdataclass.RdataClass, rdtype: dns.
- rdatatype.RdataType, covers: dns.rdatatype.RdataType=dns.rdatatype.
- NONE, create: bool=False) ->Optional[dns.rdataset.Rdataset]:
+ for rds in self.rdatasets:
+ if rds.match(rdclass, rdtype, covers):
+ return rds
+ if not create:
+ raise KeyError
+ rds = dns.rdataset.Rdataset(rdclass, rdtype, covers)
+ self._append_rdataset(rds)
+ return rds
+
+ def get_rdataset(
+ self,
+ rdclass: dns.rdataclass.RdataClass,
+ rdtype: dns.rdatatype.RdataType,
+ covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
+ create: bool = False,
+ ) -> Optional[dns.rdataset.Rdataset]:
"""Get an rdataset matching the specified properties in the
current node.
@@ -142,11 +234,19 @@ class Node:
Returns a ``dns.rdataset.Rdataset`` or ``None``.
"""
- pass
- def delete_rdataset(self, rdclass: dns.rdataclass.RdataClass, rdtype:
- dns.rdatatype.RdataType, covers: dns.rdatatype.RdataType=dns.
- rdatatype.NONE) ->None:
+ try:
+ rds = self.find_rdataset(rdclass, rdtype, covers, create)
+ except KeyError:
+ rds = None
+ return rds
+
+ def delete_rdataset(
+ self,
+ rdclass: dns.rdataclass.RdataClass,
+ rdtype: dns.rdatatype.RdataType,
+ covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
+ ) -> None:
"""Delete the rdataset matching the specified properties in the
current node.
@@ -158,9 +258,12 @@ class Node:
*covers*, an ``int``, the covered type.
"""
- pass
- def replace_rdataset(self, replacement: dns.rdataset.Rdataset) ->None:
+ rds = self.get_rdataset(rdclass, rdtype, covers)
+ if rds is not None:
+ self.rdatasets.remove(rds)
+
+ def replace_rdataset(self, replacement: dns.rdataset.Rdataset) -> None:
"""Replace an rdataset.
It is not an error if there is no rdataset matching *replacement*.
@@ -174,9 +277,19 @@ class Node:
Raises ``ValueError`` if *replacement* is not a
``dns.rdataset.Rdataset``.
"""
- pass
- def classify(self) ->NodeKind:
+ if not isinstance(replacement, dns.rdataset.Rdataset):
+ raise ValueError("replacement is not an rdataset")
+ if isinstance(replacement, dns.rrset.RRset):
+ # RRsets are not good replacements as the match() method
+ # is not compatible.
+ replacement = replacement.to_rdataset()
+ self.delete_rdataset(
+ replacement.rdclass, replacement.rdtype, replacement.covers
+ )
+ self._append_rdataset(replacement)
+
+ def classify(self) -> NodeKind:
"""Classify a node.
A node which contains a CNAME or RRSIG(CNAME) is a
@@ -191,13 +304,56 @@ class Node:
or a neutral type is a a ``NodeKind.REGULAR`` node. Regular nodes are
also commonly referred to as "other data".
"""
- pass
+ for rdataset in self.rdatasets:
+ kind = NodeKind.classify(rdataset.rdtype, rdataset.covers)
+ if kind != NodeKind.NEUTRAL:
+ return kind
+ return NodeKind.NEUTRAL
+
+ def is_immutable(self) -> bool:
+ return False
@dns.immutable.immutable
class ImmutableNode(Node):
-
def __init__(self, node):
super().__init__()
- self.rdatasets = tuple([dns.rdataset.ImmutableRdataset(rds) for rds in
- node.rdatasets])
+ self.rdatasets = tuple(
+ [dns.rdataset.ImmutableRdataset(rds) for rds in node.rdatasets]
+ )
+
+ def find_rdataset(
+ self,
+ rdclass: dns.rdataclass.RdataClass,
+ rdtype: dns.rdatatype.RdataType,
+ covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
+ create: bool = False,
+ ) -> dns.rdataset.Rdataset:
+ if create:
+ raise TypeError("immutable")
+ return super().find_rdataset(rdclass, rdtype, covers, False)
+
+ def get_rdataset(
+ self,
+ rdclass: dns.rdataclass.RdataClass,
+ rdtype: dns.rdatatype.RdataType,
+ covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
+ create: bool = False,
+ ) -> Optional[dns.rdataset.Rdataset]:
+ if create:
+ raise TypeError("immutable")
+ return super().get_rdataset(rdclass, rdtype, covers, False)
+
+ def delete_rdataset(
+ self,
+ rdclass: dns.rdataclass.RdataClass,
+ rdtype: dns.rdatatype.RdataType,
+ covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
+ ) -> None:
+ raise TypeError("immutable")
+
+ def replace_rdataset(self, replacement: dns.rdataset.Rdataset) -> None:
+ raise TypeError("immutable")
+
+ def is_immutable(self) -> bool:
+ return True
diff --git a/dns/opcode.py b/dns/opcode.py
index dfea1ae..78b43d2 100644
--- a/dns/opcode.py
+++ b/dns/opcode.py
@@ -1,21 +1,52 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2001-2017 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
"""DNS Opcodes."""
+
import dns.enum
import dns.exception
class Opcode(dns.enum.IntEnum):
+ #: Query
QUERY = 0
+ #: Inverse Query (historical)
IQUERY = 1
+ #: Server Status (unspecified and unimplemented anywhere)
STATUS = 2
+ #: Notify
NOTIFY = 4
+ #: Dynamic Update
UPDATE = 5
+ @classmethod
+ def _maximum(cls):
+ return 15
+
+ @classmethod
+ def _unknown_exception_class(cls):
+ return UnknownOpcode
+
class UnknownOpcode(dns.exception.DNSException):
"""An DNS opcode is unknown."""
-def from_text(text: str) ->Opcode:
+def from_text(text: str) -> Opcode:
"""Convert text into an opcode.
*text*, a ``str``, the textual opcode
@@ -24,20 +55,22 @@ def from_text(text: str) ->Opcode:
Returns an ``int``.
"""
- pass
+
+ return Opcode.from_text(text)
-def from_flags(flags: int) ->Opcode:
+def from_flags(flags: int) -> Opcode:
"""Extract an opcode from DNS message flags.
*flags*, an ``int``, the DNS flags.
Returns an ``int``.
"""
- pass
+
+ return Opcode((flags & 0x7800) >> 11)
-def to_flags(value: Opcode) ->int:
+def to_flags(value: Opcode) -> int:
"""Convert an opcode to a value suitable for ORing into DNS message
flags.
@@ -45,10 +78,11 @@ def to_flags(value: Opcode) ->int:
Returns an ``int``.
"""
- pass
+ return (value << 11) & 0x7800
-def to_text(value: Opcode) ->str:
+
+def to_text(value: Opcode) -> str:
"""Convert an opcode to text.
*value*, an ``int`` the opcode value,
@@ -57,21 +91,27 @@ def to_text(value: Opcode) ->str:
Returns a ``str``.
"""
- pass
+
+ return Opcode.to_text(value)
-def is_update(flags: int) ->bool:
+def is_update(flags: int) -> bool:
"""Is the opcode in flags UPDATE?
*flags*, an ``int``, the DNS message flags.
Returns a ``bool``.
"""
- pass
+
+ return from_flags(flags) == Opcode.UPDATE
+### BEGIN generated Opcode constants
+
QUERY = Opcode.QUERY
IQUERY = Opcode.IQUERY
STATUS = Opcode.STATUS
NOTIFY = Opcode.NOTIFY
UPDATE = Opcode.UPDATE
+
+### END generated Opcode constants
diff --git a/dns/query.py b/dns/query.py
index 2ec4b4f..f0ee916 100644
--- a/dns/query.py
+++ b/dns/query.py
@@ -1,4 +1,22 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2017 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
"""Talk to a DNS server."""
+
import base64
import contextlib
import enum
@@ -10,6 +28,7 @@ import socket
import struct
import time
from typing import Any, Dict, Optional, Tuple, Union
+
import dns._features
import dns.exception
import dns.inet
@@ -23,16 +42,32 @@ import dns.serial
import dns.transaction
import dns.tsig
import dns.xfr
-_have_httpx = dns._features.have('doh')
+
+
+def _remaining(expiration):
+ if expiration is None:
+ return None
+ timeout = expiration - time.time()
+ if timeout <= 0.0:
+ raise dns.exception.Timeout
+ return timeout
+
+
+def _expiration_for_this_attempt(timeout, expiration):
+ if expiration is None:
+ return None
+ return min(time.time() + timeout, expiration)
+
+
+_have_httpx = dns._features.have("doh")
if _have_httpx:
import httpcore._backends.sync
import httpx
+
_CoreNetworkBackend = httpcore.NetworkBackend
_CoreSyncStream = httpcore._backends.sync.SyncStream
-
class _NetworkBackend(_CoreNetworkBackend):
-
def __init__(self, resolver, local_port, bootstrap_address, family):
super().__init__()
self._local_port = local_port
@@ -40,46 +75,105 @@ if _have_httpx:
self._bootstrap_address = bootstrap_address
self._family = family
+ def connect_tcp(
+ self, host, port, timeout, local_address, socket_options=None
+ ): # pylint: disable=signature-differs
+ addresses = []
+ _, expiration = _compute_times(timeout)
+ if dns.inet.is_address(host):
+ addresses.append(host)
+ elif self._bootstrap_address is not None:
+ addresses.append(self._bootstrap_address)
+ else:
+ timeout = _remaining(expiration)
+ family = self._family
+ if local_address:
+ family = dns.inet.af_for_address(local_address)
+ answers = self._resolver.resolve_name(
+ host, family=family, lifetime=timeout
+ )
+ addresses = answers.addresses()
+ for address in addresses:
+ af = dns.inet.af_for_address(address)
+ if local_address is not None or self._local_port != 0:
+ source = dns.inet.low_level_address_tuple(
+ (local_address, self._local_port), af
+ )
+ else:
+ source = None
+ sock = _make_socket(af, socket.SOCK_STREAM, source)
+ attempt_expiration = _expiration_for_this_attempt(2.0, expiration)
+ try:
+ _connect(
+ sock,
+ dns.inet.low_level_address_tuple((address, port), af),
+ attempt_expiration,
+ )
+ return _CoreSyncStream(sock)
+ except Exception:
+ pass
+ raise httpcore.ConnectError
+
+ def connect_unix_socket(
+ self, path, timeout, socket_options=None
+ ): # pylint: disable=signature-differs
+ raise NotImplementedError
class _HTTPTransport(httpx.HTTPTransport):
-
- def __init__(self, *args, local_port=0, bootstrap_address=None,
- resolver=None, family=socket.AF_UNSPEC, **kwargs):
+ def __init__(
+ self,
+ *args,
+ local_port=0,
+ bootstrap_address=None,
+ resolver=None,
+ family=socket.AF_UNSPEC,
+ **kwargs,
+ ):
if resolver is None:
+ # pylint: disable=import-outside-toplevel,redefined-outer-name
import dns.resolver
+
resolver = dns.resolver.Resolver()
super().__init__(*args, **kwargs)
- self._pool._network_backend = _NetworkBackend(resolver,
- local_port, bootstrap_address, family)
+ self._pool._network_backend = _NetworkBackend(
+ resolver, local_port, bootstrap_address, family
+ )
+
else:
+ class _HTTPTransport: # type: ignore
+ def connect_tcp(self, host, port, timeout, local_address):
+ raise NotImplementedError
+
- class _HTTPTransport:
- pass
have_doh = _have_httpx
+
try:
import ssl
-except ImportError:
-
+except ImportError: # pragma: no cover
- class ssl:
+ class ssl: # type: ignore
CERT_NONE = 0
-
class WantReadException(Exception):
pass
-
class WantWriteException(Exception):
pass
-
class SSLContext:
pass
-
class SSLSocket:
pass
+
+ @classmethod
+ def create_default_context(cls, *args, **kwargs):
+ raise Exception("no ssl support") # pylint: disable=broad-exception-raised
+
+
+# Function used to create a socket. Can be overridden if needed in special
+# situations.
socket_factory = socket.socket
@@ -101,20 +195,183 @@ class NoDOQ(dns.exception.DNSException):
available."""
+# for backwards compatibility
TransferError = dns.xfr.TransferError
-if hasattr(selectors, 'PollSelector'):
- _selector_class = selectors.PollSelector
-else:
- _selector_class = selectors.SelectSelector
-def https(q: dns.message.Message, where: str, timeout: Optional[float]=None,
- port: int=443, source: Optional[str]=None, source_port: int=0,
- one_rr_per_rrset: bool=False, ignore_trailing: bool=False, session:
- Optional[Any]=None, path: str='/dns-query', post: bool=True,
- bootstrap_address: Optional[str]=None, verify: Union[bool, str]=True,
- resolver: Optional['dns.resolver.Resolver']=None, family: Optional[int]
- =socket.AF_UNSPEC) ->dns.message.Message:
+def _compute_times(timeout):
+ now = time.time()
+ if timeout is None:
+ return (now, None)
+ else:
+ return (now, now + timeout)
+
+
+def _wait_for(fd, readable, writable, _, expiration):
+ # Use the selected selector class to wait for any of the specified
+ # events. An "expiration" absolute time is converted into a relative
+ # timeout.
+ #
+ # The unused parameter is 'error', which is always set when
+ # selecting for read or write, and we have no error-only selects.
+
+ if readable and isinstance(fd, ssl.SSLSocket) and fd.pending() > 0:
+ return True
+ sel = _selector_class()
+ events = 0
+ if readable:
+ events |= selectors.EVENT_READ
+ if writable:
+ events |= selectors.EVENT_WRITE
+ if events:
+ sel.register(fd, events)
+ if expiration is None:
+ timeout = None
+ else:
+ timeout = expiration - time.time()
+ if timeout <= 0.0:
+ raise dns.exception.Timeout
+ if not sel.select(timeout):
+ raise dns.exception.Timeout
+
+
+def _set_selector_class(selector_class):
+ # Internal API. Do not use.
+
+ global _selector_class
+
+ _selector_class = selector_class
+
+
+if hasattr(selectors, "PollSelector"):
+ # Prefer poll() on platforms that support it because it has no
+ # limits on the maximum value of a file descriptor (plus it will
+ # be more efficient for high values).
+ #
+ # We ignore typing here as we can't say _selector_class is Any
+ # on python < 3.8 due to a bug.
+ _selector_class = selectors.PollSelector # type: ignore
+else:
+ _selector_class = selectors.SelectSelector # type: ignore
+
+
+def _wait_for_readable(s, expiration):
+ _wait_for(s, True, False, True, expiration)
+
+
+def _wait_for_writable(s, expiration):
+ _wait_for(s, False, True, True, expiration)
+
+
+def _addresses_equal(af, a1, a2):
+ # Convert the first value of the tuple, which is a textual format
+ # address into binary form, so that we are not confused by different
+ # textual representations of the same address
+ try:
+ n1 = dns.inet.inet_pton(af, a1[0])
+ n2 = dns.inet.inet_pton(af, a2[0])
+ except dns.exception.SyntaxError:
+ return False
+ return n1 == n2 and a1[1:] == a2[1:]
+
+
+def _matches_destination(af, from_address, destination, ignore_unexpected):
+ # Check that from_address is appropriate for a response to a query
+ # sent to destination.
+ if not destination:
+ return True
+ if _addresses_equal(af, from_address, destination) or (
+ dns.inet.is_multicast(destination[0]) and from_address[1:] == destination[1:]
+ ):
+ return True
+ elif ignore_unexpected:
+ return False
+ raise UnexpectedSource(
+ f"got a response from {from_address} instead of " f"{destination}"
+ )
+
+
+def _destination_and_source(
+ where, port, source, source_port, where_must_be_address=True
+):
+ # Apply defaults and compute destination and source tuples
+ # suitable for use in connect(), sendto(), or bind().
+ af = None
+ destination = None
+ try:
+ af = dns.inet.af_for_address(where)
+ destination = where
+ except Exception:
+ if where_must_be_address:
+ raise
+ # URLs are ok so eat the exception
+ if source:
+ saf = dns.inet.af_for_address(source)
+ if af:
+ # We know the destination af, so source had better agree!
+ if saf != af:
+ raise ValueError(
+ "different address families for source and destination"
+ )
+ else:
+ # We didn't know the destination af, but we know the source,
+ # so that's our af.
+ af = saf
+ if source_port and not source:
+ # Caller has specified a source_port but not an address, so we
+ # need to return a source, and we need to use the appropriate
+ # wildcard address as the address.
+ try:
+ source = dns.inet.any_for_af(af)
+ except Exception:
+ # we catch this and raise ValueError for backwards compatibility
+ raise ValueError("source_port specified but address family is unknown")
+ # Convert high-level (address, port) tuples into low-level address
+ # tuples.
+ if destination:
+ destination = dns.inet.low_level_address_tuple((destination, port), af)
+ if source:
+ source = dns.inet.low_level_address_tuple((source, source_port), af)
+ return (af, destination, source)
+
+
+def _make_socket(af, type, source, ssl_context=None, server_hostname=None):
+ s = socket_factory(af, type)
+ try:
+ s.setblocking(False)
+ if source is not None:
+ s.bind(source)
+ if ssl_context:
+ # LGTM gets a false positive here, as our default context is OK
+ return ssl_context.wrap_socket(
+ s,
+ do_handshake_on_connect=False, # lgtm[py/insecure-protocol]
+ server_hostname=server_hostname,
+ )
+ else:
+ return s
+ except Exception:
+ s.close()
+ raise
+
+
+def https(
+ q: dns.message.Message,
+ where: str,
+ timeout: Optional[float] = None,
+ port: int = 443,
+ source: Optional[str] = None,
+ source_port: int = 0,
+ one_rr_per_rrset: bool = False,
+ ignore_trailing: bool = False,
+ session: Optional[Any] = None,
+ path: str = "/dns-query",
+ post: bool = True,
+ bootstrap_address: Optional[str] = None,
+ verify: Union[bool, str] = True,
+ resolver: Optional["dns.resolver.Resolver"] = None,
+ family: Optional[int] = socket.AF_UNSPEC,
+) -> dns.message.Message:
"""Return the response obtained after sending a query via DNS-over-HTTPS.
*q*, a ``dns.message.Message``, the query to send.
@@ -165,7 +422,85 @@ def https(q: dns.message.Message, where: str, timeout: Optional[float]=None,
Returns a ``dns.message.Message``.
"""
- pass
+
+ if not have_doh:
+ raise NoDOH # pragma: no cover
+ if session and not isinstance(session, httpx.Client):
+ raise ValueError("session parameter must be an httpx.Client")
+
+ wire = q.to_wire()
+ (af, _, the_source) = _destination_and_source(
+ where, port, source, source_port, False
+ )
+ transport = None
+ headers = {"accept": "application/dns-message"}
+ if af is not None and dns.inet.is_address(where):
+ if af == socket.AF_INET:
+ url = "https://{}:{}{}".format(where, port, path)
+ elif af == socket.AF_INET6:
+ url = "https://[{}]:{}{}".format(where, port, path)
+ else:
+ url = where
+
+ # set source port and source address
+
+ if the_source is None:
+ local_address = None
+ local_port = 0
+ else:
+ local_address = the_source[0]
+ local_port = the_source[1]
+ transport = _HTTPTransport(
+ local_address=local_address,
+ http1=True,
+ http2=True,
+ verify=verify,
+ local_port=local_port,
+ bootstrap_address=bootstrap_address,
+ resolver=resolver,
+ family=family,
+ )
+
+ if session:
+ cm: contextlib.AbstractContextManager = contextlib.nullcontext(session)
+ else:
+ cm = httpx.Client(http1=True, http2=True, verify=verify, transport=transport)
+ with cm as session:
+ # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH
+ # GET and POST examples
+ if post:
+ headers.update(
+ {
+ "content-type": "application/dns-message",
+ "content-length": str(len(wire)),
+ }
+ )
+ response = session.post(url, headers=headers, content=wire, timeout=timeout)
+ else:
+ wire = base64.urlsafe_b64encode(wire).rstrip(b"=")
+ twire = wire.decode() # httpx does a repr() if we give it bytes
+ response = session.get(
+ url, headers=headers, timeout=timeout, params={"dns": twire}
+ )
+
+ # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH
+ # status codes
+ if response.status_code < 200 or response.status_code > 299:
+ raise ValueError(
+ "{} responded with status code {}"
+ "\nResponse body: {}".format(where, response.status_code, response.content)
+ )
+ r = dns.message.from_wire(
+ response.content,
+ keyring=q.keyring,
+ request_mac=q.request_mac,
+ one_rr_per_rrset=one_rr_per_rrset,
+ ignore_trailing=ignore_trailing,
+ )
+ r.time = response.elapsed.total_seconds()
+ if not q.is_response(r):
+ raise BadResponse
+ return r
def _udp_recv(sock, max_size, expiration):
@@ -173,7 +508,11 @@ def _udp_recv(sock, max_size, expiration):
A Timeout exception will be raised if the operation is not completed
by the expiration time.
"""
- pass
+ while True:
+ try:
+ return sock.recvfrom(max_size)
+ except BlockingIOError:
+ _wait_for_readable(sock, expiration)
def _udp_send(sock, data, destination, expiration):
@@ -181,11 +520,22 @@ def _udp_send(sock, data, destination, expiration):
A Timeout exception will be raised if the operation is not completed
by the expiration time.
"""
- pass
-
-
-def send_udp(sock: Any, what: Union[dns.message.Message, bytes],
- destination: Any, expiration: Optional[float]=None) ->Tuple[int, float]:
+ while True:
+ try:
+ if destination:
+ return sock.sendto(data, destination)
+ else:
+ return sock.send(data)
+ except BlockingIOError: # pragma: no cover
+ _wait_for_writable(sock, expiration)
+
+
+def send_udp(
+ sock: Any,
+ what: Union[dns.message.Message, bytes],
+ destination: Any,
+ expiration: Optional[float] = None,
+) -> Tuple[int, float]:
"""Send a DNS message to the specified UDP socket.
*sock*, a ``socket``.
@@ -201,15 +551,27 @@ def send_udp(sock: Any, what: Union[dns.message.Message, bytes],
Returns an ``(int, float)`` tuple of bytes sent and the sent time.
"""
- pass
-
-def receive_udp(sock: Any, destination: Optional[Any]=None, expiration:
- Optional[float]=None, ignore_unexpected: bool=False, one_rr_per_rrset:
- bool=False, keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]]=None,
- request_mac: Optional[bytes]=b'', ignore_trailing: bool=False,
- raise_on_truncation: bool=False, ignore_errors: bool=False, query:
- Optional[dns.message.Message]=None) ->Any:
+ if isinstance(what, dns.message.Message):
+ what = what.to_wire()
+ sent_time = time.time()
+ n = _udp_send(sock, what, destination, expiration)
+ return (n, sent_time)
+
+
+def receive_udp(
+ sock: Any,
+ destination: Optional[Any] = None,
+ expiration: Optional[float] = None,
+ ignore_unexpected: bool = False,
+ one_rr_per_rrset: bool = False,
+ keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]] = None,
+ request_mac: Optional[bytes] = b"",
+ ignore_trailing: bool = False,
+ raise_on_truncation: bool = False,
+ ignore_errors: bool = False,
+ query: Optional[dns.message.Message] = None,
+) -> Any:
"""Read a DNS message from a UDP socket.
*sock*, a ``socket``.
@@ -258,14 +620,65 @@ def receive_udp(sock: Any, destination: Optional[Any]=None, expiration:
*ignore_errors* is ``True``, check that the received message is a response
to this query, and if not keep listening for a valid response.
"""
- pass
-
-def udp(q: dns.message.Message, where: str, timeout: Optional[float]=None,
- port: int=53, source: Optional[str]=None, source_port: int=0,
- ignore_unexpected: bool=False, one_rr_per_rrset: bool=False,
- ignore_trailing: bool=False, raise_on_truncation: bool=False, sock:
- Optional[Any]=None, ignore_errors: bool=False) ->dns.message.Message:
+ wire = b""
+ while True:
+ (wire, from_address) = _udp_recv(sock, 65535, expiration)
+ if not _matches_destination(
+ sock.family, from_address, destination, ignore_unexpected
+ ):
+ continue
+ received_time = time.time()
+ try:
+ r = dns.message.from_wire(
+ wire,
+ keyring=keyring,
+ request_mac=request_mac,
+ one_rr_per_rrset=one_rr_per_rrset,
+ ignore_trailing=ignore_trailing,
+ raise_on_truncation=raise_on_truncation,
+ )
+ except dns.message.Truncated as e:
+ # If we got Truncated and not FORMERR, we at least got the header with TC
+ # set, and very likely the question section, so we'll re-raise if the
+ # message seems to be a response as we need to know when truncation happens.
+ # We need to check that it seems to be a response as we don't want a random
+ # injected message with TC set to cause us to bail out.
+ if (
+ ignore_errors
+ and query is not None
+ and not query.is_response(e.message())
+ ):
+ continue
+ else:
+ raise
+ except Exception:
+ if ignore_errors:
+ continue
+ else:
+ raise
+ if ignore_errors and query is not None and not query.is_response(r):
+ continue
+ if destination:
+ return (r, received_time)
+ else:
+ return (r, received_time, from_address)
+
+
+def udp(
+ q: dns.message.Message,
+ where: str,
+ timeout: Optional[float] = None,
+ port: int = 53,
+ source: Optional[str] = None,
+ source_port: int = 0,
+ ignore_unexpected: bool = False,
+ one_rr_per_rrset: bool = False,
+ ignore_trailing: bool = False,
+ raise_on_truncation: bool = False,
+ sock: Optional[Any] = None,
+ ignore_errors: bool = False,
+) -> dns.message.Message:
"""Return the response obtained after sending a query via UDP.
*q*, a ``dns.message.Message``, the query to send
@@ -307,15 +720,56 @@ def udp(q: dns.message.Message, where: str, timeout: Optional[float]=None,
Returns a ``dns.message.Message``.
"""
- pass
-
-def udp_with_fallback(q: dns.message.Message, where: str, timeout: Optional
- [float]=None, port: int=53, source: Optional[str]=None, source_port:
- int=0, ignore_unexpected: bool=False, one_rr_per_rrset: bool=False,
- ignore_trailing: bool=False, udp_sock: Optional[Any]=None, tcp_sock:
- Optional[Any]=None, ignore_errors: bool=False) ->Tuple[dns.message.
- Message, bool]:
+ wire = q.to_wire()
+ (af, destination, source) = _destination_and_source(
+ where, port, source, source_port
+ )
+ (begin_time, expiration) = _compute_times(timeout)
+ if sock:
+ cm: contextlib.AbstractContextManager = contextlib.nullcontext(sock)
+ else:
+ cm = _make_socket(af, socket.SOCK_DGRAM, source)
+ with cm as s:
+ send_udp(s, wire, destination, expiration)
+ (r, received_time) = receive_udp(
+ s,
+ destination,
+ expiration,
+ ignore_unexpected,
+ one_rr_per_rrset,
+ q.keyring,
+ q.mac,
+ ignore_trailing,
+ raise_on_truncation,
+ ignore_errors,
+ q,
+ )
+ r.time = received_time - begin_time
+ # We don't need to check q.is_response() if we are in ignore_errors mode
+ # as receive_udp() will have checked it.
+ if not (ignore_errors or q.is_response(r)):
+ raise BadResponse
+ return r
+ assert (
+ False # help mypy figure out we can't get here lgtm[py/unreachable-statement]
+ )
+
+
+def udp_with_fallback(
+ q: dns.message.Message,
+ where: str,
+ timeout: Optional[float] = None,
+ port: int = 53,
+ source: Optional[str] = None,
+ source_port: int = 0,
+ ignore_unexpected: bool = False,
+ one_rr_per_rrset: bool = False,
+ ignore_trailing: bool = False,
+ udp_sock: Optional[Any] = None,
+ tcp_sock: Optional[Any] = None,
+ ignore_errors: bool = False,
+) -> Tuple[dns.message.Message, bool]:
"""Return the response to the query, trying UDP first and falling back
to TCP if UDP results in a truncated response.
@@ -359,7 +813,35 @@ def udp_with_fallback(q: dns.message.Message, where: str, timeout: Optional
Returns a (``dns.message.Message``, tcp) tuple where tcp is ``True`` if and only if
TCP was used.
"""
- pass
+ try:
+ response = udp(
+ q,
+ where,
+ timeout,
+ port,
+ source,
+ source_port,
+ ignore_unexpected,
+ one_rr_per_rrset,
+ ignore_trailing,
+ True,
+ udp_sock,
+ ignore_errors,
+ )
+ return (response, False)
+ except dns.message.Truncated:
+ response = tcp(
+ q,
+ where,
+ timeout,
+ port,
+ source,
+ source_port,
+ one_rr_per_rrset,
+ ignore_trailing,
+ tcp_sock,
+ )
+ return (response, True)
def _net_read(sock, count, expiration):
@@ -368,7 +850,19 @@ def _net_read(sock, count, expiration):
A Timeout exception will be raised if the operation is not completed
by the expiration time.
"""
- pass
+ s = b""
+ while count > 0:
+ try:
+ n = sock.recv(count)
+ if n == b"":
+ raise EOFError
+ count -= len(n)
+ s += n
+ except (BlockingIOError, ssl.SSLWantReadError):
+ _wait_for_readable(sock, expiration)
+ except ssl.SSLWantWriteError: # pragma: no cover
+ _wait_for_writable(sock, expiration)
+ return s
def _net_write(sock, data, expiration):
@@ -376,11 +870,22 @@ def _net_write(sock, data, expiration):
A Timeout exception will be raised if the operation is not completed
by the expiration time.
"""
- pass
-
-
-def send_tcp(sock: Any, what: Union[dns.message.Message, bytes], expiration:
- Optional[float]=None) ->Tuple[int, float]:
+ current = 0
+ l = len(data)
+ while current < l:
+ try:
+ current += sock.send(data[current:])
+ except (BlockingIOError, ssl.SSLWantWriteError):
+ _wait_for_writable(sock, expiration)
+ except ssl.SSLWantReadError: # pragma: no cover
+ _wait_for_readable(sock, expiration)
+
+
+def send_tcp(
+ sock: Any,
+ what: Union[dns.message.Message, bytes],
+ expiration: Optional[float] = None,
+) -> Tuple[int, float]:
"""Send a DNS message to the specified TCP socket.
*sock*, a ``socket``.
@@ -393,13 +898,27 @@ def send_tcp(sock: Any, what: Union[dns.message.Message, bytes], expiration:
Returns an ``(int, float)`` tuple of bytes sent and the sent time.
"""
- pass
-
-def receive_tcp(sock: Any, expiration: Optional[float]=None,
- one_rr_per_rrset: bool=False, keyring: Optional[Dict[dns.name.Name, dns
- .tsig.Key]]=None, request_mac: Optional[bytes]=b'', ignore_trailing:
- bool=False) ->Tuple[dns.message.Message, float]:
+ if isinstance(what, dns.message.Message):
+ tcpmsg = what.to_wire(prepend_length=True)
+ else:
+ # copying the wire into tcpmsg is inefficient, but lets us
+ # avoid writev() or doing a short write that would get pushed
+ # onto the net
+ tcpmsg = len(what).to_bytes(2, "big") + what
+ sent_time = time.time()
+ _net_write(sock, tcpmsg, expiration)
+ return (len(tcpmsg), sent_time)
+
+
+def receive_tcp(
+ sock: Any,
+ expiration: Optional[float] = None,
+ one_rr_per_rrset: bool = False,
+ keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]] = None,
+ request_mac: Optional[bytes] = b"",
+ ignore_trailing: bool = False,
+) -> Tuple[dns.message.Message, float]:
"""Read a DNS message from a TCP socket.
*sock*, a ``socket``.
@@ -424,13 +943,43 @@ def receive_tcp(sock: Any, expiration: Optional[float]=None,
Returns a ``(dns.message.Message, float)`` tuple of the received message
and the received time.
"""
- pass
-
-def tcp(q: dns.message.Message, where: str, timeout: Optional[float]=None,
- port: int=53, source: Optional[str]=None, source_port: int=0,
- one_rr_per_rrset: bool=False, ignore_trailing: bool=False, sock:
- Optional[Any]=None) ->dns.message.Message:
+ ldata = _net_read(sock, 2, expiration)
+ (l,) = struct.unpack("!H", ldata)
+ wire = _net_read(sock, l, expiration)
+ received_time = time.time()
+ r = dns.message.from_wire(
+ wire,
+ keyring=keyring,
+ request_mac=request_mac,
+ one_rr_per_rrset=one_rr_per_rrset,
+ ignore_trailing=ignore_trailing,
+ )
+ return (r, received_time)
+
+
+def _connect(s, address, expiration):
+ err = s.connect_ex(address)
+ if err == 0:
+ return
+ if err in (errno.EINPROGRESS, errno.EWOULDBLOCK, errno.EALREADY):
+ _wait_for_writable(s, expiration)
+ err = s.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
+ if err != 0:
+ raise OSError(err, os.strerror(err))
+
+
+def tcp(
+ q: dns.message.Message,
+ where: str,
+ timeout: Optional[float] = None,
+ port: int = 53,
+ source: Optional[str] = None,
+ source_port: int = 0,
+ one_rr_per_rrset: bool = False,
+ ignore_trailing: bool = False,
+ sock: Optional[Any] = None,
+) -> dns.message.Message:
"""Return the response obtained after sending a query via TCP.
*q*, a ``dns.message.Message``, the query to send
@@ -462,15 +1011,79 @@ def tcp(q: dns.message.Message, where: str, timeout: Optional[float]=None,
Returns a ``dns.message.Message``.
"""
- pass
-
-def tls(q: dns.message.Message, where: str, timeout: Optional[float]=None,
- port: int=853, source: Optional[str]=None, source_port: int=0,
- one_rr_per_rrset: bool=False, ignore_trailing: bool=False, sock:
- Optional[ssl.SSLSocket]=None, ssl_context: Optional[ssl.SSLContext]=
- None, server_hostname: Optional[str]=None, verify: Union[bool, str]=True
- ) ->dns.message.Message:
+ wire = q.to_wire()
+ (begin_time, expiration) = _compute_times(timeout)
+ if sock:
+ cm: contextlib.AbstractContextManager = contextlib.nullcontext(sock)
+ else:
+ (af, destination, source) = _destination_and_source(
+ where, port, source, source_port
+ )
+ cm = _make_socket(af, socket.SOCK_STREAM, source)
+ with cm as s:
+ if not sock:
+ _connect(s, destination, expiration)
+ send_tcp(s, wire, expiration)
+ (r, received_time) = receive_tcp(
+ s, expiration, one_rr_per_rrset, q.keyring, q.mac, ignore_trailing
+ )
+ r.time = received_time - begin_time
+ if not q.is_response(r):
+ raise BadResponse
+ return r
+ assert (
+ False # help mypy figure out we can't get here lgtm[py/unreachable-statement]
+ )
+
+
+def _tls_handshake(s, expiration):
+ while True:
+ try:
+ s.do_handshake()
+ return
+ except ssl.SSLWantReadError:
+ _wait_for_readable(s, expiration)
+ except ssl.SSLWantWriteError: # pragma: no cover
+ _wait_for_writable(s, expiration)
+
+
+def _make_dot_ssl_context(
+ server_hostname: Optional[str], verify: Union[bool, str]
+) -> ssl.SSLContext:
+ cafile: Optional[str] = None
+ capath: Optional[str] = None
+ if isinstance(verify, str):
+ if os.path.isfile(verify):
+ cafile = verify
+ elif os.path.isdir(verify):
+ capath = verify
+ else:
+ raise ValueError("invalid verify string")
+ ssl_context = ssl.create_default_context(cafile=cafile, capath=capath)
+ ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2
+ if server_hostname is None:
+ ssl_context.check_hostname = False
+ ssl_context.set_alpn_protocols(["dot"])
+ if verify is False:
+ ssl_context.verify_mode = ssl.CERT_NONE
+ return ssl_context
+
+
+def tls(
+ q: dns.message.Message,
+ where: str,
+ timeout: Optional[float] = None,
+ port: int = 853,
+ source: Optional[str] = None,
+ source_port: int = 0,
+ one_rr_per_rrset: bool = False,
+ ignore_trailing: bool = False,
+ sock: Optional[ssl.SSLSocket] = None,
+ ssl_context: Optional[ssl.SSLContext] = None,
+ server_hostname: Optional[str] = None,
+ verify: Union[bool, str] = True,
+) -> dns.message.Message:
"""Return the response obtained after sending a query via TLS.
*q*, a ``dns.message.Message``, the query to send
@@ -517,14 +1130,66 @@ def tls(q: dns.message.Message, where: str, timeout: Optional[float]=None,
Returns a ``dns.message.Message``.
"""
- pass
-
-def quic(q: dns.message.Message, where: str, timeout: Optional[float]=None,
- port: int=853, source: Optional[str]=None, source_port: int=0,
- one_rr_per_rrset: bool=False, ignore_trailing: bool=False, connection:
- Optional[dns.quic.SyncQuicConnection]=None, verify: Union[bool, str]=
- True, server_hostname: Optional[str]=None) ->dns.message.Message:
+ if sock:
+ #
+ # If a socket was provided, there's no special TLS handling needed.
+ #
+ return tcp(
+ q,
+ where,
+ timeout,
+ port,
+ source,
+ source_port,
+ one_rr_per_rrset,
+ ignore_trailing,
+ sock,
+ )
+
+ wire = q.to_wire()
+ (begin_time, expiration) = _compute_times(timeout)
+ (af, destination, source) = _destination_and_source(
+ where, port, source, source_port
+ )
+ if ssl_context is None and not sock:
+ ssl_context = _make_dot_ssl_context(server_hostname, verify)
+
+ with _make_socket(
+ af,
+ socket.SOCK_STREAM,
+ source,
+ ssl_context=ssl_context,
+ server_hostname=server_hostname,
+ ) as s:
+ _connect(s, destination, expiration)
+ _tls_handshake(s, expiration)
+ send_tcp(s, wire, expiration)
+ (r, received_time) = receive_tcp(
+ s, expiration, one_rr_per_rrset, q.keyring, q.mac, ignore_trailing
+ )
+ r.time = received_time - begin_time
+ if not q.is_response(r):
+ raise BadResponse
+ return r
+ assert (
+ False # help mypy figure out we can't get here lgtm[py/unreachable-statement]
+ )
+
+
+def quic(
+ q: dns.message.Message,
+ where: str,
+ timeout: Optional[float] = None,
+ port: int = 853,
+ source: Optional[str] = None,
+ source_port: int = 0,
+ one_rr_per_rrset: bool = False,
+ ignore_trailing: bool = False,
+ connection: Optional[dns.quic.SyncQuicConnection] = None,
+ verify: Union[bool, str] = True,
+ server_hostname: Optional[str] = None,
+) -> dns.message.Message:
"""Return the response obtained after sending a query via DNS-over-QUIC.
*q*, a ``dns.message.Message``, the query to send.
@@ -561,17 +1226,61 @@ def quic(q: dns.message.Message, where: str, timeout: Optional[float]=None,
Returns a ``dns.message.Message``.
"""
- pass
-
-def xfr(where: str, zone: Union[dns.name.Name, str], rdtype: Union[dns.
- rdatatype.RdataType, str]=dns.rdatatype.AXFR, rdclass: Union[dns.
- rdataclass.RdataClass, str]=dns.rdataclass.IN, timeout: Optional[float]
- =None, port: int=53, keyring: Optional[Dict[dns.name.Name, dns.tsig.Key
- ]]=None, keyname: Optional[Union[dns.name.Name, str]]=None, relativize:
- bool=True, lifetime: Optional[float]=None, source: Optional[str]=None,
- source_port: int=0, serial: int=0, use_udp: bool=False, keyalgorithm:
- Union[dns.name.Name, str]=dns.tsig.default_algorithm) ->Any:
+ if not dns.quic.have_quic:
+ raise NoDOQ("DNS-over-QUIC is not available.") # pragma: no cover
+
+ q.id = 0
+ wire = q.to_wire()
+ the_connection: dns.quic.SyncQuicConnection
+ the_manager: dns.quic.SyncQuicManager
+ if connection:
+ manager: contextlib.AbstractContextManager = contextlib.nullcontext(None)
+ the_connection = connection
+ else:
+ manager = dns.quic.SyncQuicManager(
+ verify_mode=verify, server_name=server_hostname
+ )
+ the_manager = manager # for type checking happiness
+
+ with manager:
+ if not connection:
+ the_connection = the_manager.connect(where, port, source, source_port)
+ (start, expiration) = _compute_times(timeout)
+ with the_connection.make_stream(timeout) as stream:
+ stream.send(wire, True)
+ wire = stream.receive(_remaining(expiration))
+ finish = time.time()
+ r = dns.message.from_wire(
+ wire,
+ keyring=q.keyring,
+ request_mac=q.request_mac,
+ one_rr_per_rrset=one_rr_per_rrset,
+ ignore_trailing=ignore_trailing,
+ )
+ r.time = max(finish - start, 0.0)
+ if not q.is_response(r):
+ raise BadResponse
+ return r
+
+
+def xfr(
+ where: str,
+ zone: Union[dns.name.Name, str],
+ rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.AXFR,
+ rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN,
+ timeout: Optional[float] = None,
+ port: int = 53,
+ keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]] = None,
+ keyname: Optional[Union[dns.name.Name, str]] = None,
+ relativize: bool = True,
+ lifetime: Optional[float] = None,
+ source: Optional[str] = None,
+ source_port: int = 0,
+ serial: int = 0,
+ use_udp: bool = False,
+ keyalgorithm: Union[dns.name.Name, str] = dns.tsig.default_algorithm,
+) -> Any:
"""Return a generator for the responses to a zone transfer.
*where*, a ``str`` containing an IPv4 or IPv6 address, where
@@ -623,7 +1332,122 @@ def xfr(where: str, zone: Union[dns.name.Name, str], rdtype: Union[dns.
Returns a generator of ``dns.message.Message`` objects.
"""
- pass
+
+ if isinstance(zone, str):
+ zone = dns.name.from_text(zone)
+ rdtype = dns.rdatatype.RdataType.make(rdtype)
+ q = dns.message.make_query(zone, rdtype, rdclass)
+ if rdtype == dns.rdatatype.IXFR:
+ rrset = dns.rrset.from_text(zone, 0, "IN", "SOA", ". . %u 0 0 0 0" % serial)
+ q.authority.append(rrset)
+ if keyring is not None:
+ q.use_tsig(keyring, keyname, algorithm=keyalgorithm)
+ wire = q.to_wire()
+ (af, destination, source) = _destination_and_source(
+ where, port, source, source_port
+ )
+ if use_udp and rdtype != dns.rdatatype.IXFR:
+ raise ValueError("cannot do a UDP AXFR")
+ sock_type = socket.SOCK_DGRAM if use_udp else socket.SOCK_STREAM
+ with _make_socket(af, sock_type, source) as s:
+ (_, expiration) = _compute_times(lifetime)
+ _connect(s, destination, expiration)
+ l = len(wire)
+ if use_udp:
+ _udp_send(s, wire, None, expiration)
+ else:
+ tcpmsg = struct.pack("!H", l) + wire
+ _net_write(s, tcpmsg, expiration)
+ done = False
+ delete_mode = True
+ expecting_SOA = False
+ soa_rrset = None
+ if relativize:
+ origin = zone
+ oname = dns.name.empty
+ else:
+ origin = None
+ oname = zone
+ tsig_ctx = None
+ while not done:
+ (_, mexpiration) = _compute_times(timeout)
+ if mexpiration is None or (
+ expiration is not None and mexpiration > expiration
+ ):
+ mexpiration = expiration
+ if use_udp:
+ (wire, _) = _udp_recv(s, 65535, mexpiration)
+ else:
+ ldata = _net_read(s, 2, mexpiration)
+ (l,) = struct.unpack("!H", ldata)
+ wire = _net_read(s, l, mexpiration)
+ is_ixfr = rdtype == dns.rdatatype.IXFR
+ r = dns.message.from_wire(
+ wire,
+ keyring=q.keyring,
+ request_mac=q.mac,
+ xfr=True,
+ origin=origin,
+ tsig_ctx=tsig_ctx,
+ multi=True,
+ one_rr_per_rrset=is_ixfr,
+ )
+ rcode = r.rcode()
+ if rcode != dns.rcode.NOERROR:
+ raise TransferError(rcode)
+ tsig_ctx = r.tsig_ctx
+ answer_index = 0
+ if soa_rrset is None:
+ if not r.answer or r.answer[0].name != oname:
+ raise dns.exception.FormError("No answer or RRset not for qname")
+ rrset = r.answer[0]
+ if rrset.rdtype != dns.rdatatype.SOA:
+ raise dns.exception.FormError("first RRset is not an SOA")
+ answer_index = 1
+ soa_rrset = rrset.copy()
+ if rdtype == dns.rdatatype.IXFR:
+ if dns.serial.Serial(soa_rrset[0].serial) <= serial:
+ #
+ # We're already up-to-date.
+ #
+ done = True
+ else:
+ expecting_SOA = True
+ #
+ # Process SOAs in the answer section (other than the initial
+ # SOA in the first message).
+ #
+ for rrset in r.answer[answer_index:]:
+ if done:
+ raise dns.exception.FormError("answers after final SOA")
+ if rrset.rdtype == dns.rdatatype.SOA and rrset.name == oname:
+ if expecting_SOA:
+ if rrset[0].serial != serial:
+ raise dns.exception.FormError("IXFR base serial mismatch")
+ expecting_SOA = False
+ elif rdtype == dns.rdatatype.IXFR:
+ delete_mode = not delete_mode
+ #
+ # If this SOA RRset is equal to the first we saw then we're
+ # finished. If this is an IXFR we also check that we're
+ # seeing the record in the expected part of the response.
+ #
+ if rrset == soa_rrset and (
+ rdtype == dns.rdatatype.AXFR
+ or (rdtype == dns.rdatatype.IXFR and delete_mode)
+ ):
+ done = True
+ elif expecting_SOA:
+ #
+ # We made an IXFR request and are expecting another
+ # SOA RR, but saw something else, so this must be an
+ # AXFR response.
+ #
+ rdtype = dns.rdatatype.AXFR
+ expecting_SOA = False
+ if done and q.keyring and not r.had_tsig:
+ raise dns.exception.FormError("missing TSIG")
+ yield r
class UDPMode(enum.IntEnum):
@@ -633,15 +1457,23 @@ class UDPMode(enum.IntEnum):
TRY_FIRST means "try to use UDP but fall back to TCP if needed"
ONLY means "raise ``dns.xfr.UseTCP`` if trying UDP does not succeed"
"""
+
NEVER = 0
TRY_FIRST = 1
ONLY = 2
-def inbound_xfr(where: str, txn_manager: dns.transaction.TransactionManager,
- query: Optional[dns.message.Message]=None, port: int=53, timeout:
- Optional[float]=None, lifetime: Optional[float]=None, source: Optional[
- str]=None, source_port: int=0, udp_mode: UDPMode=UDPMode.NEVER) ->None:
+def inbound_xfr(
+ where: str,
+ txn_manager: dns.transaction.TransactionManager,
+ query: Optional[dns.message.Message] = None,
+ port: int = 53,
+ timeout: Optional[float] = None,
+ lifetime: Optional[float] = None,
+ source: Optional[str] = None,
+ source_port: int = 0,
+ udp_mode: UDPMode = UDPMode.NEVER,
+) -> None:
"""Conduct an inbound transfer and apply it via a transaction from the
txn_manager.
@@ -678,4 +1510,69 @@ def inbound_xfr(where: str, txn_manager: dns.transaction.TransactionManager,
Raises on errors.
"""
- pass
+ if query is None:
+ (query, serial) = dns.xfr.make_query(txn_manager)
+ else:
+ serial = dns.xfr.extract_serial_from_query(query)
+ rdtype = query.question[0].rdtype
+ is_ixfr = rdtype == dns.rdatatype.IXFR
+ origin = txn_manager.from_wire_origin()
+ wire = query.to_wire()
+ (af, destination, source) = _destination_and_source(
+ where, port, source, source_port
+ )
+ (_, expiration) = _compute_times(lifetime)
+ retry = True
+ while retry:
+ retry = False
+ if is_ixfr and udp_mode != UDPMode.NEVER:
+ sock_type = socket.SOCK_DGRAM
+ is_udp = True
+ else:
+ sock_type = socket.SOCK_STREAM
+ is_udp = False
+ with _make_socket(af, sock_type, source) as s:
+ _connect(s, destination, expiration)
+ if is_udp:
+ _udp_send(s, wire, None, expiration)
+ else:
+ tcpmsg = struct.pack("!H", len(wire)) + wire
+ _net_write(s, tcpmsg, expiration)
+ with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound:
+ done = False
+ tsig_ctx = None
+ while not done:
+ (_, mexpiration) = _compute_times(timeout)
+ if mexpiration is None or (
+ expiration is not None and mexpiration > expiration
+ ):
+ mexpiration = expiration
+ if is_udp:
+ (rwire, _) = _udp_recv(s, 65535, mexpiration)
+ else:
+ ldata = _net_read(s, 2, mexpiration)
+ (l,) = struct.unpack("!H", ldata)
+ rwire = _net_read(s, l, mexpiration)
+ r = dns.message.from_wire(
+ rwire,
+ keyring=query.keyring,
+ request_mac=query.mac,
+ xfr=True,
+ origin=origin,
+ tsig_ctx=tsig_ctx,
+ multi=(not is_udp),
+ one_rr_per_rrset=is_ixfr,
+ )
+ try:
+ done = inbound.process_message(r)
+ except dns.xfr.UseTCP:
+ assert is_udp # should not happen if we used TCP!
+ if udp_mode == UDPMode.ONLY:
+ raise
+ done = True
+ retry = True
+ udp_mode = UDPMode.NEVER
+ continue
+ tsig_ctx = r.tsig_ctx
+ if not retry and query.keyring and not r.had_tsig:
+ raise dns.exception.FormError("missing TSIG")
diff --git a/dns/quic/_asyncio.py b/dns/quic/_asyncio.py
index 046f81d..0f44331 100644
--- a/dns/quic/_asyncio.py
+++ b/dns/quic/_asyncio.py
@@ -1,23 +1,69 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
import asyncio
import socket
import ssl
import struct
import time
-import aioquic.quic.configuration
-import aioquic.quic.connection
-import aioquic.quic.events
+
+import aioquic.quic.configuration # type: ignore
+import aioquic.quic.connection # type: ignore
+import aioquic.quic.events # type: ignore
+
import dns.asyncbackend
import dns.exception
import dns.inet
-from dns.quic._common import QUIC_MAX_DATAGRAM, AsyncQuicConnection, AsyncQuicManager, BaseQuicStream, UnexpectedEOF
+from dns.quic._common import (
+ QUIC_MAX_DATAGRAM,
+ AsyncQuicConnection,
+ AsyncQuicManager,
+ BaseQuicStream,
+ UnexpectedEOF,
+)
class AsyncioQuicStream(BaseQuicStream):
-
def __init__(self, connection, stream_id):
super().__init__(connection, stream_id)
self._wake_up = asyncio.Condition()
+ async def _wait_for_wake_up(self):
+ async with self._wake_up:
+ await self._wake_up.wait()
+
+ async def wait_for(self, amount, expiration):
+ while True:
+ timeout = self._timeout_from_expiration(expiration)
+ if self._buffer.have(amount):
+ return
+ self._expecting = amount
+ try:
+ await asyncio.wait_for(self._wait_for_wake_up(), timeout)
+ except TimeoutError:
+ raise dns.exception.Timeout
+ self._expecting = 0
+
+ async def receive(self, timeout=None):
+ expiration = self._expiration_from_timeout(timeout)
+ await self.wait_for(2, expiration)
+ (size,) = struct.unpack("!H", self._buffer.get(2))
+ await self.wait_for(size, expiration)
+ return self._buffer.get(size)
+
+ async def send(self, datagram, is_end=False):
+ data = self._encapsulate(datagram)
+ await self._connection.write(self._stream_id, data, is_end)
+
+ async def _add_input(self, data, is_end):
+ if self._common_add_input(data, is_end):
+ async with self._wake_up:
+ self._wake_up.notify()
+
+ async def close(self):
+ self._close()
+
+ # Streams are async context managers
+
async def __aenter__(self):
return self
@@ -29,11 +75,8 @@ class AsyncioQuicStream(BaseQuicStream):
class AsyncioQuicConnection(AsyncQuicConnection):
-
- def __init__(self, connection, address, port, source, source_port,
- manager=None):
- super().__init__(connection, address, port, source, source_port,
- manager)
+ def __init__(self, connection, address, port, source, source_port, manager=None):
+ super().__init__(connection, address, port, source, source_port, manager)
self._socket = None
self._handshake_complete = asyncio.Event()
self._socket_created = asyncio.Event()
@@ -41,17 +84,144 @@ class AsyncioQuicConnection(AsyncQuicConnection):
self._receiver_task = None
self._sender_task = None
+ async def _receiver(self):
+ try:
+ af = dns.inet.af_for_address(self._address)
+ backend = dns.asyncbackend.get_backend("asyncio")
+ # Note that peer is a low-level address tuple, but make_socket() wants
+ # a high-level address tuple, so we convert.
+ self._socket = await backend.make_socket(
+ af, socket.SOCK_DGRAM, 0, self._source, (self._peer[0], self._peer[1])
+ )
+ self._socket_created.set()
+ async with self._socket:
+ while not self._done:
+ (datagram, address) = await self._socket.recvfrom(
+ QUIC_MAX_DATAGRAM, None
+ )
+ if address[0] != self._peer[0] or address[1] != self._peer[1]:
+ continue
+ self._connection.receive_datagram(datagram, address, time.time())
+ # Wake up the timer in case the sender is sleeping, as there may be
+ # stuff to send now.
+ async with self._wake_timer:
+ self._wake_timer.notify_all()
+ except Exception:
+ pass
+ finally:
+ self._done = True
+ async with self._wake_timer:
+ self._wake_timer.notify_all()
+ self._handshake_complete.set()
-class AsyncioQuicManager(AsyncQuicManager):
+ async def _wait_for_wake_timer(self):
+ async with self._wake_timer:
+ await self._wake_timer.wait()
+
+ async def _sender(self):
+ await self._socket_created.wait()
+ while not self._done:
+ datagrams = self._connection.datagrams_to_send(time.time())
+ for datagram, address in datagrams:
+ assert address == self._peer
+ await self._socket.sendto(datagram, self._peer, None)
+ (expiration, interval) = self._get_timer_values()
+ try:
+ await asyncio.wait_for(self._wait_for_wake_timer(), interval)
+ except Exception:
+ pass
+ self._handle_timer(expiration)
+ await self._handle_events()
+
+ async def _handle_events(self):
+ count = 0
+ while True:
+ event = self._connection.next_event()
+ if event is None:
+ return
+ if isinstance(event, aioquic.quic.events.StreamDataReceived):
+ stream = self._streams.get(event.stream_id)
+ if stream:
+ await stream._add_input(event.data, event.end_stream)
+ elif isinstance(event, aioquic.quic.events.HandshakeCompleted):
+ self._handshake_complete.set()
+ elif isinstance(event, aioquic.quic.events.ConnectionTerminated):
+ self._done = True
+ self._receiver_task.cancel()
+ elif isinstance(event, aioquic.quic.events.StreamReset):
+ stream = self._streams.get(event.stream_id)
+ if stream:
+ await stream._add_input(b"", True)
+
+ count += 1
+ if count > 10:
+ # yield
+ count = 0
+ await asyncio.sleep(0)
+
+ async def write(self, stream, data, is_end=False):
+ self._connection.send_stream_data(stream, data, is_end)
+ async with self._wake_timer:
+ self._wake_timer.notify_all()
- def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED,
- server_name=None):
+ def run(self):
+ if self._closed:
+ return
+ self._receiver_task = asyncio.Task(self._receiver())
+ self._sender_task = asyncio.Task(self._sender())
+
+ async def make_stream(self, timeout=None):
+ try:
+ await asyncio.wait_for(self._handshake_complete.wait(), timeout)
+ except TimeoutError:
+ raise dns.exception.Timeout
+ if self._done:
+ raise UnexpectedEOF
+ stream_id = self._connection.get_next_available_stream_id(False)
+ stream = AsyncioQuicStream(self, stream_id)
+ self._streams[stream_id] = stream
+ return stream
+
+ async def close(self):
+ if not self._closed:
+ self._manager.closed(self._peer[0], self._peer[1])
+ self._closed = True
+ self._connection.close()
+ # sender might be blocked on this, so set it
+ self._socket_created.set()
+ async with self._wake_timer:
+ self._wake_timer.notify_all()
+ try:
+ await self._receiver_task
+ except asyncio.CancelledError:
+ pass
+ try:
+ await self._sender_task
+ except asyncio.CancelledError:
+ pass
+ await self._socket.close()
+
+
+class AsyncioQuicManager(AsyncQuicManager):
+ def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None):
super().__init__(conf, verify_mode, AsyncioQuicConnection, server_name)
+ def connect(
+ self, address, port=853, source=None, source_port=0, want_session_ticket=True
+ ):
+ (connection, start) = self._connect(
+ address, port, source, source_port, want_session_ticket
+ )
+ if start:
+ connection.run()
+ return connection
+
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
+ # Copy the iterator into a list as exiting things will mutate the connections
+ # table.
connections = list(self._connections.values())
for connection in connections:
await connection.close()
diff --git a/dns/quic/_common.py b/dns/quic/_common.py
index d37beb9..0eacc69 100644
--- a/dns/quic/_common.py
+++ b/dns/quic/_common.py
@@ -1,14 +1,21 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
import copy
import functools
import socket
import struct
import time
from typing import Any, Optional
-import aioquic.quic.configuration
-import aioquic.quic.connection
+
+import aioquic.quic.configuration # type: ignore
+import aioquic.quic.connection # type: ignore
+
import dns.inet
+
QUIC_MAX_DATAGRAM = 2048
MAX_SESSION_TICKETS = 8
+# If we hit the max sessions limit we will delete this many of the oldest connections.
+# The value must be a integer > 0 and <= MAX_SESSION_TICKETS.
SESSIONS_TO_DELETE = MAX_SESSION_TICKETS // 4
@@ -17,25 +24,81 @@ class UnexpectedEOF(Exception):
class Buffer:
-
def __init__(self):
- self._buffer = b''
+ self._buffer = b""
self._seen_end = False
+ def put(self, data, is_end):
+ if self._seen_end:
+ return
+ self._buffer += data
+ if is_end:
+ self._seen_end = True
-class BaseQuicStream:
+ def have(self, amount):
+ if len(self._buffer) >= amount:
+ return True
+ if self._seen_end:
+ raise UnexpectedEOF
+ return False
+
+ def seen_end(self):
+ return self._seen_end
+ def get(self, amount):
+ assert self.have(amount)
+ data = self._buffer[:amount]
+ self._buffer = self._buffer[amount:]
+ return data
+
+
+class BaseQuicStream:
def __init__(self, connection, stream_id):
self._connection = connection
self._stream_id = stream_id
self._buffer = Buffer()
self._expecting = 0
+ def id(self):
+ return self._stream_id
-class BaseQuicConnection:
+ def _expiration_from_timeout(self, timeout):
+ if timeout is not None:
+ expiration = time.time() + timeout
+ else:
+ expiration = None
+ return expiration
+
+ def _timeout_from_expiration(self, expiration):
+ if expiration is not None:
+ timeout = max(expiration - time.time(), 0.0)
+ else:
+ timeout = None
+ return timeout
- def __init__(self, connection, address, port, source=None, source_port=
- 0, manager=None):
+ # Subclass must implement receive() as sync / async and which returns a message
+ # or raises UnexpectedEOF.
+
+ def _encapsulate(self, datagram):
+ l = len(datagram)
+ return struct.pack("!H", l) + datagram
+
+ def _common_add_input(self, data, is_end):
+ self._buffer.put(data, is_end)
+ try:
+ return self._expecting > 0 and self._buffer.have(self._expecting)
+ except UnexpectedEOF:
+ return True
+
+ def _close(self):
+ self._connection.close_stream(self._stream_id)
+ self._buffer.put(b"", True) # send EOF in case we haven't seen it.
+
+
+class BaseQuicConnection:
+ def __init__(
+ self, connection, address, port, source=None, source_port=0, manager=None
+ ):
self._done = False
self._connection = connection
self._address = address
@@ -47,25 +110,45 @@ class BaseQuicConnection:
self._peer = dns.inet.low_level_address_tuple((address, port))
if source is None and source_port != 0:
if self._af == socket.AF_INET:
- source = '0.0.0.0'
+ source = "0.0.0.0"
elif self._af == socket.AF_INET6:
- source = '::'
+ source = "::"
else:
raise NotImplementedError
if source:
- self._source = source, source_port
+ self._source = (source, source_port)
else:
self._source = None
+ def close_stream(self, stream_id):
+ del self._streams[stream_id]
+
+ def _get_timer_values(self, closed_is_special=True):
+ now = time.time()
+ expiration = self._connection.get_timer()
+ if expiration is None:
+ expiration = now + 3600 # arbitrary "big" value
+ interval = max(expiration - now, 0)
+ if self._closed and closed_is_special:
+ # lower sleep interval to avoid a race in the closing process
+ # which can lead to higher latency closing due to sleeping when
+ # we have events.
+ interval = min(interval, 0.05)
+ return (expiration, interval)
+
+ def _handle_timer(self, expiration):
+ now = time.time()
+ if expiration <= now:
+ self._connection.handle_timer(now)
+
class AsyncQuicConnection(BaseQuicConnection):
- pass
+ async def make_stream(self, timeout: Optional[float] = None) -> Any:
+ pass
class BaseQuicManager:
-
- def __init__(self, conf, verify_mode, connection_factory, server_name=None
- ):
+ def __init__(self, conf, verify_mode, connection_factory, server_name=None):
self._connections = {}
self._connection_factory = connection_factory
self._session_tickets = {}
@@ -74,13 +157,68 @@ class BaseQuicManager:
if isinstance(verify_mode, str):
verify_path = verify_mode
verify_mode = True
- conf = aioquic.quic.configuration.QuicConfiguration(alpn_protocols
- =['doq', 'doq-i03'], verify_mode=verify_mode, server_name=
- server_name)
+ conf = aioquic.quic.configuration.QuicConfiguration(
+ alpn_protocols=["doq", "doq-i03"],
+ verify_mode=verify_mode,
+ server_name=server_name,
+ )
if verify_path is not None:
conf.load_verify_locations(verify_path)
self._conf = conf
+ def _connect(
+ self, address, port=853, source=None, source_port=0, want_session_ticket=True
+ ):
+ connection = self._connections.get((address, port))
+ if connection is not None:
+ return (connection, False)
+ conf = self._conf
+ if want_session_ticket:
+ try:
+ session_ticket = self._session_tickets.pop((address, port))
+ # We found a session ticket, so make a configuration that uses it.
+ conf = copy.copy(conf)
+ conf.session_ticket = session_ticket
+ except KeyError:
+ # No session ticket.
+ pass
+ # Whether or not we found a session ticket, we want a handler to save
+ # one.
+ session_ticket_handler = functools.partial(
+ self.save_session_ticket, address, port
+ )
+ else:
+ session_ticket_handler = None
+ qconn = aioquic.quic.connection.QuicConnection(
+ configuration=conf,
+ session_ticket_handler=session_ticket_handler,
+ )
+ lladdress = dns.inet.low_level_address_tuple((address, port))
+ qconn.connect(lladdress, time.time())
+ connection = self._connection_factory(
+ qconn, address, port, source, source_port, self
+ )
+ self._connections[(address, port)] = connection
+ return (connection, True)
+
+ def closed(self, address, port):
+ try:
+ del self._connections[(address, port)]
+ except KeyError:
+ pass
+
+ def save_session_ticket(self, address, port, ticket):
+ # We rely on dictionaries keys() being in insertion order here. We
+ # can't just popitem() as that would be LIFO which is the opposite of
+ # what we want.
+ l = len(self._session_tickets)
+ if l >= MAX_SESSION_TICKETS:
+ keys_to_delete = list(self._session_tickets.keys())[0:SESSIONS_TO_DELETE]
+ for key in keys_to_delete:
+ del self._session_tickets[key]
+ self._session_tickets[(address, port)] = ticket
+
class AsyncQuicManager(BaseQuicManager):
- pass
+ def connect(self, address, port=853, source=None, source_port=0):
+ raise NotImplementedError
diff --git a/dns/quic/_sync.py b/dns/quic/_sync.py
index 5052983..120cb5f 100644
--- a/dns/quic/_sync.py
+++ b/dns/quic/_sync.py
@@ -1,28 +1,73 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
import selectors
import socket
import ssl
import struct
import threading
import time
-import aioquic.quic.configuration
-import aioquic.quic.connection
-import aioquic.quic.events
+
+import aioquic.quic.configuration # type: ignore
+import aioquic.quic.connection # type: ignore
+import aioquic.quic.events # type: ignore
+
import dns.exception
import dns.inet
-from dns.quic._common import QUIC_MAX_DATAGRAM, BaseQuicConnection, BaseQuicManager, BaseQuicStream, UnexpectedEOF
-if hasattr(selectors, 'PollSelector'):
- _selector_class = selectors.PollSelector
+from dns.quic._common import (
+ QUIC_MAX_DATAGRAM,
+ BaseQuicConnection,
+ BaseQuicManager,
+ BaseQuicStream,
+ UnexpectedEOF,
+)
+
+# Avoid circularity with dns.query
+if hasattr(selectors, "PollSelector"):
+ _selector_class = selectors.PollSelector # type: ignore
else:
- _selector_class = selectors.SelectSelector
+ _selector_class = selectors.SelectSelector # type: ignore
class SyncQuicStream(BaseQuicStream):
-
def __init__(self, connection, stream_id):
super().__init__(connection, stream_id)
self._wake_up = threading.Condition()
self._lock = threading.Lock()
+ def wait_for(self, amount, expiration):
+ while True:
+ timeout = self._timeout_from_expiration(expiration)
+ with self._lock:
+ if self._buffer.have(amount):
+ return
+ self._expecting = amount
+ with self._wake_up:
+ if not self._wake_up.wait(timeout):
+ raise dns.exception.Timeout
+ self._expecting = 0
+
+ def receive(self, timeout=None):
+ expiration = self._expiration_from_timeout(timeout)
+ self.wait_for(2, expiration)
+ with self._lock:
+ (size,) = struct.unpack("!H", self._buffer.get(2))
+ self.wait_for(size, expiration)
+ with self._lock:
+ return self._buffer.get(size)
+
+ def send(self, datagram, is_end=False):
+ data = self._encapsulate(datagram)
+ self._connection.write(self._stream_id, data, is_end)
+
+ def _add_input(self, data, is_end):
+ if self._common_add_input(data, is_end):
+ with self._wake_up:
+ self._wake_up.notify()
+
+ def close(self):
+ with self._lock:
+ self._close()
+
def __enter__(self):
return self
@@ -34,39 +79,159 @@ class SyncQuicStream(BaseQuicStream):
class SyncQuicConnection(BaseQuicConnection):
-
- def __init__(self, connection, address, port, source, source_port, manager
- ):
- super().__init__(connection, address, port, source, source_port,
- manager)
+ def __init__(self, connection, address, port, source, source_port, manager):
+ super().__init__(connection, address, port, source, source_port, manager)
self._socket = socket.socket(self._af, socket.SOCK_DGRAM, 0)
if self._source is not None:
try:
- self._socket.bind(dns.inet.low_level_address_tuple(self.
- _source, self._af))
+ self._socket.bind(
+ dns.inet.low_level_address_tuple(self._source, self._af)
+ )
except Exception:
self._socket.close()
raise
self._socket.connect(self._peer)
- self._send_wakeup, self._receive_wakeup = socket.socketpair()
+ (self._send_wakeup, self._receive_wakeup) = socket.socketpair()
self._receive_wakeup.setblocking(False)
self._socket.setblocking(False)
self._handshake_complete = threading.Event()
self._worker_thread = None
self._lock = threading.Lock()
+ def _read(self):
+ count = 0
+ while count < 10:
+ count += 1
+ try:
+ datagram = self._socket.recv(QUIC_MAX_DATAGRAM)
+ except BlockingIOError:
+ return
+ with self._lock:
+ self._connection.receive_datagram(datagram, self._peer, time.time())
-class SyncQuicManager(BaseQuicManager):
+ def _drain_wakeup(self):
+ while True:
+ try:
+ self._receive_wakeup.recv(32)
+ except BlockingIOError:
+ return
+
+ def _worker(self):
+ try:
+ sel = _selector_class()
+ sel.register(self._socket, selectors.EVENT_READ, self._read)
+ sel.register(self._receive_wakeup, selectors.EVENT_READ, self._drain_wakeup)
+ while not self._done:
+ (expiration, interval) = self._get_timer_values(False)
+ items = sel.select(interval)
+ for key, _ in items:
+ key.data()
+ with self._lock:
+ self._handle_timer(expiration)
+ self._handle_events()
+ with self._lock:
+ datagrams = self._connection.datagrams_to_send(time.time())
+ for datagram, _ in datagrams:
+ try:
+ self._socket.send(datagram)
+ except BlockingIOError:
+ # we let QUIC handle any lossage
+ pass
+ finally:
+ with self._lock:
+ self._done = True
+ # Ensure anyone waiting for this gets woken up.
+ self._handshake_complete.set()
+
+ def _handle_events(self):
+ while True:
+ with self._lock:
+ event = self._connection.next_event()
+ if event is None:
+ return
+ if isinstance(event, aioquic.quic.events.StreamDataReceived):
+ with self._lock:
+ stream = self._streams.get(event.stream_id)
+ if stream:
+ stream._add_input(event.data, event.end_stream)
+ elif isinstance(event, aioquic.quic.events.HandshakeCompleted):
+ self._handshake_complete.set()
+ elif isinstance(event, aioquic.quic.events.ConnectionTerminated):
+ with self._lock:
+ self._done = True
+ elif isinstance(event, aioquic.quic.events.StreamReset):
+ with self._lock:
+ stream = self._streams.get(event.stream_id)
+ if stream:
+ stream._add_input(b"", True)
+
+ def write(self, stream, data, is_end=False):
+ with self._lock:
+ self._connection.send_stream_data(stream, data, is_end)
+ self._send_wakeup.send(b"\x01")
+
+ def run(self):
+ if self._closed:
+ return
+ self._worker_thread = threading.Thread(target=self._worker)
+ self._worker_thread.start()
+
+ def make_stream(self, timeout=None):
+ if not self._handshake_complete.wait(timeout):
+ raise dns.exception.Timeout
+ with self._lock:
+ if self._done:
+ raise UnexpectedEOF
+ stream_id = self._connection.get_next_available_stream_id(False)
+ stream = SyncQuicStream(self, stream_id)
+ self._streams[stream_id] = stream
+ return stream
+
+ def close_stream(self, stream_id):
+ with self._lock:
+ super().close_stream(stream_id)
- def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED,
- server_name=None):
+ def close(self):
+ with self._lock:
+ if self._closed:
+ return
+ self._manager.closed(self._peer[0], self._peer[1])
+ self._closed = True
+ self._connection.close()
+ self._send_wakeup.send(b"\x01")
+ self._worker_thread.join()
+
+
+class SyncQuicManager(BaseQuicManager):
+ def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None):
super().__init__(conf, verify_mode, SyncQuicConnection, server_name)
self._lock = threading.Lock()
+ def connect(
+ self, address, port=853, source=None, source_port=0, want_session_ticket=True
+ ):
+ with self._lock:
+ (connection, start) = self._connect(
+ address, port, source, source_port, want_session_ticket
+ )
+ if start:
+ connection.run()
+ return connection
+
+ def closed(self, address, port):
+ with self._lock:
+ super().closed(address, port)
+
+ def save_session_ticket(self, address, port, ticket):
+ with self._lock:
+ super().save_session_ticket(address, port, ticket)
+
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
+ # Copy the iterator into a list as exiting things will mutate the connections
+ # table.
connections = list(self._connections.values())
for connection in connections:
connection.close()
diff --git a/dns/quic/_trio.py b/dns/quic/_trio.py
index 70d5e90..35e36b9 100644
--- a/dns/quic/_trio.py
+++ b/dns/quic/_trio.py
@@ -1,23 +1,67 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
import socket
import ssl
import struct
import time
-import aioquic.quic.configuration
-import aioquic.quic.connection
-import aioquic.quic.events
+
+import aioquic.quic.configuration # type: ignore
+import aioquic.quic.connection # type: ignore
+import aioquic.quic.events # type: ignore
import trio
+
import dns.exception
import dns.inet
from dns._asyncbackend import NullContext
-from dns.quic._common import QUIC_MAX_DATAGRAM, AsyncQuicConnection, AsyncQuicManager, BaseQuicStream, UnexpectedEOF
+from dns.quic._common import (
+ QUIC_MAX_DATAGRAM,
+ AsyncQuicConnection,
+ AsyncQuicManager,
+ BaseQuicStream,
+ UnexpectedEOF,
+)
class TrioQuicStream(BaseQuicStream):
-
def __init__(self, connection, stream_id):
super().__init__(connection, stream_id)
self._wake_up = trio.Condition()
+ async def wait_for(self, amount):
+ while True:
+ if self._buffer.have(amount):
+ return
+ self._expecting = amount
+ async with self._wake_up:
+ await self._wake_up.wait()
+ self._expecting = 0
+
+ async def receive(self, timeout=None):
+ if timeout is None:
+ context = NullContext(None)
+ else:
+ context = trio.move_on_after(timeout)
+ with context:
+ await self.wait_for(2)
+ (size,) = struct.unpack("!H", self._buffer.get(2))
+ await self.wait_for(size)
+ return self._buffer.get(size)
+ raise dns.exception.Timeout
+
+ async def send(self, datagram, is_end=False):
+ data = self._encapsulate(datagram)
+ await self._connection.write(self._stream_id, data, is_end)
+
+ async def _add_input(self, data, is_end):
+ if self._common_add_input(data, is_end):
+ async with self._wake_up:
+ self._wake_up.notify()
+
+ async def close(self):
+ self._close()
+
+ # Streams are async context managers
+
async def __aenter__(self):
return self
@@ -29,29 +73,137 @@ class TrioQuicStream(BaseQuicStream):
class TrioQuicConnection(AsyncQuicConnection):
-
- def __init__(self, connection, address, port, source, source_port,
- manager=None):
- super().__init__(connection, address, port, source, source_port,
- manager)
+ def __init__(self, connection, address, port, source, source_port, manager=None):
+ super().__init__(connection, address, port, source, source_port, manager)
self._socket = trio.socket.socket(self._af, socket.SOCK_DGRAM, 0)
self._handshake_complete = trio.Event()
self._run_done = trio.Event()
self._worker_scope = None
self._send_pending = False
+ async def _worker(self):
+ try:
+ if self._source:
+ await self._socket.bind(
+ dns.inet.low_level_address_tuple(self._source, self._af)
+ )
+ await self._socket.connect(self._peer)
+ while not self._done:
+ (expiration, interval) = self._get_timer_values(False)
+ if self._send_pending:
+ # Do not block forever if sends are pending. Even though we
+ # have a wake-up mechanism if we've already started the blocking
+ # read, the possibility of context switching in send means that
+ # more writes can happen while we have no wake up context, so
+ # we need self._send_pending to avoid (effectively) a "lost wakeup"
+ # race.
+ interval = 0.0
+ with trio.CancelScope(
+ deadline=trio.current_time() + interval
+ ) as self._worker_scope:
+ datagram = await self._socket.recv(QUIC_MAX_DATAGRAM)
+ self._connection.receive_datagram(datagram, self._peer, time.time())
+ self._worker_scope = None
+ self._handle_timer(expiration)
+ await self._handle_events()
+ # We clear this now, before sending anything, as sending can cause
+ # context switches that do more sends. We want to know if that
+ # happens so we don't block a long time on the recv() above.
+ self._send_pending = False
+ datagrams = self._connection.datagrams_to_send(time.time())
+ for datagram, _ in datagrams:
+ await self._socket.send(datagram)
+ finally:
+ self._done = True
+ self._handshake_complete.set()
-class TrioQuicManager(AsyncQuicManager):
+ async def _handle_events(self):
+ count = 0
+ while True:
+ event = self._connection.next_event()
+ if event is None:
+ return
+ if isinstance(event, aioquic.quic.events.StreamDataReceived):
+ stream = self._streams.get(event.stream_id)
+ if stream:
+ await stream._add_input(event.data, event.end_stream)
+ elif isinstance(event, aioquic.quic.events.HandshakeCompleted):
+ self._handshake_complete.set()
+ elif isinstance(event, aioquic.quic.events.ConnectionTerminated):
+ self._done = True
+ self._socket.close()
+ elif isinstance(event, aioquic.quic.events.StreamReset):
+ stream = self._streams.get(event.stream_id)
+ if stream:
+ await stream._add_input(b"", True)
+ count += 1
+ if count > 10:
+ # yield
+ count = 0
+ await trio.sleep(0)
+
+ async def write(self, stream, data, is_end=False):
+ self._connection.send_stream_data(stream, data, is_end)
+ self._send_pending = True
+ if self._worker_scope is not None:
+ self._worker_scope.cancel()
+
+ async def run(self):
+ if self._closed:
+ return
+ async with trio.open_nursery() as nursery:
+ nursery.start_soon(self._worker)
+ self._run_done.set()
+
+ async def make_stream(self, timeout=None):
+ if timeout is None:
+ context = NullContext(None)
+ else:
+ context = trio.move_on_after(timeout)
+ with context:
+ await self._handshake_complete.wait()
+ if self._done:
+ raise UnexpectedEOF
+ stream_id = self._connection.get_next_available_stream_id(False)
+ stream = TrioQuicStream(self, stream_id)
+ self._streams[stream_id] = stream
+ return stream
+ raise dns.exception.Timeout
- def __init__(self, nursery, conf=None, verify_mode=ssl.CERT_REQUIRED,
- server_name=None):
+ async def close(self):
+ if not self._closed:
+ self._manager.closed(self._peer[0], self._peer[1])
+ self._closed = True
+ self._connection.close()
+ self._send_pending = True
+ if self._worker_scope is not None:
+ self._worker_scope.cancel()
+ await self._run_done.wait()
+
+
+class TrioQuicManager(AsyncQuicManager):
+ def __init__(
+ self, nursery, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None
+ ):
super().__init__(conf, verify_mode, TrioQuicConnection, server_name)
self._nursery = nursery
+ def connect(
+ self, address, port=853, source=None, source_port=0, want_session_ticket=True
+ ):
+ (connection, start) = self._connect(
+ address, port, source, source_port, want_session_ticket
+ )
+ if start:
+ self._nursery.start_soon(connection.run)
+ return connection
+
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
+ # Copy the iterator into a list as exiting things will mutate the connections
+ # table.
connections = list(self._connections.values())
for connection in connections:
await connection.close()
diff --git a/dns/rcode.py b/dns/rcode.py
index 820a695..8e6386f 100644
--- a/dns/rcode.py
+++ b/dns/rcode.py
@@ -1,38 +1,86 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2001-2017 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
"""DNS Result Codes."""
+
from typing import Tuple
+
import dns.enum
import dns.exception
class Rcode(dns.enum.IntEnum):
+ #: No error
NOERROR = 0
+ #: Format error
FORMERR = 1
+ #: Server failure
SERVFAIL = 2
+ #: Name does not exist ("Name Error" in RFC 1025 terminology).
NXDOMAIN = 3
+ #: Not implemented
NOTIMP = 4
+ #: Refused
REFUSED = 5
+ #: Name exists.
YXDOMAIN = 6
+ #: RRset exists.
YXRRSET = 7
+ #: RRset does not exist.
NXRRSET = 8
+ #: Not authoritative.
NOTAUTH = 9
+ #: Name not in zone.
NOTZONE = 10
+ #: DSO-TYPE Not Implemented
DSOTYPENI = 11
+ #: Bad EDNS version.
BADVERS = 16
+ #: TSIG Signature Failure
BADSIG = 16
+ #: Key not recognized.
BADKEY = 17
+ #: Signature out of time window.
BADTIME = 18
+ #: Bad TKEY Mode.
BADMODE = 19
+ #: Duplicate key name.
BADNAME = 20
+ #: Algorithm not supported.
BADALG = 21
+ #: Bad Truncation
BADTRUNC = 22
+ #: Bad/missing Server Cookie
BADCOOKIE = 23
+ @classmethod
+ def _maximum(cls):
+ return 4095
+
+ @classmethod
+ def _unknown_exception_class(cls):
+ return UnknownRcode
+
class UnknownRcode(dns.exception.DNSException):
"""A DNS rcode is unknown."""
-def from_text(text: str) ->Rcode:
+def from_text(text: str) -> Rcode:
"""Convert text into an rcode.
*text*, a ``str``, the textual rcode or an integer in textual form.
@@ -41,10 +89,11 @@ def from_text(text: str) ->Rcode:
Returns a ``dns.rcode.Rcode``.
"""
- pass
+
+ return Rcode.from_text(text)
-def from_flags(flags: int, ednsflags: int) ->Rcode:
+def from_flags(flags: int, ednsflags: int) -> Rcode:
"""Return the rcode value encoded by flags and ednsflags.
*flags*, an ``int``, the DNS flags field.
@@ -55,10 +104,12 @@ def from_flags(flags: int, ednsflags: int) ->Rcode:
Returns a ``dns.rcode.Rcode``.
"""
- pass
+ value = (flags & 0x000F) | ((ednsflags >> 20) & 0xFF0)
+ return Rcode.make(value)
-def to_flags(value: Rcode) ->Tuple[int, int]:
+
+def to_flags(value: Rcode) -> Tuple[int, int]:
"""Return a (flags, ednsflags) tuple which encodes the rcode.
*value*, a ``dns.rcode.Rcode``, the rcode.
@@ -67,10 +118,15 @@ def to_flags(value: Rcode) ->Tuple[int, int]:
Returns an ``(int, int)`` tuple.
"""
- pass
+
+ if value < 0 or value > 4095:
+ raise ValueError("rcode must be >= 0 and <= 4095")
+ v = value & 0xF
+ ev = (value & 0xFF0) << 20
+ return (v, ev)
-def to_text(value: Rcode, tsig: bool=False) ->str:
+def to_text(value: Rcode, tsig: bool = False) -> str:
"""Convert rcode into text.
*value*, a ``dns.rcode.Rcode``, the rcode.
@@ -79,9 +135,14 @@ def to_text(value: Rcode, tsig: bool=False) ->str:
Returns a ``str``.
"""
- pass
+
+ if tsig and value == Rcode.BADVERS:
+ return "BADSIG"
+ return Rcode.to_text(value)
+### BEGIN generated Rcode constants
+
NOERROR = Rcode.NOERROR
FORMERR = Rcode.FORMERR
SERVFAIL = Rcode.SERVFAIL
@@ -103,3 +164,5 @@ BADNAME = Rcode.BADNAME
BADALG = Rcode.BADALG
BADTRUNC = Rcode.BADTRUNC
BADCOOKIE = Rcode.BADCOOKIE
+
+### END generated Rcode constants
diff --git a/dns/rdata.py b/dns/rdata.py
index 11cbf5e..024fd8f 100644
--- a/dns/rdata.py
+++ b/dns/rdata.py
@@ -1,4 +1,22 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2001-2017 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
"""DNS rdata."""
+
import base64
import binascii
import inspect
@@ -7,6 +25,7 @@ import itertools
import random
from importlib import import_module
from typing import Any, Dict, Optional, Tuple, Union
+
import dns.exception
import dns.immutable
import dns.ipv4
@@ -17,7 +36,15 @@ import dns.rdatatype
import dns.tokenizer
import dns.ttl
import dns.wire
+
_chunksize = 32
+
+# We currently allow comparisons for rdata with relative names for backwards
+# compatibility, but in the future we will not, as these kinds of comparisons
+# can lead to subtle bugs if code is not carefully written.
+#
+# This switch allows the future behavior to be turned on so code can be
+# tested with it.
_allow_relative_comparisons = True
@@ -29,49 +56,81 @@ class NoRelativeRdataOrdering(dns.exception.DNSException):
"""
-def _wordbreak(data, chunksize=_chunksize, separator=b' '):
+def _wordbreak(data, chunksize=_chunksize, separator=b" "):
"""Break a binary string into chunks of chunksize characters separated by
a space.
"""
- pass
+ if not chunksize:
+ return data.decode()
+ return separator.join(
+ [data[i : i + chunksize] for i in range(0, len(data), chunksize)]
+ ).decode()
+
+
+# pylint: disable=unused-argument
-def _hexify(data, chunksize=_chunksize, separator=b' ', **kw):
+
+def _hexify(data, chunksize=_chunksize, separator=b" ", **kw):
"""Convert a binary string into its hex encoding, broken up into chunks
of chunksize characters separated by a separator.
"""
- pass
+
+ return _wordbreak(binascii.hexlify(data), chunksize, separator)
-def _base64ify(data, chunksize=_chunksize, separator=b' ', **kw):
+def _base64ify(data, chunksize=_chunksize, separator=b" ", **kw):
"""Convert a binary string into its base64 encoding, broken up into chunks
of chunksize characters separated by a separator.
"""
- pass
+
+ return _wordbreak(base64.b64encode(data), chunksize, separator)
+# pylint: enable=unused-argument
+
__escaped = b'"\\'
def _escapify(qstring):
"""Escape the characters in a quoted string which need it."""
- pass
+
+ if isinstance(qstring, str):
+ qstring = qstring.encode()
+ if not isinstance(qstring, bytearray):
+ qstring = bytearray(qstring)
+
+ text = ""
+ for c in qstring:
+ if c in __escaped:
+ text += "\\" + chr(c)
+ elif c >= 0x20 and c < 0x7F:
+ text += chr(c)
+ else:
+ text += "\\%03d" % c
+ return text
def _truncate_bitmap(what):
"""Determine the index of greatest byte that isn't all zeros, and
return the bitmap that contains all the bytes less than that index.
"""
- pass
+ for i in range(len(what) - 1, -1, -1):
+ if what[i] != 0:
+ return what[0 : i + 1]
+ return what[0:1]
+
+# So we don't have to edit all the rdata classes...
_constify = dns.immutable.constify
@dns.immutable.immutable
class Rdata:
"""Base class for all DNS rdata types."""
- __slots__ = ['rdclass', 'rdtype', 'rdcomment']
+
+ __slots__ = ["rdclass", "rdtype", "rdcomment"]
def __init__(self, rdclass, rdtype):
"""Initialize an rdata.
@@ -80,11 +139,25 @@ class Rdata:
*rdtype*, an ``int`` is the rdatatype of the Rdata.
"""
+
self.rdclass = self._as_rdataclass(rdclass)
self.rdtype = self._as_rdatatype(rdtype)
self.rdcomment = None
+ def _get_all_slots(self):
+ return itertools.chain.from_iterable(
+ getattr(cls, "__slots__", []) for cls in self.__class__.__mro__
+ )
+
def __getstate__(self):
+ # We used to try to do a tuple of all slots here, but it
+ # doesn't work as self._all_slots isn't available at
+ # __setstate__() time. Before that we tried to store a tuple
+ # of __slots__, but that didn't work as it didn't store the
+ # slots defined by ancestors. This older way didn't fail
+ # outright, but ended up with partially broken objects, e.g.
+ # if you unpickled an A RR it wouldn't have rdclass and rdtype
+ # attributes, and would compare badly.
state = {}
for slot in self._get_all_slots():
state[slot] = getattr(self, slot)
@@ -93,10 +166,12 @@ class Rdata:
def __setstate__(self, state):
for slot, val in state.items():
object.__setattr__(self, slot, val)
- if not hasattr(self, 'rdcomment'):
- object.__setattr__(self, 'rdcomment', None)
+ if not hasattr(self, "rdcomment"):
+ # Pickled rdata from 2.0.x might not have a rdcomment, so add
+ # it if needed.
+ object.__setattr__(self, "rdcomment", None)
- def covers(self) ->dns.rdatatype.RdataType:
+ def covers(self) -> dns.rdatatype.RdataType:
"""Return the type a Rdata covers.
DNS SIG/RRSIG rdatas apply to a specific type; this type is
@@ -107,59 +182,96 @@ class Rdata:
Returns a ``dns.rdatatype.RdataType``.
"""
- pass
- def extended_rdatatype(self) ->int:
+ return dns.rdatatype.NONE
+
+ def extended_rdatatype(self) -> int:
"""Return a 32-bit type value, the least significant 16 bits of
which are the ordinary DNS type, and the upper 16 bits of which are
the "covered" type, if any.
Returns an ``int``.
"""
- pass
- def to_text(self, origin: Optional[dns.name.Name]=None, relativize:
- bool=True, **kw: Dict[str, Any]) ->str:
+ return self.covers() << 16 | self.rdtype
+
+ def to_text(
+ self,
+ origin: Optional[dns.name.Name] = None,
+ relativize: bool = True,
+ **kw: Dict[str, Any],
+ ) -> str:
"""Convert an rdata to text format.
Returns a ``str``.
"""
- pass
- def to_wire(self, file: Optional[Any]=None, compress: Optional[dns.name
- .CompressType]=None, origin: Optional[dns.name.Name]=None,
- canonicalize: bool=False) ->bytes:
+ raise NotImplementedError # pragma: no cover
+
+ def _to_wire(
+ self,
+ file: Optional[Any],
+ compress: Optional[dns.name.CompressType] = None,
+ origin: Optional[dns.name.Name] = None,
+ canonicalize: bool = False,
+ ) -> bytes:
+ raise NotImplementedError # pragma: no cover
+
+ def to_wire(
+ self,
+ file: Optional[Any] = None,
+ compress: Optional[dns.name.CompressType] = None,
+ origin: Optional[dns.name.Name] = None,
+ canonicalize: bool = False,
+ ) -> bytes:
"""Convert an rdata to wire format.
Returns a ``bytes`` or ``None``.
"""
- pass
- def to_generic(self, origin: Optional[dns.name.Name]=None
- ) ->'dns.rdata.GenericRdata':
+ if file:
+ return self._to_wire(file, compress, origin, canonicalize)
+ else:
+ f = io.BytesIO()
+ self._to_wire(f, compress, origin, canonicalize)
+ return f.getvalue()
+
+ def to_generic(
+ self, origin: Optional[dns.name.Name] = None
+ ) -> "dns.rdata.GenericRdata":
"""Creates a dns.rdata.GenericRdata equivalent of this rdata.
Returns a ``dns.rdata.GenericRdata``.
"""
- pass
+ return dns.rdata.GenericRdata(
+ self.rdclass, self.rdtype, self.to_wire(origin=origin)
+ )
- def to_digestable(self, origin: Optional[dns.name.Name]=None) ->bytes:
+ def to_digestable(self, origin: Optional[dns.name.Name] = None) -> bytes:
"""Convert rdata to a format suitable for digesting in hashes. This
is also the DNSSEC canonical form.
Returns a ``bytes``.
"""
- pass
+
+ return self.to_wire(origin=origin, canonicalize=True)
def __repr__(self):
covers = self.covers()
if covers == dns.rdatatype.NONE:
- ctext = ''
+ ctext = ""
else:
- ctext = '(' + dns.rdatatype.to_text(covers) + ')'
- return '<DNS ' + dns.rdataclass.to_text(self.rdclass
- ) + ' ' + dns.rdatatype.to_text(self.rdtype
- ) + ctext + ' rdata: ' + str(self) + '>'
+ ctext = "(" + dns.rdatatype.to_text(covers) + ")"
+ return (
+ "<DNS "
+ + dns.rdataclass.to_text(self.rdclass)
+ + " "
+ + dns.rdatatype.to_text(self.rdtype)
+ + ctext
+ + " rdata: "
+ + str(self)
+ + ">"
+ )
def __str__(self):
return self.to_text()
@@ -180,7 +292,36 @@ class Rdata:
In the future, all ordering comparisons for rdata with
relative names will be disallowed.
"""
- pass
+ try:
+ our = self.to_digestable()
+ our_relative = False
+ except dns.name.NeedAbsoluteNameOrOrigin:
+ if _allow_relative_comparisons:
+ our = self.to_digestable(dns.name.root)
+ our_relative = True
+ try:
+ their = other.to_digestable()
+ their_relative = False
+ except dns.name.NeedAbsoluteNameOrOrigin:
+ if _allow_relative_comparisons:
+ their = other.to_digestable(dns.name.root)
+ their_relative = True
+ if _allow_relative_comparisons:
+ if our_relative != their_relative:
+ # For the purpose of comparison, all rdata with at least one
+ # relative name is less than an rdata with only absolute names.
+ if our_relative:
+ return -1
+ else:
+ return 1
+ elif our_relative or their_relative:
+ raise NoRelativeRdataOrdering
+ if our == their:
+ return 0
+ elif our > their:
+ return 1
+ else:
+ return -1
def __eq__(self, other):
if not isinstance(other, Rdata):
@@ -211,33 +352,67 @@ class Rdata:
return not self.__eq__(other)
def __lt__(self, other):
- if not isinstance(other, Rdata
- ) or self.rdclass != other.rdclass or self.rdtype != other.rdtype:
+ if (
+ not isinstance(other, Rdata)
+ or self.rdclass != other.rdclass
+ or self.rdtype != other.rdtype
+ ):
return NotImplemented
return self._cmp(other) < 0
def __le__(self, other):
- if not isinstance(other, Rdata
- ) or self.rdclass != other.rdclass or self.rdtype != other.rdtype:
+ if (
+ not isinstance(other, Rdata)
+ or self.rdclass != other.rdclass
+ or self.rdtype != other.rdtype
+ ):
return NotImplemented
return self._cmp(other) <= 0
def __ge__(self, other):
- if not isinstance(other, Rdata
- ) or self.rdclass != other.rdclass or self.rdtype != other.rdtype:
+ if (
+ not isinstance(other, Rdata)
+ or self.rdclass != other.rdclass
+ or self.rdtype != other.rdtype
+ ):
return NotImplemented
return self._cmp(other) >= 0
def __gt__(self, other):
- if not isinstance(other, Rdata
- ) or self.rdclass != other.rdclass or self.rdtype != other.rdtype:
+ if (
+ not isinstance(other, Rdata)
+ or self.rdclass != other.rdclass
+ or self.rdtype != other.rdtype
+ ):
return NotImplemented
return self._cmp(other) > 0
def __hash__(self):
return hash(self.to_digestable(dns.name.root))
- def replace(self, **kwargs: Any) ->'Rdata':
+ @classmethod
+ def from_text(
+ cls,
+ rdclass: dns.rdataclass.RdataClass,
+ rdtype: dns.rdatatype.RdataType,
+ tok: dns.tokenizer.Tokenizer,
+ origin: Optional[dns.name.Name] = None,
+ relativize: bool = True,
+ relativize_to: Optional[dns.name.Name] = None,
+ ) -> "Rdata":
+ raise NotImplementedError # pragma: no cover
+
+ @classmethod
+ def from_wire_parser(
+ cls,
+ rdclass: dns.rdataclass.RdataClass,
+ rdtype: dns.rdatatype.RdataType,
+ parser: dns.wire.Parser,
+ origin: Optional[dns.name.Name] = None,
+ ) -> "Rdata":
+ raise NotImplementedError # pragma: no cover
+
+ def replace(self, **kwargs: Any) -> "Rdata":
"""
Create a new Rdata instance based on the instance replace was
invoked on. It is possible to pass different parameters to
@@ -248,7 +423,179 @@ class Rdata:
Returns an instance of the same Rdata subclass as *self*.
"""
- pass
+
+ # Get the constructor parameters.
+ parameters = inspect.signature(self.__init__).parameters # type: ignore
+
+ # Ensure that all of the arguments correspond to valid fields.
+ # Don't allow rdclass or rdtype to be changed, though.
+ for key in kwargs:
+ if key == "rdcomment":
+ continue
+ if key not in parameters:
+ raise AttributeError(
+ "'{}' object has no attribute '{}'".format(
+ self.__class__.__name__, key
+ )
+ )
+ if key in ("rdclass", "rdtype"):
+ raise AttributeError(
+ "Cannot overwrite '{}' attribute '{}'".format(
+ self.__class__.__name__, key
+ )
+ )
+
+ # Construct the parameter list. For each field, use the value in
+ # kwargs if present, and the current value otherwise.
+ args = (kwargs.get(key, getattr(self, key)) for key in parameters)
+
+ # Create, validate, and return the new object.
+ rd = self.__class__(*args)
+ # The comment is not set in the constructor, so give it special
+ # handling.
+ rdcomment = kwargs.get("rdcomment", self.rdcomment)
+ if rdcomment is not None:
+ object.__setattr__(rd, "rdcomment", rdcomment)
+ return rd
+
+ # Type checking and conversion helpers. These are class methods as
+ # they don't touch object state and may be useful to others.
+
+ @classmethod
+ def _as_rdataclass(cls, value):
+ return dns.rdataclass.RdataClass.make(value)
+
+ @classmethod
+ def _as_rdatatype(cls, value):
+ return dns.rdatatype.RdataType.make(value)
+
+ @classmethod
+ def _as_bytes(
+ cls,
+ value: Any,
+ encode: bool = False,
+ max_length: Optional[int] = None,
+ empty_ok: bool = True,
+ ) -> bytes:
+ if encode and isinstance(value, str):
+ bvalue = value.encode()
+ elif isinstance(value, bytearray):
+ bvalue = bytes(value)
+ elif isinstance(value, bytes):
+ bvalue = value
+ else:
+ raise ValueError("not bytes")
+ if max_length is not None and len(bvalue) > max_length:
+ raise ValueError("too long")
+ if not empty_ok and len(bvalue) == 0:
+ raise ValueError("empty bytes not allowed")
+ return bvalue
+
+ @classmethod
+ def _as_name(cls, value):
+ # Note that proper name conversion (e.g. with origin and IDNA
+ # awareness) is expected to be done via from_text. This is just
+ # a simple thing for people invoking the constructor directly.
+ if isinstance(value, str):
+ return dns.name.from_text(value)
+ elif not isinstance(value, dns.name.Name):
+ raise ValueError("not a name")
+ return value
+
+ @classmethod
+ def _as_uint8(cls, value):
+ if not isinstance(value, int):
+ raise ValueError("not an integer")
+ if value < 0 or value > 255:
+ raise ValueError("not a uint8")
+ return value
+
+ @classmethod
+ def _as_uint16(cls, value):
+ if not isinstance(value, int):
+ raise ValueError("not an integer")
+ if value < 0 or value > 65535:
+ raise ValueError("not a uint16")
+ return value
+
+ @classmethod
+ def _as_uint32(cls, value):
+ if not isinstance(value, int):
+ raise ValueError("not an integer")
+ if value < 0 or value > 4294967295:
+ raise ValueError("not a uint32")
+ return value
+
+ @classmethod
+ def _as_uint48(cls, value):
+ if not isinstance(value, int):
+ raise ValueError("not an integer")
+ if value < 0 or value > 281474976710655:
+ raise ValueError("not a uint48")
+ return value
+
+ @classmethod
+ def _as_int(cls, value, low=None, high=None):
+ if not isinstance(value, int):
+ raise ValueError("not an integer")
+ if low is not None and value < low:
+ raise ValueError("value too small")
+ if high is not None and value > high:
+ raise ValueError("value too large")
+ return value
+
+ @classmethod
+ def _as_ipv4_address(cls, value):
+ if isinstance(value, str):
+ return dns.ipv4.canonicalize(value)
+ elif isinstance(value, bytes):
+ return dns.ipv4.inet_ntoa(value)
+ else:
+ raise ValueError("not an IPv4 address")
+
+ @classmethod
+ def _as_ipv6_address(cls, value):
+ if isinstance(value, str):
+ return dns.ipv6.canonicalize(value)
+ elif isinstance(value, bytes):
+ return dns.ipv6.inet_ntoa(value)
+ else:
+ raise ValueError("not an IPv6 address")
+
+ @classmethod
+ def _as_bool(cls, value):
+ if isinstance(value, bool):
+ return value
+ else:
+ raise ValueError("not a boolean")
+
+ @classmethod
+ def _as_ttl(cls, value):
+ if isinstance(value, int):
+ return cls._as_int(value, 0, dns.ttl.MAX_TTL)
+ elif isinstance(value, str):
+ return dns.ttl.from_text(value)
+ else:
+ raise ValueError("not a TTL")
+
+ @classmethod
+ def _as_tuple(cls, value, as_value):
+ try:
+ # For user convenience, if value is a singleton of the list
+ # element type, wrap it in a tuple.
+ return (as_value(value),)
+ except Exception:
+ # Otherwise, check each element of the iterable *value*
+ # against *as_value*.
+ return tuple(as_value(v) for v in value)
+
+ # Processing order
+
+ @classmethod
+ def _processing_order(cls, iterable):
+ items = list(iterable)
+ random.shuffle(items)
+ return items
@dns.immutable.immutable
@@ -258,23 +605,86 @@ class GenericRdata(Rdata):
This class is used for rdata types for which we have no better
implementation. It implements the DNS "unknown RRs" scheme.
"""
- __slots__ = ['data']
+
+ __slots__ = ["data"]
def __init__(self, rdclass, rdtype, data):
super().__init__(rdclass, rdtype)
self.data = data
-
-_rdata_classes: Dict[Tuple[dns.rdataclass.RdataClass, dns.rdatatype.
- RdataType], Any] = {}
-_module_prefix = 'dns.rdtypes'
-
-
-def from_text(rdclass: Union[dns.rdataclass.RdataClass, str], rdtype: Union
- [dns.rdatatype.RdataType, str], tok: Union[dns.tokenizer.Tokenizer, str
- ], origin: Optional[dns.name.Name]=None, relativize: bool=True,
- relativize_to: Optional[dns.name.Name]=None, idna_codec: Optional[dns.
- name.IDNACodec]=None) ->Rdata:
+ def to_text(
+ self,
+ origin: Optional[dns.name.Name] = None,
+ relativize: bool = True,
+ **kw: Dict[str, Any],
+ ) -> str:
+ return r"\# %d " % len(self.data) + _hexify(self.data, **kw)
+
+ @classmethod
+ def from_text(
+ cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
+ ):
+ token = tok.get()
+ if not token.is_identifier() or token.value != r"\#":
+ raise dns.exception.SyntaxError(r"generic rdata does not start with \#")
+ length = tok.get_int()
+ hex = tok.concatenate_remaining_identifiers(True).encode()
+ data = binascii.unhexlify(hex)
+ if len(data) != length:
+ raise dns.exception.SyntaxError("generic rdata hex data has wrong length")
+ return cls(rdclass, rdtype, data)
+
+ def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
+ file.write(self.data)
+
+ @classmethod
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ return cls(rdclass, rdtype, parser.get_remaining())
+
+
+_rdata_classes: Dict[Tuple[dns.rdataclass.RdataClass, dns.rdatatype.RdataType], Any] = (
+ {}
+)
+_module_prefix = "dns.rdtypes"
+
+
+def get_rdata_class(rdclass, rdtype):
+ cls = _rdata_classes.get((rdclass, rdtype))
+ if not cls:
+ cls = _rdata_classes.get((dns.rdatatype.ANY, rdtype))
+ if not cls:
+ rdclass_text = dns.rdataclass.to_text(rdclass)
+ rdtype_text = dns.rdatatype.to_text(rdtype)
+ rdtype_text = rdtype_text.replace("-", "_")
+ try:
+ mod = import_module(
+ ".".join([_module_prefix, rdclass_text, rdtype_text])
+ )
+ cls = getattr(mod, rdtype_text)
+ _rdata_classes[(rdclass, rdtype)] = cls
+ except ImportError:
+ try:
+ mod = import_module(".".join([_module_prefix, "ANY", rdtype_text]))
+ cls = getattr(mod, rdtype_text)
+ _rdata_classes[(dns.rdataclass.ANY, rdtype)] = cls
+ _rdata_classes[(rdclass, rdtype)] = cls
+ except ImportError:
+ pass
+ if not cls:
+ cls = GenericRdata
+ _rdata_classes[(rdclass, rdtype)] = cls
+ return cls
+
+
+def from_text(
+ rdclass: Union[dns.rdataclass.RdataClass, str],
+ rdtype: Union[dns.rdatatype.RdataType, str],
+ tok: Union[dns.tokenizer.Tokenizer, str],
+ origin: Optional[dns.name.Name] = None,
+ relativize: bool = True,
+ relativize_to: Optional[dns.name.Name] = None,
+ idna_codec: Optional[dns.name.IDNACodec] = None,
+) -> Rdata:
"""Build an rdata object from text format.
This function attempts to dynamically load a class which
@@ -311,12 +721,57 @@ def from_text(rdclass: Union[dns.rdataclass.RdataClass, str], rdtype: Union
Returns an instance of the chosen Rdata subclass.
"""
- pass
-
-
-def from_wire_parser(rdclass: Union[dns.rdataclass.RdataClass, str], rdtype:
- Union[dns.rdatatype.RdataType, str], parser: dns.wire.Parser, origin:
- Optional[dns.name.Name]=None) ->Rdata:
+ if isinstance(tok, str):
+ tok = dns.tokenizer.Tokenizer(tok, idna_codec=idna_codec)
+ rdclass = dns.rdataclass.RdataClass.make(rdclass)
+ rdtype = dns.rdatatype.RdataType.make(rdtype)
+ cls = get_rdata_class(rdclass, rdtype)
+ with dns.exception.ExceptionWrapper(dns.exception.SyntaxError):
+ rdata = None
+ if cls != GenericRdata:
+ # peek at first token
+ token = tok.get()
+ tok.unget(token)
+ if token.is_identifier() and token.value == r"\#":
+ #
+ # Known type using the generic syntax. Extract the
+ # wire form from the generic syntax, and then run
+ # from_wire on it.
+ #
+ grdata = GenericRdata.from_text(
+ rdclass, rdtype, tok, origin, relativize, relativize_to
+ )
+ rdata = from_wire(
+ rdclass, rdtype, grdata.data, 0, len(grdata.data), origin
+ )
+ #
+ # If this comparison isn't equal, then there must have been
+ # compressed names in the wire format, which is an error,
+ # there being no reasonable context to decompress with.
+ #
+ rwire = rdata.to_wire()
+ if rwire != grdata.data:
+ raise dns.exception.SyntaxError(
+ "compressed data in "
+ "generic syntax form "
+ "of known rdatatype"
+ )
+ if rdata is None:
+ rdata = cls.from_text(
+ rdclass, rdtype, tok, origin, relativize, relativize_to
+ )
+ token = tok.get_eol_as_token()
+ if token.comment is not None:
+ object.__setattr__(rdata, "rdcomment", token.comment)
+ return rdata
+
+
+def from_wire_parser(
+ rdclass: Union[dns.rdataclass.RdataClass, str],
+ rdtype: Union[dns.rdatatype.RdataType, str],
+ parser: dns.wire.Parser,
+ origin: Optional[dns.name.Name] = None,
+) -> Rdata:
"""Build an rdata object from wire format
This function attempts to dynamically load a class which
@@ -339,12 +794,22 @@ def from_wire_parser(rdclass: Union[dns.rdataclass.RdataClass, str], rdtype:
Returns an instance of the chosen Rdata subclass.
"""
- pass
-
-def from_wire(rdclass: Union[dns.rdataclass.RdataClass, str], rdtype: Union
- [dns.rdatatype.RdataType, str], wire: bytes, current: int, rdlen: int,
- origin: Optional[dns.name.Name]=None) ->Rdata:
+ rdclass = dns.rdataclass.RdataClass.make(rdclass)
+ rdtype = dns.rdatatype.RdataType.make(rdtype)
+ cls = get_rdata_class(rdclass, rdtype)
+ with dns.exception.ExceptionWrapper(dns.exception.FormError):
+ return cls.from_wire_parser(rdclass, rdtype, parser, origin)
+
+
+def from_wire(
+ rdclass: Union[dns.rdataclass.RdataClass, str],
+ rdtype: Union[dns.rdatatype.RdataType, str],
+ wire: bytes,
+ current: int,
+ rdlen: int,
+ origin: Optional[dns.name.Name] = None,
+) -> Rdata:
"""Build an rdata object from wire format
This function attempts to dynamically load a class which
@@ -371,19 +836,28 @@ def from_wire(rdclass: Union[dns.rdataclass.RdataClass, str], rdtype: Union
Returns an instance of the chosen Rdata subclass.
"""
- pass
+ parser = dns.wire.Parser(wire, current)
+ with parser.restrict_to(rdlen):
+ return from_wire_parser(rdclass, rdtype, parser, origin)
class RdatatypeExists(dns.exception.DNSException):
"""DNS rdatatype already exists."""
- supp_kwargs = {'rdclass', 'rdtype'}
- fmt = ('The rdata type with class {rdclass:d} and rdtype {rdtype:d} ' +
- 'already exists.')
+
+ supp_kwargs = {"rdclass", "rdtype"}
+ fmt = (
+ "The rdata type with class {rdclass:d} and rdtype {rdtype:d} "
+ + "already exists."
+ )
-def register_type(implementation: Any, rdtype: int, rdtype_text: str,
- is_singleton: bool=False, rdclass: dns.rdataclass.RdataClass=dns.
- rdataclass.IN) ->None:
+def register_type(
+ implementation: Any,
+ rdtype: int,
+ rdtype_text: str,
+ is_singleton: bool = False,
+ rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN,
+) -> None:
"""Dynamically register a module to handle an rdatatype.
*implementation*, a module implementing the type in the usual dnspython
@@ -399,4 +873,12 @@ def register_type(implementation: Any, rdtype: int, rdtype_text: str,
*rdclass*, the rdataclass of the type, or ``dns.rdataclass.ANY`` if
it applies to all classes.
"""
- pass
+
+ rdtype = dns.rdatatype.RdataType.make(rdtype)
+ existing_cls = get_rdata_class(rdclass, rdtype)
+ if existing_cls != GenericRdata or dns.rdatatype.is_metatype(rdtype):
+ raise RdatatypeExists(rdclass=rdclass, rdtype=rdtype)
+ _rdata_classes[(rdclass, rdtype)] = getattr(
+ implementation, rdtype_text.replace("-", "_")
+ )
+ dns.rdatatype.register_type(rdtype, rdtype_text, is_singleton)
diff --git a/dns/rdataclass.py b/dns/rdataclass.py
index 2db3e64..89b85a7 100644
--- a/dns/rdataclass.py
+++ b/dns/rdataclass.py
@@ -1,10 +1,29 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2001-2017 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
"""DNS Rdata Classes."""
+
import dns.enum
import dns.exception
class RdataClass(dns.enum.IntEnum):
"""DNS Rdata Class"""
+
RESERVED0 = 0
IN = 1
INTERNET = IN
@@ -15,6 +34,22 @@ class RdataClass(dns.enum.IntEnum):
NONE = 254
ANY = 255
+ @classmethod
+ def _maximum(cls):
+ return 65535
+
+ @classmethod
+ def _short_name(cls):
+ return "class"
+
+ @classmethod
+ def _prefix(cls):
+ return "CLASS"
+
+ @classmethod
+ def _unknown_exception_class(cls):
+ return UnknownRdataclass
+
_metaclasses = {RdataClass.NONE, RdataClass.ANY}
@@ -23,7 +58,7 @@ class UnknownRdataclass(dns.exception.DNSException):
"""A DNS class is unknown."""
-def from_text(text: str) ->RdataClass:
+def from_text(text: str) -> RdataClass:
"""Convert text into a DNS rdata class value.
The input text can be a defined DNS RR class mnemonic or
@@ -37,10 +72,11 @@ def from_text(text: str) ->RdataClass:
Returns a ``dns.rdataclass.RdataClass``.
"""
- pass
+
+ return RdataClass.from_text(text)
-def to_text(value: RdataClass) ->str:
+def to_text(value: RdataClass) -> str:
"""Convert a DNS rdata class value to text.
If the value has a known mnemonic, it will be used, otherwise the
@@ -50,18 +86,24 @@ def to_text(value: RdataClass) ->str:
Returns a ``str``.
"""
- pass
+
+ return RdataClass.to_text(value)
-def is_metaclass(rdclass: RdataClass) ->bool:
+def is_metaclass(rdclass: RdataClass) -> bool:
"""True if the specified class is a metaclass.
The currently defined metaclasses are ANY and NONE.
*rdclass* is a ``dns.rdataclass.RdataClass``.
"""
- pass
+ if rdclass in _metaclasses:
+ return True
+ return False
+
+
+### BEGIN generated RdataClass constants
RESERVED0 = RdataClass.RESERVED0
IN = RdataClass.IN
@@ -72,3 +114,5 @@ HS = RdataClass.HS
HESIOD = RdataClass.HESIOD
NONE = RdataClass.NONE
ANY = RdataClass.ANY
+
+### END generated RdataClass constants
diff --git a/dns/rdataset.py b/dns/rdataset.py
index 7228b61..8bff58d 100644
--- a/dns/rdataset.py
+++ b/dns/rdataset.py
@@ -1,8 +1,27 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2001-2017 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
"""DNS rdatasets (an rdataset is a set of rdatas of a given type and class)"""
+
import io
import random
import struct
from typing import Any, Collection, Dict, List, Optional, Union, cast
+
import dns.exception
import dns.immutable
import dns.name
@@ -12,6 +31,8 @@ import dns.rdatatype
import dns.renderer
import dns.set
import dns.ttl
+
+# define SimpleSet here for backwards compatibility
SimpleSet = dns.set.Set
@@ -26,11 +47,16 @@ class IncompatibleTypes(dns.exception.DNSException):
class Rdataset(dns.set.Set):
"""A DNS rdataset."""
- __slots__ = ['rdclass', 'rdtype', 'covers', 'ttl']
- def __init__(self, rdclass: dns.rdataclass.RdataClass, rdtype: dns.
- rdatatype.RdataType, covers: dns.rdatatype.RdataType=dns.rdatatype.
- NONE, ttl: int=0):
+ __slots__ = ["rdclass", "rdtype", "covers", "ttl"]
+
+ def __init__(
+ self,
+ rdclass: dns.rdataclass.RdataClass,
+ rdtype: dns.rdatatype.RdataType,
+ covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
+ ttl: int = 0,
+ ):
"""Create a new rdataset of the specified class and type.
*rdclass*, a ``dns.rdataclass.RdataClass``, the rdataclass.
@@ -41,13 +67,22 @@ class Rdataset(dns.set.Set):
*ttl*, an ``int``, the TTL.
"""
+
super().__init__()
self.rdclass = rdclass
self.rdtype: dns.rdatatype.RdataType = rdtype
self.covers: dns.rdatatype.RdataType = covers
self.ttl = ttl
- def update_ttl(self, ttl: int) ->None:
+ def _clone(self):
+ obj = super()._clone()
+ obj.rdclass = self.rdclass
+ obj.rdtype = self.rdtype
+ obj.covers = self.covers
+ obj.ttl = self.ttl
+ return obj
+
+ def update_ttl(self, ttl: int) -> None:
"""Perform TTL minimization.
Set the TTL of the rdataset to be the lesser of the set's current
@@ -56,9 +91,15 @@ class Rdataset(dns.set.Set):
*ttl*, an ``int`` or ``str``.
"""
- pass
-
- def add(self, rd: dns.rdata.Rdata, ttl: Optional[int]=None) ->None:
+ ttl = dns.ttl.make(ttl)
+ if len(self) == 0:
+ self.ttl = ttl
+ elif ttl < self.ttl:
+ self.ttl = ttl
+
+ def add( # pylint: disable=arguments-differ,arguments-renamed
+ self, rd: dns.rdata.Rdata, ttl: Optional[int] = None
+ ) -> None:
"""Add the specified rdata to the rdataset.
If the optional *ttl* parameter is supplied, then
@@ -74,7 +115,34 @@ class Rdataset(dns.set.Set):
Raises ``dns.rdataset.DifferingCovers`` if the type is a signature
type and the covered type does not match that of the rdataset.
"""
- pass
+
+ #
+ # If we're adding a signature, do some special handling to
+ # check that the signature covers the same type as the
+ # other rdatas in this rdataset. If this is the first rdata
+ # in the set, initialize the covers field.
+ #
+ if self.rdclass != rd.rdclass or self.rdtype != rd.rdtype:
+ raise IncompatibleTypes
+ if ttl is not None:
+ self.update_ttl(ttl)
+ if self.rdtype == dns.rdatatype.RRSIG or self.rdtype == dns.rdatatype.SIG:
+ covers = rd.covers()
+ if len(self) == 0 and self.covers == dns.rdatatype.NONE:
+ self.covers = covers
+ elif self.covers != covers:
+ raise DifferingCovers
+ if dns.rdatatype.is_singleton(rd.rdtype) and len(self) > 0:
+ self.clear()
+ super().add(rd)
+
+ def union_update(self, other):
+ self.update_ttl(other.ttl)
+ super().union_update(other)
+
+ def intersection_update(self, other):
+ self.update_ttl(other.ttl)
+ super().intersection_update(other)
def update(self, other):
"""Add all rdatas in other to self.
@@ -82,16 +150,33 @@ class Rdataset(dns.set.Set):
*other*, a ``dns.rdataset.Rdataset``, the rdataset from which
to update.
"""
- pass
+
+ self.update_ttl(other.ttl)
+ super().update(other)
+
+ def _rdata_repr(self):
+ def maybe_truncate(s):
+ if len(s) > 100:
+ return s[:100] + "..."
+ return s
+
+ return "[%s]" % ", ".join("<%s>" % maybe_truncate(str(rr)) for rr in self)
def __repr__(self):
if self.covers == 0:
- ctext = ''
+ ctext = ""
else:
- ctext = '(' + dns.rdatatype.to_text(self.covers) + ')'
- return '<DNS ' + dns.rdataclass.to_text(self.rdclass
- ) + ' ' + dns.rdatatype.to_text(self.rdtype
- ) + ctext + ' rdataset: ' + self._rdata_repr() + '>'
+ ctext = "(" + dns.rdatatype.to_text(self.covers) + ")"
+ return (
+ "<DNS "
+ + dns.rdataclass.to_text(self.rdclass)
+ + " "
+ + dns.rdatatype.to_text(self.rdtype)
+ + ctext
+ + " rdataset: "
+ + self._rdata_repr()
+ + ">"
+ )
def __str__(self):
return self.to_text()
@@ -99,18 +184,26 @@ class Rdataset(dns.set.Set):
def __eq__(self, other):
if not isinstance(other, Rdataset):
return False
- if (self.rdclass != other.rdclass or self.rdtype != other.rdtype or
- self.covers != other.covers):
+ if (
+ self.rdclass != other.rdclass
+ or self.rdtype != other.rdtype
+ or self.covers != other.covers
+ ):
return False
return super().__eq__(other)
def __ne__(self, other):
return not self.__eq__(other)
- def to_text(self, name: Optional[dns.name.Name]=None, origin: Optional[
- dns.name.Name]=None, relativize: bool=True, override_rdclass:
- Optional[dns.rdataclass.RdataClass]=None, want_comments: bool=False,
- **kw: Dict[str, Any]) ->str:
+ def to_text(
+ self,
+ name: Optional[dns.name.Name] = None,
+ origin: Optional[dns.name.Name] = None,
+ relativize: bool = True,
+ override_rdclass: Optional[dns.rdataclass.RdataClass] = None,
+ want_comments: bool = False,
+ **kw: Dict[str, Any],
+ ) -> str:
"""Convert the rdataset into DNS zone file format.
See ``dns.name.Name.choose_relativity`` for more information
@@ -135,12 +228,65 @@ class Rdataset(dns.set.Set):
*want_comments*, a ``bool``. If ``True``, emit comments for rdata
which have them. The default is ``False``.
"""
- pass
- def to_wire(self, name: dns.name.Name, file: Any, compress: Optional[
- dns.name.CompressType]=None, origin: Optional[dns.name.Name]=None,
- override_rdclass: Optional[dns.rdataclass.RdataClass]=None,
- want_shuffle: bool=True) ->int:
+ if name is not None:
+ name = name.choose_relativity(origin, relativize)
+ ntext = str(name)
+ pad = " "
+ else:
+ ntext = ""
+ pad = ""
+ s = io.StringIO()
+ if override_rdclass is not None:
+ rdclass = override_rdclass
+ else:
+ rdclass = self.rdclass
+ if len(self) == 0:
+ #
+ # Empty rdatasets are used for the question section, and in
+ # some dynamic updates, so we don't need to print out the TTL
+ # (which is meaningless anyway).
+ #
+ s.write(
+ "{}{}{} {}\n".format(
+ ntext,
+ pad,
+ dns.rdataclass.to_text(rdclass),
+ dns.rdatatype.to_text(self.rdtype),
+ )
+ )
+ else:
+ for rd in self:
+ extra = ""
+ if want_comments:
+ if rd.rdcomment:
+ extra = f" ;{rd.rdcomment}"
+ s.write(
+ "%s%s%d %s %s %s%s\n"
+ % (
+ ntext,
+ pad,
+ self.ttl,
+ dns.rdataclass.to_text(rdclass),
+ dns.rdatatype.to_text(self.rdtype),
+ rd.to_text(origin=origin, relativize=relativize, **kw),
+ extra,
+ )
+ )
+ #
+ # We strip off the final \n for the caller's convenience in printing
+ #
+ return s.getvalue()[:-1]
+
+ def to_wire(
+ self,
+ name: dns.name.Name,
+ file: Any,
+ compress: Optional[dns.name.CompressType] = None,
+ origin: Optional[dns.name.Name] = None,
+ override_rdclass: Optional[dns.rdataclass.RdataClass] = None,
+ want_shuffle: bool = True,
+ ) -> int:
"""Convert the rdataset to wire format.
*name*, a ``dns.name.Name`` is the owner name to use.
@@ -164,16 +310,44 @@ class Rdataset(dns.set.Set):
Returns an ``int``, the number of records emitted.
"""
- pass
- def match(self, rdclass: dns.rdataclass.RdataClass, rdtype: dns.
- rdatatype.RdataType, covers: dns.rdatatype.RdataType) ->bool:
+ if override_rdclass is not None:
+ rdclass = override_rdclass
+ want_shuffle = False
+ else:
+ rdclass = self.rdclass
+ if len(self) == 0:
+ name.to_wire(file, compress, origin)
+ file.write(struct.pack("!HHIH", self.rdtype, rdclass, 0, 0))
+ return 1
+ else:
+ l: Union[Rdataset, List[dns.rdata.Rdata]]
+ if want_shuffle:
+ l = list(self)
+ random.shuffle(l)
+ else:
+ l = self
+ for rd in l:
+ name.to_wire(file, compress, origin)
+ file.write(struct.pack("!HHI", self.rdtype, rdclass, self.ttl))
+ with dns.renderer.prefixed_length(file, 2):
+ rd.to_wire(file, compress, origin)
+ return len(self)
+
+ def match(
+ self,
+ rdclass: dns.rdataclass.RdataClass,
+ rdtype: dns.rdatatype.RdataType,
+ covers: dns.rdatatype.RdataType,
+ ) -> bool:
"""Returns ``True`` if this rdataset matches the specified class,
type, and covers.
"""
- pass
+ if self.rdclass == rdclass and self.rdtype == rdtype and self.covers == covers:
+ return True
+ return False
- def processing_order(self) ->List[dns.rdata.Rdata]:
+ def processing_order(self) -> List[dns.rdata.Rdata]:
"""Return rdatas in a valid processing order according to the type's
specification. For example, MX records are in preference order from
lowest to highest preferences, with items of the same preference
@@ -182,44 +356,92 @@ class Rdataset(dns.set.Set):
For types that do not define a processing order, the rdatas are
simply shuffled.
"""
- pass
+ if len(self) == 0:
+ return []
+ else:
+ return self[0]._processing_order(iter(self))
@dns.immutable.immutable
-class ImmutableRdataset(Rdataset):
+class ImmutableRdataset(Rdataset): # lgtm[py/missing-equals]
"""An immutable DNS rdataset."""
+
_clone_class = Rdataset
def __init__(self, rdataset: Rdataset):
"""Create an immutable rdataset from the specified rdataset."""
- super().__init__(rdataset.rdclass, rdataset.rdtype, rdataset.covers,
- rdataset.ttl)
+
+ super().__init__(
+ rdataset.rdclass, rdataset.rdtype, rdataset.covers, rdataset.ttl
+ )
self.items = dns.immutable.Dict(rdataset.items)
+ def update_ttl(self, ttl):
+ raise TypeError("immutable")
+
+ def add(self, rd, ttl=None):
+ raise TypeError("immutable")
+
+ def union_update(self, other):
+ raise TypeError("immutable")
+
+ def intersection_update(self, other):
+ raise TypeError("immutable")
+
+ def update(self, other):
+ raise TypeError("immutable")
+
def __delitem__(self, i):
- raise TypeError('immutable')
+ raise TypeError("immutable")
+
+ # lgtm complains about these not raising ArithmeticError, but there is
+ # precedent for overrides of these methods in other classes to raise
+ # TypeError, and it seems like the better exception.
- def __ior__(self, other):
- raise TypeError('immutable')
+ def __ior__(self, other): # lgtm[py/unexpected-raise-in-special-method]
+ raise TypeError("immutable")
- def __iand__(self, other):
- raise TypeError('immutable')
+ def __iand__(self, other): # lgtm[py/unexpected-raise-in-special-method]
+ raise TypeError("immutable")
- def __iadd__(self, other):
- raise TypeError('immutable')
+ def __iadd__(self, other): # lgtm[py/unexpected-raise-in-special-method]
+ raise TypeError("immutable")
- def __isub__(self, other):
- raise TypeError('immutable')
+ def __isub__(self, other): # lgtm[py/unexpected-raise-in-special-method]
+ raise TypeError("immutable")
+
+ def clear(self):
+ raise TypeError("immutable")
def __copy__(self):
return ImmutableRdataset(super().copy())
+ def copy(self):
+ return ImmutableRdataset(super().copy())
-def from_text_list(rdclass: Union[dns.rdataclass.RdataClass, str], rdtype:
- Union[dns.rdatatype.RdataType, str], ttl: int, text_rdatas: Collection[
- str], idna_codec: Optional[dns.name.IDNACodec]=None, origin: Optional[
- dns.name.Name]=None, relativize: bool=True, relativize_to: Optional[dns
- .name.Name]=None) ->Rdataset:
+ def union(self, other):
+ return ImmutableRdataset(super().union(other))
+
+ def intersection(self, other):
+ return ImmutableRdataset(super().intersection(other))
+
+ def difference(self, other):
+ return ImmutableRdataset(super().difference(other))
+
+ def symmetric_difference(self, other):
+ return ImmutableRdataset(super().symmetric_difference(other))
+
+
+def from_text_list(
+ rdclass: Union[dns.rdataclass.RdataClass, str],
+ rdtype: Union[dns.rdatatype.RdataType, str],
+ ttl: int,
+ text_rdatas: Collection[str],
+ idna_codec: Optional[dns.name.IDNACodec] = None,
+ origin: Optional[dns.name.Name] = None,
+ relativize: bool = True,
+ relativize_to: Optional[dns.name.Name] = None,
+) -> Rdataset:
"""Create an rdataset with the specified class, type, and TTL, and with
the specified list of rdatas in text format.
@@ -237,32 +459,58 @@ def from_text_list(rdclass: Union[dns.rdataclass.RdataClass, str], rdtype:
Returns a ``dns.rdataset.Rdataset`` object.
"""
- pass
-
-def from_text(rdclass: Union[dns.rdataclass.RdataClass, str], rdtype: Union
- [dns.rdatatype.RdataType, str], ttl: int, *text_rdatas: Any) ->Rdataset:
+ rdclass = dns.rdataclass.RdataClass.make(rdclass)
+ rdtype = dns.rdatatype.RdataType.make(rdtype)
+ r = Rdataset(rdclass, rdtype)
+ r.update_ttl(ttl)
+ for t in text_rdatas:
+ rd = dns.rdata.from_text(
+ r.rdclass, r.rdtype, t, origin, relativize, relativize_to, idna_codec
+ )
+ r.add(rd)
+ return r
+
+
+def from_text(
+ rdclass: Union[dns.rdataclass.RdataClass, str],
+ rdtype: Union[dns.rdatatype.RdataType, str],
+ ttl: int,
+ *text_rdatas: Any,
+) -> Rdataset:
"""Create an rdataset with the specified class, type, and TTL, and with
the specified rdatas in text format.
Returns a ``dns.rdataset.Rdataset`` object.
"""
- pass
+
+ return from_text_list(rdclass, rdtype, ttl, cast(Collection[str], text_rdatas))
-def from_rdata_list(ttl: int, rdatas: Collection[dns.rdata.Rdata]) ->Rdataset:
+def from_rdata_list(ttl: int, rdatas: Collection[dns.rdata.Rdata]) -> Rdataset:
"""Create an rdataset with the specified TTL, and with
the specified list of rdata objects.
Returns a ``dns.rdataset.Rdataset`` object.
"""
- pass
+ if len(rdatas) == 0:
+ raise ValueError("rdata list must not be empty")
+ r = None
+ for rd in rdatas:
+ if r is None:
+ r = Rdataset(rd.rdclass, rd.rdtype)
+ r.update_ttl(ttl)
+ r.add(rd)
+ assert r is not None
+ return r
-def from_rdata(ttl: int, *rdatas: Any) ->Rdataset:
+
+def from_rdata(ttl: int, *rdatas: Any) -> Rdataset:
"""Create an rdataset with the specified TTL, and with
the specified rdata objects.
Returns a ``dns.rdataset.Rdataset`` object.
"""
- pass
+
+ return from_rdata_list(ttl, cast(Collection[dns.rdata.Rdata], rdatas))
diff --git a/dns/rdatatype.py b/dns/rdatatype.py
index f375d83..e6c5818 100644
--- a/dns/rdatatype.py
+++ b/dns/rdatatype.py
@@ -1,11 +1,31 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2001-2017 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
"""DNS Rdata Types."""
+
from typing import Dict
+
import dns.enum
import dns.exception
class RdataType(dns.enum.IntEnum):
"""DNS Rdata Type"""
+
TYPE0 = 0
NONE = 0
A = 1
@@ -88,19 +108,59 @@ class RdataType(dns.enum.IntEnum):
TA = 32768
DLV = 32769
+ @classmethod
+ def _maximum(cls):
+ return 65535
+
+ @classmethod
+ def _short_name(cls):
+ return "type"
+
+ @classmethod
+ def _prefix(cls):
+ return "TYPE"
+
+ @classmethod
+ def _extra_from_text(cls, text):
+ if text.find("-") >= 0:
+ try:
+ return cls[text.replace("-", "_")]
+ except KeyError:
+ pass
+ return _registered_by_text.get(text)
+
+ @classmethod
+ def _extra_to_text(cls, value, current_text):
+ if current_text is None:
+ return _registered_by_value.get(value)
+ if current_text.find("_") >= 0:
+ return current_text.replace("_", "-")
+ return current_text
+
+ @classmethod
+ def _unknown_exception_class(cls):
+ return UnknownRdatatype
+
_registered_by_text: Dict[str, RdataType] = {}
_registered_by_value: Dict[RdataType, str] = {}
+
_metatypes = {RdataType.OPT}
-_singletons = {RdataType.SOA, RdataType.NXT, RdataType.DNAME, RdataType.
- NSEC, RdataType.CNAME}
+
+_singletons = {
+ RdataType.SOA,
+ RdataType.NXT,
+ RdataType.DNAME,
+ RdataType.NSEC,
+ RdataType.CNAME,
+}
class UnknownRdatatype(dns.exception.DNSException):
"""DNS resource record type is unknown."""
-def from_text(text: str) ->RdataType:
+def from_text(text: str) -> RdataType:
"""Convert text into a DNS rdata type value.
The input text can be a defined DNS RR type mnemonic or
@@ -114,10 +174,11 @@ def from_text(text: str) ->RdataType:
Returns a ``dns.rdatatype.RdataType``.
"""
- pass
+
+ return RdataType.from_text(text)
-def to_text(value: RdataType) ->str:
+def to_text(value: RdataType) -> str:
"""Convert a DNS rdata type value to text.
If the value has a known mnemonic, it will be used, otherwise the
@@ -127,10 +188,11 @@ def to_text(value: RdataType) ->str:
Returns a ``str``.
"""
- pass
+ return RdataType.to_text(value)
-def is_metatype(rdtype: RdataType) ->bool:
+
+def is_metatype(rdtype: RdataType) -> bool:
"""True if the specified type is a metatype.
*rdtype* is a ``dns.rdatatype.RdataType``.
@@ -140,10 +202,11 @@ def is_metatype(rdtype: RdataType) ->bool:
Returns a ``bool``.
"""
- pass
+
+ return (256 > rdtype >= 128) or rdtype in _metatypes
-def is_singleton(rdtype: RdataType) ->bool:
+def is_singleton(rdtype: RdataType) -> bool:
"""Is the specified type a singleton type?
Singleton types can only have a single rdata in an rdataset, or a single
@@ -156,11 +219,16 @@ def is_singleton(rdtype: RdataType) ->bool:
Returns a ``bool``.
"""
- pass
+ if rdtype in _singletons:
+ return True
+ return False
-def register_type(rdtype: RdataType, rdtype_text: str, is_singleton: bool=False
- ) ->None:
+
+# pylint: disable=redefined-outer-name
+def register_type(
+ rdtype: RdataType, rdtype_text: str, is_singleton: bool = False
+) -> None:
"""Dynamically register an rdatatype.
*rdtype*, a ``dns.rdatatype.RdataType``, the rdatatype to register.
@@ -170,9 +238,15 @@ def register_type(rdtype: RdataType, rdtype_text: str, is_singleton: bool=False
*is_singleton*, a ``bool``, indicating if the type is a singleton (i.e.
RRsets of the type can have only one member.)
"""
- pass
+
+ _registered_by_text[rdtype_text] = rdtype
+ _registered_by_value[rdtype] = rdtype_text
+ if is_singleton:
+ _singletons.add(rdtype)
+### BEGIN generated RdataType constants
+
TYPE0 = RdataType.TYPE0
NONE = RdataType.NONE
A = RdataType.A
@@ -254,3 +328,5 @@ AVC = RdataType.AVC
AMTRELAY = RdataType.AMTRELAY
TA = RdataType.TA
DLV = RdataType.DLV
+
+### END generated RdataType constants
diff --git a/dns/rdtypes/ANY/AFSDB.py b/dns/rdtypes/ANY/AFSDB.py
index 085e3a6..06a3b97 100644
--- a/dns/rdtypes/ANY/AFSDB.py
+++ b/dns/rdtypes/ANY/AFSDB.py
@@ -1,3 +1,20 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
import dns.immutable
import dns.rdtypes.mxbase
@@ -6,12 +23,23 @@ import dns.rdtypes.mxbase
class AFSDB(dns.rdtypes.mxbase.UncompressedDowncasingMX):
"""AFSDB record"""
+ # Use the property mechanism to make "subtype" an alias for the
+ # "preference" attribute, and "hostname" an alias for the "exchange"
+ # attribute.
+ #
+ # This lets us inherit the UncompressedMX implementation but lets
+ # the caller use appropriate attribute names for the rdata type.
+ #
+ # We probably lose some performance vs. a cut-and-paste
+ # implementation, but this way we don't copy code, and that's
+ # good.
+
@property
def subtype(self):
- """the AFSDB subtype"""
- pass
+ "the AFSDB subtype"
+ return self.preference
@property
def hostname(self):
- """the AFSDB hostname"""
- pass
+ "the AFSDB hostname"
+ return self.exchange
diff --git a/dns/rdtypes/ANY/AMTRELAY.py b/dns/rdtypes/ANY/AMTRELAY.py
index a0c35ed..ed2b072 100644
--- a/dns/rdtypes/ANY/AMTRELAY.py
+++ b/dns/rdtypes/ANY/AMTRELAY.py
@@ -1,23 +1,91 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2006, 2007, 2009-2011 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
import struct
+
import dns.exception
import dns.immutable
import dns.rdtypes.util
class Relay(dns.rdtypes.util.Gateway):
- name = 'AMTRELAY relay'
+ name = "AMTRELAY relay"
+
+ @property
+ def relay(self):
+ return self.gateway
@dns.immutable.immutable
class AMTRELAY(dns.rdata.Rdata):
"""AMTRELAY record"""
- __slots__ = ['precedence', 'discovery_optional', 'relay_type', 'relay']
- def __init__(self, rdclass, rdtype, precedence, discovery_optional,
- relay_type, relay):
+ # see: RFC 8777
+
+ __slots__ = ["precedence", "discovery_optional", "relay_type", "relay"]
+
+ def __init__(
+ self, rdclass, rdtype, precedence, discovery_optional, relay_type, relay
+ ):
super().__init__(rdclass, rdtype)
relay = Relay(relay_type, relay)
self.precedence = self._as_uint8(precedence)
self.discovery_optional = self._as_bool(discovery_optional)
self.relay_type = relay.type
self.relay = relay.relay
+
+ def to_text(self, origin=None, relativize=True, **kw):
+ relay = Relay(self.relay_type, self.relay).to_text(origin, relativize)
+ return "%d %d %d %s" % (
+ self.precedence,
+ self.discovery_optional,
+ self.relay_type,
+ relay,
+ )
+
+ @classmethod
+ def from_text(
+ cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
+ ):
+ precedence = tok.get_uint8()
+ discovery_optional = tok.get_uint8()
+ if discovery_optional > 1:
+ raise dns.exception.SyntaxError("expecting 0 or 1")
+ discovery_optional = bool(discovery_optional)
+ relay_type = tok.get_uint8()
+ if relay_type > 0x7F:
+ raise dns.exception.SyntaxError("expecting an integer <= 127")
+ relay = Relay.from_text(relay_type, tok, origin, relativize, relativize_to)
+ return cls(
+ rdclass, rdtype, precedence, discovery_optional, relay_type, relay.relay
+ )
+
+ def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
+ relay_type = self.relay_type | (self.discovery_optional << 7)
+ header = struct.pack("!BB", self.precedence, relay_type)
+ file.write(header)
+ Relay(self.relay_type, self.relay).to_wire(file, compress, origin, canonicalize)
+
+ @classmethod
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ (precedence, relay_type) = parser.get_struct("!BB")
+ discovery_optional = bool(relay_type >> 7)
+ relay_type &= 0x7F
+ relay = Relay.from_wire_parser(relay_type, parser, origin)
+ return cls(
+ rdclass, rdtype, precedence, discovery_optional, relay_type, relay.relay
+ )
diff --git a/dns/rdtypes/ANY/AVC.py b/dns/rdtypes/ANY/AVC.py
index fdcfaeb..a27ae2d 100644
--- a/dns/rdtypes/ANY/AVC.py
+++ b/dns/rdtypes/ANY/AVC.py
@@ -1,3 +1,20 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2016 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
import dns.immutable
import dns.rdtypes.txtbase
@@ -5,3 +22,5 @@ import dns.rdtypes.txtbase
@dns.immutable.immutable
class AVC(dns.rdtypes.txtbase.TXTBase):
"""AVC record"""
+
+ # See: IANA dns parameters for AVC
diff --git a/dns/rdtypes/ANY/CAA.py b/dns/rdtypes/ANY/CAA.py
index e9e0c89..2e6a7e7 100644
--- a/dns/rdtypes/ANY/CAA.py
+++ b/dns/rdtypes/ANY/CAA.py
@@ -1,4 +1,22 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
import struct
+
import dns.exception
import dns.immutable
import dns.rdata
@@ -8,12 +26,46 @@ import dns.tokenizer
@dns.immutable.immutable
class CAA(dns.rdata.Rdata):
"""CAA (Certification Authority Authorization) record"""
- __slots__ = ['flags', 'tag', 'value']
+
+ # see: RFC 6844
+
+ __slots__ = ["flags", "tag", "value"]
def __init__(self, rdclass, rdtype, flags, tag, value):
super().__init__(rdclass, rdtype)
self.flags = self._as_uint8(flags)
self.tag = self._as_bytes(tag, True, 255)
if not tag.isalnum():
- raise ValueError('tag is not alphanumeric')
+ raise ValueError("tag is not alphanumeric")
self.value = self._as_bytes(value)
+
+ def to_text(self, origin=None, relativize=True, **kw):
+ return '%u %s "%s"' % (
+ self.flags,
+ dns.rdata._escapify(self.tag),
+ dns.rdata._escapify(self.value),
+ )
+
+ @classmethod
+ def from_text(
+ cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
+ ):
+ flags = tok.get_uint8()
+ tag = tok.get_string().encode()
+ value = tok.get_string().encode()
+ return cls(rdclass, rdtype, flags, tag, value)
+
+ def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
+ file.write(struct.pack("!B", self.flags))
+ l = len(self.tag)
+ assert l < 256
+ file.write(struct.pack("!B", l))
+ file.write(self.tag)
+ file.write(self.value)
+
+ @classmethod
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ flags = parser.get_uint8()
+ tag = parser.get_counted_bytes()
+ value = parser.get_remaining()
+ return cls(rdclass, rdtype, flags, tag, value)
diff --git a/dns/rdtypes/ANY/CDNSKEY.py b/dns/rdtypes/ANY/CDNSKEY.py
index 4c3c80e..b613409 100644
--- a/dns/rdtypes/ANY/CDNSKEY.py
+++ b/dns/rdtypes/ANY/CDNSKEY.py
@@ -1,6 +1,31 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2004-2007, 2009-2011 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
import dns.immutable
-import dns.rdtypes.dnskeybase
-from dns.rdtypes.dnskeybase import REVOKE, SEP, ZONE
+import dns.rdtypes.dnskeybase # lgtm[py/import-and-import-from]
+
+# pylint: disable=unused-import
+from dns.rdtypes.dnskeybase import ( # noqa: F401 lgtm[py/unused-import]
+ REVOKE,
+ SEP,
+ ZONE,
+)
+
+# pylint: enable=unused-import
@dns.immutable.immutable
diff --git a/dns/rdtypes/ANY/CDS.py b/dns/rdtypes/ANY/CDS.py
index d79117e..8312b97 100644
--- a/dns/rdtypes/ANY/CDS.py
+++ b/dns/rdtypes/ANY/CDS.py
@@ -1,3 +1,20 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
import dns.immutable
import dns.rdtypes.dsbase
@@ -5,5 +22,8 @@ import dns.rdtypes.dsbase
@dns.immutable.immutable
class CDS(dns.rdtypes.dsbase.DSBase):
"""CDS record"""
- _digest_length_by_type = {**dns.rdtypes.dsbase.DSBase.
- _digest_length_by_type, (0): 1}
+
+ _digest_length_by_type = {
+ **dns.rdtypes.dsbase.DSBase._digest_length_by_type,
+ 0: 1, # delete, RFC 8078 Sec. 4 (including Errata ID 5049)
+ }
diff --git a/dns/rdtypes/ANY/CERT.py b/dns/rdtypes/ANY/CERT.py
index 922f467..f369cc8 100644
--- a/dns/rdtypes/ANY/CERT.py
+++ b/dns/rdtypes/ANY/CERT.py
@@ -1,26 +1,116 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
import base64
import struct
+
import dns.dnssectypes
import dns.exception
import dns.immutable
import dns.rdata
import dns.tokenizer
-_ctype_by_value = {(1): 'PKIX', (2): 'SPKI', (3): 'PGP', (4): 'IPKIX', (5):
- 'ISPKI', (6): 'IPGP', (7): 'ACPKIX', (8): 'IACPKIX', (253): 'URI', (254
- ): 'OID'}
-_ctype_by_name = {'PKIX': 1, 'SPKI': 2, 'PGP': 3, 'IPKIX': 4, 'ISPKI': 5,
- 'IPGP': 6, 'ACPKIX': 7, 'IACPKIX': 8, 'URI': 253, 'OID': 254}
+
+_ctype_by_value = {
+ 1: "PKIX",
+ 2: "SPKI",
+ 3: "PGP",
+ 4: "IPKIX",
+ 5: "ISPKI",
+ 6: "IPGP",
+ 7: "ACPKIX",
+ 8: "IACPKIX",
+ 253: "URI",
+ 254: "OID",
+}
+
+_ctype_by_name = {
+ "PKIX": 1,
+ "SPKI": 2,
+ "PGP": 3,
+ "IPKIX": 4,
+ "ISPKI": 5,
+ "IPGP": 6,
+ "ACPKIX": 7,
+ "IACPKIX": 8,
+ "URI": 253,
+ "OID": 254,
+}
+
+
+def _ctype_from_text(what):
+ v = _ctype_by_name.get(what)
+ if v is not None:
+ return v
+ return int(what)
+
+
+def _ctype_to_text(what):
+ v = _ctype_by_value.get(what)
+ if v is not None:
+ return v
+ return str(what)
@dns.immutable.immutable
class CERT(dns.rdata.Rdata):
"""CERT record"""
- __slots__ = ['certificate_type', 'key_tag', 'algorithm', 'certificate']
- def __init__(self, rdclass, rdtype, certificate_type, key_tag,
- algorithm, certificate):
+ # see RFC 4398
+
+ __slots__ = ["certificate_type", "key_tag", "algorithm", "certificate"]
+
+ def __init__(
+ self, rdclass, rdtype, certificate_type, key_tag, algorithm, certificate
+ ):
super().__init__(rdclass, rdtype)
self.certificate_type = self._as_uint16(certificate_type)
self.key_tag = self._as_uint16(key_tag)
self.algorithm = self._as_uint8(algorithm)
self.certificate = self._as_bytes(certificate)
+
+ def to_text(self, origin=None, relativize=True, **kw):
+ certificate_type = _ctype_to_text(self.certificate_type)
+ return "%s %d %s %s" % (
+ certificate_type,
+ self.key_tag,
+ dns.dnssectypes.Algorithm.to_text(self.algorithm),
+ dns.rdata._base64ify(self.certificate, **kw),
+ )
+
+ @classmethod
+ def from_text(
+ cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
+ ):
+ certificate_type = _ctype_from_text(tok.get_string())
+ key_tag = tok.get_uint16()
+ algorithm = dns.dnssectypes.Algorithm.from_text(tok.get_string())
+ b64 = tok.concatenate_remaining_identifiers().encode()
+ certificate = base64.b64decode(b64)
+ return cls(rdclass, rdtype, certificate_type, key_tag, algorithm, certificate)
+
+ def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
+ prefix = struct.pack(
+ "!HHB", self.certificate_type, self.key_tag, self.algorithm
+ )
+ file.write(prefix)
+ file.write(self.certificate)
+
+ @classmethod
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ (certificate_type, key_tag, algorithm) = parser.get_struct("!HHB")
+ certificate = parser.get_remaining()
+ return cls(rdclass, rdtype, certificate_type, key_tag, algorithm, certificate)
diff --git a/dns/rdtypes/ANY/CNAME.py b/dns/rdtypes/ANY/CNAME.py
index 573f74e..665e407 100644
--- a/dns/rdtypes/ANY/CNAME.py
+++ b/dns/rdtypes/ANY/CNAME.py
@@ -1,3 +1,20 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
import dns.immutable
import dns.rdtypes.nsbase
diff --git a/dns/rdtypes/ANY/CSYNC.py b/dns/rdtypes/ANY/CSYNC.py
index 88807e3..2f972f6 100644
--- a/dns/rdtypes/ANY/CSYNC.py
+++ b/dns/rdtypes/ANY/CSYNC.py
@@ -1,4 +1,22 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2004-2007, 2009-2011, 2016 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
import struct
+
import dns.exception
import dns.immutable
import dns.name
@@ -9,13 +27,14 @@ import dns.rdtypes.util
@dns.immutable.immutable
class Bitmap(dns.rdtypes.util.Bitmap):
- type_name = 'CSYNC'
+ type_name = "CSYNC"
@dns.immutable.immutable
class CSYNC(dns.rdata.Rdata):
"""CSYNC record"""
- __slots__ = ['serial', 'flags', 'windows']
+
+ __slots__ = ["serial", "flags", "windows"]
def __init__(self, rdclass, rdtype, serial, flags, windows):
super().__init__(rdclass, rdtype)
@@ -24,3 +43,26 @@ class CSYNC(dns.rdata.Rdata):
if not isinstance(windows, Bitmap):
windows = Bitmap(windows)
self.windows = tuple(windows.windows)
+
+ def to_text(self, origin=None, relativize=True, **kw):
+ text = Bitmap(self.windows).to_text()
+ return "%d %d%s" % (self.serial, self.flags, text)
+
+ @classmethod
+ def from_text(
+ cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
+ ):
+ serial = tok.get_uint32()
+ flags = tok.get_uint16()
+ bitmap = Bitmap.from_text(tok)
+ return cls(rdclass, rdtype, serial, flags, bitmap)
+
+ def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
+ file.write(struct.pack("!IH", self.serial, self.flags))
+ Bitmap(self.windows).to_wire(file)
+
+ @classmethod
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ (serial, flags) = parser.get_struct("!IH")
+ bitmap = Bitmap.from_wire_parser(parser)
+ return cls(rdclass, rdtype, serial, flags, bitmap)
diff --git a/dns/rdtypes/ANY/DLV.py b/dns/rdtypes/ANY/DLV.py
index 19c3328..6c134f1 100644
--- a/dns/rdtypes/ANY/DLV.py
+++ b/dns/rdtypes/ANY/DLV.py
@@ -1,3 +1,20 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2009-2011 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
import dns.immutable
import dns.rdtypes.dsbase
diff --git a/dns/rdtypes/ANY/DNAME.py b/dns/rdtypes/ANY/DNAME.py
index 279fb81..bbf9186 100644
--- a/dns/rdtypes/ANY/DNAME.py
+++ b/dns/rdtypes/ANY/DNAME.py
@@ -1,3 +1,20 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
import dns.immutable
import dns.rdtypes.nsbase
@@ -5,3 +22,6 @@ import dns.rdtypes.nsbase
@dns.immutable.immutable
class DNAME(dns.rdtypes.nsbase.UncompressedNS):
"""DNAME record"""
+
+ def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
+ self.target.to_wire(file, None, origin, canonicalize)
diff --git a/dns/rdtypes/ANY/DNSKEY.py b/dns/rdtypes/ANY/DNSKEY.py
index 9b0bfaf..6d961a9 100644
--- a/dns/rdtypes/ANY/DNSKEY.py
+++ b/dns/rdtypes/ANY/DNSKEY.py
@@ -1,6 +1,31 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2004-2007, 2009-2011 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
import dns.immutable
-import dns.rdtypes.dnskeybase
-from dns.rdtypes.dnskeybase import REVOKE, SEP, ZONE
+import dns.rdtypes.dnskeybase # lgtm[py/import-and-import-from]
+
+# pylint: disable=unused-import
+from dns.rdtypes.dnskeybase import ( # noqa: F401 lgtm[py/unused-import]
+ REVOKE,
+ SEP,
+ ZONE,
+)
+
+# pylint: enable=unused-import
@dns.immutable.immutable
diff --git a/dns/rdtypes/ANY/DS.py b/dns/rdtypes/ANY/DS.py
index d5a27a9..58b3108 100644
--- a/dns/rdtypes/ANY/DS.py
+++ b/dns/rdtypes/ANY/DS.py
@@ -1,3 +1,20 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
import dns.immutable
import dns.rdtypes.dsbase
diff --git a/dns/rdtypes/ANY/EUI48.py b/dns/rdtypes/ANY/EUI48.py
index b99cfbe..c843be5 100644
--- a/dns/rdtypes/ANY/EUI48.py
+++ b/dns/rdtypes/ANY/EUI48.py
@@ -1,3 +1,21 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2015 Red Hat, Inc.
+# Author: Petr Spacek <pspacek@redhat.com>
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED 'AS IS' AND RED HAT DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
import dns.immutable
import dns.rdtypes.euibase
@@ -5,5 +23,8 @@ import dns.rdtypes.euibase
@dns.immutable.immutable
class EUI48(dns.rdtypes.euibase.EUIBase):
"""EUI48 record"""
- byte_len = 6
- text_len = byte_len * 3 - 1
+
+ # see: rfc7043.txt
+
+ byte_len = 6 # 0123456789ab (in hex)
+ text_len = byte_len * 3 - 1 # 01-23-45-67-89-ab
diff --git a/dns/rdtypes/ANY/EUI64.py b/dns/rdtypes/ANY/EUI64.py
index 1789d71..f6d7e25 100644
--- a/dns/rdtypes/ANY/EUI64.py
+++ b/dns/rdtypes/ANY/EUI64.py
@@ -1,3 +1,21 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2015 Red Hat, Inc.
+# Author: Petr Spacek <pspacek@redhat.com>
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED 'AS IS' AND RED HAT DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
import dns.immutable
import dns.rdtypes.euibase
@@ -5,5 +23,8 @@ import dns.rdtypes.euibase
@dns.immutable.immutable
class EUI64(dns.rdtypes.euibase.EUIBase):
"""EUI64 record"""
- byte_len = 8
- text_len = byte_len * 3 - 1
+
+ # see: rfc7043.txt
+
+ byte_len = 8 # 0123456789abcdef (in hex)
+ text_len = byte_len * 3 - 1 # 01-23-45-67-89-ab-cd-ef
diff --git a/dns/rdtypes/ANY/GPOS.py b/dns/rdtypes/ANY/GPOS.py
index 1f28c06..312338f 100644
--- a/dns/rdtypes/ANY/GPOS.py
+++ b/dns/rdtypes/ANY/GPOS.py
@@ -1,14 +1,54 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
import struct
+
import dns.exception
import dns.immutable
import dns.rdata
import dns.tokenizer
+def _validate_float_string(what):
+ if len(what) == 0:
+ raise dns.exception.FormError
+ if what[0] == b"-"[0] or what[0] == b"+"[0]:
+ what = what[1:]
+ if what.isdigit():
+ return
+ try:
+ (left, right) = what.split(b".")
+ except ValueError:
+ raise dns.exception.FormError
+ if left == b"" and right == b"":
+ raise dns.exception.FormError
+ if not left == b"" and not left.decode().isdigit():
+ raise dns.exception.FormError
+ if not right == b"" and not right.decode().isdigit():
+ raise dns.exception.FormError
+
+
@dns.immutable.immutable
class GPOS(dns.rdata.Rdata):
"""GPOS record"""
- __slots__ = ['latitude', 'longitude', 'altitude']
+
+ # see: RFC 1712
+
+ __slots__ = ["latitude", "longitude", "altitude"]
def __init__(self, rdclass, rdtype, latitude, longitude, altitude):
super().__init__(rdclass, rdtype)
@@ -29,22 +69,57 @@ class GPOS(dns.rdata.Rdata):
self.altitude = altitude
flat = self.float_latitude
if flat < -90.0 or flat > 90.0:
- raise dns.exception.FormError('bad latitude')
+ raise dns.exception.FormError("bad latitude")
flong = self.float_longitude
if flong < -180.0 or flong > 180.0:
- raise dns.exception.FormError('bad longitude')
+ raise dns.exception.FormError("bad longitude")
+
+ def to_text(self, origin=None, relativize=True, **kw):
+ return "{} {} {}".format(
+ self.latitude.decode(), self.longitude.decode(), self.altitude.decode()
+ )
+
+ @classmethod
+ def from_text(
+ cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
+ ):
+ latitude = tok.get_string()
+ longitude = tok.get_string()
+ altitude = tok.get_string()
+ return cls(rdclass, rdtype, latitude, longitude, altitude)
+
+ def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
+ l = len(self.latitude)
+ assert l < 256
+ file.write(struct.pack("!B", l))
+ file.write(self.latitude)
+ l = len(self.longitude)
+ assert l < 256
+ file.write(struct.pack("!B", l))
+ file.write(self.longitude)
+ l = len(self.altitude)
+ assert l < 256
+ file.write(struct.pack("!B", l))
+ file.write(self.altitude)
+
+ @classmethod
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ latitude = parser.get_counted_bytes()
+ longitude = parser.get_counted_bytes()
+ altitude = parser.get_counted_bytes()
+ return cls(rdclass, rdtype, latitude, longitude, altitude)
@property
def float_latitude(self):
- """latitude as a floating point value"""
- pass
+ "latitude as a floating point value"
+ return float(self.latitude)
@property
def float_longitude(self):
- """longitude as a floating point value"""
- pass
+ "longitude as a floating point value"
+ return float(self.longitude)
@property
def float_altitude(self):
- """altitude as a floating point value"""
- pass
+ "altitude as a floating point value"
+ return float(self.altitude)
diff --git a/dns/rdtypes/ANY/HINFO.py b/dns/rdtypes/ANY/HINFO.py
index b0da043..c2c45de 100644
--- a/dns/rdtypes/ANY/HINFO.py
+++ b/dns/rdtypes/ANY/HINFO.py
@@ -1,4 +1,22 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
import struct
+
import dns.exception
import dns.immutable
import dns.rdata
@@ -8,9 +26,41 @@ import dns.tokenizer
@dns.immutable.immutable
class HINFO(dns.rdata.Rdata):
"""HINFO record"""
- __slots__ = ['cpu', 'os']
+
+ # see: RFC 1035
+
+ __slots__ = ["cpu", "os"]
def __init__(self, rdclass, rdtype, cpu, os):
super().__init__(rdclass, rdtype)
self.cpu = self._as_bytes(cpu, True, 255)
self.os = self._as_bytes(os, True, 255)
+
+ def to_text(self, origin=None, relativize=True, **kw):
+ return '"{}" "{}"'.format(
+ dns.rdata._escapify(self.cpu), dns.rdata._escapify(self.os)
+ )
+
+ @classmethod
+ def from_text(
+ cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
+ ):
+ cpu = tok.get_string(max_length=255)
+ os = tok.get_string(max_length=255)
+ return cls(rdclass, rdtype, cpu, os)
+
+ def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
+ l = len(self.cpu)
+ assert l < 256
+ file.write(struct.pack("!B", l))
+ file.write(self.cpu)
+ l = len(self.os)
+ assert l < 256
+ file.write(struct.pack("!B", l))
+ file.write(self.os)
+
+ @classmethod
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ cpu = parser.get_counted_bytes()
+ os = parser.get_counted_bytes()
+ return cls(rdclass, rdtype, cpu, os)
diff --git a/dns/rdtypes/ANY/HIP.py b/dns/rdtypes/ANY/HIP.py
index 5590282..9166913 100644
--- a/dns/rdtypes/ANY/HIP.py
+++ b/dns/rdtypes/ANY/HIP.py
@@ -1,6 +1,24 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2010, 2011 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
import base64
import binascii
import struct
+
import dns.exception
import dns.immutable
import dns.rdata
@@ -10,7 +28,10 @@ import dns.rdatatype
@dns.immutable.immutable
class HIP(dns.rdata.Rdata):
"""HIP record"""
- __slots__ = ['hit', 'algorithm', 'key', 'servers']
+
+ # see: RFC 5205
+
+ __slots__ = ["hit", "algorithm", "key", "servers"]
def __init__(self, rdclass, rdtype, hit, algorithm, key, servers):
super().__init__(rdclass, rdtype)
@@ -18,3 +39,47 @@ class HIP(dns.rdata.Rdata):
self.algorithm = self._as_uint8(algorithm)
self.key = self._as_bytes(key, True)
self.servers = self._as_tuple(servers, self._as_name)
+
+ def to_text(self, origin=None, relativize=True, **kw):
+ hit = binascii.hexlify(self.hit).decode()
+ key = base64.b64encode(self.key).replace(b"\n", b"").decode()
+ text = ""
+ servers = []
+ for server in self.servers:
+ servers.append(server.choose_relativity(origin, relativize))
+ if len(servers) > 0:
+ text += " " + " ".join((x.to_unicode() for x in servers))
+ return "%u %s %s%s" % (self.algorithm, hit, key, text)
+
+ @classmethod
+ def from_text(
+ cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
+ ):
+ algorithm = tok.get_uint8()
+ hit = binascii.unhexlify(tok.get_string().encode())
+ key = base64.b64decode(tok.get_string().encode())
+ servers = []
+ for token in tok.get_remaining():
+ server = tok.as_name(token, origin, relativize, relativize_to)
+ servers.append(server)
+ return cls(rdclass, rdtype, hit, algorithm, key, servers)
+
+ def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
+ lh = len(self.hit)
+ lk = len(self.key)
+ file.write(struct.pack("!BBH", lh, self.algorithm, lk))
+ file.write(self.hit)
+ file.write(self.key)
+ for server in self.servers:
+ server.to_wire(file, None, origin, False)
+
+ @classmethod
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ (lh, algorithm, lk) = parser.get_struct("!BBH")
+ hit = parser.get_bytes(lh)
+ key = parser.get_bytes(lk)
+ servers = []
+ while parser.remaining() > 0:
+ server = parser.get_name(origin)
+ servers.append(server)
+ return cls(rdclass, rdtype, hit, algorithm, key, servers)
diff --git a/dns/rdtypes/ANY/ISDN.py b/dns/rdtypes/ANY/ISDN.py
index c961c24..fb01eab 100644
--- a/dns/rdtypes/ANY/ISDN.py
+++ b/dns/rdtypes/ANY/ISDN.py
@@ -1,4 +1,22 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
import struct
+
import dns.exception
import dns.immutable
import dns.rdata
@@ -8,9 +26,52 @@ import dns.tokenizer
@dns.immutable.immutable
class ISDN(dns.rdata.Rdata):
"""ISDN record"""
- __slots__ = ['address', 'subaddress']
+
+ # see: RFC 1183
+
+ __slots__ = ["address", "subaddress"]
def __init__(self, rdclass, rdtype, address, subaddress):
super().__init__(rdclass, rdtype)
self.address = self._as_bytes(address, True, 255)
self.subaddress = self._as_bytes(subaddress, True, 255)
+
+ def to_text(self, origin=None, relativize=True, **kw):
+ if self.subaddress:
+ return '"{}" "{}"'.format(
+ dns.rdata._escapify(self.address), dns.rdata._escapify(self.subaddress)
+ )
+ else:
+ return '"%s"' % dns.rdata._escapify(self.address)
+
+ @classmethod
+ def from_text(
+ cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
+ ):
+ address = tok.get_string()
+ tokens = tok.get_remaining(max_tokens=1)
+ if len(tokens) >= 1:
+ subaddress = tokens[0].unescape().value
+ else:
+ subaddress = ""
+ return cls(rdclass, rdtype, address, subaddress)
+
+ def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
+ l = len(self.address)
+ assert l < 256
+ file.write(struct.pack("!B", l))
+ file.write(self.address)
+ l = len(self.subaddress)
+ if l > 0:
+ assert l < 256
+ file.write(struct.pack("!B", l))
+ file.write(self.subaddress)
+
+ @classmethod
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ address = parser.get_counted_bytes()
+ if parser.remaining() > 0:
+ subaddress = parser.get_counted_bytes()
+ else:
+ subaddress = b""
+ return cls(rdclass, rdtype, address, subaddress)
diff --git a/dns/rdtypes/ANY/L32.py b/dns/rdtypes/ANY/L32.py
index 67c8691..09804c2 100644
--- a/dns/rdtypes/ANY/L32.py
+++ b/dns/rdtypes/ANY/L32.py
@@ -1,4 +1,7 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
import struct
+
import dns.immutable
import dns.rdata
@@ -6,9 +9,33 @@ import dns.rdata
@dns.immutable.immutable
class L32(dns.rdata.Rdata):
"""L32 record"""
- __slots__ = ['preference', 'locator32']
+
+ # see: rfc6742.txt
+
+ __slots__ = ["preference", "locator32"]
def __init__(self, rdclass, rdtype, preference, locator32):
super().__init__(rdclass, rdtype)
self.preference = self._as_uint16(preference)
self.locator32 = self._as_ipv4_address(locator32)
+
+ def to_text(self, origin=None, relativize=True, **kw):
+ return f"{self.preference} {self.locator32}"
+
+ @classmethod
+ def from_text(
+ cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
+ ):
+ preference = tok.get_uint16()
+ nodeid = tok.get_identifier()
+ return cls(rdclass, rdtype, preference, nodeid)
+
+ def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
+ file.write(struct.pack("!H", self.preference))
+ file.write(dns.ipv4.inet_aton(self.locator32))
+
+ @classmethod
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ preference = parser.get_uint16()
+ locator32 = parser.get_remaining()
+ return cls(rdclass, rdtype, preference, locator32)
diff --git a/dns/rdtypes/ANY/L64.py b/dns/rdtypes/ANY/L64.py
index 34cf23b..fb76808 100644
--- a/dns/rdtypes/ANY/L64.py
+++ b/dns/rdtypes/ANY/L64.py
@@ -1,4 +1,7 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
import struct
+
import dns.immutable
import dns.rdtypes.util
@@ -6,15 +9,39 @@ import dns.rdtypes.util
@dns.immutable.immutable
class L64(dns.rdata.Rdata):
"""L64 record"""
- __slots__ = ['preference', 'locator64']
+
+ # see: rfc6742.txt
+
+ __slots__ = ["preference", "locator64"]
def __init__(self, rdclass, rdtype, preference, locator64):
super().__init__(rdclass, rdtype)
self.preference = self._as_uint16(preference)
if isinstance(locator64, bytes):
if len(locator64) != 8:
- raise ValueError('invalid locator64')
- self.locator64 = dns.rdata._hexify(locator64, 4, b':')
+ raise ValueError("invalid locator64")
+ self.locator64 = dns.rdata._hexify(locator64, 4, b":")
else:
- dns.rdtypes.util.parse_formatted_hex(locator64, 4, 4, ':')
+ dns.rdtypes.util.parse_formatted_hex(locator64, 4, 4, ":")
self.locator64 = locator64
+
+ def to_text(self, origin=None, relativize=True, **kw):
+ return f"{self.preference} {self.locator64}"
+
+ @classmethod
+ def from_text(
+ cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
+ ):
+ preference = tok.get_uint16()
+ locator64 = tok.get_identifier()
+ return cls(rdclass, rdtype, preference, locator64)
+
+ def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
+ file.write(struct.pack("!H", self.preference))
+ file.write(dns.rdtypes.util.parse_formatted_hex(self.locator64, 4, 4, ":"))
+
+ @classmethod
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ preference = parser.get_uint16()
+ locator64 = parser.get_remaining()
+ return cls(rdclass, rdtype, preference, locator64)
diff --git a/dns/rdtypes/ANY/LOC.py b/dns/rdtypes/ANY/LOC.py
index 2be3324..a36a2c1 100644
--- a/dns/rdtypes/ANY/LOC.py
+++ b/dns/rdtypes/ANY/LOC.py
@@ -1,25 +1,134 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
import struct
+
import dns.exception
import dns.immutable
import dns.rdata
-_pows = tuple(10 ** i for i in range(0, 11))
+
+_pows = tuple(10**i for i in range(0, 11))
+
+# default values are in centimeters
_default_size = 100.0
_default_hprec = 1000000.0
_default_vprec = 1000.0
-_MAX_LATITUDE = 2147483648 + 90 * 3600000
-_MIN_LATITUDE = 2147483648 - 90 * 3600000
-_MAX_LONGITUDE = 2147483648 + 180 * 3600000
-_MIN_LONGITUDE = 2147483648 - 180 * 3600000
+
+# for use by from_wire()
+_MAX_LATITUDE = 0x80000000 + 90 * 3600000
+_MIN_LATITUDE = 0x80000000 - 90 * 3600000
+_MAX_LONGITUDE = 0x80000000 + 180 * 3600000
+_MIN_LONGITUDE = 0x80000000 - 180 * 3600000
+
+
+def _exponent_of(what, desc):
+ if what == 0:
+ return 0
+ exp = None
+ for i, pow in enumerate(_pows):
+ if what < pow:
+ exp = i - 1
+ break
+ if exp is None or exp < 0:
+ raise dns.exception.SyntaxError("%s value out of bounds" % desc)
+ return exp
+
+
+def _float_to_tuple(what):
+ if what < 0:
+ sign = -1
+ what *= -1
+ else:
+ sign = 1
+ what = round(what * 3600000)
+ degrees = int(what // 3600000)
+ what -= degrees * 3600000
+ minutes = int(what // 60000)
+ what -= minutes * 60000
+ seconds = int(what // 1000)
+ what -= int(seconds * 1000)
+ what = int(what)
+ return (degrees, minutes, seconds, what, sign)
+
+
+def _tuple_to_float(what):
+ value = float(what[0])
+ value += float(what[1]) / 60.0
+ value += float(what[2]) / 3600.0
+ value += float(what[3]) / 3600000.0
+ return float(what[4]) * value
+
+
+def _encode_size(what, desc):
+ what = int(what)
+ exponent = _exponent_of(what, desc) & 0xF
+ base = what // pow(10, exponent) & 0xF
+ return base * 16 + exponent
+
+
+def _decode_size(what, desc):
+ exponent = what & 0x0F
+ if exponent > 9:
+ raise dns.exception.FormError("bad %s exponent" % desc)
+ base = (what & 0xF0) >> 4
+ if base > 9:
+ raise dns.exception.FormError("bad %s base" % desc)
+ return base * pow(10, exponent)
+
+
+def _check_coordinate_list(value, low, high):
+ if value[0] < low or value[0] > high:
+ raise ValueError(f"not in range [{low}, {high}]")
+ if value[1] < 0 or value[1] > 59:
+ raise ValueError("bad minutes value")
+ if value[2] < 0 or value[2] > 59:
+ raise ValueError("bad seconds value")
+ if value[3] < 0 or value[3] > 999:
+ raise ValueError("bad milliseconds value")
+ if value[4] != 1 and value[4] != -1:
+ raise ValueError("bad hemisphere value")
@dns.immutable.immutable
class LOC(dns.rdata.Rdata):
"""LOC record"""
- __slots__ = ['latitude', 'longitude', 'altitude', 'size',
- 'horizontal_precision', 'vertical_precision']
- def __init__(self, rdclass, rdtype, latitude, longitude, altitude, size
- =_default_size, hprec=_default_hprec, vprec=_default_vprec):
+ # see: RFC 1876
+
+ __slots__ = [
+ "latitude",
+ "longitude",
+ "altitude",
+ "size",
+ "horizontal_precision",
+ "vertical_precision",
+ ]
+
+ def __init__(
+ self,
+ rdclass,
+ rdtype,
+ latitude,
+ longitude,
+ altitude,
+ size=_default_size,
+ hprec=_default_hprec,
+ vprec=_default_vprec,
+ ):
"""Initialize a LOC record instance.
The parameters I{latitude} and I{longitude} may be either a 4-tuple
@@ -27,6 +136,7 @@ class LOC(dns.rdata.Rdata):
or they may be floating point values specifying the number of
degrees. The other parameters are floats. Size, horizontal precision,
and vertical precision are specified in centimeters."""
+
super().__init__(rdclass, rdtype)
if isinstance(latitude, int):
latitude = float(latitude)
@@ -45,12 +155,200 @@ class LOC(dns.rdata.Rdata):
self.horizontal_precision = float(hprec)
self.vertical_precision = float(vprec)
+ def to_text(self, origin=None, relativize=True, **kw):
+ if self.latitude[4] > 0:
+ lat_hemisphere = "N"
+ else:
+ lat_hemisphere = "S"
+ if self.longitude[4] > 0:
+ long_hemisphere = "E"
+ else:
+ long_hemisphere = "W"
+ text = "%d %d %d.%03d %s %d %d %d.%03d %s %0.2fm" % (
+ self.latitude[0],
+ self.latitude[1],
+ self.latitude[2],
+ self.latitude[3],
+ lat_hemisphere,
+ self.longitude[0],
+ self.longitude[1],
+ self.longitude[2],
+ self.longitude[3],
+ long_hemisphere,
+ self.altitude / 100.0,
+ )
+
+ # do not print default values
+ if (
+ self.size != _default_size
+ or self.horizontal_precision != _default_hprec
+ or self.vertical_precision != _default_vprec
+ ):
+ text += " {:0.2f}m {:0.2f}m {:0.2f}m".format(
+ self.size / 100.0,
+ self.horizontal_precision / 100.0,
+ self.vertical_precision / 100.0,
+ )
+ return text
+
+ @classmethod
+ def from_text(
+ cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
+ ):
+ latitude = [0, 0, 0, 0, 1]
+ longitude = [0, 0, 0, 0, 1]
+ size = _default_size
+ hprec = _default_hprec
+ vprec = _default_vprec
+
+ latitude[0] = tok.get_int()
+ t = tok.get_string()
+ if t.isdigit():
+ latitude[1] = int(t)
+ t = tok.get_string()
+ if "." in t:
+ (seconds, milliseconds) = t.split(".")
+ if not seconds.isdigit():
+ raise dns.exception.SyntaxError("bad latitude seconds value")
+ latitude[2] = int(seconds)
+ l = len(milliseconds)
+ if l == 0 or l > 3 or not milliseconds.isdigit():
+ raise dns.exception.SyntaxError("bad latitude milliseconds value")
+ if l == 1:
+ m = 100
+ elif l == 2:
+ m = 10
+ else:
+ m = 1
+ latitude[3] = m * int(milliseconds)
+ t = tok.get_string()
+ elif t.isdigit():
+ latitude[2] = int(t)
+ t = tok.get_string()
+ if t == "S":
+ latitude[4] = -1
+ elif t != "N":
+ raise dns.exception.SyntaxError("bad latitude hemisphere value")
+
+ longitude[0] = tok.get_int()
+ t = tok.get_string()
+ if t.isdigit():
+ longitude[1] = int(t)
+ t = tok.get_string()
+ if "." in t:
+ (seconds, milliseconds) = t.split(".")
+ if not seconds.isdigit():
+ raise dns.exception.SyntaxError("bad longitude seconds value")
+ longitude[2] = int(seconds)
+ l = len(milliseconds)
+ if l == 0 or l > 3 or not milliseconds.isdigit():
+ raise dns.exception.SyntaxError("bad longitude milliseconds value")
+ if l == 1:
+ m = 100
+ elif l == 2:
+ m = 10
+ else:
+ m = 1
+ longitude[3] = m * int(milliseconds)
+ t = tok.get_string()
+ elif t.isdigit():
+ longitude[2] = int(t)
+ t = tok.get_string()
+ if t == "W":
+ longitude[4] = -1
+ elif t != "E":
+ raise dns.exception.SyntaxError("bad longitude hemisphere value")
+
+ t = tok.get_string()
+ if t[-1] == "m":
+ t = t[0:-1]
+ altitude = float(t) * 100.0 # m -> cm
+
+ tokens = tok.get_remaining(max_tokens=3)
+ if len(tokens) >= 1:
+ value = tokens[0].unescape().value
+ if value[-1] == "m":
+ value = value[0:-1]
+ size = float(value) * 100.0 # m -> cm
+ if len(tokens) >= 2:
+ value = tokens[1].unescape().value
+ if value[-1] == "m":
+ value = value[0:-1]
+ hprec = float(value) * 100.0 # m -> cm
+ if len(tokens) >= 3:
+ value = tokens[2].unescape().value
+ if value[-1] == "m":
+ value = value[0:-1]
+ vprec = float(value) * 100.0 # m -> cm
+
+ # Try encoding these now so we raise if they are bad
+ _encode_size(size, "size")
+ _encode_size(hprec, "horizontal precision")
+ _encode_size(vprec, "vertical precision")
+
+ return cls(rdclass, rdtype, latitude, longitude, altitude, size, hprec, vprec)
+
+ def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
+ milliseconds = (
+ self.latitude[0] * 3600000
+ + self.latitude[1] * 60000
+ + self.latitude[2] * 1000
+ + self.latitude[3]
+ ) * self.latitude[4]
+ latitude = 0x80000000 + milliseconds
+ milliseconds = (
+ self.longitude[0] * 3600000
+ + self.longitude[1] * 60000
+ + self.longitude[2] * 1000
+ + self.longitude[3]
+ ) * self.longitude[4]
+ longitude = 0x80000000 + milliseconds
+ altitude = int(self.altitude) + 10000000
+ size = _encode_size(self.size, "size")
+ hprec = _encode_size(self.horizontal_precision, "horizontal precision")
+ vprec = _encode_size(self.vertical_precision, "vertical precision")
+ wire = struct.pack(
+ "!BBBBIII", 0, size, hprec, vprec, latitude, longitude, altitude
+ )
+ file.write(wire)
+
+ @classmethod
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ (
+ version,
+ size,
+ hprec,
+ vprec,
+ latitude,
+ longitude,
+ altitude,
+ ) = parser.get_struct("!BBBBIII")
+ if version != 0:
+ raise dns.exception.FormError("LOC version not zero")
+ if latitude < _MIN_LATITUDE or latitude > _MAX_LATITUDE:
+ raise dns.exception.FormError("bad latitude")
+ if latitude > 0x80000000:
+ latitude = (latitude - 0x80000000) / 3600000
+ else:
+ latitude = -1 * (0x80000000 - latitude) / 3600000
+ if longitude < _MIN_LONGITUDE or longitude > _MAX_LONGITUDE:
+ raise dns.exception.FormError("bad longitude")
+ if longitude > 0x80000000:
+ longitude = (longitude - 0x80000000) / 3600000
+ else:
+ longitude = -1 * (0x80000000 - longitude) / 3600000
+ altitude = float(altitude) - 10000000.0
+ size = _decode_size(size, "size")
+ hprec = _decode_size(hprec, "horizontal precision")
+ vprec = _decode_size(vprec, "vertical precision")
+ return cls(rdclass, rdtype, latitude, longitude, altitude, size, hprec, vprec)
+
@property
def float_latitude(self):
- """latitude as a floating point value"""
- pass
+ "latitude as a floating point value"
+ return _tuple_to_float(self.latitude)
@property
def float_longitude(self):
- """longitude as a floating point value"""
- pass
+ "longitude as a floating point value"
+ return _tuple_to_float(self.longitude)
diff --git a/dns/rdtypes/ANY/LP.py b/dns/rdtypes/ANY/LP.py
index 44d48de..312663f 100644
--- a/dns/rdtypes/ANY/LP.py
+++ b/dns/rdtypes/ANY/LP.py
@@ -1,4 +1,7 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
import struct
+
import dns.immutable
import dns.rdata
@@ -6,9 +9,34 @@ import dns.rdata
@dns.immutable.immutable
class LP(dns.rdata.Rdata):
"""LP record"""
- __slots__ = ['preference', 'fqdn']
+
+ # see: rfc6742.txt
+
+ __slots__ = ["preference", "fqdn"]
def __init__(self, rdclass, rdtype, preference, fqdn):
super().__init__(rdclass, rdtype)
self.preference = self._as_uint16(preference)
self.fqdn = self._as_name(fqdn)
+
+ def to_text(self, origin=None, relativize=True, **kw):
+ fqdn = self.fqdn.choose_relativity(origin, relativize)
+ return "%d %s" % (self.preference, fqdn)
+
+ @classmethod
+ def from_text(
+ cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
+ ):
+ preference = tok.get_uint16()
+ fqdn = tok.get_name(origin, relativize, relativize_to)
+ return cls(rdclass, rdtype, preference, fqdn)
+
+ def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
+ file.write(struct.pack("!H", self.preference))
+ self.fqdn.to_wire(file, compress, origin, canonicalize)
+
+ @classmethod
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ preference = parser.get_uint16()
+ fqdn = parser.get_name(origin)
+ return cls(rdclass, rdtype, preference, fqdn)
diff --git a/dns/rdtypes/ANY/MX.py b/dns/rdtypes/ANY/MX.py
index 560cf68..0c300c5 100644
--- a/dns/rdtypes/ANY/MX.py
+++ b/dns/rdtypes/ANY/MX.py
@@ -1,3 +1,20 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
import dns.immutable
import dns.rdtypes.mxbase
diff --git a/dns/rdtypes/ANY/NID.py b/dns/rdtypes/ANY/NID.py
index 4f8b115..2f64917 100644
--- a/dns/rdtypes/ANY/NID.py
+++ b/dns/rdtypes/ANY/NID.py
@@ -1,4 +1,7 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
import struct
+
import dns.immutable
import dns.rdtypes.util
@@ -6,15 +9,39 @@ import dns.rdtypes.util
@dns.immutable.immutable
class NID(dns.rdata.Rdata):
"""NID record"""
- __slots__ = ['preference', 'nodeid']
+
+ # see: rfc6742.txt
+
+ __slots__ = ["preference", "nodeid"]
def __init__(self, rdclass, rdtype, preference, nodeid):
super().__init__(rdclass, rdtype)
self.preference = self._as_uint16(preference)
if isinstance(nodeid, bytes):
if len(nodeid) != 8:
- raise ValueError('invalid nodeid')
- self.nodeid = dns.rdata._hexify(nodeid, 4, b':')
+ raise ValueError("invalid nodeid")
+ self.nodeid = dns.rdata._hexify(nodeid, 4, b":")
else:
- dns.rdtypes.util.parse_formatted_hex(nodeid, 4, 4, ':')
+ dns.rdtypes.util.parse_formatted_hex(nodeid, 4, 4, ":")
self.nodeid = nodeid
+
+ def to_text(self, origin=None, relativize=True, **kw):
+ return f"{self.preference} {self.nodeid}"
+
+ @classmethod
+ def from_text(
+ cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
+ ):
+ preference = tok.get_uint16()
+ nodeid = tok.get_identifier()
+ return cls(rdclass, rdtype, preference, nodeid)
+
+ def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
+ file.write(struct.pack("!H", self.preference))
+ file.write(dns.rdtypes.util.parse_formatted_hex(self.nodeid, 4, 4, ":"))
+
+ @classmethod
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ preference = parser.get_uint16()
+ nodeid = parser.get_remaining()
+ return cls(rdclass, rdtype, preference, nodeid)
diff --git a/dns/rdtypes/ANY/NINFO.py b/dns/rdtypes/ANY/NINFO.py
index cb66fe7..b177bdd 100644
--- a/dns/rdtypes/ANY/NINFO.py
+++ b/dns/rdtypes/ANY/NINFO.py
@@ -1,3 +1,20 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2006, 2007, 2009-2011 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
import dns.immutable
import dns.rdtypes.txtbase
@@ -5,3 +22,5 @@ import dns.rdtypes.txtbase
@dns.immutable.immutable
class NINFO(dns.rdtypes.txtbase.TXTBase):
"""NINFO record"""
+
+ # see: draft-reid-dnsext-zs-01
diff --git a/dns/rdtypes/ANY/NS.py b/dns/rdtypes/ANY/NS.py
index bc1f8b0..c3f34ce 100644
--- a/dns/rdtypes/ANY/NS.py
+++ b/dns/rdtypes/ANY/NS.py
@@ -1,3 +1,20 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
import dns.immutable
import dns.rdtypes.nsbase
diff --git a/dns/rdtypes/ANY/NSEC.py b/dns/rdtypes/ANY/NSEC.py
index 3a68e73..340525a 100644
--- a/dns/rdtypes/ANY/NSEC.py
+++ b/dns/rdtypes/ANY/NSEC.py
@@ -1,3 +1,20 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2004-2007, 2009-2011 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
import dns.exception
import dns.immutable
import dns.name
@@ -8,13 +25,14 @@ import dns.rdtypes.util
@dns.immutable.immutable
class Bitmap(dns.rdtypes.util.Bitmap):
- type_name = 'NSEC'
+ type_name = "NSEC"
@dns.immutable.immutable
class NSEC(dns.rdata.Rdata):
"""NSEC record"""
- __slots__ = ['next', 'windows']
+
+ __slots__ = ["next", "windows"]
def __init__(self, rdclass, rdtype, next, windows):
super().__init__(rdclass, rdtype)
@@ -22,3 +40,28 @@ class NSEC(dns.rdata.Rdata):
if not isinstance(windows, Bitmap):
windows = Bitmap(windows)
self.windows = tuple(windows.windows)
+
+ def to_text(self, origin=None, relativize=True, **kw):
+ next = self.next.choose_relativity(origin, relativize)
+ text = Bitmap(self.windows).to_text()
+ return "{}{}".format(next, text)
+
+ @classmethod
+ def from_text(
+ cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
+ ):
+ next = tok.get_name(origin, relativize, relativize_to)
+ windows = Bitmap.from_text(tok)
+ return cls(rdclass, rdtype, next, windows)
+
+ def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
+ # Note that NSEC downcasing, originally mandated by RFC 4034
+ # section 6.2 was removed by RFC 6840 section 5.1.
+ self.next.to_wire(file, None, origin, False)
+ Bitmap(self.windows).to_wire(file)
+
+ @classmethod
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ next = parser.get_name(origin)
+ bitmap = Bitmap.from_wire_parser(parser)
+ return cls(rdclass, rdtype, next, bitmap)
diff --git a/dns/rdtypes/ANY/NSEC3.py b/dns/rdtypes/ANY/NSEC3.py
index 4f6caed..d71302b 100644
--- a/dns/rdtypes/ANY/NSEC3.py
+++ b/dns/rdtypes/ANY/NSEC3.py
@@ -1,31 +1,58 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2004-2017 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
import base64
import binascii
import struct
+
import dns.exception
import dns.immutable
import dns.rdata
import dns.rdatatype
import dns.rdtypes.util
-b32_hex_to_normal = bytes.maketrans(b'0123456789ABCDEFGHIJKLMNOPQRSTUV',
- b'ABCDEFGHIJKLMNOPQRSTUVWXYZ234567')
-b32_normal_to_hex = bytes.maketrans(b'ABCDEFGHIJKLMNOPQRSTUVWXYZ234567',
- b'0123456789ABCDEFGHIJKLMNOPQRSTUV')
+
+b32_hex_to_normal = bytes.maketrans(
+ b"0123456789ABCDEFGHIJKLMNOPQRSTUV", b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567"
+)
+b32_normal_to_hex = bytes.maketrans(
+ b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567", b"0123456789ABCDEFGHIJKLMNOPQRSTUV"
+)
+
+# hash algorithm constants
SHA1 = 1
+
+# flag constants
OPTOUT = 1
@dns.immutable.immutable
class Bitmap(dns.rdtypes.util.Bitmap):
- type_name = 'NSEC3'
+ type_name = "NSEC3"
@dns.immutable.immutable
class NSEC3(dns.rdata.Rdata):
"""NSEC3 record"""
- __slots__ = ['algorithm', 'flags', 'iterations', 'salt', 'next', 'windows']
- def __init__(self, rdclass, rdtype, algorithm, flags, iterations, salt,
- next, windows):
+ __slots__ = ["algorithm", "flags", "iterations", "salt", "next", "windows"]
+
+ def __init__(
+ self, rdclass, rdtype, algorithm, flags, iterations, salt, next, windows
+ ):
super().__init__(rdclass, rdtype)
self.algorithm = self._as_uint8(algorithm)
self.flags = self._as_uint8(flags)
@@ -35,3 +62,65 @@ class NSEC3(dns.rdata.Rdata):
if not isinstance(windows, Bitmap):
windows = Bitmap(windows)
self.windows = tuple(windows.windows)
+
+ def _next_text(self):
+ next = base64.b32encode(self.next).translate(b32_normal_to_hex).lower().decode()
+ next = next.rstrip("=")
+ return next
+
+ def to_text(self, origin=None, relativize=True, **kw):
+ next = self._next_text()
+ if self.salt == b"":
+ salt = "-"
+ else:
+ salt = binascii.hexlify(self.salt).decode()
+ text = Bitmap(self.windows).to_text()
+ return "%u %u %u %s %s%s" % (
+ self.algorithm,
+ self.flags,
+ self.iterations,
+ salt,
+ next,
+ text,
+ )
+
+ @classmethod
+ def from_text(
+ cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
+ ):
+ algorithm = tok.get_uint8()
+ flags = tok.get_uint8()
+ iterations = tok.get_uint16()
+ salt = tok.get_string()
+ if salt == "-":
+ salt = b""
+ else:
+ salt = binascii.unhexlify(salt.encode("ascii"))
+ next = tok.get_string().encode("ascii").upper().translate(b32_hex_to_normal)
+ if next.endswith(b"="):
+ raise binascii.Error("Incorrect padding")
+ if len(next) % 8 != 0:
+ next += b"=" * (8 - len(next) % 8)
+ next = base64.b32decode(next)
+ bitmap = Bitmap.from_text(tok)
+ return cls(rdclass, rdtype, algorithm, flags, iterations, salt, next, bitmap)
+
+ def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
+ l = len(self.salt)
+ file.write(struct.pack("!BBHB", self.algorithm, self.flags, self.iterations, l))
+ file.write(self.salt)
+ l = len(self.next)
+ file.write(struct.pack("!B", l))
+ file.write(self.next)
+ Bitmap(self.windows).to_wire(file)
+
+ @classmethod
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ (algorithm, flags, iterations) = parser.get_struct("!BBH")
+ salt = parser.get_counted_bytes()
+ next = parser.get_counted_bytes()
+ bitmap = Bitmap.from_wire_parser(parser)
+ return cls(rdclass, rdtype, algorithm, flags, iterations, salt, next, bitmap)
+
+ def next_name(self, origin=None):
+ return dns.name.from_text(self._next_text(), origin)
diff --git a/dns/rdtypes/ANY/NSEC3PARAM.py b/dns/rdtypes/ANY/NSEC3PARAM.py
index 0d40074..d1e62eb 100644
--- a/dns/rdtypes/ANY/NSEC3PARAM.py
+++ b/dns/rdtypes/ANY/NSEC3PARAM.py
@@ -1,5 +1,23 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2004-2007, 2009-2011 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
import binascii
import struct
+
import dns.exception
import dns.immutable
import dns.rdata
@@ -8,7 +26,8 @@ import dns.rdata
@dns.immutable.immutable
class NSEC3PARAM(dns.rdata.Rdata):
"""NSEC3PARAM record"""
- __slots__ = ['algorithm', 'flags', 'iterations', 'salt']
+
+ __slots__ = ["algorithm", "flags", "iterations", "salt"]
def __init__(self, rdclass, rdtype, algorithm, flags, iterations, salt):
super().__init__(rdclass, rdtype)
@@ -16,3 +35,35 @@ class NSEC3PARAM(dns.rdata.Rdata):
self.flags = self._as_uint8(flags)
self.iterations = self._as_uint16(iterations)
self.salt = self._as_bytes(salt, True, 255)
+
+ def to_text(self, origin=None, relativize=True, **kw):
+ if self.salt == b"":
+ salt = "-"
+ else:
+ salt = binascii.hexlify(self.salt).decode()
+ return "%u %u %u %s" % (self.algorithm, self.flags, self.iterations, salt)
+
+ @classmethod
+ def from_text(
+ cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
+ ):
+ algorithm = tok.get_uint8()
+ flags = tok.get_uint8()
+ iterations = tok.get_uint16()
+ salt = tok.get_string()
+ if salt == "-":
+ salt = ""
+ else:
+ salt = binascii.unhexlify(salt.encode())
+ return cls(rdclass, rdtype, algorithm, flags, iterations, salt)
+
+ def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
+ l = len(self.salt)
+ file.write(struct.pack("!BBHB", self.algorithm, self.flags, self.iterations, l))
+ file.write(self.salt)
+
+ @classmethod
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ (algorithm, flags, iterations) = parser.get_struct("!BBH")
+ salt = parser.get_counted_bytes()
+ return cls(rdclass, rdtype, algorithm, flags, iterations, salt)
diff --git a/dns/rdtypes/ANY/OPENPGPKEY.py b/dns/rdtypes/ANY/OPENPGPKEY.py
index 2722935..4d7a4b6 100644
--- a/dns/rdtypes/ANY/OPENPGPKEY.py
+++ b/dns/rdtypes/ANY/OPENPGPKEY.py
@@ -1,4 +1,22 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2016 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
import base64
+
import dns.exception
import dns.immutable
import dns.rdata
@@ -9,6 +27,27 @@ import dns.tokenizer
class OPENPGPKEY(dns.rdata.Rdata):
"""OPENPGPKEY record"""
+ # see: RFC 7929
+
def __init__(self, rdclass, rdtype, key):
super().__init__(rdclass, rdtype)
self.key = self._as_bytes(key)
+
+ def to_text(self, origin=None, relativize=True, **kw):
+ return dns.rdata._base64ify(self.key, chunksize=None, **kw)
+
+ @classmethod
+ def from_text(
+ cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
+ ):
+ b64 = tok.concatenate_remaining_identifiers().encode()
+ key = base64.b64decode(b64)
+ return cls(rdclass, rdtype, key)
+
+ def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
+ file.write(self.key)
+
+ @classmethod
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ key = parser.get_remaining()
+ return cls(rdclass, rdtype, key)
diff --git a/dns/rdtypes/ANY/OPT.py b/dns/rdtypes/ANY/OPT.py
index 096cbcb..d343dfa 100644
--- a/dns/rdtypes/ANY/OPT.py
+++ b/dns/rdtypes/ANY/OPT.py
@@ -1,14 +1,36 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2001-2017 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
import struct
+
import dns.edns
import dns.exception
import dns.immutable
import dns.rdata
+# We don't implement from_text, and that's ok.
+# pylint: disable=abstract-method
+
@dns.immutable.immutable
class OPT(dns.rdata.Rdata):
"""OPT record"""
- __slots__ = ['options']
+
+ __slots__ = ["options"]
def __init__(self, rdclass, rdtype, options):
"""Initialize an OPT rdata.
@@ -20,15 +42,36 @@ class OPT(dns.rdata.Rdata):
*options*, a tuple of ``bytes``
"""
+
super().__init__(rdclass, rdtype)
def as_option(option):
if not isinstance(option, dns.edns.Option):
- raise ValueError('option is not a dns.edns.option')
+ raise ValueError("option is not a dns.edns.option")
return option
+
self.options = self._as_tuple(options, as_option)
+ def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
+ for opt in self.options:
+ owire = opt.to_wire()
+ file.write(struct.pack("!HH", opt.otype, len(owire)))
+ file.write(owire)
+
+ def to_text(self, origin=None, relativize=True, **kw):
+ return " ".join(opt.to_text() for opt in self.options)
+
+ @classmethod
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ options = []
+ while parser.remaining() > 0:
+ (otype, olen) = parser.get_struct("!HH")
+ with parser.restrict_to(olen):
+ opt = dns.edns.option_from_wire_parser(otype, parser)
+ options.append(opt)
+ return cls(rdclass, rdtype, options)
+
@property
def payload(self):
- """payload size"""
- pass
+ "payload size"
+ return self.rdclass
diff --git a/dns/rdtypes/ANY/PTR.py b/dns/rdtypes/ANY/PTR.py
index b0a3915..98c3616 100644
--- a/dns/rdtypes/ANY/PTR.py
+++ b/dns/rdtypes/ANY/PTR.py
@@ -1,3 +1,20 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
import dns.immutable
import dns.rdtypes.nsbase
diff --git a/dns/rdtypes/ANY/RP.py b/dns/rdtypes/ANY/RP.py
index ab9501f..9b74549 100644
--- a/dns/rdtypes/ANY/RP.py
+++ b/dns/rdtypes/ANY/RP.py
@@ -1,3 +1,20 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
import dns.exception
import dns.immutable
import dns.name
@@ -7,9 +24,35 @@ import dns.rdata
@dns.immutable.immutable
class RP(dns.rdata.Rdata):
"""RP record"""
- __slots__ = ['mbox', 'txt']
+
+ # see: RFC 1183
+
+ __slots__ = ["mbox", "txt"]
def __init__(self, rdclass, rdtype, mbox, txt):
super().__init__(rdclass, rdtype)
self.mbox = self._as_name(mbox)
self.txt = self._as_name(txt)
+
+ def to_text(self, origin=None, relativize=True, **kw):
+ mbox = self.mbox.choose_relativity(origin, relativize)
+ txt = self.txt.choose_relativity(origin, relativize)
+ return "{} {}".format(str(mbox), str(txt))
+
+ @classmethod
+ def from_text(
+ cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
+ ):
+ mbox = tok.get_name(origin, relativize, relativize_to)
+ txt = tok.get_name(origin, relativize, relativize_to)
+ return cls(rdclass, rdtype, mbox, txt)
+
+ def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
+ self.mbox.to_wire(file, None, origin, canonicalize)
+ self.txt.to_wire(file, None, origin, canonicalize)
+
+ @classmethod
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ mbox = parser.get_name(origin)
+ txt = parser.get_name(origin)
+ return cls(rdclass, rdtype, mbox, txt)
diff --git a/dns/rdtypes/ANY/RRSIG.py b/dns/rdtypes/ANY/RRSIG.py
index d8b408e..8beb423 100644
--- a/dns/rdtypes/ANY/RRSIG.py
+++ b/dns/rdtypes/ANY/RRSIG.py
@@ -1,7 +1,25 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2004-2007, 2009-2011 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
import base64
import calendar
import struct
import time
+
import dns.dnssectypes
import dns.exception
import dns.immutable
@@ -13,14 +31,54 @@ class BadSigTime(dns.exception.DNSException):
"""Time in DNS SIG or RRSIG resource record cannot be parsed."""
+def sigtime_to_posixtime(what):
+ if len(what) <= 10 and what.isdigit():
+ return int(what)
+ if len(what) != 14:
+ raise BadSigTime
+ year = int(what[0:4])
+ month = int(what[4:6])
+ day = int(what[6:8])
+ hour = int(what[8:10])
+ minute = int(what[10:12])
+ second = int(what[12:14])
+ return calendar.timegm((year, month, day, hour, minute, second, 0, 0, 0))
+
+
+def posixtime_to_sigtime(what):
+ return time.strftime("%Y%m%d%H%M%S", time.gmtime(what))
+
+
@dns.immutable.immutable
class RRSIG(dns.rdata.Rdata):
"""RRSIG record"""
- __slots__ = ['type_covered', 'algorithm', 'labels', 'original_ttl',
- 'expiration', 'inception', 'key_tag', 'signer', 'signature']
- def __init__(self, rdclass, rdtype, type_covered, algorithm, labels,
- original_ttl, expiration, inception, key_tag, signer, signature):
+ __slots__ = [
+ "type_covered",
+ "algorithm",
+ "labels",
+ "original_ttl",
+ "expiration",
+ "inception",
+ "key_tag",
+ "signer",
+ "signature",
+ ]
+
+ def __init__(
+ self,
+ rdclass,
+ rdtype,
+ type_covered,
+ algorithm,
+ labels,
+ original_ttl,
+ expiration,
+ inception,
+ key_tag,
+ signer,
+ signature,
+ ):
super().__init__(rdclass, rdtype)
self.type_covered = self._as_rdatatype(type_covered)
self.algorithm = dns.dnssectypes.Algorithm.make(algorithm)
@@ -31,3 +89,69 @@ class RRSIG(dns.rdata.Rdata):
self.key_tag = self._as_uint16(key_tag)
self.signer = self._as_name(signer)
self.signature = self._as_bytes(signature)
+
+ def covers(self):
+ return self.type_covered
+
+ def to_text(self, origin=None, relativize=True, **kw):
+ return "%s %d %d %d %s %s %d %s %s" % (
+ dns.rdatatype.to_text(self.type_covered),
+ self.algorithm,
+ self.labels,
+ self.original_ttl,
+ posixtime_to_sigtime(self.expiration),
+ posixtime_to_sigtime(self.inception),
+ self.key_tag,
+ self.signer.choose_relativity(origin, relativize),
+ dns.rdata._base64ify(self.signature, **kw),
+ )
+
+ @classmethod
+ def from_text(
+ cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
+ ):
+ type_covered = dns.rdatatype.from_text(tok.get_string())
+ algorithm = dns.dnssectypes.Algorithm.from_text(tok.get_string())
+ labels = tok.get_int()
+ original_ttl = tok.get_ttl()
+ expiration = sigtime_to_posixtime(tok.get_string())
+ inception = sigtime_to_posixtime(tok.get_string())
+ key_tag = tok.get_int()
+ signer = tok.get_name(origin, relativize, relativize_to)
+ b64 = tok.concatenate_remaining_identifiers().encode()
+ signature = base64.b64decode(b64)
+ return cls(
+ rdclass,
+ rdtype,
+ type_covered,
+ algorithm,
+ labels,
+ original_ttl,
+ expiration,
+ inception,
+ key_tag,
+ signer,
+ signature,
+ )
+
+ def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
+ header = struct.pack(
+ "!HBBIIIH",
+ self.type_covered,
+ self.algorithm,
+ self.labels,
+ self.original_ttl,
+ self.expiration,
+ self.inception,
+ self.key_tag,
+ )
+ file.write(header)
+ self.signer.to_wire(file, None, origin, canonicalize)
+ file.write(self.signature)
+
+ @classmethod
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ header = parser.get_struct("!HBBIIIH")
+ signer = parser.get_name(origin)
+ signature = parser.get_remaining()
+ return cls(rdclass, rdtype, *header, signer, signature)
diff --git a/dns/rdtypes/ANY/RT.py b/dns/rdtypes/ANY/RT.py
index 93a6252..5a4d45c 100644
--- a/dns/rdtypes/ANY/RT.py
+++ b/dns/rdtypes/ANY/RT.py
@@ -1,3 +1,20 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
import dns.immutable
import dns.rdtypes.mxbase
diff --git a/dns/rdtypes/ANY/SMIMEA.py b/dns/rdtypes/ANY/SMIMEA.py
index f61fd49..55d87bf 100644
--- a/dns/rdtypes/ANY/SMIMEA.py
+++ b/dns/rdtypes/ANY/SMIMEA.py
@@ -1,3 +1,5 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
import dns.immutable
import dns.rdtypes.tlsabase
diff --git a/dns/rdtypes/ANY/SOA.py b/dns/rdtypes/ANY/SOA.py
index 62d0470..09aa832 100644
--- a/dns/rdtypes/ANY/SOA.py
+++ b/dns/rdtypes/ANY/SOA.py
@@ -1,4 +1,22 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
import struct
+
import dns.exception
import dns.immutable
import dns.name
@@ -8,11 +26,14 @@ import dns.rdata
@dns.immutable.immutable
class SOA(dns.rdata.Rdata):
"""SOA record"""
- __slots__ = ['mname', 'rname', 'serial', 'refresh', 'retry', 'expire',
- 'minimum']
- def __init__(self, rdclass, rdtype, mname, rname, serial, refresh,
- retry, expire, minimum):
+ # see: RFC 1035
+
+ __slots__ = ["mname", "rname", "serial", "refresh", "retry", "expire", "minimum"]
+
+ def __init__(
+ self, rdclass, rdtype, mname, rname, serial, refresh, retry, expire, minimum
+ ):
super().__init__(rdclass, rdtype)
self.mname = self._as_name(mname)
self.rname = self._as_name(rname)
@@ -21,3 +42,45 @@ class SOA(dns.rdata.Rdata):
self.retry = self._as_ttl(retry)
self.expire = self._as_ttl(expire)
self.minimum = self._as_ttl(minimum)
+
+ def to_text(self, origin=None, relativize=True, **kw):
+ mname = self.mname.choose_relativity(origin, relativize)
+ rname = self.rname.choose_relativity(origin, relativize)
+ return "%s %s %d %d %d %d %d" % (
+ mname,
+ rname,
+ self.serial,
+ self.refresh,
+ self.retry,
+ self.expire,
+ self.minimum,
+ )
+
+ @classmethod
+ def from_text(
+ cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
+ ):
+ mname = tok.get_name(origin, relativize, relativize_to)
+ rname = tok.get_name(origin, relativize, relativize_to)
+ serial = tok.get_uint32()
+ refresh = tok.get_ttl()
+ retry = tok.get_ttl()
+ expire = tok.get_ttl()
+ minimum = tok.get_ttl()
+ return cls(
+ rdclass, rdtype, mname, rname, serial, refresh, retry, expire, minimum
+ )
+
+ def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
+ self.mname.to_wire(file, compress, origin, canonicalize)
+ self.rname.to_wire(file, compress, origin, canonicalize)
+ five_ints = struct.pack(
+ "!IIIII", self.serial, self.refresh, self.retry, self.expire, self.minimum
+ )
+ file.write(five_ints)
+
+ @classmethod
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ mname = parser.get_name(origin)
+ rname = parser.get_name(origin)
+ return cls(rdclass, rdtype, mname, rname, *parser.get_struct("!IIIII"))
diff --git a/dns/rdtypes/ANY/SPF.py b/dns/rdtypes/ANY/SPF.py
index 1f512e9..1df3b70 100644
--- a/dns/rdtypes/ANY/SPF.py
+++ b/dns/rdtypes/ANY/SPF.py
@@ -1,3 +1,20 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2006, 2007, 2009-2011 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
import dns.immutable
import dns.rdtypes.txtbase
@@ -5,3 +22,5 @@ import dns.rdtypes.txtbase
@dns.immutable.immutable
class SPF(dns.rdtypes.txtbase.TXTBase):
"""SPF record"""
+
+ # see: RFC 4408
diff --git a/dns/rdtypes/ANY/SSHFP.py b/dns/rdtypes/ANY/SSHFP.py
index 5bf303c..d2c4b07 100644
--- a/dns/rdtypes/ANY/SSHFP.py
+++ b/dns/rdtypes/ANY/SSHFP.py
@@ -1,5 +1,23 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2005-2007, 2009-2011 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
import binascii
import struct
+
import dns.immutable
import dns.rdata
import dns.rdatatype
@@ -8,10 +26,43 @@ import dns.rdatatype
@dns.immutable.immutable
class SSHFP(dns.rdata.Rdata):
"""SSHFP record"""
- __slots__ = ['algorithm', 'fp_type', 'fingerprint']
+
+ # See RFC 4255
+
+ __slots__ = ["algorithm", "fp_type", "fingerprint"]
def __init__(self, rdclass, rdtype, algorithm, fp_type, fingerprint):
super().__init__(rdclass, rdtype)
self.algorithm = self._as_uint8(algorithm)
self.fp_type = self._as_uint8(fp_type)
self.fingerprint = self._as_bytes(fingerprint, True)
+
+ def to_text(self, origin=None, relativize=True, **kw):
+ kw = kw.copy()
+ chunksize = kw.pop("chunksize", 128)
+ return "%d %d %s" % (
+ self.algorithm,
+ self.fp_type,
+ dns.rdata._hexify(self.fingerprint, chunksize=chunksize, **kw),
+ )
+
+ @classmethod
+ def from_text(
+ cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
+ ):
+ algorithm = tok.get_uint8()
+ fp_type = tok.get_uint8()
+ fingerprint = tok.concatenate_remaining_identifiers().encode()
+ fingerprint = binascii.unhexlify(fingerprint)
+ return cls(rdclass, rdtype, algorithm, fp_type, fingerprint)
+
+ def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
+ header = struct.pack("!BB", self.algorithm, self.fp_type)
+ file.write(header)
+ file.write(self.fingerprint)
+
+ @classmethod
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ header = parser.get_struct("BB")
+ fingerprint = parser.get_remaining()
+ return cls(rdclass, rdtype, header[0], header[1], fingerprint)
diff --git a/dns/rdtypes/ANY/TKEY.py b/dns/rdtypes/ANY/TKEY.py
index 43462af..5b490b8 100644
--- a/dns/rdtypes/ANY/TKEY.py
+++ b/dns/rdtypes/ANY/TKEY.py
@@ -1,5 +1,23 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2004-2007, 2009-2011 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
import base64
import struct
+
import dns.exception
import dns.immutable
import dns.rdata
@@ -8,11 +26,29 @@ import dns.rdata
@dns.immutable.immutable
class TKEY(dns.rdata.Rdata):
"""TKEY Record"""
- __slots__ = ['algorithm', 'inception', 'expiration', 'mode', 'error',
- 'key', 'other']
- def __init__(self, rdclass, rdtype, algorithm, inception, expiration,
- mode, error, key, other=b''):
+ __slots__ = [
+ "algorithm",
+ "inception",
+ "expiration",
+ "mode",
+ "error",
+ "key",
+ "other",
+ ]
+
+ def __init__(
+ self,
+ rdclass,
+ rdtype,
+ algorithm,
+ inception,
+ expiration,
+ mode,
+ error,
+ key,
+ other=b"",
+ ):
super().__init__(rdclass, rdtype)
self.algorithm = self._as_name(algorithm)
self.inception = self._as_uint32(inception)
@@ -21,6 +57,84 @@ class TKEY(dns.rdata.Rdata):
self.error = self._as_uint16(error)
self.key = self._as_bytes(key)
self.other = self._as_bytes(other)
+
+ def to_text(self, origin=None, relativize=True, **kw):
+ _algorithm = self.algorithm.choose_relativity(origin, relativize)
+ text = "%s %u %u %u %u %s" % (
+ str(_algorithm),
+ self.inception,
+ self.expiration,
+ self.mode,
+ self.error,
+ dns.rdata._base64ify(self.key, 0),
+ )
+ if len(self.other) > 0:
+ text += " %s" % (dns.rdata._base64ify(self.other, 0))
+
+ return text
+
+ @classmethod
+ def from_text(
+ cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
+ ):
+ algorithm = tok.get_name(relativize=False)
+ inception = tok.get_uint32()
+ expiration = tok.get_uint32()
+ mode = tok.get_uint16()
+ error = tok.get_uint16()
+ key_b64 = tok.get_string().encode()
+ key = base64.b64decode(key_b64)
+ other_b64 = tok.concatenate_remaining_identifiers(True).encode()
+ other = base64.b64decode(other_b64)
+
+ return cls(
+ rdclass, rdtype, algorithm, inception, expiration, mode, error, key, other
+ )
+
+ def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
+ self.algorithm.to_wire(file, compress, origin)
+ file.write(
+ struct.pack("!IIHH", self.inception, self.expiration, self.mode, self.error)
+ )
+ file.write(struct.pack("!H", len(self.key)))
+ file.write(self.key)
+ file.write(struct.pack("!H", len(self.other)))
+ if len(self.other) > 0:
+ file.write(self.other)
+
+ @classmethod
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ algorithm = parser.get_name(origin)
+ inception, expiration, mode, error = parser.get_struct("!IIHH")
+ key = parser.get_counted_bytes(2)
+ other = parser.get_counted_bytes(2)
+
+ return cls(
+ rdclass, rdtype, algorithm, inception, expiration, mode, error, key, other
+ )
+
+ # Constants for the mode field - from RFC 2930:
+ # 2.5 The Mode Field
+ #
+ # The mode field specifies the general scheme for key agreement or
+ # the purpose of the TKEY DNS message. Servers and resolvers
+ # supporting this specification MUST implement the Diffie-Hellman key
+ # agreement mode and the key deletion mode for queries. All other
+ # modes are OPTIONAL. A server supporting TKEY that receives a TKEY
+ # request with a mode it does not support returns the BADMODE error.
+ # The following values of the Mode octet are defined, available, or
+ # reserved:
+ #
+ # Value Description
+ # ----- -----------
+ # 0 - reserved, see section 7
+ # 1 server assignment
+ # 2 Diffie-Hellman exchange
+ # 3 GSS-API negotiation
+ # 4 resolver assignment
+ # 5 key deletion
+ # 6-65534 - available, see section 7
+ # 65535 - reserved, see section 7
SERVER_ASSIGNMENT = 1
DIFFIE_HELLMAN_EXCHANGE = 2
GSSAPI_NEGOTIATION = 3
diff --git a/dns/rdtypes/ANY/TLSA.py b/dns/rdtypes/ANY/TLSA.py
index 3cc2a11..4dffc55 100644
--- a/dns/rdtypes/ANY/TLSA.py
+++ b/dns/rdtypes/ANY/TLSA.py
@@ -1,3 +1,5 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
import dns.immutable
import dns.rdtypes.tlsabase
diff --git a/dns/rdtypes/ANY/TSIG.py b/dns/rdtypes/ANY/TSIG.py
index 9f3e67e..7942382 100644
--- a/dns/rdtypes/ANY/TSIG.py
+++ b/dns/rdtypes/ANY/TSIG.py
@@ -1,5 +1,23 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2001-2017 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
import base64
import struct
+
import dns.exception
import dns.immutable
import dns.rcode
@@ -9,11 +27,29 @@ import dns.rdata
@dns.immutable.immutable
class TSIG(dns.rdata.Rdata):
"""TSIG record"""
- __slots__ = ['algorithm', 'time_signed', 'fudge', 'mac', 'original_id',
- 'error', 'other']
- def __init__(self, rdclass, rdtype, algorithm, time_signed, fudge, mac,
- original_id, error, other):
+ __slots__ = [
+ "algorithm",
+ "time_signed",
+ "fudge",
+ "mac",
+ "original_id",
+ "error",
+ "other",
+ ]
+
+ def __init__(
+ self,
+ rdclass,
+ rdtype,
+ algorithm,
+ time_signed,
+ fudge,
+ mac,
+ original_id,
+ error,
+ other,
+ ):
"""Initialize a TSIG rdata.
*rdclass*, an ``int`` is the rdataclass of the Rdata.
@@ -34,6 +70,7 @@ class TSIG(dns.rdata.Rdata):
*other*, a ``bytes``
"""
+
super().__init__(rdclass, rdtype)
self.algorithm = self._as_name(algorithm)
self.time_signed = self._as_uint48(time_signed)
@@ -42,3 +79,82 @@ class TSIG(dns.rdata.Rdata):
self.original_id = self._as_uint16(original_id)
self.error = dns.rcode.Rcode.make(error)
self.other = self._as_bytes(other)
+
+ def to_text(self, origin=None, relativize=True, **kw):
+ algorithm = self.algorithm.choose_relativity(origin, relativize)
+ error = dns.rcode.to_text(self.error, True)
+ text = (
+ f"{algorithm} {self.time_signed} {self.fudge} "
+ + f"{len(self.mac)} {dns.rdata._base64ify(self.mac, 0)} "
+ + f"{self.original_id} {error} {len(self.other)}"
+ )
+ if self.other:
+ text += f" {dns.rdata._base64ify(self.other, 0)}"
+ return text
+
+ @classmethod
+ def from_text(
+ cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
+ ):
+ algorithm = tok.get_name(relativize=False)
+ time_signed = tok.get_uint48()
+ fudge = tok.get_uint16()
+ mac_len = tok.get_uint16()
+ mac = base64.b64decode(tok.get_string())
+ if len(mac) != mac_len:
+ raise SyntaxError("invalid MAC")
+ original_id = tok.get_uint16()
+ error = dns.rcode.from_text(tok.get_string())
+ other_len = tok.get_uint16()
+ if other_len > 0:
+ other = base64.b64decode(tok.get_string())
+ if len(other) != other_len:
+ raise SyntaxError("invalid other data")
+ else:
+ other = b""
+ return cls(
+ rdclass,
+ rdtype,
+ algorithm,
+ time_signed,
+ fudge,
+ mac,
+ original_id,
+ error,
+ other,
+ )
+
+ def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
+ self.algorithm.to_wire(file, None, origin, False)
+ file.write(
+ struct.pack(
+ "!HIHH",
+ (self.time_signed >> 32) & 0xFFFF,
+ self.time_signed & 0xFFFFFFFF,
+ self.fudge,
+ len(self.mac),
+ )
+ )
+ file.write(self.mac)
+ file.write(struct.pack("!HHH", self.original_id, self.error, len(self.other)))
+ file.write(self.other)
+
+ @classmethod
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ algorithm = parser.get_name()
+ time_signed = parser.get_uint48()
+ fudge = parser.get_uint16()
+ mac = parser.get_counted_bytes(2)
+ (original_id, error) = parser.get_struct("!HH")
+ other = parser.get_counted_bytes(2)
+ return cls(
+ rdclass,
+ rdtype,
+ algorithm,
+ time_signed,
+ fudge,
+ mac,
+ original_id,
+ error,
+ other,
+ )
diff --git a/dns/rdtypes/ANY/TXT.py b/dns/rdtypes/ANY/TXT.py
index ecbfa14..6d4dae2 100644
--- a/dns/rdtypes/ANY/TXT.py
+++ b/dns/rdtypes/ANY/TXT.py
@@ -1,3 +1,20 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
import dns.immutable
import dns.rdtypes.txtbase
diff --git a/dns/rdtypes/ANY/URI.py b/dns/rdtypes/ANY/URI.py
index 9764c07..2efbb30 100644
--- a/dns/rdtypes/ANY/URI.py
+++ b/dns/rdtypes/ANY/URI.py
@@ -1,4 +1,23 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
+# Copyright (C) 2015 Red Hat, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
import struct
+
import dns.exception
import dns.immutable
import dns.name
@@ -9,7 +28,10 @@ import dns.rdtypes.util
@dns.immutable.immutable
class URI(dns.rdata.Rdata):
"""URI record"""
- __slots__ = ['priority', 'weight', 'target']
+
+ # see RFC 7553
+
+ __slots__ = ["priority", "weight", "target"]
def __init__(self, rdclass, rdtype, priority, weight, target):
super().__init__(rdclass, rdtype)
@@ -17,4 +39,41 @@ class URI(dns.rdata.Rdata):
self.weight = self._as_uint16(weight)
self.target = self._as_bytes(target, True)
if len(self.target) == 0:
- raise dns.exception.SyntaxError('URI target cannot be empty')
+ raise dns.exception.SyntaxError("URI target cannot be empty")
+
+ def to_text(self, origin=None, relativize=True, **kw):
+ return '%d %d "%s"' % (self.priority, self.weight, self.target.decode())
+
+ @classmethod
+ def from_text(
+ cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
+ ):
+ priority = tok.get_uint16()
+ weight = tok.get_uint16()
+ target = tok.get().unescape()
+ if not (target.is_quoted_string() or target.is_identifier()):
+ raise dns.exception.SyntaxError("URI target must be a string")
+ return cls(rdclass, rdtype, priority, weight, target.value)
+
+ def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
+ two_ints = struct.pack("!HH", self.priority, self.weight)
+ file.write(two_ints)
+ file.write(self.target)
+
+ @classmethod
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ (priority, weight) = parser.get_struct("!HH")
+ target = parser.get_remaining()
+ if len(target) == 0:
+ raise dns.exception.FormError("URI target may not be empty")
+ return cls(rdclass, rdtype, priority, weight, target)
+
+ def _processing_priority(self):
+ return self.priority
+
+ def _processing_weight(self):
+ return self.weight
+
+ @classmethod
+ def _processing_order(cls, iterable):
+ return dns.rdtypes.util.weighted_processing_order(iterable)
diff --git a/dns/rdtypes/ANY/X25.py b/dns/rdtypes/ANY/X25.py
index 81d5a53..8375611 100644
--- a/dns/rdtypes/ANY/X25.py
+++ b/dns/rdtypes/ANY/X25.py
@@ -1,4 +1,22 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
import struct
+
import dns.exception
import dns.immutable
import dns.rdata
@@ -8,8 +26,32 @@ import dns.tokenizer
@dns.immutable.immutable
class X25(dns.rdata.Rdata):
"""X25 record"""
- __slots__ = ['address']
+
+ # see RFC 1183
+
+ __slots__ = ["address"]
def __init__(self, rdclass, rdtype, address):
super().__init__(rdclass, rdtype)
self.address = self._as_bytes(address, True, 255)
+
+ def to_text(self, origin=None, relativize=True, **kw):
+ return '"%s"' % dns.rdata._escapify(self.address)
+
+ @classmethod
+ def from_text(
+ cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
+ ):
+ address = tok.get_string()
+ return cls(rdclass, rdtype, address)
+
+ def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
+ l = len(self.address)
+ assert l < 256
+ file.write(struct.pack("!B", l))
+ file.write(self.address)
+
+ @classmethod
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ address = parser.get_counted_bytes()
+ return cls(rdclass, rdtype, address)
diff --git a/dns/rdtypes/ANY/ZONEMD.py b/dns/rdtypes/ANY/ZONEMD.py
index 6beade3..c90e3ee 100644
--- a/dns/rdtypes/ANY/ZONEMD.py
+++ b/dns/rdtypes/ANY/ZONEMD.py
@@ -1,5 +1,8 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
import binascii
import struct
+
import dns.immutable
import dns.rdata
import dns.rdatatype
@@ -9,20 +12,55 @@ import dns.zonetypes
@dns.immutable.immutable
class ZONEMD(dns.rdata.Rdata):
"""ZONEMD record"""
- __slots__ = ['serial', 'scheme', 'hash_algorithm', 'digest']
- def __init__(self, rdclass, rdtype, serial, scheme, hash_algorithm, digest
- ):
+ # See RFC 8976
+
+ __slots__ = ["serial", "scheme", "hash_algorithm", "digest"]
+
+ def __init__(self, rdclass, rdtype, serial, scheme, hash_algorithm, digest):
super().__init__(rdclass, rdtype)
self.serial = self._as_uint32(serial)
self.scheme = dns.zonetypes.DigestScheme.make(scheme)
- self.hash_algorithm = dns.zonetypes.DigestHashAlgorithm.make(
- hash_algorithm)
+ self.hash_algorithm = dns.zonetypes.DigestHashAlgorithm.make(hash_algorithm)
self.digest = self._as_bytes(digest)
- if self.scheme == 0:
- raise ValueError('scheme 0 is reserved')
- if self.hash_algorithm == 0:
- raise ValueError('hash_algorithm 0 is reserved')
+
+ if self.scheme == 0: # reserved, RFC 8976 Sec. 5.2
+ raise ValueError("scheme 0 is reserved")
+ if self.hash_algorithm == 0: # reserved, RFC 8976 Sec. 5.3
+ raise ValueError("hash_algorithm 0 is reserved")
+
hasher = dns.zonetypes._digest_hashers.get(self.hash_algorithm)
if hasher and hasher().digest_size != len(self.digest):
- raise ValueError('digest length inconsistent with hash algorithm')
+ raise ValueError("digest length inconsistent with hash algorithm")
+
+ def to_text(self, origin=None, relativize=True, **kw):
+ kw = kw.copy()
+ chunksize = kw.pop("chunksize", 128)
+ return "%d %d %d %s" % (
+ self.serial,
+ self.scheme,
+ self.hash_algorithm,
+ dns.rdata._hexify(self.digest, chunksize=chunksize, **kw),
+ )
+
+ @classmethod
+ def from_text(
+ cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
+ ):
+ serial = tok.get_uint32()
+ scheme = tok.get_uint8()
+ hash_algorithm = tok.get_uint8()
+ digest = tok.concatenate_remaining_identifiers().encode()
+ digest = binascii.unhexlify(digest)
+ return cls(rdclass, rdtype, serial, scheme, hash_algorithm, digest)
+
+ def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
+ header = struct.pack("!IBB", self.serial, self.scheme, self.hash_algorithm)
+ file.write(header)
+ file.write(self.digest)
+
+ @classmethod
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ header = parser.get_struct("!IBB")
+ digest = parser.get_remaining()
+ return cls(rdclass, rdtype, header[0], header[1], header[2], digest)
diff --git a/dns/rdtypes/CH/A.py b/dns/rdtypes/CH/A.py
index 49c36d8..583a88a 100644
--- a/dns/rdtypes/CH/A.py
+++ b/dns/rdtypes/CH/A.py
@@ -1,4 +1,22 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
import struct
+
import dns.immutable
import dns.rdtypes.mxbase
@@ -6,9 +24,36 @@ import dns.rdtypes.mxbase
@dns.immutable.immutable
class A(dns.rdata.Rdata):
"""A record for Chaosnet"""
- __slots__ = ['domain', 'address']
+
+ # domain: the domain of the address
+ # address: the 16-bit address
+
+ __slots__ = ["domain", "address"]
def __init__(self, rdclass, rdtype, domain, address):
super().__init__(rdclass, rdtype)
self.domain = self._as_name(domain)
self.address = self._as_uint16(address)
+
+ def to_text(self, origin=None, relativize=True, **kw):
+ domain = self.domain.choose_relativity(origin, relativize)
+ return "%s %o" % (domain, self.address)
+
+ @classmethod
+ def from_text(
+ cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
+ ):
+ domain = tok.get_name(origin, relativize, relativize_to)
+ address = tok.get_uint16(base=8)
+ return cls(rdclass, rdtype, domain, address)
+
+ def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
+ self.domain.to_wire(file, compress, origin, canonicalize)
+ pref = struct.pack("!H", self.address)
+ file.write(pref)
+
+ @classmethod
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ domain = parser.get_name(origin)
+ address = parser.get_uint16()
+ return cls(rdclass, rdtype, domain, address)
diff --git a/dns/rdtypes/IN/A.py b/dns/rdtypes/IN/A.py
index 5c5f664..e09d611 100644
--- a/dns/rdtypes/IN/A.py
+++ b/dns/rdtypes/IN/A.py
@@ -1,3 +1,20 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
import dns.exception
import dns.immutable
import dns.ipv4
@@ -8,8 +25,27 @@ import dns.tokenizer
@dns.immutable.immutable
class A(dns.rdata.Rdata):
"""A record."""
- __slots__ = ['address']
+
+ __slots__ = ["address"]
def __init__(self, rdclass, rdtype, address):
super().__init__(rdclass, rdtype)
self.address = self._as_ipv4_address(address)
+
+ def to_text(self, origin=None, relativize=True, **kw):
+ return self.address
+
+ @classmethod
+ def from_text(
+ cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
+ ):
+ address = tok.get_identifier()
+ return cls(rdclass, rdtype, address)
+
+ def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
+ file.write(dns.ipv4.inet_aton(self.address))
+
+ @classmethod
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ address = parser.get_remaining()
+ return cls(rdclass, rdtype, address)
diff --git a/dns/rdtypes/IN/AAAA.py b/dns/rdtypes/IN/AAAA.py
index b24c5fb..0cd139e 100644
--- a/dns/rdtypes/IN/AAAA.py
+++ b/dns/rdtypes/IN/AAAA.py
@@ -1,3 +1,20 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
import dns.exception
import dns.immutable
import dns.ipv6
@@ -8,8 +25,27 @@ import dns.tokenizer
@dns.immutable.immutable
class AAAA(dns.rdata.Rdata):
"""AAAA record."""
- __slots__ = ['address']
+
+ __slots__ = ["address"]
def __init__(self, rdclass, rdtype, address):
super().__init__(rdclass, rdtype)
self.address = self._as_ipv6_address(address)
+
+ def to_text(self, origin=None, relativize=True, **kw):
+ return self.address
+
+ @classmethod
+ def from_text(
+ cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
+ ):
+ address = tok.get_identifier()
+ return cls(rdclass, rdtype, address)
+
+ def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
+ file.write(dns.ipv6.inet_aton(self.address))
+
+ @classmethod
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ address = parser.get_remaining()
+ return cls(rdclass, rdtype, address)
diff --git a/dns/rdtypes/IN/APL.py b/dns/rdtypes/IN/APL.py
index 6572e67..44cb3fe 100644
--- a/dns/rdtypes/IN/APL.py
+++ b/dns/rdtypes/IN/APL.py
@@ -1,6 +1,24 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2017 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
import binascii
import codecs
import struct
+
import dns.exception
import dns.immutable
import dns.ipv4
@@ -12,7 +30,8 @@ import dns.tokenizer
@dns.immutable.immutable
class APLItem:
"""An APL list item."""
- __slots__ = ['family', 'negation', 'address', 'prefix']
+
+ __slots__ = ["family", "negation", "address", "prefix"]
def __init__(self, family, negation, address, prefix):
self.family = dns.rdata.Rdata._as_uint16(family)
@@ -29,19 +48,103 @@ class APLItem:
def __str__(self):
if self.negation:
- return '!%d:%s/%s' % (self.family, self.address, self.prefix)
+ return "!%d:%s/%s" % (self.family, self.address, self.prefix)
+ else:
+ return "%d:%s/%s" % (self.family, self.address, self.prefix)
+
+ def to_wire(self, file):
+ if self.family == 1:
+ address = dns.ipv4.inet_aton(self.address)
+ elif self.family == 2:
+ address = dns.ipv6.inet_aton(self.address)
else:
- return '%d:%s/%s' % (self.family, self.address, self.prefix)
+ address = binascii.unhexlify(self.address)
+ #
+ # Truncate least significant zero bytes.
+ #
+ last = 0
+ for i in range(len(address) - 1, -1, -1):
+ if address[i] != 0:
+ last = i + 1
+ break
+ address = address[0:last]
+ l = len(address)
+ assert l < 128
+ if self.negation:
+ l |= 0x80
+ header = struct.pack("!HBB", self.family, self.prefix, l)
+ file.write(header)
+ file.write(address)
@dns.immutable.immutable
class APL(dns.rdata.Rdata):
"""APL record."""
- __slots__ = ['items']
+
+ # see: RFC 3123
+
+ __slots__ = ["items"]
def __init__(self, rdclass, rdtype, items):
super().__init__(rdclass, rdtype)
for item in items:
if not isinstance(item, APLItem):
- raise ValueError('item not an APLItem')
+ raise ValueError("item not an APLItem")
self.items = tuple(items)
+
+ def to_text(self, origin=None, relativize=True, **kw):
+ return " ".join(map(str, self.items))
+
+ @classmethod
+ def from_text(
+ cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
+ ):
+ items = []
+ for token in tok.get_remaining():
+ item = token.unescape().value
+ if item[0] == "!":
+ negation = True
+ item = item[1:]
+ else:
+ negation = False
+ (family, rest) = item.split(":", 1)
+ family = int(family)
+ (address, prefix) = rest.split("/", 1)
+ prefix = int(prefix)
+ item = APLItem(family, negation, address, prefix)
+ items.append(item)
+
+ return cls(rdclass, rdtype, items)
+
+ def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
+ for item in self.items:
+ item.to_wire(file)
+
+ @classmethod
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ items = []
+ while parser.remaining() > 0:
+ header = parser.get_struct("!HBB")
+ afdlen = header[2]
+ if afdlen > 127:
+ negation = True
+ afdlen -= 128
+ else:
+ negation = False
+ address = parser.get_bytes(afdlen)
+ l = len(address)
+ if header[0] == 1:
+ if l < 4:
+ address += b"\x00" * (4 - l)
+ elif header[0] == 2:
+ if l < 16:
+ address += b"\x00" * (16 - l)
+ else:
+ #
+ # This isn't really right according to the RFC, but it
+ # seems better than throwing an exception
+ #
+ address = codecs.encode(address, "hex_codec")
+ item = APLItem(header[0], negation, address, header[1])
+ items.append(item)
+ return cls(rdclass, rdtype, items)
diff --git a/dns/rdtypes/IN/DHCID.py b/dns/rdtypes/IN/DHCID.py
index ab8928f..723492f 100644
--- a/dns/rdtypes/IN/DHCID.py
+++ b/dns/rdtypes/IN/DHCID.py
@@ -1,4 +1,22 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2006, 2007, 2009-2011 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
import base64
+
import dns.exception
import dns.immutable
import dns.rdata
@@ -7,8 +25,30 @@ import dns.rdata
@dns.immutable.immutable
class DHCID(dns.rdata.Rdata):
"""DHCID record"""
- __slots__ = ['data']
+
+ # see: RFC 4701
+
+ __slots__ = ["data"]
def __init__(self, rdclass, rdtype, data):
super().__init__(rdclass, rdtype)
self.data = self._as_bytes(data)
+
+ def to_text(self, origin=None, relativize=True, **kw):
+ return dns.rdata._base64ify(self.data, **kw)
+
+ @classmethod
+ def from_text(
+ cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
+ ):
+ b64 = tok.concatenate_remaining_identifiers().encode()
+ data = base64.b64decode(b64)
+ return cls(rdclass, rdtype, data)
+
+ def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
+ file.write(self.data)
+
+ @classmethod
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ data = parser.get_remaining()
+ return cls(rdclass, rdtype, data)
diff --git a/dns/rdtypes/IN/HTTPS.py b/dns/rdtypes/IN/HTTPS.py
index 4e56d14..15464cb 100644
--- a/dns/rdtypes/IN/HTTPS.py
+++ b/dns/rdtypes/IN/HTTPS.py
@@ -1,3 +1,5 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
import dns.immutable
import dns.rdtypes.svcbbase
diff --git a/dns/rdtypes/IN/IPSECKEY.py b/dns/rdtypes/IN/IPSECKEY.py
index cb4b002..e3a6615 100644
--- a/dns/rdtypes/IN/IPSECKEY.py
+++ b/dns/rdtypes/IN/IPSECKEY.py
@@ -1,21 +1,43 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2006, 2007, 2009-2011 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
import base64
import struct
+
import dns.exception
import dns.immutable
import dns.rdtypes.util
class Gateway(dns.rdtypes.util.Gateway):
- name = 'IPSECKEY gateway'
+ name = "IPSECKEY gateway"
@dns.immutable.immutable
class IPSECKEY(dns.rdata.Rdata):
"""IPSECKEY record"""
- __slots__ = ['precedence', 'gateway_type', 'algorithm', 'gateway', 'key']
- def __init__(self, rdclass, rdtype, precedence, gateway_type, algorithm,
- gateway, key):
+ # see: RFC 4025
+
+ __slots__ = ["precedence", "gateway_type", "algorithm", "gateway", "key"]
+
+ def __init__(
+ self, rdclass, rdtype, precedence, gateway_type, algorithm, gateway, key
+ ):
super().__init__(rdclass, rdtype)
gateway = Gateway(gateway_type, gateway)
self.precedence = self._as_uint8(precedence)
@@ -23,3 +45,47 @@ class IPSECKEY(dns.rdata.Rdata):
self.algorithm = self._as_uint8(algorithm)
self.gateway = gateway.gateway
self.key = self._as_bytes(key)
+
+ def to_text(self, origin=None, relativize=True, **kw):
+ gateway = Gateway(self.gateway_type, self.gateway).to_text(origin, relativize)
+ return "%d %d %d %s %s" % (
+ self.precedence,
+ self.gateway_type,
+ self.algorithm,
+ gateway,
+ dns.rdata._base64ify(self.key, **kw),
+ )
+
+ @classmethod
+ def from_text(
+ cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
+ ):
+ precedence = tok.get_uint8()
+ gateway_type = tok.get_uint8()
+ algorithm = tok.get_uint8()
+ gateway = Gateway.from_text(
+ gateway_type, tok, origin, relativize, relativize_to
+ )
+ b64 = tok.concatenate_remaining_identifiers().encode()
+ key = base64.b64decode(b64)
+ return cls(
+ rdclass, rdtype, precedence, gateway_type, algorithm, gateway.gateway, key
+ )
+
+ def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
+ header = struct.pack("!BBB", self.precedence, self.gateway_type, self.algorithm)
+ file.write(header)
+ Gateway(self.gateway_type, self.gateway).to_wire(
+ file, compress, origin, canonicalize
+ )
+ file.write(self.key)
+
+ @classmethod
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ header = parser.get_struct("!BBB")
+ gateway_type = header[1]
+ gateway = Gateway.from_wire_parser(gateway_type, parser, origin)
+ key = parser.get_remaining()
+ return cls(
+ rdclass, rdtype, header[0], gateway_type, header[2], gateway.gateway, key
+ )
diff --git a/dns/rdtypes/IN/KX.py b/dns/rdtypes/IN/KX.py
index b6f7705..6073df4 100644
--- a/dns/rdtypes/IN/KX.py
+++ b/dns/rdtypes/IN/KX.py
@@ -1,3 +1,20 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
import dns.immutable
import dns.rdtypes.mxbase
diff --git a/dns/rdtypes/IN/NAPTR.py b/dns/rdtypes/IN/NAPTR.py
index eb0966e..195d1cb 100644
--- a/dns/rdtypes/IN/NAPTR.py
+++ b/dns/rdtypes/IN/NAPTR.py
@@ -1,4 +1,22 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
import struct
+
import dns.exception
import dns.immutable
import dns.name
@@ -6,14 +24,24 @@ import dns.rdata
import dns.rdtypes.util
+def _write_string(file, s):
+ l = len(s)
+ assert l < 256
+ file.write(struct.pack("!B", l))
+ file.write(s)
+
+
@dns.immutable.immutable
class NAPTR(dns.rdata.Rdata):
"""NAPTR record"""
- __slots__ = ['order', 'preference', 'flags', 'service', 'regexp',
- 'replacement']
- def __init__(self, rdclass, rdtype, order, preference, flags, service,
- regexp, replacement):
+ # see: RFC 3403
+
+ __slots__ = ["order", "preference", "flags", "service", "regexp", "replacement"]
+
+ def __init__(
+ self, rdclass, rdtype, order, preference, flags, service, regexp, replacement
+ ):
super().__init__(rdclass, rdtype)
self.flags = self._as_bytes(flags, True, 255)
self.service = self._as_bytes(service, True, 255)
@@ -21,3 +49,62 @@ class NAPTR(dns.rdata.Rdata):
self.order = self._as_uint16(order)
self.preference = self._as_uint16(preference)
self.replacement = self._as_name(replacement)
+
+ def to_text(self, origin=None, relativize=True, **kw):
+ replacement = self.replacement.choose_relativity(origin, relativize)
+ return '%d %d "%s" "%s" "%s" %s' % (
+ self.order,
+ self.preference,
+ dns.rdata._escapify(self.flags),
+ dns.rdata._escapify(self.service),
+ dns.rdata._escapify(self.regexp),
+ replacement,
+ )
+
+ @classmethod
+ def from_text(
+ cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
+ ):
+ order = tok.get_uint16()
+ preference = tok.get_uint16()
+ flags = tok.get_string()
+ service = tok.get_string()
+ regexp = tok.get_string()
+ replacement = tok.get_name(origin, relativize, relativize_to)
+ return cls(
+ rdclass, rdtype, order, preference, flags, service, regexp, replacement
+ )
+
+ def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
+ two_ints = struct.pack("!HH", self.order, self.preference)
+ file.write(two_ints)
+ _write_string(file, self.flags)
+ _write_string(file, self.service)
+ _write_string(file, self.regexp)
+ self.replacement.to_wire(file, compress, origin, canonicalize)
+
+ @classmethod
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ (order, preference) = parser.get_struct("!HH")
+ strings = []
+ for _ in range(3):
+ s = parser.get_counted_bytes()
+ strings.append(s)
+ replacement = parser.get_name(origin)
+ return cls(
+ rdclass,
+ rdtype,
+ order,
+ preference,
+ strings[0],
+ strings[1],
+ strings[2],
+ replacement,
+ )
+
+ def _processing_priority(self):
+ return (self.order, self.preference)
+
+ @classmethod
+ def _processing_order(cls, iterable):
+ return dns.rdtypes.util.priority_processing_order(iterable)
diff --git a/dns/rdtypes/IN/NSAP.py b/dns/rdtypes/IN/NSAP.py
index 7af3b5e..a4854b3 100644
--- a/dns/rdtypes/IN/NSAP.py
+++ b/dns/rdtypes/IN/NSAP.py
@@ -1,4 +1,22 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
import binascii
+
import dns.exception
import dns.immutable
import dns.rdata
@@ -8,8 +26,35 @@ import dns.tokenizer
@dns.immutable.immutable
class NSAP(dns.rdata.Rdata):
"""NSAP record."""
- __slots__ = ['address']
+
+ # see: RFC 1706
+
+ __slots__ = ["address"]
def __init__(self, rdclass, rdtype, address):
super().__init__(rdclass, rdtype)
self.address = self._as_bytes(address)
+
+ def to_text(self, origin=None, relativize=True, **kw):
+ return "0x%s" % binascii.hexlify(self.address).decode()
+
+ @classmethod
+ def from_text(
+ cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
+ ):
+ address = tok.get_string()
+ if address[0:2] != "0x":
+ raise dns.exception.SyntaxError("string does not start with 0x")
+ address = address[2:].replace(".", "")
+ if len(address) % 2 != 0:
+ raise dns.exception.SyntaxError("hexstring has odd length")
+ address = binascii.unhexlify(address.encode())
+ return cls(rdclass, rdtype, address)
+
+ def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
+ file.write(self.address)
+
+ @classmethod
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ address = parser.get_remaining()
+ return cls(rdclass, rdtype, address)
diff --git a/dns/rdtypes/IN/NSAP_PTR.py b/dns/rdtypes/IN/NSAP_PTR.py
index 6b64d23..ce1c663 100644
--- a/dns/rdtypes/IN/NSAP_PTR.py
+++ b/dns/rdtypes/IN/NSAP_PTR.py
@@ -1,3 +1,20 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
import dns.immutable
import dns.rdtypes.nsbase
diff --git a/dns/rdtypes/IN/PX.py b/dns/rdtypes/IN/PX.py
index 59ed238..cdca153 100644
--- a/dns/rdtypes/IN/PX.py
+++ b/dns/rdtypes/IN/PX.py
@@ -1,4 +1,22 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
import struct
+
import dns.exception
import dns.immutable
import dns.name
@@ -9,10 +27,47 @@ import dns.rdtypes.util
@dns.immutable.immutable
class PX(dns.rdata.Rdata):
"""PX record."""
- __slots__ = ['preference', 'map822', 'mapx400']
+
+ # see: RFC 2163
+
+ __slots__ = ["preference", "map822", "mapx400"]
def __init__(self, rdclass, rdtype, preference, map822, mapx400):
super().__init__(rdclass, rdtype)
self.preference = self._as_uint16(preference)
self.map822 = self._as_name(map822)
self.mapx400 = self._as_name(mapx400)
+
+ def to_text(self, origin=None, relativize=True, **kw):
+ map822 = self.map822.choose_relativity(origin, relativize)
+ mapx400 = self.mapx400.choose_relativity(origin, relativize)
+ return "%d %s %s" % (self.preference, map822, mapx400)
+
+ @classmethod
+ def from_text(
+ cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
+ ):
+ preference = tok.get_uint16()
+ map822 = tok.get_name(origin, relativize, relativize_to)
+ mapx400 = tok.get_name(origin, relativize, relativize_to)
+ return cls(rdclass, rdtype, preference, map822, mapx400)
+
+ def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
+ pref = struct.pack("!H", self.preference)
+ file.write(pref)
+ self.map822.to_wire(file, None, origin, canonicalize)
+ self.mapx400.to_wire(file, None, origin, canonicalize)
+
+ @classmethod
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ preference = parser.get_uint16()
+ map822 = parser.get_name(origin)
+ mapx400 = parser.get_name(origin)
+ return cls(rdclass, rdtype, preference, map822, mapx400)
+
+ def _processing_priority(self):
+ return self.preference
+
+ @classmethod
+ def _processing_order(cls, iterable):
+ return dns.rdtypes.util.priority_processing_order(iterable)
diff --git a/dns/rdtypes/IN/SRV.py b/dns/rdtypes/IN/SRV.py
index 014a72f..5adef98 100644
--- a/dns/rdtypes/IN/SRV.py
+++ b/dns/rdtypes/IN/SRV.py
@@ -1,4 +1,22 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
import struct
+
import dns.exception
import dns.immutable
import dns.name
@@ -9,7 +27,10 @@ import dns.rdtypes.util
@dns.immutable.immutable
class SRV(dns.rdata.Rdata):
"""SRV record"""
- __slots__ = ['priority', 'weight', 'port', 'target']
+
+ # see: RFC 2782
+
+ __slots__ = ["priority", "weight", "port", "target"]
def __init__(self, rdclass, rdtype, priority, weight, port, target):
super().__init__(rdclass, rdtype)
@@ -17,3 +38,38 @@ class SRV(dns.rdata.Rdata):
self.weight = self._as_uint16(weight)
self.port = self._as_uint16(port)
self.target = self._as_name(target)
+
+ def to_text(self, origin=None, relativize=True, **kw):
+ target = self.target.choose_relativity(origin, relativize)
+ return "%d %d %d %s" % (self.priority, self.weight, self.port, target)
+
+ @classmethod
+ def from_text(
+ cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
+ ):
+ priority = tok.get_uint16()
+ weight = tok.get_uint16()
+ port = tok.get_uint16()
+ target = tok.get_name(origin, relativize, relativize_to)
+ return cls(rdclass, rdtype, priority, weight, port, target)
+
+ def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
+ three_ints = struct.pack("!HHH", self.priority, self.weight, self.port)
+ file.write(three_ints)
+ self.target.to_wire(file, compress, origin, canonicalize)
+
+ @classmethod
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ (priority, weight, port) = parser.get_struct("!HHH")
+ target = parser.get_name(origin)
+ return cls(rdclass, rdtype, priority, weight, port, target)
+
+ def _processing_priority(self):
+ return self.priority
+
+ def _processing_weight(self):
+ return self.weight
+
+ @classmethod
+ def _processing_order(cls, iterable):
+ return dns.rdtypes.util.weighted_processing_order(iterable)
diff --git a/dns/rdtypes/IN/SVCB.py b/dns/rdtypes/IN/SVCB.py
index a9446c8..ff3e932 100644
--- a/dns/rdtypes/IN/SVCB.py
+++ b/dns/rdtypes/IN/SVCB.py
@@ -1,3 +1,5 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
import dns.immutable
import dns.rdtypes.svcbbase
diff --git a/dns/rdtypes/IN/WKS.py b/dns/rdtypes/IN/WKS.py
index 9bb41d9..881a784 100644
--- a/dns/rdtypes/IN/WKS.py
+++ b/dns/rdtypes/IN/WKS.py
@@ -1,12 +1,32 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
import socket
import struct
+
import dns.immutable
import dns.ipv4
import dns.rdata
+
try:
- _proto_tcp = socket.getprotobyname('tcp')
- _proto_udp = socket.getprotobyname('udp')
+ _proto_tcp = socket.getprotobyname("tcp")
+ _proto_udp = socket.getprotobyname("udp")
except OSError:
+ # Fall back to defaults in case /etc/protocols is unavailable.
_proto_tcp = 6
_proto_udp = 17
@@ -14,10 +34,67 @@ except OSError:
@dns.immutable.immutable
class WKS(dns.rdata.Rdata):
"""WKS record"""
- __slots__ = ['address', 'protocol', 'bitmap']
+
+ # see: RFC 1035
+
+ __slots__ = ["address", "protocol", "bitmap"]
def __init__(self, rdclass, rdtype, address, protocol, bitmap):
super().__init__(rdclass, rdtype)
self.address = self._as_ipv4_address(address)
self.protocol = self._as_uint8(protocol)
self.bitmap = self._as_bytes(bitmap)
+
+ def to_text(self, origin=None, relativize=True, **kw):
+ bits = []
+ for i, byte in enumerate(self.bitmap):
+ for j in range(0, 8):
+ if byte & (0x80 >> j):
+ bits.append(str(i * 8 + j))
+ text = " ".join(bits)
+ return "%s %d %s" % (self.address, self.protocol, text)
+
+ @classmethod
+ def from_text(
+ cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
+ ):
+ address = tok.get_string()
+ protocol = tok.get_string()
+ if protocol.isdigit():
+ protocol = int(protocol)
+ else:
+ protocol = socket.getprotobyname(protocol)
+ bitmap = bytearray()
+ for token in tok.get_remaining():
+ value = token.unescape().value
+ if value.isdigit():
+ serv = int(value)
+ else:
+ if protocol != _proto_udp and protocol != _proto_tcp:
+ raise NotImplementedError("protocol must be TCP or UDP")
+ if protocol == _proto_udp:
+ protocol_text = "udp"
+ else:
+ protocol_text = "tcp"
+ serv = socket.getservbyname(value, protocol_text)
+ i = serv // 8
+ l = len(bitmap)
+ if l < i + 1:
+ for _ in range(l, i + 1):
+ bitmap.append(0)
+ bitmap[i] = bitmap[i] | (0x80 >> (serv % 8))
+ bitmap = dns.rdata._truncate_bitmap(bitmap)
+ return cls(rdclass, rdtype, address, protocol, bitmap)
+
+ def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
+ file.write(dns.ipv4.inet_aton(self.address))
+ protocol = struct.pack("!B", self.protocol)
+ file.write(protocol)
+ file.write(self.bitmap)
+
+ @classmethod
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ address = parser.get_bytes(4)
+ protocol = parser.get_uint8()
+ bitmap = parser.get_remaining()
+ return cls(rdclass, rdtype, address, protocol, bitmap)
diff --git a/dns/rdtypes/dnskeybase.py b/dns/rdtypes/dnskeybase.py
index 9fd68d0..db300f8 100644
--- a/dns/rdtypes/dnskeybase.py
+++ b/dns/rdtypes/dnskeybase.py
@@ -1,23 +1,44 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2004-2007, 2009-2011 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
import base64
import enum
import struct
+
import dns.dnssectypes
import dns.exception
import dns.immutable
import dns.rdata
-__all__ = ['SEP', 'REVOKE', 'ZONE']
+
+# wildcard import
+__all__ = ["SEP", "REVOKE", "ZONE"] # noqa: F822
class Flag(enum.IntFlag):
- SEP = 1
- REVOKE = 128
- ZONE = 256
+ SEP = 0x0001
+ REVOKE = 0x0080
+ ZONE = 0x0100
@dns.immutable.immutable
class DNSKEYBase(dns.rdata.Rdata):
"""Base class for rdata that is like a DNSKEY record"""
- __slots__ = ['flags', 'protocol', 'algorithm', 'key']
+
+ __slots__ = ["flags", "protocol", "algorithm", "key"]
def __init__(self, rdclass, rdtype, flags, protocol, algorithm, key):
super().__init__(rdclass, rdtype)
@@ -26,7 +47,41 @@ class DNSKEYBase(dns.rdata.Rdata):
self.algorithm = dns.dnssectypes.Algorithm.make(algorithm)
self.key = self._as_bytes(key)
+ def to_text(self, origin=None, relativize=True, **kw):
+ return "%d %d %d %s" % (
+ self.flags,
+ self.protocol,
+ self.algorithm,
+ dns.rdata._base64ify(self.key, **kw),
+ )
+
+ @classmethod
+ def from_text(
+ cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
+ ):
+ flags = tok.get_uint16()
+ protocol = tok.get_uint8()
+ algorithm = tok.get_string()
+ b64 = tok.concatenate_remaining_identifiers().encode()
+ key = base64.b64decode(b64)
+ return cls(rdclass, rdtype, flags, protocol, algorithm, key)
+
+ def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
+ header = struct.pack("!HBB", self.flags, self.protocol, self.algorithm)
+ file.write(header)
+ file.write(self.key)
+
+ @classmethod
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ header = parser.get_struct("!HBB")
+ key = parser.get_remaining()
+ return cls(rdclass, rdtype, header[0], header[1], header[2], key)
+
+
+### BEGIN generated Flag constants
SEP = Flag.SEP
REVOKE = Flag.REVOKE
ZONE = Flag.ZONE
+
+### END generated Flag constants
diff --git a/dns/rdtypes/dsbase.py b/dns/rdtypes/dsbase.py
index a394668..cd21f02 100644
--- a/dns/rdtypes/dsbase.py
+++ b/dns/rdtypes/dsbase.py
@@ -1,5 +1,23 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2010, 2011 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
import binascii
import struct
+
import dns.dnssectypes
import dns.immutable
import dns.rdata
@@ -9,21 +27,59 @@ import dns.rdatatype
@dns.immutable.immutable
class DSBase(dns.rdata.Rdata):
"""Base class for rdata that is like a DS record"""
- __slots__ = ['key_tag', 'algorithm', 'digest_type', 'digest']
- _digest_length_by_type = {(1): 20, (2): 32, (3): 32, (4): 48}
- def __init__(self, rdclass, rdtype, key_tag, algorithm, digest_type, digest
- ):
+ __slots__ = ["key_tag", "algorithm", "digest_type", "digest"]
+
+ # Digest types registry:
+ # https://www.iana.org/assignments/ds-rr-types/ds-rr-types.xhtml
+ _digest_length_by_type = {
+ 1: 20, # SHA-1, RFC 3658 Sec. 2.4
+ 2: 32, # SHA-256, RFC 4509 Sec. 2.2
+ 3: 32, # GOST R 34.11-94, RFC 5933 Sec. 4 in conjunction with RFC 4490 Sec. 2.1
+ 4: 48, # SHA-384, RFC 6605 Sec. 2
+ }
+
+ def __init__(self, rdclass, rdtype, key_tag, algorithm, digest_type, digest):
super().__init__(rdclass, rdtype)
self.key_tag = self._as_uint16(key_tag)
self.algorithm = dns.dnssectypes.Algorithm.make(algorithm)
- self.digest_type = dns.dnssectypes.DSDigest.make(self._as_uint8(
- digest_type))
+ self.digest_type = dns.dnssectypes.DSDigest.make(self._as_uint8(digest_type))
self.digest = self._as_bytes(digest)
try:
- if len(self.digest) != self._digest_length_by_type[self.digest_type
- ]:
- raise ValueError('digest length inconsistent with digest type')
+ if len(self.digest) != self._digest_length_by_type[self.digest_type]:
+ raise ValueError("digest length inconsistent with digest type")
except KeyError:
- if self.digest_type == 0:
- raise ValueError('digest type 0 is reserved')
+ if self.digest_type == 0: # reserved, RFC 3658 Sec. 2.4
+ raise ValueError("digest type 0 is reserved")
+
+ def to_text(self, origin=None, relativize=True, **kw):
+ kw = kw.copy()
+ chunksize = kw.pop("chunksize", 128)
+ return "%d %d %d %s" % (
+ self.key_tag,
+ self.algorithm,
+ self.digest_type,
+ dns.rdata._hexify(self.digest, chunksize=chunksize, **kw),
+ )
+
+ @classmethod
+ def from_text(
+ cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
+ ):
+ key_tag = tok.get_uint16()
+ algorithm = tok.get_string()
+ digest_type = tok.get_uint8()
+ digest = tok.concatenate_remaining_identifiers().encode()
+ digest = binascii.unhexlify(digest)
+ return cls(rdclass, rdtype, key_tag, algorithm, digest_type, digest)
+
+ def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
+ header = struct.pack("!HBB", self.key_tag, self.algorithm, self.digest_type)
+ file.write(header)
+ file.write(self.digest)
+
+ @classmethod
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ header = parser.get_struct("!HBB")
+ digest = parser.get_remaining()
+ return cls(rdclass, rdtype, header[0], header[1], header[2], digest)
diff --git a/dns/rdtypes/euibase.py b/dns/rdtypes/euibase.py
index dc72327..751087b 100644
--- a/dns/rdtypes/euibase.py
+++ b/dns/rdtypes/euibase.py
@@ -1,4 +1,21 @@
+# Copyright (C) 2015 Red Hat, Inc.
+# Author: Petr Spacek <pspacek@redhat.com>
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED 'AS IS' AND RED HAT DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
import binascii
+
import dns.immutable
import dns.rdata
@@ -6,12 +23,48 @@ import dns.rdata
@dns.immutable.immutable
class EUIBase(dns.rdata.Rdata):
"""EUIxx record"""
- __slots__ = ['eui']
+
+ # see: rfc7043.txt
+
+ __slots__ = ["eui"]
+ # define these in subclasses
+ # byte_len = 6 # 0123456789ab (in hex)
+ # text_len = byte_len * 3 - 1 # 01-23-45-67-89-ab
def __init__(self, rdclass, rdtype, eui):
super().__init__(rdclass, rdtype)
self.eui = self._as_bytes(eui)
if len(self.eui) != self.byte_len:
raise dns.exception.FormError(
- 'EUI%s rdata has to have %s bytes' % (self.byte_len * 8,
- self.byte_len))
+ "EUI%s rdata has to have %s bytes" % (self.byte_len * 8, self.byte_len)
+ )
+
+ def to_text(self, origin=None, relativize=True, **kw):
+ return dns.rdata._hexify(self.eui, chunksize=2, separator=b"-", **kw)
+
+ @classmethod
+ def from_text(
+ cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
+ ):
+ text = tok.get_string()
+ if len(text) != cls.text_len:
+ raise dns.exception.SyntaxError(
+ "Input text must have %s characters" % cls.text_len
+ )
+ for i in range(2, cls.byte_len * 3 - 1, 3):
+ if text[i] != "-":
+ raise dns.exception.SyntaxError("Dash expected at position %s" % i)
+ text = text.replace("-", "")
+ try:
+ data = binascii.unhexlify(text.encode())
+ except (ValueError, TypeError) as ex:
+ raise dns.exception.SyntaxError("Hex decoding error: %s" % str(ex))
+ return cls(rdclass, rdtype, data)
+
+ def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
+ file.write(self.eui)
+
+ @classmethod
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ eui = parser.get_bytes(cls.byte_len)
+ return cls(rdclass, rdtype, eui)
diff --git a/dns/rdtypes/mxbase.py b/dns/rdtypes/mxbase.py
index 47ac142..6d5e3d8 100644
--- a/dns/rdtypes/mxbase.py
+++ b/dns/rdtypes/mxbase.py
@@ -1,5 +1,24 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
"""MX-like base classes."""
+
import struct
+
import dns.exception
import dns.immutable
import dns.name
@@ -10,13 +29,44 @@ import dns.rdtypes.util
@dns.immutable.immutable
class MXBase(dns.rdata.Rdata):
"""Base class for rdata that is like an MX record."""
- __slots__ = ['preference', 'exchange']
+
+ __slots__ = ["preference", "exchange"]
def __init__(self, rdclass, rdtype, preference, exchange):
super().__init__(rdclass, rdtype)
self.preference = self._as_uint16(preference)
self.exchange = self._as_name(exchange)
+ def to_text(self, origin=None, relativize=True, **kw):
+ exchange = self.exchange.choose_relativity(origin, relativize)
+ return "%d %s" % (self.preference, exchange)
+
+ @classmethod
+ def from_text(
+ cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
+ ):
+ preference = tok.get_uint16()
+ exchange = tok.get_name(origin, relativize, relativize_to)
+ return cls(rdclass, rdtype, preference, exchange)
+
+ def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
+ pref = struct.pack("!H", self.preference)
+ file.write(pref)
+ self.exchange.to_wire(file, compress, origin, canonicalize)
+
+ @classmethod
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ preference = parser.get_uint16()
+ exchange = parser.get_name(origin)
+ return cls(rdclass, rdtype, preference, exchange)
+
+ def _processing_priority(self):
+ return self.preference
+
+ @classmethod
+ def _processing_order(cls, iterable):
+ return dns.rdtypes.util.priority_processing_order(iterable)
+
@dns.immutable.immutable
class UncompressedMX(MXBase):
@@ -24,8 +74,14 @@ class UncompressedMX(MXBase):
is not compressed when converted to DNS wire format, and whose
digestable form is not downcased."""
+ def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
+ super()._to_wire(file, None, origin, False)
+
@dns.immutable.immutable
class UncompressedDowncasingMX(MXBase):
"""Base class for rdata that is like an MX record, but whose name
is not compressed when convert to DNS wire format."""
+
+ def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
+ super()._to_wire(file, None, origin, canonicalize)
diff --git a/dns/rdtypes/nsbase.py b/dns/rdtypes/nsbase.py
index 21eef60..904224f 100644
--- a/dns/rdtypes/nsbase.py
+++ b/dns/rdtypes/nsbase.py
@@ -1,4 +1,22 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
"""NS-like base classes."""
+
import dns.exception
import dns.immutable
import dns.name
@@ -8,15 +26,38 @@ import dns.rdata
@dns.immutable.immutable
class NSBase(dns.rdata.Rdata):
"""Base class for rdata that is like an NS record."""
- __slots__ = ['target']
+
+ __slots__ = ["target"]
def __init__(self, rdclass, rdtype, target):
super().__init__(rdclass, rdtype)
self.target = self._as_name(target)
+ def to_text(self, origin=None, relativize=True, **kw):
+ target = self.target.choose_relativity(origin, relativize)
+ return str(target)
+
+ @classmethod
+ def from_text(
+ cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
+ ):
+ target = tok.get_name(origin, relativize, relativize_to)
+ return cls(rdclass, rdtype, target)
+
+ def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
+ self.target.to_wire(file, compress, origin, canonicalize)
+
+ @classmethod
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ target = parser.get_name(origin)
+ return cls(rdclass, rdtype, target)
+
@dns.immutable.immutable
class UncompressedNS(NSBase):
"""Base class for rdata that is like an NS record, but whose name
is not compressed when convert to DNS wire format, and whose
digestable form is not downcased."""
+
+ def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
+ self.target.to_wire(file, None, origin, False)
diff --git a/dns/rdtypes/svcbbase.py b/dns/rdtypes/svcbbase.py
index e82b8fd..0565241 100644
--- a/dns/rdtypes/svcbbase.py
+++ b/dns/rdtypes/svcbbase.py
@@ -1,6 +1,9 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
import base64
import enum
import struct
+
import dns.enum
import dns.exception
import dns.immutable
@@ -13,6 +16,9 @@ import dns.renderer
import dns.tokenizer
import dns.wire
+# Until there is an RFC, this module is experimental and may be changed in
+# incompatible ways.
+
class UnknownParamKey(dns.exception.DNSException):
"""Unknown SVCB ParamKey"""
@@ -20,6 +26,7 @@ class UnknownParamKey(dns.exception.DNSException):
class ParamKey(dns.enum.IntEnum):
"""SVCB ParamKey"""
+
MANDATORY = 0
ALPN = 1
NO_DEFAULT_ALPN = 2
@@ -29,6 +36,22 @@ class ParamKey(dns.enum.IntEnum):
IPV6HINT = 6
DOHPATH = 7
+ @classmethod
+ def _maximum(cls):
+ return 65535
+
+ @classmethod
+ def _short_name(cls):
+ return "SVCBParamKey"
+
+ @classmethod
+ def _prefix(cls):
+ return "KEY"
+
+ @classmethod
+ def _unknown_exception_class(cls):
+ return UnknownParamKey
+
class Emptiness(enum.IntEnum):
NEVER = 0
@@ -36,13 +59,108 @@ class Emptiness(enum.IntEnum):
ALLOWED = 2
+def _validate_key(key):
+ force_generic = False
+ if isinstance(key, bytes):
+ # We decode to latin-1 so we get 0-255 as valid and do NOT interpret
+ # UTF-8 sequences
+ key = key.decode("latin-1")
+ if isinstance(key, str):
+ if key.lower().startswith("key"):
+ force_generic = True
+ if key[3:].startswith("0") and len(key) != 4:
+ # key has leading zeros
+ raise ValueError("leading zeros in key")
+ key = key.replace("-", "_")
+ return (ParamKey.make(key), force_generic)
+
+
+def key_to_text(key):
+ return ParamKey.to_text(key).replace("_", "-").lower()
+
+
+# Like rdata escapify, but escapes ',' too.
+
_escaped = b'",\\'
+def _escapify(qstring):
+ text = ""
+ for c in qstring:
+ if c in _escaped:
+ text += "\\" + chr(c)
+ elif c >= 0x20 and c < 0x7F:
+ text += chr(c)
+ else:
+ text += "\\%03d" % c
+ return text
+
+
+def _unescape(value):
+ if value == "":
+ return value
+ unescaped = b""
+ l = len(value)
+ i = 0
+ while i < l:
+ c = value[i]
+ i += 1
+ if c == "\\":
+ if i >= l: # pragma: no cover (can't happen via tokenizer get())
+ raise dns.exception.UnexpectedEnd
+ c = value[i]
+ i += 1
+ if c.isdigit():
+ if i >= l:
+ raise dns.exception.UnexpectedEnd
+ c2 = value[i]
+ i += 1
+ if i >= l:
+ raise dns.exception.UnexpectedEnd
+ c3 = value[i]
+ i += 1
+ if not (c2.isdigit() and c3.isdigit()):
+ raise dns.exception.SyntaxError
+ codepoint = int(c) * 100 + int(c2) * 10 + int(c3)
+ if codepoint > 255:
+ raise dns.exception.SyntaxError
+ unescaped += b"%c" % (codepoint)
+ continue
+ unescaped += c.encode()
+ return unescaped
+
+
+def _split(value):
+ l = len(value)
+ i = 0
+ items = []
+ unescaped = b""
+ while i < l:
+ c = value[i]
+ i += 1
+ if c == ord("\\"):
+ if i >= l: # pragma: no cover (can't happen via tokenizer get())
+ raise dns.exception.UnexpectedEnd
+ c = value[i]
+ i += 1
+ unescaped += b"%c" % (c)
+ elif c == ord(","):
+ items.append(unescaped)
+ unescaped = b""
+ else:
+ unescaped += b"%c" % (c)
+ items.append(unescaped)
+ return items
+
+
@dns.immutable.immutable
class Param:
"""Abstract base class for SVCB parameters"""
+ @classmethod
+ def emptiness(cls):
+ return Emptiness.NEVER
+
@dns.immutable.immutable
class GenericParam(Param):
@@ -51,75 +169,269 @@ class GenericParam(Param):
def __init__(self, value):
self.value = dns.rdata.Rdata._as_bytes(value, True)
+ @classmethod
+ def emptiness(cls):
+ return Emptiness.ALLOWED
+
+ @classmethod
+ def from_value(cls, value):
+ if value is None or len(value) == 0:
+ return None
+ else:
+ return cls(_unescape(value))
+
+ def to_text(self):
+ return '"' + dns.rdata._escapify(self.value) + '"'
+
+ @classmethod
+ def from_wire_parser(cls, parser, origin=None): # pylint: disable=W0613
+ value = parser.get_bytes(parser.remaining())
+ if len(value) == 0:
+ return None
+ else:
+ return cls(value)
+
+ def to_wire(self, file, origin=None): # pylint: disable=W0613
+ file.write(self.value)
+
@dns.immutable.immutable
class MandatoryParam(Param):
-
def __init__(self, keys):
+ # check for duplicates
keys = sorted([_validate_key(key)[0] for key in keys])
prior_k = None
for k in keys:
if k == prior_k:
- raise ValueError(f'duplicate key {k:d}')
+ raise ValueError(f"duplicate key {k:d}")
prior_k = k
if k == ParamKey.MANDATORY:
- raise ValueError('listed the mandatory key as mandatory')
+ raise ValueError("listed the mandatory key as mandatory")
self.keys = tuple(keys)
+ @classmethod
+ def from_value(cls, value):
+ keys = [k.encode() for k in value.split(",")]
+ return cls(keys)
+
+ def to_text(self):
+ return '"' + ",".join([key_to_text(key) for key in self.keys]) + '"'
+
+ @classmethod
+ def from_wire_parser(cls, parser, origin=None): # pylint: disable=W0613
+ keys = []
+ last_key = -1
+ while parser.remaining() > 0:
+ key = parser.get_uint16()
+ if key < last_key:
+ raise dns.exception.FormError("manadatory keys not ascending")
+ last_key = key
+ keys.append(key)
+ return cls(keys)
+
+ def to_wire(self, file, origin=None): # pylint: disable=W0613
+ for key in self.keys:
+ file.write(struct.pack("!H", key))
+
@dns.immutable.immutable
class ALPNParam(Param):
-
def __init__(self, ids):
- self.ids = dns.rdata.Rdata._as_tuple(ids, lambda x: dns.rdata.Rdata
- ._as_bytes(x, True, 255, False))
+ self.ids = dns.rdata.Rdata._as_tuple(
+ ids, lambda x: dns.rdata.Rdata._as_bytes(x, True, 255, False)
+ )
+
+ @classmethod
+ def from_value(cls, value):
+ return cls(_split(_unescape(value)))
+
+ def to_text(self):
+ value = ",".join([_escapify(id) for id in self.ids])
+ return '"' + dns.rdata._escapify(value.encode()) + '"'
+
+ @classmethod
+ def from_wire_parser(cls, parser, origin=None): # pylint: disable=W0613
+ ids = []
+ while parser.remaining() > 0:
+ id = parser.get_counted_bytes()
+ ids.append(id)
+ return cls(ids)
+
+ def to_wire(self, file, origin=None): # pylint: disable=W0613
+ for id in self.ids:
+ file.write(struct.pack("!B", len(id)))
+ file.write(id)
@dns.immutable.immutable
class NoDefaultALPNParam(Param):
- pass
+ # We don't ever expect to instantiate this class, but we need
+ # a from_value() and a from_wire_parser(), so we just return None
+ # from the class methods when things are OK.
+
+ @classmethod
+ def emptiness(cls):
+ return Emptiness.ALWAYS
+
+ @classmethod
+ def from_value(cls, value):
+ if value is None or value == "":
+ return None
+ else:
+ raise ValueError("no-default-alpn with non-empty value")
+
+ def to_text(self):
+ raise NotImplementedError # pragma: no cover
+
+ @classmethod
+ def from_wire_parser(cls, parser, origin=None): # pylint: disable=W0613
+ if parser.remaining() != 0:
+ raise dns.exception.FormError
+ return None
+
+ def to_wire(self, file, origin=None): # pylint: disable=W0613
+ raise NotImplementedError # pragma: no cover
@dns.immutable.immutable
class PortParam(Param):
-
def __init__(self, port):
self.port = dns.rdata.Rdata._as_uint16(port)
+ @classmethod
+ def from_value(cls, value):
+ value = int(value)
+ return cls(value)
+
+ def to_text(self):
+ return f'"{self.port}"'
+
+ @classmethod
+ def from_wire_parser(cls, parser, origin=None): # pylint: disable=W0613
+ port = parser.get_uint16()
+ return cls(port)
+
+ def to_wire(self, file, origin=None): # pylint: disable=W0613
+ file.write(struct.pack("!H", self.port))
+
@dns.immutable.immutable
class IPv4HintParam(Param):
-
def __init__(self, addresses):
- self.addresses = dns.rdata.Rdata._as_tuple(addresses, dns.rdata.
- Rdata._as_ipv4_address)
+ self.addresses = dns.rdata.Rdata._as_tuple(
+ addresses, dns.rdata.Rdata._as_ipv4_address
+ )
+
+ @classmethod
+ def from_value(cls, value):
+ addresses = value.split(",")
+ return cls(addresses)
+
+ def to_text(self):
+ return '"' + ",".join(self.addresses) + '"'
+
+ @classmethod
+ def from_wire_parser(cls, parser, origin=None): # pylint: disable=W0613
+ addresses = []
+ while parser.remaining() > 0:
+ ip = parser.get_bytes(4)
+ addresses.append(dns.ipv4.inet_ntoa(ip))
+ return cls(addresses)
+
+ def to_wire(self, file, origin=None): # pylint: disable=W0613
+ for address in self.addresses:
+ file.write(dns.ipv4.inet_aton(address))
@dns.immutable.immutable
class IPv6HintParam(Param):
-
def __init__(self, addresses):
- self.addresses = dns.rdata.Rdata._as_tuple(addresses, dns.rdata.
- Rdata._as_ipv6_address)
+ self.addresses = dns.rdata.Rdata._as_tuple(
+ addresses, dns.rdata.Rdata._as_ipv6_address
+ )
+
+ @classmethod
+ def from_value(cls, value):
+ addresses = value.split(",")
+ return cls(addresses)
+
+ def to_text(self):
+ return '"' + ",".join(self.addresses) + '"'
+
+ @classmethod
+ def from_wire_parser(cls, parser, origin=None): # pylint: disable=W0613
+ addresses = []
+ while parser.remaining() > 0:
+ ip = parser.get_bytes(16)
+ addresses.append(dns.ipv6.inet_ntoa(ip))
+ return cls(addresses)
+
+ def to_wire(self, file, origin=None): # pylint: disable=W0613
+ for address in self.addresses:
+ file.write(dns.ipv6.inet_aton(address))
@dns.immutable.immutable
class ECHParam(Param):
-
def __init__(self, ech):
self.ech = dns.rdata.Rdata._as_bytes(ech, True)
-
-_class_for_key = {ParamKey.MANDATORY: MandatoryParam, ParamKey.ALPN:
- ALPNParam, ParamKey.NO_DEFAULT_ALPN: NoDefaultALPNParam, ParamKey.PORT:
- PortParam, ParamKey.IPV4HINT: IPv4HintParam, ParamKey.ECH: ECHParam,
- ParamKey.IPV6HINT: IPv6HintParam}
+ @classmethod
+ def from_value(cls, value):
+ if "\\" in value:
+ raise ValueError("escape in ECH value")
+ value = base64.b64decode(value.encode())
+ return cls(value)
+
+ def to_text(self):
+ b64 = base64.b64encode(self.ech).decode("ascii")
+ return f'"{b64}"'
+
+ @classmethod
+ def from_wire_parser(cls, parser, origin=None): # pylint: disable=W0613
+ value = parser.get_bytes(parser.remaining())
+ return cls(value)
+
+ def to_wire(self, file, origin=None): # pylint: disable=W0613
+ file.write(self.ech)
+
+
+_class_for_key = {
+ ParamKey.MANDATORY: MandatoryParam,
+ ParamKey.ALPN: ALPNParam,
+ ParamKey.NO_DEFAULT_ALPN: NoDefaultALPNParam,
+ ParamKey.PORT: PortParam,
+ ParamKey.IPV4HINT: IPv4HintParam,
+ ParamKey.ECH: ECHParam,
+ ParamKey.IPV6HINT: IPv6HintParam,
+}
+
+
+def _validate_and_define(params, key, value):
+ (key, force_generic) = _validate_key(_unescape(key))
+ if key in params:
+ raise SyntaxError(f'duplicate key "{key:d}"')
+ cls = _class_for_key.get(key, GenericParam)
+ emptiness = cls.emptiness()
+ if value is None:
+ if emptiness == Emptiness.NEVER:
+ raise SyntaxError("value cannot be empty")
+ value = cls.from_value(value)
+ else:
+ if force_generic:
+ value = cls.from_wire_parser(dns.wire.Parser(_unescape(value)))
+ else:
+ value = cls.from_value(value)
+ params[key] = value
@dns.immutable.immutable
class SVCBBase(dns.rdata.Rdata):
"""Base class for SVCB-like records"""
- __slots__ = ['priority', 'target', 'params']
+
+ # see: draft-ietf-dnsop-svcb-https-11
+
+ __slots__ = ["priority", "target", "params"]
def __init__(self, rdclass, rdtype, priority, target, params):
super().__init__(rdclass, rdtype)
@@ -128,14 +440,114 @@ class SVCBBase(dns.rdata.Rdata):
for k, v in params.items():
k = ParamKey.make(k)
if not isinstance(v, Param) and v is not None:
- raise ValueError(f'{k:d} not a Param')
+ raise ValueError(f"{k:d} not a Param")
self.params = dns.immutable.Dict(params)
+ # Make sure any parameter listed as mandatory is present in the
+ # record.
mandatory = params.get(ParamKey.MANDATORY)
if mandatory:
for key in mandatory.keys:
+ # Note we have to say "not in" as we have None as a value
+ # so a get() and a not None test would be wrong.
if key not in params:
- raise ValueError(
- f'key {key:d} declared mandatory but not present')
+ raise ValueError(f"key {key:d} declared mandatory but not present")
+ # The no-default-alpn parameter requires the alpn parameter.
if ParamKey.NO_DEFAULT_ALPN in params:
if ParamKey.ALPN not in params:
- raise ValueError('no-default-alpn present, but alpn missing')
+ raise ValueError("no-default-alpn present, but alpn missing")
+
+ def to_text(self, origin=None, relativize=True, **kw):
+ target = self.target.choose_relativity(origin, relativize)
+ params = []
+ for key in sorted(self.params.keys()):
+ value = self.params[key]
+ if value is None:
+ params.append(key_to_text(key))
+ else:
+ kv = key_to_text(key) + "=" + value.to_text()
+ params.append(kv)
+ if len(params) > 0:
+ space = " "
+ else:
+ space = ""
+ return "%d %s%s%s" % (self.priority, target, space, " ".join(params))
+
+ @classmethod
+ def from_text(
+ cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
+ ):
+ priority = tok.get_uint16()
+ target = tok.get_name(origin, relativize, relativize_to)
+ if priority == 0:
+ token = tok.get()
+ if not token.is_eol_or_eof():
+ raise SyntaxError("parameters in AliasMode")
+ tok.unget(token)
+ params = {}
+ while True:
+ token = tok.get()
+ if token.is_eol_or_eof():
+ tok.unget(token)
+ break
+ if token.ttype != dns.tokenizer.IDENTIFIER:
+ raise SyntaxError("parameter is not an identifier")
+ equals = token.value.find("=")
+ if equals == len(token.value) - 1:
+ # 'key=', so next token should be a quoted string without
+ # any intervening whitespace.
+ key = token.value[:-1]
+ token = tok.get(want_leading=True)
+ if token.ttype != dns.tokenizer.QUOTED_STRING:
+ raise SyntaxError("whitespace after =")
+ value = token.value
+ elif equals > 0:
+ # key=value
+ key = token.value[:equals]
+ value = token.value[equals + 1 :]
+ elif equals == 0:
+ # =key
+ raise SyntaxError('parameter cannot start with "="')
+ else:
+ # key
+ key = token.value
+ value = None
+ _validate_and_define(params, key, value)
+ return cls(rdclass, rdtype, priority, target, params)
+
+ def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
+ file.write(struct.pack("!H", self.priority))
+ self.target.to_wire(file, None, origin, False)
+ for key in sorted(self.params):
+ file.write(struct.pack("!H", key))
+ value = self.params[key]
+ with dns.renderer.prefixed_length(file, 2):
+ # Note that we're still writing a length of zero if the value is None
+ if value is not None:
+ value.to_wire(file, origin)
+
+ @classmethod
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ priority = parser.get_uint16()
+ target = parser.get_name(origin)
+ if priority == 0 and parser.remaining() != 0:
+ raise dns.exception.FormError("parameters in AliasMode")
+ params = {}
+ prior_key = -1
+ while parser.remaining() > 0:
+ key = parser.get_uint16()
+ if key < prior_key:
+ raise dns.exception.FormError("keys not in order")
+ prior_key = key
+ vlen = parser.get_uint16()
+ pcls = _class_for_key.get(key, GenericParam)
+ with parser.restrict_to(vlen):
+ value = pcls.from_wire_parser(parser, origin)
+ params[key] = value
+ return cls(rdclass, rdtype, priority, target, params)
+
+ def _processing_priority(self):
+ return self.priority
+
+ @classmethod
+ def _processing_order(cls, iterable):
+ return dns.rdtypes.util.priority_processing_order(iterable)
diff --git a/dns/rdtypes/tlsabase.py b/dns/rdtypes/tlsabase.py
index fc65a48..a059d2c 100644
--- a/dns/rdtypes/tlsabase.py
+++ b/dns/rdtypes/tlsabase.py
@@ -1,5 +1,23 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2005-2007, 2009-2011 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
import binascii
import struct
+
import dns.immutable
import dns.rdata
import dns.rdatatype
@@ -8,7 +26,10 @@ import dns.rdatatype
@dns.immutable.immutable
class TLSABase(dns.rdata.Rdata):
"""Base class for TLSA and SMIMEA records"""
- __slots__ = ['usage', 'selector', 'mtype', 'cert']
+
+ # see: RFC 6698
+
+ __slots__ = ["usage", "selector", "mtype", "cert"]
def __init__(self, rdclass, rdtype, usage, selector, mtype, cert):
super().__init__(rdclass, rdtype)
@@ -16,3 +37,35 @@ class TLSABase(dns.rdata.Rdata):
self.selector = self._as_uint8(selector)
self.mtype = self._as_uint8(mtype)
self.cert = self._as_bytes(cert)
+
+ def to_text(self, origin=None, relativize=True, **kw):
+ kw = kw.copy()
+ chunksize = kw.pop("chunksize", 128)
+ return "%d %d %d %s" % (
+ self.usage,
+ self.selector,
+ self.mtype,
+ dns.rdata._hexify(self.cert, chunksize=chunksize, **kw),
+ )
+
+ @classmethod
+ def from_text(
+ cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
+ ):
+ usage = tok.get_uint8()
+ selector = tok.get_uint8()
+ mtype = tok.get_uint8()
+ cert = tok.concatenate_remaining_identifiers().encode()
+ cert = binascii.unhexlify(cert)
+ return cls(rdclass, rdtype, usage, selector, mtype, cert)
+
+ def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
+ header = struct.pack("!BBB", self.usage, self.selector, self.mtype)
+ file.write(header)
+ file.write(self.cert)
+
+ @classmethod
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ header = parser.get_struct("BBB")
+ cert = parser.get_remaining()
+ return cls(rdclass, rdtype, header[0], header[1], header[2], cert)
diff --git a/dns/rdtypes/txtbase.py b/dns/rdtypes/txtbase.py
index 5fba0da..44d6df5 100644
--- a/dns/rdtypes/txtbase.py
+++ b/dns/rdtypes/txtbase.py
@@ -1,5 +1,24 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2006-2017 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
"""TXT-like base class."""
+
from typing import Any, Dict, Iterable, Optional, Tuple, Union
+
import dns.exception
import dns.immutable
import dns.rdata
@@ -10,10 +29,15 @@ import dns.tokenizer
@dns.immutable.immutable
class TXTBase(dns.rdata.Rdata):
"""Base class for rdata that is like a TXT record (see RFC 1035)."""
- __slots__ = ['strings']
- def __init__(self, rdclass: dns.rdataclass.RdataClass, rdtype: dns.
- rdatatype.RdataType, strings: Iterable[Union[bytes, str]]):
+ __slots__ = ["strings"]
+
+ def __init__(
+ self,
+ rdclass: dns.rdataclass.RdataClass,
+ rdtype: dns.rdatatype.RdataType,
+ strings: Iterable[Union[bytes, str]],
+ ):
"""Initialize a TXT-like rdata.
*rdclass*, an ``int`` is the rdataclass of the Rdata.
@@ -23,5 +47,58 @@ class TXTBase(dns.rdata.Rdata):
*strings*, a tuple of ``bytes``
"""
super().__init__(rdclass, rdtype)
- self.strings: Tuple[bytes] = self._as_tuple(strings, lambda x: self
- ._as_bytes(x, True, 255))
+ self.strings: Tuple[bytes] = self._as_tuple(
+ strings, lambda x: self._as_bytes(x, True, 255)
+ )
+
+ def to_text(
+ self,
+ origin: Optional[dns.name.Name] = None,
+ relativize: bool = True,
+ **kw: Dict[str, Any],
+ ) -> str:
+ txt = ""
+ prefix = ""
+ for s in self.strings:
+ txt += '{}"{}"'.format(prefix, dns.rdata._escapify(s))
+ prefix = " "
+ return txt
+
+ @classmethod
+ def from_text(
+ cls,
+ rdclass: dns.rdataclass.RdataClass,
+ rdtype: dns.rdatatype.RdataType,
+ tok: dns.tokenizer.Tokenizer,
+ origin: Optional[dns.name.Name] = None,
+ relativize: bool = True,
+ relativize_to: Optional[dns.name.Name] = None,
+ ) -> dns.rdata.Rdata:
+ strings = []
+ for token in tok.get_remaining():
+ token = token.unescape_to_bytes()
+ # The 'if' below is always true in the current code, but we
+ # are leaving this check in in case things change some day.
+ if not (
+ token.is_quoted_string() or token.is_identifier()
+ ): # pragma: no cover
+ raise dns.exception.SyntaxError("expected a string")
+ if len(token.value) > 255:
+ raise dns.exception.SyntaxError("string too long")
+ strings.append(token.value)
+ if len(strings) == 0:
+ raise dns.exception.UnexpectedEnd
+ return cls(rdclass, rdtype, strings)
+
+ def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
+ for s in self.strings:
+ with dns.renderer.prefixed_length(file, 1):
+ file.write(s)
+
+ @classmethod
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ strings = []
+ while parser.remaining() > 0:
+ s = parser.get_counted_bytes()
+ strings.append(s)
+ return cls(rdclass, rdtype, strings)
diff --git a/dns/rdtypes/util.py b/dns/rdtypes/util.py
index 920b9b3..54908fd 100644
--- a/dns/rdtypes/util.py
+++ b/dns/rdtypes/util.py
@@ -1,7 +1,25 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2006, 2007, 2009-2011 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
import collections
import random
import struct
from typing import Any, List
+
import dns.exception
import dns.ipv4
import dns.ipv6
@@ -11,33 +29,229 @@ import dns.rdata
class Gateway:
"""A helper class for the IPSECKEY gateway and AMTRELAY relay fields"""
- name = ''
+
+ name = ""
def __init__(self, type, gateway=None):
self.type = dns.rdata.Rdata._as_uint8(type)
self.gateway = gateway
self._check()
+ @classmethod
+ def _invalid_type(cls, gateway_type):
+ return f"invalid {cls.name} type: {gateway_type}"
+
+ def _check(self):
+ if self.type == 0:
+ if self.gateway not in (".", None):
+ raise SyntaxError(f"invalid {self.name} for type 0")
+ self.gateway = None
+ elif self.type == 1:
+ # check that it's OK
+ dns.ipv4.inet_aton(self.gateway)
+ elif self.type == 2:
+ # check that it's OK
+ dns.ipv6.inet_aton(self.gateway)
+ elif self.type == 3:
+ if not isinstance(self.gateway, dns.name.Name):
+ raise SyntaxError(f"invalid {self.name}; not a name")
+ else:
+ raise SyntaxError(self._invalid_type(self.type))
+
+ def to_text(self, origin=None, relativize=True):
+ if self.type == 0:
+ return "."
+ elif self.type in (1, 2):
+ return self.gateway
+ elif self.type == 3:
+ return str(self.gateway.choose_relativity(origin, relativize))
+ else:
+ raise ValueError(self._invalid_type(self.type)) # pragma: no cover
+
+ @classmethod
+ def from_text(
+ cls, gateway_type, tok, origin=None, relativize=True, relativize_to=None
+ ):
+ if gateway_type in (0, 1, 2):
+ gateway = tok.get_string()
+ elif gateway_type == 3:
+ gateway = tok.get_name(origin, relativize, relativize_to)
+ else:
+ raise dns.exception.SyntaxError(
+ cls._invalid_type(gateway_type)
+ ) # pragma: no cover
+ return cls(gateway_type, gateway)
+
+ # pylint: disable=unused-argument
+ def to_wire(self, file, compress=None, origin=None, canonicalize=False):
+ if self.type == 0:
+ pass
+ elif self.type == 1:
+ file.write(dns.ipv4.inet_aton(self.gateway))
+ elif self.type == 2:
+ file.write(dns.ipv6.inet_aton(self.gateway))
+ elif self.type == 3:
+ self.gateway.to_wire(file, None, origin, False)
+ else:
+ raise ValueError(self._invalid_type(self.type)) # pragma: no cover
+
+ # pylint: enable=unused-argument
+
+ @classmethod
+ def from_wire_parser(cls, gateway_type, parser, origin=None):
+ if gateway_type == 0:
+ gateway = None
+ elif gateway_type == 1:
+ gateway = dns.ipv4.inet_ntoa(parser.get_bytes(4))
+ elif gateway_type == 2:
+ gateway = dns.ipv6.inet_ntoa(parser.get_bytes(16))
+ elif gateway_type == 3:
+ gateway = parser.get_name(origin)
+ else:
+ raise dns.exception.FormError(cls._invalid_type(gateway_type))
+ return cls(gateway_type, gateway)
+
class Bitmap:
"""A helper class for the NSEC/NSEC3/CSYNC type bitmaps"""
- type_name = ''
+
+ type_name = ""
def __init__(self, windows=None):
last_window = -1
self.windows = windows
for window, bitmap in self.windows:
if not isinstance(window, int):
- raise ValueError(f'bad {self.type_name} window type')
+ raise ValueError(f"bad {self.type_name} window type")
if window <= last_window:
- raise ValueError(f'bad {self.type_name} window order')
+ raise ValueError(f"bad {self.type_name} window order")
if window > 256:
- raise ValueError(f'bad {self.type_name} window number')
+ raise ValueError(f"bad {self.type_name} window number")
last_window = window
if not isinstance(bitmap, bytes):
- raise ValueError(f'bad {self.type_name} octets type')
+ raise ValueError(f"bad {self.type_name} octets type")
if len(bitmap) == 0 or len(bitmap) > 32:
- raise ValueError(f'bad {self.type_name} octets')
+ raise ValueError(f"bad {self.type_name} octets")
+
+ def to_text(self) -> str:
+ text = ""
+ for window, bitmap in self.windows:
+ bits = []
+ for i, byte in enumerate(bitmap):
+ for j in range(0, 8):
+ if byte & (0x80 >> j):
+ rdtype = window * 256 + i * 8 + j
+ bits.append(dns.rdatatype.to_text(rdtype))
+ text += " " + " ".join(bits)
+ return text
+
+ @classmethod
+ def from_text(cls, tok: "dns.tokenizer.Tokenizer") -> "Bitmap":
+ rdtypes = []
+ for token in tok.get_remaining():
+ rdtype = dns.rdatatype.from_text(token.unescape().value)
+ if rdtype == 0:
+ raise dns.exception.SyntaxError(f"{cls.type_name} with bit 0")
+ rdtypes.append(rdtype)
+ return cls.from_rdtypes(rdtypes)
+
+ @classmethod
+ def from_rdtypes(cls, rdtypes: List[dns.rdatatype.RdataType]) -> "Bitmap":
+ rdtypes = sorted(rdtypes)
+ window = 0
+ octets = 0
+ prior_rdtype = 0
+ bitmap = bytearray(b"\0" * 32)
+ windows = []
+ for rdtype in rdtypes:
+ if rdtype == prior_rdtype:
+ continue
+ prior_rdtype = rdtype
+ new_window = rdtype // 256
+ if new_window != window:
+ if octets != 0:
+ windows.append((window, bytes(bitmap[0:octets])))
+ bitmap = bytearray(b"\0" * 32)
+ window = new_window
+ offset = rdtype % 256
+ byte = offset // 8
+ bit = offset % 8
+ octets = byte + 1
+ bitmap[byte] = bitmap[byte] | (0x80 >> bit)
+ if octets != 0:
+ windows.append((window, bytes(bitmap[0:octets])))
+ return cls(windows)
+
+ def to_wire(self, file: Any) -> None:
+ for window, bitmap in self.windows:
+ file.write(struct.pack("!BB", window, len(bitmap)))
+ file.write(bitmap)
+
+ @classmethod
+ def from_wire_parser(cls, parser: "dns.wire.Parser") -> "Bitmap":
+ windows = []
+ while parser.remaining() > 0:
+ window = parser.get_uint8()
+ bitmap = parser.get_counted_bytes()
+ windows.append((window, bitmap))
+ return cls(windows)
+
+
+def _priority_table(items):
+ by_priority = collections.defaultdict(list)
+ for rdata in items:
+ by_priority[rdata._processing_priority()].append(rdata)
+ return by_priority
+
+
+def priority_processing_order(iterable):
+ items = list(iterable)
+ if len(items) == 1:
+ return items
+ by_priority = _priority_table(items)
+ ordered = []
+ for k in sorted(by_priority.keys()):
+ rdatas = by_priority[k]
+ random.shuffle(rdatas)
+ ordered.extend(rdatas)
+ return ordered
_no_weight = 0.1
+
+
+def weighted_processing_order(iterable):
+ items = list(iterable)
+ if len(items) == 1:
+ return items
+ by_priority = _priority_table(items)
+ ordered = []
+ for k in sorted(by_priority.keys()):
+ rdatas = by_priority[k]
+ total = sum(rdata._processing_weight() or _no_weight for rdata in rdatas)
+ while len(rdatas) > 1:
+ r = random.uniform(0, total)
+ for n, rdata in enumerate(rdatas):
+ weight = rdata._processing_weight() or _no_weight
+ if weight > r:
+ break
+ r -= weight
+ total -= weight
+ ordered.append(rdata) # pylint: disable=undefined-loop-variable
+ del rdatas[n] # pylint: disable=undefined-loop-variable
+ ordered.append(rdatas[0])
+ return ordered
+
+
+def parse_formatted_hex(formatted, num_chunks, chunk_size, separator):
+ if len(formatted) != num_chunks * (chunk_size + 1) - 1:
+ raise ValueError("invalid formatted hex string")
+ value = b""
+ for _ in range(num_chunks):
+ chunk = formatted[0:chunk_size]
+ value += int(chunk, 16).to_bytes(chunk_size // 2, "big")
+ formatted = formatted[chunk_size:]
+ if len(formatted) > 0 and formatted[0] != separator:
+ raise ValueError("invalid formatted hex string")
+ formatted = formatted[1:]
+ return value
diff --git a/dns/renderer.py b/dns/renderer.py
index fbebb17..a77481f 100644
--- a/dns/renderer.py
+++ b/dns/renderer.py
@@ -1,17 +1,55 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2001-2017 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
"""Help for building DNS wire format messages"""
+
import contextlib
import io
import random
import struct
import time
+
import dns.exception
import dns.tsig
+
QUESTION = 0
ANSWER = 1
AUTHORITY = 2
ADDITIONAL = 3
+@contextlib.contextmanager
+def prefixed_length(output, length_length):
+ output.write(b"\00" * length_length)
+ start = output.tell()
+ yield
+ end = output.tell()
+ length = end - start
+ if length > 0:
+ try:
+ output.seek(start - length_length)
+ try:
+ output.write(length.to_bytes(length_length, "big"))
+ except OverflowError:
+ raise dns.exception.FormError
+ finally:
+ output.seek(end)
+
+
class Renderer:
"""Helper class for building DNS wire-format messages.
@@ -59,6 +97,7 @@ class Renderer:
def __init__(self, id=None, flags=0, max_size=65535, origin=None):
"""Initialize a new renderer."""
+
self.output = io.BytesIO()
if id is None:
self.id = random.randint(0, 65535)
@@ -70,8 +109,8 @@ class Renderer:
self.compress = {}
self.section = QUESTION
self.counts = [0, 0, 0, 0]
- self.output.write(b'\x00' * 12)
- self.mac = ''
+ self.output.write(b"\x00" * 12)
+ self.mac = ""
self.reserved = 0
self.was_padded = False
@@ -80,7 +119,15 @@ class Renderer:
compression table entries that pointed beyond the truncation
point.
"""
- pass
+
+ self.output.seek(where)
+ self.output.truncate()
+ keys_to_delete = []
+ for k, v in self.compress.items():
+ if v >= where:
+ keys_to_delete.append(k)
+ for k in keys_to_delete:
+ del self.compress[k]
def _set_section(self, section):
"""Set the renderer's current section.
@@ -91,11 +138,37 @@ class Renderer:
Raises dns.exception.FormError if an attempt was made to set
a section value less than the current section.
"""
- pass
+
+ if self.section != section:
+ if self.section > section:
+ raise dns.exception.FormError
+ self.section = section
+
+ @contextlib.contextmanager
+ def _track_size(self):
+ start = self.output.tell()
+ yield start
+ if self.output.tell() > self.max_size:
+ self._rollback(start)
+ raise dns.exception.TooBig
+
+ @contextlib.contextmanager
+ def _temporarily_seek_to(self, where):
+ current = self.output.tell()
+ try:
+ self.output.seek(where)
+ yield
+ finally:
+ self.output.seek(current)
def add_question(self, qname, rdtype, rdclass=dns.rdataclass.IN):
"""Add a question to the message."""
- pass
+
+ self._set_section(QUESTION)
+ with self._track_size():
+ qname.to_wire(self.output, self.compress, self.origin)
+ self.output.write(struct.pack("!HH", rdtype, rdclass))
+ self.counts[QUESTION] += 1
def add_rrset(self, section, rrset, **kw):
"""Add the rrset to the specified section.
@@ -103,7 +176,11 @@ class Renderer:
Any keyword arguments are passed on to the rdataset's to_wire()
routine.
"""
- pass
+
+ self._set_section(section)
+ with self._track_size():
+ n = rrset.to_wire(self.output, self.compress, self.origin, **kw)
+ self.counts[section] += n
def add_rdataset(self, section, name, rdataset, **kw):
"""Add the rdataset to the specified section, using the specified
@@ -112,7 +189,11 @@ class Renderer:
Any keyword arguments are passed on to the rdataset's to_wire()
routine.
"""
- pass
+
+ self._set_section(section)
+ with self._track_size():
+ n = rdataset.to_wire(name, self.output, self.compress, self.origin, **kw)
+ self.counts[section] += n
def add_opt(self, opt, pad=0, opt_size=0, tsig_size=0):
"""Add *opt* to the additional section, applying padding if desired. The
@@ -121,19 +202,68 @@ class Renderer:
Note that we don't have reliable way of knowing how big a GSS-TSIG digest
might be, so we we might not get an even multiple of the pad in that case."""
- pass
+ if pad:
+ ttl = opt.ttl
+ assert opt_size >= 11
+ opt_rdata = opt[0]
+ size_without_padding = self.output.tell() + opt_size + tsig_size
+ remainder = size_without_padding % pad
+ if remainder:
+ pad = b"\x00" * (pad - remainder)
+ else:
+ pad = b""
+ options = list(opt_rdata.options)
+ options.append(dns.edns.GenericOption(dns.edns.OptionType.PADDING, pad))
+ opt = dns.message.Message._make_opt(ttl, opt_rdata.rdclass, options)
+ self.was_padded = True
+ self.add_rrset(ADDITIONAL, opt)
def add_edns(self, edns, ednsflags, payload, options=None):
"""Add an EDNS OPT record to the message."""
- pass
- def add_tsig(self, keyname, secret, fudge, id, tsig_error, other_data,
- request_mac, algorithm=dns.tsig.default_algorithm):
+ # make sure the EDNS version in ednsflags agrees with edns
+ ednsflags &= 0xFF00FFFF
+ ednsflags |= edns << 16
+ opt = dns.message.Message._make_opt(ednsflags, payload, options)
+ self.add_opt(opt)
+
+ def add_tsig(
+ self,
+ keyname,
+ secret,
+ fudge,
+ id,
+ tsig_error,
+ other_data,
+ request_mac,
+ algorithm=dns.tsig.default_algorithm,
+ ):
"""Add a TSIG signature to the message."""
- pass
- def add_multi_tsig(self, ctx, keyname, secret, fudge, id, tsig_error,
- other_data, request_mac, algorithm=dns.tsig.default_algorithm):
+ s = self.output.getvalue()
+
+ if isinstance(secret, dns.tsig.Key):
+ key = secret
+ else:
+ key = dns.tsig.Key(keyname, secret, algorithm)
+ tsig = dns.message.Message._make_tsig(
+ keyname, algorithm, 0, fudge, b"", id, tsig_error, other_data
+ )
+ (tsig, _) = dns.tsig.sign(s, key, tsig[0], int(time.time()), request_mac)
+ self._write_tsig(tsig, keyname)
+
+ def add_multi_tsig(
+ self,
+ ctx,
+ keyname,
+ secret,
+ fudge,
+ id,
+ tsig_error,
+ other_data,
+ request_mac,
+ algorithm=dns.tsig.default_algorithm,
+ ):
"""Add a TSIG signature to the message. Unlike add_tsig(), this can be
used for a series of consecutive DNS envelopes, e.g. for a zone
transfer over TCP [RFC2845, 4.4].
@@ -141,7 +271,39 @@ class Renderer:
For the first message in the sequence, give ctx=None. For each
subsequent message, give the ctx that was returned from the
add_multi_tsig() call for the previous message."""
- pass
+
+ s = self.output.getvalue()
+
+ if isinstance(secret, dns.tsig.Key):
+ key = secret
+ else:
+ key = dns.tsig.Key(keyname, secret, algorithm)
+ tsig = dns.message.Message._make_tsig(
+ keyname, algorithm, 0, fudge, b"", id, tsig_error, other_data
+ )
+ (tsig, ctx) = dns.tsig.sign(
+ s, key, tsig[0], int(time.time()), request_mac, ctx, True
+ )
+ self._write_tsig(tsig, keyname)
+ return ctx
+
+ def _write_tsig(self, tsig, keyname):
+ if self.was_padded:
+ compress = None
+ else:
+ compress = self.compress
+ self._set_section(ADDITIONAL)
+ with self._track_size():
+ keyname.to_wire(self.output, compress, self.origin)
+ self.output.write(
+ struct.pack("!HHI", dns.rdatatype.TSIG, dns.rdataclass.ANY, 0)
+ )
+ with prefixed_length(self.output, 2):
+ tsig.to_wire(self.output)
+
+ self.counts[ADDITIONAL] += 1
+ with self._temporarily_seek_to(10):
+ self.output.write(struct.pack("!H", self.counts[ADDITIONAL]))
def write_header(self):
"""Write the DNS message header.
@@ -150,16 +312,35 @@ class Renderer:
have been rendered, but before the optional TSIG signature
is added.
"""
- pass
+
+ with self._temporarily_seek_to(0):
+ self.output.write(
+ struct.pack(
+ "!HHHHHH",
+ self.id,
+ self.flags,
+ self.counts[0],
+ self.counts[1],
+ self.counts[2],
+ self.counts[3],
+ )
+ )
def get_wire(self):
"""Return the wire format message."""
- pass
- def reserve(self, size: int) ->None:
- """Reserve *size* bytes."""
- pass
+ return self.output.getvalue()
- def release_reserved(self) ->None:
+ def reserve(self, size: int) -> None:
+ """Reserve *size* bytes."""
+ if size < 0:
+ raise ValueError("reserved amount must be non-negative")
+ if size > self.max_size:
+ raise ValueError("cannot reserve more than the maximum size")
+ self.reserved += size
+ self.max_size -= size
+
+ def release_reserved(self) -> None:
"""Release the reserved bytes."""
- pass
+ self.max_size += self.reserved
+ self.reserved = 0
diff --git a/dns/resolver.py b/dns/resolver.py
index 116ab15..f08f824 100644
--- a/dns/resolver.py
+++ b/dns/resolver.py
@@ -1,4 +1,22 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2017 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
"""DNS stub resolver."""
+
import contextlib
import random
import socket
@@ -8,6 +26,7 @@ import time
import warnings
from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union
from urllib.parse import urlparse
+
import dns._ddr
import dns.edns
import dns.exception
@@ -25,40 +44,70 @@ import dns.rdatatype
import dns.rdtypes.svcbbase
import dns.reversename
import dns.tsig
-if sys.platform == 'win32':
+
+if sys.platform == "win32":
import dns.win32util
class NXDOMAIN(dns.exception.DNSException):
"""The DNS query name does not exist."""
- supp_kwargs = {'qnames', 'responses'}
- fmt = None
+ supp_kwargs = {"qnames", "responses"}
+ fmt = None # we have our own __str__ implementation
+
+ # pylint: disable=arguments-differ
+
+ # We do this as otherwise mypy complains about unexpected keyword argument
+ # idna_exception
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
- def __str__(self) ->str:
- if 'qnames' not in self.kwargs:
+ def _check_kwargs(self, qnames, responses=None):
+ if not isinstance(qnames, (list, tuple, set)):
+ raise AttributeError("qnames must be a list, tuple or set")
+ if len(qnames) == 0:
+ raise AttributeError("qnames must contain at least one element")
+ if responses is None:
+ responses = {}
+ elif not isinstance(responses, dict):
+ raise AttributeError("responses must be a dict(qname=response)")
+ kwargs = dict(qnames=qnames, responses=responses)
+ return kwargs
+
+ def __str__(self) -> str:
+ if "qnames" not in self.kwargs:
return super().__str__()
- qnames = self.kwargs['qnames']
+ qnames = self.kwargs["qnames"]
if len(qnames) > 1:
- msg = 'None of DNS query names exist'
+ msg = "None of DNS query names exist"
else:
- msg = 'The DNS query name does not exist'
- qnames = ', '.join(map(str, qnames))
- return '{}: {}'.format(msg, qnames)
+ msg = "The DNS query name does not exist"
+ qnames = ", ".join(map(str, qnames))
+ return "{}: {}".format(msg, qnames)
@property
def canonical_name(self):
"""Return the unresolved canonical name."""
- pass
+ if "qnames" not in self.kwargs:
+ raise TypeError("parametrized exception required")
+ for qname in self.kwargs["qnames"]:
+ response = self.kwargs["responses"][qname]
+ try:
+ cname = response.canonical_name()
+ if cname != qname:
+ return cname
+ except Exception:
+ # We can just eat this exception as it means there was
+ # something wrong with the response.
+ pass
+ return self.kwargs["qnames"][0]
def __add__(self, e_nx):
"""Augment by results from another NXDOMAIN exception."""
- qnames0 = list(self.kwargs.get('qnames', []))
- responses0 = dict(self.kwargs.get('responses', {}))
- responses1 = e_nx.kwargs.get('responses', {})
- for qname1 in e_nx.kwargs.get('qnames', []):
+ qnames0 = list(self.kwargs.get("qnames", []))
+ responses0 = dict(self.kwargs.get("responses", {}))
+ responses1 = e_nx.kwargs.get("responses", {})
+ for qname1 in e_nx.kwargs.get("qnames", []):
if qname1 not in qnames0:
qnames0.append(qname1)
if qname1 in responses1:
@@ -70,7 +119,7 @@ class NXDOMAIN(dns.exception.DNSException):
Returns a list of ``dns.name.Name``.
"""
- pass
+ return self.kwargs["qnames"]
def responses(self):
"""A map from queried names to their NXDOMAIN responses.
@@ -78,51 +127,79 @@ class NXDOMAIN(dns.exception.DNSException):
Returns a dict mapping a ``dns.name.Name`` to a
``dns.message.Message``.
"""
- pass
+ return self.kwargs["responses"]
def response(self, qname):
"""The response for query *qname*.
Returns a ``dns.message.Message``.
"""
- pass
+ return self.kwargs["responses"][qname]
class YXDOMAIN(dns.exception.DNSException):
"""The DNS query name is too long after DNAME substitution."""
-ErrorTuple = Tuple[Optional[str], bool, int, Union[Exception, str],
- Optional[dns.message.Message]]
+ErrorTuple = Tuple[
+ Optional[str],
+ bool,
+ int,
+ Union[Exception, str],
+ Optional[dns.message.Message],
+]
-def _errors_to_text(errors: List[ErrorTuple]) ->List[str]:
+def _errors_to_text(errors: List[ErrorTuple]) -> List[str]:
"""Turn a resolution errors trace into a list of text."""
- pass
+ texts = []
+ for err in errors:
+ texts.append("Server {} answered {}".format(err[0], err[3]))
+ return texts
class LifetimeTimeout(dns.exception.Timeout):
"""The resolution lifetime expired."""
- msg = 'The resolution lifetime expired.'
- fmt = '%s after {timeout:.3f} seconds: {errors}' % msg[:-1]
- supp_kwargs = {'timeout', 'errors'}
+ msg = "The resolution lifetime expired."
+ fmt = "%s after {timeout:.3f} seconds: {errors}" % msg[:-1]
+ supp_kwargs = {"timeout", "errors"}
+
+ # We do this as otherwise mypy complains about unexpected keyword argument
+ # idna_exception
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
+ def _fmt_kwargs(self, **kwargs):
+ srv_msgs = _errors_to_text(kwargs["errors"])
+ return super()._fmt_kwargs(
+ timeout=kwargs["timeout"], errors="; ".join(srv_msgs)
+ )
+
+# We added more detail to resolution timeouts, but they are still
+# subclasses of dns.exception.Timeout for backwards compatibility. We also
+# keep dns.resolver.Timeout defined for backwards compatibility.
Timeout = LifetimeTimeout
class NoAnswer(dns.exception.DNSException):
"""The DNS response does not contain an answer to the question."""
- fmt = (
- 'The DNS response does not contain an answer to the question: {query}')
- supp_kwargs = {'response'}
+ fmt = "The DNS response does not contain an answer to the question: {query}"
+ supp_kwargs = {"response"}
+
+ # We do this as otherwise mypy complains about unexpected keyword argument
+ # idna_exception
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
+ def _fmt_kwargs(self, **kwargs):
+ return super()._fmt_kwargs(query=kwargs["response"].question)
+
+ def response(self):
+ return self.kwargs["response"]
+
class NoNameservers(dns.exception.DNSException):
"""All nameservers failed to answer the query.
@@ -132,13 +209,22 @@ class NoNameservers(dns.exception.DNSException):
[(server IP address, any object convertible to string)].
Non-empty errors list will add explanatory message ()
"""
- msg = 'All nameservers failed to answer the query.'
- fmt = '%s {query}: {errors}' % msg[:-1]
- supp_kwargs = {'request', 'errors'}
+ msg = "All nameservers failed to answer the query."
+ fmt = "%s {query}: {errors}" % msg[:-1]
+ supp_kwargs = {"request", "errors"}
+
+ # We do this as otherwise mypy complains about unexpected keyword argument
+ # idna_exception
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
+ def _fmt_kwargs(self, **kwargs):
+ srv_msgs = _errors_to_text(kwargs["errors"])
+ return super()._fmt_kwargs(
+ query=kwargs["request"].question, errors="; ".join(srv_msgs)
+ )
+
class NotAbsolute(dns.exception.DNSException):
"""An absolute domain name is required but a relative name was provided."""
@@ -172,10 +258,15 @@ class Answer:
RRset's name might not be the query name.
"""
- def __init__(self, qname: dns.name.Name, rdtype: dns.rdatatype.
- RdataType, rdclass: dns.rdataclass.RdataClass, response: dns.
- message.QueryMessage, nameserver: Optional[str]=None, port:
- Optional[int]=None) ->None:
+ def __init__(
+ self,
+ qname: dns.name.Name,
+ rdtype: dns.rdatatype.RdataType,
+ rdclass: dns.rdataclass.RdataClass,
+ response: dns.message.QueryMessage,
+ nameserver: Optional[str] = None,
+ port: Optional[int] = None,
+ ) -> None:
self.qname = qname
self.rdtype = rdtype
self.rdclass = rdclass
@@ -183,25 +274,27 @@ class Answer:
self.nameserver = nameserver
self.port = port
self.chaining_result = response.resolve_chaining()
+ # Copy some attributes out of chaining_result for backwards
+ # compatibility and convenience.
self.canonical_name = self.chaining_result.canonical_name
self.rrset = self.chaining_result.answer
self.expiration = time.time() + self.chaining_result.minimum_ttl
- def __getattr__(self, attr):
- if attr == 'name':
+ def __getattr__(self, attr): # pragma: no cover
+ if attr == "name":
return self.rrset.name
- elif attr == 'ttl':
+ elif attr == "ttl":
return self.rrset.ttl
- elif attr == 'covers':
+ elif attr == "covers":
return self.rrset.covers
- elif attr == 'rdclass':
+ elif attr == "rdclass":
return self.rrset.rdclass
- elif attr == 'rdtype':
+ elif attr == "rdtype":
return self.rrset.rdtype
else:
raise AttributeError(attr)
- def __len__(self) ->int:
+ def __len__(self) -> int:
return self.rrset and len(self.rrset) or 0
def __iter__(self):
@@ -227,64 +320,127 @@ class HostAnswers(Answers):
type.
"""
+ @classmethod
+ def make(
+ cls,
+ v6: Optional[Answer] = None,
+ v4: Optional[Answer] = None,
+ add_empty: bool = True,
+ ) -> "HostAnswers":
+ answers = HostAnswers()
+ if v6 is not None and (add_empty or v6.rrset):
+ answers[dns.rdatatype.AAAA] = v6
+ if v4 is not None and (add_empty or v4.rrset):
+ answers[dns.rdatatype.A] = v4
+ return answers
+
+ # Returns pairs of (address, family) from this result, potentiallys
+ # filtering by address family.
+ def addresses_and_families(
+ self, family: int = socket.AF_UNSPEC
+ ) -> Iterator[Tuple[str, int]]:
+ if family == socket.AF_UNSPEC:
+ yield from self.addresses_and_families(socket.AF_INET6)
+ yield from self.addresses_and_families(socket.AF_INET)
+ return
+ elif family == socket.AF_INET6:
+ answer = self.get(dns.rdatatype.AAAA)
+ elif family == socket.AF_INET:
+ answer = self.get(dns.rdatatype.A)
+ else:
+ raise NotImplementedError(f"unknown address family {family}")
+ if answer:
+ for rdata in answer:
+ yield (rdata.address, family)
+
+ # Returns addresses from this result, potentially filtering by
+ # address family.
+ def addresses(self, family: int = socket.AF_UNSPEC) -> Iterator[str]:
+ return (pair[0] for pair in self.addresses_and_families(family))
+
+ # Returns the canonical name from this result.
+ def canonical_name(self) -> dns.name.Name:
+ answer = self.get(dns.rdatatype.AAAA, self.get(dns.rdatatype.A))
+ return answer.canonical_name
+
class CacheStatistics:
"""Cache Statistics"""
- def __init__(self, hits: int=0, misses: int=0) ->None:
+ def __init__(self, hits: int = 0, misses: int = 0) -> None:
self.hits = hits
self.misses = misses
+ def reset(self) -> None:
+ self.hits = 0
+ self.misses = 0
+
+ def clone(self) -> "CacheStatistics":
+ return CacheStatistics(self.hits, self.misses)
-class CacheBase:
- def __init__(self) ->None:
+class CacheBase:
+ def __init__(self) -> None:
self.lock = threading.Lock()
self.statistics = CacheStatistics()
- def reset_statistics(self) ->None:
+ def reset_statistics(self) -> None:
"""Reset all statistics to zero."""
- pass
+ with self.lock:
+ self.statistics.reset()
- def hits(self) ->int:
+ def hits(self) -> int:
"""How many hits has the cache had?"""
- pass
+ with self.lock:
+ return self.statistics.hits
- def misses(self) ->int:
+ def misses(self) -> int:
"""How many misses has the cache had?"""
- pass
+ with self.lock:
+ return self.statistics.misses
- def get_statistics_snapshot(self) ->CacheStatistics:
+ def get_statistics_snapshot(self) -> CacheStatistics:
"""Return a consistent snapshot of all the statistics.
If running with multiple threads, it's better to take a
snapshot than to call statistics methods such as hits() and
misses() individually.
"""
- pass
+ with self.lock:
+ return self.statistics.clone()
-CacheKey = Tuple[dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.
- RdataClass]
+CacheKey = Tuple[dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.RdataClass]
class Cache(CacheBase):
"""Simple thread-safe DNS answer cache."""
- def __init__(self, cleaning_interval: float=300.0) ->None:
+ def __init__(self, cleaning_interval: float = 300.0) -> None:
"""*cleaning_interval*, a ``float`` is the number of seconds between
periodic cleanings.
"""
+
super().__init__()
self.data: Dict[CacheKey, Answer] = {}
self.cleaning_interval = cleaning_interval
self.next_cleaning: float = time.time() + self.cleaning_interval
- def _maybe_clean(self) ->None:
+ def _maybe_clean(self) -> None:
"""Clean the cache if it's time to do so."""
- pass
- def get(self, key: CacheKey) ->Optional[Answer]:
+ now = time.time()
+ if self.next_cleaning <= now:
+ keys_to_delete = []
+ for k, v in self.data.items():
+ if v.expiration <= now:
+ keys_to_delete.append(k)
+ for k in keys_to_delete:
+ del self.data[k]
+ now = time.time()
+ self.next_cleaning = now + self.cleaning_interval
+
+ def get(self, key: CacheKey) -> Optional[Answer]:
"""Get the answer associated with *key*.
Returns None if no answer is cached for the key.
@@ -294,9 +450,17 @@ class Cache(CacheBase):
Returns a ``dns.resolver.Answer`` or ``None``.
"""
- pass
- def put(self, key: CacheKey, value: Answer) ->None:
+ with self.lock:
+ self._maybe_clean()
+ v = self.data.get(key)
+ if v is None or v.expiration <= time.time():
+ self.statistics.misses += 1
+ return None
+ self.statistics.hits += 1
+ return v
+
+ def put(self, key: CacheKey, value: Answer) -> None:
"""Associate key and value in the cache.
*key*, a ``(dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.RdataClass)``
@@ -304,9 +468,12 @@ class Cache(CacheBase):
*value*, a ``dns.resolver.Answer``, the answer.
"""
- pass
- def flush(self, key: Optional[CacheKey]=None) ->None:
+ with self.lock:
+ self._maybe_clean()
+ self.data[key] = value
+
+ def flush(self, key: Optional[CacheKey] = None) -> None:
"""Flush the cache.
If *key* is not ``None``, only that item is flushed. Otherwise the entire cache
@@ -315,7 +482,14 @@ class Cache(CacheBase):
*key*, a ``(dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.RdataClass)``
tuple whose values are the query name, rdtype, and rdclass respectively.
"""
- pass
+
+ with self.lock:
+ if key is not None:
+ if key in self.data:
+ del self.data[key]
+ else:
+ self.data = {}
+ self.next_cleaning = time.time() + self.cleaning_interval
class LRUCacheNode:
@@ -328,6 +502,16 @@ class LRUCacheNode:
self.prev = self
self.next = self
+ def link_after(self, node: "LRUCacheNode") -> None:
+ self.prev = node
+ self.next = node.next
+ node.next.prev = self
+ node.next = self
+
+ def unlink(self) -> None:
+ self.next.prev = self.prev
+ self.prev.next = self.next
+
class LRUCache(CacheBase):
"""Thread-safe, bounded, least-recently-used DNS answer cache.
@@ -339,10 +523,11 @@ class LRUCache(CacheBase):
for a new one.
"""
- def __init__(self, max_size: int=100000) ->None:
+ def __init__(self, max_size: int = 100000) -> None:
"""*max_size*, an ``int``, is the maximum number of nodes to cache;
it must be greater than 0.
"""
+
super().__init__()
self.data: Dict[CacheKey, LRUCacheNode] = {}
self.set_max_size(max_size)
@@ -350,7 +535,12 @@ class LRUCache(CacheBase):
self.sentinel.prev = self.sentinel
self.sentinel.next = self.sentinel
- def get(self, key: CacheKey) ->Optional[Answer]:
+ def set_max_size(self, max_size: int) -> None:
+ if max_size < 1:
+ max_size = 1
+ self.max_size = max_size
+
+ def get(self, key: CacheKey) -> Optional[Answer]:
"""Get the answer associated with *key*.
Returns None if no answer is cached for the key.
@@ -360,13 +550,34 @@ class LRUCache(CacheBase):
Returns a ``dns.resolver.Answer`` or ``None``.
"""
- pass
- def get_hits_for_key(self, key: CacheKey) ->int:
+ with self.lock:
+ node = self.data.get(key)
+ if node is None:
+ self.statistics.misses += 1
+ return None
+ # Unlink because we're either going to move the node to the front
+ # of the LRU list or we're going to free it.
+ node.unlink()
+ if node.value.expiration <= time.time():
+ del self.data[node.key]
+ self.statistics.misses += 1
+ return None
+ node.link_after(self.sentinel)
+ self.statistics.hits += 1
+ node.hits += 1
+ return node.value
+
+ def get_hits_for_key(self, key: CacheKey) -> int:
"""Return the number of cache hits associated with the specified key."""
- pass
-
- def put(self, key: CacheKey, value: Answer) ->None:
+ with self.lock:
+ node = self.data.get(key)
+ if node is None or node.value.expiration <= time.time():
+ return 0
+ else:
+ return node.hits
+
+ def put(self, key: CacheKey, value: Answer) -> None:
"""Associate key and value in the cache.
*key*, a ``(dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.RdataClass)``
@@ -374,9 +585,21 @@ class LRUCache(CacheBase):
*value*, a ``dns.resolver.Answer``, the answer.
"""
- pass
- def flush(self, key: Optional[CacheKey]=None) ->None:
+ with self.lock:
+ node = self.data.get(key)
+ if node is not None:
+ node.unlink()
+ del self.data[node.key]
+ while len(self.data) >= self.max_size:
+ gnode = self.sentinel.prev
+ gnode.unlink()
+ del self.data[gnode.key]
+ node = LRUCacheNode(key, value)
+ node.link_after(self.sentinel)
+ self.data[key] = node
+
+ def flush(self, key: Optional[CacheKey] = None) -> None:
"""Flush the cache.
If *key* is not ``None``, only that item is flushed. Otherwise the entire cache
@@ -385,7 +608,20 @@ class LRUCache(CacheBase):
*key*, a ``(dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.RdataClass)``
tuple whose values are the query name, rdtype, and rdclass respectively.
"""
- pass
+
+ with self.lock:
+ if key is not None:
+ node = self.data.get(key)
+ if node is not None:
+ node.unlink()
+ del self.data[node.key]
+ else:
+ gnode = self.sentinel.next
+ while gnode != self.sentinel:
+ next = gnode.next
+ gnode.unlink()
+ gnode = next
+ self.data = {}
class _Resolution:
@@ -400,10 +636,16 @@ class _Resolution:
resolver data structures directly.
"""
- def __init__(self, resolver: 'BaseResolver', qname: Union[dns.name.Name,
- str], rdtype: Union[dns.rdatatype.RdataType, str], rdclass: Union[
- dns.rdataclass.RdataClass, str], tcp: bool, raise_on_no_answer:
- bool, search: Optional[bool]) ->None:
+ def __init__(
+ self,
+ resolver: "BaseResolver",
+ qname: Union[dns.name.Name, str],
+ rdtype: Union[dns.rdatatype.RdataType, str],
+ rdclass: Union[dns.rdataclass.RdataClass, str],
+ tcp: bool,
+ raise_on_no_answer: bool,
+ search: Optional[bool],
+ ) -> None:
if isinstance(qname, str):
qname = dns.name.from_text(qname, None)
rdtype = dns.rdatatype.RdataType.make(rdtype)
@@ -419,8 +661,8 @@ class _Resolution:
self.rdclass = rdclass
self.tcp = tcp
self.raise_on_no_answer = raise_on_no_answer
- self.nxdomain_responses: Dict[dns.name.Name, dns.message.QueryMessage
- ] = {}
+ self.nxdomain_responses: Dict[dns.name.Name, dns.message.QueryMessage] = {}
+ # Initialize other things to help analysis tools
self.qname = dns.name.empty
self.nameservers: List[dns.nameserver.Nameserver] = []
self.current_nameservers: List[dns.nameserver.Nameserver] = []
@@ -431,18 +673,234 @@ class _Resolution:
self.request: Optional[dns.message.QueryMessage] = None
self.backoff = 0.0
- def next_request(self) ->Tuple[Optional[dns.message.QueryMessage],
- Optional[Answer]]:
+ def next_request(
+ self,
+ ) -> Tuple[Optional[dns.message.QueryMessage], Optional[Answer]]:
"""Get the next request to send, and check the cache.
Returns a (request, answer) tuple. At most one of request or
answer will not be None.
"""
- pass
+
+ # We return a tuple instead of Union[Message,Answer] as it lets
+ # the caller avoid isinstance().
+
+ while len(self.qnames) > 0:
+ self.qname = self.qnames.pop(0)
+
+ # Do we know the answer?
+ if self.resolver.cache:
+ answer = self.resolver.cache.get(
+ (self.qname, self.rdtype, self.rdclass)
+ )
+ if answer is not None:
+ if answer.rrset is None and self.raise_on_no_answer:
+ raise NoAnswer(response=answer.response)
+ else:
+ return (None, answer)
+ answer = self.resolver.cache.get(
+ (self.qname, dns.rdatatype.ANY, self.rdclass)
+ )
+ if answer is not None and answer.response.rcode() == dns.rcode.NXDOMAIN:
+ # cached NXDOMAIN; record it and continue to next
+ # name.
+ self.nxdomain_responses[self.qname] = answer.response
+ continue
+
+ # Build the request
+ request = dns.message.make_query(self.qname, self.rdtype, self.rdclass)
+ if self.resolver.keyname is not None:
+ request.use_tsig(
+ self.resolver.keyring,
+ self.resolver.keyname,
+ algorithm=self.resolver.keyalgorithm,
+ )
+ request.use_edns(
+ self.resolver.edns,
+ self.resolver.ednsflags,
+ self.resolver.payload,
+ options=self.resolver.ednsoptions,
+ )
+ if self.resolver.flags is not None:
+ request.flags = self.resolver.flags
+
+ self.nameservers = self.resolver._enrich_nameservers(
+ self.resolver._nameservers,
+ self.resolver.nameserver_ports,
+ self.resolver.port,
+ )
+ if self.resolver.rotate:
+ random.shuffle(self.nameservers)
+ self.current_nameservers = self.nameservers[:]
+ self.errors = []
+ self.nameserver = None
+ self.tcp_attempt = False
+ self.retry_with_tcp = False
+ self.request = request
+ self.backoff = 0.10
+
+ return (request, None)
+
+ #
+ # We've tried everything and only gotten NXDOMAINs. (We know
+ # it's only NXDOMAINs as anything else would have returned
+ # before now.)
+ #
+ raise NXDOMAIN(qnames=self.qnames_to_try, responses=self.nxdomain_responses)
+
+ def next_nameserver(self) -> Tuple[dns.nameserver.Nameserver, bool, float]:
+ if self.retry_with_tcp:
+ assert self.nameserver is not None
+ assert not self.nameserver.is_always_max_size()
+ self.tcp_attempt = True
+ self.retry_with_tcp = False
+ return (self.nameserver, True, 0)
+
+ backoff = 0.0
+ if not self.current_nameservers:
+ if len(self.nameservers) == 0:
+ # Out of things to try!
+ raise NoNameservers(request=self.request, errors=self.errors)
+ self.current_nameservers = self.nameservers[:]
+ backoff = self.backoff
+ self.backoff = min(self.backoff * 2, 2)
+
+ self.nameserver = self.current_nameservers.pop(0)
+ self.tcp_attempt = self.tcp or self.nameserver.is_always_max_size()
+ return (self.nameserver, self.tcp_attempt, backoff)
+
+ def query_result(
+ self, response: Optional[dns.message.Message], ex: Optional[Exception]
+ ) -> Tuple[Optional[Answer], bool]:
+ #
+ # returns an (answer: Answer, end_loop: bool) tuple.
+ #
+ assert self.nameserver is not None
+ if ex:
+ # Exception during I/O or from_wire()
+ assert response is None
+ self.errors.append(
+ (
+ str(self.nameserver),
+ self.tcp_attempt,
+ self.nameserver.answer_port(),
+ ex,
+ response,
+ )
+ )
+ if (
+ isinstance(ex, dns.exception.FormError)
+ or isinstance(ex, EOFError)
+ or isinstance(ex, OSError)
+ or isinstance(ex, NotImplementedError)
+ ):
+ # This nameserver is no good, take it out of the mix.
+ self.nameservers.remove(self.nameserver)
+ elif isinstance(ex, dns.message.Truncated):
+ if self.tcp_attempt:
+ # Truncation with TCP is no good!
+ self.nameservers.remove(self.nameserver)
+ else:
+ self.retry_with_tcp = True
+ return (None, False)
+ # We got an answer!
+ assert response is not None
+ assert isinstance(response, dns.message.QueryMessage)
+ rcode = response.rcode()
+ if rcode == dns.rcode.NOERROR:
+ try:
+ answer = Answer(
+ self.qname,
+ self.rdtype,
+ self.rdclass,
+ response,
+ self.nameserver.answer_nameserver(),
+ self.nameserver.answer_port(),
+ )
+ except Exception as e:
+ self.errors.append(
+ (
+ str(self.nameserver),
+ self.tcp_attempt,
+ self.nameserver.answer_port(),
+ e,
+ response,
+ )
+ )
+ # The nameserver is no good, take it out of the mix.
+ self.nameservers.remove(self.nameserver)
+ return (None, False)
+ if self.resolver.cache:
+ self.resolver.cache.put((self.qname, self.rdtype, self.rdclass), answer)
+ if answer.rrset is None and self.raise_on_no_answer:
+ raise NoAnswer(response=answer.response)
+ return (answer, True)
+ elif rcode == dns.rcode.NXDOMAIN:
+ # Further validate the response by making an Answer, even
+ # if we aren't going to cache it.
+ try:
+ answer = Answer(
+ self.qname, dns.rdatatype.ANY, dns.rdataclass.IN, response
+ )
+ except Exception as e:
+ self.errors.append(
+ (
+ str(self.nameserver),
+ self.tcp_attempt,
+ self.nameserver.answer_port(),
+ e,
+ response,
+ )
+ )
+ # The nameserver is no good, take it out of the mix.
+ self.nameservers.remove(self.nameserver)
+ return (None, False)
+ self.nxdomain_responses[self.qname] = response
+ if self.resolver.cache:
+ self.resolver.cache.put(
+ (self.qname, dns.rdatatype.ANY, self.rdclass), answer
+ )
+ # Make next_nameserver() return None, so caller breaks its
+ # inner loop and calls next_request().
+ return (None, True)
+ elif rcode == dns.rcode.YXDOMAIN:
+ yex = YXDOMAIN()
+ self.errors.append(
+ (
+ str(self.nameserver),
+ self.tcp_attempt,
+ self.nameserver.answer_port(),
+ yex,
+ response,
+ )
+ )
+ raise yex
+ else:
+ #
+ # We got a response, but we're not happy with the
+ # rcode in it.
+ #
+ if rcode != dns.rcode.SERVFAIL or not self.resolver.retry_servfail:
+ self.nameservers.remove(self.nameserver)
+ self.errors.append(
+ (
+ str(self.nameserver),
+ self.tcp_attempt,
+ self.nameserver.answer_port(),
+ dns.rcode.to_text(rcode),
+ response,
+ )
+ )
+ return (None, False)
class BaseResolver:
"""DNS stub resolver."""
+
+ # We initialize in reset()
+ #
+ # pylint: disable=attribute-defined-outside-init
+
domain: dns.name.Name
nameserver_ports: Dict[str, int]
port: int
@@ -464,8 +922,9 @@ class BaseResolver:
ndots: Optional[int]
_nameservers: Sequence[Union[str, dns.nameserver.Nameserver]]
- def __init__(self, filename: str='/etc/resolv.conf', configure: bool=True
- ) ->None:
+ def __init__(
+ self, filename: str = "/etc/resolv.conf", configure: bool = True
+ ) -> None:
"""*filename*, a ``str`` or file object, specifying a file
in standard /etc/resolv.conf format. This parameter is meaningful
only when *configure* is true and the platform is POSIX.
@@ -476,18 +935,41 @@ class BaseResolver:
/etc/resolv.conf file on POSIX systems and from the registry
on Windows systems.)
"""
+
self.reset()
if configure:
- if sys.platform == 'win32':
+ if sys.platform == "win32":
self.read_registry()
elif filename:
self.read_resolv_conf(filename)
- def reset(self) ->None:
+ def reset(self) -> None:
"""Reset all resolver configuration to the defaults."""
- pass
- def read_resolv_conf(self, f: Any) ->None:
+ self.domain = dns.name.Name(dns.name.from_text(socket.gethostname())[1:])
+ if len(self.domain) == 0:
+ self.domain = dns.name.root
+ self._nameservers = []
+ self.nameserver_ports = {}
+ self.port = 53
+ self.search = []
+ self.use_search_by_default = False
+ self.timeout = 2.0
+ self.lifetime = 5.0
+ self.keyring = None
+ self.keyname = None
+ self.keyalgorithm = dns.tsig.default_algorithm
+ self.edns = -1
+ self.ednsflags = 0
+ self.ednsoptions = None
+ self.payload = 0
+ self.cache = None
+ self.flags = None
+ self.retry_servfail = False
+ self.rotate = False
+ self.ndots = None
+
+ def read_resolv_conf(self, f: Any) -> None:
"""Process *f* as a file in the /etc/resolv.conf format. If f is
a ``str``, it is used as the name of the file to open; otherwise it
is treated as the file itself.
@@ -503,25 +985,160 @@ class BaseResolver:
- options - supported options are rotate, timeout, edns0, and ndots
"""
- pass
- def read_registry(self) ->None:
+ nameservers = []
+ if isinstance(f, str):
+ try:
+ cm: contextlib.AbstractContextManager = open(f)
+ except OSError:
+ # /etc/resolv.conf doesn't exist, can't be read, etc.
+ raise NoResolverConfiguration(f"cannot open {f}")
+ else:
+ cm = contextlib.nullcontext(f)
+ with cm as f:
+ for l in f:
+ if len(l) == 0 or l[0] == "#" or l[0] == ";":
+ continue
+ tokens = l.split()
+
+ # Any line containing less than 2 tokens is malformed
+ if len(tokens) < 2:
+ continue
+
+ if tokens[0] == "nameserver":
+ nameservers.append(tokens[1])
+ elif tokens[0] == "domain":
+ self.domain = dns.name.from_text(tokens[1])
+ # domain and search are exclusive
+ self.search = []
+ elif tokens[0] == "search":
+ # the last search wins
+ self.search = []
+ for suffix in tokens[1:]:
+ self.search.append(dns.name.from_text(suffix))
+ # We don't set domain as it is not used if
+ # len(self.search) > 0
+ elif tokens[0] == "options":
+ for opt in tokens[1:]:
+ if opt == "rotate":
+ self.rotate = True
+ elif opt == "edns0":
+ self.use_edns()
+ elif "timeout" in opt:
+ try:
+ self.timeout = int(opt.split(":")[1])
+ except (ValueError, IndexError):
+ pass
+ elif "ndots" in opt:
+ try:
+ self.ndots = int(opt.split(":")[1])
+ except (ValueError, IndexError):
+ pass
+ if len(nameservers) == 0:
+ raise NoResolverConfiguration("no nameservers")
+ # Assigning directly instead of appending means we invoke the
+ # setter logic, with additonal checking and enrichment.
+ self.nameservers = nameservers
+
+ def read_registry(self) -> None:
"""Extract resolver configuration from the Windows registry."""
- pass
-
- def use_tsig(self, keyring: Any, keyname: Optional[Union[dns.name.Name,
- str]]=None, algorithm: Union[dns.name.Name, str]=dns.tsig.
- default_algorithm) ->None:
+ try:
+ info = dns.win32util.get_dns_info() # type: ignore
+ if info.domain is not None:
+ self.domain = info.domain
+ self.nameservers = info.nameservers
+ self.search = info.search
+ except AttributeError:
+ raise NotImplementedError
+
+ def _compute_timeout(
+ self,
+ start: float,
+ lifetime: Optional[float] = None,
+ errors: Optional[List[ErrorTuple]] = None,
+ ) -> float:
+ lifetime = self.lifetime if lifetime is None else lifetime
+ now = time.time()
+ duration = now - start
+ if errors is None:
+ errors = []
+ if duration < 0:
+ if duration < -1:
+ # Time going backwards is bad. Just give up.
+ raise LifetimeTimeout(timeout=duration, errors=errors)
+ else:
+ # Time went backwards, but only a little. This can
+ # happen, e.g. under vmware with older linux kernels.
+ # Pretend it didn't happen.
+ duration = 0
+ if duration >= lifetime:
+ raise LifetimeTimeout(timeout=duration, errors=errors)
+ return min(lifetime - duration, self.timeout)
+
+ def _get_qnames_to_try(
+ self, qname: dns.name.Name, search: Optional[bool]
+ ) -> List[dns.name.Name]:
+ # This is a separate method so we can unit test the search
+ # rules without requiring the Internet.
+ if search is None:
+ search = self.use_search_by_default
+ qnames_to_try = []
+ if qname.is_absolute():
+ qnames_to_try.append(qname)
+ else:
+ abs_qname = qname.concatenate(dns.name.root)
+ if search:
+ if len(self.search) > 0:
+ # There is a search list, so use it exclusively
+ search_list = self.search[:]
+ elif self.domain != dns.name.root and self.domain is not None:
+ # We have some notion of a domain that isn't the root, so
+ # use it as the search list.
+ search_list = [self.domain]
+ else:
+ search_list = []
+ # Figure out the effective ndots (default is 1)
+ if self.ndots is None:
+ ndots = 1
+ else:
+ ndots = self.ndots
+ for suffix in search_list:
+ qnames_to_try.append(qname + suffix)
+ if len(qname) > ndots:
+ # The name has at least ndots dots, so we should try an
+ # absolute query first.
+ qnames_to_try.insert(0, abs_qname)
+ else:
+ # The name has less than ndots dots, so we should search
+ # first, then try the absolute name.
+ qnames_to_try.append(abs_qname)
+ else:
+ qnames_to_try.append(abs_qname)
+ return qnames_to_try
+
+ def use_tsig(
+ self,
+ keyring: Any,
+ keyname: Optional[Union[dns.name.Name, str]] = None,
+ algorithm: Union[dns.name.Name, str] = dns.tsig.default_algorithm,
+ ) -> None:
"""Add a TSIG signature to each query.
The parameters are passed to ``dns.message.Message.use_tsig()``;
see its documentation for details.
"""
- pass
- def use_edns(self, edns: Optional[Union[int, bool]]=0, ednsflags: int=0,
- payload: int=dns.message.DEFAULT_EDNS_PAYLOAD, options: Optional[
- List[dns.edns.Option]]=None) ->None:
+ self.keyring = keyring
+ self.keyname = keyname
+ self.keyalgorithm = algorithm
+
+ def use_edns(
+ self,
+ edns: Optional[Union[int, bool]] = 0,
+ ednsflags: int = 0,
+ payload: int = dns.message.DEFAULT_EDNS_PAYLOAD,
+ options: Optional[List[dns.edns.Option]] = None,
+ ) -> None:
"""Configure EDNS behavior.
*edns*, an ``int``, is the EDNS level to use. Specifying
@@ -538,18 +1155,72 @@ class BaseResolver:
*options*, a list of ``dns.edns.Option`` objects or ``None``, the EDNS
options.
"""
- pass
- def set_flags(self, flags: int) ->None:
+ if edns is None or edns is False:
+ edns = -1
+ elif edns is True:
+ edns = 0
+ self.edns = edns
+ self.ednsflags = ednsflags
+ self.payload = payload
+ self.ednsoptions = options
+
+ def set_flags(self, flags: int) -> None:
"""Overrides the default flags with your own.
*flags*, an ``int``, the message flags to use.
"""
- pass
+
+ self.flags = flags
+
+ @classmethod
+ def _enrich_nameservers(
+ cls,
+ nameservers: Sequence[Union[str, dns.nameserver.Nameserver]],
+ nameserver_ports: Dict[str, int],
+ default_port: int,
+ ) -> List[dns.nameserver.Nameserver]:
+ enriched_nameservers = []
+ if isinstance(nameservers, list):
+ for nameserver in nameservers:
+ enriched_nameserver: dns.nameserver.Nameserver
+ if isinstance(nameserver, dns.nameserver.Nameserver):
+ enriched_nameserver = nameserver
+ elif dns.inet.is_address(nameserver):
+ port = nameserver_ports.get(nameserver, default_port)
+ enriched_nameserver = dns.nameserver.Do53Nameserver(
+ nameserver, port
+ )
+ else:
+ try:
+ if urlparse(nameserver).scheme != "https":
+ raise NotImplementedError
+ except Exception:
+ raise ValueError(
+ f"nameserver {nameserver} is not a "
+ "dns.nameserver.Nameserver instance or text form, "
+ "IP address, nor a valid https URL"
+ )
+ enriched_nameserver = dns.nameserver.DoHNameserver(nameserver)
+ enriched_nameservers.append(enriched_nameserver)
+ else:
+ raise ValueError(
+ "nameservers must be a list or tuple (not a {})".format(
+ type(nameservers)
+ )
+ )
+ return enriched_nameservers
+
+ @property
+ def nameservers(
+ self,
+ ) -> Sequence[Union[str, dns.nameserver.Nameserver]]:
+ return self._nameservers
@nameservers.setter
- def nameservers(self, nameservers: Sequence[Union[str, dns.nameserver.
- Nameserver]]) ->None:
+ def nameservers(
+ self, nameservers: Sequence[Union[str, dns.nameserver.Nameserver]]
+ ) -> None:
"""
*nameservers*, a ``list`` of nameservers, where a nameserver is either
a string interpretable as a nameserver, or a ``dns.nameserver.Nameserver``
@@ -557,18 +1228,26 @@ class BaseResolver:
Raises ``ValueError`` if *nameservers* is not a list of nameservers.
"""
- pass
+ # We just call _enrich_nameservers() for checking
+ self._enrich_nameservers(nameservers, self.nameserver_ports, self.port)
+ self._nameservers = nameservers
class Resolver(BaseResolver):
"""DNS stub resolver."""
- def resolve(self, qname: Union[dns.name.Name, str], rdtype: Union[dns.
- rdatatype.RdataType, str]=dns.rdatatype.A, rdclass: Union[dns.
- rdataclass.RdataClass, str]=dns.rdataclass.IN, tcp: bool=False,
- source: Optional[str]=None, raise_on_no_answer: bool=True,
- source_port: int=0, lifetime: Optional[float]=None, search:
- Optional[bool]=None) ->Answer:
+ def resolve(
+ self,
+ qname: Union[dns.name.Name, str],
+ rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.A,
+ rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN,
+ tcp: bool = False,
+ source: Optional[str] = None,
+ raise_on_no_answer: bool = True,
+ source_port: int = 0,
+ lifetime: Optional[float] = None,
+ search: Optional[bool] = None,
+ ) -> Answer: # pylint: disable=arguments-differ
"""Query nameservers to find the answer to the question.
The *qname*, *rdtype*, and *rdclass* parameters may be objects
@@ -619,13 +1298,57 @@ class Resolver(BaseResolver):
Returns a ``dns.resolver.Answer`` instance.
"""
- pass
- def query(self, qname: Union[dns.name.Name, str], rdtype: Union[dns.
- rdatatype.RdataType, str]=dns.rdatatype.A, rdclass: Union[dns.
- rdataclass.RdataClass, str]=dns.rdataclass.IN, tcp: bool=False,
- source: Optional[str]=None, raise_on_no_answer: bool=True,
- source_port: int=0, lifetime: Optional[float]=None) ->Answer:
+ resolution = _Resolution(
+ self, qname, rdtype, rdclass, tcp, raise_on_no_answer, search
+ )
+ start = time.time()
+ while True:
+ (request, answer) = resolution.next_request()
+ # Note we need to say "if answer is not None" and not just
+ # "if answer" because answer implements __len__, and python
+ # will call that. We want to return if we have an answer
+ # object, including in cases where its length is 0.
+ if answer is not None:
+ # cache hit!
+ return answer
+ assert request is not None # needed for type checking
+ done = False
+ while not done:
+ (nameserver, tcp, backoff) = resolution.next_nameserver()
+ if backoff:
+ time.sleep(backoff)
+ timeout = self._compute_timeout(start, lifetime, resolution.errors)
+ try:
+ response = nameserver.query(
+ request,
+ timeout=timeout,
+ source=source,
+ source_port=source_port,
+ max_size=tcp,
+ )
+ except Exception as ex:
+ (_, done) = resolution.query_result(None, ex)
+ continue
+ (answer, done) = resolution.query_result(response, None)
+ # Note we need to say "if answer is not None" and not just
+ # "if answer" because answer implements __len__, and python
+ # will call that. We want to return if we have an answer
+ # object, including in cases where its length is 0.
+ if answer is not None:
+ return answer
+
+ def query(
+ self,
+ qname: Union[dns.name.Name, str],
+ rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.A,
+ rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN,
+ tcp: bool = False,
+ source: Optional[str] = None,
+ raise_on_no_answer: bool = True,
+ source_port: int = 0,
+ lifetime: Optional[float] = None,
+ ) -> Answer: # pragma: no cover
"""Query nameservers to find the answer to the question.
This method calls resolve() with ``search=True``, and is
@@ -633,9 +1356,24 @@ class Resolver(BaseResolver):
dnspython. See the documentation for the resolve() method for
further details.
"""
- pass
-
- def resolve_address(self, ipaddr: str, *args: Any, **kwargs: Any) ->Answer:
+ warnings.warn(
+ "please use dns.resolver.Resolver.resolve() instead",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ return self.resolve(
+ qname,
+ rdtype,
+ rdclass,
+ tcp,
+ source,
+ raise_on_no_answer,
+ source_port,
+ lifetime,
+ True,
+ )
+
+ def resolve_address(self, ipaddr: str, *args: Any, **kwargs: Any) -> Answer:
"""Use a resolver to run a reverse query for PTR records.
This utilizes the resolve() method to perform a PTR lookup on the
@@ -648,10 +1386,23 @@ class Resolver(BaseResolver):
except for rdtype and rdclass are also supported by this
function.
"""
- pass
-
- def resolve_name(self, name: Union[dns.name.Name, str], family: int=
- socket.AF_UNSPEC, **kwargs: Any) ->HostAnswers:
+ # We make a modified kwargs for type checking happiness, as otherwise
+ # we get a legit warning about possibly having rdtype and rdclass
+ # in the kwargs more than once.
+ modified_kwargs: Dict[str, Any] = {}
+ modified_kwargs.update(kwargs)
+ modified_kwargs["rdtype"] = dns.rdatatype.PTR
+ modified_kwargs["rdclass"] = dns.rdataclass.IN
+ return self.resolve(
+ dns.reversename.from_address(ipaddr), *args, **modified_kwargs
+ )
+
+ def resolve_name(
+ self,
+ name: Union[dns.name.Name, str],
+ family: int = socket.AF_UNSPEC,
+ **kwargs: Any,
+ ) -> HostAnswers:
"""Use a resolver to query for address records.
This utilizes the resolve() method to perform A and/or AAAA lookups on
@@ -666,9 +1417,54 @@ class Resolver(BaseResolver):
except for rdtype and rdclass are also supported by this
function.
"""
- pass
-
- def canonical_name(self, name: Union[dns.name.Name, str]) ->dns.name.Name:
+ # We make a modified kwargs for type checking happiness, as otherwise
+ # we get a legit warning about possibly having rdtype and rdclass
+ # in the kwargs more than once.
+ modified_kwargs: Dict[str, Any] = {}
+ modified_kwargs.update(kwargs)
+ modified_kwargs.pop("rdtype", None)
+ modified_kwargs["rdclass"] = dns.rdataclass.IN
+
+ if family == socket.AF_INET:
+ v4 = self.resolve(name, dns.rdatatype.A, **modified_kwargs)
+ return HostAnswers.make(v4=v4)
+ elif family == socket.AF_INET6:
+ v6 = self.resolve(name, dns.rdatatype.AAAA, **modified_kwargs)
+ return HostAnswers.make(v6=v6)
+ elif family != socket.AF_UNSPEC:
+ raise NotImplementedError(f"unknown address family {family}")
+
+ raise_on_no_answer = modified_kwargs.pop("raise_on_no_answer", True)
+ lifetime = modified_kwargs.pop("lifetime", None)
+ start = time.time()
+ v6 = self.resolve(
+ name,
+ dns.rdatatype.AAAA,
+ raise_on_no_answer=False,
+ lifetime=self._compute_timeout(start, lifetime),
+ **modified_kwargs,
+ )
+ # Note that setting name ensures we query the same name
+ # for A as we did for AAAA. (This is just in case search lists
+ # are active by default in the resolver configuration and
+ # we might be talking to a server that says NXDOMAIN when it
+ # wants to say NOERROR no data.
+ name = v6.qname
+ v4 = self.resolve(
+ name,
+ dns.rdatatype.A,
+ raise_on_no_answer=False,
+ lifetime=self._compute_timeout(start, lifetime),
+ **modified_kwargs,
+ )
+ answers = HostAnswers.make(v6=v6, v4=v4, add_empty=not raise_on_no_answer)
+ if not answers:
+ raise NoAnswer(response=v6.response)
+ return answers
+
+ # pylint: disable=redefined-outer-name
+
+ def canonical_name(self, name: Union[dns.name.Name, str]) -> dns.name.Name:
"""Determine the canonical name of *name*.
The canonical name is the name the resolver uses for queries
@@ -682,9 +1478,16 @@ class Resolver(BaseResolver):
Returns a ``dns.name.Name``.
"""
- pass
+ try:
+ answer = self.resolve(name, raise_on_no_answer=False)
+ canonical_name = answer.canonical_name
+ except dns.resolver.NXDOMAIN as e:
+ canonical_name = e.canonical_name
+ return canonical_name
+
+ # pylint: enable=redefined-outer-name
- def try_ddr(self, lifetime: float=5.0) ->None:
+ def try_ddr(self, lifetime: float = 5.0) -> None:
"""Try to update the resolver's nameservers using Discovery of Designated
Resolvers (DDR). If successful, the resolver will subsequently use
DNS-over-HTTPS or DNS-over-TLS for future queries.
@@ -703,31 +1506,53 @@ class Resolver(BaseResolver):
the bootstrap nameserver is in the Subject Alternative Name field of the
TLS certficate.
"""
- pass
-
-
+ try:
+ expiration = time.time() + lifetime
+ answer = self.resolve(
+ dns._ddr._local_resolver_name, "SVCB", lifetime=lifetime
+ )
+ timeout = dns.query._remaining(expiration)
+ nameservers = dns._ddr._get_nameservers_sync(answer, timeout)
+ if len(nameservers) > 0:
+ self.nameservers = nameservers
+ except Exception:
+ pass
+
+
+#: The default resolver.
default_resolver: Optional[Resolver] = None
-def get_default_resolver() ->Resolver:
+def get_default_resolver() -> Resolver:
"""Get the default resolver, initializing it if necessary."""
- pass
+ if default_resolver is None:
+ reset_default_resolver()
+ assert default_resolver is not None
+ return default_resolver
-def reset_default_resolver() ->None:
+def reset_default_resolver() -> None:
"""Re-initialize default resolver.
Note that the resolver configuration (i.e. /etc/resolv.conf on UNIX
systems) will be re-read immediately.
"""
- pass
-
-def resolve(qname: Union[dns.name.Name, str], rdtype: Union[dns.rdatatype.
- RdataType, str]=dns.rdatatype.A, rdclass: Union[dns.rdataclass.
- RdataClass, str]=dns.rdataclass.IN, tcp: bool=False, source: Optional[
- str]=None, raise_on_no_answer: bool=True, source_port: int=0, lifetime:
- Optional[float]=None, search: Optional[bool]=None) ->Answer:
+ global default_resolver
+ default_resolver = Resolver()
+
+
+def resolve(
+ qname: Union[dns.name.Name, str],
+ rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.A,
+ rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN,
+ tcp: bool = False,
+ source: Optional[str] = None,
+ raise_on_no_answer: bool = True,
+ source_port: int = 0,
+ lifetime: Optional[float] = None,
+ search: Optional[bool] = None,
+) -> Answer: # pragma: no cover
"""Query nameservers to find the answer to the question.
This is a convenience function that uses the default resolver
@@ -736,14 +1561,30 @@ def resolve(qname: Union[dns.name.Name, str], rdtype: Union[dns.rdatatype.
See ``dns.resolver.Resolver.resolve`` for more information on the
parameters.
"""
- pass
-
-def query(qname: Union[dns.name.Name, str], rdtype: Union[dns.rdatatype.
- RdataType, str]=dns.rdatatype.A, rdclass: Union[dns.rdataclass.
- RdataClass, str]=dns.rdataclass.IN, tcp: bool=False, source: Optional[
- str]=None, raise_on_no_answer: bool=True, source_port: int=0, lifetime:
- Optional[float]=None) ->Answer:
+ return get_default_resolver().resolve(
+ qname,
+ rdtype,
+ rdclass,
+ tcp,
+ source,
+ raise_on_no_answer,
+ source_port,
+ lifetime,
+ search,
+ )
+
+
+def query(
+ qname: Union[dns.name.Name, str],
+ rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.A,
+ rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN,
+ tcp: bool = False,
+ source: Optional[str] = None,
+ raise_on_no_answer: bool = True,
+ source_port: int = 0,
+ lifetime: Optional[float] = None,
+) -> Answer: # pragma: no cover
"""Query nameservers to find the answer to the question.
This method calls resolve() with ``search=True``, and is
@@ -751,50 +1592,71 @@ def query(qname: Union[dns.name.Name, str], rdtype: Union[dns.rdatatype.
dnspython. See the documentation for the resolve() method for
further details.
"""
- pass
-
-
-def resolve_address(ipaddr: str, *args: Any, **kwargs: Any) ->Answer:
+ warnings.warn(
+ "please use dns.resolver.resolve() instead", DeprecationWarning, stacklevel=2
+ )
+ return resolve(
+ qname,
+ rdtype,
+ rdclass,
+ tcp,
+ source,
+ raise_on_no_answer,
+ source_port,
+ lifetime,
+ True,
+ )
+
+
+def resolve_address(ipaddr: str, *args: Any, **kwargs: Any) -> Answer:
"""Use a resolver to run a reverse query for PTR records.
See ``dns.resolver.Resolver.resolve_address`` for more information on the
parameters.
"""
- pass
+ return get_default_resolver().resolve_address(ipaddr, *args, **kwargs)
-def resolve_name(name: Union[dns.name.Name, str], family: int=socket.
- AF_UNSPEC, **kwargs: Any) ->HostAnswers:
+
+def resolve_name(
+ name: Union[dns.name.Name, str], family: int = socket.AF_UNSPEC, **kwargs: Any
+) -> HostAnswers:
"""Use a resolver to query for address records.
See ``dns.resolver.Resolver.resolve_name`` for more information on the
parameters.
"""
- pass
+
+ return get_default_resolver().resolve_name(name, family, **kwargs)
-def canonical_name(name: Union[dns.name.Name, str]) ->dns.name.Name:
+def canonical_name(name: Union[dns.name.Name, str]) -> dns.name.Name:
"""Determine the canonical name of *name*.
See ``dns.resolver.Resolver.canonical_name`` for more information on the
parameters and possible exceptions.
"""
- pass
+
+ return get_default_resolver().canonical_name(name)
-def try_ddr(lifetime: float=5.0) ->None:
+def try_ddr(lifetime: float = 5.0) -> None:
"""Try to update the default resolver's nameservers using Discovery of Designated
Resolvers (DDR). If successful, the resolver will subsequently use
DNS-over-HTTPS or DNS-over-TLS for future queries.
See :py:func:`dns.resolver.Resolver.try_ddr` for more information.
"""
- pass
+ return get_default_resolver().try_ddr(lifetime)
-def zone_for_name(name: Union[dns.name.Name, str], rdclass: dns.rdataclass.
- RdataClass=dns.rdataclass.IN, tcp: bool=False, resolver: Optional[
- Resolver]=None, lifetime: Optional[float]=None) ->dns.name.Name:
+def zone_for_name(
+ name: Union[dns.name.Name, str],
+ rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN,
+ tcp: bool = False,
+ resolver: Optional[Resolver] = None,
+ lifetime: Optional[float] = None,
+) -> dns.name.Name:
"""Find the name of the zone which contains the specified name.
*name*, an absolute ``dns.name.Name`` or ``str``, the query name.
@@ -819,11 +1681,66 @@ def zone_for_name(name: Union[dns.name.Name, str], rdclass: dns.rdataclass.
Returns a ``dns.name.Name``.
"""
- pass
-
-def make_resolver_at(where: Union[dns.name.Name, str], port: int=53, family:
- int=socket.AF_UNSPEC, resolver: Optional[Resolver]=None) ->Resolver:
+ if isinstance(name, str):
+ name = dns.name.from_text(name, dns.name.root)
+ if resolver is None:
+ resolver = get_default_resolver()
+ if not name.is_absolute():
+ raise NotAbsolute(name)
+ start = time.time()
+ expiration: Optional[float]
+ if lifetime is not None:
+ expiration = start + lifetime
+ else:
+ expiration = None
+ while 1:
+ try:
+ rlifetime: Optional[float]
+ if expiration is not None:
+ rlifetime = expiration - time.time()
+ if rlifetime <= 0:
+ rlifetime = 0
+ else:
+ rlifetime = None
+ answer = resolver.resolve(
+ name, dns.rdatatype.SOA, rdclass, tcp, lifetime=rlifetime
+ )
+ assert answer.rrset is not None
+ if answer.rrset.name == name:
+ return name
+ # otherwise we were CNAMEd or DNAMEd and need to look higher
+ except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer) as e:
+ if isinstance(e, dns.resolver.NXDOMAIN):
+ response = e.responses().get(name)
+ else:
+ response = e.response() # pylint: disable=no-value-for-parameter
+ if response:
+ for rrs in response.authority:
+ if rrs.rdtype == dns.rdatatype.SOA and rrs.rdclass == rdclass:
+ (nr, _, _) = rrs.name.fullcompare(name)
+ if nr == dns.name.NAMERELN_SUPERDOMAIN:
+ # We're doing a proper superdomain check as
+ # if the name were equal we ought to have gotten
+ # it in the answer section! We are ignoring the
+ # possibility that the authority is insane and
+ # is including multiple SOA RRs for different
+ # authorities.
+ return rrs.name
+ # we couldn't extract anything useful from the response (e.g. it's
+ # a type 3 NXDOMAIN)
+ try:
+ name = name.parent()
+ except dns.name.NoParent:
+ raise NoRootSOA
+
+
+def make_resolver_at(
+ where: Union[dns.name.Name, str],
+ port: int = 53,
+ family: int = socket.AF_UNSPEC,
+ resolver: Optional[Resolver] = None,
+) -> Resolver:
"""Make a stub resolver using the specified destination as the full resolver.
*where*, a ``dns.name.Name`` or ``str`` the domain name or IP address of the
@@ -841,16 +1758,34 @@ def make_resolver_at(where: Union[dns.name.Name, str], port: int=53, family:
Returns a ``dns.resolver.Resolver`` or raises an exception.
"""
- pass
-
-
-def resolve_at(where: Union[dns.name.Name, str], qname: Union[dns.name.Name,
- str], rdtype: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.A,
- rdclass: Union[dns.rdataclass.RdataClass, str]=dns.rdataclass.IN, tcp:
- bool=False, source: Optional[str]=None, raise_on_no_answer: bool=True,
- source_port: int=0, lifetime: Optional[float]=None, search: Optional[
- bool]=None, port: int=53, family: int=socket.AF_UNSPEC, resolver:
- Optional[Resolver]=None) ->Answer:
+ if resolver is None:
+ resolver = get_default_resolver()
+ nameservers: List[Union[str, dns.nameserver.Nameserver]] = []
+ if isinstance(where, str) and dns.inet.is_address(where):
+ nameservers.append(dns.nameserver.Do53Nameserver(where, port))
+ else:
+ for address in resolver.resolve_name(where, family).addresses():
+ nameservers.append(dns.nameserver.Do53Nameserver(address, port))
+ res = dns.resolver.Resolver(configure=False)
+ res.nameservers = nameservers
+ return res
+
+
+def resolve_at(
+ where: Union[dns.name.Name, str],
+ qname: Union[dns.name.Name, str],
+ rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.A,
+ rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN,
+ tcp: bool = False,
+ source: Optional[str] = None,
+ raise_on_no_answer: bool = True,
+ source_port: int = 0,
+ lifetime: Optional[float] = None,
+ search: Optional[bool] = None,
+ port: int = 53,
+ family: int = socket.AF_UNSPEC,
+ resolver: Optional[Resolver] = None,
+) -> Answer:
"""Query nameservers to find the answer to the question.
This is a convenience function that calls ``dns.resolver.make_resolver_at()`` to
@@ -864,11 +1799,29 @@ def resolve_at(where: Union[dns.name.Name, str], qname: Union[dns.name.Name,
``dns.resolver.make_resolver_at()`` and then use that resolver for the queries
instead of calling ``resolve_at()`` multiple times.
"""
- pass
+ return make_resolver_at(where, port, family, resolver).resolve(
+ qname,
+ rdtype,
+ rdclass,
+ tcp,
+ source,
+ raise_on_no_answer,
+ source_port,
+ lifetime,
+ search,
+ )
+
+
+#
+# Support for overriding the system resolver for all python code in the
+# running process.
+#
+
+_protocols_for_socktype = {
+ socket.SOCK_DGRAM: [socket.SOL_UDP],
+ socket.SOCK_STREAM: [socket.SOL_TCP],
+}
-
-_protocols_for_socktype = {socket.SOCK_DGRAM: [socket.SOL_UDP], socket.
- SOCK_STREAM: [socket.SOL_TCP]}
_resolver = None
_original_getaddrinfo = socket.getaddrinfo
_original_getnameinfo = socket.getnameinfo
@@ -878,7 +1831,191 @@ _original_gethostbyname_ex = socket.gethostbyname_ex
_original_gethostbyaddr = socket.gethostbyaddr
-def override_system_resolver(resolver: Optional[Resolver]=None) ->None:
+def _getaddrinfo(
+ host=None, service=None, family=socket.AF_UNSPEC, socktype=0, proto=0, flags=0
+):
+ if flags & socket.AI_NUMERICHOST != 0:
+ # Short circuit directly into the system's getaddrinfo(). We're
+ # not adding any value in this case, and this avoids infinite loops
+ # because dns.query.* needs to call getaddrinfo() for IPv6 scoping
+ # reasons. We will also do this short circuit below if we
+ # discover that the host is an address literal.
+ return _original_getaddrinfo(host, service, family, socktype, proto, flags)
+ if flags & (socket.AI_ADDRCONFIG | socket.AI_V4MAPPED) != 0:
+ # Not implemented. We raise a gaierror as opposed to a
+ # NotImplementedError as it helps callers handle errors more
+ # appropriately. [Issue #316]
+ #
+ # We raise EAI_FAIL as opposed to EAI_SYSTEM because there is
+ # no EAI_SYSTEM on Windows [Issue #416]. We didn't go for
+ # EAI_BADFLAGS as the flags aren't bad, we just don't
+ # implement them.
+ raise socket.gaierror(
+ socket.EAI_FAIL, "Non-recoverable failure in name resolution"
+ )
+ if host is None and service is None:
+ raise socket.gaierror(socket.EAI_NONAME, "Name or service not known")
+ addrs = []
+ canonical_name = None # pylint: disable=redefined-outer-name
+ # Is host None or an address literal? If so, use the system's
+ # getaddrinfo().
+ if host is None:
+ return _original_getaddrinfo(host, service, family, socktype, proto, flags)
+ try:
+ # We don't care about the result of af_for_address(), we're just
+ # calling it so it raises an exception if host is not an IPv4 or
+ # IPv6 address.
+ dns.inet.af_for_address(host)
+ return _original_getaddrinfo(host, service, family, socktype, proto, flags)
+ except Exception:
+ pass
+ # Something needs resolution!
+ try:
+ answers = _resolver.resolve_name(host, family)
+ addrs = answers.addresses_and_families()
+ canonical_name = answers.canonical_name().to_text(True)
+ except dns.resolver.NXDOMAIN:
+ raise socket.gaierror(socket.EAI_NONAME, "Name or service not known")
+ except Exception:
+ # We raise EAI_AGAIN here as the failure may be temporary
+ # (e.g. a timeout) and EAI_SYSTEM isn't defined on Windows.
+ # [Issue #416]
+ raise socket.gaierror(socket.EAI_AGAIN, "Temporary failure in name resolution")
+ port = None
+ try:
+ # Is it a port literal?
+ if service is None:
+ port = 0
+ else:
+ port = int(service)
+ except Exception:
+ if flags & socket.AI_NUMERICSERV == 0:
+ try:
+ port = socket.getservbyname(service)
+ except Exception:
+ pass
+ if port is None:
+ raise socket.gaierror(socket.EAI_NONAME, "Name or service not known")
+ tuples = []
+ if socktype == 0:
+ socktypes = [socket.SOCK_DGRAM, socket.SOCK_STREAM]
+ else:
+ socktypes = [socktype]
+ if flags & socket.AI_CANONNAME != 0:
+ cname = canonical_name
+ else:
+ cname = ""
+ for addr, af in addrs:
+ for socktype in socktypes:
+ for proto in _protocols_for_socktype[socktype]:
+ addr_tuple = dns.inet.low_level_address_tuple((addr, port), af)
+ tuples.append((af, socktype, proto, cname, addr_tuple))
+ if len(tuples) == 0:
+ raise socket.gaierror(socket.EAI_NONAME, "Name or service not known")
+ return tuples
+
+
+def _getnameinfo(sockaddr, flags=0):
+ host = sockaddr[0]
+ port = sockaddr[1]
+ if len(sockaddr) == 4:
+ scope = sockaddr[3]
+ family = socket.AF_INET6
+ else:
+ scope = None
+ family = socket.AF_INET
+ tuples = _getaddrinfo(host, port, family, socket.SOCK_STREAM, socket.SOL_TCP, 0)
+ if len(tuples) > 1:
+ raise socket.error("sockaddr resolved to multiple addresses")
+ addr = tuples[0][4][0]
+ if flags & socket.NI_DGRAM:
+ pname = "udp"
+ else:
+ pname = "tcp"
+ qname = dns.reversename.from_address(addr)
+ if flags & socket.NI_NUMERICHOST == 0:
+ try:
+ answer = _resolver.resolve(qname, "PTR")
+ hostname = answer.rrset[0].target.to_text(True)
+ except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer):
+ if flags & socket.NI_NAMEREQD:
+ raise socket.gaierror(socket.EAI_NONAME, "Name or service not known")
+ hostname = addr
+ if scope is not None:
+ hostname += "%" + str(scope)
+ else:
+ hostname = addr
+ if scope is not None:
+ hostname += "%" + str(scope)
+ if flags & socket.NI_NUMERICSERV:
+ service = str(port)
+ else:
+ service = socket.getservbyport(port, pname)
+ return (hostname, service)
+
+
+def _getfqdn(name=None):
+ if name is None:
+ name = socket.gethostname()
+ try:
+ (name, _, _) = _gethostbyaddr(name)
+ # Python's version checks aliases too, but our gethostbyname
+ # ignores them, so we do so here as well.
+ except Exception:
+ pass
+ return name
+
+
+def _gethostbyname(name):
+ return _gethostbyname_ex(name)[2][0]
+
+
+def _gethostbyname_ex(name):
+ aliases = []
+ addresses = []
+ tuples = _getaddrinfo(
+ name, 0, socket.AF_INET, socket.SOCK_STREAM, socket.SOL_TCP, socket.AI_CANONNAME
+ )
+ canonical = tuples[0][3]
+ for item in tuples:
+ addresses.append(item[4][0])
+ # XXX we just ignore aliases
+ return (canonical, aliases, addresses)
+
+
+def _gethostbyaddr(ip):
+ try:
+ dns.ipv6.inet_aton(ip)
+ sockaddr = (ip, 80, 0, 0)
+ family = socket.AF_INET6
+ except Exception:
+ try:
+ dns.ipv4.inet_aton(ip)
+ except Exception:
+ raise socket.gaierror(socket.EAI_NONAME, "Name or service not known")
+ sockaddr = (ip, 80)
+ family = socket.AF_INET
+ (name, _) = _getnameinfo(sockaddr, socket.NI_NAMEREQD)
+ aliases = []
+ addresses = []
+ tuples = _getaddrinfo(
+ name, 0, family, socket.SOCK_STREAM, socket.SOL_TCP, socket.AI_CANONNAME
+ )
+ canonical = tuples[0][3]
+ # We only want to include an address from the tuples if it's the
+ # same as the one we asked about. We do this comparison in binary
+ # to avoid any differences in text representations.
+ bin_ip = dns.inet.inet_pton(family, ip)
+ for item in tuples:
+ addr = item[4][0]
+ bin_addr = dns.inet.inet_pton(family, addr)
+ if bin_ip == bin_addr:
+ addresses.append(addr)
+ # XXX we just ignore aliases
+ return (canonical, aliases, addresses)
+
+
+def override_system_resolver(resolver: Optional[Resolver] = None) -> None:
"""Override the system resolver routines in the socket module with
versions which use dnspython's resolver.
@@ -891,9 +2028,27 @@ def override_system_resolver(resolver: Optional[Resolver]=None) ->None:
resolver, a ``dns.resolver.Resolver`` or ``None``, the resolver to use.
"""
- pass
+
+ if resolver is None:
+ resolver = get_default_resolver()
+ global _resolver
+ _resolver = resolver
+ socket.getaddrinfo = _getaddrinfo
+ socket.getnameinfo = _getnameinfo
+ socket.getfqdn = _getfqdn
+ socket.gethostbyname = _gethostbyname
+ socket.gethostbyname_ex = _gethostbyname_ex
+ socket.gethostbyaddr = _gethostbyaddr
-def restore_system_resolver() ->None:
+def restore_system_resolver() -> None:
"""Undo the effects of prior override_system_resolver()."""
- pass
+
+ global _resolver
+ _resolver = None
+ socket.getaddrinfo = _original_getaddrinfo
+ socket.getnameinfo = _original_getnameinfo
+ socket.getfqdn = _original_getfqdn
+ socket.gethostbyname = _original_gethostbyname
+ socket.gethostbyname_ex = _original_gethostbyname_ex
+ socket.gethostbyaddr = _original_gethostbyaddr
diff --git a/dns/reversename.py b/dns/reversename.py
index 416c57f..8236c71 100644
--- a/dns/reversename.py
+++ b/dns/reversename.py
@@ -1,14 +1,37 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2006-2017 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
"""DNS Reverse Map Names."""
+
import binascii
+
import dns.ipv4
import dns.ipv6
import dns.name
-ipv4_reverse_domain = dns.name.from_text('in-addr.arpa.')
-ipv6_reverse_domain = dns.name.from_text('ip6.arpa.')
+ipv4_reverse_domain = dns.name.from_text("in-addr.arpa.")
+ipv6_reverse_domain = dns.name.from_text("ip6.arpa.")
-def from_address(text: str, v4_origin: dns.name.Name=ipv4_reverse_domain,
- v6_origin: dns.name.Name=ipv6_reverse_domain) ->dns.name.Name:
+
+def from_address(
+ text: str,
+ v4_origin: dns.name.Name = ipv4_reverse_domain,
+ v6_origin: dns.name.Name = ipv6_reverse_domain,
+) -> dns.name.Name:
"""Convert an IPv4 or IPv6 address in textual form into a Name object whose
value is the reverse-map domain name of the address.
@@ -27,11 +50,26 @@ def from_address(text: str, v4_origin: dns.name.Name=ipv4_reverse_domain,
Returns a ``dns.name.Name``.
"""
- pass
+
+ try:
+ v6 = dns.ipv6.inet_aton(text)
+ if dns.ipv6.is_mapped(v6):
+ parts = ["%d" % byte for byte in v6[12:]]
+ origin = v4_origin
+ else:
+ parts = [x for x in str(binascii.hexlify(v6).decode())]
+ origin = v6_origin
+ except Exception:
+ parts = ["%d" % byte for byte in dns.ipv4.inet_aton(text)]
+ origin = v4_origin
+ return dns.name.from_text(".".join(reversed(parts)), origin=origin)
-def to_address(name: dns.name.Name, v4_origin: dns.name.Name=
- ipv4_reverse_domain, v6_origin: dns.name.Name=ipv6_reverse_domain) ->str:
+def to_address(
+ name: dns.name.Name,
+ v4_origin: dns.name.Name = ipv4_reverse_domain,
+ v6_origin: dns.name.Name = ipv6_reverse_domain,
+) -> str:
"""Convert a reverse map domain name into textual address form.
*name*, a ``dns.name.Name``, an IPv4 or IPv6 address in reverse-map name
@@ -48,4 +86,20 @@ def to_address(name: dns.name.Name, v4_origin: dns.name.Name=
Returns a ``str``.
"""
- pass
+
+ if name.is_subdomain(v4_origin):
+ name = name.relativize(v4_origin)
+ text = b".".join(reversed(name.labels))
+ # run through inet_ntoa() to check syntax and make pretty.
+ return dns.ipv4.inet_ntoa(dns.ipv4.inet_aton(text))
+ elif name.is_subdomain(v6_origin):
+ name = name.relativize(v6_origin)
+ labels = list(reversed(name.labels))
+ parts = []
+ for i in range(0, len(labels), 4):
+ parts.append(b"".join(labels[i : i + 4]))
+ text = b":".join(parts)
+ # run through inet_ntoa() to check syntax and make pretty.
+ return dns.ipv6.inet_ntoa(dns.ipv6.inet_aton(text))
+ else:
+ raise dns.exception.SyntaxError("unknown reverse-map address family")
diff --git a/dns/rrset.py b/dns/rrset.py
index f235f64..6f39b10 100644
--- a/dns/rrset.py
+++ b/dns/rrset.py
@@ -1,5 +1,24 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2017 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
"""DNS RRsets (an RRset is a named rdataset)"""
+
from typing import Any, Collection, Dict, Optional, Union, cast
+
import dns.name
import dns.rdataclass
import dns.rdataset
@@ -15,29 +34,51 @@ class RRset(dns.rdataset.Rdataset):
arguments, reflecting the fact that RRsets always have an owner
name.
"""
- __slots__ = ['name', 'deleting']
- def __init__(self, name: dns.name.Name, rdclass: dns.rdataclass.
- RdataClass, rdtype: dns.rdatatype.RdataType, covers: dns.rdatatype.
- RdataType=dns.rdatatype.NONE, deleting: Optional[dns.rdataclass.
- RdataClass]=None):
+ __slots__ = ["name", "deleting"]
+
+ def __init__(
+ self,
+ name: dns.name.Name,
+ rdclass: dns.rdataclass.RdataClass,
+ rdtype: dns.rdatatype.RdataType,
+ covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
+ deleting: Optional[dns.rdataclass.RdataClass] = None,
+ ):
"""Create a new RRset."""
+
super().__init__(rdclass, rdtype, covers)
self.name = name
self.deleting = deleting
+ def _clone(self):
+ obj = super()._clone()
+ obj.name = self.name
+ obj.deleting = self.deleting
+ return obj
+
def __repr__(self):
if self.covers == 0:
- ctext = ''
+ ctext = ""
else:
- ctext = '(' + dns.rdatatype.to_text(self.covers) + ')'
+ ctext = "(" + dns.rdatatype.to_text(self.covers) + ")"
if self.deleting is not None:
- dtext = ' delete=' + dns.rdataclass.to_text(self.deleting)
+ dtext = " delete=" + dns.rdataclass.to_text(self.deleting)
else:
- dtext = ''
- return '<DNS ' + str(self.name) + ' ' + dns.rdataclass.to_text(self
- .rdclass) + ' ' + dns.rdatatype.to_text(self.rdtype
- ) + ctext + dtext + ' RRset: ' + self._rdata_repr() + '>'
+ dtext = ""
+ return (
+ "<DNS "
+ + str(self.name)
+ + " "
+ + dns.rdataclass.to_text(self.rdclass)
+ + " "
+ + dns.rdatatype.to_text(self.rdtype)
+ + ctext
+ + dtext
+ + " RRset: "
+ + self._rdata_repr()
+ + ">"
+ )
def __str__(self):
return self.to_text()
@@ -50,7 +91,7 @@ class RRset(dns.rdataset.Rdataset):
return False
return super().__eq__(other)
- def match(self, *args: Any, **kwargs: Any) ->bool:
+ def match(self, *args: Any, **kwargs: Any) -> bool: # type: ignore[override]
"""Does this rrset match the specified attributes?
Behaves as :py:func:`full_match()` if the first argument is a
@@ -62,18 +103,36 @@ class RRset(dns.rdataset.Rdataset):
makes RRsets matchable as Rdatasets while preserving backwards
compatibility.)
"""
- pass
-
- def full_match(self, name: dns.name.Name, rdclass: dns.rdataclass.
- RdataClass, rdtype: dns.rdatatype.RdataType, covers: dns.rdatatype.
- RdataType, deleting: Optional[dns.rdataclass.RdataClass]=None) ->bool:
+ if isinstance(args[0], dns.name.Name):
+ return self.full_match(*args, **kwargs) # type: ignore[arg-type]
+ else:
+ return super().match(*args, **kwargs) # type: ignore[arg-type]
+
+ def full_match(
+ self,
+ name: dns.name.Name,
+ rdclass: dns.rdataclass.RdataClass,
+ rdtype: dns.rdatatype.RdataType,
+ covers: dns.rdatatype.RdataType,
+ deleting: Optional[dns.rdataclass.RdataClass] = None,
+ ) -> bool:
"""Returns ``True`` if this rrset matches the specified name, class,
type, covers, and deletion state.
"""
- pass
+ if not super().match(rdclass, rdtype, covers):
+ return False
+ if self.name != name or self.deleting != deleting:
+ return False
+ return True
- def to_text(self, origin: Optional[dns.name.Name]=None, relativize:
- bool=True, **kw: Dict[str, Any]) ->str:
+ # pylint: disable=arguments-differ
+
+ def to_text( # type: ignore[override]
+ self,
+ origin: Optional[dns.name.Name] = None,
+ relativize: bool = True,
+ **kw: Dict[str, Any],
+ ) -> str:
"""Convert the RRset into DNS zone file format.
See ``dns.name.Name.choose_relativity`` for more information
@@ -89,11 +148,18 @@ class RRset(dns.rdataset.Rdataset):
*relativize*, a ``bool``. If ``True``, names will be relativized
to *origin*.
"""
- pass
- def to_wire(self, file: Any, compress: Optional[dns.name.CompressType]=
- None, origin: Optional[dns.name.Name]=None, **kw: Dict[str, Any]
- ) ->int:
+ return super().to_text(
+ self.name, origin, relativize, self.deleting, **kw # type: ignore
+ )
+
+ def to_wire( # type: ignore[override]
+ self,
+ file: Any,
+ compress: Optional[dns.name.CompressType] = None, # type: ignore
+ origin: Optional[dns.name.Name] = None,
+ **kw: Dict[str, Any],
+ ) -> int:
"""Convert the RRset to wire format.
All keyword arguments are passed to ``dns.rdataset.to_wire()``; see
@@ -101,21 +167,32 @@ class RRset(dns.rdataset.Rdataset):
Returns an ``int``, the number of records emitted.
"""
- pass
- def to_rdataset(self) ->dns.rdataset.Rdataset:
+ return super().to_wire(
+ self.name, file, compress, origin, self.deleting, **kw # type:ignore
+ )
+
+ # pylint: enable=arguments-differ
+
+ def to_rdataset(self) -> dns.rdataset.Rdataset:
"""Convert an RRset into an Rdataset.
Returns a ``dns.rdataset.Rdataset``.
"""
- pass
-
-
-def from_text_list(name: Union[dns.name.Name, str], ttl: int, rdclass:
- Union[dns.rdataclass.RdataClass, str], rdtype: Union[dns.rdatatype.
- RdataType, str], text_rdatas: Collection[str], idna_codec: Optional[dns
- .name.IDNACodec]=None, origin: Optional[dns.name.Name]=None, relativize:
- bool=True, relativize_to: Optional[dns.name.Name]=None) ->RRset:
+ return dns.rdataset.from_rdata_list(self.ttl, list(self))
+
+
+def from_text_list(
+ name: Union[dns.name.Name, str],
+ ttl: int,
+ rdclass: Union[dns.rdataclass.RdataClass, str],
+ rdtype: Union[dns.rdatatype.RdataType, str],
+ text_rdatas: Collection[str],
+ idna_codec: Optional[dns.name.IDNACodec] = None,
+ origin: Optional[dns.name.Name] = None,
+ relativize: bool = True,
+ relativize_to: Optional[dns.name.Name] = None,
+) -> RRset:
"""Create an RRset with the specified name, TTL, class, and type, and with
the specified list of rdatas in text format.
@@ -133,23 +210,45 @@ def from_text_list(name: Union[dns.name.Name, str], ttl: int, rdclass:
Returns a ``dns.rrset.RRset`` object.
"""
- pass
-
-def from_text(name: Union[dns.name.Name, str], ttl: int, rdclass: Union[dns
- .rdataclass.RdataClass, str], rdtype: Union[dns.rdatatype.RdataType,
- str], *text_rdatas: Any) ->RRset:
+ if isinstance(name, str):
+ name = dns.name.from_text(name, None, idna_codec=idna_codec)
+ rdclass = dns.rdataclass.RdataClass.make(rdclass)
+ rdtype = dns.rdatatype.RdataType.make(rdtype)
+ r = RRset(name, rdclass, rdtype)
+ r.update_ttl(ttl)
+ for t in text_rdatas:
+ rd = dns.rdata.from_text(
+ r.rdclass, r.rdtype, t, origin, relativize, relativize_to, idna_codec
+ )
+ r.add(rd)
+ return r
+
+
+def from_text(
+ name: Union[dns.name.Name, str],
+ ttl: int,
+ rdclass: Union[dns.rdataclass.RdataClass, str],
+ rdtype: Union[dns.rdatatype.RdataType, str],
+ *text_rdatas: Any,
+) -> RRset:
"""Create an RRset with the specified name, TTL, class, and type and with
the specified rdatas in text format.
Returns a ``dns.rrset.RRset`` object.
"""
- pass
+
+ return from_text_list(
+ name, ttl, rdclass, rdtype, cast(Collection[str], text_rdatas)
+ )
-def from_rdata_list(name: Union[dns.name.Name, str], ttl: int, rdatas:
- Collection[dns.rdata.Rdata], idna_codec: Optional[dns.name.IDNACodec]=None
- ) ->RRset:
+def from_rdata_list(
+ name: Union[dns.name.Name, str],
+ ttl: int,
+ rdatas: Collection[dns.rdata.Rdata],
+ idna_codec: Optional[dns.name.IDNACodec] = None,
+) -> RRset:
"""Create an RRset with the specified name and TTL, and with
the specified list of rdata objects.
@@ -160,14 +259,27 @@ def from_rdata_list(name: Union[dns.name.Name, str], ttl: int, rdatas:
Returns a ``dns.rrset.RRset`` object.
"""
- pass
+ if isinstance(name, str):
+ name = dns.name.from_text(name, None, idna_codec=idna_codec)
+
+ if len(rdatas) == 0:
+ raise ValueError("rdata list must not be empty")
+ r = None
+ for rd in rdatas:
+ if r is None:
+ r = RRset(name, rd.rdclass, rd.rdtype)
+ r.update_ttl(ttl)
+ r.add(rd)
+ assert r is not None
+ return r
-def from_rdata(name: Union[dns.name.Name, str], ttl: int, *rdatas: Any
- ) ->RRset:
+
+def from_rdata(name: Union[dns.name.Name, str], ttl: int, *rdatas: Any) -> RRset:
"""Create an RRset with the specified name and TTL, and with
the specified rdata objects.
Returns a ``dns.rrset.RRset`` object.
"""
- pass
+
+ return from_rdata_list(name, ttl, cast(Collection[dns.rdata.Rdata], rdatas))
diff --git a/dns/serial.py b/dns/serial.py
index 3d76f4d..3417299 100644
--- a/dns/serial.py
+++ b/dns/serial.py
@@ -1,14 +1,15 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
"""Serial Number Arthimetic from RFC 1982"""
class Serial:
-
- def __init__(self, value: int, bits: int=32):
- self.value = value % 2 ** bits
+ def __init__(self, value: int, bits: int = 32):
+ self.value = value % 2**bits
self.bits = bits
def __repr__(self):
- return f'dns.serial.Serial({self.value}, {self.bits})'
+ return f"dns.serial.Serial({self.value}, {self.bits})"
def __eq__(self, other):
if isinstance(other, int):
@@ -29,11 +30,11 @@ class Serial:
other = Serial(other, self.bits)
elif not isinstance(other, Serial) or other.bits != self.bits:
return NotImplemented
- if self.value < other.value and other.value - self.value < 2 ** (self
- .bits - 1):
+ if self.value < other.value and other.value - self.value < 2 ** (self.bits - 1):
return True
- elif self.value > other.value and self.value - other.value > 2 ** (self
- .bits - 1):
+ elif self.value > other.value and self.value - other.value > 2 ** (
+ self.bits - 1
+ ):
return True
else:
return False
@@ -46,11 +47,11 @@ class Serial:
other = Serial(other, self.bits)
elif not isinstance(other, Serial) or other.bits != self.bits:
return NotImplemented
- if self.value < other.value and other.value - self.value > 2 ** (self
- .bits - 1):
+ if self.value < other.value and other.value - self.value > 2 ** (self.bits - 1):
return True
- elif self.value > other.value and self.value - other.value < 2 ** (self
- .bits - 1):
+ elif self.value > other.value and self.value - other.value < 2 ** (
+ self.bits - 1
+ ):
return True
else:
return False
@@ -66,10 +67,10 @@ class Serial:
delta = other
else:
raise ValueError
- if abs(delta) > 2 ** (self.bits - 1) - 1:
+ if abs(delta) > (2 ** (self.bits - 1) - 1):
raise ValueError
v += delta
- v = v % 2 ** self.bits
+ v = v % 2**self.bits
return Serial(v, self.bits)
def __iadd__(self, other):
@@ -80,10 +81,10 @@ class Serial:
delta = other
else:
raise ValueError
- if abs(delta) > 2 ** (self.bits - 1) - 1:
+ if abs(delta) > (2 ** (self.bits - 1) - 1):
raise ValueError
v += delta
- v = v % 2 ** self.bits
+ v = v % 2**self.bits
self.value = v
return self
@@ -95,10 +96,10 @@ class Serial:
delta = other
else:
raise ValueError
- if abs(delta) > 2 ** (self.bits - 1) - 1:
+ if abs(delta) > (2 ** (self.bits - 1) - 1):
raise ValueError
v -= delta
- v = v % 2 ** self.bits
+ v = v % 2**self.bits
return Serial(v, self.bits)
def __isub__(self, other):
@@ -109,9 +110,9 @@ class Serial:
delta = other
else:
raise ValueError
- if abs(delta) > 2 ** (self.bits - 1) - 1:
+ if abs(delta) > (2 ** (self.bits - 1) - 1):
raise ValueError
v -= delta
- v = v % 2 ** self.bits
+ v = v % 2**self.bits
self.value = v
return self
diff --git a/dns/set.py b/dns/set.py
index d90e21f..f0fb0d5 100644
--- a/dns/set.py
+++ b/dns/set.py
@@ -1,3 +1,20 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2017 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
import itertools
@@ -9,38 +26,50 @@ class Set:
as these sets are based on lists and are thus indexable, and this
ability is widely used in dnspython applications.
"""
- __slots__ = ['items']
+
+ __slots__ = ["items"]
def __init__(self, items=None):
"""Initialize the set.
*items*, an iterable or ``None``, the initial set of items.
"""
+
self.items = dict()
if items is not None:
for item in items:
- self.add(item)
+ # This is safe for how we use set, but if other code
+ # subclasses it could be a legitimate issue.
+ self.add(item) # lgtm[py/init-calls-subclass]
def __repr__(self):
- return 'dns.set.Set(%s)' % repr(list(self.items.keys()))
+ return "dns.set.Set(%s)" % repr(list(self.items.keys()))
def add(self, item):
"""Add an item to the set."""
- pass
+
+ if item not in self.items:
+ self.items[item] = None
def remove(self, item):
"""Remove an item from the set."""
- pass
+
+ try:
+ del self.items[item]
+ except KeyError:
+ raise ValueError
def discard(self, item):
"""Remove an item from the set if present."""
- pass
+
+ self.items.pop(item, None)
def pop(self):
"""Remove an arbitrary item from the set."""
- pass
+ (k, _) = self.items.popitem()
+ return k
- def _clone(self) ->'Set':
+ def _clone(self) -> "Set":
"""Make a (shallow) copy of the set.
There is a 'clone protocol' that subclasses of this class
@@ -52,44 +81,87 @@ class Set:
return new instances (e.g. union) once, and keep using them in
subclasses.
"""
- pass
+
+ if hasattr(self, "_clone_class"):
+ cls = self._clone_class # type: ignore
+ else:
+ cls = self.__class__
+ obj = cls.__new__(cls)
+ obj.items = dict()
+ obj.items.update(self.items)
+ return obj
def __copy__(self):
"""Make a (shallow) copy of the set."""
+
return self._clone()
def copy(self):
"""Make a (shallow) copy of the set."""
- pass
+
+ return self._clone()
def union_update(self, other):
"""Update the set, adding any elements from other which are not
already in the set.
"""
- pass
+
+ if not isinstance(other, Set):
+ raise ValueError("other must be a Set instance")
+ if self is other: # lgtm[py/comparison-using-is]
+ return
+ for item in other.items:
+ self.add(item)
def intersection_update(self, other):
"""Update the set, removing any elements from other which are not
in both sets.
"""
- pass
+
+ if not isinstance(other, Set):
+ raise ValueError("other must be a Set instance")
+ if self is other: # lgtm[py/comparison-using-is]
+ return
+ # we make a copy of the list so that we can remove items from
+ # the list without breaking the iterator.
+ for item in list(self.items):
+ if item not in other.items:
+ del self.items[item]
def difference_update(self, other):
"""Update the set, removing any elements from other which are in
the set.
"""
- pass
+
+ if not isinstance(other, Set):
+ raise ValueError("other must be a Set instance")
+ if self is other: # lgtm[py/comparison-using-is]
+ self.items.clear()
+ else:
+ for item in other.items:
+ self.discard(item)
def symmetric_difference_update(self, other):
"""Update the set, retaining only elements unique to both sets."""
- pass
+
+ if not isinstance(other, Set):
+ raise ValueError("other must be a Set instance")
+ if self is other: # lgtm[py/comparison-using-is]
+ self.items.clear()
+ else:
+ overlap = self.intersection(other)
+ self.union_update(other)
+ self.difference_update(overlap)
def union(self, other):
"""Return a new set which is the union of ``self`` and ``other``.
Returns the same Set type as this set.
"""
- pass
+
+ obj = self._clone()
+ obj.union_update(other)
+ return obj
def intersection(self, other):
"""Return a new set which is the intersection of ``self`` and
@@ -97,7 +169,10 @@ class Set:
Returns the same Set type as this set.
"""
- pass
+
+ obj = self._clone()
+ obj.intersection_update(other)
+ return obj
def difference(self, other):
"""Return a new set which ``self`` - ``other``, i.e. the items
@@ -105,7 +180,10 @@ class Set:
Returns the same Set type as this set.
"""
- pass
+
+ obj = self._clone()
+ obj.difference_update(other)
+ return obj
def symmetric_difference(self, other):
"""Return a new set which (``self`` - ``other``) | (``other``
@@ -114,7 +192,10 @@ class Set:
Returns the same Set type as this set.
"""
- pass
+
+ obj = self._clone()
+ obj.symmetric_difference_update(other)
+ return obj
def __or__(self, other):
return self.union(other)
@@ -158,11 +239,13 @@ class Set:
*other*, the collection of items with which to update the set, which
may be any iterable type.
"""
- pass
+
+ for item in other:
+ self.add(item)
def clear(self):
"""Make the set empty."""
- pass
+ self.items.clear()
def __eq__(self, other):
return self.items == other.items
@@ -194,11 +277,31 @@ class Set:
Returns a ``bool``.
"""
- pass
+
+ if not isinstance(other, Set):
+ raise ValueError("other must be a Set instance")
+ for item in self.items:
+ if item not in other.items:
+ return False
+ return True
def issuperset(self, other):
"""Is this set a superset of *other*?
Returns a ``bool``.
"""
- pass
+
+ if not isinstance(other, Set):
+ raise ValueError("other must be a Set instance")
+ for item in other.items:
+ if item not in self.items:
+ return False
+ return True
+
+ def isdisjoint(self, other):
+ if not isinstance(other, Set):
+ raise ValueError("other must be a Set instance")
+ for item in other.items:
+ if item in self.items:
+ return False
+ return True
diff --git a/dns/tokenizer.py b/dns/tokenizer.py
index c6a389f..454cac4 100644
--- a/dns/tokenizer.py
+++ b/dns/tokenizer.py
@@ -1,12 +1,33 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2017 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
"""Tokenize DNS zone file format"""
+
import io
import sys
from typing import Any, List, Optional, Tuple
+
import dns.exception
import dns.name
import dns.ttl
-_DELIMITERS = {' ', '\t', '\n', ';', '(', ')', '"'}
+
+_DELIMITERS = {" ", "\t", "\n", ";", "(", ")", '"'}
_QUOTING_DELIMITERS = {'"'}
+
EOF = 0
EOL = 1
WHITESPACE = 2
@@ -28,14 +49,44 @@ class Token:
has_escape: Does the token value contain escapes?
"""
- def __init__(self, ttype: int, value: Any='', has_escape: bool=False,
- comment: Optional[str]=None):
+ def __init__(
+ self,
+ ttype: int,
+ value: Any = "",
+ has_escape: bool = False,
+ comment: Optional[str] = None,
+ ):
"""Initialize a token instance."""
+
self.ttype = ttype
self.value = value
self.has_escape = has_escape
self.comment = comment
+ def is_eof(self) -> bool:
+ return self.ttype == EOF
+
+ def is_eol(self) -> bool:
+ return self.ttype == EOL
+
+ def is_whitespace(self) -> bool:
+ return self.ttype == WHITESPACE
+
+ def is_identifier(self) -> bool:
+ return self.ttype == IDENTIFIER
+
+ def is_quoted_string(self) -> bool:
+ return self.ttype == QUOTED_STRING
+
+ def is_comment(self) -> bool:
+ return self.ttype == COMMENT
+
+ def is_delimiter(self) -> bool: # pragma: no cover (we don't return delimiters yet)
+ return self.ttype == DELIMITER
+
+ def is_eol_or_eof(self) -> bool:
+ return self.ttype == EOL or self.ttype == EOF
+
def __eq__(self, other):
if not isinstance(other, Token):
return False
@@ -49,6 +100,100 @@ class Token:
def __str__(self):
return '%d "%s"' % (self.ttype, self.value)
+ def unescape(self) -> "Token":
+ if not self.has_escape:
+ return self
+ unescaped = ""
+ l = len(self.value)
+ i = 0
+ while i < l:
+ c = self.value[i]
+ i += 1
+ if c == "\\":
+ if i >= l: # pragma: no cover (can't happen via get())
+ raise dns.exception.UnexpectedEnd
+ c = self.value[i]
+ i += 1
+ if c.isdigit():
+ if i >= l:
+ raise dns.exception.UnexpectedEnd
+ c2 = self.value[i]
+ i += 1
+ if i >= l:
+ raise dns.exception.UnexpectedEnd
+ c3 = self.value[i]
+ i += 1
+ if not (c2.isdigit() and c3.isdigit()):
+ raise dns.exception.SyntaxError
+ codepoint = int(c) * 100 + int(c2) * 10 + int(c3)
+ if codepoint > 255:
+ raise dns.exception.SyntaxError
+ c = chr(codepoint)
+ unescaped += c
+ return Token(self.ttype, unescaped)
+
+ def unescape_to_bytes(self) -> "Token":
+ # We used to use unescape() for TXT-like records, but this
+ # caused problems as we'd process DNS escapes into Unicode code
+ # points instead of byte values, and then a to_text() of the
+ # processed data would not equal the original input. For
+ # example, \226 in the TXT record would have a to_text() of
+ # \195\162 because we applied UTF-8 encoding to Unicode code
+ # point 226.
+ #
+ # We now apply escapes while converting directly to bytes,
+ # avoiding this double encoding.
+ #
+ # This code also handles cases where the unicode input has
+ # non-ASCII code-points in it by converting it to UTF-8. TXT
+ # records aren't defined for Unicode, but this is the best we
+ # can do to preserve meaning. For example,
+ #
+ # foo\u200bbar
+ #
+ # (where \u200b is Unicode code point 0x200b) will be treated
+ # as if the input had been the UTF-8 encoding of that string,
+ # namely:
+ #
+ # foo\226\128\139bar
+ #
+ unescaped = b""
+ l = len(self.value)
+ i = 0
+ while i < l:
+ c = self.value[i]
+ i += 1
+ if c == "\\":
+ if i >= l: # pragma: no cover (can't happen via get())
+ raise dns.exception.UnexpectedEnd
+ c = self.value[i]
+ i += 1
+ if c.isdigit():
+ if i >= l:
+ raise dns.exception.UnexpectedEnd
+ c2 = self.value[i]
+ i += 1
+ if i >= l:
+ raise dns.exception.UnexpectedEnd
+ c3 = self.value[i]
+ i += 1
+ if not (c2.isdigit() and c3.isdigit()):
+ raise dns.exception.SyntaxError
+ codepoint = int(c) * 100 + int(c2) * 10 + int(c3)
+ if codepoint > 255:
+ raise dns.exception.SyntaxError
+ unescaped += b"%c" % (codepoint)
+ else:
+ # Note that as mentioned above, if c is a Unicode
+ # code point outside of the ASCII range, then this
+ # += is converting that code point to its UTF-8
+ # encoding and appending multiple bytes to
+ # unescaped.
+ unescaped += c.encode()
+ else:
+ unescaped += c.encode()
+ return Token(self.ttype, bytes(unescaped))
+
class Tokenizer:
"""A DNS zone file format tokenizer.
@@ -83,8 +228,12 @@ class Tokenizer:
encoder/decoder is used.
"""
- def __init__(self, f: Any=sys.stdin, filename: Optional[str]=None,
- idna_codec: Optional[dns.name.IDNACodec]=None):
+ def __init__(
+ self,
+ f: Any = sys.stdin,
+ filename: Optional[str] = None,
+ idna_codec: Optional[dns.name.IDNACodec] = None,
+ ):
"""Initialize a tokenizer instance.
f: The file to tokenize. The default is sys.stdin.
@@ -98,19 +247,21 @@ class Tokenizer:
encoder/decoder. If None, the default IDNA 2003
encoder/decoder is used.
"""
+
if isinstance(f, str):
f = io.StringIO(f)
if filename is None:
- filename = '<string>'
+ filename = "<string>"
elif isinstance(f, bytes):
f = io.StringIO(f.decode())
if filename is None:
- filename = '<string>'
- elif filename is None:
- if f is sys.stdin:
- filename = '<stdin>'
- else:
- filename = '<file>'
+ filename = "<string>"
+ else:
+ if filename is None:
+ if f is sys.stdin:
+ filename = "<stdin>"
+ else:
+ filename = "<file>"
self.file = f
self.ungotten_char: Optional[str] = None
self.ungotten_token: Optional[Token] = None
@@ -126,19 +277,33 @@ class Tokenizer:
else:
self.idna_codec = idna_codec
- def _get_char(self) ->str:
+ def _get_char(self) -> str:
"""Read a character from input."""
- pass
- def where(self) ->Tuple[str, int]:
+ if self.ungotten_char is None:
+ if self.eof:
+ c = ""
+ else:
+ c = self.file.read(1)
+ if c == "":
+ self.eof = True
+ elif c == "\n":
+ self.line_number += 1
+ else:
+ c = self.ungotten_char
+ self.ungotten_char = None
+ return c
+
+ def where(self) -> Tuple[str, int]:
"""Return the current location in the input.
Returns a (string, int) tuple. The first item is the filename of
the input, the second is the current line number.
"""
- pass
- def _unget_char(self, c: str) ->None:
+ return (self.filename, self.line_number)
+
+ def _unget_char(self, c: str) -> None:
"""Unget a character.
The unget buffer for characters is only one character large; it is
@@ -148,9 +313,13 @@ class Tokenizer:
c: the character to unget
raises UngetBufferFull: there is already an ungotten char
"""
- pass
- def skip_whitespace(self) ->int:
+ if self.ungotten_char is not None:
+ # this should never happen!
+ raise UngetBufferFull # pragma: no cover
+ self.ungotten_char = c
+
+ def skip_whitespace(self) -> int:
"""Consume input until a non-whitespace character is encountered.
The non-whitespace character is then ungotten, and the number of
@@ -160,9 +329,17 @@ class Tokenizer:
Returns the number of characters skipped.
"""
- pass
- def get(self, want_leading: bool=False, want_comment: bool=False) ->Token:
+ skipped = 0
+ while True:
+ c = self._get_char()
+ if c != " " and c != "\t":
+ if (c != "\n") or not self.multiline:
+ self._unget_char(c)
+ return skipped
+ skipped += 1
+
+ def get(self, want_leading: bool = False, want_comment: bool = False) -> Token:
"""Get the next token.
want_leading: If True, return a WHITESPACE token if the
@@ -177,9 +354,103 @@ class Tokenizer:
Returns a Token.
"""
- pass
- def unget(self, token: Token) ->None:
+ if self.ungotten_token is not None:
+ utoken = self.ungotten_token
+ self.ungotten_token = None
+ if utoken.is_whitespace():
+ if want_leading:
+ return utoken
+ elif utoken.is_comment():
+ if want_comment:
+ return utoken
+ else:
+ return utoken
+ skipped = self.skip_whitespace()
+ if want_leading and skipped > 0:
+ return Token(WHITESPACE, " ")
+ token = ""
+ ttype = IDENTIFIER
+ has_escape = False
+ while True:
+ c = self._get_char()
+ if c == "" or c in self.delimiters:
+ if c == "" and self.quoting:
+ raise dns.exception.UnexpectedEnd
+ if token == "" and ttype != QUOTED_STRING:
+ if c == "(":
+ self.multiline += 1
+ self.skip_whitespace()
+ continue
+ elif c == ")":
+ if self.multiline <= 0:
+ raise dns.exception.SyntaxError
+ self.multiline -= 1
+ self.skip_whitespace()
+ continue
+ elif c == '"':
+ if not self.quoting:
+ self.quoting = True
+ self.delimiters = _QUOTING_DELIMITERS
+ ttype = QUOTED_STRING
+ continue
+ else:
+ self.quoting = False
+ self.delimiters = _DELIMITERS
+ self.skip_whitespace()
+ continue
+ elif c == "\n":
+ return Token(EOL, "\n")
+ elif c == ";":
+ while 1:
+ c = self._get_char()
+ if c == "\n" or c == "":
+ break
+ token += c
+ if want_comment:
+ self._unget_char(c)
+ return Token(COMMENT, token)
+ elif c == "":
+ if self.multiline:
+ raise dns.exception.SyntaxError(
+ "unbalanced parentheses"
+ )
+ return Token(EOF, comment=token)
+ elif self.multiline:
+ self.skip_whitespace()
+ token = ""
+ continue
+ else:
+ return Token(EOL, "\n", comment=token)
+ else:
+ # This code exists in case we ever want a
+ # delimiter to be returned. It never produces
+ # a token currently.
+ token = c
+ ttype = DELIMITER
+ else:
+ self._unget_char(c)
+ break
+ elif self.quoting and c == "\n":
+ raise dns.exception.SyntaxError("newline in quoted string")
+ elif c == "\\":
+ #
+ # It's an escape. Put it and the next character into
+ # the token; it will be checked later for goodness.
+ #
+ token += c
+ has_escape = True
+ c = self._get_char()
+ if c == "" or (c == "\n" and not self.quoting):
+ raise dns.exception.UnexpectedEnd
+ token += c
+ if token == "" and ttype != QUOTED_STRING:
+ if self.multiline:
+ raise dns.exception.SyntaxError("unbalanced parentheses")
+ ttype = EOF
+ return Token(ttype, token, has_escape)
+
+ def unget(self, token: Token) -> None:
"""Unget a token.
The unget buffer for tokens is only one token large; it is
@@ -190,29 +461,45 @@ class Tokenizer:
Raises UngetBufferFull: there is already an ungotten token
"""
- pass
+
+ if self.ungotten_token is not None:
+ raise UngetBufferFull
+ self.ungotten_token = token
def next(self):
"""Return the next item in an iteration.
Returns a Token.
"""
- pass
+
+ token = self.get()
+ if token.is_eof():
+ raise StopIteration
+ return token
+
__next__ = next
def __iter__(self):
return self
- def get_int(self, base: int=10) ->int:
+ # Helpers
+
+ def get_int(self, base: int = 10) -> int:
"""Read the next token and interpret it as an unsigned integer.
Raises dns.exception.SyntaxError if not an unsigned integer.
Returns an int.
"""
- pass
- def get_uint8(self) ->int:
+ token = self.get().unescape()
+ if not token.is_identifier():
+ raise dns.exception.SyntaxError("expecting an identifier")
+ if not token.value.isdigit():
+ raise dns.exception.SyntaxError("expecting an integer")
+ return int(token.value, base)
+
+ def get_uint8(self) -> int:
"""Read the next token and interpret it as an 8-bit unsigned
integer.
@@ -220,9 +507,15 @@ class Tokenizer:
Returns an int.
"""
- pass
- def get_uint16(self, base: int=10) ->int:
+ value = self.get_int()
+ if value < 0 or value > 255:
+ raise dns.exception.SyntaxError(
+ "%d is not an unsigned 8-bit integer" % value
+ )
+ return value
+
+ def get_uint16(self, base: int = 10) -> int:
"""Read the next token and interpret it as a 16-bit unsigned
integer.
@@ -230,9 +523,20 @@ class Tokenizer:
Returns an int.
"""
- pass
- def get_uint32(self, base: int=10) ->int:
+ value = self.get_int(base=base)
+ if value < 0 or value > 65535:
+ if base == 8:
+ raise dns.exception.SyntaxError(
+ "%o is not an octal unsigned 16-bit integer" % value
+ )
+ else:
+ raise dns.exception.SyntaxError(
+ "%d is not an unsigned 16-bit integer" % value
+ )
+ return value
+
+ def get_uint32(self, base: int = 10) -> int:
"""Read the next token and interpret it as a 32-bit unsigned
integer.
@@ -240,9 +544,15 @@ class Tokenizer:
Returns an int.
"""
- pass
- def get_uint48(self, base: int=10) ->int:
+ value = self.get_int(base=base)
+ if value < 0 or value > 4294967295:
+ raise dns.exception.SyntaxError(
+ "%d is not an unsigned 32-bit integer" % value
+ )
+ return value
+
+ def get_uint48(self, base: int = 10) -> int:
"""Read the next token and interpret it as a 48-bit unsigned
integer.
@@ -250,9 +560,15 @@ class Tokenizer:
Returns an int.
"""
- pass
- def get_string(self, max_length: Optional[int]=None) ->str:
+ value = self.get_int(base=base)
+ if value < 0 or value > 281474976710655:
+ raise dns.exception.SyntaxError(
+ "%d is not an unsigned 48-bit integer" % value
+ )
+ return value
+
+ def get_string(self, max_length: Optional[int] = None) -> str:
"""Read the next token and interpret it as a string.
Raises dns.exception.SyntaxError if not a string.
@@ -261,27 +577,47 @@ class Tokenizer:
Returns a string.
"""
- pass
- def get_identifier(self) ->str:
+ token = self.get().unescape()
+ if not (token.is_identifier() or token.is_quoted_string()):
+ raise dns.exception.SyntaxError("expecting a string")
+ if max_length and len(token.value) > max_length:
+ raise dns.exception.SyntaxError("string too long")
+ return token.value
+
+ def get_identifier(self) -> str:
"""Read the next token, which should be an identifier.
Raises dns.exception.SyntaxError if not an identifier.
Returns a string.
"""
- pass
- def get_remaining(self, max_tokens: Optional[int]=None) ->List[Token]:
+ token = self.get().unescape()
+ if not token.is_identifier():
+ raise dns.exception.SyntaxError("expecting an identifier")
+ return token.value
+
+ def get_remaining(self, max_tokens: Optional[int] = None) -> List[Token]:
"""Return the remaining tokens on the line, until an EOL or EOF is seen.
max_tokens: If not None, stop after this number of tokens.
Returns a list of tokens.
"""
- pass
- def concatenate_remaining_identifiers(self, allow_empty: bool=False) ->str:
+ tokens = []
+ while True:
+ token = self.get()
+ if token.is_eol_or_eof():
+ self.unget(token)
+ break
+ tokens.append(token)
+ if len(tokens) == max_tokens:
+ break
+ return tokens
+
+ def concatenate_remaining_identifiers(self, allow_empty: bool = False) -> str:
"""Read the remaining tokens on the line, which should be identifiers.
Raises dns.exception.SyntaxError if there are no remaining tokens,
@@ -293,39 +629,71 @@ class Tokenizer:
Returns a string containing a concatenation of the remaining
identifiers.
"""
- pass
-
- def as_name(self, token: Token, origin: Optional[dns.name.Name]=None,
- relativize: bool=False, relativize_to: Optional[dns.name.Name]=None
- ) ->dns.name.Name:
+ s = ""
+ while True:
+ token = self.get().unescape()
+ if token.is_eol_or_eof():
+ self.unget(token)
+ break
+ if not token.is_identifier():
+ raise dns.exception.SyntaxError
+ s += token.value
+ if not (allow_empty or s):
+ raise dns.exception.SyntaxError("expecting another identifier")
+ return s
+
+ def as_name(
+ self,
+ token: Token,
+ origin: Optional[dns.name.Name] = None,
+ relativize: bool = False,
+ relativize_to: Optional[dns.name.Name] = None,
+ ) -> dns.name.Name:
"""Try to interpret the token as a DNS name.
Raises dns.exception.SyntaxError if not a name.
Returns a dns.name.Name.
"""
- pass
-
- def get_name(self, origin: Optional[dns.name.Name]=None, relativize:
- bool=False, relativize_to: Optional[dns.name.Name]=None
- ) ->dns.name.Name:
+ if not token.is_identifier():
+ raise dns.exception.SyntaxError("expecting an identifier")
+ name = dns.name.from_text(token.value, origin, self.idna_codec)
+ return name.choose_relativity(relativize_to or origin, relativize)
+
+ def get_name(
+ self,
+ origin: Optional[dns.name.Name] = None,
+ relativize: bool = False,
+ relativize_to: Optional[dns.name.Name] = None,
+ ) -> dns.name.Name:
"""Read the next token and interpret it as a DNS name.
Raises dns.exception.SyntaxError if not a name.
Returns a dns.name.Name.
"""
- pass
- def get_eol_as_token(self) ->Token:
+ token = self.get()
+ return self.as_name(token, origin, relativize, relativize_to)
+
+ def get_eol_as_token(self) -> Token:
"""Read the next token and raise an exception if it isn't EOL or
EOF.
Returns a string.
"""
- pass
- def get_ttl(self) ->int:
+ token = self.get()
+ if not token.is_eol_or_eof():
+ raise dns.exception.SyntaxError(
+ 'expected EOL or EOF, got %d "%s"' % (token.ttype, token.value)
+ )
+ return token
+
+ def get_eol(self) -> str:
+ return self.get_eol_as_token().value
+
+ def get_ttl(self) -> int:
"""Read the next token and interpret it as a DNS TTL.
Raises dns.exception.SyntaxError or dns.ttl.BadTTL if not an
@@ -333,4 +701,8 @@ class Tokenizer:
Returns an int.
"""
- pass
+
+ token = self.get().unescape()
+ if not token.is_identifier():
+ raise dns.exception.SyntaxError("expecting an identifier")
+ return dns.ttl.from_text(token.value)
diff --git a/dns/transaction.py b/dns/transaction.py
index 3acb2f4..84e54f7 100644
--- a/dns/transaction.py
+++ b/dns/transaction.py
@@ -1,5 +1,8 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
import collections
from typing import Any, Callable, Iterator, List, Optional, Tuple, Union
+
import dns.exception
import dns.name
import dns.node
@@ -12,12 +15,11 @@ import dns.ttl
class TransactionManager:
-
- def reader(self) ->'Transaction':
+ def reader(self) -> "Transaction":
"""Begin a read-only transaction."""
- pass
+ raise NotImplementedError # pragma: no cover
- def writer(self, replacement: bool=False) ->'Transaction':
+ def writer(self, replacement: bool = False) -> "Transaction":
"""Begin a writable transaction.
*replacement*, a ``bool``. If `True`, the content of the
@@ -25,10 +27,11 @@ class TransactionManager:
the default, then the content of the transaction updates the
existing content.
"""
- pass
+ raise NotImplementedError # pragma: no cover
- def origin_information(self) ->Tuple[Optional[dns.name.Name], bool,
- Optional[dns.name.Name]]:
+ def origin_information(
+ self,
+ ) -> Tuple[Optional[dns.name.Name], bool, Optional[dns.name.Name]]:
"""Returns a tuple
(absolute_origin, relativize, effective_origin)
@@ -51,15 +54,19 @@ class TransactionManager:
relativity).
"""
- pass
+ raise NotImplementedError # pragma: no cover
- def get_class(self) ->dns.rdataclass.RdataClass:
+ def get_class(self) -> dns.rdataclass.RdataClass:
"""The class of the transaction manager."""
- pass
+ raise NotImplementedError # pragma: no cover
- def from_wire_origin(self) ->Optional[dns.name.Name]:
+ def from_wire_origin(self) -> Optional[dns.name.Name]:
"""Origin to use in from_wire() calls."""
- pass
+ (absolute_origin, relativize, _) = self.origin_information()
+ if relativize:
+ return absolute_origin
+ else:
+ return None
class DeleteNotExact(dns.exception.DNSException):
@@ -74,17 +81,35 @@ class AlreadyEnded(dns.exception.DNSException):
"""Tried to use an already-ended transaction."""
-CheckPutRdatasetType = Callable[['Transaction', dns.name.Name, dns.rdataset
- .Rdataset], None]
-CheckDeleteRdatasetType = Callable[['Transaction', dns.name.Name, dns.
- rdatatype.RdataType, dns.rdatatype.RdataType], None]
-CheckDeleteNameType = Callable[['Transaction', dns.name.Name], None]
+def _ensure_immutable_rdataset(rdataset):
+ if rdataset is None or isinstance(rdataset, dns.rdataset.ImmutableRdataset):
+ return rdataset
+ return dns.rdataset.ImmutableRdataset(rdataset)
-class Transaction:
+def _ensure_immutable_node(node):
+ if node is None or node.is_immutable():
+ return node
+ return dns.node.ImmutableNode(node)
+
- def __init__(self, manager: TransactionManager, replacement: bool=False,
- read_only: bool=False):
+CheckPutRdatasetType = Callable[
+ ["Transaction", dns.name.Name, dns.rdataset.Rdataset], None
+]
+CheckDeleteRdatasetType = Callable[
+ ["Transaction", dns.name.Name, dns.rdatatype.RdataType, dns.rdatatype.RdataType],
+ None,
+]
+CheckDeleteNameType = Callable[["Transaction", dns.name.Name], None]
+
+
+class Transaction:
+ def __init__(
+ self,
+ manager: TransactionManager,
+ replacement: bool = False,
+ read_only: bool = False,
+ ):
self.manager = manager
self.replacement = replacement
self.read_only = read_only
@@ -93,24 +118,44 @@ class Transaction:
self._check_delete_rdataset: List[CheckDeleteRdatasetType] = []
self._check_delete_name: List[CheckDeleteNameType] = []
- def get(self, name: Optional[Union[dns.name.Name, str]], rdtype: Union[
- dns.rdatatype.RdataType, str], covers: Union[dns.rdatatype.
- RdataType, str]=dns.rdatatype.NONE) ->dns.rdataset.Rdataset:
+ #
+ # This is the high level API
+ #
+ # Note that we currently use non-immutable types in the return type signature to
+ # avoid covariance problems, e.g. if the caller has a List[Rdataset], mypy will be
+ # unhappy if we return an ImmutableRdataset.
+
+ def get(
+ self,
+ name: Optional[Union[dns.name.Name, str]],
+ rdtype: Union[dns.rdatatype.RdataType, str],
+ covers: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.NONE,
+ ) -> dns.rdataset.Rdataset:
"""Return the rdataset associated with *name*, *rdtype*, and *covers*,
or `None` if not found.
Note that the returned rdataset is immutable.
"""
- pass
-
- def get_node(self, name: dns.name.Name) ->Optional[dns.node.Node]:
+ self._check_ended()
+ if isinstance(name, str):
+ name = dns.name.from_text(name, None)
+ rdtype = dns.rdatatype.RdataType.make(rdtype)
+ covers = dns.rdatatype.RdataType.make(covers)
+ rdataset = self._get_rdataset(name, rdtype, covers)
+ return _ensure_immutable_rdataset(rdataset)
+
+ def get_node(self, name: dns.name.Name) -> Optional[dns.node.Node]:
"""Return the node at *name*, if any.
Returns an immutable node or ``None``.
"""
- pass
+ return _ensure_immutable_node(self._get_node(name))
+
+ def _check_read_only(self) -> None:
+ if self.read_only:
+ raise ReadOnly
- def add(self, *args: Any) ->None:
+ def add(self, *args: Any) -> None:
"""Add records.
The arguments may be:
@@ -121,9 +166,11 @@ class Transaction:
- name, ttl, rdata...
"""
- pass
+ self._check_ended()
+ self._check_read_only()
+ self._add(False, args)
- def replace(self, *args: Any) ->None:
+ def replace(self, *args: Any) -> None:
"""Replace the existing rdataset at the name with the specified
rdataset, or add the specified rdataset if there was no existing
rdataset.
@@ -140,9 +187,11 @@ class Transaction:
a delete of the name followed by one or more calls to add() or
replace().
"""
- pass
+ self._check_ended()
+ self._check_read_only()
+ self._add(True, args)
- def delete(self, *args: Any) ->None:
+ def delete(self, *args: Any) -> None:
"""Delete records.
It is not an error if some of the records are not in the existing
@@ -160,9 +209,11 @@ class Transaction:
- name, rdata...
"""
- pass
+ self._check_ended()
+ self._check_read_only()
+ self._delete(False, args)
- def delete_exact(self, *args: Any) ->None:
+ def delete_exact(self, *args: Any) -> None:
"""Delete records.
The arguments may be:
@@ -181,14 +232,23 @@ class Transaction:
are not in the existing set.
"""
- pass
+ self._check_ended()
+ self._check_read_only()
+ self._delete(True, args)
- def name_exists(self, name: Union[dns.name.Name, str]) ->bool:
+ def name_exists(self, name: Union[dns.name.Name, str]) -> bool:
"""Does the specified name exist?"""
- pass
-
- def update_serial(self, value: int=1, relative: bool=True, name: dns.
- name.Name=dns.name.empty) ->None:
+ self._check_ended()
+ if isinstance(name, str):
+ name = dns.name.from_text(name, None)
+ return self._name_exists(name)
+
+ def update_serial(
+ self,
+ value: int = 1,
+ relative: bool = True,
+ name: dns.name.Name = dns.name.empty,
+ ) -> None:
"""Update the serial number.
*value*, an `int`, is an increment if *relative* is `True`, or the
@@ -200,13 +260,30 @@ class Transaction:
so large that it would cause the new serial to be less than the
prior value.
"""
- pass
+ self._check_ended()
+ if value < 0:
+ raise ValueError("negative update_serial() value")
+ if isinstance(name, str):
+ name = dns.name.from_text(name, None)
+ rdataset = self._get_rdataset(name, dns.rdatatype.SOA, dns.rdatatype.NONE)
+ if rdataset is None or len(rdataset) == 0:
+ raise KeyError
+ if relative:
+ serial = dns.serial.Serial(rdataset[0].serial) + value
+ else:
+ serial = dns.serial.Serial(value)
+ serial = serial.value # convert back to int
+ if serial == 0:
+ serial = 1
+ rdata = rdataset[0].replace(serial=serial)
+ new_rdataset = dns.rdataset.from_rdata(rdataset.ttl, rdata)
+ self.replace(name, new_rdataset)
def __iter__(self):
self._check_ended()
return self._iterate_rdatasets()
- def changed(self) ->bool:
+ def changed(self) -> bool:
"""Has this transaction changed anything?
For read-only transactions, the result is always `False`.
@@ -214,9 +291,10 @@ class Transaction:
For writable transactions, the result is `True` if at some time
during the life of the transaction, the content was changed.
"""
- pass
+ self._check_ended()
+ return self._changed()
- def commit(self) ->None:
+ def commit(self) -> None:
"""Commit the transaction.
Normally transactions are used as context managers and commit
@@ -227,9 +305,9 @@ class Transaction:
Raises an exception if the commit fails (in which case the transaction
is also rolled back.
"""
- pass
+ self._end(True)
- def rollback(self) ->None:
+ def rollback(self) -> None:
"""Rollback the transaction.
Normally transactions are used as context managers and commit
@@ -239,9 +317,9 @@ class Transaction:
Rollback cannot otherwise fail.
"""
- pass
+ self._end(False)
- def check_put_rdataset(self, check: CheckPutRdatasetType) ->None:
+ def check_put_rdataset(self, check: CheckPutRdatasetType) -> None:
"""Call *check* before putting (storing) an rdataset.
The function is called with the transaction, the name, and the rdataset.
@@ -251,9 +329,9 @@ class Transaction:
called. The check function should raise an exception if it objects to
the put, and otherwise should return ``None``.
"""
- pass
+ self._check_put_rdataset.append(check)
- def check_delete_rdataset(self, check: CheckDeleteRdatasetType) ->None:
+ def check_delete_rdataset(self, check: CheckDeleteRdatasetType) -> None:
"""Call *check* before deleting an rdataset.
The function is called with the transaction, the name, the rdatatype,
@@ -264,9 +342,9 @@ class Transaction:
called. The check function should raise an exception if it objects to
the put, and otherwise should return ``None``.
"""
- pass
+ self._check_delete_rdataset.append(check)
- def check_delete_name(self, check: CheckDeleteNameType) ->None:
+ def check_delete_name(self, check: CheckDeleteNameType) -> None:
"""Call *check* before putting (storing) an rdataset.
The function is called with the transaction and the name.
@@ -276,25 +354,206 @@ class Transaction:
called. The check function should raise an exception if it objects to
the put, and otherwise should return ``None``.
"""
- pass
+ self._check_delete_name.append(check)
- def iterate_rdatasets(self) ->Iterator[Tuple[dns.name.Name, dns.
- rdataset.Rdataset]]:
+ def iterate_rdatasets(
+ self,
+ ) -> Iterator[Tuple[dns.name.Name, dns.rdataset.Rdataset]]:
"""Iterate all the rdatasets in the transaction, returning
(`dns.name.Name`, `dns.rdataset.Rdataset`) tuples.
Note that as is usual with python iterators, adding or removing items
while iterating will invalidate the iterator and may raise `RuntimeError`
or fail to iterate over all entries."""
- pass
+ self._check_ended()
+ return self._iterate_rdatasets()
- def iterate_names(self) ->Iterator[dns.name.Name]:
+ def iterate_names(self) -> Iterator[dns.name.Name]:
"""Iterate all the names in the transaction.
Note that as is usual with python iterators, adding or removing names
while iterating will invalidate the iterator and may raise `RuntimeError`
or fail to iterate over all entries."""
- pass
+ self._check_ended()
+ return self._iterate_names()
+
+ #
+ # Helper methods
+ #
+
+ def _raise_if_not_empty(self, method, args):
+ if len(args) != 0:
+ raise TypeError(f"extra parameters to {method}")
+
+ def _rdataset_from_args(self, method, deleting, args):
+ try:
+ arg = args.popleft()
+ if isinstance(arg, dns.rrset.RRset):
+ rdataset = arg.to_rdataset()
+ elif isinstance(arg, dns.rdataset.Rdataset):
+ rdataset = arg
+ else:
+ if deleting:
+ ttl = 0
+ else:
+ if isinstance(arg, int):
+ ttl = arg
+ if ttl > dns.ttl.MAX_TTL:
+ raise ValueError(f"{method}: TTL value too big")
+ else:
+ raise TypeError(f"{method}: expected a TTL")
+ arg = args.popleft()
+ if isinstance(arg, dns.rdata.Rdata):
+ rdataset = dns.rdataset.from_rdata(ttl, arg)
+ else:
+ raise TypeError(f"{method}: expected an Rdata")
+ return rdataset
+ except IndexError:
+ if deleting:
+ return None
+ else:
+ # reraise
+ raise TypeError(f"{method}: expected more arguments")
+
+ def _add(self, replace, args):
+ try:
+ args = collections.deque(args)
+ if replace:
+ method = "replace()"
+ else:
+ method = "add()"
+ arg = args.popleft()
+ if isinstance(arg, str):
+ arg = dns.name.from_text(arg, None)
+ if isinstance(arg, dns.name.Name):
+ name = arg
+ rdataset = self._rdataset_from_args(method, False, args)
+ elif isinstance(arg, dns.rrset.RRset):
+ rrset = arg
+ name = rrset.name
+ # rrsets are also rdatasets, but they don't print the
+ # same and can't be stored in nodes, so convert.
+ rdataset = rrset.to_rdataset()
+ else:
+ raise TypeError(
+ f"{method} requires a name or RRset as the first argument"
+ )
+ if rdataset.rdclass != self.manager.get_class():
+ raise ValueError(f"{method} has objects of wrong RdataClass")
+ if rdataset.rdtype == dns.rdatatype.SOA:
+ (_, _, origin) = self._origin_information()
+ if name != origin:
+ raise ValueError(f"{method} has non-origin SOA")
+ self._raise_if_not_empty(method, args)
+ if not replace:
+ existing = self._get_rdataset(name, rdataset.rdtype, rdataset.covers)
+ if existing is not None:
+ if isinstance(existing, dns.rdataset.ImmutableRdataset):
+ trds = dns.rdataset.Rdataset(
+ existing.rdclass, existing.rdtype, existing.covers
+ )
+ trds.update(existing)
+ existing = trds
+ rdataset = existing.union(rdataset)
+ self._checked_put_rdataset(name, rdataset)
+ except IndexError:
+ raise TypeError(f"not enough parameters to {method}")
+
+ def _delete(self, exact, args):
+ try:
+ args = collections.deque(args)
+ if exact:
+ method = "delete_exact()"
+ else:
+ method = "delete()"
+ arg = args.popleft()
+ if isinstance(arg, str):
+ arg = dns.name.from_text(arg, None)
+ if isinstance(arg, dns.name.Name):
+ name = arg
+ if len(args) > 0 and (
+ isinstance(args[0], int) or isinstance(args[0], str)
+ ):
+ # deleting by type and (optionally) covers
+ rdtype = dns.rdatatype.RdataType.make(args.popleft())
+ if len(args) > 0:
+ covers = dns.rdatatype.RdataType.make(args.popleft())
+ else:
+ covers = dns.rdatatype.NONE
+ self._raise_if_not_empty(method, args)
+ existing = self._get_rdataset(name, rdtype, covers)
+ if existing is None:
+ if exact:
+ raise DeleteNotExact(f"{method}: missing rdataset")
+ else:
+ self._delete_rdataset(name, rdtype, covers)
+ return
+ else:
+ rdataset = self._rdataset_from_args(method, True, args)
+ elif isinstance(arg, dns.rrset.RRset):
+ rdataset = arg # rrsets are also rdatasets
+ name = rdataset.name
+ else:
+ raise TypeError(
+ f"{method} requires a name or RRset as the first argument"
+ )
+ self._raise_if_not_empty(method, args)
+ if rdataset:
+ if rdataset.rdclass != self.manager.get_class():
+ raise ValueError(f"{method} has objects of wrong RdataClass")
+ existing = self._get_rdataset(name, rdataset.rdtype, rdataset.covers)
+ if existing is not None:
+ if exact:
+ intersection = existing.intersection(rdataset)
+ if intersection != rdataset:
+ raise DeleteNotExact(f"{method}: missing rdatas")
+ rdataset = existing.difference(rdataset)
+ if len(rdataset) == 0:
+ self._checked_delete_rdataset(
+ name, rdataset.rdtype, rdataset.covers
+ )
+ else:
+ self._checked_put_rdataset(name, rdataset)
+ elif exact:
+ raise DeleteNotExact(f"{method}: missing rdataset")
+ else:
+ if exact and not self._name_exists(name):
+ raise DeleteNotExact(f"{method}: name not known")
+ self._checked_delete_name(name)
+ except IndexError:
+ raise TypeError(f"not enough parameters to {method}")
+
+ def _check_ended(self):
+ if self._ended:
+ raise AlreadyEnded
+
+ def _end(self, commit):
+ self._check_ended()
+ if self._ended:
+ raise AlreadyEnded
+ try:
+ self._end_transaction(commit)
+ finally:
+ self._ended = True
+
+ def _checked_put_rdataset(self, name, rdataset):
+ for check in self._check_put_rdataset:
+ check(self, name, rdataset)
+ self._put_rdataset(name, rdataset)
+
+ def _checked_delete_rdataset(self, name, rdtype, covers):
+ for check in self._check_delete_rdataset:
+ check(self, name, rdtype, covers)
+ self._delete_rdataset(name, rdtype, covers)
+
+ def _checked_delete_name(self, name):
+ for check in self._check_delete_name:
+ check(self, name)
+ self._delete_name(name)
+
+ #
+ # Transactions are context managers.
+ #
def __enter__(self):
return self
@@ -307,40 +566,45 @@ class Transaction:
self.rollback()
return False
+ #
+ # This is the low level API, which must be implemented by subclasses
+ # of Transaction.
+ #
+
def _get_rdataset(self, name, rdtype, covers):
"""Return the rdataset associated with *name*, *rdtype*, and *covers*,
or `None` if not found.
"""
- pass
+ raise NotImplementedError # pragma: no cover
def _put_rdataset(self, name, rdataset):
"""Store the rdataset."""
- pass
+ raise NotImplementedError # pragma: no cover
def _delete_name(self, name):
"""Delete all data associated with *name*.
It is not an error if the name does not exist.
"""
- pass
+ raise NotImplementedError # pragma: no cover
def _delete_rdataset(self, name, rdtype, covers):
"""Delete all data associated with *name*, *rdtype*, and *covers*.
It is not an error if the rdataset does not exist.
"""
- pass
+ raise NotImplementedError # pragma: no cover
def _name_exists(self, name):
"""Does name exist?
Returns a bool.
"""
- pass
+ raise NotImplementedError # pragma: no cover
def _changed(self):
"""Has this transaction changed anything?"""
- pass
+ raise NotImplementedError # pragma: no cover
def _end_transaction(self, commit):
"""End the transaction.
@@ -351,7 +615,7 @@ class Transaction:
If committing and the commit fails, then roll back and raise an
exception.
"""
- pass
+ raise NotImplementedError # pragma: no cover
def _set_origin(self, origin):
"""Set the origin.
@@ -360,19 +624,28 @@ class Transaction:
source, and an origin setting operation occurs (e.g. $ORIGIN
in a zone file).
"""
- pass
+ raise NotImplementedError # pragma: no cover
def _iterate_rdatasets(self):
"""Return an iterator that yields (name, rdataset) tuples."""
- pass
+ raise NotImplementedError # pragma: no cover
def _iterate_names(self):
"""Return an iterator that yields a name."""
- pass
+ raise NotImplementedError # pragma: no cover
def _get_node(self, name):
"""Return the node at *name*, if any.
Returns a node or ``None``.
"""
- pass
+ raise NotImplementedError # pragma: no cover
+
+ #
+ # Low-level API with a default implementation, in case a subclass needs
+ # to override.
+ #
+
+ def _origin_information(self):
+ # This is only used by _add()
+ return self.manager.origin_information()
diff --git a/dns/tsig.py b/dns/tsig.py
index 38ac6a5..780852e 100644
--- a/dns/tsig.py
+++ b/dns/tsig.py
@@ -1,8 +1,27 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2001-2007, 2009-2011 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
"""DNS TSIG support."""
+
import base64
import hashlib
import hmac
import struct
+
import dns.exception
import dns.name
import dns.rcode
@@ -45,20 +64,33 @@ class PeerBadTruncation(PeerError):
"""The peer didn't like amount of truncation in the TSIG we sent"""
-HMAC_MD5 = dns.name.from_text('HMAC-MD5.SIG-ALG.REG.INT')
-HMAC_SHA1 = dns.name.from_text('hmac-sha1')
-HMAC_SHA224 = dns.name.from_text('hmac-sha224')
-HMAC_SHA256 = dns.name.from_text('hmac-sha256')
-HMAC_SHA256_128 = dns.name.from_text('hmac-sha256-128')
-HMAC_SHA384 = dns.name.from_text('hmac-sha384')
-HMAC_SHA384_192 = dns.name.from_text('hmac-sha384-192')
-HMAC_SHA512 = dns.name.from_text('hmac-sha512')
-HMAC_SHA512_256 = dns.name.from_text('hmac-sha512-256')
-GSS_TSIG = dns.name.from_text('gss-tsig')
+# TSIG Algorithms
+
+HMAC_MD5 = dns.name.from_text("HMAC-MD5.SIG-ALG.REG.INT")
+HMAC_SHA1 = dns.name.from_text("hmac-sha1")
+HMAC_SHA224 = dns.name.from_text("hmac-sha224")
+HMAC_SHA256 = dns.name.from_text("hmac-sha256")
+HMAC_SHA256_128 = dns.name.from_text("hmac-sha256-128")
+HMAC_SHA384 = dns.name.from_text("hmac-sha384")
+HMAC_SHA384_192 = dns.name.from_text("hmac-sha384-192")
+HMAC_SHA512 = dns.name.from_text("hmac-sha512")
+HMAC_SHA512_256 = dns.name.from_text("hmac-sha512-256")
+GSS_TSIG = dns.name.from_text("gss-tsig")
+
default_algorithm = HMAC_SHA256
-mac_sizes = {HMAC_SHA1: 20, HMAC_SHA224: 28, HMAC_SHA256: 32,
- HMAC_SHA256_128: 16, HMAC_SHA384: 48, HMAC_SHA384_192: 24, HMAC_SHA512:
- 64, HMAC_SHA512_256: 32, HMAC_MD5: 16, GSS_TSIG: 128}
+
+mac_sizes = {
+ HMAC_SHA1: 20,
+ HMAC_SHA224: 28,
+ HMAC_SHA256: 32,
+ HMAC_SHA256_128: 16,
+ HMAC_SHA384: 48,
+ HMAC_SHA384_192: 24,
+ HMAC_SHA512: 64,
+ HMAC_SHA512_256: 32,
+ HMAC_MD5: 16,
+ GSS_TSIG: 128, # This is what we assume to be the worst case!
+}
class GSSTSig:
@@ -73,12 +105,26 @@ class GSSTSig:
def __init__(self, gssapi_context):
self.gssapi_context = gssapi_context
- self.data = b''
- self.name = 'gss-tsig'
+ self.data = b""
+ self.name = "gss-tsig"
+ def update(self, data):
+ self.data += data
-class GSSTSigAdapter:
+ def sign(self):
+ # defer to the GSSAPI function to sign
+ return self.gssapi_context.get_signature(self.data)
+ def verify(self, expected):
+ try:
+ # defer to the GSSAPI function to verify
+ return self.gssapi_context.verify_signature(self.data, expected)
+ except Exception:
+ # note the usage of a bare exception
+ raise BadSignature
+
+
+class GSSTSigAdapter:
def __init__(self, keyring):
self.keyring = keyring
@@ -92,24 +138,49 @@ class GSSTSigAdapter:
else:
return None
+ @classmethod
+ def parse_tkey_and_step(cls, key, message, keyname):
+ # if the message is a TKEY type, absorb the key material
+ # into the context using step(); this is used to allow the
+ # client to complete the GSSAPI negotiation before attempting
+ # to verify the signed response to a TKEY message exchange
+ try:
+ rrset = message.find_rrset(
+ message.answer, keyname, dns.rdataclass.ANY, dns.rdatatype.TKEY
+ )
+ if rrset:
+ token = rrset[0].key
+ gssapi_context = key.secret
+ return gssapi_context.step(token)
+ except KeyError:
+ pass
+
class HMACTSig:
"""
HMAC TSIG implementation. This uses the HMAC python module to handle the
sign/verify operations.
"""
- _hashes = {HMAC_SHA1: hashlib.sha1, HMAC_SHA224: hashlib.sha224,
- HMAC_SHA256: hashlib.sha256, HMAC_SHA256_128: (hashlib.sha256, 128),
- HMAC_SHA384: hashlib.sha384, HMAC_SHA384_192: (hashlib.sha384, 192),
- HMAC_SHA512: hashlib.sha512, HMAC_SHA512_256: (hashlib.sha512, 256),
- HMAC_MD5: hashlib.md5}
+
+ _hashes = {
+ HMAC_SHA1: hashlib.sha1,
+ HMAC_SHA224: hashlib.sha224,
+ HMAC_SHA256: hashlib.sha256,
+ HMAC_SHA256_128: (hashlib.sha256, 128),
+ HMAC_SHA384: hashlib.sha384,
+ HMAC_SHA384_192: (hashlib.sha384, 192),
+ HMAC_SHA512: hashlib.sha512,
+ HMAC_SHA512_256: (hashlib.sha512, 256),
+ HMAC_MD5: hashlib.md5,
+ }
def __init__(self, key, algorithm):
try:
hashinfo = self._hashes[algorithm]
except KeyError:
- raise NotImplementedError(
- f'TSIG algorithm {algorithm} is not supported')
+ raise NotImplementedError(f"TSIG algorithm {algorithm} is not supported")
+
+ # create the HMAC context
if isinstance(hashinfo, tuple):
self.hmac_context = hmac.new(key, digestmod=hashinfo[0])
self.size = hashinfo[1]
@@ -118,17 +189,58 @@ class HMACTSig:
self.size = None
self.name = self.hmac_context.name
if self.size:
- self.name += f'-{self.size}'
+ self.name += f"-{self.size}"
+ def update(self, data):
+ return self.hmac_context.update(data)
+
+ def sign(self):
+ # defer to the HMAC digest() function for that digestmod
+ digest = self.hmac_context.digest()
+ if self.size:
+ digest = digest[: (self.size // 8)]
+ return digest
-def _digest(wire, key, rdata, time=None, request_mac=None, ctx=None, multi=None
- ):
+ def verify(self, expected):
+ # re-digest and compare the results
+ mac = self.sign()
+ if not hmac.compare_digest(mac, expected):
+ raise BadSignature
+
+
+def _digest(wire, key, rdata, time=None, request_mac=None, ctx=None, multi=None):
"""Return a context containing the TSIG rdata for the input parameters
@rtype: dns.tsig.HMACTSig or dns.tsig.GSSTSig object
@raises ValueError: I{other_data} is too long
@raises NotImplementedError: I{algorithm} is not supported
"""
- pass
+
+ first = not (ctx and multi)
+ if first:
+ ctx = get_context(key)
+ if request_mac:
+ ctx.update(struct.pack("!H", len(request_mac)))
+ ctx.update(request_mac)
+ ctx.update(struct.pack("!H", rdata.original_id))
+ ctx.update(wire[2:])
+ if first:
+ ctx.update(key.name.to_digestable())
+ ctx.update(struct.pack("!H", dns.rdataclass.ANY))
+ ctx.update(struct.pack("!I", 0))
+ if time is None:
+ time = rdata.time_signed
+ upper_time = (time >> 32) & 0xFFFF
+ lower_time = time & 0xFFFFFFFF
+ time_encoded = struct.pack("!HIH", upper_time, lower_time, rdata.fudge)
+ other_len = len(rdata.other)
+ if other_len > 65535:
+ raise ValueError("TSIG Other Data is > 65535 bytes")
+ if first:
+ ctx.update(key.algorithm.to_digestable() + time_encoded)
+ ctx.update(struct.pack("!HH", rdata.error, other_len) + rdata.other)
+ else:
+ ctx.update(time_encoded)
+ return ctx
def _maybe_start_digest(key, mac, multi):
@@ -136,7 +248,13 @@ def _maybe_start_digest(key, mac, multi):
start a new context.
@rtype: dns.tsig.HMACTSig or dns.tsig.GSSTSig object
"""
- pass
+ if multi:
+ ctx = get_context(key)
+ ctx.update(struct.pack("!H", len(mac)))
+ ctx.update(mac)
+ return ctx
+ else:
+ return None
def sign(wire, key, rdata, time=None, request_mac=None, ctx=None, multi=False):
@@ -147,11 +265,17 @@ def sign(wire, key, rdata, time=None, request_mac=None, ctx=None, multi=False):
@raises ValueError: I{other_data} is too long
@raises NotImplementedError: I{algorithm} is not supported
"""
- pass
+ ctx = _digest(wire, key, rdata, time, request_mac, ctx, multi)
+ mac = ctx.sign()
+ tsig = rdata.replace(time_signed=time, mac=mac)
-def validate(wire, key, owner, rdata, now, request_mac, tsig_start, ctx=
- None, multi=False):
+ return (tsig, _maybe_start_digest(key, mac, multi))
+
+
+def validate(
+ wire, key, owner, rdata, now, request_mac, tsig_start, ctx=None, multi=False
+):
"""Validate the specified TSIG rdata against the other input parameters.
@raises FormError: The TSIG is badly formed.
@@ -159,7 +283,32 @@ def validate(wire, key, owner, rdata, now, request_mac, tsig_start, ctx=
server.
@raises BadSignature: The TSIG signature did not validate
@rtype: dns.tsig.HMACTSig or dns.tsig.GSSTSig object"""
- pass
+
+ (adcount,) = struct.unpack("!H", wire[10:12])
+ if adcount == 0:
+ raise dns.exception.FormError
+ adcount -= 1
+ new_wire = wire[0:10] + struct.pack("!H", adcount) + wire[12:tsig_start]
+ if rdata.error != 0:
+ if rdata.error == dns.rcode.BADSIG:
+ raise PeerBadSignature
+ elif rdata.error == dns.rcode.BADKEY:
+ raise PeerBadKey
+ elif rdata.error == dns.rcode.BADTIME:
+ raise PeerBadTime
+ elif rdata.error == dns.rcode.BADTRUNC:
+ raise PeerBadTruncation
+ else:
+ raise PeerError("unknown TSIG error code %d" % rdata.error)
+ if abs(rdata.time_signed - now) > rdata.fudge:
+ raise BadTime
+ if key.name != owner:
+ raise BadKey
+ if key.algorithm != rdata.algorithm:
+ raise BadAlgorithm
+ ctx = _digest(new_wire, key, rdata, None, request_mac, ctx, multi)
+ ctx.verify(rdata.mac)
+ return _maybe_start_digest(key, rdata.mac, multi)
def get_context(key):
@@ -168,11 +317,14 @@ def get_context(key):
@rtype: HMAC context
@raises NotImplementedError: I{algorithm} is not supported
"""
- pass
+ if key.algorithm == GSS_TSIG:
+ return GSSTSig(key.secret)
+ else:
+ return HMACTSig(key.secret, key.algorithm)
-class Key:
+class Key:
def __init__(self, name, secret, algorithm=default_algorithm):
if isinstance(name, str):
name = dns.name.from_text(name)
@@ -185,12 +337,16 @@ class Key:
self.algorithm = algorithm
def __eq__(self, other):
- return (isinstance(other, Key) and self.name == other.name and self
- .secret == other.secret and self.algorithm == other.algorithm)
+ return (
+ isinstance(other, Key)
+ and self.name == other.name
+ and self.secret == other.secret
+ and self.algorithm == other.algorithm
+ )
def __repr__(self):
r = f"<DNS key name='{self.name}', " + f"algorithm='{self.algorithm}'"
if self.algorithm != GSS_TSIG:
r += f", secret='{base64.b64encode(self.secret).decode()}'"
- r += '>'
+ r += ">"
return r
diff --git a/dns/tsigkeyring.py b/dns/tsigkeyring.py
index 83df7bd..1010a79 100644
--- a/dns/tsigkeyring.py
+++ b/dns/tsigkeyring.py
@@ -1,23 +1,68 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
"""A place to store TSIG keys."""
+
import base64
from typing import Any, Dict
+
import dns.name
import dns.tsig
-def from_text(textring: Dict[str, Any]) ->Dict[dns.name.Name, dns.tsig.Key]:
+def from_text(textring: Dict[str, Any]) -> Dict[dns.name.Name, dns.tsig.Key]:
"""Convert a dictionary containing (textual DNS name, base64 secret)
pairs into a binary keyring which has (dns.name.Name, bytes) pairs, or
a dictionary containing (textual DNS name, (algorithm, base64 secret))
pairs into a binary keyring which has (dns.name.Name, dns.tsig.Key) pairs.
@rtype: dict"""
- pass
+
+ keyring = {}
+ for name, value in textring.items():
+ kname = dns.name.from_text(name)
+ if isinstance(value, str):
+ keyring[kname] = dns.tsig.Key(kname, value).secret
+ else:
+ (algorithm, secret) = value
+ keyring[kname] = dns.tsig.Key(kname, secret, algorithm)
+ return keyring
-def to_text(keyring: Dict[dns.name.Name, Any]) ->Dict[str, Any]:
+def to_text(keyring: Dict[dns.name.Name, Any]) -> Dict[str, Any]:
"""Convert a dictionary containing (dns.name.Name, dns.tsig.Key) pairs
into a text keyring which has (textual DNS name, (textual algorithm,
base64 secret)) pairs, or a dictionary containing (dns.name.Name, bytes)
pairs into a text keyring which has (textual DNS name, base64 secret) pairs.
@rtype: dict"""
- pass
+
+ textring = {}
+
+ def b64encode(secret):
+ return base64.encodebytes(secret).decode().rstrip()
+
+ for name, key in keyring.items():
+ tname = name.to_text()
+ if isinstance(key, bytes):
+ textring[tname] = b64encode(key)
+ else:
+ if isinstance(key.secret, bytes):
+ text_secret = b64encode(key.secret)
+ else:
+ text_secret = str(key.secret)
+
+ textring[tname] = (key.algorithm.to_text(), text_secret)
+ return textring
diff --git a/dns/ttl.py b/dns/ttl.py
index 0ade6bc..264b033 100644
--- a/dns/ttl.py
+++ b/dns/ttl.py
@@ -1,14 +1,39 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2017 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
"""DNS TTL conversion."""
+
from typing import Union
+
import dns.exception
-MAX_TTL = 2 ** 32 - 1
+
+# Technically TTLs are supposed to be between 0 and 2**31 - 1, with values
+# greater than that interpreted as 0, but we do not impose this policy here
+# as values > 2**31 - 1 occur in real world data.
+#
+# We leave it to applications to impose tighter bounds if desired.
+MAX_TTL = 2**32 - 1
class BadTTL(dns.exception.SyntaxError):
"""DNS TTL value is not well-formed."""
-def from_text(text: str) ->int:
+def from_text(text: str) -> int:
"""Convert the text form of a TTL to an integer.
The BIND 8 units syntax for TTLs (e.g. '1w6d4h3m10s') is supported.
@@ -19,4 +44,49 @@ def from_text(text: str) ->int:
Returns an ``int``.
"""
- pass
+
+ if text.isdigit():
+ total = int(text)
+ elif len(text) == 0:
+ raise BadTTL
+ else:
+ total = 0
+ current = 0
+ need_digit = True
+ for c in text:
+ if c.isdigit():
+ current *= 10
+ current += int(c)
+ need_digit = False
+ else:
+ if need_digit:
+ raise BadTTL
+ c = c.lower()
+ if c == "w":
+ total += current * 604800
+ elif c == "d":
+ total += current * 86400
+ elif c == "h":
+ total += current * 3600
+ elif c == "m":
+ total += current * 60
+ elif c == "s":
+ total += current
+ else:
+ raise BadTTL("unknown unit '%s'" % c)
+ current = 0
+ need_digit = True
+ if not current == 0:
+ raise BadTTL("trailing integer")
+ if total < 0 or total > MAX_TTL:
+ raise BadTTL("TTL should be between 0 and 2**32 - 1 (inclusive)")
+ return total
+
+
+def make(value: Union[int, str]) -> int:
+ if isinstance(value, int):
+ return value
+ elif isinstance(value, str):
+ return dns.ttl.from_text(value)
+ else:
+ raise ValueError("cannot convert value to TTL")
diff --git a/dns/update.py b/dns/update.py
index d53b842..bf1157a 100644
--- a/dns/update.py
+++ b/dns/update.py
@@ -1,5 +1,24 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
"""DNS Dynamic Update Support"""
+
from typing import Any, List, Optional, Union
+
import dns.message
import dns.name
import dns.opcode
@@ -12,20 +31,30 @@ import dns.tsig
class UpdateSection(dns.enum.IntEnum):
"""Update sections"""
+
ZONE = 0
PREREQ = 1
UPDATE = 2
ADDITIONAL = 3
+ @classmethod
+ def _maximum(cls):
+ return 3
-class UpdateMessage(dns.message.Message):
- _section_enum = UpdateSection
- def __init__(self, zone: Optional[Union[dns.name.Name, str]]=None,
- rdclass: dns.rdataclass.RdataClass=dns.rdataclass.IN, keyring:
- Optional[Any]=None, keyname: Optional[dns.name.Name]=None,
- keyalgorithm: Union[dns.name.Name, str]=dns.tsig.default_algorithm,
- id: Optional[int]=None):
+class UpdateMessage(dns.message.Message): # lgtm[py/missing-equals]
+ # ignore the mypy error here as we mean to use a different enum
+ _section_enum = UpdateSection # type: ignore
+
+ def __init__(
+ self,
+ zone: Optional[Union[dns.name.Name, str]] = None,
+ rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN,
+ keyring: Optional[Any] = None,
+ keyname: Optional[dns.name.Name] = None,
+ keyalgorithm: Union[dns.name.Name, str] = dns.tsig.default_algorithm,
+ id: Optional[int] = None,
+ ):
"""Initialize a new DNS Update object.
See the documentation of the Message class for a complete
@@ -49,29 +78,54 @@ class UpdateMessage(dns.message.Message):
rdclass = dns.rdataclass.RdataClass.make(rdclass)
self.zone_rdclass = rdclass
if self.origin:
- self.find_rrset(self.zone, self.origin, rdclass, dns.rdatatype.
- SOA, create=True, force_unique=True)
+ self.find_rrset(
+ self.zone,
+ self.origin,
+ rdclass,
+ dns.rdatatype.SOA,
+ create=True,
+ force_unique=True,
+ )
if keyring is not None:
self.use_tsig(keyring, keyname, algorithm=keyalgorithm)
@property
- def zone(self) ->List[dns.rrset.RRset]:
+ def zone(self) -> List[dns.rrset.RRset]:
"""The zone section."""
- pass
+ return self.sections[0]
+
+ @zone.setter
+ def zone(self, v):
+ self.sections[0] = v
@property
- def prerequisite(self) ->List[dns.rrset.RRset]:
+ def prerequisite(self) -> List[dns.rrset.RRset]:
"""The prerequisite section."""
- pass
+ return self.sections[1]
+
+ @prerequisite.setter
+ def prerequisite(self, v):
+ self.sections[1] = v
@property
- def update(self) ->List[dns.rrset.RRset]:
+ def update(self) -> List[dns.rrset.RRset]:
"""The update section."""
- pass
+ return self.sections[2]
+
+ @update.setter
+ def update(self, v):
+ self.sections[2] = v
def _add_rr(self, name, ttl, rd, deleting=None, section=None):
"""Add a single RR to the update section."""
- pass
+
+ if section is None:
+ section = self.update
+ covers = rd.covers()
+ rrset = self.find_rrset(
+ section, name, self.zone_rdclass, rd.rdtype, covers, deleting, True, True
+ )
+ rrset.add(rd, ttl)
def _add(self, replace, section, name, *args):
"""Add records.
@@ -88,9 +142,32 @@ class UpdateMessage(dns.message.Message):
- ttl, rdtype, string...
"""
- pass
- def add(self, name: Union[dns.name.Name, str], *args: Any) ->None:
+ if isinstance(name, str):
+ name = dns.name.from_text(name, None)
+ if isinstance(args[0], dns.rdataset.Rdataset):
+ for rds in args:
+ if replace:
+ self.delete(name, rds.rdtype)
+ for rd in rds:
+ self._add_rr(name, rds.ttl, rd, section=section)
+ else:
+ args = list(args)
+ ttl = int(args.pop(0))
+ if isinstance(args[0], dns.rdata.Rdata):
+ if replace:
+ self.delete(name, args[0].rdtype)
+ for rd in args:
+ self._add_rr(name, ttl, rd, section=section)
+ else:
+ rdtype = dns.rdatatype.RdataType.make(args.pop(0))
+ if replace:
+ self.delete(name, rdtype)
+ for s in args:
+ rd = dns.rdata.from_text(self.zone_rdclass, rdtype, s, self.origin)
+ self._add_rr(name, ttl, rd, section=section)
+
+ def add(self, name: Union[dns.name.Name, str], *args: Any) -> None:
"""Add records.
The first argument is always a name. The other
@@ -102,9 +179,10 @@ class UpdateMessage(dns.message.Message):
- ttl, rdtype, string...
"""
- pass
- def delete(self, name: Union[dns.name.Name, str], *args: Any) ->None:
+ self._add(False, self.update, name, *args)
+
+ def delete(self, name: Union[dns.name.Name, str], *args: Any) -> None:
"""Delete records.
The first argument is always a name. The other
@@ -118,9 +196,53 @@ class UpdateMessage(dns.message.Message):
- rdtype, [string...]
"""
- pass
- def replace(self, name: Union[dns.name.Name, str], *args: Any) ->None:
+ if isinstance(name, str):
+ name = dns.name.from_text(name, None)
+ if len(args) == 0:
+ self.find_rrset(
+ self.update,
+ name,
+ dns.rdataclass.ANY,
+ dns.rdatatype.ANY,
+ dns.rdatatype.NONE,
+ dns.rdataclass.ANY,
+ True,
+ True,
+ )
+ elif isinstance(args[0], dns.rdataset.Rdataset):
+ for rds in args:
+ for rd in rds:
+ self._add_rr(name, 0, rd, dns.rdataclass.NONE)
+ else:
+ largs = list(args)
+ if isinstance(largs[0], dns.rdata.Rdata):
+ for rd in largs:
+ self._add_rr(name, 0, rd, dns.rdataclass.NONE)
+ else:
+ rdtype = dns.rdatatype.RdataType.make(largs.pop(0))
+ if len(largs) == 0:
+ self.find_rrset(
+ self.update,
+ name,
+ self.zone_rdclass,
+ rdtype,
+ dns.rdatatype.NONE,
+ dns.rdataclass.ANY,
+ True,
+ True,
+ )
+ else:
+ for s in largs:
+ rd = dns.rdata.from_text(
+ self.zone_rdclass,
+ rdtype,
+ s, # type: ignore[arg-type]
+ self.origin,
+ )
+ self._add_rr(name, 0, rd, dns.rdataclass.NONE)
+
+ def replace(self, name: Union[dns.name.Name, str], *args: Any) -> None:
"""Replace records.
The first argument is always a name. The other
@@ -135,9 +257,10 @@ class UpdateMessage(dns.message.Message):
Note that if you want to replace the entire node, you should do
a delete of the name followed by one or more calls to add.
"""
- pass
- def present(self, name: Union[dns.name.Name, str], *args: Any) ->None:
+ self._add(True, self.update, name, *args)
+
+ def present(self, name: Union[dns.name.Name, str], *args: Any) -> None:
"""Require that an owner name (and optionally an rdata type,
or specific rdataset) exists as a prerequisite to the
execution of the update.
@@ -151,17 +274,113 @@ class UpdateMessage(dns.message.Message):
- rdtype, string...
"""
- pass
- def absent(self, name: Union[dns.name.Name, str], rdtype: Optional[
- Union[dns.rdatatype.RdataType, str]]=None) ->None:
+ if isinstance(name, str):
+ name = dns.name.from_text(name, None)
+ if len(args) == 0:
+ self.find_rrset(
+ self.prerequisite,
+ name,
+ dns.rdataclass.ANY,
+ dns.rdatatype.ANY,
+ dns.rdatatype.NONE,
+ None,
+ True,
+ True,
+ )
+ elif (
+ isinstance(args[0], dns.rdataset.Rdataset)
+ or isinstance(args[0], dns.rdata.Rdata)
+ or len(args) > 1
+ ):
+ if not isinstance(args[0], dns.rdataset.Rdataset):
+ # Add a 0 TTL
+ largs = list(args)
+ largs.insert(0, 0) # type: ignore[arg-type]
+ self._add(False, self.prerequisite, name, *largs)
+ else:
+ self._add(False, self.prerequisite, name, *args)
+ else:
+ rdtype = dns.rdatatype.RdataType.make(args[0])
+ self.find_rrset(
+ self.prerequisite,
+ name,
+ dns.rdataclass.ANY,
+ rdtype,
+ dns.rdatatype.NONE,
+ None,
+ True,
+ True,
+ )
+
+ def absent(
+ self,
+ name: Union[dns.name.Name, str],
+ rdtype: Optional[Union[dns.rdatatype.RdataType, str]] = None,
+ ) -> None:
"""Require that an owner name (and optionally an rdata type) does
not exist as a prerequisite to the execution of the update."""
- pass
-
+ if isinstance(name, str):
+ name = dns.name.from_text(name, None)
+ if rdtype is None:
+ self.find_rrset(
+ self.prerequisite,
+ name,
+ dns.rdataclass.NONE,
+ dns.rdatatype.ANY,
+ dns.rdatatype.NONE,
+ None,
+ True,
+ True,
+ )
+ else:
+ rdtype = dns.rdatatype.RdataType.make(rdtype)
+ self.find_rrset(
+ self.prerequisite,
+ name,
+ dns.rdataclass.NONE,
+ rdtype,
+ dns.rdatatype.NONE,
+ None,
+ True,
+ True,
+ )
+
+ def _get_one_rr_per_rrset(self, value):
+ # Updates are always one_rr_per_rrset
+ return True
+
+ def _parse_rr_header(self, section, name, rdclass, rdtype):
+ deleting = None
+ empty = False
+ if section == UpdateSection.ZONE:
+ if (
+ dns.rdataclass.is_metaclass(rdclass)
+ or rdtype != dns.rdatatype.SOA
+ or self.zone
+ ):
+ raise dns.exception.FormError
+ else:
+ if not self.zone:
+ raise dns.exception.FormError
+ if rdclass in (dns.rdataclass.ANY, dns.rdataclass.NONE):
+ deleting = rdclass
+ rdclass = self.zone[0].rdclass
+ empty = (
+ deleting == dns.rdataclass.ANY or section == UpdateSection.PREREQ
+ )
+ return (rdclass, rdtype, deleting, empty)
+
+
+# backwards compatibility
Update = UpdateMessage
+
+### BEGIN generated UpdateSection constants
+
ZONE = UpdateSection.ZONE
PREREQ = UpdateSection.PREREQ
UPDATE = UpdateSection.UPDATE
ADDITIONAL = UpdateSection.ADDITIONAL
+
+### END generated UpdateSection constants
diff --git a/dns/version.py b/dns/version.py
index 6246edc..251f258 100644
--- a/dns/version.py
+++ b/dns/version.py
@@ -1,16 +1,58 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2017 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
"""dnspython release version information."""
+
+#: MAJOR
MAJOR = 2
+#: MINOR
MINOR = 6
+#: MICRO
MICRO = 1
-RELEASELEVEL = 15
+#: RELEASELEVEL
+RELEASELEVEL = 0x0F
+#: SERIAL
SERIAL = 0
-if RELEASELEVEL == 15:
- version = '%d.%d.%d' % (MAJOR, MINOR, MICRO)
-elif RELEASELEVEL == 0:
- version = '%d.%d.%ddev%d' % (MAJOR, MINOR, MICRO, SERIAL)
-elif RELEASELEVEL == 12:
- version = '%d.%d.%drc%d' % (MAJOR, MINOR, MICRO, SERIAL)
-else:
- version = '%d.%d.%d%x%d' % (MAJOR, MINOR, MICRO, RELEASELEVEL, SERIAL)
-hexversion = (MAJOR << 24 | MINOR << 16 | MICRO << 8 | RELEASELEVEL << 4 |
- SERIAL)
+
+if RELEASELEVEL == 0x0F: # pragma: no cover lgtm[py/unreachable-statement]
+ #: version
+ version = "%d.%d.%d" % (MAJOR, MINOR, MICRO) # lgtm[py/unreachable-statement]
+elif RELEASELEVEL == 0x00: # pragma: no cover lgtm[py/unreachable-statement]
+ version = "%d.%d.%ddev%d" % (
+ MAJOR,
+ MINOR,
+ MICRO,
+ SERIAL,
+ ) # lgtm[py/unreachable-statement]
+elif RELEASELEVEL == 0x0C: # pragma: no cover lgtm[py/unreachable-statement]
+ version = "%d.%d.%drc%d" % (
+ MAJOR,
+ MINOR,
+ MICRO,
+ SERIAL,
+ ) # lgtm[py/unreachable-statement]
+else: # pragma: no cover lgtm[py/unreachable-statement]
+ version = "%d.%d.%d%x%d" % (
+ MAJOR,
+ MINOR,
+ MICRO,
+ RELEASELEVEL,
+ SERIAL,
+ ) # lgtm[py/unreachable-statement]
+
+#: hexversion
+hexversion = MAJOR << 24 | MINOR << 16 | MICRO << 8 | RELEASELEVEL << 4 | SERIAL
diff --git a/dns/versioned.py b/dns/versioned.py
index d716a34..fd78e67 100644
--- a/dns/versioned.py
+++ b/dns/versioned.py
@@ -1,7 +1,11 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
"""DNS Versioned Zones."""
+
import collections
import threading
from typing import Callable, Deque, Optional, Set, Union
+
import dns.exception
import dns.immutable
import dns.name
@@ -17,6 +21,7 @@ class UseTransaction(dns.exception.DNSException):
"""To alter a versioned zone, use a transaction."""
+# Backwards compatibility
Node = dns.zone.VersionedNode
ImmutableNode = dns.zone.ImmutableVersionedNode
Version = dns.zone.Version
@@ -25,15 +30,26 @@ ImmutableVersion = dns.zone.ImmutableVersion
Transaction = dns.zone.Transaction
-class Zone(dns.zone.Zone):
- __slots__ = ['_versions', '_versions_lock', '_write_txn',
- '_write_waiters', '_write_event', '_pruning_policy', '_readers']
+class Zone(dns.zone.Zone): # lgtm[py/missing-equals]
+ __slots__ = [
+ "_versions",
+ "_versions_lock",
+ "_write_txn",
+ "_write_waiters",
+ "_write_event",
+ "_pruning_policy",
+ "_readers",
+ ]
+
node_factory = Node
- def __init__(self, origin: Optional[Union[dns.name.Name, str]], rdclass:
- dns.rdataclass.RdataClass=dns.rdataclass.IN, relativize: bool=True,
- pruning_policy: Optional[Callable[['Zone', Version], Optional[bool]
- ]]=None):
+ def __init__(
+ self,
+ origin: Optional[Union[dns.name.Name, str]],
+ rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN,
+ relativize: bool = True,
+ pruning_policy: Optional[Callable[["Zone", Version], Optional[bool]]] = None,
+ ):
"""Initialize a versioned zone object.
*origin* is the origin of the zone. It may be a ``dns.name.Name``,
@@ -60,17 +76,145 @@ class Zone(dns.zone.Zone):
self._write_event: Optional[threading.Event] = None
self._write_waiters: Deque[threading.Event] = collections.deque()
self._readers: Set[Transaction] = set()
- self._commit_version_unlocked(None, WritableVersion(self,
- replacement=True), origin)
+ self._commit_version_unlocked(
+ None, WritableVersion(self, replacement=True), origin
+ )
+
+ def reader(
+ self, id: Optional[int] = None, serial: Optional[int] = None
+ ) -> Transaction: # pylint: disable=arguments-differ
+ if id is not None and serial is not None:
+ raise ValueError("cannot specify both id and serial")
+ with self._version_lock:
+ if id is not None:
+ version = None
+ for v in reversed(self._versions):
+ if v.id == id:
+ version = v
+ break
+ if version is None:
+ raise KeyError("version not found")
+ elif serial is not None:
+ if self.relativize:
+ oname = dns.name.empty
+ else:
+ assert self.origin is not None
+ oname = self.origin
+ version = None
+ for v in reversed(self._versions):
+ n = v.nodes.get(oname)
+ if n:
+ rds = n.get_rdataset(self.rdclass, dns.rdatatype.SOA)
+ if rds and rds[0].serial == serial:
+ version = v
+ break
+ if version is None:
+ raise KeyError("serial not found")
+ else:
+ version = self._versions[-1]
+ txn = Transaction(self, False, version)
+ self._readers.add(txn)
+ return txn
+
+ def writer(self, replacement: bool = False) -> Transaction:
+ event = None
+ while True:
+ with self._version_lock:
+ # Checking event == self._write_event ensures that either
+ # no one was waiting before we got lucky and found no write
+ # txn, or we were the one who was waiting and got woken up.
+ # This prevents "taking cuts" when creating a write txn.
+ if self._write_txn is None and event == self._write_event:
+ # Creating the transaction defers version setup
+ # (i.e. copying the nodes dictionary) until we
+ # give up the lock, so that we hold the lock as
+ # short a time as possible. This is why we call
+ # _setup_version() below.
+ self._write_txn = Transaction(
+ self, replacement, make_immutable=True
+ )
+ # give up our exclusive right to make a Transaction
+ self._write_event = None
+ break
+ # Someone else is writing already, so we will have to
+ # wait, but we want to do the actual wait outside the
+ # lock.
+ event = threading.Event()
+ self._write_waiters.append(event)
+ # wait (note we gave up the lock!)
+ #
+ # We only wake one sleeper at a time, so it's important
+ # that no event waiter can exit this method (e.g. via
+ # cancellation) without returning a transaction or waking
+ # someone else up.
+ #
+ # This is not a problem with Threading module threads as
+ # they cannot be canceled, but could be an issue with trio
+ # tasks when we do the async version of writer().
+ # I.e. we'd need to do something like:
+ #
+ # try:
+ # event.wait()
+ # except trio.Cancelled:
+ # with self._version_lock:
+ # self._maybe_wakeup_one_waiter_unlocked()
+ # raise
+ #
+ event.wait()
+ # Do the deferred version setup.
+ self._write_txn._setup_version()
+ return self._write_txn
+
+ def _maybe_wakeup_one_waiter_unlocked(self):
+ if len(self._write_waiters) > 0:
+ self._write_event = self._write_waiters.popleft()
+ self._write_event.set()
+
+ # pylint: disable=unused-argument
+ def _default_pruning_policy(self, zone, version):
+ return True
- def set_max_versions(self, max_versions: Optional[int]) ->None:
+ # pylint: enable=unused-argument
+
+ def _prune_versions_unlocked(self):
+ assert len(self._versions) > 0
+ # Don't ever prune a version greater than or equal to one that
+ # a reader has open. This pins versions in memory while the
+ # reader is open, and importantly lets the reader open a txn on
+ # a successor version (e.g. if generating an IXFR).
+ #
+ # Note our definition of least_kept also ensures we do not try to
+ # delete the greatest version.
+ if len(self._readers) > 0:
+ least_kept = min(txn.version.id for txn in self._readers)
+ else:
+ least_kept = self._versions[-1].id
+ while self._versions[0].id < least_kept and self._pruning_policy(
+ self, self._versions[0]
+ ):
+ self._versions.popleft()
+
+ def set_max_versions(self, max_versions: Optional[int]) -> None:
"""Set a pruning policy that retains up to the specified number
of versions
"""
- pass
+ if max_versions is not None and max_versions < 1:
+ raise ValueError("max versions must be at least 1")
+ if max_versions is None:
+
+ def policy(zone, _): # pylint: disable=unused-argument
+ return False
+
+ else:
+
+ def policy(zone, _):
+ return len(zone._versions) > max_versions
- def set_pruning_policy(self, policy: Optional[Callable[['Zone', Version
- ], Optional[bool]]]) ->None:
+ self.set_pruning_policy(policy)
+
+ def set_pruning_policy(
+ self, policy: Optional[Callable[["Zone", Version], Optional[bool]]]
+ ) -> None:
"""Set the pruning policy for the zone.
The *policy* function takes a `Version` and returns `True` if
@@ -82,4 +226,93 @@ class Zone(dns.zone.Zone):
time the function returns `False`, the checking stops. I.e. the
retained versions are always a consecutive sequence.
"""
- pass
+ if policy is None:
+ policy = self._default_pruning_policy
+ with self._version_lock:
+ self._pruning_policy = policy
+ self._prune_versions_unlocked()
+
+ def _end_read(self, txn):
+ with self._version_lock:
+ self._readers.remove(txn)
+ self._prune_versions_unlocked()
+
+ def _end_write_unlocked(self, txn):
+ assert self._write_txn == txn
+ self._write_txn = None
+ self._maybe_wakeup_one_waiter_unlocked()
+
+ def _end_write(self, txn):
+ with self._version_lock:
+ self._end_write_unlocked(txn)
+
+ def _commit_version_unlocked(self, txn, version, origin):
+ self._versions.append(version)
+ self._prune_versions_unlocked()
+ self.nodes = version.nodes
+ if self.origin is None:
+ self.origin = origin
+ # txn can be None in __init__ when we make the empty version.
+ if txn is not None:
+ self._end_write_unlocked(txn)
+
+ def _commit_version(self, txn, version, origin):
+ with self._version_lock:
+ self._commit_version_unlocked(txn, version, origin)
+
+ def _get_next_version_id(self):
+ if len(self._versions) > 0:
+ id = self._versions[-1].id + 1
+ else:
+ id = 1
+ return id
+
+ def find_node(
+ self, name: Union[dns.name.Name, str], create: bool = False
+ ) -> dns.node.Node:
+ if create:
+ raise UseTransaction
+ return super().find_node(name)
+
+ def delete_node(self, name: Union[dns.name.Name, str]) -> None:
+ raise UseTransaction
+
+ def find_rdataset(
+ self,
+ name: Union[dns.name.Name, str],
+ rdtype: Union[dns.rdatatype.RdataType, str],
+ covers: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.NONE,
+ create: bool = False,
+ ) -> dns.rdataset.Rdataset:
+ if create:
+ raise UseTransaction
+ rdataset = super().find_rdataset(name, rdtype, covers)
+ return dns.rdataset.ImmutableRdataset(rdataset)
+
+ def get_rdataset(
+ self,
+ name: Union[dns.name.Name, str],
+ rdtype: Union[dns.rdatatype.RdataType, str],
+ covers: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.NONE,
+ create: bool = False,
+ ) -> Optional[dns.rdataset.Rdataset]:
+ if create:
+ raise UseTransaction
+ rdataset = super().get_rdataset(name, rdtype, covers)
+ if rdataset is not None:
+ return dns.rdataset.ImmutableRdataset(rdataset)
+ else:
+ return None
+
+ def delete_rdataset(
+ self,
+ name: Union[dns.name.Name, str],
+ rdtype: Union[dns.rdatatype.RdataType, str],
+ covers: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.NONE,
+ ) -> None:
+ raise UseTransaction
+
+ def replace_rdataset(
+ self, name: Union[dns.name.Name, str], replacement: dns.rdataset.Rdataset
+ ) -> None:
+ raise UseTransaction
diff --git a/dns/win32util.py b/dns/win32util.py
index aee6d5a..aaa7e93 100644
--- a/dns/win32util.py
+++ b/dns/win32util.py
@@ -1,52 +1,245 @@
import sys
+
import dns._features
-if sys.platform == 'win32':
+
+if sys.platform == "win32":
from typing import Any
+
import dns.name
+
_prefer_wmi = True
- import winreg
+
+ import winreg # pylint: disable=import-error
+
+ # Keep pylint quiet on non-windows.
try:
- WindowsError is None
+ WindowsError is None # pylint: disable=used-before-assignment
except KeyError:
WindowsError = Exception
- if dns._features.have('wmi'):
+
+ if dns._features.have("wmi"):
import threading
- import pythoncom
- import wmi
+
+ import pythoncom # pylint: disable=import-error
+ import wmi # pylint: disable=import-error
+
_have_wmi = True
else:
_have_wmi = False
+ def _config_domain(domain):
+ # Sometimes DHCP servers add a '.' prefix to the default domain, and
+ # Windows just stores such values in the registry (see #687).
+ # Check for this and fix it.
+ if domain.startswith("."):
+ domain = domain[1:]
+ return dns.name.from_text(domain)
class DnsInfo:
-
def __init__(self):
self.domain = None
self.nameservers = []
self.search = []
- if _have_wmi:
+ if _have_wmi:
class _WMIGetter(threading.Thread):
-
def __init__(self):
super().__init__()
self.info = DnsInfo()
- else:
+ def run(self):
+ pythoncom.CoInitialize()
+ try:
+ system = wmi.WMI()
+ for interface in system.Win32_NetworkAdapterConfiguration():
+ if interface.IPEnabled and interface.DNSServerSearchOrder:
+ self.info.nameservers = list(interface.DNSServerSearchOrder)
+ if interface.DNSDomain:
+ self.info.domain = _config_domain(interface.DNSDomain)
+ if interface.DNSDomainSuffixSearchOrder:
+ self.info.search = [
+ _config_domain(x)
+ for x in interface.DNSDomainSuffixSearchOrder
+ ]
+ break
+ finally:
+ pythoncom.CoUninitialize()
- class _WMIGetter:
- pass
+ def get(self):
+ # We always run in a separate thread to avoid any issues with
+ # the COM threading model.
+ self.start()
+ self.join()
+ return self.info
+ else:
- class _RegistryGetter:
+ class _WMIGetter: # type: ignore
+ pass
+ class _RegistryGetter:
def __init__(self):
self.info = DnsInfo()
+ def _determine_split_char(self, entry):
+ #
+ # The windows registry irritatingly changes the list element
+ # delimiter in between ' ' and ',' (and vice-versa) in various
+ # versions of windows.
+ #
+ if entry.find(" ") >= 0:
+ split_char = " "
+ elif entry.find(",") >= 0:
+ split_char = ","
+ else:
+ # probably a singleton; treat as a space-separated list.
+ split_char = " "
+ return split_char
+
+ def _config_nameservers(self, nameservers):
+ split_char = self._determine_split_char(nameservers)
+ ns_list = nameservers.split(split_char)
+ for ns in ns_list:
+ if ns not in self.info.nameservers:
+ self.info.nameservers.append(ns)
+
+ def _config_search(self, search):
+ split_char = self._determine_split_char(search)
+ search_list = search.split(split_char)
+ for s in search_list:
+ s = _config_domain(s)
+ if s not in self.info.search:
+ self.info.search.append(s)
+
+ def _config_fromkey(self, key, always_try_domain):
+ try:
+ servers, _ = winreg.QueryValueEx(key, "NameServer")
+ except WindowsError:
+ servers = None
+ if servers:
+ self._config_nameservers(servers)
+ if servers or always_try_domain:
+ try:
+ dom, _ = winreg.QueryValueEx(key, "Domain")
+ if dom:
+ self.info.domain = _config_domain(dom)
+ except WindowsError:
+ pass
+ else:
+ try:
+ servers, _ = winreg.QueryValueEx(key, "DhcpNameServer")
+ except WindowsError:
+ servers = None
+ if servers:
+ self._config_nameservers(servers)
+ try:
+ dom, _ = winreg.QueryValueEx(key, "DhcpDomain")
+ if dom:
+ self.info.domain = _config_domain(dom)
+ except WindowsError:
+ pass
+ try:
+ search, _ = winreg.QueryValueEx(key, "SearchList")
+ except WindowsError:
+ search = None
+ if search is None:
+ try:
+ search, _ = winreg.QueryValueEx(key, "DhcpSearchList")
+ except WindowsError:
+ search = None
+ if search:
+ self._config_search(search)
+
+ def _is_nic_enabled(self, lm, guid):
+ # Look in the Windows Registry to determine whether the network
+ # interface corresponding to the given guid is enabled.
+ #
+ # (Code contributed by Paul Marks, thanks!)
+ #
+ try:
+ # This hard-coded location seems to be consistent, at least
+ # from Windows 2000 through Vista.
+ connection_key = winreg.OpenKey(
+ lm,
+ r"SYSTEM\CurrentControlSet\Control\Network"
+ r"\{4D36E972-E325-11CE-BFC1-08002BE10318}"
+ r"\%s\Connection" % guid,
+ )
+
+ try:
+ # The PnpInstanceID points to a key inside Enum
+ (pnp_id, ttype) = winreg.QueryValueEx(
+ connection_key, "PnpInstanceID"
+ )
+
+ if ttype != winreg.REG_SZ:
+ raise ValueError # pragma: no cover
+
+ device_key = winreg.OpenKey(
+ lm, r"SYSTEM\CurrentControlSet\Enum\%s" % pnp_id
+ )
+
+ try:
+ # Get ConfigFlags for this device
+ (flags, ttype) = winreg.QueryValueEx(device_key, "ConfigFlags")
+
+ if ttype != winreg.REG_DWORD:
+ raise ValueError # pragma: no cover
+
+ # Based on experimentation, bit 0x1 indicates that the
+ # device is disabled.
+ #
+ # XXXRTH I suspect we really want to & with 0x03 so
+ # that CONFIGFLAGS_REMOVED devices are also ignored,
+ # but we're shifting to WMI as ConfigFlags is not
+ # supposed to be used.
+ return not flags & 0x1
+
+ finally:
+ device_key.Close()
+ finally:
+ connection_key.Close()
+ except Exception: # pragma: no cover
+ return False
+
def get(self):
"""Extract resolver configuration from the Windows registry."""
- pass
+
+ lm = winreg.ConnectRegistry(None, winreg.HKEY_LOCAL_MACHINE)
+ try:
+ tcp_params = winreg.OpenKey(
+ lm, r"SYSTEM\CurrentControlSet\Services\Tcpip\Parameters"
+ )
+ try:
+ self._config_fromkey(tcp_params, True)
+ finally:
+ tcp_params.Close()
+ interfaces = winreg.OpenKey(
+ lm,
+ r"SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces",
+ )
+ try:
+ i = 0
+ while True:
+ try:
+ guid = winreg.EnumKey(interfaces, i)
+ i += 1
+ key = winreg.OpenKey(interfaces, guid)
+ try:
+ if not self._is_nic_enabled(lm, guid):
+ continue
+ self._config_fromkey(key, False)
+ finally:
+ key.Close()
+ except EnvironmentError:
+ break
+ finally:
+ interfaces.Close()
+ finally:
+ lm.Close()
+ return self.info
+
_getter_class: Any
if _have_wmi and _prefer_wmi:
_getter_class = _WMIGetter
@@ -55,4 +248,5 @@ if sys.platform == 'win32':
def get_dns_info():
"""Extract resolver configuration."""
- pass
+ getter = _getter_class()
+ return getter.get()
diff --git a/dns/wire.py b/dns/wire.py
index 13bdace..9f9b157 100644
--- a/dns/wire.py
+++ b/dns/wire.py
@@ -1,16 +1,89 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
import contextlib
import struct
from typing import Iterator, Optional, Tuple
+
import dns.exception
import dns.name
class Parser:
-
- def __init__(self, wire: bytes, current: int=0):
+ def __init__(self, wire: bytes, current: int = 0):
self.wire = wire
self.current = 0
self.end = len(self.wire)
if current:
self.seek(current)
self.furthest = current
+
+ def remaining(self) -> int:
+ return self.end - self.current
+
+ def get_bytes(self, size: int) -> bytes:
+ assert size >= 0
+ if size > self.remaining():
+ raise dns.exception.FormError
+ output = self.wire[self.current : self.current + size]
+ self.current += size
+ self.furthest = max(self.furthest, self.current)
+ return output
+
+ def get_counted_bytes(self, length_size: int = 1) -> bytes:
+ length = int.from_bytes(self.get_bytes(length_size), "big")
+ return self.get_bytes(length)
+
+ def get_remaining(self) -> bytes:
+ return self.get_bytes(self.remaining())
+
+ def get_uint8(self) -> int:
+ return struct.unpack("!B", self.get_bytes(1))[0]
+
+ def get_uint16(self) -> int:
+ return struct.unpack("!H", self.get_bytes(2))[0]
+
+ def get_uint32(self) -> int:
+ return struct.unpack("!I", self.get_bytes(4))[0]
+
+ def get_uint48(self) -> int:
+ return int.from_bytes(self.get_bytes(6), "big")
+
+ def get_struct(self, format: str) -> Tuple:
+ return struct.unpack(format, self.get_bytes(struct.calcsize(format)))
+
+ def get_name(self, origin: Optional["dns.name.Name"] = None) -> "dns.name.Name":
+ name = dns.name.from_wire_parser(self)
+ if origin:
+ name = name.relativize(origin)
+ return name
+
+ def seek(self, where: int) -> None:
+ # Note that seeking to the end is OK! (If you try to read
+ # after such a seek, you'll get an exception as expected.)
+ if where < 0 or where > self.end:
+ raise dns.exception.FormError
+ self.current = where
+
+ @contextlib.contextmanager
+ def restrict_to(self, size: int) -> Iterator:
+ assert size >= 0
+ if size > self.remaining():
+ raise dns.exception.FormError
+ saved_end = self.end
+ try:
+ self.end = self.current + size
+ yield
+ # We make this check here and not in the finally as we
+ # don't want to raise if we're already raising for some
+ # other reason.
+ if self.current != self.end:
+ raise dns.exception.FormError
+ finally:
+ self.end = saved_end
+
+ @contextlib.contextmanager
+ def restore_furthest(self) -> Iterator:
+ try:
+ yield None
+ finally:
+ self.current = self.furthest
diff --git a/dns/xfr.py b/dns/xfr.py
index 3d6d66f..dd247d3 100644
--- a/dns/xfr.py
+++ b/dns/xfr.py
@@ -1,4 +1,22 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2017 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
from typing import Any, List, Optional, Tuple, Union
+
import dns.exception
import dns.message
import dns.name
@@ -15,7 +33,7 @@ class TransferError(dns.exception.DNSException):
"""A zone transfer response got a non-zero rcode."""
def __init__(self, rcode):
- message = 'Zone transfer error: %s' % dns.rcode.to_text(rcode)
+ message = "Zone transfer error: %s" % dns.rcode.to_text(rcode)
super().__init__(message)
self.rcode = rcode
@@ -33,9 +51,13 @@ class Inbound:
State machine for zone transfers.
"""
- def __init__(self, txn_manager: dns.transaction.TransactionManager,
- rdtype: dns.rdatatype.RdataType=dns.rdatatype.AXFR, serial:
- Optional[int]=None, is_udp: bool=False):
+ def __init__(
+ self,
+ txn_manager: dns.transaction.TransactionManager,
+ rdtype: dns.rdatatype.RdataType = dns.rdatatype.AXFR,
+ serial: Optional[int] = None,
+ is_udp: bool = False,
+ ):
"""Initialize an inbound zone transfer.
*txn_manager* is a :py:class:`dns.transaction.TransactionManager`.
@@ -53,19 +75,18 @@ class Inbound:
self.rdtype = rdtype
if rdtype == dns.rdatatype.IXFR:
if serial is None:
- raise ValueError('a starting serial must be supplied for IXFRs'
- )
+ raise ValueError("a starting serial must be supplied for IXFRs")
elif is_udp:
- raise ValueError('is_udp specified for AXFR')
+ raise ValueError("is_udp specified for AXFR")
self.serial = serial
self.is_udp = is_udp
- _, _, self.origin = txn_manager.origin_information()
+ (_, _, self.origin) = txn_manager.origin_information()
self.soa_rdataset: Optional[dns.rdataset.Rdataset] = None
self.done = False
self.expecting_SOA = False
self.delete_mode = False
- def process_message(self, message: dns.message.Message) ->bool:
+ def process_message(self, message: dns.message.Message) -> bool:
"""Process one message in the transfer.
The message should have the same relativization as was specified when
@@ -74,7 +95,148 @@ class Inbound:
Returns `True` if the transfer is complete, and `False` otherwise.
"""
- pass
+ if self.txn is None:
+ replacement = self.rdtype == dns.rdatatype.AXFR
+ self.txn = self.txn_manager.writer(replacement)
+ rcode = message.rcode()
+ if rcode != dns.rcode.NOERROR:
+ raise TransferError(rcode)
+ #
+ # We don't require a question section, but if it is present is
+ # should be correct.
+ #
+ if len(message.question) > 0:
+ if message.question[0].name != self.origin:
+ raise dns.exception.FormError("wrong question name")
+ if message.question[0].rdtype != self.rdtype:
+ raise dns.exception.FormError("wrong question rdatatype")
+ answer_index = 0
+ if self.soa_rdataset is None:
+ #
+ # This is the first message. We're expecting an SOA at
+ # the origin.
+ #
+ if not message.answer or message.answer[0].name != self.origin:
+ raise dns.exception.FormError("No answer or RRset not for zone origin")
+ rrset = message.answer[0]
+ rdataset = rrset
+ if rdataset.rdtype != dns.rdatatype.SOA:
+ raise dns.exception.FormError("first RRset is not an SOA")
+ answer_index = 1
+ self.soa_rdataset = rdataset.copy()
+ if self.rdtype == dns.rdatatype.IXFR:
+ if self.soa_rdataset[0].serial == self.serial:
+ #
+ # We're already up-to-date.
+ #
+ self.done = True
+ elif dns.serial.Serial(self.soa_rdataset[0].serial) < self.serial:
+ # It went backwards!
+ raise SerialWentBackwards
+ else:
+ if self.is_udp and len(message.answer[answer_index:]) == 0:
+ #
+ # There are no more records, so this is the
+ # "truncated" response. Say to use TCP
+ #
+ raise UseTCP
+ #
+ # Note we're expecting another SOA so we can detect
+ # if this IXFR response is an AXFR-style response.
+ #
+ self.expecting_SOA = True
+ #
+ # Process the answer section (other than the initial SOA in
+ # the first message).
+ #
+ for rrset in message.answer[answer_index:]:
+ name = rrset.name
+ rdataset = rrset
+ if self.done:
+ raise dns.exception.FormError("answers after final SOA")
+ assert self.txn is not None # for mypy
+ if rdataset.rdtype == dns.rdatatype.SOA and name == self.origin:
+ #
+ # Every time we see an origin SOA delete_mode inverts
+ #
+ if self.rdtype == dns.rdatatype.IXFR:
+ self.delete_mode = not self.delete_mode
+ #
+ # If this SOA Rdataset is equal to the first we saw
+ # then we're finished. If this is an IXFR we also
+ # check that we're seeing the record in the expected
+ # part of the response.
+ #
+ if rdataset == self.soa_rdataset and (
+ self.rdtype == dns.rdatatype.AXFR
+ or (self.rdtype == dns.rdatatype.IXFR and self.delete_mode)
+ ):
+ #
+ # This is the final SOA
+ #
+ if self.expecting_SOA:
+ # We got an empty IXFR sequence!
+ raise dns.exception.FormError("empty IXFR sequence")
+ if (
+ self.rdtype == dns.rdatatype.IXFR
+ and self.serial != rdataset[0].serial
+ ):
+ raise dns.exception.FormError("unexpected end of IXFR sequence")
+ self.txn.replace(name, rdataset)
+ self.txn.commit()
+ self.txn = None
+ self.done = True
+ else:
+ #
+ # This is not the final SOA
+ #
+ self.expecting_SOA = False
+ if self.rdtype == dns.rdatatype.IXFR:
+ if self.delete_mode:
+ # This is the start of an IXFR deletion set
+ if rdataset[0].serial != self.serial:
+ raise dns.exception.FormError(
+ "IXFR base serial mismatch"
+ )
+ else:
+ # This is the start of an IXFR addition set
+ self.serial = rdataset[0].serial
+ self.txn.replace(name, rdataset)
+ else:
+ # We saw a non-final SOA for the origin in an AXFR.
+ raise dns.exception.FormError("unexpected origin SOA in AXFR")
+ continue
+ if self.expecting_SOA:
+ #
+ # We made an IXFR request and are expecting another
+ # SOA RR, but saw something else, so this must be an
+ # AXFR response.
+ #
+ self.rdtype = dns.rdatatype.AXFR
+ self.expecting_SOA = False
+ self.delete_mode = False
+ self.txn.rollback()
+ self.txn = self.txn_manager.writer(True)
+ #
+ # Note we are falling through into the code below
+ # so whatever rdataset this was gets written.
+ #
+ # Add or remove the data
+ if self.delete_mode:
+ self.txn.delete_exact(name, rdataset)
+ else:
+ self.txn.add(name, rdataset)
+ if self.is_udp and not self.done:
+ #
+ # This is a UDP IXFR and we didn't get to done, and we didn't
+ # get the proper "truncated" response
+ #
+ raise dns.exception.FormError("unexpected end of UDP IXFR")
+ return self.done
+
+ #
+ # Inbounds are context managers.
+ #
def __enter__(self):
return self
@@ -85,13 +247,18 @@ class Inbound:
return False
-def make_query(txn_manager: dns.transaction.TransactionManager, serial:
- Optional[int]=0, use_edns: Optional[Union[int, bool]]=None, ednsflags:
- Optional[int]=None, payload: Optional[int]=None, request_payload:
- Optional[int]=None, options: Optional[List[dns.edns.Option]]=None,
- keyring: Any=None, keyname: Optional[dns.name.Name]=None, keyalgorithm:
- Union[dns.name.Name, str]=dns.tsig.default_algorithm) ->Tuple[dns.
- message.QueryMessage, Optional[int]]:
+def make_query(
+ txn_manager: dns.transaction.TransactionManager,
+ serial: Optional[int] = 0,
+ use_edns: Optional[Union[int, bool]] = None,
+ ednsflags: Optional[int] = None,
+ payload: Optional[int] = None,
+ request_payload: Optional[int] = None,
+ options: Optional[List[dns.edns.Option]] = None,
+ keyring: Any = None,
+ keyname: Optional[dns.name.Name] = None,
+ keyalgorithm: Union[dns.name.Name, str] = dns.tsig.default_algorithm,
+) -> Tuple[dns.message.QueryMessage, Optional[int]]:
"""Make an AXFR or IXFR query.
*txn_manager* is a ``dns.transaction.TransactionManager``, typically a
@@ -111,10 +278,50 @@ def make_query(txn_manager: dns.transaction.TransactionManager, serial:
Returns a `(query, serial)` tuple.
"""
- pass
+ (zone_origin, _, origin) = txn_manager.origin_information()
+ if zone_origin is None:
+ raise ValueError("no zone origin")
+ if serial is None:
+ rdtype = dns.rdatatype.AXFR
+ elif not isinstance(serial, int):
+ raise ValueError("serial is not an integer")
+ elif serial == 0:
+ with txn_manager.reader() as txn:
+ rdataset = txn.get(origin, "SOA")
+ if rdataset:
+ serial = rdataset[0].serial
+ rdtype = dns.rdatatype.IXFR
+ else:
+ serial = None
+ rdtype = dns.rdatatype.AXFR
+ elif serial > 0 and serial < 4294967296:
+ rdtype = dns.rdatatype.IXFR
+ else:
+ raise ValueError("serial out-of-range")
+ rdclass = txn_manager.get_class()
+ q = dns.message.make_query(
+ zone_origin,
+ rdtype,
+ rdclass,
+ use_edns,
+ False,
+ ednsflags,
+ payload,
+ request_payload,
+ options,
+ )
+ if serial is not None:
+ rdata = dns.rdata.from_text(rdclass, "SOA", f". . {serial} 0 0 0 0")
+ rrset = q.find_rrset(
+ q.authority, zone_origin, rdclass, dns.rdatatype.SOA, create=True
+ )
+ rrset.add(rdata, 0)
+ if keyring is not None:
+ q.use_tsig(keyring, keyname, algorithm=keyalgorithm)
+ return (q, serial)
-def extract_serial_from_query(query: dns.message.Message) ->Optional[int]:
+def extract_serial_from_query(query: dns.message.Message) -> Optional[int]:
"""Extract the SOA serial number from query if it is an IXFR and return
it, otherwise return None.
@@ -123,4 +330,14 @@ def extract_serial_from_query(query: dns.message.Message) ->Optional[int]:
Raises if the query is not an IXFR or AXFR, or if an IXFR doesn't have
an appropriate SOA RRset in the authority section.
"""
- pass
+ if not isinstance(query, dns.message.QueryMessage):
+ raise ValueError("query not a QueryMessage")
+ question = query.question[0]
+ if question.rdtype == dns.rdatatype.AXFR:
+ return None
+ elif question.rdtype != dns.rdatatype.IXFR:
+ raise ValueError("query is not an AXFR or IXFR")
+ soa = query.find_rrset(
+ query.authority, question.name, question.rdclass, dns.rdatatype.SOA
+ )
+ return soa[0].serial
diff --git a/dns/zone.py b/dns/zone.py
index 464b98d..844919e 100644
--- a/dns/zone.py
+++ b/dns/zone.py
@@ -1,9 +1,39 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
"""DNS Zones."""
+
import contextlib
import io
import os
import struct
-from typing import Any, Callable, Iterable, Iterator, List, MutableMapping, Optional, Set, Tuple, Union
+from typing import (
+ Any,
+ Callable,
+ Iterable,
+ Iterator,
+ List,
+ MutableMapping,
+ Optional,
+ Set,
+ Tuple,
+ Union,
+)
+
import dns.exception
import dns.grange
import dns.immutable
@@ -55,6 +85,38 @@ class DigestVerificationFailure(dns.exception.DNSException):
"""The ZONEMD digest failed to verify."""
+def _validate_name(
+ name: dns.name.Name,
+ origin: Optional[dns.name.Name],
+ relativize: bool,
+) -> dns.name.Name:
+ # This name validation code is shared by Zone and Version
+ if origin is None:
+ # This should probably never happen as other code (e.g.
+ # _rr_line) will notice the lack of an origin before us, but
+ # we check just in case!
+ raise KeyError("no zone origin is defined")
+ if name.is_absolute():
+ if not name.is_subdomain(origin):
+ raise KeyError("name parameter must be a subdomain of the zone origin")
+ if relativize:
+ name = name.relativize(origin)
+ else:
+ # We have a relative name. Make sure that the derelativized name is
+ # not too long.
+ try:
+ abs_name = name.derelativize(origin)
+ except dns.name.NameTooLong:
+ # We map dns.name.NameTooLong to KeyError to be consistent with
+ # the other exceptions above.
+ raise KeyError("relative name too long for zone")
+ if not relativize:
+ # We have a relative name in a non-relative zone, so use the
+ # derelativized name.
+ name = abs_name
+ return name
+
+
class Zone(dns.transaction.TransactionManager):
"""A DNS zone.
@@ -65,16 +127,20 @@ class Zone(dns.transaction.TransactionManager):
if the name is relative it is treated as relative to the origin of
the zone.
"""
+
node_factory: Callable[[], dns.node.Node] = dns.node.Node
- map_factory: Callable[[], MutableMapping[dns.name.Name, dns.node.Node]
- ] = dict
- writable_version_factory: Optional[Callable[[], 'WritableVersion']] = None
- immutable_version_factory: Optional[Callable[[], 'ImmutableVersion']
- ] = None
- __slots__ = ['rdclass', 'origin', 'nodes', 'relativize']
-
- def __init__(self, origin: Optional[Union[dns.name.Name, str]], rdclass:
- dns.rdataclass.RdataClass=dns.rdataclass.IN, relativize: bool=True):
+ map_factory: Callable[[], MutableMapping[dns.name.Name, dns.node.Node]] = dict
+ writable_version_factory: Optional[Callable[[], "WritableVersion"]] = None
+ immutable_version_factory: Optional[Callable[[], "ImmutableVersion"]] = None
+
+ __slots__ = ["rdclass", "origin", "nodes", "relativize"]
+
+ def __init__(
+ self,
+ origin: Optional[Union[dns.name.Name, str]],
+ rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN,
+ relativize: bool = True,
+ ):
"""Initialize a zone object.
*origin* is the origin of the zone. It may be a ``dns.name.Name``,
@@ -86,18 +152,17 @@ class Zone(dns.transaction.TransactionManager):
*relativize*, a ``bool``, determine's whether domain names are
relativized to the zone's origin. The default is ``True``.
"""
+
if origin is not None:
if isinstance(origin, str):
origin = dns.name.from_text(origin)
elif not isinstance(origin, dns.name.Name):
- raise ValueError(
- 'origin parameter must be convertible to a DNS name')
+ raise ValueError("origin parameter must be convertible to a DNS name")
if not origin.is_absolute():
- raise ValueError('origin parameter must be an absolute name')
+ raise ValueError("origin parameter must be an absolute name")
self.origin = origin
self.rdclass = rdclass
- self.nodes: MutableMapping[dns.name.Name, dns.node.Node
- ] = self.map_factory()
+ self.nodes: MutableMapping[dns.name.Name, dns.node.Node] = self.map_factory()
self.relativize = relativize
def __eq__(self, other):
@@ -106,10 +171,14 @@ class Zone(dns.transaction.TransactionManager):
Returns a ``bool``.
"""
+
if not isinstance(other, Zone):
return False
- if (self.rdclass != other.rdclass or self.origin != other.origin or
- self.nodes != other.nodes):
+ if (
+ self.rdclass != other.rdclass
+ or self.origin != other.origin
+ or self.nodes != other.nodes
+ ):
return False
return True
@@ -118,8 +187,18 @@ class Zone(dns.transaction.TransactionManager):
Returns a ``bool``.
"""
+
return not self.__eq__(other)
+ def _validate_name(self, name: Union[dns.name.Name, str]) -> dns.name.Name:
+ # Note that any changes in this method should have corresponding changes
+ # made in the Version _validate_name() method.
+ if isinstance(name, str):
+ name = dns.name.from_text(name, None)
+ elif not isinstance(name, dns.name.Name):
+ raise KeyError("name parameter must be convertible to a DNS name")
+ return _validate_name(name, self.origin, self.relativize)
+
def __getitem__(self, key):
key = self._validate_name(key)
return self.nodes[key]
@@ -135,12 +214,26 @@ class Zone(dns.transaction.TransactionManager):
def __iter__(self):
return self.nodes.__iter__()
+ def keys(self):
+ return self.nodes.keys()
+
+ def values(self):
+ return self.nodes.values()
+
+ def items(self):
+ return self.nodes.items()
+
+ def get(self, key):
+ key = self._validate_name(key)
+ return self.nodes.get(key)
+
def __contains__(self, key):
key = self._validate_name(key)
return key in self.nodes
- def find_node(self, name: Union[dns.name.Name, str], create: bool=False
- ) ->dns.node.Node:
+ def find_node(
+ self, name: Union[dns.name.Name, str], create: bool = False
+ ) -> dns.node.Node:
"""Find a node in the zone, possibly creating it.
*name*: the name of the node to find.
@@ -156,10 +249,19 @@ class Zone(dns.transaction.TransactionManager):
Returns a ``dns.node.Node``.
"""
- pass
- def get_node(self, name: Union[dns.name.Name, str], create: bool=False
- ) ->Optional[dns.node.Node]:
+ name = self._validate_name(name)
+ node = self.nodes.get(name)
+ if node is None:
+ if not create:
+ raise KeyError
+ node = self.node_factory()
+ self.nodes[name] = node
+ return node
+
+ def get_node(
+ self, name: Union[dns.name.Name, str], create: bool = False
+ ) -> Optional[dns.node.Node]:
"""Get a node in the zone, possibly creating it.
This method is like ``find_node()``, except it returns None instead
@@ -176,9 +278,14 @@ class Zone(dns.transaction.TransactionManager):
Returns a ``dns.node.Node`` or ``None``.
"""
- pass
- def delete_node(self, name: Union[dns.name.Name, str]) ->None:
+ try:
+ node = self.find_node(name, create)
+ except KeyError:
+ node = None
+ return node
+
+ def delete_node(self, name: Union[dns.name.Name, str]) -> None:
"""Delete the specified node if it exists.
*name*: the name of the node to find.
@@ -188,12 +295,18 @@ class Zone(dns.transaction.TransactionManager):
It is not an error if the node does not exist.
"""
- pass
- def find_rdataset(self, name: Union[dns.name.Name, str], rdtype: Union[
- dns.rdatatype.RdataType, str], covers: Union[dns.rdatatype.
- RdataType, str]=dns.rdatatype.NONE, create: bool=False
- ) ->dns.rdataset.Rdataset:
+ name = self._validate_name(name)
+ if name in self.nodes:
+ del self.nodes[name]
+
+ def find_rdataset(
+ self,
+ name: Union[dns.name.Name, str],
+ rdtype: Union[dns.rdatatype.RdataType, str],
+ covers: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.NONE,
+ create: bool = False,
+ ) -> dns.rdataset.Rdataset:
"""Look for an rdataset with the specified name and type in the zone,
and return an rdataset encapsulating it.
@@ -227,12 +340,20 @@ class Zone(dns.transaction.TransactionManager):
Returns a ``dns.rdataset.Rdataset``.
"""
- pass
- def get_rdataset(self, name: Union[dns.name.Name, str], rdtype: Union[
- dns.rdatatype.RdataType, str], covers: Union[dns.rdatatype.
- RdataType, str]=dns.rdatatype.NONE, create: bool=False) ->Optional[dns
- .rdataset.Rdataset]:
+ name = self._validate_name(name)
+ rdtype = dns.rdatatype.RdataType.make(rdtype)
+ covers = dns.rdatatype.RdataType.make(covers)
+ node = self.find_node(name, create)
+ return node.find_rdataset(self.rdclass, rdtype, covers, create)
+
+ def get_rdataset(
+ self,
+ name: Union[dns.name.Name, str],
+ rdtype: Union[dns.rdatatype.RdataType, str],
+ covers: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.NONE,
+ create: bool = False,
+ ) -> Optional[dns.rdataset.Rdataset]:
"""Look for an rdataset with the specified name and type in the zone.
This method is like ``find_rdataset()``, except it returns None instead
@@ -267,11 +388,19 @@ class Zone(dns.transaction.TransactionManager):
Returns a ``dns.rdataset.Rdataset`` or ``None``.
"""
- pass
- def delete_rdataset(self, name: Union[dns.name.Name, str], rdtype:
- Union[dns.rdatatype.RdataType, str], covers: Union[dns.rdatatype.
- RdataType, str]=dns.rdatatype.NONE) ->None:
+ try:
+ rdataset = self.find_rdataset(name, rdtype, covers, create)
+ except KeyError:
+ rdataset = None
+ return rdataset
+
+ def delete_rdataset(
+ self,
+ name: Union[dns.name.Name, str],
+ rdtype: Union[dns.rdatatype.RdataType, str],
+ covers: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.NONE,
+ ) -> None:
"""Delete the rdataset matching *rdtype* and *covers*, if it
exists at the node specified by *name*.
@@ -294,10 +423,19 @@ class Zone(dns.transaction.TransactionManager):
makes RRSIGs much easier to work with than if RRSIGs covering different rdata
types were aggregated into a single RRSIG rdataset.
"""
- pass
- def replace_rdataset(self, name: Union[dns.name.Name, str], replacement:
- dns.rdataset.Rdataset) ->None:
+ name = self._validate_name(name)
+ rdtype = dns.rdatatype.RdataType.make(rdtype)
+ covers = dns.rdatatype.RdataType.make(covers)
+ node = self.get_node(name)
+ if node is not None:
+ node.delete_rdataset(self.rdclass, rdtype, covers)
+ if len(node) == 0:
+ self.delete_node(name)
+
+ def replace_rdataset(
+ self, name: Union[dns.name.Name, str], replacement: dns.rdataset.Rdataset
+ ) -> None:
"""Replace an rdataset at name.
It is not an error if there is no rdataset matching I{replacement}.
@@ -315,11 +453,18 @@ class Zone(dns.transaction.TransactionManager):
*replacement*, a ``dns.rdataset.Rdataset``, the replacement rdataset.
"""
- pass
- def find_rrset(self, name: Union[dns.name.Name, str], rdtype: Union[dns
- .rdatatype.RdataType, str], covers: Union[dns.rdatatype.RdataType,
- str]=dns.rdatatype.NONE) ->dns.rrset.RRset:
+ if replacement.rdclass != self.rdclass:
+ raise ValueError("replacement.rdclass != zone.rdclass")
+ node = self.find_node(name, True)
+ node.replace_rdataset(replacement)
+
+ def find_rrset(
+ self,
+ name: Union[dns.name.Name, str],
+ rdtype: Union[dns.rdatatype.RdataType, str],
+ covers: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.NONE,
+ ) -> dns.rrset.RRset:
"""Look for an rdataset with the specified name and type in the zone,
and return an RRset encapsulating it.
@@ -357,11 +502,21 @@ class Zone(dns.transaction.TransactionManager):
Returns a ``dns.rrset.RRset`` or ``None``.
"""
- pass
- def get_rrset(self, name: Union[dns.name.Name, str], rdtype: Union[dns.
- rdatatype.RdataType, str], covers: Union[dns.rdatatype.RdataType,
- str]=dns.rdatatype.NONE) ->Optional[dns.rrset.RRset]:
+ vname = self._validate_name(name)
+ rdtype = dns.rdatatype.RdataType.make(rdtype)
+ covers = dns.rdatatype.RdataType.make(covers)
+ rdataset = self.nodes[vname].find_rdataset(self.rdclass, rdtype, covers)
+ rrset = dns.rrset.RRset(vname, self.rdclass, rdtype, covers)
+ rrset.update(rdataset)
+ return rrset
+
+ def get_rrset(
+ self,
+ name: Union[dns.name.Name, str],
+ rdtype: Union[dns.rdatatype.RdataType, str],
+ covers: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.NONE,
+ ) -> Optional[dns.rrset.RRset]:
"""Look for an rdataset with the specified name and type in the zone,
and return an RRset encapsulating it.
@@ -395,12 +550,18 @@ class Zone(dns.transaction.TransactionManager):
Returns a ``dns.rrset.RRset`` or ``None``.
"""
- pass
- def iterate_rdatasets(self, rdtype: Union[dns.rdatatype.RdataType, str]
- =dns.rdatatype.ANY, covers: Union[dns.rdatatype.RdataType, str]=dns
- .rdatatype.NONE) ->Iterator[Tuple[dns.name.Name, dns.rdataset.Rdataset]
- ]:
+ try:
+ rrset = self.find_rrset(name, rdtype, covers)
+ except KeyError:
+ rrset = None
+ return rrset
+
+ def iterate_rdatasets(
+ self,
+ rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.ANY,
+ covers: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.NONE,
+ ) -> Iterator[Tuple[dns.name.Name, dns.rdataset.Rdataset]]:
"""Return a generator which yields (name, rdataset) tuples for
all rdatasets in the zone which have the specified *rdtype*
and *covers*. If *rdtype* is ``dns.rdatatype.ANY``, the default,
@@ -418,11 +579,21 @@ class Zone(dns.transaction.TransactionManager):
covering different rdata types were aggregated into a single
RRSIG rdataset.
"""
- pass
- def iterate_rdatas(self, rdtype: Union[dns.rdatatype.RdataType, str]=
- dns.rdatatype.ANY, covers: Union[dns.rdatatype.RdataType, str]=dns.
- rdatatype.NONE) ->Iterator[Tuple[dns.name.Name, int, dns.rdata.Rdata]]:
+ rdtype = dns.rdatatype.RdataType.make(rdtype)
+ covers = dns.rdatatype.RdataType.make(covers)
+ for name, node in self.items():
+ for rds in node:
+ if rdtype == dns.rdatatype.ANY or (
+ rds.rdtype == rdtype and rds.covers == covers
+ ):
+ yield (name, rds)
+
+ def iterate_rdatas(
+ self,
+ rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.ANY,
+ covers: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.NONE,
+ ) -> Iterator[Tuple[dns.name.Name, int, dns.rdata.Rdata]]:
"""Return a generator which yields (name, ttl, rdata) tuples for
all rdatas in the zone which have the specified *rdtype*
and *covers*. If *rdtype* is ``dns.rdatatype.ANY``, the default,
@@ -440,11 +611,26 @@ class Zone(dns.transaction.TransactionManager):
covering different rdata types were aggregated into a single
RRSIG rdataset.
"""
- pass
- def to_file(self, f: Any, sorted: bool=True, relativize: bool=True, nl:
- Optional[str]=None, want_comments: bool=False, want_origin: bool=False
- ) ->None:
+ rdtype = dns.rdatatype.RdataType.make(rdtype)
+ covers = dns.rdatatype.RdataType.make(covers)
+ for name, node in self.items():
+ for rds in node:
+ if rdtype == dns.rdatatype.ANY or (
+ rds.rdtype == rdtype and rds.covers == covers
+ ):
+ for rdata in rds:
+ yield (name, rds.ttl, rdata)
+
+ def to_file(
+ self,
+ f: Any,
+ sorted: bool = True,
+ relativize: bool = True,
+ nl: Optional[str] = None,
+ want_comments: bool = False,
+ want_origin: bool = False,
+ ) -> None:
"""Write a zone to a file.
*f*, a file or `str`. If *f* is a string, it is treated
@@ -471,11 +657,68 @@ class Zone(dns.transaction.TransactionManager):
the start of the file. If ``False``, the default, do not emit
one.
"""
- pass
- def to_text(self, sorted: bool=True, relativize: bool=True, nl:
- Optional[str]=None, want_comments: bool=False, want_origin: bool=False
- ) ->str:
+ if isinstance(f, str):
+ cm: contextlib.AbstractContextManager = open(f, "wb")
+ else:
+ cm = contextlib.nullcontext(f)
+ with cm as f:
+ # must be in this way, f.encoding may contain None, or even
+ # attribute may not be there
+ file_enc = getattr(f, "encoding", None)
+ if file_enc is None:
+ file_enc = "utf-8"
+
+ if nl is None:
+ # binary mode, '\n' is not enough
+ nl_b = os.linesep.encode(file_enc)
+ nl = "\n"
+ elif isinstance(nl, str):
+ nl_b = nl.encode(file_enc)
+ else:
+ nl_b = nl
+ nl = nl.decode()
+
+ if want_origin:
+ assert self.origin is not None
+ l = "$ORIGIN " + self.origin.to_text()
+ l_b = l.encode(file_enc)
+ try:
+ f.write(l_b)
+ f.write(nl_b)
+ except TypeError: # textual mode
+ f.write(l)
+ f.write(nl)
+
+ if sorted:
+ names = list(self.keys())
+ names.sort()
+ else:
+ names = self.keys()
+ for n in names:
+ l = self[n].to_text(
+ n,
+ origin=self.origin,
+ relativize=relativize,
+ want_comments=want_comments,
+ )
+ l_b = l.encode(file_enc)
+
+ try:
+ f.write(l_b)
+ f.write(nl_b)
+ except TypeError: # textual mode
+ f.write(l)
+ f.write(nl)
+
+ def to_text(
+ self,
+ sorted: bool = True,
+ relativize: bool = True,
+ nl: Optional[str] = None,
+ want_comments: bool = False,
+ want_origin: bool = False,
+ ) -> str:
"""Return a zone's text as though it were written to a file.
*sorted*, a ``bool``. If True, the default, then the file
@@ -501,9 +744,13 @@ class Zone(dns.transaction.TransactionManager):
Returns a ``str``.
"""
- pass
+ temp_buffer = io.StringIO()
+ self.to_file(temp_buffer, sorted, relativize, nl, want_comments, want_origin)
+ return_value = temp_buffer.getvalue()
+ temp_buffer.close()
+ return return_value
- def check_origin(self) ->None:
+ def check_origin(self) -> None:
"""Do some simple checking of the zone's origin.
Raises ``dns.zone.NoSOA`` if there is no SOA RRset.
@@ -512,42 +759,220 @@ class Zone(dns.transaction.TransactionManager):
Raises ``KeyError`` if there is no origin node.
"""
- pass
-
- def get_soa(self, txn: Optional[dns.transaction.Transaction]=None
- ) ->dns.rdtypes.ANY.SOA.SOA:
+ if self.relativize:
+ name = dns.name.empty
+ else:
+ assert self.origin is not None
+ name = self.origin
+ if self.get_rdataset(name, dns.rdatatype.SOA) is None:
+ raise NoSOA
+ if self.get_rdataset(name, dns.rdatatype.NS) is None:
+ raise NoNS
+
+ def get_soa(
+ self, txn: Optional[dns.transaction.Transaction] = None
+ ) -> dns.rdtypes.ANY.SOA.SOA:
"""Get the zone SOA rdata.
Raises ``dns.zone.NoSOA`` if there is no SOA RRset.
Returns a ``dns.rdtypes.ANY.SOA.SOA`` Rdata.
"""
+ if self.relativize:
+ origin_name = dns.name.empty
+ else:
+ if self.origin is None:
+ # get_soa() has been called very early, and there must not be
+ # an SOA if there is no origin.
+ raise NoSOA
+ origin_name = self.origin
+ soa: Optional[dns.rdataset.Rdataset]
+ if txn:
+ soa = txn.get(origin_name, dns.rdatatype.SOA)
+ else:
+ soa = self.get_rdataset(origin_name, dns.rdatatype.SOA)
+ if soa is None:
+ raise NoSOA
+ return soa[0]
+
+ def _compute_digest(
+ self,
+ hash_algorithm: DigestHashAlgorithm,
+ scheme: DigestScheme = DigestScheme.SIMPLE,
+ ) -> bytes:
+ hashinfo = _digest_hashers.get(hash_algorithm)
+ if not hashinfo:
+ raise UnsupportedDigestHashAlgorithm
+ if scheme != DigestScheme.SIMPLE:
+ raise UnsupportedDigestScheme
+
+ if self.relativize:
+ origin_name = dns.name.empty
+ else:
+ assert self.origin is not None
+ origin_name = self.origin
+ hasher = hashinfo()
+ for name, node in sorted(self.items()):
+ rrnamebuf = name.to_digestable(self.origin)
+ for rdataset in sorted(node, key=lambda rds: (rds.rdtype, rds.covers)):
+ if name == origin_name and dns.rdatatype.ZONEMD in (
+ rdataset.rdtype,
+ rdataset.covers,
+ ):
+ continue
+ rrfixed = struct.pack(
+ "!HHI", rdataset.rdtype, rdataset.rdclass, rdataset.ttl
+ )
+ rdatas = [rdata.to_digestable(self.origin) for rdata in rdataset]
+ for rdata in sorted(rdatas):
+ rrlen = struct.pack("!H", len(rdata))
+ hasher.update(rrnamebuf + rrfixed + rrlen + rdata)
+ return hasher.digest()
+
+ def compute_digest(
+ self,
+ hash_algorithm: DigestHashAlgorithm,
+ scheme: DigestScheme = DigestScheme.SIMPLE,
+ ) -> dns.rdtypes.ANY.ZONEMD.ZONEMD:
+ serial = self.get_soa().serial
+ digest = self._compute_digest(hash_algorithm, scheme)
+ return dns.rdtypes.ANY.ZONEMD.ZONEMD(
+ self.rdclass, dns.rdatatype.ZONEMD, serial, scheme, hash_algorithm, digest
+ )
+
+ def verify_digest(
+ self, zonemd: Optional[dns.rdtypes.ANY.ZONEMD.ZONEMD] = None
+ ) -> None:
+ digests: Union[dns.rdataset.Rdataset, List[dns.rdtypes.ANY.ZONEMD.ZONEMD]]
+ if zonemd:
+ digests = [zonemd]
+ else:
+ assert self.origin is not None
+ rds = self.get_rdataset(self.origin, dns.rdatatype.ZONEMD)
+ if rds is None:
+ raise NoDigest
+ digests = rds
+ for digest in digests:
+ try:
+ computed = self._compute_digest(digest.hash_algorithm, digest.scheme)
+ if computed == digest.digest:
+ return
+ except Exception:
+ pass
+ raise DigestVerificationFailure
+
+ # TransactionManager methods
+
+ def reader(self) -> "Transaction":
+ return Transaction(self, False, Version(self, 1, self.nodes, self.origin))
+
+ def writer(self, replacement: bool = False) -> "Transaction":
+ txn = Transaction(self, replacement)
+ txn._setup_version()
+ return txn
+
+ def origin_information(
+ self,
+ ) -> Tuple[Optional[dns.name.Name], bool, Optional[dns.name.Name]]:
+ effective: Optional[dns.name.Name]
+ if self.relativize:
+ effective = dns.name.empty
+ else:
+ effective = self.origin
+ return (self.origin, self.relativize, effective)
+
+ def get_class(self):
+ return self.rdclass
+
+ # Transaction methods
+
+ def _end_read(self, txn):
+ pass
+
+ def _end_write(self, txn):
pass
+ def _commit_version(self, _, version, origin):
+ self.nodes = version.nodes
+ if self.origin is None:
+ self.origin = origin
+
+ def _get_next_version_id(self):
+ # Versions are ephemeral and all have id 1
+ return 1
+
+
+# These classes used to be in dns.versioned, but have moved here so we can use
+# the copy-on-write transaction mechanism for both kinds of zones. In a
+# regular zone, the version only exists during the transaction, and the nodes
+# are regular dns.node.Nodes.
+
+# A node with a version id.
-class VersionedNode(dns.node.Node):
- __slots__ = ['id']
+
+class VersionedNode(dns.node.Node): # lgtm[py/missing-equals]
+ __slots__ = ["id"]
def __init__(self):
super().__init__()
+ # A proper id will get set by the Version
self.id = 0
@dns.immutable.immutable
class ImmutableVersionedNode(VersionedNode):
-
def __init__(self, node):
super().__init__()
self.id = node.id
- self.rdatasets = tuple([dns.rdataset.ImmutableRdataset(rds) for rds in
- node.rdatasets])
+ self.rdatasets = tuple(
+ [dns.rdataset.ImmutableRdataset(rds) for rds in node.rdatasets]
+ )
+
+ def find_rdataset(
+ self,
+ rdclass: dns.rdataclass.RdataClass,
+ rdtype: dns.rdatatype.RdataType,
+ covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
+ create: bool = False,
+ ) -> dns.rdataset.Rdataset:
+ if create:
+ raise TypeError("immutable")
+ return super().find_rdataset(rdclass, rdtype, covers, False)
+
+ def get_rdataset(
+ self,
+ rdclass: dns.rdataclass.RdataClass,
+ rdtype: dns.rdatatype.RdataType,
+ covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
+ create: bool = False,
+ ) -> Optional[dns.rdataset.Rdataset]:
+ if create:
+ raise TypeError("immutable")
+ return super().get_rdataset(rdclass, rdtype, covers, False)
+
+ def delete_rdataset(
+ self,
+ rdclass: dns.rdataclass.RdataClass,
+ rdtype: dns.rdatatype.RdataType,
+ covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
+ ) -> None:
+ raise TypeError("immutable")
+
+ def replace_rdataset(self, replacement: dns.rdataset.Rdataset) -> None:
+ raise TypeError("immutable")
+
+ def is_immutable(self) -> bool:
+ return True
class Version:
-
- def __init__(self, zone: Zone, id: int, nodes: Optional[MutableMapping[
- dns.name.Name, dns.node.Node]]=None, origin: Optional[dns.name.Name
- ]=None):
+ def __init__(
+ self,
+ zone: Zone,
+ id: int,
+ nodes: Optional[MutableMapping[dns.name.Name, dns.node.Node]] = None,
+ origin: Optional[dns.name.Name] = None,
+ ):
self.zone = zone
self.id = id
if nodes is not None:
@@ -556,48 +981,258 @@ class Version:
self.nodes = zone.map_factory()
self.origin = origin
+ def _validate_name(self, name: dns.name.Name) -> dns.name.Name:
+ return _validate_name(name, self.origin, self.zone.relativize)
+
+ def get_node(self, name: dns.name.Name) -> Optional[dns.node.Node]:
+ name = self._validate_name(name)
+ return self.nodes.get(name)
+
+ def get_rdataset(
+ self,
+ name: dns.name.Name,
+ rdtype: dns.rdatatype.RdataType,
+ covers: dns.rdatatype.RdataType,
+ ) -> Optional[dns.rdataset.Rdataset]:
+ node = self.get_node(name)
+ if node is None:
+ return None
+ return node.get_rdataset(self.zone.rdclass, rdtype, covers)
+
+ def keys(self):
+ return self.nodes.keys()
+
+ def items(self):
+ return self.nodes.items()
-class WritableVersion(Version):
- def __init__(self, zone: Zone, replacement: bool=False):
+class WritableVersion(Version):
+ def __init__(self, zone: Zone, replacement: bool = False):
+ # The zone._versions_lock must be held by our caller in a versioned
+ # zone.
id = zone._get_next_version_id()
super().__init__(zone, id)
if not replacement:
+ # We copy the map, because that gives us a simple and thread-safe
+ # way of doing versions, and we have a garbage collector to help
+ # us. We only make new node objects if we actually change the
+ # node.
self.nodes.update(zone.nodes)
+ # We have to copy the zone origin as it may be None in the first
+ # version, and we don't want to mutate the zone until we commit.
self.origin = zone.origin
self.changed: Set[dns.name.Name] = set()
+ def _maybe_cow(self, name: dns.name.Name) -> dns.node.Node:
+ name = self._validate_name(name)
+ node = self.nodes.get(name)
+ if node is None or name not in self.changed:
+ new_node = self.zone.node_factory()
+ if hasattr(new_node, "id"):
+ # We keep doing this for backwards compatibility, as earlier
+ # code used new_node.id != self.id for the "do we need to CoW?"
+ # test. Now we use the changed set as this works with both
+ # regular zones and versioned zones.
+ #
+ # We ignore the mypy error as this is safe but it doesn't see it.
+ new_node.id = self.id # type: ignore
+ if node is not None:
+ # moo! copy on write!
+ new_node.rdatasets.extend(node.rdatasets)
+ self.nodes[name] = new_node
+ self.changed.add(name)
+ return new_node
+ else:
+ return node
+
+ def delete_node(self, name: dns.name.Name) -> None:
+ name = self._validate_name(name)
+ if name in self.nodes:
+ del self.nodes[name]
+ self.changed.add(name)
+
+ def put_rdataset(
+ self, name: dns.name.Name, rdataset: dns.rdataset.Rdataset
+ ) -> None:
+ node = self._maybe_cow(name)
+ node.replace_rdataset(rdataset)
+
+ def delete_rdataset(
+ self,
+ name: dns.name.Name,
+ rdtype: dns.rdatatype.RdataType,
+ covers: dns.rdatatype.RdataType,
+ ) -> None:
+ node = self._maybe_cow(name)
+ node.delete_rdataset(self.zone.rdclass, rdtype, covers)
+ if len(node) == 0:
+ del self.nodes[name]
+
@dns.immutable.immutable
class ImmutableVersion(Version):
-
def __init__(self, version: WritableVersion):
+ # We tell super() that it's a replacement as we don't want it
+ # to copy the nodes, as we're about to do that with an
+ # immutable Dict.
super().__init__(version.zone, True)
+ # set the right id!
self.id = version.id
+ # keep the origin
self.origin = version.origin
+ # Make changed nodes immutable
for name in version.changed:
node = version.nodes.get(name)
+ # it might not exist if we deleted it in the version
if node:
version.nodes[name] = ImmutableVersionedNode(node)
- self.nodes = dns.immutable.Dict(version.nodes, True, self.zone.
- map_factory)
+ # We're changing the type of the nodes dictionary here on purpose, so
+ # we ignore the mypy error.
+ self.nodes = dns.immutable.Dict(
+ version.nodes, True, self.zone.map_factory
+ ) # type: ignore
class Transaction(dns.transaction.Transaction):
-
def __init__(self, zone, replacement, version=None, make_immutable=False):
read_only = version is not None
super().__init__(zone, replacement, read_only)
self.version = version
self.make_immutable = make_immutable
+ @property
+ def zone(self):
+ return self.manager
+
+ def _setup_version(self):
+ assert self.version is None
+ factory = self.manager.writable_version_factory
+ if factory is None:
+ factory = WritableVersion
+ self.version = factory(self.zone, self.replacement)
+
+ def _get_rdataset(self, name, rdtype, covers):
+ return self.version.get_rdataset(name, rdtype, covers)
+
+ def _put_rdataset(self, name, rdataset):
+ assert not self.read_only
+ self.version.put_rdataset(name, rdataset)
+
+ def _delete_name(self, name):
+ assert not self.read_only
+ self.version.delete_node(name)
-def from_text(text: str, origin: Optional[Union[dns.name.Name, str]]=None,
- rdclass: dns.rdataclass.RdataClass=dns.rdataclass.IN, relativize: bool=
- True, zone_factory: Any=Zone, filename: Optional[str]=None,
- allow_include: bool=False, check_origin: bool=True, idna_codec:
- Optional[dns.name.IDNACodec]=None, allow_directives: Union[bool,
- Iterable[str]]=True) ->Zone:
+ def _delete_rdataset(self, name, rdtype, covers):
+ assert not self.read_only
+ self.version.delete_rdataset(name, rdtype, covers)
+
+ def _name_exists(self, name):
+ return self.version.get_node(name) is not None
+
+ def _changed(self):
+ if self.read_only:
+ return False
+ else:
+ return len(self.version.changed) > 0
+
+ def _end_transaction(self, commit):
+ if self.read_only:
+ self.zone._end_read(self)
+ elif commit and len(self.version.changed) > 0:
+ if self.make_immutable:
+ factory = self.manager.immutable_version_factory
+ if factory is None:
+ factory = ImmutableVersion
+ version = factory(self.version)
+ else:
+ version = self.version
+ self.zone._commit_version(self, version, self.version.origin)
+ else:
+ # rollback
+ self.zone._end_write(self)
+
+ def _set_origin(self, origin):
+ if self.version.origin is None:
+ self.version.origin = origin
+
+ def _iterate_rdatasets(self):
+ for name, node in self.version.items():
+ for rdataset in node:
+ yield (name, rdataset)
+
+ def _iterate_names(self):
+ return self.version.keys()
+
+ def _get_node(self, name):
+ return self.version.get_node(name)
+
+ def _origin_information(self):
+ (absolute, relativize, effective) = self.manager.origin_information()
+ if absolute is None and self.version.origin is not None:
+ # No origin has been committed yet, but we've learned one as part of
+ # this txn. Use it.
+ absolute = self.version.origin
+ if relativize:
+ effective = dns.name.empty
+ else:
+ effective = absolute
+ return (absolute, relativize, effective)
+
+
+def _from_text(
+ text: Any,
+ origin: Optional[Union[dns.name.Name, str]] = None,
+ rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN,
+ relativize: bool = True,
+ zone_factory: Any = Zone,
+ filename: Optional[str] = None,
+ allow_include: bool = False,
+ check_origin: bool = True,
+ idna_codec: Optional[dns.name.IDNACodec] = None,
+ allow_directives: Union[bool, Iterable[str]] = True,
+) -> Zone:
+ # See the comments for the public APIs from_text() and from_file() for
+ # details.
+
+ # 'text' can also be a file, but we don't publish that fact
+ # since it's an implementation detail. The official file
+ # interface is from_file().
+
+ if filename is None:
+ filename = "<string>"
+ zone = zone_factory(origin, rdclass, relativize=relativize)
+ with zone.writer(True) as txn:
+ tok = dns.tokenizer.Tokenizer(text, filename, idna_codec=idna_codec)
+ reader = dns.zonefile.Reader(
+ tok,
+ rdclass,
+ txn,
+ allow_include=allow_include,
+ allow_directives=allow_directives,
+ )
+ try:
+ reader.read()
+ except dns.zonefile.UnknownOrigin:
+ # for backwards compatibility
+ raise dns.zone.UnknownOrigin
+ # Now that we're done reading, do some basic checking of the zone.
+ if check_origin:
+ zone.check_origin()
+ return zone
+
+
+def from_text(
+ text: str,
+ origin: Optional[Union[dns.name.Name, str]] = None,
+ rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN,
+ relativize: bool = True,
+ zone_factory: Any = Zone,
+ filename: Optional[str] = None,
+ allow_include: bool = False,
+ check_origin: bool = True,
+ idna_codec: Optional[dns.name.IDNACodec] = None,
+ allow_directives: Union[bool, Iterable[str]] = True,
+) -> Zone:
"""Build a zone object from a zone file format string.
*text*, a ``str``, the zone file format input.
@@ -646,15 +1281,32 @@ def from_text(text: str, origin: Optional[Union[dns.name.Name, str]]=None,
Returns a subclass of ``dns.zone.Zone``.
"""
- pass
-
-
-def from_file(f: Any, origin: Optional[Union[dns.name.Name, str]]=None,
- rdclass: dns.rdataclass.RdataClass=dns.rdataclass.IN, relativize: bool=
- True, zone_factory: Any=Zone, filename: Optional[str]=None,
- allow_include: bool=True, check_origin: bool=True, idna_codec: Optional
- [dns.name.IDNACodec]=None, allow_directives: Union[bool, Iterable[str]]
- =True) ->Zone:
+ return _from_text(
+ text,
+ origin,
+ rdclass,
+ relativize,
+ zone_factory,
+ filename,
+ allow_include,
+ check_origin,
+ idna_codec,
+ allow_directives,
+ )
+
+
+def from_file(
+ f: Any,
+ origin: Optional[Union[dns.name.Name, str]] = None,
+ rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN,
+ relativize: bool = True,
+ zone_factory: Any = Zone,
+ filename: Optional[str] = None,
+ allow_include: bool = True,
+ check_origin: bool = True,
+ idna_codec: Optional[dns.name.IDNACodec] = None,
+ allow_directives: Union[bool, Iterable[str]] = True,
+) -> Zone:
"""Read a zone file and build a zone object.
*f*, a file or ``str``. If *f* is a string, it is treated
@@ -703,11 +1355,35 @@ def from_file(f: Any, origin: Optional[Union[dns.name.Name, str]]=None,
Returns a subclass of ``dns.zone.Zone``.
"""
- pass
-
-def from_xfr(xfr: Any, zone_factory: Any=Zone, relativize: bool=True,
- check_origin: bool=True) ->Zone:
+ if isinstance(f, str):
+ if filename is None:
+ filename = f
+ cm: contextlib.AbstractContextManager = open(f)
+ else:
+ cm = contextlib.nullcontext(f)
+ with cm as f:
+ return _from_text(
+ f,
+ origin,
+ rdclass,
+ relativize,
+ zone_factory,
+ filename,
+ allow_include,
+ check_origin,
+ idna_codec,
+ allow_directives,
+ )
+ assert False # make mypy happy lgtm[py/unreachable-statement]
+
+
+def from_xfr(
+ xfr: Any,
+ zone_factory: Any = Zone,
+ relativize: bool = True,
+ check_origin: bool = True,
+) -> Zone:
"""Convert the output of a zone transfer generator into a zone object.
*xfr*, a generator of ``dns.message.Message`` objects, typically
@@ -732,4 +1408,27 @@ def from_xfr(xfr: Any, zone_factory: Any=Zone, relativize: bool=True,
Returns a subclass of ``dns.zone.Zone``.
"""
- pass
+
+ z = None
+ for r in xfr:
+ if z is None:
+ if relativize:
+ origin = r.origin
+ else:
+ origin = r.answer[0].name
+ rdclass = r.answer[0].rdclass
+ z = zone_factory(origin, rdclass, relativize=relativize)
+ for rrset in r.answer:
+ znode = z.nodes.get(rrset.name)
+ if not znode:
+ znode = z.node_factory()
+ z.nodes[rrset.name] = znode
+ zrds = znode.find_rdataset(rrset.rdclass, rrset.rdtype, rrset.covers, True)
+ zrds.update_ttl(rrset.ttl)
+ for rd in rrset:
+ zrds.add(rd)
+ if z is None:
+ raise ValueError("empty transfer")
+ if check_origin:
+ z.check_origin()
+ return z
diff --git a/dns/zonefile.py b/dns/zonefile.py
index c5e5731..af064e7 100644
--- a/dns/zonefile.py
+++ b/dns/zonefile.py
@@ -1,7 +1,26 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
"""DNS Zones."""
+
import re
import sys
from typing import Any, Iterable, List, Optional, Set, Tuple, Union
+
import dns.exception
import dns.grange
import dns.name
@@ -24,22 +43,66 @@ class CNAMEAndOtherData(dns.exception.DNSException):
"""A node has a CNAME and other data"""
-SavedStateType = Tuple[dns.tokenizer.Tokenizer, Optional[dns.name.Name],
- Optional[dns.name.Name], Optional[Any], int, bool, int, bool]
+def _check_cname_and_other_data(txn, name, rdataset):
+ rdataset_kind = dns.node.NodeKind.classify_rdataset(rdataset)
+ node = txn.get_node(name)
+ if node is None:
+ # empty nodes are neutral.
+ return
+ node_kind = node.classify()
+ if (
+ node_kind == dns.node.NodeKind.CNAME
+ and rdataset_kind == dns.node.NodeKind.REGULAR
+ ):
+ raise CNAMEAndOtherData("rdataset type is not compatible with a CNAME node")
+ elif (
+ node_kind == dns.node.NodeKind.REGULAR
+ and rdataset_kind == dns.node.NodeKind.CNAME
+ ):
+ raise CNAMEAndOtherData(
+ "CNAME rdataset is not compatible with a regular data node"
+ )
+ # Otherwise at least one of the node and the rdataset is neutral, so
+ # adding the rdataset is ok
+
+
+SavedStateType = Tuple[
+ dns.tokenizer.Tokenizer,
+ Optional[dns.name.Name], # current_origin
+ Optional[dns.name.Name], # last_name
+ Optional[Any], # current_file
+ int, # last_ttl
+ bool, # last_ttl_known
+ int, # default_ttl
+ bool,
+] # default_ttl_known
+
+
+def _upper_dollarize(s):
+ s = s.upper()
+ if not s.startswith("$"):
+ s = "$" + s
+ return s
class Reader:
"""Read a DNS zone file into a transaction."""
- def __init__(self, tok: dns.tokenizer.Tokenizer, rdclass: dns.
- rdataclass.RdataClass, txn: dns.transaction.Transaction,
- allow_include: bool=False, allow_directives: Union[bool, Iterable[
- str]]=True, force_name: Optional[dns.name.Name]=None, force_ttl:
- Optional[int]=None, force_rdclass: Optional[dns.rdataclass.
- RdataClass]=None, force_rdtype: Optional[dns.rdatatype.RdataType]=
- None, default_ttl: Optional[int]=None):
+ def __init__(
+ self,
+ tok: dns.tokenizer.Tokenizer,
+ rdclass: dns.rdataclass.RdataClass,
+ txn: dns.transaction.Transaction,
+ allow_include: bool = False,
+ allow_directives: Union[bool, Iterable[str]] = True,
+ force_name: Optional[dns.name.Name] = None,
+ force_ttl: Optional[int] = None,
+ force_rdclass: Optional[dns.rdataclass.RdataClass] = None,
+ force_rdtype: Optional[dns.rdatatype.RdataType] = None,
+ default_ttl: Optional[int] = None,
+ ):
self.tok = tok
- self.zone_origin, self.relativize, _ = txn.manager.origin_information()
+ (self.zone_origin, self.relativize, _) = txn.manager.origin_information()
self.current_origin = self.zone_origin
self.last_ttl = 0
self.last_ttl_known = False
@@ -58,64 +121,547 @@ class Reader:
self.current_file: Optional[Any] = None
self.allowed_directives: Set[str]
if allow_directives is True:
- self.allowed_directives = {'$GENERATE', '$ORIGIN', '$TTL'}
+ self.allowed_directives = {"$GENERATE", "$ORIGIN", "$TTL"}
if allow_include:
- self.allowed_directives.add('$INCLUDE')
+ self.allowed_directives.add("$INCLUDE")
elif allow_directives is False:
+ # allow_include was ignored in earlier releases if allow_directives was
+ # False, so we continue that.
self.allowed_directives = set()
else:
- self.allowed_directives = set(_upper_dollarize(d) for d in
- allow_directives)
+ # Note that if directives are explicitly specified, then allow_include
+ # is ignored.
+ self.allowed_directives = set(_upper_dollarize(d) for d in allow_directives)
self.force_name = force_name
self.force_ttl = force_ttl
self.force_rdclass = force_rdclass
self.force_rdtype = force_rdtype
self.txn.check_put_rdataset(_check_cname_and_other_data)
+ def _eat_line(self):
+ while 1:
+ token = self.tok.get()
+ if token.is_eol_or_eof():
+ break
+
+ def _get_identifier(self):
+ token = self.tok.get()
+ if not token.is_identifier():
+ raise dns.exception.SyntaxError
+ return token
+
def _rr_line(self):
"""Process one line from a DNS zone file."""
- pass
+ token = None
+ # Name
+ if self.force_name is not None:
+ name = self.force_name
+ else:
+ if self.current_origin is None:
+ raise UnknownOrigin
+ token = self.tok.get(want_leading=True)
+ if not token.is_whitespace():
+ self.last_name = self.tok.as_name(token, self.current_origin)
+ else:
+ token = self.tok.get()
+ if token.is_eol_or_eof():
+ # treat leading WS followed by EOL/EOF as if they were EOL/EOF.
+ return
+ self.tok.unget(token)
+ name = self.last_name
+ if not name.is_subdomain(self.zone_origin):
+ self._eat_line()
+ return
+ if self.relativize:
+ name = name.relativize(self.zone_origin)
+
+ # TTL
+ if self.force_ttl is not None:
+ ttl = self.force_ttl
+ self.last_ttl = ttl
+ self.last_ttl_known = True
+ else:
+ token = self._get_identifier()
+ ttl = None
+ try:
+ ttl = dns.ttl.from_text(token.value)
+ self.last_ttl = ttl
+ self.last_ttl_known = True
+ token = None
+ except dns.ttl.BadTTL:
+ self.tok.unget(token)
+
+ # Class
+ if self.force_rdclass is not None:
+ rdclass = self.force_rdclass
+ else:
+ token = self._get_identifier()
+ try:
+ rdclass = dns.rdataclass.from_text(token.value)
+ except dns.exception.SyntaxError:
+ raise
+ except Exception:
+ rdclass = self.zone_rdclass
+ self.tok.unget(token)
+ if rdclass != self.zone_rdclass:
+ raise dns.exception.SyntaxError("RR class is not zone's class")
+
+ if ttl is None:
+ # support for <class> <ttl> <type> syntax
+ token = self._get_identifier()
+ ttl = None
+ try:
+ ttl = dns.ttl.from_text(token.value)
+ self.last_ttl = ttl
+ self.last_ttl_known = True
+ token = None
+ except dns.ttl.BadTTL:
+ if self.default_ttl_known:
+ ttl = self.default_ttl
+ elif self.last_ttl_known:
+ ttl = self.last_ttl
+ self.tok.unget(token)
+
+ # Type
+ if self.force_rdtype is not None:
+ rdtype = self.force_rdtype
+ else:
+ token = self._get_identifier()
+ try:
+ rdtype = dns.rdatatype.from_text(token.value)
+ except Exception:
+ raise dns.exception.SyntaxError("unknown rdatatype '%s'" % token.value)
+
+ try:
+ rd = dns.rdata.from_text(
+ rdclass,
+ rdtype,
+ self.tok,
+ self.current_origin,
+ self.relativize,
+ self.zone_origin,
+ )
+ except dns.exception.SyntaxError:
+ # Catch and reraise.
+ raise
+ except Exception:
+ # All exceptions that occur in the processing of rdata
+ # are treated as syntax errors. This is not strictly
+ # correct, but it is correct almost all of the time.
+ # We convert them to syntax errors so that we can emit
+ # helpful filename:line info.
+ (ty, va) = sys.exc_info()[:2]
+ raise dns.exception.SyntaxError(
+ "caught exception {}: {}".format(str(ty), str(va))
+ )
+
+ if not self.default_ttl_known and rdtype == dns.rdatatype.SOA:
+ # The pre-RFC2308 and pre-BIND9 behavior inherits the zone default
+ # TTL from the SOA minttl if no $TTL statement is present before the
+ # SOA is parsed.
+ self.default_ttl = rd.minimum
+ self.default_ttl_known = True
+ if ttl is None:
+ # if we didn't have a TTL on the SOA, set it!
+ ttl = rd.minimum
+
+ # TTL check. We had to wait until now to do this as the SOA RR's
+ # own TTL can be inferred from its minimum.
+ if ttl is None:
+ raise dns.exception.SyntaxError("Missing default TTL value")
+
+ self.txn.add(name, ttl, rd)
+
+ def _parse_modify(self, side: str) -> Tuple[str, str, int, int, str]:
+ # Here we catch everything in '{' '}' in a group so we can replace it
+ # with ''.
+ is_generate1 = re.compile(r"^.*\$({(\+|-?)(\d+),(\d+),(.)}).*$")
+ is_generate2 = re.compile(r"^.*\$({(\+|-?)(\d+)}).*$")
+ is_generate3 = re.compile(r"^.*\$({(\+|-?)(\d+),(\d+)}).*$")
+ # Sometimes there are modifiers in the hostname. These come after
+ # the dollar sign. They are in the form: ${offset[,width[,base]]}.
+ # Make names
+ g1 = is_generate1.match(side)
+ if g1:
+ mod, sign, offset, width, base = g1.groups()
+ if sign == "":
+ sign = "+"
+ g2 = is_generate2.match(side)
+ if g2:
+ mod, sign, offset = g2.groups()
+ if sign == "":
+ sign = "+"
+ width = 0
+ base = "d"
+ g3 = is_generate3.match(side)
+ if g3:
+ mod, sign, offset, width = g3.groups()
+ if sign == "":
+ sign = "+"
+ base = "d"
+
+ if not (g1 or g2 or g3):
+ mod = ""
+ sign = "+"
+ offset = 0
+ width = 0
+ base = "d"
+
+ offset = int(offset)
+ width = int(width)
+
+ if sign not in ["+", "-"]:
+ raise dns.exception.SyntaxError("invalid offset sign %s" % sign)
+ if base not in ["d", "o", "x", "X", "n", "N"]:
+ raise dns.exception.SyntaxError("invalid type %s" % base)
+
+ return mod, sign, offset, width, base
def _generate_line(self):
+ # range lhs [ttl] [class] type rhs [ comment ]
"""Process one line containing the GENERATE statement from a DNS
zone file."""
- pass
+ if self.current_origin is None:
+ raise UnknownOrigin
+
+ token = self.tok.get()
+ # Range (required)
+ try:
+ start, stop, step = dns.grange.from_text(token.value)
+ token = self.tok.get()
+ if not token.is_identifier():
+ raise dns.exception.SyntaxError
+ except Exception:
+ raise dns.exception.SyntaxError
+
+ # lhs (required)
+ try:
+ lhs = token.value
+ token = self.tok.get()
+ if not token.is_identifier():
+ raise dns.exception.SyntaxError
+ except Exception:
+ raise dns.exception.SyntaxError
+
+ # TTL
+ try:
+ ttl = dns.ttl.from_text(token.value)
+ self.last_ttl = ttl
+ self.last_ttl_known = True
+ token = self.tok.get()
+ if not token.is_identifier():
+ raise dns.exception.SyntaxError
+ except dns.ttl.BadTTL:
+ if not (self.last_ttl_known or self.default_ttl_known):
+ raise dns.exception.SyntaxError("Missing default TTL value")
+ if self.default_ttl_known:
+ ttl = self.default_ttl
+ elif self.last_ttl_known:
+ ttl = self.last_ttl
+ # Class
+ try:
+ rdclass = dns.rdataclass.from_text(token.value)
+ token = self.tok.get()
+ if not token.is_identifier():
+ raise dns.exception.SyntaxError
+ except dns.exception.SyntaxError:
+ raise dns.exception.SyntaxError
+ except Exception:
+ rdclass = self.zone_rdclass
+ if rdclass != self.zone_rdclass:
+ raise dns.exception.SyntaxError("RR class is not zone's class")
+ # Type
+ try:
+ rdtype = dns.rdatatype.from_text(token.value)
+ token = self.tok.get()
+ if not token.is_identifier():
+ raise dns.exception.SyntaxError
+ except Exception:
+ raise dns.exception.SyntaxError("unknown rdatatype '%s'" % token.value)
- def read(self) ->None:
+ # rhs (required)
+ rhs = token.value
+
+ def _calculate_index(counter: int, offset_sign: str, offset: int) -> int:
+ """Calculate the index from the counter and offset."""
+ if offset_sign == "-":
+ offset *= -1
+ return counter + offset
+
+ def _format_index(index: int, base: str, width: int) -> str:
+ """Format the index with the given base, and zero-fill it
+ to the given width."""
+ if base in ["d", "o", "x", "X"]:
+ return format(index, base).zfill(width)
+
+ # base can only be n or N here
+ hexa = _format_index(index, "x", width)
+ nibbles = ".".join(hexa[::-1])[:width]
+ if base == "N":
+ nibbles = nibbles.upper()
+ return nibbles
+
+ lmod, lsign, loffset, lwidth, lbase = self._parse_modify(lhs)
+ rmod, rsign, roffset, rwidth, rbase = self._parse_modify(rhs)
+ for i in range(start, stop + 1, step):
+ # +1 because bind is inclusive and python is exclusive
+
+ lindex = _calculate_index(i, lsign, loffset)
+ rindex = _calculate_index(i, rsign, roffset)
+
+ lzfindex = _format_index(lindex, lbase, lwidth)
+ rzfindex = _format_index(rindex, rbase, rwidth)
+
+ name = lhs.replace("$%s" % (lmod), lzfindex)
+ rdata = rhs.replace("$%s" % (rmod), rzfindex)
+
+ self.last_name = dns.name.from_text(
+ name, self.current_origin, self.tok.idna_codec
+ )
+ name = self.last_name
+ if not name.is_subdomain(self.zone_origin):
+ self._eat_line()
+ return
+ if self.relativize:
+ name = name.relativize(self.zone_origin)
+
+ try:
+ rd = dns.rdata.from_text(
+ rdclass,
+ rdtype,
+ rdata,
+ self.current_origin,
+ self.relativize,
+ self.zone_origin,
+ )
+ except dns.exception.SyntaxError:
+ # Catch and reraise.
+ raise
+ except Exception:
+ # All exceptions that occur in the processing of rdata
+ # are treated as syntax errors. This is not strictly
+ # correct, but it is correct almost all of the time.
+ # We convert them to syntax errors so that we can emit
+ # helpful filename:line info.
+ (ty, va) = sys.exc_info()[:2]
+ raise dns.exception.SyntaxError(
+ "caught exception %s: %s" % (str(ty), str(va))
+ )
+
+ self.txn.add(name, ttl, rd)
+
+ def read(self) -> None:
"""Read a DNS zone file and build a zone object.
@raises dns.zone.NoSOA: No SOA RR was found at the zone origin
@raises dns.zone.NoNS: No NS RRset was found at the zone origin
"""
- pass
+ try:
+ while 1:
+ token = self.tok.get(True, True)
+ if token.is_eof():
+ if self.current_file is not None:
+ self.current_file.close()
+ if len(self.saved_state) > 0:
+ (
+ self.tok,
+ self.current_origin,
+ self.last_name,
+ self.current_file,
+ self.last_ttl,
+ self.last_ttl_known,
+ self.default_ttl,
+ self.default_ttl_known,
+ ) = self.saved_state.pop(-1)
+ continue
+ break
+ elif token.is_eol():
+ continue
+ elif token.is_comment():
+ self.tok.get_eol()
+ continue
+ elif token.value[0] == "$" and len(self.allowed_directives) > 0:
+ # Note that we only run directive processing code if at least
+ # one directive is allowed in order to be backwards compatible
+ c = token.value.upper()
+ if c not in self.allowed_directives:
+ raise dns.exception.SyntaxError(
+ f"zone file directive '{c}' is not allowed"
+ )
+ if c == "$TTL":
+ token = self.tok.get()
+ if not token.is_identifier():
+ raise dns.exception.SyntaxError("bad $TTL")
+ self.default_ttl = dns.ttl.from_text(token.value)
+ self.default_ttl_known = True
+ self.tok.get_eol()
+ elif c == "$ORIGIN":
+ self.current_origin = self.tok.get_name()
+ self.tok.get_eol()
+ if self.zone_origin is None:
+ self.zone_origin = self.current_origin
+ self.txn._set_origin(self.current_origin)
+ elif c == "$INCLUDE":
+ token = self.tok.get()
+ filename = token.value
+ token = self.tok.get()
+ new_origin: Optional[dns.name.Name]
+ if token.is_identifier():
+ new_origin = dns.name.from_text(
+ token.value, self.current_origin, self.tok.idna_codec
+ )
+ self.tok.get_eol()
+ elif not token.is_eol_or_eof():
+ raise dns.exception.SyntaxError("bad origin in $INCLUDE")
+ else:
+ new_origin = self.current_origin
+ self.saved_state.append(
+ (
+ self.tok,
+ self.current_origin,
+ self.last_name,
+ self.current_file,
+ self.last_ttl,
+ self.last_ttl_known,
+ self.default_ttl,
+ self.default_ttl_known,
+ )
+ )
+ self.current_file = open(filename, "r")
+ self.tok = dns.tokenizer.Tokenizer(self.current_file, filename)
+ self.current_origin = new_origin
+ elif c == "$GENERATE":
+ self._generate_line()
+ else:
+ raise dns.exception.SyntaxError(
+ f"Unknown zone file directive '{c}'"
+ )
+ continue
+ self.tok.unget(token)
+ self._rr_line()
+ except dns.exception.SyntaxError as detail:
+ (filename, line_number) = self.tok.where()
+ if detail is None:
+ detail = "syntax error"
+ ex = dns.exception.SyntaxError(
+ "%s:%d: %s" % (filename, line_number, detail)
+ )
+ tb = sys.exc_info()[2]
+ raise ex.with_traceback(tb) from None
-class RRsetsReaderTransaction(dns.transaction.Transaction):
+class RRsetsReaderTransaction(dns.transaction.Transaction):
def __init__(self, manager, replacement, read_only):
assert not read_only
super().__init__(manager, replacement, read_only)
self.rdatasets = {}
+ def _get_rdataset(self, name, rdtype, covers):
+ return self.rdatasets.get((name, rdtype, covers))
-class RRSetsReaderManager(dns.transaction.TransactionManager):
+ def _get_node(self, name):
+ rdatasets = []
+ for (rdataset_name, _, _), rdataset in self.rdatasets.items():
+ if name == rdataset_name:
+ rdatasets.append(rdataset)
+ if len(rdatasets) == 0:
+ return None
+ node = dns.node.Node()
+ node.rdatasets = rdatasets
+ return node
+
+ def _put_rdataset(self, name, rdataset):
+ self.rdatasets[(name, rdataset.rdtype, rdataset.covers)] = rdataset
+
+ def _delete_name(self, name):
+ # First remove any changes involving the name
+ remove = []
+ for key in self.rdatasets:
+ if key[0] == name:
+ remove.append(key)
+ if len(remove) > 0:
+ for key in remove:
+ del self.rdatasets[key]
- def __init__(self, origin=dns.name.root, relativize=False, rdclass=dns.
- rdataclass.IN):
+ def _delete_rdataset(self, name, rdtype, covers):
+ try:
+ del self.rdatasets[(name, rdtype, covers)]
+ except KeyError:
+ pass
+
+ def _name_exists(self, name):
+ for n, _, _ in self.rdatasets:
+ if n == name:
+ return True
+ return False
+
+ def _changed(self):
+ return len(self.rdatasets) > 0
+
+ def _end_transaction(self, commit):
+ if commit and self._changed():
+ rrsets = []
+ for (name, _, _), rdataset in self.rdatasets.items():
+ rrset = dns.rrset.RRset(
+ name, rdataset.rdclass, rdataset.rdtype, rdataset.covers
+ )
+ rrset.update(rdataset)
+ rrsets.append(rrset)
+ self.manager.set_rrsets(rrsets)
+
+ def _set_origin(self, origin):
+ pass
+
+ def _iterate_rdatasets(self):
+ raise NotImplementedError # pragma: no cover
+
+ def _iterate_names(self):
+ raise NotImplementedError # pragma: no cover
+
+
+class RRSetsReaderManager(dns.transaction.TransactionManager):
+ def __init__(
+ self, origin=dns.name.root, relativize=False, rdclass=dns.rdataclass.IN
+ ):
self.origin = origin
self.relativize = relativize
self.rdclass = rdclass
self.rrsets = []
+ def reader(self): # pragma: no cover
+ raise NotImplementedError
+
+ def writer(self, replacement=False):
+ assert replacement is True
+ return RRsetsReaderTransaction(self, True, False)
+
+ def get_class(self):
+ return self.rdclass
+
+ def origin_information(self):
+ if self.relativize:
+ effective = dns.name.empty
+ else:
+ effective = self.origin
+ return (self.origin, self.relativize, effective)
+
+ def set_rrsets(self, rrsets):
+ self.rrsets = rrsets
+
-def read_rrsets(text: Any, name: Optional[Union[dns.name.Name, str]]=None,
- ttl: Optional[int]=None, rdclass: Optional[Union[dns.rdataclass.
- RdataClass, str]]=dns.rdataclass.IN, default_rdclass: Union[dns.
- rdataclass.RdataClass, str]=dns.rdataclass.IN, rdtype: Optional[Union[
- dns.rdatatype.RdataType, str]]=None, default_ttl: Optional[Union[int,
- str]]=None, idna_codec: Optional[dns.name.IDNACodec]=None, origin:
- Optional[Union[dns.name.Name, str]]=dns.name.root, relativize: bool=False
- ) ->List[dns.rrset.RRset]:
+def read_rrsets(
+ text: Any,
+ name: Optional[Union[dns.name.Name, str]] = None,
+ ttl: Optional[int] = None,
+ rdclass: Optional[Union[dns.rdataclass.RdataClass, str]] = dns.rdataclass.IN,
+ default_rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN,
+ rdtype: Optional[Union[dns.rdatatype.RdataType, str]] = None,
+ default_ttl: Optional[Union[int, str]] = None,
+ idna_codec: Optional[dns.name.IDNACodec] = None,
+ origin: Optional[Union[dns.name.Name, str]] = dns.name.root,
+ relativize: bool = False,
+) -> List[dns.rrset.RRset]:
"""Read one or more rrsets from the specified text, possibly subject
to restrictions.
@@ -165,4 +711,36 @@ def read_rrsets(text: Any, name: Optional[Union[dns.name.Name, str]]=None,
if ``False`` then any relative names in the input are made absolute by
appending the *origin*.
"""
- pass
+ if isinstance(origin, str):
+ origin = dns.name.from_text(origin, dns.name.root, idna_codec)
+ if isinstance(name, str):
+ name = dns.name.from_text(name, origin, idna_codec)
+ if isinstance(ttl, str):
+ ttl = dns.ttl.from_text(ttl)
+ if isinstance(default_ttl, str):
+ default_ttl = dns.ttl.from_text(default_ttl)
+ if rdclass is not None:
+ rdclass = dns.rdataclass.RdataClass.make(rdclass)
+ else:
+ rdclass = None
+ default_rdclass = dns.rdataclass.RdataClass.make(default_rdclass)
+ if rdtype is not None:
+ rdtype = dns.rdatatype.RdataType.make(rdtype)
+ else:
+ rdtype = None
+ manager = RRSetsReaderManager(origin, relativize, default_rdclass)
+ with manager.writer(True) as txn:
+ tok = dns.tokenizer.Tokenizer(text, "<input>", idna_codec=idna_codec)
+ reader = Reader(
+ tok,
+ default_rdclass,
+ txn,
+ allow_directives=False,
+ force_name=name,
+ force_ttl=ttl,
+ force_rdclass=rdclass,
+ force_rdtype=rdtype,
+ default_ttl=default_ttl,
+ )
+ reader.read()
+ return manager.rrsets
diff --git a/dns/zonetypes.py b/dns/zonetypes.py
index 63dfed7..195ee2e 100644
--- a/dns/zonetypes.py
+++ b/dns/zonetypes.py
@@ -1,18 +1,37 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
"""Common zone-related types."""
+
+# This is a separate file to avoid import circularity between dns.zone and
+# the implementation of the ZONEMD type.
+
import hashlib
+
import dns.enum
class DigestScheme(dns.enum.IntEnum):
"""ZONEMD Scheme"""
+
SIMPLE = 1
+ @classmethod
+ def _maximum(cls):
+ return 255
+
class DigestHashAlgorithm(dns.enum.IntEnum):
"""ZONEMD Hash Algorithm"""
+
SHA384 = 1
SHA512 = 2
+ @classmethod
+ def _maximum(cls):
+ return 255
+
-_digest_hashers = {DigestHashAlgorithm.SHA384: hashlib.sha384,
- DigestHashAlgorithm.SHA512: hashlib.sha512}
+_digest_hashers = {
+ DigestHashAlgorithm.SHA384: hashlib.sha384,
+ DigestHashAlgorithm.SHA512: hashlib.sha512,
+}