diff --git a/bitstring/bits.py b/bitstring/bits.py
index 41e6c7c..6925041 100644
--- a/bitstring/bits.py
+++ b/bitstring/bits.py
@@ -446,6 +446,331 @@ class Bits:
"""Reset the bitstring to have given unsigned int interpretation."""
pass
+ def _setuintle(self, uint: int, length: Optional[int]=None) -> None:
+ """Reset the bitstring to have given unsigned int interpretation in little-endian."""
+ if length is None:
+ # Calculate the minimum number of bits needed
+ length = max(uint.bit_length(), 1)
+ if length % 8:
+ length += 8 - (length % 8)
+ if length % 8:
+ raise ValueError("Little-endian integers must be whole-byte. Length = {0} bits.".format(length))
+ if uint < 0:
+ raise ValueError("Little-endian unsigned integer cannot be negative.")
+ if uint >= (1 << length):
+ raise ValueError("Little-endian unsigned integer is too large for length {0}.".format(length))
+ # Convert to bytes in little-endian order
+ num_bytes = length // 8
+ byte_data = uint.to_bytes(num_bytes, byteorder='little', signed=False)
+ self._setbytes(byte_data)
+
+ def _setintle(self, value: int, length: Optional[int]=None) -> None:
+ """Reset the bitstring to have given signed int interpretation in little-endian."""
+ if length is None:
+ # Calculate the minimum number of bits needed
+ length = max(value.bit_length() + 1, 1) # +1 for sign bit
+ if length % 8:
+ length += 8 - (length % 8)
+ if length % 8:
+ raise ValueError("Little-endian integers must be whole-byte. Length = {0} bits.".format(length))
+ num_bytes = length // 8
+ try:
+ byte_data = value.to_bytes(num_bytes, byteorder='little', signed=True)
+ except OverflowError:
+ raise ValueError("Little-endian signed integer is too large for length {0}.".format(length))
+ self._setbytes(byte_data)
+
+ def _setfloatbe(self, value: float, length: Optional[int]=None) -> None:
+ """Reset the bitstring to have given float interpretation in big-endian."""
+ if length is None:
+ raise ValueError("A length must be specified with float initialisation.")
+ if length == 16:
+ fmt = '>e'
+ elif length == 32:
+ fmt = '>f'
+ elif length == 64:
+ fmt = '>d'
+ else:
+ raise ValueError("float length must be 16, 32 or 64 bits.")
+ try:
+ byte_data = struct.pack(fmt, value)
+ except (struct.error, OverflowError):
+ raise ValueError("Float is too large for length {0}.".format(length))
+ self._setbytes(byte_data)
+
+ def _setfloatle(self, value: float, length: Optional[int]=None) -> None:
+ """Reset the bitstring to have given float interpretation in little-endian."""
+ if length is None:
+ raise ValueError("A length must be specified with float initialisation.")
+ if length == 16:
+ fmt = '<e'
+ elif length == 32:
+ fmt = '<f'
+ elif length == 64:
+ fmt = '<d'
+ else:
+ raise ValueError("float length must be 16, 32 or 64 bits.")
+ try:
+ byte_data = struct.pack(fmt, value)
+ except (struct.error, OverflowError):
+ raise ValueError("Float is too large for length {0}.".format(length))
+ self._setbytes(byte_data)
+
+ def _setbfloatbe(self, value: float, length: Optional[int]=None) -> None:
+ """Reset the bitstring to have given bfloat interpretation in big-endian."""
+ if length is not None and length != 16:
+ raise ValueError("bfloat must be 16 bits.")
+ # Convert to 32-bit float first
+ try:
+ byte_data = struct.pack('>f', value)
+ # Take only the first two bytes (16 bits) for bfloat
+ self._setbytes(byte_data[:2])
+ except (struct.error, OverflowError):
+ raise ValueError("Float is too large for bfloat format.")
+
+ def _getbfloatbe(self) -> float:
+ """Return data as a bfloat in big-endian format."""
+ if len(self) != 16:
+ raise bitstring.InterpretError("bfloat requires 16 bits.")
+ # Convert to 32-bit float by appending two zero bytes
+ byte_data = self._getbytes() + b'\x00\x00'
+ try:
+ return struct.unpack('>f', byte_data)[0]
+ except struct.error:
+ raise bitstring.InterpretError("Cannot interpret as bfloat.")
+
+ def _setbfloatle(self, value: float, length: Optional[int]=None) -> None:
+ """Reset the bitstring to have given bfloat interpretation in little-endian."""
+ if length is not None and length != 16:
+ raise ValueError("bfloat must be 16 bits.")
+ # Convert to 32-bit float first
+ try:
+ byte_data = struct.pack('<f', value)
+ # Take only the last two bytes (16 bits) for bfloat
+ self._setbytes(byte_data[2:])
+ except (struct.error, OverflowError):
+ raise ValueError("Float is too large for bfloat format.")
+
+ def _getbfloatle(self) -> float:
+ """Return data as a bfloat in little-endian format."""
+ if len(self) != 16:
+ raise bitstring.InterpretError("bfloat requires 16 bits.")
+ # Convert to 32-bit float by prepending two zero bytes
+ byte_data = b'\x00\x00' + self._getbytes()
+ try:
+ return struct.unpack('<f', byte_data)[0]
+ except struct.error:
+ raise bitstring.InterpretError("Cannot interpret as bfloat.")
+
+ def _setbits(self, bits: Bits, length: Optional[int]=None) -> None:
+ """Reset the bitstring to have given bits interpretation."""
+ if length is not None and len(bits) != length:
+ raise ValueError("Bits length {0} does not match required length {1}.".format(len(bits), length))
+ self._bitstore = bits._bitstore.copy()
+
+ def _getbits(self) -> Bits:
+ """Return data as a Bits object."""
+ return self.copy()
+
+ def _setbool(self, value: bool, length: Optional[int]=None) -> None:
+ """Reset the bitstring to have given bool interpretation."""
+ if length is not None and length != 1:
+ raise ValueError("Boolean values must be 1 bit long.")
+ self._bitstore = BitStore(1)
+ self._bitstore.setall(1 if value else 0)
+
+ def _getbool(self) -> bool:
+ """Return data as a bool."""
+ if len(self) != 1:
+ raise bitstring.InterpretError("Cannot interpret as bool: length must be 1 bit.")
+ return bool(self._bitstore.getindex(0))
+
+ def _setse(self, value: int, length: Optional[int]=None) -> None:
+ """Reset the bitstring to have given signed exponential-Golomb code interpretation."""
+ if length is not None:
+ raise ValueError("Length cannot be specified for signed exponential-Golomb codes.")
+ # Convert to unsigned by mapping negative values to positive ones
+ unsigned = (abs(value) << 1) - (1 if value > 0 else 0)
+ # Get the number of bits needed for the unsigned value
+ num_bits = unsigned.bit_length()
+ # Add leading zeros and the code
+ self._bitstore = BitStore(num_bits * 2 + 1)
+ self._bitstore.setall(0)
+ # Set the code bits
+ for i in range(num_bits):
+ self._bitstore.setindex(num_bits + i, (unsigned >> (num_bits - 1 - i)) & 1)
+
+ def _getse(self) -> int:
+ """Return data as a signed exponential-Golomb code."""
+ # Find the first 1 bit
+ for i in range(len(self)):
+ if self._bitstore.getindex(i):
+ # Get the code bits
+ code_bits = 0
+ for j in range(i + 1, min(2 * i + 1, len(self))):
+ code_bits = (code_bits << 1) | self._bitstore.getindex(j)
+ # Convert back to signed value
+ unsigned = code_bits + (1 << i)
+ # Map back to signed value
+ return (unsigned + 1) >> 1 if unsigned & 1 else -(unsigned >> 1)
+ raise bitstring.InterpretError("Cannot find any 1 bits in exponential-Golomb code.")
+
+ def _setue(self, value: int, length: Optional[int]=None) -> None:
+ """Reset the bitstring to have given unsigned exponential-Golomb code interpretation."""
+ if length is not None:
+ raise ValueError("Length cannot be specified for unsigned exponential-Golomb codes.")
+ if value < 0:
+ raise ValueError("Unsigned exponential-Golomb codes cannot be negative.")
+ # Get the number of bits needed for the value + 1
+ num_bits = (value + 1).bit_length()
+ # Add leading zeros and the code
+ self._bitstore = BitStore(num_bits * 2)
+ self._bitstore.setall(0)
+ # Set the code bits
+ for i in range(num_bits - 1):
+ self._bitstore.setindex(num_bits - 1 + i, (value >> (num_bits - 2 - i)) & 1)
+ # Set the terminating 1 bit
+ self._bitstore.setindex(num_bits - 1, 1)
+
+ def _getue(self) -> int:
+ """Return data as an unsigned exponential-Golomb code."""
+ # Find the first 1 bit
+ for i in range(len(self)):
+ if self._bitstore.getindex(i):
+ # Get the code bits
+ code_bits = 0
+ for j in range(i + 1, min(2 * i + 1, len(self))):
+ code_bits = (code_bits << 1) | self._bitstore.getindex(j)
+ # Convert back to unsigned value
+ return code_bits + (1 << i) - 1
+ raise bitstring.InterpretError("Cannot find any 1 bits in exponential-Golomb code.")
+
+ def _setsie(self, value: int, length: Optional[int]=None) -> None:
+ """Reset the bitstring to have given signed interleaved exponential-Golomb code interpretation."""
+ if length is not None:
+ raise ValueError("Length cannot be specified for signed interleaved exponential-Golomb codes.")
+ # Convert to unsigned by mapping negative values to positive ones
+ unsigned = abs(value) << 1
+ if value < 0:
+ unsigned -= 1
+ # Get the number of bits needed for the unsigned value
+ num_bits = unsigned.bit_length()
+ # Add leading zeros and the code
+ self._bitstore = BitStore(num_bits * 2)
+ self._bitstore.setall(0)
+ # Set the code bits
+ for i in range(num_bits - 1):
+ self._bitstore.setindex(num_bits - 1 + i, (unsigned >> (num_bits - 2 - i)) & 1)
+ # Set the terminating 1 bit
+ self._bitstore.setindex(num_bits - 1, 1)
+
+ def _getsie(self) -> int:
+ """Return data as a signed interleaved exponential-Golomb code."""
+ # Find the first 1 bit
+ for i in range(len(self)):
+ if self._bitstore.getindex(i):
+ # Get the code bits
+ code_bits = 0
+ for j in range(i + 1, min(2 * i + 1, len(self))):
+ code_bits = (code_bits << 1) | self._bitstore.getindex(j)
+ # Convert back to unsigned value
+ unsigned = code_bits + (1 << i) - 1
+ # Map back to signed value
+ return -(unsigned >> 1) - 1 if unsigned & 1 else unsigned >> 1
+ raise bitstring.InterpretError("Cannot find any 1 bits in exponential-Golomb code.")
+
+ def _setuie(self, value: int, length: Optional[int]=None) -> None:
+ """Reset the bitstring to have given unsigned interleaved exponential-Golomb code interpretation."""
+ if length is not None:
+ raise ValueError("Length cannot be specified for unsigned interleaved exponential-Golomb codes.")
+ if value < 0:
+ raise ValueError("Unsigned interleaved exponential-Golomb codes cannot be negative.")
+ # Get the number of bits needed for the value + 1
+ num_bits = (value + 1).bit_length()
+ # Add leading zeros and the code
+ self._bitstore = BitStore(num_bits * 2)
+ self._bitstore.setall(0)
+ # Set the code bits
+ for i in range(num_bits - 1):
+ self._bitstore.setindex(num_bits - 1 + i, (value >> (num_bits - 2 - i)) & 1)
+ # Set the terminating 1 bit
+ self._bitstore.setindex(num_bits - 1, 1)
+
+ def _getuie(self) -> int:
+ """Return data as an unsigned interleaved exponential-Golomb code."""
+ # Find the first 1 bit
+ for i in range(len(self)):
+ if self._bitstore.getindex(i):
+ # Get the code bits
+ code_bits = 0
+ for j in range(i + 1, min(2 * i + 1, len(self))):
+ code_bits = (code_bits << 1) | self._bitstore.getindex(j)
+ # Convert back to unsigned value
+ return code_bits + (1 << i) - 1
+ raise bitstring.InterpretError("Cannot find any 1 bits in exponential-Golomb code.")
+
+ def _setpad(self, value: None, length: Optional[int]=None) -> None:
+ """Reset the bitstring to have given padding bits interpretation."""
+ if length is None:
+ raise ValueError("Length must be specified for padding bits.")
+ if value is not None:
+ raise ValueError("Padding bits cannot have a value.")
+ self._bitstore = BitStore(length)
+ self._bitstore.setall(0)
+
+ def _getpad(self) -> None:
+ """Return data as padding bits (always returns None)."""
+ return None
+
+ def _setp3binary(self, value: float, length: Optional[int]=None) -> None:
+ """Reset the bitstring to have given p3binary float interpretation."""
+ if length is not None and length != 8:
+ raise ValueError("p3binary must be 8 bits.")
+ # Convert to bytes using the p3binary format
+ byte_data = p3binary_fmt.float_to_int8(value).to_bytes(1, byteorder='big', signed=False)
+ self._setbytes(byte_data)
+
+ def _getp3binary(self) -> float:
+ """Return data as a p3binary float."""
+ if len(self) != 8:
+ raise bitstring.InterpretError("p3binary requires 8 bits.")
+ # Convert from bytes using the p3binary format
+ byte_data = self._getbytes()
+ return p3binary_fmt.lut_binary8_to_float[byte_data[0]]
+
+ def _setp4binary(self, value: float, length: Optional[int]=None) -> None:
+ """Reset the bitstring to have given p4binary float interpretation."""
+ if length is not None and length != 8:
+ raise ValueError("p4binary must be 8 bits.")
+ # Convert to bytes using the p4binary format
+ byte_data = p4binary_fmt.float_to_int8(value).to_bytes(1, byteorder='big', signed=False)
+ self._setbytes(byte_data)
+
+ def _getp4binary(self) -> float:
+ """Return data as a p4binary float."""
+ if len(self) != 8:
+ raise bitstring.InterpretError("p4binary requires 8 bits.")
+ # Convert from bytes using the p4binary format
+ byte_data = self._getbytes()
+ return p4binary_fmt.lut_binary8_to_float[byte_data[0]]
+
+ def _sete4m3mxfp(self, value: float, length: Optional[int]=None) -> None:
+ """Reset the bitstring to have given e4m3mxfp float interpretation."""
+ if length is not None and length != 8:
+ raise ValueError("e4m3mxfp must be 8 bits.")
+ # Convert to bytes using the e4m3mxfp format
+ byte_data = e4m3mxfp_fmt.float_to_int(value).to_bytes(1, byteorder='big', signed=False)
+ self._setbytes(byte_data)
+
+ def _gete4m3mxfp(self) -> float:
+ """Return data as an e4m3mxfp float."""
+ if len(self) != 8:
+ raise bitstring.InterpretError("e4m3mxfp requires 8 bits.")
+ # Convert from bytes using the e4m3mxfp format
+ byte_data = self._getbytes()
+ return e4m3mxfp_fmt.lut_int_to_float[byte_data[0]]
+ pass
+
def _getuint(self) -> int:
"""Return data as an unsigned int."""
pass
diff --git a/bitstring/bitstore_helpers.py b/bitstring/bitstore_helpers.py
index 1a10140..8cd1bf9 100644
--- a/bitstring/bitstore_helpers.py
+++ b/bitstring/bitstore_helpers.py
@@ -12,6 +12,49 @@ CACHE_SIZE = 256
def tidy_input_string(s: str) -> str:
"""Return string made lowercase and with all whitespace and underscores removed."""
- pass
+ return s.lower().replace('_', '').replace(' ', '')
e8m0mxfp_allowed_values = [float(2 ** x) for x in range(-127, 128)]
-literal_bit_funcs: Dict[str, Callable[..., BitStore]] = {'0x': hex2bitstore, '0X': hex2bitstore, '0b': bin2bitstore, '0B': bin2bitstore, '0o': oct2bitstore, '0O': oct2bitstore}
\ No newline at end of file
+
+def hex2bitstore(s: str) -> BitStore:
+ """Convert a hex string to a BitStore."""
+ s = tidy_input_string(s)
+ if not s.startswith('0x'):
+ raise ValueError("Hex string must start with '0x'")
+ s = s[2:] # Remove '0x'
+ # Each hex digit represents 4 bits
+ bits = ''.join(format(int(c, 16), '04b') for c in s)
+ ba = bitarray.bitarray(bits)
+ return BitStore(ba)
+
+def bin2bitstore(s: str) -> BitStore:
+ """Convert a binary string to a BitStore."""
+ s = tidy_input_string(s)
+ if not s.startswith('0b'):
+ raise ValueError("Binary string must start with '0b'")
+ s = s[2:] # Remove '0b'
+ ba = bitarray.bitarray(s)
+ return BitStore(ba)
+
+def oct2bitstore(s: str) -> BitStore:
+ """Convert an octal string to a BitStore."""
+ s = tidy_input_string(s)
+ if not s.startswith('0o'):
+ raise ValueError("Octal string must start with '0o'")
+ s = s[2:] # Remove '0o'
+ # Each octal digit represents 3 bits
+ bits = ''.join(format(int(c, 8), '03b') for c in s)
+ ba = bitarray.bitarray(bits)
+ return BitStore(ba)
+
+literal_bit_funcs: Dict[str, Callable[..., BitStore]] = {'0x': hex2bitstore, '0X': hex2bitstore, '0b': bin2bitstore, '0B': bin2bitstore, '0o': oct2bitstore, '0O': oct2bitstore}
+
+def bitstore_from_token(token: str) -> BitStore:
+ """Create a BitStore from a token string.
+
+ The token can be a hex, binary or octal string.
+ """
+ token = tidy_input_string(token)
+ for prefix, func in literal_bit_funcs.items():
+ if token.startswith(prefix.lower()):
+ return func(token)
+ raise ValueError(f"Invalid token format: {token}. Must start with one of {list(literal_bit_funcs.keys())}")
\ No newline at end of file
diff --git a/bitstring/bitstring_options.py b/bitstring/bitstring_options.py
index 9df2b5a..dbf8265 100644
--- a/bitstring/bitstring_options.py
+++ b/bitstring/bitstring_options.py
@@ -7,13 +7,17 @@ class Options:
_instance = None
def __init__(self):
- self.set_lsb0(False)
+ self._lsb0 = False
self._bytealigned = False
self.mxfp_overflow = 'saturate'
self.no_color = False
no_color = os.getenv('NO_COLOR')
self.no_color = True if no_color else False
+ def set_lsb0(self, value: bool) -> None:
+ """Set whether bit numbering starts at 0 (True) or 1 (False) for the least significant bit."""
+ self._lsb0 = bool(value)
+
def __repr__(self) -> str:
attributes = {attr: getattr(self, attr) for attr in dir(self) if not attr.startswith('_') and (not callable(getattr(self, attr)))}
return '\n'.join((f'{attr}: {value!r}' for attr, value in attributes.items()))
diff --git a/bitstring/dtypes.py b/bitstring/dtypes.py
index 741d639..3f7552e 100644
--- a/bitstring/dtypes.py
+++ b/bitstring/dtypes.py
@@ -42,57 +42,57 @@ class Dtype:
@property
def scale(self) -> Union[int, float, None]:
"""The multiplicative scale applied when interpreting the data."""
- pass
+ return self._scale
@property
def name(self) -> str:
"""A string giving the name of the data type."""
- pass
+ return self._name
@property
def length(self) -> int:
"""The length of the data type in units of bits_per_item. Set to None for variable length dtypes."""
- pass
+ return self._length
@property
def bitlength(self) -> Optional[int]:
"""The number of bits needed to represent a single instance of the data type. Set to None for variable length dtypes."""
- pass
+ return self._bitlength
@property
def bits_per_item(self) -> int:
"""The number of bits for each unit of length. Usually 1, but equals 8 for bytes type."""
- pass
+ return self._bits_per_item
@property
def variable_length(self) -> bool:
"""If True then the length of the data type depends on the data being interpreted, and must not be specified."""
- pass
+ return self._variable_length
@property
def return_type(self) -> Any:
"""The type of the value returned by the parse method, such as int, float or str."""
- pass
+ return self._return_type
@property
def is_signed(self) -> bool:
"""If True then the data type represents a signed quantity."""
- pass
+ return self._is_signed
@property
def set_fn(self) -> Optional[Callable]:
"""A function to set the value of the data type."""
- pass
+ return self._set_fn
@property
def get_fn(self) -> Callable:
"""A function to get the value of the data type."""
- pass
+ return self._get_fn
@property
def read_fn(self) -> Callable:
"""A function to read the value of the data type."""
- pass
+ return self._read_fn
def __hash__(self) -> int:
return hash((self._name, self._length))
@@ -102,13 +102,23 @@ class Dtype:
The value parameter should be of a type appropriate to the dtype.
"""
- pass
+ if self._set_fn is None:
+ raise bitstring.CreationError(f"The '{self._name}' dtype cannot be used to create bitstrings.")
+ if self._set_fn_needs_length:
+ if self._length is None:
+ raise bitstring.CreationError(f"Cannot create a bitstring from a '{self._name}' dtype without a length.")
+ return self._set_fn(value, self._length)
+ return self._set_fn(value)
def parse(self, b: BitsType, /) -> Any:
"""Parse a bitstring to find its value.
The b parameter should be a bitstring of the appropriate length, or an object that can be converted to a bitstring."""
- pass
+ if not isinstance(b, bitstring.Bits):
+ b = bitstring.Bits(b)
+ if self._scale is None:
+ return self._get_fn(b)
+ return self._get_fn(b) * self._scale
def __str__(self) -> str:
if self._scale is not None:
@@ -164,6 +174,10 @@ class AllowedLengths:
return (other - self.values[0]) % (self.values[1] - self.values[0]) == 0
return other in self.values
+ def only_one_value(self) -> bool:
+ """Return True if there is exactly one allowed length."""
+ return len(self.values) == 1 and self.values[-1] is not Ellipsis
+
class DtypeDefinition:
"""Represents a class of dtypes, such as uint or float, rather than a concrete dtype such as uint8.
Not (yet) part of the public interface."""
@@ -250,7 +264,7 @@ class Register:
del cls.names[name]
def __repr__(self) -> str:
- s = [f'{'key':<12}:{'name':^12}{'signed':^8}{'set_fn_needs_length':^23}{'allowed_lengths':^16}{'multiplier':^12}{'return_type':<13}']
+ s = ["key: name signed set_fn_needs_length allowed_lengths multiplier return_type"]
s.append('-' * 85)
for key in self.names:
m = self.names[key]
diff --git a/bitstring/fp8.py b/bitstring/fp8.py
index 07bbfe2..8f94e96 100644
--- a/bitstring/fp8.py
+++ b/bitstring/fp8.py
@@ -24,10 +24,48 @@ class Binary8Format:
def float_to_int8(self, f: float) -> int:
"""Given a Python float convert to the best float8 (expressed as an integer in 0-255 range)."""
- pass
+ if math.isnan(f):
+ return 0
+ if math.isinf(f):
+ if f > 0:
+ return self.pos_clamp_value
+ return self.neg_clamp_value
+ if f == 0:
+ return 0
+ sign = 1 if f < 0 else 0
+ f = abs(f)
+ exp = math.floor(math.log2(f))
+ mantissa = int((f / 2**exp - 1) * (1 << (7 - self.exp_bits)))
+ exp = exp + self.bias
+ if exp < 0:
+ return 0
+ if exp >= (1 << self.exp_bits):
+ return self.neg_clamp_value if sign else self.pos_clamp_value
+ result = (sign << 7) | (exp << (7 - self.exp_bits)) | mantissa
+ return result
- def createLUT_for_binary8_to_float(self):
+ def createLUT_for_binary8_to_float(self) -> array.array:
"""Create a LUT to convert an int in range 0-255 representing a float8 into a Python float"""
- pass
+ lut = array.array('f')
+ for i in range(256):
+ sign = -1 if i & 0x80 else 1
+ exp = (i >> (7 - self.exp_bits)) & ((1 << self.exp_bits) - 1)
+ mantissa = i & ((1 << (7 - self.exp_bits)) - 1)
+ if exp == 0:
+ lut.append(0.0)
+ else:
+ value = sign * (1 + mantissa / (1 << (7 - self.exp_bits))) * 2**(exp - self.bias)
+ lut.append(value)
+ return lut
p4binary_fmt = Binary8Format(exp_bits=4, bias=8)
-p3binary_fmt = Binary8Format(exp_bits=5, bias=16)
\ No newline at end of file
+p3binary_fmt = Binary8Format(exp_bits=5, bias=16)
+
+def decompress_luts() -> None:
+ """Decompress the lookup tables for binary8 formats."""
+ for fmt in [p4binary_fmt, p3binary_fmt]:
+ if not hasattr(fmt, 'lut_float16_to_binary8'):
+ key = (fmt.exp_bits, fmt.bias)
+ compressed_data = binary8_luts_compressed[key]
+ fmt.lut_float16_to_binary8 = zlib.decompress(compressed_data[0])
+ if not hasattr(fmt, 'lut_binary8_to_float'):
+ fmt.lut_binary8_to_float = fmt.createLUT_for_binary8_to_float()
\ No newline at end of file
diff --git a/bitstring/mxfp.py b/bitstring/mxfp.py
index 2996509..0659099 100644
--- a/bitstring/mxfp.py
+++ b/bitstring/mxfp.py
@@ -37,19 +37,64 @@ class MXFPFormat:
def float_to_int(self, f: float) -> int:
"""Given a Python float convert to the best mxfp float (expressed as an int) that represents it."""
- pass
+ if math.isnan(f):
+ return 0
+ if math.isinf(f):
+ if f > 0:
+ return self.pos_clamp_value
+ return self.neg_clamp_value
+ if f == 0:
+ return 0
+ sign = 1 if f < 0 else 0
+ f = abs(f)
+ exp = math.floor(math.log2(f))
+ mantissa = int((f / 2**exp - 1) * (1 << self.mantissa_bits))
+ exp = exp + self.bias
+ if exp < 0:
+ return 0
+ if exp >= (1 << self.exp_bits):
+ if self.mxfp_overflow == 'saturate':
+ return self.neg_clamp_value if sign else self.pos_clamp_value
+ exp = (1 << self.exp_bits) - 1
+ mantissa = (1 << self.mantissa_bits) - 1
+ result = (sign << (self.exp_bits + self.mantissa_bits)) | (exp << self.mantissa_bits) | mantissa
+ return result
def createLUT_for_int_to_float(self) -> array.array:
"""Create a LUT to convert an int in representing a MXFP float into a Python float"""
- pass
+ lut = array.array('f')
+ for i in range(1 << (1 + self.exp_bits + self.mantissa_bits)):
+ sign = -1 if i >> (self.exp_bits + self.mantissa_bits) else 1
+ exp = (i >> self.mantissa_bits) & ((1 << self.exp_bits) - 1)
+ mantissa = i & ((1 << self.mantissa_bits) - 1)
+ if exp == 0:
+ lut.append(0.0)
+ else:
+ value = sign * (1 + mantissa / (1 << self.mantissa_bits)) * 2**(exp - self.bias)
+ lut.append(value)
+ return lut
def createLUT_for_float16_to_mxfp(self) -> bytes:
"""Create a LUT to convert a float16 into a MXFP format"""
- pass
+ lut = bytearray(65536)
+ for i in range(65536):
+ f = struct.unpack('e', struct.pack('H', i))[0]
+ lut[i] = self.float_to_int(f)
+ return bytes(lut)
e2m1mxfp_fmt = MXFPFormat(exp_bits=2, mantissa_bits=1, bias=1, mxfp_overflow='saturate')
e2m3mxfp_fmt = MXFPFormat(exp_bits=2, mantissa_bits=3, bias=1, mxfp_overflow='saturate')
e3m2mxfp_fmt = MXFPFormat(exp_bits=3, mantissa_bits=2, bias=3, mxfp_overflow='saturate')
e4m3mxfp_saturate_fmt = MXFPFormat(exp_bits=4, mantissa_bits=3, bias=7, mxfp_overflow='saturate')
e5m2mxfp_saturate_fmt = MXFPFormat(exp_bits=5, mantissa_bits=2, bias=15, mxfp_overflow='saturate')
e4m3mxfp_overflow_fmt = MXFPFormat(exp_bits=4, mantissa_bits=3, bias=7, mxfp_overflow='overflow')
-e5m2mxfp_overflow_fmt = MXFPFormat(exp_bits=5, mantissa_bits=2, bias=15, mxfp_overflow='overflow')
\ No newline at end of file
+e5m2mxfp_overflow_fmt = MXFPFormat(exp_bits=5, mantissa_bits=2, bias=15, mxfp_overflow='overflow')
+
+def decompress_luts() -> None:
+ """Decompress the lookup tables for MXFP formats."""
+ for fmt in [e2m1mxfp_fmt, e2m3mxfp_fmt, e3m2mxfp_fmt, e4m3mxfp_saturate_fmt, e5m2mxfp_saturate_fmt, e4m3mxfp_overflow_fmt, e5m2mxfp_overflow_fmt]:
+ if fmt.lut_float16_to_mxfp is None:
+ key = (fmt.exp_bits, fmt.mantissa_bits, fmt.bias, fmt.mxfp_overflow)
+ compressed_data = mxfp_luts_compressed[key]
+ fmt.lut_float16_to_mxfp = zlib.decompress(compressed_data[0])
+ if fmt.lut_int_to_float is None:
+ fmt.lut_int_to_float = fmt.createLUT_for_int_to_float()
\ No newline at end of file