diff --git a/dns/_features.py b/dns/_features.py
index de4dcf5..7d624d3 100644
--- a/dns/_features.py
+++ b/dns/_features.py
@@ -3,6 +3,26 @@ import itertools
import string
from typing import Dict, List, Tuple
+def _tuple_from_text(text: str) -> Tuple[int, ...]:
+ """Convert version text like '1.2.3' or '1.2rc1' to a tuple of integers.
+
+ Only includes numeric components, so '1.2rc1' becomes (1, 2).
+ """
+ if not text:
+ return ()
+ parts = []
+ for part in text.split('.'):
+ # Extract leading digits
+ digits = ''
+ for c in part:
+ if c.isdigit():
+ digits += c
+ else:
+ break
+ if digits:
+ parts.append(int(digits))
+ return tuple(parts)
+
def _version_check(requirement: str) -> bool:
"""Is the requirement fulfilled?
@@ -10,7 +30,15 @@ def _version_check(requirement: str) -> bool:
package>=version
"""
- pass
+ try:
+ package, version = requirement.split('>=')
+ installed_version = importlib.metadata.version(package)
+ # Convert versions to tuples for comparison
+ installed = tuple(int(x) for x in installed_version.split('.'))
+ required = tuple(int(x) for x in version.split('.'))
+ return installed >= required
+ except (importlib.metadata.PackageNotFoundError, ValueError):
+ return False
_cache: Dict[str, bool] = {}
def have(feature: str) -> bool:
@@ -23,7 +51,13 @@ def have(feature: str) -> bool:
and ``False`` if it is not or if metadata is
missing.
"""
- pass
+ if feature in _cache:
+ return _cache[feature]
+ if feature not in _requirements:
+ return False
+ result = all(_version_check(requirement) for requirement in _requirements[feature])
+ _cache[feature] = result
+ return result
def force(feature: str, enabled: bool) -> None:
"""Force the status of *feature* to be *enabled*.
@@ -31,5 +65,5 @@ def force(feature: str, enabled: bool) -> None:
This method is provided as a workaround for any cases
where importlib.metadata is ineffective, or for testing.
"""
- pass
+ _cache[feature] = enabled
_requirements: Dict[str, List[str]] = {'dnssec': ['cryptography>=41'], 'doh': ['httpcore>=1.0.0', 'httpx>=0.26.0', 'h2>=4.1.0'], 'doq': ['aioquic>=0.9.25'], 'idna': ['idna>=3.6'], 'trio': ['trio>=0.23'], 'wmi': ['wmi>=1.5.1']}
\ No newline at end of file
diff --git a/dns/_immutable_ctx.py b/dns/_immutable_ctx.py
index 2f94965..be8fd34 100644
--- a/dns/_immutable_ctx.py
+++ b/dns/_immutable_ctx.py
@@ -16,4 +16,17 @@ class _Immutable:
if _in__init__.get() is not self:
raise TypeError("object doesn't support attribute assignment")
else:
- super().__delattr__(name)
\ No newline at end of file
+ super().__delattr__(name)
+
+def immutable(f):
+ """A decorator which makes the returned object immutable.
+
+ The object has to inherit from the _Immutable class for this to work.
+ """
+ def wrapped(self, *args, **kwargs):
+ token = _in__init__.set(self)
+ try:
+ return f(self, *args, **kwargs)
+ finally:
+ _in__init__.reset(token)
+ return wrapped
\ No newline at end of file
diff --git a/dns/query.py b/dns/query.py
index bad09ad..c030c7a 100644
--- a/dns/query.py
+++ b/dns/query.py
@@ -72,6 +72,116 @@ except ImportError:
pass
socket_factory = socket.socket
+def _compute_times(timeout: Optional[float]) -> Tuple[Optional[float], Optional[float]]:
+ """Return a tuple of the current time and the expiration time, based on
+ the current time and the specified timeout. If timeout is None, None is
+ returned for the expiration time.
+
+ Returns a tuple of (float, float) or (float, None)
+ """
+ now = time.time()
+ if timeout is not None:
+ return (now, now + timeout)
+ return (now, None)
+
+def _remaining(expiration: Optional[float]) -> float:
+ """Return the amount of time remaining until the expiration time.
+
+ Returns a float or 0.0 if time has expired.
+ """
+ if expiration is None:
+ return 0.0
+ timeout = expiration - time.time()
+ if timeout <= 0.0:
+ return 0.0
+ else:
+ return timeout
+
+def _matches_destination(af: socket.AddressFamily, from_address: Any, destination: Any, ignore_scope: bool=False) -> bool:
+ """Is the address we got a response from the same address we sent to?
+
+ Returns a bool.
+ """
+ if af == socket.AF_INET:
+ # Destination is a tuple of (ip, port)
+ return from_address[0] == destination[0]
+ elif af == socket.AF_INET6:
+ # Destination is a tuple of (ip, port, flow info, scope id)
+ if ignore_scope:
+ from_address = from_address[:3] + (0,)
+ destination = destination[:3] + (0,)
+ return from_address[0] == destination[0]
+ else:
+ return False
+
+def _destination_and_source(where: str, port: int, source: Optional[str], source_port: int, where_must_be_address: bool=True) -> Tuple[socket.AddressFamily, Any, Any]:
+ """Return a tuple of address family, destination, and source.
+
+ Returns a tuple of (int, Any, Any)
+ """
+ af = None
+ destination = None
+ source_tuple = None
+
+ if where_must_be_address:
+ af = dns.inet.af_for_address(where)
+ destination = _lltuple((where, port))
+ else:
+ # We assume AF_INET if we don't know and are using a hostname
+ af = socket.AF_INET
+ destination = None
+
+ if source is not None:
+ af = dns.inet.af_for_address(source)
+ source_tuple = _lltuple((source, source_port))
+ elif source_port:
+ if af == socket.AF_INET:
+ source_tuple = ('0.0.0.0', source_port)
+ elif af == socket.AF_INET6:
+ source_tuple = ('::', source_port, 0, 0)
+ else:
+ raise ValueError('source_port specified but address family is unknown')
+
+ return (af, destination, source_tuple)
+
+def _make_dot_ssl_context(verify: Union[bool, str], server_hostname: Optional[str]=None) -> ssl.SSLContext:
+ """Create an SSL context for DoT.
+
+ *verify*, a ``bool`` or ``str``. If a ``True``, then TLS certificate verification
+ of the server is done using the default CA bundle; if ``False``, then no
+ verification is done; if a ``str`` then it specifies the path to a certificate
+ file or directory which will be used for verification.
+
+ *server_hostname*, a ``str`` or ``None``, the server's hostname.
+ """
+ if verify is True:
+ ctx = ssl.create_default_context()
+ if server_hostname is None:
+ ctx.check_hostname = False
+ elif verify is False:
+ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS)
+ ctx.verify_mode = ssl.CERT_NONE
+ elif isinstance(verify, str):
+ ctx = ssl.create_default_context(cafile=verify)
+ if server_hostname is None:
+ ctx.check_hostname = False
+ else:
+ raise ValueError('verify must be True, False, or a string')
+ return ctx
+
+socket_factory = socket.socket
+
+class UDPMode(enum.IntEnum):
+ """What transport to use.
+
+ NEVER means always use TCP.
+ ONLY means always use UDP.
+ FIRST means try UDP and fallback to TCP if needed.
+ """
+ NEVER = 0
+ ONLY = 1
+ FIRST = 2
+
class UnexpectedSource(dns.exception.DNSException):
"""A DNS query response came from an unexpected address or port."""
@@ -149,14 +259,40 @@ def _udp_recv(sock, max_size, expiration):
A Timeout exception will be raised if the operation is not completed
by the expiration time.
"""
- pass
+ while True:
+ if expiration is not None:
+ timeout = _remaining(expiration)
+ if timeout <= 0.0:
+ raise dns.exception.Timeout
+ sock.settimeout(timeout)
+ try:
+ return sock.recvfrom(max_size)
+ except socket.timeout:
+ raise dns.exception.Timeout
+ except socket.error as e:
+ if e.args[0] == errno.EINTR:
+ continue
+ raise
def _udp_send(sock, data, destination, expiration):
"""Sends the specified datagram to destination over the socket.
A Timeout exception will be raised if the operation is not completed
by the expiration time.
"""
- pass
+ while True:
+ if expiration is not None:
+ timeout = _remaining(expiration)
+ if timeout <= 0.0:
+ raise dns.exception.Timeout
+ sock.settimeout(timeout)
+ try:
+ return sock.sendto(data, destination)
+ except socket.timeout:
+ raise dns.exception.Timeout
+ except socket.error as e:
+ if e.args[0] == errno.EINTR:
+ continue
+ raise
def send_udp(sock: Any, what: Union[dns.message.Message, bytes], destination: Any, expiration: Optional[float]=None) -> Tuple[int, float]:
"""Send a DNS message to the specified UDP socket.
@@ -174,7 +310,11 @@ def send_udp(sock: Any, what: Union[dns.message.Message, bytes], destination: An
Returns an ``(int, float)`` tuple of bytes sent and the sent time.
"""
- pass
+ if isinstance(what, dns.message.Message):
+ what = what.to_wire()
+ sent_time = time.time()
+ n = _udp_send(sock, what, destination, expiration)
+ return (n, sent_time)
def receive_udp(sock: Any, destination: Optional[Any]=None, expiration: Optional[float]=None, ignore_unexpected: bool=False, one_rr_per_rrset: bool=False, keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]]=None, request_mac: Optional[bytes]=b'', ignore_trailing: bool=False, raise_on_truncation: bool=False, ignore_errors: bool=False, query: Optional[dns.message.Message]=None) -> Any:
"""Read a DNS message from a UDP socket.
@@ -225,7 +365,40 @@ def receive_udp(sock: Any, destination: Optional[Any]=None, expiration: Optional
*ignore_errors* is ``True``, check that the received message is a response
to this query, and if not keep listening for a valid response.
"""
- pass
+ while True:
+ wire = b''
+ try:
+ (wire, from_address) = _udp_recv(sock, 65535, expiration)
+ except dns.exception.Timeout:
+ raise
+ received_time = time.time()
+ if expiration is not None and received_time > expiration:
+ raise dns.exception.Timeout
+ if destination:
+ if not ignore_unexpected and not _matches_destination(sock.family, from_address, destination, True):
+ if not ignore_errors:
+ raise UnexpectedSource('got a response from {} instead of {}'.format(from_address, destination))
+ continue
+ try:
+ r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac,
+ one_rr_per_rrset=one_rr_per_rrset,
+ ignore_trailing=ignore_trailing)
+ except dns.message.TrailingJunk:
+ if not ignore_errors:
+ raise
+ continue
+ except dns.exception.FormError:
+ if not ignore_errors:
+ raise
+ continue
+ if query is not None and query.id != r.id:
+ if not ignore_errors:
+ raise BadResponse
+ continue
+ if destination is None:
+ return (r, received_time, from_address)
+ else:
+ return (r, received_time)
def udp(q: dns.message.Message, where: str, timeout: Optional[float]=None, port: int=53, source: Optional[str]=None, source_port: int=0, ignore_unexpected: bool=False, one_rr_per_rrset: bool=False, ignore_trailing: bool=False, raise_on_truncation: bool=False, sock: Optional[Any]=None, ignore_errors: bool=False) -> dns.message.Message:
"""Return the response obtained after sending a query via UDP.
@@ -323,14 +496,40 @@ def _net_read(sock, count, expiration):
A Timeout exception will be raised if the operation is not completed
by the expiration time.
"""
- pass
+ s = b''
+ while count > 0:
+ if expiration is not None:
+ timeout = _remaining(expiration)
+ if timeout <= 0.0:
+ raise dns.exception.Timeout
+ sock.settimeout(timeout)
+ try:
+ n = sock.recv(count)
+ if n == b'':
+ raise EOFError
+ count = count - len(n)
+ s = s + n
+ except socket.timeout:
+ raise dns.exception.Timeout
+ return s
def _net_write(sock, data, expiration):
"""Write the specified data to the socket.
A Timeout exception will be raised if the operation is not completed
by the expiration time.
"""
- pass
+ current = 0
+ l = len(data)
+ while current < l:
+ if expiration is not None:
+ timeout = _remaining(expiration)
+ if timeout <= 0.0:
+ raise dns.exception.Timeout
+ sock.settimeout(timeout)
+ try:
+ current += sock.send(data[current:])
+ except socket.timeout:
+ raise dns.exception.Timeout
def send_tcp(sock: Any, what: Union[dns.message.Message, bytes], expiration: Optional[float]=None) -> Tuple[int, float]:
"""Send a DNS message to the specified TCP socket.
@@ -345,7 +544,15 @@ def send_tcp(sock: Any, what: Union[dns.message.Message, bytes], expiration: Opt
Returns an ``(int, float)`` tuple of bytes sent and the sent time.
"""
- pass
+ if isinstance(what, dns.message.Message):
+ what = what.to_wire()
+ l = len(what)
+ # Convert the 16-bit integer to network byte order
+ header = struct.pack("!H", l)
+ sent_time = time.time()
+ _net_write(sock, header, expiration)
+ _net_write(sock, what, expiration)
+ return (l + 2, sent_time)
def receive_tcp(sock: Any, expiration: Optional[float]=None, one_rr_per_rrset: bool=False, keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]]=None, request_mac: Optional[bytes]=b'', ignore_trailing: bool=False) -> Tuple[dns.message.Message, float]:
"""Read a DNS message from a TCP socket.
@@ -372,7 +579,14 @@ def receive_tcp(sock: Any, expiration: Optional[float]=None, one_rr_per_rrset: b
Returns a ``(dns.message.Message, float)`` tuple of the received message
and the received time.
"""
- pass
+ ldata = _net_read(sock, 2, expiration)
+ (l,) = struct.unpack("!H", ldata)
+ wire = _net_read(sock, l, expiration)
+ received_time = time.time()
+ r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac,
+ one_rr_per_rrset=one_rr_per_rrset,
+ ignore_trailing=ignore_trailing)
+ return (r, received_time)
def tcp(q: dns.message.Message, where: str, timeout: Optional[float]=None, port: int=53, source: Optional[str]=None, source_port: int=0, one_rr_per_rrset: bool=False, ignore_trailing: bool=False, sock: Optional[Any]=None) -> dns.message.Message:
"""Return the response obtained after sending a query via TCP.
diff --git a/dns/rdata.py b/dns/rdata.py
index 9da56a7..1a9844e 100644
--- a/dns/rdata.py
+++ b/dns/rdata.py
@@ -31,37 +31,56 @@ def _wordbreak(data, chunksize=_chunksize, separator=b' '):
"""Break a binary string into chunks of chunksize characters separated by
a space.
"""
- pass
+ chunks = []
+ for i in range(0, len(data), chunksize):
+ chunks.append(data[i:i + chunksize])
+ return separator.join(chunks)
def _hexify(data, chunksize=_chunksize, separator=b' ', **kw):
"""Convert a binary string into its hex encoding, broken up into chunks
of chunksize characters separated by a separator.
"""
- pass
+ hex_data = binascii.hexlify(data)
+ return _wordbreak(hex_data, chunksize, separator)
def _base64ify(data, chunksize=_chunksize, separator=b' ', **kw):
"""Convert a binary string into its base64 encoding, broken up into chunks
of chunksize characters separated by a separator.
"""
- pass
+ b64_data = base64.b64encode(data)
+ return _wordbreak(b64_data, chunksize, separator)
__escaped = b'"\\'
def _escapify(qstring):
"""Escape the characters in a quoted string which need it."""
- pass
+ if not isinstance(qstring, bytes):
+ qstring = qstring.encode()
+
+ text = ''
+ for c in qstring:
+ if c in __escaped:
+ text += '\\' + chr(c)
+ elif c >= 0x20 and c < 0x7F:
+ text += chr(c)
+ else:
+ text += '\\%03d' % c
+ return text
def _truncate_bitmap(what):
"""Determine the index of greatest byte that isn't all zeros, and
return the bitmap that contains all the bytes less than that index.
"""
- pass
+ for i in range(len(what) - 1, -1, -1):
+ if what[i] != 0:
+ return what[0:i + 1]
+ return what[0:1]
_constify = dns.immutable.constify
-@dns.immutable.immutable
class Rdata:
"""Base class for all DNS rdata types."""
__slots__ = ['rdclass', 'rdtype', 'rdcomment']
+ @dns.immutable.immutable
def __init__(self, rdclass, rdtype):
"""Initialize an rdata.
@@ -96,7 +115,7 @@ class Rdata:
Returns a ``dns.rdatatype.RdataType``.
"""
- pass
+ return dns.rdatatype.NONE
def extended_rdatatype(self) -> int:
"""Return a 32-bit type value, the least significant 16 bits of
@@ -105,28 +124,32 @@ class Rdata:
Returns an ``int``.
"""
- pass
+ return self.covers() << 16 | self.rdtype
def to_text(self, origin: Optional[dns.name.Name]=None, relativize: bool=True, **kw: Dict[str, Any]) -> str:
"""Convert an rdata to text format.
Returns a ``str``.
"""
- pass
+ raise NotImplementedError
def to_wire(self, file: Optional[Any]=None, compress: Optional[dns.name.CompressType]=None, origin: Optional[dns.name.Name]=None, canonicalize: bool=False) -> bytes:
"""Convert an rdata to wire format.
Returns a ``bytes`` or ``None``.
"""
- pass
+ raise NotImplementedError
def to_generic(self, origin: Optional[dns.name.Name]=None) -> 'dns.rdata.GenericRdata':
"""Creates a dns.rdata.GenericRdata equivalent of this rdata.
Returns a ``dns.rdata.GenericRdata``.
"""
- pass
+ # Get wire format data
+ with io.BytesIO() as buffer:
+ self.to_wire(buffer, None, origin)
+ data = buffer.getvalue()
+ return GenericRdata(self.rdclass, self.rdtype, data)
def to_digestable(self, origin: Optional[dns.name.Name]=None) -> bytes:
"""Convert rdata to a format suitable for digesting in hashes. This
@@ -134,7 +157,7 @@ class Rdata:
Returns a ``bytes``.
"""
- pass
+ return self.to_wire(None, None, origin, True)
def __repr__(self):
covers = self.covers()
@@ -163,7 +186,28 @@ class Rdata:
In the future, all ordering comparisons for rdata with
relative names will be disallowed.
"""
- pass
+ our_relative = False
+ their_relative = False
+ try:
+ our = self.to_digestable()
+ except dns.name.NeedAbsoluteNameOrOrigin:
+ our = self.to_digestable(dns.name.root)
+ our_relative = True
+ try:
+ their = other.to_digestable()
+ except dns.name.NeedAbsoluteNameOrOrigin:
+ their = other.to_digestable(dns.name.root)
+ their_relative = True
+ if our_relative and not their_relative:
+ return -1
+ if their_relative and not our_relative:
+ return 1
+ if our < their:
+ return -1
+ elif our > their:
+ return 1
+ else:
+ return 0
def __eq__(self, other):
if not isinstance(other, Rdata):
@@ -227,9 +271,22 @@ class Rdata:
Returns an instance of the same Rdata subclass as *self*.
"""
- pass
+ # Get all slots from the class hierarchy
+ slots = self._get_all_slots()
+
+ # Create a new instance with same rdclass and rdtype
+ new_instance = self.__class__(self.rdclass, self.rdtype)
+
+ # Copy all slot values from self to new instance, unless overridden in kwargs
+ for slot in slots:
+ if slot in kwargs:
+ value = kwargs[slot]
+ else:
+ value = getattr(self, slot)
+ object.__setattr__(new_instance, slot, value)
+
+ return new_instance
-@dns.immutable.immutable
class GenericRdata(Rdata):
"""Generic Rdata Class
@@ -238,9 +295,19 @@ class GenericRdata(Rdata):
"""
__slots__ = ['data']
+ @dns.immutable.immutable
def __init__(self, rdclass, rdtype, data):
super().__init__(rdclass, rdtype)
self.data = data
+
+ def to_text(self, origin=None, relativize=True, **kw):
+ return r'\# %u %s' % (len(self.data), str(_hexify(self.data).decode()))
+
+ def to_wire(self, file=None, compress=None, origin=None, canonicalize=False):
+ if file:
+ file.write(self.data)
+ else:
+ return self.data
_rdata_classes: Dict[Tuple[dns.rdataclass.RdataClass, dns.rdatatype.RdataType], Any] = {}
_module_prefix = 'dns.rdtypes'
@@ -281,7 +348,34 @@ def from_text(rdclass: Union[dns.rdataclass.RdataClass, str], rdtype: Union[dns.
Returns an instance of the chosen Rdata subclass.
"""
- pass
+ rdclass = dns.rdataclass.RdataClass.make(rdclass)
+ rdtype = dns.rdatatype.RdataType.make(rdtype)
+
+ if isinstance(tok, str):
+ tok = dns.tokenizer.Tokenizer(tok, idna_codec=idna_codec)
+
+ # Get the class
+ rdclass_text = dns.rdataclass.to_text(rdclass)
+ rdtype_text = dns.rdatatype.to_text(rdtype)
+ mod = None
+ rdtype_cls = None
+ try:
+ mod_name = f"{_module_prefix}.{rdclass_text}.{rdtype_text}"
+ mod = import_module(mod_name)
+ rdtype_cls = getattr(mod, rdtype_text)
+ except (ImportError, AttributeError):
+ try:
+ mod_name = f"{_module_prefix}.ANY.{rdtype_text}"
+ mod = import_module(mod_name)
+ rdtype_cls = getattr(mod, rdtype_text)
+ except (ImportError, AttributeError):
+ rdtype_cls = GenericRdata
+
+ # Call from_text() on the class
+ if relativize_to is None:
+ relativize_to = origin
+ return rdtype_cls.from_text(rdclass, rdtype, tok, origin, relativize,
+ relativize_to)
def from_wire_parser(rdclass: Union[dns.rdataclass.RdataClass, str], rdtype: Union[dns.rdatatype.RdataType, str], parser: dns.wire.Parser, origin: Optional[dns.name.Name]=None) -> Rdata:
"""Build an rdata object from wire format
@@ -306,7 +400,28 @@ def from_wire_parser(rdclass: Union[dns.rdataclass.RdataClass, str], rdtype: Uni
Returns an instance of the chosen Rdata subclass.
"""
- pass
+ rdclass = dns.rdataclass.RdataClass.make(rdclass)
+ rdtype = dns.rdatatype.RdataType.make(rdtype)
+
+ # Get the class
+ rdclass_text = dns.rdataclass.to_text(rdclass)
+ rdtype_text = dns.rdatatype.to_text(rdtype)
+ mod = None
+ rdtype_cls = None
+ try:
+ mod_name = f"{_module_prefix}.{rdclass_text}.{rdtype_text}"
+ mod = import_module(mod_name)
+ rdtype_cls = getattr(mod, rdtype_text)
+ except (ImportError, AttributeError):
+ try:
+ mod_name = f"{_module_prefix}.ANY.{rdtype_text}"
+ mod = import_module(mod_name)
+ rdtype_cls = getattr(mod, rdtype_text)
+ except (ImportError, AttributeError):
+ rdtype_cls = GenericRdata
+
+ # Call from_wire() on the class
+ return rdtype_cls.from_wire_parser(rdclass, rdtype, parser, origin)
def from_wire(rdclass: Union[dns.rdataclass.RdataClass, str], rdtype: Union[dns.rdatatype.RdataType, str], wire: bytes, current: int, rdlen: int, origin: Optional[dns.name.Name]=None) -> Rdata:
"""Build an rdata object from wire format
@@ -335,7 +450,9 @@ def from_wire(rdclass: Union[dns.rdataclass.RdataClass, str], rdtype: Union[dns.
Returns an instance of the chosen Rdata subclass.
"""
- pass
+ parser = dns.wire.Parser(wire, current)
+ with parser.restrict_to(rdlen):
+ return from_wire_parser(rdclass, rdtype, parser, origin)
class RdatatypeExists(dns.exception.DNSException):
"""DNS rdatatype already exists."""
@@ -358,4 +475,7 @@ def register_type(implementation: Any, rdtype: int, rdtype_text: str, is_singlet
*rdclass*, the rdataclass of the type, or ``dns.rdataclass.ANY`` if
it applies to all classes.
"""
- pass
\ No newline at end of file
+ existing_cls = _rdata_classes.get((rdclass, rdtype))
+ if existing_cls is not None:
+ raise RdatatypeExists(rdclass=rdclass, rdtype=rdtype)
+ _rdata_classes[(rdclass, rdtype)] = implementation
\ No newline at end of file
diff --git a/dns/rdtypes/ANY/CDNSKEY.py b/dns/rdtypes/ANY/CDNSKEY.py
index e722829..32c824c 100644
--- a/dns/rdtypes/ANY/CDNSKEY.py
+++ b/dns/rdtypes/ANY/CDNSKEY.py
@@ -2,6 +2,5 @@ import dns.immutable
import dns.rdtypes.dnskeybase
from dns.rdtypes.dnskeybase import REVOKE, SEP, ZONE
-@dns.immutable.immutable
class CDNSKEY(dns.rdtypes.dnskeybase.DNSKEYBase):
"""CDNSKEY record"""
\ No newline at end of file
diff --git a/dns/rdtypes/ANY/DNSKEY.py b/dns/rdtypes/ANY/DNSKEY.py
index bc52d7c..eca6b51 100644
--- a/dns/rdtypes/ANY/DNSKEY.py
+++ b/dns/rdtypes/ANY/DNSKEY.py
@@ -2,6 +2,5 @@ import dns.immutable
import dns.rdtypes.dnskeybase
from dns.rdtypes.dnskeybase import REVOKE, SEP, ZONE
-@dns.immutable.immutable
class DNSKEY(dns.rdtypes.dnskeybase.DNSKEYBase):
"""DNSKEY record"""
\ No newline at end of file
diff --git a/dns/rdtypes/ANY/EUI48.py b/dns/rdtypes/ANY/EUI48.py
index 65820f7..321cc48 100644
--- a/dns/rdtypes/ANY/EUI48.py
+++ b/dns/rdtypes/ANY/EUI48.py
@@ -1,7 +1,6 @@
import dns.immutable
import dns.rdtypes.euibase
-@dns.immutable.immutable
class EUI48(dns.rdtypes.euibase.EUIBase):
"""EUI48 record"""
byte_len = 6
diff --git a/dns/rdtypes/ANY/TXT.py b/dns/rdtypes/ANY/TXT.py
index eae71e0..aff2103 100644
--- a/dns/rdtypes/ANY/TXT.py
+++ b/dns/rdtypes/ANY/TXT.py
@@ -1,6 +1,5 @@
import dns.immutable
import dns.rdtypes.txtbase
-@dns.immutable.immutable
class TXT(dns.rdtypes.txtbase.TXTBase):
"""TXT record"""
\ No newline at end of file
diff --git a/dns/rdtypes/dnskeybase.py b/dns/rdtypes/dnskeybase.py
index bfb691a..2c1e562 100644
--- a/dns/rdtypes/dnskeybase.py
+++ b/dns/rdtypes/dnskeybase.py
@@ -12,11 +12,11 @@ class Flag(enum.IntFlag):
REVOKE = 128
ZONE = 256
-@dns.immutable.immutable
class DNSKEYBase(dns.rdata.Rdata):
"""Base class for rdata that is like a DNSKEY record"""
__slots__ = ['flags', 'protocol', 'algorithm', 'key']
+ @dns.immutable.immutable
def __init__(self, rdclass, rdtype, flags, protocol, algorithm, key):
super().__init__(rdclass, rdtype)
self.flags = Flag(self._as_uint16(flags))
diff --git a/dns/rdtypes/dsbase.py b/dns/rdtypes/dsbase.py
index 0b37fde..76a4f14 100644
--- a/dns/rdtypes/dsbase.py
+++ b/dns/rdtypes/dsbase.py
@@ -5,12 +5,12 @@ import dns.immutable
import dns.rdata
import dns.rdatatype
-@dns.immutable.immutable
class DSBase(dns.rdata.Rdata):
"""Base class for rdata that is like a DS record"""
__slots__ = ['key_tag', 'algorithm', 'digest_type', 'digest']
_digest_length_by_type = {1: 20, 2: 32, 3: 32, 4: 48}
+ @dns.immutable.immutable
def __init__(self, rdclass, rdtype, key_tag, algorithm, digest_type, digest):
super().__init__(rdclass, rdtype)
self.key_tag = self._as_uint16(key_tag)
diff --git a/dns/rdtypes/euibase.py b/dns/rdtypes/euibase.py
index 3056722..74c247f 100644
--- a/dns/rdtypes/euibase.py
+++ b/dns/rdtypes/euibase.py
@@ -2,11 +2,11 @@ import binascii
import dns.immutable
import dns.rdata
-@dns.immutable.immutable
class EUIBase(dns.rdata.Rdata):
"""EUIxx record"""
__slots__ = ['eui']
+ @dns.immutable.immutable
def __init__(self, rdclass, rdtype, eui):
super().__init__(rdclass, rdtype)
self.eui = self._as_bytes(eui)
diff --git a/dns/rdtypes/nsbase.py b/dns/rdtypes/nsbase.py
index f8a63f9..d98bdea 100644
--- a/dns/rdtypes/nsbase.py
+++ b/dns/rdtypes/nsbase.py
@@ -4,16 +4,15 @@ import dns.immutable
import dns.name
import dns.rdata
-@dns.immutable.immutable
class NSBase(dns.rdata.Rdata):
"""Base class for rdata that is like an NS record."""
__slots__ = ['target']
+ @dns.immutable.immutable
def __init__(self, rdclass, rdtype, target):
super().__init__(rdclass, rdtype)
self.target = self._as_name(target)
-@dns.immutable.immutable
class UncompressedNS(NSBase):
"""Base class for rdata that is like an NS record, but whose name
is not compressed when convert to DNS wire format, and whose
diff --git a/dns/rdtypes/svcbbase.py b/dns/rdtypes/svcbbase.py
index 5031566..eb267d5 100644
--- a/dns/rdtypes/svcbbase.py
+++ b/dns/rdtypes/svcbbase.py
@@ -33,20 +33,19 @@ class Emptiness(enum.IntEnum):
ALLOWED = 2
_escaped = b'",\\'
-@dns.immutable.immutable
class Param:
"""Abstract base class for SVCB parameters"""
-@dns.immutable.immutable
class GenericParam(Param):
"""Generic SVCB parameter"""
+ @dns.immutable.immutable
def __init__(self, value):
self.value = dns.rdata.Rdata._as_bytes(value, True)
-@dns.immutable.immutable
class MandatoryParam(Param):
+ @dns.immutable.immutable
def __init__(self, keys):
keys = sorted([_validate_key(key)[0] for key in keys])
prior_k = None
@@ -58,46 +57,45 @@ class MandatoryParam(Param):
raise ValueError('listed the mandatory key as mandatory')
self.keys = tuple(keys)
-@dns.immutable.immutable
class ALPNParam(Param):
+ @dns.immutable.immutable
def __init__(self, ids):
self.ids = dns.rdata.Rdata._as_tuple(ids, lambda x: dns.rdata.Rdata._as_bytes(x, True, 255, False))
-@dns.immutable.immutable
class NoDefaultALPNParam(Param):
pass
-@dns.immutable.immutable
class PortParam(Param):
+ @dns.immutable.immutable
def __init__(self, port):
self.port = dns.rdata.Rdata._as_uint16(port)
-@dns.immutable.immutable
class IPv4HintParam(Param):
+ @dns.immutable.immutable
def __init__(self, addresses):
self.addresses = dns.rdata.Rdata._as_tuple(addresses, dns.rdata.Rdata._as_ipv4_address)
-@dns.immutable.immutable
class IPv6HintParam(Param):
+ @dns.immutable.immutable
def __init__(self, addresses):
self.addresses = dns.rdata.Rdata._as_tuple(addresses, dns.rdata.Rdata._as_ipv6_address)
-@dns.immutable.immutable
class ECHParam(Param):
+ @dns.immutable.immutable
def __init__(self, ech):
self.ech = dns.rdata.Rdata._as_bytes(ech, True)
_class_for_key = {ParamKey.MANDATORY: MandatoryParam, ParamKey.ALPN: ALPNParam, ParamKey.NO_DEFAULT_ALPN: NoDefaultALPNParam, ParamKey.PORT: PortParam, ParamKey.IPV4HINT: IPv4HintParam, ParamKey.ECH: ECHParam, ParamKey.IPV6HINT: IPv6HintParam}
-@dns.immutable.immutable
class SVCBBase(dns.rdata.Rdata):
"""Base class for SVCB-like records"""
__slots__ = ['priority', 'target', 'params']
+ @dns.immutable.immutable
def __init__(self, rdclass, rdtype, priority, target, params):
super().__init__(rdclass, rdtype)
self.priority = self._as_uint16(priority)
diff --git a/dns/rdtypes/txtbase.py b/dns/rdtypes/txtbase.py
index 271bb72..04487be 100644
--- a/dns/rdtypes/txtbase.py
+++ b/dns/rdtypes/txtbase.py
@@ -6,11 +6,11 @@ import dns.rdata
import dns.renderer
import dns.tokenizer
-@dns.immutable.immutable
class TXTBase(dns.rdata.Rdata):
"""Base class for rdata that is like a TXT record (see RFC 1035)."""
__slots__ = ['strings']
+ @dns.immutable.immutable
def __init__(self, rdclass: dns.rdataclass.RdataClass, rdtype: dns.rdatatype.RdataType, strings: Iterable[Union[bytes, str]]):
"""Initialize a TXT-like rdata.
diff --git a/dns/resolver.py b/dns/resolver.py
index 1d1d63b..3424bb4 100644
--- a/dns/resolver.py
+++ b/dns/resolver.py
@@ -447,7 +447,34 @@ class BaseResolver:
def reset(self) -> None:
"""Reset all resolver configuration to the defaults."""
- pass
+ self.domain = dns.name.empty
+ self.nameserver_ports = {}
+ self.port = 53
+ self.search = []
+ self.use_search_by_default = True
+ self.timeout = 2.0
+ self.lifetime = 5.0
+ self.keyring = None
+ self.keyname = None
+ self.keyalgorithm = dns.name.from_text('HMAC-MD5.SIG-ALG.REG.INT')
+ self.edns = -1
+ self.ednsflags = 0
+ self.ednsoptions = None
+ self.payload = dns.message.DEFAULT_EDNS_PAYLOAD
+ self.cache = None
+ self.flags = None
+ self.retry_servfail = False
+ self.rotate = False
+ self.ndots = None
+ self._nameservers = []
+
+ @property
+ def nameservers(self) -> Sequence[Union[str, dns.nameserver.Nameserver]]:
+ """The nameservers to use for queries.
+
+ Raises ValueError if no nameservers are configured.
+ """
+ return self._nameservers
def read_resolv_conf(self, f: Any) -> None:
"""Process *f* as a file in the /etc/resolv.conf format. If f is
@@ -507,14 +534,16 @@ class BaseResolver:
@nameservers.setter
def nameservers(self, nameservers: Sequence[Union[str, dns.nameserver.Nameserver]]) -> None:
- """
+ """Set the nameservers to use for queries.
+
*nameservers*, a ``list`` of nameservers, where a nameserver is either
- a string interpretable as a nameserver, or a ``dns.nameserver.Nameserver``
- instance.
+ a string containing an IPv4 or IPv6 address, or a ``dns.nameserver.Nameserver``.
Raises ``ValueError`` if *nameservers* is not a list of nameservers.
"""
- pass
+ self._nameservers = list(nameservers)
+
+
class Resolver(BaseResolver):
"""DNS stub resolver."""