Source code for fips.FIPS204.pack

from .auxilary import AUXILARY

[docs] class PACK: """ These functions efficiently translate an element w belonging to R into a byte string and vica versa under the assumption that the coefficients of w are in a restricted range. """ def __init__(self, parameter: dict[str, int]): self.auxilary = AUXILARY(parameter) self.N = parameter["N"] self.q = parameter["q"] self.omega = parameter["omega"] self.k = parameter["k"]
[docs] def SimpleBitPack(self, w: list[int], b: int) -> bytes: """ Algorithm 16 Encodes a polynomial w into a byte string. Args: w (``list[int]``): Polynomial coefficients, must be of length ``256``. b (``int``): Upper bound of coefficients. Must be ``≥ 1``. Returns: bytes (``bytes``): bytestring of length ``32 * bitlen(b)``, representing packed ``bits``. Raises: TypeError: If types are invalid. ValueError: If values/lengths are out of expected bounds. """ if len(w) != 256: raise ValueError("Polynomial w must have exactly 256 coefficients.") if b < 1: raise ValueError("b must be at least 1.") if any(c < 0 or c > b for c in w): # print(w) raise ValueError(f"All coefficients must be in the range [0, {b}].") out = bytearray() bitlen = b.bit_length() current_bits = 0 buffer = 0 for val in w: buffer |= (val << current_bits) current_bits += bitlen while current_bits >= 8: out.append(buffer & 0xFF) buffer >>= 8 current_bits -= 8 if current_bits > 0: out.append(buffer & 0xFF) return bytes(out)
[docs] def BitPack(self, w: list[int], a: int, b: int) -> bytes: """ Algorithm 17 Encodes a polynomial ``w`` into a ``bytestring``. Args: w (``list[int]``): List of ``256`` polynomial coefficients in ``[-a, b]``. a (``int``): Lower bound (non-negative). b (``int``): Upper bound (non-negative). Returns: bytes (``bytes``): Packed byte string of length ``32 * bitlen(a + b)``. Raises: TypeError: If ``w`` is not a list of integers or ``a/b`` are not integers. ValueError: If length of ``w ≠ 256``, or any coefficient is out of range. """ if a < 0 or b < 0: raise ValueError("a and b must be non-negative integers.") if len(w) != 256: raise ValueError("w must have exactly 256 coefficients.") if min(w) < -a or max(w) > b: raise ValueError(f"Each coeffici ent must be in the range [-{a}, {b}].") # --- Bit Packing --- bitlen = (a + b).bit_length() # Instead of strings, we accumulate bits into a single python integer. accumulator = 0 shift = 0 for c in w: val = b - c # Shift the value into the correct position and add to accumulator accumulator |= (val << shift) shift += bitlen # 4. Convert to Bytes # Calculate exact number of bytes needed num_bytes = (256 * bitlen + 7) // 8 # Dump to bytes (Little Endian is standard for Dilithium/FIPS 204) return accumulator.to_bytes(num_bytes, byteorder='little')
[docs] def SimpleBitUnpack(self, v: bytes, b: int) -> list[int]: """ Algorithm 18 Reverses a procedure ``SimpleBitPack``. Args: v (``bytes``): Byte string of length ``32 * bitlen(b)``, result of simple_bit_pack. b (``int``): Upper bound for the original coefficients (must be ``≥ 1``). Returns: polynomial (``list[int]``): List of ``256`` unpacked coefficients in range ``[0, b]``. Raises: TypeError: If ``v`` is not bytes or ``b`` is not integer. ValueError: If ``v`` has invalid length or ``b`` is invalid. """ # --- Validation --- if not isinstance(v, (bytes, bytearray)): raise TypeError("Input v must be a byte string.") if b < 1: raise ValueError("b must be a positive integer.") c = b.bit_length() # bit length of b expected_length = 32 * c if len(v) != expected_length: raise ValueError(f"Input length must be {expected_length} bytes for b = {b}.") # --- Convert bytes to bitstring --- z = self.auxilary.BytesToBits(v) # --- Reconstruct coefficients --- w: list[int] = [] for i in range(256): start = i * c end = start + c bits = z[start:end] coeff = self.auxilary.BitsToInteger(bits, c) w.append(coeff) return w
[docs] def BitUnpack(self, v: bytes, a: int, b: int) -> list[int]: """ Algorithm 19 Reverses the procedure BitPack. Args: v (``bytes``): Packed byte string (length = ``32 * bitlen(a + b))``. a (``int``): Non-negative integer (lower bound). b (``int``): Non-negative integer (upper bound). Returns: polynomial (``list[int]``): List of ``256`` coefficients in the range ``[-a, b]``. Raises: TypeError: If input types are incorrect. ValueError: If ``v`` is invalid length or ``a/b`` are negative. """ # --- Input Checks --- if a < 0 or b < 0: raise ValueError("a and b must be non-negative.") # --- Step 1: Compute bit length --- c = (a + b).bit_length() expected_len = 32 * c if len(v) != expected_len: raise ValueError(f"Expected input length = {expected_len} bytes for a + b = {a + b}, got {len(v)}.") # --- Step 2: Convert bytes to bitstring --- z = self.auxilary.BytesToBits(v) # --- Step 3–5: Decode coefficients --- w:list[int] = [] for i in range(256): start = i * c end = start + c bits = z[start:end] decoded = self.auxilary.BitsToInteger(bits, c) w.append(b - decoded) return w
[docs] def HintBitPack(self, h: list[list[int]]) -> bytes: """ Algorithm 20 Encodes a polynomial vector ``h`` with binary coefficients into a ``bytestring``. Args: h (``list[list[int]]``): A list of ``k`` polynomials, each of ``256`` binary coefficients. Returns: bytes (``bytes``): Byte string of length ``omega + k`` representing the packed hint vector. Raises: ValueError: If h is malformed or contains more than omega number of 1s. """ if len(h) != self.k: raise ValueError(f"h must be a list of {self.k} polynomials.") if not all(len(poly) == 256 for poly in h): raise ValueError("Each polynomial in h must be a list of 256 binary coefficients.") if not all(c in (0, 1) for poly in h for c in poly): raise ValueError("All coefficients in h must be 0 or 1.") y = [0] * (self.omega + self.k) index = 0 for i in range(self.k): for j in range(256): if h[i][j] == 1: if index >= self.omega: raise ValueError("Number of 1s in h exceeds omega.") y[index] = j index += 1 if index > 255: raise ValueError("index exceeds 255; encoding would overflow a byte.") y[self.omega + i] = index return bytes(y)
[docs] def HintBitUnpack(self, y: bytes): """ Algorithm 21 Reverses the procedure HintBitPack. Args: y (``bytes``): ``Bytestring`` of length ``omega + k`` that encodes ``h``. Returns: Matrix (``list[list[int]]``): ``h``, a list of ``k`` polynomials, each being a list of ``w-bit`` coefficients. Raises: ValueError: If ``y`` is malformed or does not conform to expected structure. """ if not isinstance(y, (bytes, bytearray)): raise TypeError("Input y must be a byte string.") if len(y) != self.omega + self.k: raise ValueError(f"Invalid input length: expected {self.omega + self.k}, got {len(y)}") # Initialize h as a k×omega zero matrix (each polynomial has omega coefficients) h = [[0 for _ in range(self.N)] for _ in range(self.k)] index = 0 for i in range(self.k): if y[self.omega + i] < index or y[self.omega + i] > self.omega: # malformed input print("first") return None First = index while index < y[self.omega + i]: if index > First: if y[index - 1] >= y[index]: # malformed input print("second") return None h[i][y[index]] = 1 index = index + 1 for i in range(index, self.omega): if y[i] != 0: print("forth") return None return h