[docs]
class AUXILARY:
"""
This class provides subroutines utilized by MLDSA, including function for data-type converstions and arithmetic.
"""
def __init__(self, parameter: dict[str, int]):
self.q = parameter["q"]
self.N = parameter["N"]
self.eta = parameter["eta"]
[docs]
def IntegerToBits(self, x: int, alpha: int):
"""
Algorithm 9
Computes the base-2 representation of x mod 2^alpha in ``little-endian`` order.
Args:
x (``int``): A ``non-negative`` integer.
alpha (``int``): Number of ``bits`` to represent.
Returns:
str (``bits``): ``Bitstring`` of length ``alpha`` in ``little-endian`` order.
Raises:
ValueError: If x is ``negative`` or alpha is not a positive integer.
ValueError: If x is ``too big`` to be represented in alpha bits.
TypeError: If x or alpha is ``not an integer``.
"""
if x < 0:
raise ValueError("x must be non-negative.")
if alpha <= 0:
raise ValueError("alpha must be a positive integer.")
if x >= 2 ** alpha:
raise ValueError(f"x = {x} cannot be represented in {alpha} bits.")
x_mod = x
bits:list[str] = []
for _ in range(alpha):
bits.append(str(x_mod % 2))
x_mod //= 2
return ''.join(bits)
[docs]
def BitsToInteger(self, y: str, alpha: int) -> int:
"""
Algorithm 10
Computes the ``integer`` value expressed by a bit string using ``little-endian`` order.
Args:
y (``bitstring``): ``Bitstring`` to convert to an integer.
alpha (``int``): Number of ``bits`` to consider.
Returns:
integer (``int``): The ``integer`` value represented by the ``bitstring``.
Raises:
ValueError: If the length of ``y`` does not match alpha or if ``y`` contains invalid characters.
TypeError: If ``y`` is not a ``string`` or alpha is not an ``integer``.
"""
if alpha <= 0:
raise ValueError("alpha must be a positive integer.")
if len(y) != alpha:
raise ValueError(f"Bit string y must have exactly {alpha} bits.")
if any(bit not in "01" for bit in y):
raise ValueError("Bit string y must contain only '0' and '1' characters.")
x = 0
for i in range(1, alpha + 1):
bit = int(y[alpha - i])
x = 2 * x + bit
return x
[docs]
def IntegerToBytes(self, x: int, alpha: int) -> bytes:
"""
Algorithm 11
Computes a base-256 representation of x mod 256^alpha in ``little-endian`` order.
Args:
x (``int``): A ``non-negative`` integer.
alpha (``int``): Number of ``bytes`` in the output.
Returns:
bytes (``bytes``): ``Bytestring`` of length alpha in ``little-endian`` order.
Raises:
ValueError: x = _ cannot be represented in _ bytes.
"""
try:
return x.to_bytes(alpha, byteorder = "little")
except OverflowError:
raise ValueError(f"x = {x} cannot be represented in {alpha} bytes.")
[docs]
def BitsToBytes(self, y: str) -> bytes:
"""
Algorithm 12
Converts a ``bitstring`` y into a ``bytestring`` using ``little-endian`` order.
Args:
y (``str``): ``Bitstring`` consisting of ``0`` and ``1``.
Returns:
bytes (``bytes``): ``Bytestring`` of length ceil ``(len(y) / 8)``.
Raises:
TypeError: If y is ``not a string``.
ValueError: If y contains characters other than ``0`` or ``1``.
"""
if any(bit not in "01" for bit in y):
raise ValueError("Bit string y must contain only '0' and '1'.")
alpha = len(y)
byte_len = (alpha + 7) // 8 # Equivalent to ceil(alpha / 8)
z = [0] * byte_len
for i in range(alpha):
byte_index = i // 8
bit_index = i % 8
z[byte_index] |= int(y[i]) << bit_index
return bytes(z)
[docs]
def BytesToBits(self, z: bytes) -> str:
"""
Algorithm 13
Converts a ``bytestring`` ``z`` into a ``bitstring`` in ``little-endian`` order.
Args:
z (``bytes``): A ``bytestring``.
Returns:
str (bit string): A ``bitstring`` of length ``8 * len(z)``, in ``little-endian`` order.
Raises:
TypeError: If ``z`` is not a ``bytes`` object.
"""
if not isinstance(z, (bytes, bytearray)):
raise TypeError("Input z must be a bytes or bytearray object.")
bits: list[str] = []
for byte in z:
for i in range(8):
bits.append(str((byte >> i) & 1)) # Little-endian bit order
return ''.join(bits)
[docs]
def CoeffFromThreeBytes(self, b0: int, b1: int, b2:int) -> int | None:
"""
Algorithm 14
Generates an element of {``0``, ``1``, ``2``, ... , ``q - 1``} U { ``None`` }
Args:
b0 (``int``): first byte
b1 (``int``): second byte
b2 (``int``): third byte
Returns:
z (``int``): sampled coefficient or ``None`` if rejected.
Raises:
TypeError: if any of ``b0``, ``b1``, ``b2`` is ``not an integer``.
"""
# checks for validity of inputs.
for i, b in enumerate((b0, b1, b2), start = 0):
if not (0 <= b <= 255):
raise ValueError (f"b{i} must be in the range 0 - 255.")
# line 1: make a copy of b2.
b2_prime = b2
# line 2 to 4: making sure b2_prime is 7 bits, not 8.
if b2_prime > 127:
b2_prime = b2_prime - 128
# line 5: evaluate z for sampling.
z = (b2_prime << 16) + (b1 << 8) + b0
# line 6 to 8: reject the sample z if it's greater than q.
if z < self.q:
return z # accept sample
else:
return None # reject sample
[docs]
def CoeffFromHalfByte(self, b: int) -> int | None:
"""
Algorithm 15
Let ``eta`` ∈ {2, 4}.
Generates an element of {``-eta``, ``-eta + 1``, ... , ``eta``} U { ``None`` }
Args:
b (``int``): an integer in the range ``0 - 15``.
Returns:
z (``int``): sampled coefficient or ``None`` if rejected.
Raises:
TypeError: if b is ``not an integer``.
ValueError: if b is ``not`` in the range ``0 - 15``.
"""
if not (0 <= b <= 15):
raise ValueError (f"{b} must be in the range 0 - 15.")
# line 1 and 2: rejection sampline from {-2, ... , 2 }
if self.eta == 2 and b < 15:
return 2 - (b % 5)
# line 3: rejection sampline from {-4, ... , 4 }
elif self.eta == 4 and b < 9:
return 4 - b
# line 4: sample is just rejected.
else:
return None
[docs]
def CenteredModulus(self, z: int) -> int:
"""
Additional Helper Function 1
Computes the centered modulus ``z mod± q``.
Maps each integer ``x`` to the unique ``r`` in ::
[-(q-1)/2, (q-1)/2]
such that ::
x ≡ r (mod q).
Args:
z(``int``): An integer
Returns:
CenteredModulus(``int``)
Raises:
TypeError: If ``z`` is ``not an integer``.
"""
half_q = (self.q - 1) // 2
return (z + half_q) % self.q - half_q
[docs]
def CenteredModulusList(self, z: list[int]) -> list[int]:
"""
Additional Helper Function 2
Computes the centered modulus ``z mod± q`` for a ``list`` of ``Integers``.
Maps each integer ``x`` to the unique ``r`` in ::
[-(q-1)/2, (q-1)/2]
such that ::
x ≡ r (mod q).
Args:
z(``list[int]``): A list of ``integers``.
Returns:
CenteredModulus(``list[int]``)
Raises:
TypeError: If ``z`` is not a ``list[int]``.
"""
half_q = (self.q - 1) // 2
return [(x + half_q) % self.q - half_q for x in z]
[docs]
def CenteredModulusMatrix(self, z: list[list[int]]) -> list[list[int]]:
"""
Additional Helper Function 3
Computes the centered modulus ``z mod± q`` for a ``matrix`` of ``Integers``.
Maps each integer ``x`` to the unique ``r`` in ::
[-(q-1)/2, (q-1)/2]
such that ::
x ≡ r (mod q).
Args:
z(``list[list[int]]``): A matrix of ``integers``.
Returns:
CenteredModulus(``list[list[int]]``)
Raises:
TypeError: If ``z`` is not a ``list[list[int]]``.
"""
half_q = (self.q - 1) // 2
return [[(x + half_q) % self.q - half_q for x in z[k]] for k in range(len(z))]
[docs]
def abs_for_list (self, z: list[int]) -> list[int]:
"""
Additional Helper Function 4
Computes the absolute values of a ``list`` of integers.
Args:
z (``list[int]``): A ``list`` of integers.
Returns:
|z| (``list[int]``): A ``list`` containing the ``absolute`` values of the input ``integers``.
"""
for p in range (len(z)):
z[p] = abs(z[p])
return z
[docs]
def InfinityNorm(self, z: list[list[int]]) -> int:
"""
Additional Helper Function 5
Compute the ``L-Infinity Norm`` of a ``Matrix``.
Args:
z (``list[list[int]]``): A ``Matrix`` of integers.
Returns:
infinity_norm (``int``): The infinity norm (``maximum`` absolute value) among all lists.
Raises:
TypeError: If ``z`` is not a list of lists of ``integers``.
TypeError: If any element in the sublists is ``not an integer``.
TypeError: If elements of ``z[x]`` are not lists.
"""
max_value = 0
for i in range (len(z)):
if max_value < max(self.abs_for_list(self.CenteredModulusList(z[i]))):
max_value = max(self.abs_for_list(self.CenteredModulusList(z[i])))
return max_value # returns the max value among all lists.
[docs]
def CalcOnes(self, h: list[list[int]]) -> int:
"""
Additional Helper Function 6
Compute the number of ``1`` s inside a ``list[list[int]]`` .
Args:
h (``list[list[int]]``): A ``matrix`` containing ``0`` s and ``1`` s.
Returns:
Count(``int``): The count of ``1`` s in the ``matrix``.
Raises:
TypeError: If ``h`` is not a ``matrix`` of integers.
TypeError: If any element in the sublists is not an integer .
TypeError: If elements of ``h[x]`` are not ``0`` or ``1``.
"""
count = 0
for i in range(len(h)):
for j in range(len(h[i])):
if h[i][j] == 1:
count = count + 1
return count