Source code for fips.FIPS204.auxilary

[docs] class AUXILARY: """ This class provides subroutines utilized by MLDSA, including function for data-type converstions and arithmetic. """ def __init__(self, parameter: dict[str, int]): self.q = parameter["q"] self.N = parameter["N"] self.eta = parameter["eta"]
[docs] def IntegerToBits(self, x: int, alpha: int): """ Algorithm 9 Computes the base-2 representation of x mod 2^alpha in ``little-endian`` order. Args: x (``int``): A ``non-negative`` integer. alpha (``int``): Number of ``bits`` to represent. Returns: str (``bits``): ``Bitstring`` of length ``alpha`` in ``little-endian`` order. Raises: ValueError: If x is ``negative`` or alpha is not a positive integer. ValueError: If x is ``too big`` to be represented in alpha bits. TypeError: If x or alpha is ``not an integer``. """ if x < 0: raise ValueError("x must be non-negative.") if alpha <= 0: raise ValueError("alpha must be a positive integer.") if x >= 2 ** alpha: raise ValueError(f"x = {x} cannot be represented in {alpha} bits.") x_mod = x bits:list[str] = [] for _ in range(alpha): bits.append(str(x_mod % 2)) x_mod //= 2 return ''.join(bits)
[docs] def BitsToInteger(self, y: str, alpha: int) -> int: """ Algorithm 10 Computes the ``integer`` value expressed by a bit string using ``little-endian`` order. Args: y (``bitstring``): ``Bitstring`` to convert to an integer. alpha (``int``): Number of ``bits`` to consider. Returns: integer (``int``): The ``integer`` value represented by the ``bitstring``. Raises: ValueError: If the length of ``y`` does not match alpha or if ``y`` contains invalid characters. TypeError: If ``y`` is not a ``string`` or alpha is not an ``integer``. """ if alpha <= 0: raise ValueError("alpha must be a positive integer.") if len(y) != alpha: raise ValueError(f"Bit string y must have exactly {alpha} bits.") if any(bit not in "01" for bit in y): raise ValueError("Bit string y must contain only '0' and '1' characters.") x = 0 for i in range(1, alpha + 1): bit = int(y[alpha - i]) x = 2 * x + bit return x
[docs] def IntegerToBytes(self, x: int, alpha: int) -> bytes: """ Algorithm 11 Computes a base-256 representation of x mod 256^alpha in ``little-endian`` order. Args: x (``int``): A ``non-negative`` integer. alpha (``int``): Number of ``bytes`` in the output. Returns: bytes (``bytes``): ``Bytestring`` of length alpha in ``little-endian`` order. Raises: ValueError: x = _ cannot be represented in _ bytes. """ try: return x.to_bytes(alpha, byteorder = "little") except OverflowError: raise ValueError(f"x = {x} cannot be represented in {alpha} bytes.")
[docs] def BitsToBytes(self, y: str) -> bytes: """ Algorithm 12 Converts a ``bitstring`` y into a ``bytestring`` using ``little-endian`` order. Args: y (``str``): ``Bitstring`` consisting of ``0`` and ``1``. Returns: bytes (``bytes``): ``Bytestring`` of length ceil ``(len(y) / 8)``. Raises: TypeError: If y is ``not a string``. ValueError: If y contains characters other than ``0`` or ``1``. """ if any(bit not in "01" for bit in y): raise ValueError("Bit string y must contain only '0' and '1'.") alpha = len(y) byte_len = (alpha + 7) // 8 # Equivalent to ceil(alpha / 8) z = [0] * byte_len for i in range(alpha): byte_index = i // 8 bit_index = i % 8 z[byte_index] |= int(y[i]) << bit_index return bytes(z)
[docs] def BytesToBits(self, z: bytes) -> str: """ Algorithm 13 Converts a ``bytestring`` ``z`` into a ``bitstring`` in ``little-endian`` order. Args: z (``bytes``): A ``bytestring``. Returns: str (bit string): A ``bitstring`` of length ``8 * len(z)``, in ``little-endian`` order. Raises: TypeError: If ``z`` is not a ``bytes`` object. """ if not isinstance(z, (bytes, bytearray)): raise TypeError("Input z must be a bytes or bytearray object.") bits: list[str] = [] for byte in z: for i in range(8): bits.append(str((byte >> i) & 1)) # Little-endian bit order return ''.join(bits)
[docs] def CoeffFromThreeBytes(self, b0: int, b1: int, b2:int) -> int | None: """ Algorithm 14 Generates an element of {``0``, ``1``, ``2``, ... , ``q - 1``} U { ``None`` } Args: b0 (``int``): first byte b1 (``int``): second byte b2 (``int``): third byte Returns: z (``int``): sampled coefficient or ``None`` if rejected. Raises: TypeError: if any of ``b0``, ``b1``, ``b2`` is ``not an integer``. """ # checks for validity of inputs. for i, b in enumerate((b0, b1, b2), start = 0): if not (0 <= b <= 255): raise ValueError (f"b{i} must be in the range 0 - 255.") # line 1: make a copy of b2. b2_prime = b2 # line 2 to 4: making sure b2_prime is 7 bits, not 8. if b2_prime > 127: b2_prime = b2_prime - 128 # line 5: evaluate z for sampling. z = (b2_prime << 16) + (b1 << 8) + b0 # line 6 to 8: reject the sample z if it's greater than q. if z < self.q: return z # accept sample else: return None # reject sample
[docs] def CoeffFromHalfByte(self, b: int) -> int | None: """ Algorithm 15 Let ``eta`` ∈ {2, 4}. Generates an element of {``-eta``, ``-eta + 1``, ... , ``eta``} U { ``None`` } Args: b (``int``): an integer in the range ``0 - 15``. Returns: z (``int``): sampled coefficient or ``None`` if rejected. Raises: TypeError: if b is ``not an integer``. ValueError: if b is ``not`` in the range ``0 - 15``. """ if not (0 <= b <= 15): raise ValueError (f"{b} must be in the range 0 - 15.") # line 1 and 2: rejection sampline from {-2, ... , 2 } if self.eta == 2 and b < 15: return 2 - (b % 5) # line 3: rejection sampline from {-4, ... , 4 } elif self.eta == 4 and b < 9: return 4 - b # line 4: sample is just rejected. else: return None
[docs] def CenteredModulus(self, z: int) -> int: """ Additional Helper Function 1 Computes the centered modulus ``z mod± q``. Maps each integer ``x`` to the unique ``r`` in :: [-(q-1)/2, (q-1)/2] such that :: x ≡ r (mod q). Args: z(``int``): An integer Returns: CenteredModulus(``int``) Raises: TypeError: If ``z`` is ``not an integer``. """ half_q = (self.q - 1) // 2 return (z + half_q) % self.q - half_q
[docs] def CenteredModulusList(self, z: list[int]) -> list[int]: """ Additional Helper Function 2 Computes the centered modulus ``z mod± q`` for a ``list`` of ``Integers``. Maps each integer ``x`` to the unique ``r`` in :: [-(q-1)/2, (q-1)/2] such that :: x ≡ r (mod q). Args: z(``list[int]``): A list of ``integers``. Returns: CenteredModulus(``list[int]``) Raises: TypeError: If ``z`` is not a ``list[int]``. """ half_q = (self.q - 1) // 2 return [(x + half_q) % self.q - half_q for x in z]
[docs] def CenteredModulusMatrix(self, z: list[list[int]]) -> list[list[int]]: """ Additional Helper Function 3 Computes the centered modulus ``z mod± q`` for a ``matrix`` of ``Integers``. Maps each integer ``x`` to the unique ``r`` in :: [-(q-1)/2, (q-1)/2] such that :: x ≡ r (mod q). Args: z(``list[list[int]]``): A matrix of ``integers``. Returns: CenteredModulus(``list[list[int]]``) Raises: TypeError: If ``z`` is not a ``list[list[int]]``. """ half_q = (self.q - 1) // 2 return [[(x + half_q) % self.q - half_q for x in z[k]] for k in range(len(z))]
[docs] def abs_for_list (self, z: list[int]) -> list[int]: """ Additional Helper Function 4 Computes the absolute values of a ``list`` of integers. Args: z (``list[int]``): A ``list`` of integers. Returns: |z| (``list[int]``): A ``list`` containing the ``absolute`` values of the input ``integers``. """ for p in range (len(z)): z[p] = abs(z[p]) return z
[docs] def InfinityNorm(self, z: list[list[int]]) -> int: """ Additional Helper Function 5 Compute the ``L-Infinity Norm`` of a ``Matrix``. Args: z (``list[list[int]]``): A ``Matrix`` of integers. Returns: infinity_norm (``int``): The infinity norm (``maximum`` absolute value) among all lists. Raises: TypeError: If ``z`` is not a list of lists of ``integers``. TypeError: If any element in the sublists is ``not an integer``. TypeError: If elements of ``z[x]`` are not lists. """ max_value = 0 for i in range (len(z)): if max_value < max(self.abs_for_list(self.CenteredModulusList(z[i]))): max_value = max(self.abs_for_list(self.CenteredModulusList(z[i]))) return max_value # returns the max value among all lists.
[docs] def CalcOnes(self, h: list[list[int]]) -> int: """ Additional Helper Function 6 Compute the number of ``1`` s inside a ``list[list[int]]`` . Args: h (``list[list[int]]``): A ``matrix`` containing ``0`` s and ``1`` s. Returns: Count(``int``): The count of ``1`` s in the ``matrix``. Raises: TypeError: If ``h`` is not a ``matrix`` of integers. TypeError: If any element in the sublists is not an integer . TypeError: If elements of ``h[x]`` are not ``0`` or ``1``. """ count = 0 for i in range(len(h)): for j in range(len(h[i])): if h[i][j] == 1: count = count + 1 return count