From f75a9692c7892b80fd25ec5cfc1f7b1dce6ea332 Mon Sep 17 00:00:00 2001 From: Bhargav Annem Date: Tue, 11 Mar 2025 17:34:28 -0700 Subject: [PATCH 01/10] feat(wip): ntt --- algebra/poly/ntt.py | 315 ++++++++++++++++++++++++++++++++++++++++++++ pyproject.toml | 1 + uv.lock | 23 ++++ 3 files changed, 339 insertions(+) create mode 100644 algebra/poly/ntt.py diff --git a/algebra/poly/ntt.py b/algebra/poly/ntt.py new file mode 100644 index 0000000..f934c66 --- /dev/null +++ b/algebra/poly/ntt.py @@ -0,0 +1,315 @@ +import random +import numpy as np +from sympy import isprime +from tinygrad.tensor import Tensor + +class NTT: + def isInteger(self, M): + return isinstance(M, int) + + def isPrime(self, M): + assert self.isInteger(M), 'Not an integer.' + return isprime(M) + + # modular exponential algorithm + # complexity is O(log N) + def modExponent(self, base, power, M): + """ + Standard implementation for single values + """ + result = 1 + power = int(power) + base = base % M + while power > 0: + if power & 1: + result = (result * base) % M + base = (base * base) % M + power = power >> 1 + return result + + # Vectorized version of modular exponentiation + def vector_modExponent(self, bases, power, M): + """ + Vectorized implementation using tinygrad + For an array of bases raised to the same power + + Parameters: + bases (array-like): Array of base values + power (int): Power to raise each base to + M (int): Modulus + + Returns: + Tensor: Result of modular exponentiation for each base + """ + # Convert inputs to Tensor + if not isinstance(bases, Tensor): + bases_tensor = Tensor(bases) + else: + bases_tensor = bases + + result = Tensor.ones(*bases_tensor.shape) + power_val = int(power) + bases_tensor = bases_tensor % M + + while power_val > 0: + if power_val & 1: + result = (result * bases_tensor) % M + bases_tensor = (bases_tensor * bases_tensor) % M + power_val = power_val >> 1 + + return result + + # calculate x^(-1) mod M + def modInv(self, x, M): + """ + Extended Euclidean Algorithm for modular inverse + """ + t, new_t, r, new_r = 0, 1, M, x + + while new_r != 0: + quotient = r // new_r + t, new_t = new_t, (t - quotient * new_t) + r, new_r = new_r, (r % new_r) + + if r > 1: + return "x is not invertible." + if t < 0: + t = t + M + return t + + # check if r^k = 1 (mod M), k> shift_bits) << shift_bits + + # Compute twiddle factors - use numpy directly for simplicity + w_P = np.array([self.modExponent(w, P, M) for P in P_values]) + + # Extract even and odd elements + even = poly_np[even_indices] + odd = poly_np[odd_indices] + + # Apply twiddle factors to odd elements + odd_twisted = (odd * w_P) % M + + # Compute DFT results + points1 = (even + odd_twisted) % M + points2 = (even - odd_twisted) % M + + # Combine results + poly_np = np.concatenate([points1, points2]) + poly_tensor = Tensor(poly_np) + + return poly_tensor + + # Inverse NTT implementation + def intt(self, points, M, N, w): + """ + Inverse number theoretic transform algorithm + + Parameters: + points (Tensor or list): Input points from the frequency domain + M (int): Modulus for the NTT + N (int): Size of the transform (power of 2) + w (int): Nth root of unity modulo M + + Returns: + Tensor: Result of the inverse NTT transform + """ + # Convert input to Tensor if it's not already + if not isinstance(points, Tensor): + points = Tensor(points) + + inv_w = self.modInv(w, M) + inv_N = self.modInv(N, M) + + # Perform normal NTT with inverse of w + poly = self.ntt(points, M, N, inv_w) + + # Apply scaling factor - using numpy for simplicity + poly_np = poly.numpy() + poly_np = (poly_np * inv_N) % M + + return Tensor(poly_np) + +def main(): + # Initialize the NTT class + ntt = NTT() + + # Choose parameters for NTT + # We need a prime modulus M such that (M-1) is divisible by N + N = 8 # Power of 2 + M = 17 # A prime where (M-1) is divisible by N + + # Verify our parameters + if not ntt.isPrime(M) or (M-1) % N != 0: + print(f"Invalid parameters: M={M} must be prime and (M-1) must be divisible by N={N}") + return + + # Generate a primitive Nth root of unity + w = ntt.NthRootOfUnity(M, N) + print(f"Primitive {N}th root of unity modulo {M}: {w}") + + # Create a sample polynomial (coefficients) + poly = Tensor([3, 1, 4, 1, 5, 9, 2, 6]) + print(f"Original polynomial: {poly.numpy()}") + + # Perform forward NTT + freq_domain = ntt.ntt(poly, M, N, w) + print(f"After NTT: {freq_domain.numpy()}") + + # Perform inverse NTT + time_domain = ntt.intt(freq_domain, M, N, w) + print(f"After INTT (should match original): {time_domain.numpy()}") + + # Verify the transformation is correct + if np.allclose(poly.numpy(), time_domain.numpy()): + print("✓ NTT transformation is correct!") + else: + print("✗ NTT transformation failed!") + + # Example of polynomial multiplication using NTT + # (a*b = INTT(NTT(a) * NTT(b))) + poly1 = Tensor([1, 2, 3, 4, 0, 0, 0, 0]) # 4x^3 + 3x^2 + 2x + 1 + poly2 = Tensor([5, 6, 7, 0, 0, 0, 0, 0]) # 7x^2 + 6x + 5 + + print(f"\nPolynomial 1: {poly1.numpy()}") + print(f"Polynomial 2: {poly2.numpy()}") + + # Transform to frequency domain + freq1 = ntt.ntt(poly1, M, N, w) + freq2 = ntt.ntt(poly2, M, N, w) + + # Multiply in frequency domain (element-wise) + freq_prod = (freq1 * freq2) % M + + # Transform back to time domain + poly_prod = ntt.intt(freq_prod, M, N, w) + print(f"Polynomial product: {poly_prod.numpy()}") + + # Let's calculate the expected result by direct convolution + expected = np.zeros(8, dtype=int) + for i in range(4): # poly1 has 4 non-zero terms + for j in range(3): # poly2 has 3 non-zero terms + expected[i+j] = (expected[i+j] + poly1.numpy()[i] * poly2.numpy()[j]) % M + + expected_tensor = Tensor(expected) + print(f"Expected (direct convolution): {expected}") + + # Check if our computed result matches the expected polynomial product + if np.allclose(poly_prod.numpy(), expected_tensor.numpy()): + print("✓ Polynomial multiplication using NTT is correct!") + else: + print("✗ Polynomial multiplication using NTT failed!") + print(f"Expected: {expected}") + print(f"Got: {poly_prod.numpy()}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 9f6e1a0..67ae1dd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,6 +8,7 @@ authors = [{ name = "Pia", email = "gayeongparkk@gmail.com" }] requires-python = ">=3.10" dependencies = [ "numpy>=2.2.3", + "sympy>=1.13.3", "tinygrad>=0.10.2", ] diff --git a/uv.lock b/uv.lock index dc6fdd9..a490d78 100644 --- a/uv.lock +++ b/uv.lock @@ -8,6 +8,7 @@ version = "0.1.0" source = { virtual = "." } dependencies = [ { name = "numpy" }, + { name = "sympy" }, { name = "tinygrad" }, ] @@ -20,6 +21,7 @@ dev = [ [package.metadata] requires-dist = [ { name = "numpy", specifier = ">=2.2.3" }, + { name = "sympy", specifier = ">=1.13.3" }, { name = "tinygrad", specifier = ">=0.10.2" }, ] @@ -56,6 +58,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ef/a6/62565a6e1cf69e10f5727360368e451d4b7f58beeac6173dc9db836a5b46/iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374", size = 5892 }, ] +[[package]] +name = "mpmath" +version = "1.3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e0/47/dd32fa426cc72114383ac549964eecb20ecfd886d1e5ccf5340b55b02f57/mpmath-1.3.0.tar.gz", hash = "sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f", size = 508106 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/43/e3/7d92a15f894aa0c9c4b49b8ee9ac9850d6e63b03c9c32c0367a13ae62209/mpmath-1.3.0-py3-none-any.whl", hash = "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c", size = 536198 }, +] + [[package]] name = "numpy" version = "2.2.3" @@ -178,6 +189,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/63/6a/aca01554949f3a401991dc32fe22837baeaccb8a0d868256cbb26a029778/ruff-0.9.7-py3-none-win_arm64.whl", hash = "sha256:b075a700b2533feb7a01130ff656a4ec0d5f340bb540ad98759b8401c32c2037", size = 10177763 }, ] +[[package]] +name = "sympy" +version = "1.13.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mpmath" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/11/8a/5a7fd6284fa8caac23a26c9ddf9c30485a48169344b4bd3b0f02fef1890f/sympy-1.13.3.tar.gz", hash = "sha256:b27fd2c6530e0ab39e275fc9b683895367e51d5da91baa8d3d64db2565fec4d9", size = 7533196 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/99/ff/c87e0622b1dadea79d2fb0b25ade9ed98954c9033722eb707053d310d4f3/sympy-1.13.3-py3-none-any.whl", hash = "sha256:54612cf55a62755ee71824ce692986f23c88ffa77207b30c1368eda4a7060f73", size = 6189483 }, +] + [[package]] name = "tinygrad" version = "0.10.2" From e1146668d3d5b4959f11499955a0fbc3dd12969d Mon Sep 17 00:00:00 2001 From: Bhargav Annem Date: Tue, 11 Mar 2025 17:40:46 -0700 Subject: [PATCH 02/10] fix(wip): functional --- algebra/poly/ntt.py | 464 +++++++++++++++++++++----------------------- 1 file changed, 223 insertions(+), 241 deletions(-) diff --git a/algebra/poly/ntt.py b/algebra/poly/ntt.py index f934c66..b5f2c11 100644 --- a/algebra/poly/ntt.py +++ b/algebra/poly/ntt.py @@ -3,270 +3,263 @@ from sympy import isprime from tinygrad.tensor import Tensor -class NTT: - def isInteger(self, M): - return isinstance(M, int) +# Default parameters +DEFAULT_N = 8 # Size of transform (power of 2) +DEFAULT_M = 17 # Prime modulus where (M-1) is divisible by N +DEFAULT_W = None # Will be calculated on first use - def isPrime(self, M): - assert self.isInteger(M), 'Not an integer.' - return isprime(M) +# Utility functions +def is_integer(M): + return isinstance(M, int) - # modular exponential algorithm - # complexity is O(log N) - def modExponent(self, base, power, M): - """ - Standard implementation for single values - """ - result = 1 - power = int(power) - base = base % M - while power > 0: - if power & 1: - result = (result * base) % M - base = (base * base) % M - power = power >> 1 - return result - - # Vectorized version of modular exponentiation - def vector_modExponent(self, bases, power, M): - """ - Vectorized implementation using tinygrad - For an array of bases raised to the same power - - Parameters: - bases (array-like): Array of base values - power (int): Power to raise each base to - M (int): Modulus - - Returns: - Tensor: Result of modular exponentiation for each base - """ - # Convert inputs to Tensor - if not isinstance(bases, Tensor): - bases_tensor = Tensor(bases) - else: - bases_tensor = bases - - result = Tensor.ones(*bases_tensor.shape) - power_val = int(power) - bases_tensor = bases_tensor % M - - while power_val > 0: - if power_val & 1: - result = (result * bases_tensor) % M - bases_tensor = (bases_tensor * bases_tensor) % M - power_val = power_val >> 1 - - return result - - # calculate x^(-1) mod M - def modInv(self, x, M): - """ - Extended Euclidean Algorithm for modular inverse - """ - t, new_t, r, new_r = 0, 1, M, x +def is_prime(M): + assert is_integer(M), 'Not an integer.' + return isprime(M) - while new_r != 0: - quotient = r // new_r - t, new_t = new_t, (t - quotient * new_t) - r, new_r = new_r, (r % new_r) - - if r > 1: - return "x is not invertible." - if t < 0: - t = t + M - return t +# Modular arithmetic functions +def mod_exponent(base, power, M): + """ + Standard implementation of modular exponentiation for single values + """ + result = 1 + power = int(power) + base = base % M + while power > 0: + if power & 1: + result = (result * base) % M + base = (base * base) % M + power = power >> 1 + return result - # check if r^k = 1 (mod M), k 0: + if power_val & 1: + result = (result * bases_tensor) % M + bases_tensor = (bases_tensor * bases_tensor) % M + power_val = power_val >> 1 - Returns: - bool: True if r^k = 1 (mod M) for any k < N - """ - # For small N, we can just check directly without vectorization - for k in range(2, N): - if self.modExponent(r, k, M) == 1: - return True - return False + return result + +def mod_inv(x, M): + """ + Extended Euclidean Algorithm for modular inverse + """ + t, new_t, r, new_r = 0, 1, M, x - # generate primitive nth root of unity - def NthRootOfUnity(self, M, N): - assert self.isPrime(M), 'Not a prime.' # modulus should be a prime - assert (M - 1) % N == 0, 'N cannot divide phi(M)' - phi_M = M - 1 + while new_r != 0: + quotient = r // new_r + t, new_t = new_t, (t - quotient * new_t) + r, new_r = new_r, (r % new_r) - while True: - alpha = random.randrange(1, M) - beta = self.modExponent(alpha, phi_M // N, M) - # check if beta can be k th root of unity, k 1: + return "x is not invertible." + if t < 0: + t = t + M + return t - # verify B^N = 1 (mod M) - def isNthRootOfUnity(self, M, N, beta): - return self.modExponent(beta, N, M) == 1 +def exist_small_order(r, M, N): + """ + Check if r is a primitive Nth root of unity + """ + for k in range(2, N): + if mod_exponent(r, k, M) == 1: + return True + return False - def bitReverse(self, num, length): - """ - Efficient bit reverse using numpy - """ - # Convert to binary string, pad, reverse, and convert back to int - return int(bin(num)[2:].zfill(length)[::-1], 2) +def get_nth_root_of_unity(M=DEFAULT_M, N=DEFAULT_N): + """ + Generate primitive nth root of unity + """ + assert is_prime(M), 'Not a prime.' # modulus should be a prime + assert (M - 1) % N == 0, 'N cannot divide phi(M)' + phi_M = M - 1 - def orderReverse(self, poly, N_bit): - """ - Vectorized bit reversal permutation - - Parameters: - poly (Tensor or list): Input polynomial coefficients - N_bit (int): Bit length for the reversal - - Returns: - Tensor: Bit-reversed polynomial - """ - # Convert input to Tensor if it's not already - if not isinstance(poly, Tensor): - poly = Tensor(poly) - - n = poly.shape[0] - - # Create indices and their bit-reversed counterparts - indices = np.arange(n) - bit_reversed = np.array([self.bitReverse(i, N_bit) for i in indices]) - - # Create a new tensor with bit-reversed order - poly_np = poly.numpy() - result = Tensor(poly_np[bit_reversed]) - - return result + while True: + alpha = random.randrange(1, M) + beta = mod_exponent(alpha, phi_M // N, M) + # check if beta can be k th root of unity, k> shift_bits) << shift_bits - poly_tensor = bit_reversed_poly + # Compute twiddle factors + w_P = np.array([mod_exponent(w, P, M) for P in P_values]) - for i in range(N_bit): - # Create vectors for even and odd indices - even_indices = np.arange(0, N, 2) - odd_indices = np.arange(1, N, 2) - - # Get current polynomial coefficients as numpy array for indexing - poly_np = poly_tensor.numpy() - - # Create powers for twiddle factors - shift_bits = N_bit - 1 - i - P_values = (np.arange(N//2) >> shift_bits) << shift_bits - - # Compute twiddle factors - use numpy directly for simplicity - w_P = np.array([self.modExponent(w, P, M) for P in P_values]) - - # Extract even and odd elements - even = poly_np[even_indices] - odd = poly_np[odd_indices] - - # Apply twiddle factors to odd elements - odd_twisted = (odd * w_P) % M - - # Compute DFT results - points1 = (even + odd_twisted) % M - points2 = (even - odd_twisted) % M - - # Combine results - poly_np = np.concatenate([points1, points2]) - poly_tensor = Tensor(poly_np) - - return poly_tensor - - # Inverse NTT implementation - def intt(self, points, M, N, w): - """ - Inverse number theoretic transform algorithm + # Extract even and odd elements + even = poly_np[even_indices] + odd = poly_np[odd_indices] - Parameters: - points (Tensor or list): Input points from the frequency domain - M (int): Modulus for the NTT - N (int): Size of the transform (power of 2) - w (int): Nth root of unity modulo M + # Apply twiddle factors to odd elements + odd_twisted = (odd * w_P) % M - Returns: - Tensor: Result of the inverse NTT transform - """ - # Convert input to Tensor if it's not already - if not isinstance(points, Tensor): - points = Tensor(points) - - inv_w = self.modInv(w, M) - inv_N = self.modInv(N, M) + # Compute DFT results + points1 = (even + odd_twisted) % M + points2 = (even - odd_twisted) % M - # Perform normal NTT with inverse of w - poly = self.ntt(points, M, N, inv_w) + # Combine results + poly_np = np.concatenate([points1, points2]) + poly_tensor = Tensor(poly_np) - # Apply scaling factor - using numpy for simplicity - poly_np = poly.numpy() - poly_np = (poly_np * inv_N) % M + return poly_tensor + +def intt(points, N=DEFAULT_N, M=DEFAULT_M, w=DEFAULT_W): + """ + Inverse number theoretic transform algorithm + """ + global DEFAULT_W + + # If w is not provided, calculate it + if w is None: + if DEFAULT_W is None: + DEFAULT_W = get_nth_root_of_unity(M, N) + w = DEFAULT_W + + # Convert input to Tensor if it's not already + if not isinstance(points, Tensor): + points = Tensor(points) - return Tensor(poly_np) + inv_w = mod_inv(w, M) + inv_N = mod_inv(N, M) + + # Perform normal NTT with inverse of w + poly = ntt(points, N, M, inv_w) + + # Apply scaling factor + poly_np = poly.numpy() + poly_np = (poly_np * inv_N) % M + + return Tensor(poly_np) -def main(): - # Initialize the NTT class - ntt = NTT() +# Polynomial operations +def polynomial_multiply(poly1, poly2, N=DEFAULT_N, M=DEFAULT_M, w=DEFAULT_W): + """ + Multiply two polynomials using NTT + """ + # Check if inputs are Tensors + if not isinstance(poly1, Tensor): + poly1 = Tensor(poly1) + if not isinstance(poly2, Tensor): + poly2 = Tensor(poly2) + + # Transform to frequency domain + freq1 = ntt(poly1, N, M, w) + freq2 = ntt(poly2, N, M, w) - # Choose parameters for NTT - # We need a prime modulus M such that (M-1) is divisible by N - N = 8 # Power of 2 - M = 17 # A prime where (M-1) is divisible by N + # Multiply in frequency domain (element-wise) + freq_prod = (freq1 * freq2) % M + + # Transform back to time domain + poly_prod = intt(freq_prod, N, M, w) + return poly_prod + +# Example usage function +def example(): # Verify our parameters - if not ntt.isPrime(M) or (M-1) % N != 0: - print(f"Invalid parameters: M={M} must be prime and (M-1) must be divisible by N={N}") + if not is_prime(DEFAULT_M) or (DEFAULT_M-1) % DEFAULT_N != 0: + print(f"Invalid parameters: M={DEFAULT_M} must be prime and (M-1) must be divisible by N={DEFAULT_N}") return # Generate a primitive Nth root of unity - w = ntt.NthRootOfUnity(M, N) - print(f"Primitive {N}th root of unity modulo {M}: {w}") + w = get_nth_root_of_unity(DEFAULT_M, DEFAULT_N) + print(f"Primitive {DEFAULT_N}th root of unity modulo {DEFAULT_M}: {w}") # Create a sample polynomial (coefficients) poly = Tensor([3, 1, 4, 1, 5, 9, 2, 6]) print(f"Original polynomial: {poly.numpy()}") # Perform forward NTT - freq_domain = ntt.ntt(poly, M, N, w) + freq_domain = ntt(poly) print(f"After NTT: {freq_domain.numpy()}") # Perform inverse NTT - time_domain = ntt.intt(freq_domain, M, N, w) + time_domain = intt(freq_domain) print(f"After INTT (should match original): {time_domain.numpy()}") # Verify the transformation is correct @@ -276,40 +269,29 @@ def main(): print("✗ NTT transformation failed!") # Example of polynomial multiplication using NTT - # (a*b = INTT(NTT(a) * NTT(b))) poly1 = Tensor([1, 2, 3, 4, 0, 0, 0, 0]) # 4x^3 + 3x^2 + 2x + 1 poly2 = Tensor([5, 6, 7, 0, 0, 0, 0, 0]) # 7x^2 + 6x + 5 print(f"\nPolynomial 1: {poly1.numpy()}") print(f"Polynomial 2: {poly2.numpy()}") - # Transform to frequency domain - freq1 = ntt.ntt(poly1, M, N, w) - freq2 = ntt.ntt(poly2, M, N, w) - - # Multiply in frequency domain (element-wise) - freq_prod = (freq1 * freq2) % M - - # Transform back to time domain - poly_prod = ntt.intt(freq_prod, M, N, w) + # Multiply polynomials using our function + poly_prod = polynomial_multiply(poly1, poly2) print(f"Polynomial product: {poly_prod.numpy()}") # Let's calculate the expected result by direct convolution expected = np.zeros(8, dtype=int) for i in range(4): # poly1 has 4 non-zero terms for j in range(3): # poly2 has 3 non-zero terms - expected[i+j] = (expected[i+j] + poly1.numpy()[i] * poly2.numpy()[j]) % M + expected[i+j] = (expected[i+j] + poly1.numpy()[i] * poly2.numpy()[j]) % DEFAULT_M - expected_tensor = Tensor(expected) print(f"Expected (direct convolution): {expected}") # Check if our computed result matches the expected polynomial product - if np.allclose(poly_prod.numpy(), expected_tensor.numpy()): + if np.allclose(poly_prod.numpy(), expected): print("✓ Polynomial multiplication using NTT is correct!") else: print("✗ Polynomial multiplication using NTT failed!") - print(f"Expected: {expected}") - print(f"Got: {poly_prod.numpy()}") if __name__ == "__main__": - main() \ No newline at end of file + example() \ No newline at end of file From a16cb761f26b6872c42f0ef70eb6fea7890f5685 Mon Sep 17 00:00:00 2001 From: Bhargav Annem Date: Tue, 11 Mar 2025 18:26:17 -0700 Subject: [PATCH 03/10] feat(wip): matrix ntt --- algebra/poly/ntt.py | 382 ++++++++++++++------------------------------ 1 file changed, 116 insertions(+), 266 deletions(-) diff --git a/algebra/poly/ntt.py b/algebra/poly/ntt.py index b5f2c11..bf5a633 100644 --- a/algebra/poly/ntt.py +++ b/algebra/poly/ntt.py @@ -1,297 +1,147 @@ -import random -import numpy as np -from sympy import isprime from tinygrad.tensor import Tensor +from tinygrad import dtypes -# Default parameters -DEFAULT_N = 8 # Size of transform (power of 2) -DEFAULT_M = 17 # Prime modulus where (M-1) is divisible by N -DEFAULT_W = None # Will be calculated on first use - -# Utility functions -def is_integer(M): - return isinstance(M, int) - -def is_prime(M): - assert is_integer(M), 'Not an integer.' - return isprime(M) - -# Modular arithmetic functions -def mod_exponent(base, power, M): - """ - Standard implementation of modular exponentiation for single values - """ - result = 1 - power = int(power) - base = base % M - while power > 0: - if power & 1: - result = (result * base) % M - base = (base * base) % M - power = power >> 1 - return result - -def vector_mod_exponent(bases, power, M): - """ - Vectorized implementation using tinygrad - For an array of bases raised to the same power +def ntt_matrix(n, prime, primitive_root): """ - # Convert inputs to Tensor - if not isinstance(bases, Tensor): - bases_tensor = Tensor(bases) - else: - bases_tensor = bases - - result = Tensor.ones(*bases_tensor.shape) - power_val = int(power) - bases_tensor = bases_tensor % M + Generate the NTT matrix for a given size n, prime modulus, and primitive root. - while power_val > 0: - if power_val & 1: - result = (result * bases_tensor) % M - bases_tensor = (bases_tensor * bases_tensor) % M - power_val = power_val >> 1 - - return result - -def mod_inv(x, M): - """ - Extended Euclidean Algorithm for modular inverse - """ - t, new_t, r, new_r = 0, 1, M, x - - while new_r != 0: - quotient = r // new_r - t, new_t = new_t, (t - quotient * new_t) - r, new_r = new_r, (r % new_r) + Args: + n: Size of the transform + prime: The prime modulus + primitive_root: A primitive n-th root of unity modulo prime - if r > 1: - return "x is not invertible." - if t < 0: - t = t + M - return t - -def exist_small_order(r, M, N): + Returns: + n x n Tensor for the NTT transform """ - Check if r is a primitive Nth root of unity - """ - for k in range(2, N): - if mod_exponent(r, k, M) == 1: - return True - return False - -def get_nth_root_of_unity(M=DEFAULT_M, N=DEFAULT_N): - """ - Generate primitive nth root of unity - """ - assert is_prime(M), 'Not a prime.' # modulus should be a prime - assert (M - 1) % N == 0, 'N cannot divide phi(M)' - phi_M = M - 1 - - while True: - alpha = random.randrange(1, M) - beta = mod_exponent(alpha, phi_M // N, M) - # check if beta can be k th root of unity, k> shift_bits) << shift_bits + Returns: + n x n Tensor for the inverse NTT transform + """ + # Create omega as the primitive n-th root of unity + omega = pow(primitive_root, (prime - 1) // n, prime) + # Compute inverse of omega + omega_inv = pow(omega, prime - 2, prime) + + # Generate the inverse NTT matrix + matrix = Tensor.zeros(n, n, dtype=dtypes.uint32).contiguous() + + # Fill the matrix manually using elementwise operations + for i in range(n): + for j in range(n): + # Compute omega_inv^(i*j) mod prime + power = (i * j) % n + value = pow(omega_inv, power, prime) + # Assign value to matrix position (i,j) + row = matrix[i].numpy() + row[j] = value + matrix[i] = Tensor(row, dtype=dtypes.uint32) + + # Multiply by n_inv (modular multiplicative inverse of n) + n_inv = pow(n, prime - 2, prime) + matrix = (matrix * n_inv) % prime + + return matrix + +def ntt(polynomial, prime, primitive_root): + """ + Compute the Number Theoretic Transform of a polynomial. + + Args: + polynomial: Coefficient vector of the polynomial (Tensor or list/array) + prime: The prime modulus + primitive_root: A primitive root modulo prime - # Compute twiddle factors - w_P = np.array([mod_exponent(w, P, M) for P in P_values]) - - # Extract even and odd elements - even = poly_np[even_indices] - odd = poly_np[odd_indices] - - # Apply twiddle factors to odd elements - odd_twisted = (odd * w_P) % M - - # Compute DFT results - points1 = (even + odd_twisted) % M - points2 = (even - odd_twisted) % M - - # Combine results - poly_np = np.concatenate([points1, points2]) - poly_tensor = Tensor(poly_np) - - return poly_tensor - -def intt(points, N=DEFAULT_N, M=DEFAULT_M, w=DEFAULT_W): + Returns: + The NTT of the polynomial as a Tensor """ - Inverse number theoretic transform algorithm - """ - global DEFAULT_W + # Convert polynomial to Tensor if it's not already + if not isinstance(polynomial, Tensor): + polynomial = Tensor(polynomial, dtype=dtypes.uint32) - # If w is not provided, calculate it - if w is None: - if DEFAULT_W is None: - DEFAULT_W = get_nth_root_of_unity(M, N) - w = DEFAULT_W + n = len(polynomial) + matrix = ntt_matrix(n, prime, primitive_root) - # Convert input to Tensor if it's not already - if not isinstance(points, Tensor): - points = Tensor(points) - - inv_w = mod_inv(w, M) - inv_N = mod_inv(N, M) + # Ensure polynomial coefficients are within the field + polynomial = polynomial % prime - # Perform normal NTT with inverse of w - poly = ntt(points, N, M, inv_w) + # Compute the matrix multiplication (dot product) and take modulo prime + result = matrix.matmul(polynomial) % prime - # Apply scaling factor - poly_np = poly.numpy() - poly_np = (poly_np * inv_N) % M - - return Tensor(poly_np) + return result -# Polynomial operations -def polynomial_multiply(poly1, poly2, N=DEFAULT_N, M=DEFAULT_M, w=DEFAULT_W): +def intt(transformed, prime, primitive_root): """ - Multiply two polynomials using NTT + Compute the Inverse Number Theoretic Transform. + + Args: + transformed: The transformed polynomial (Tensor or list/array) + prime: The prime modulus + primitive_root: A primitive root modulo prime + + Returns: + The original polynomial coefficients as a Tensor """ - # Check if inputs are Tensors - if not isinstance(poly1, Tensor): - poly1 = Tensor(poly1) - if not isinstance(poly2, Tensor): - poly2 = Tensor(poly2) + # Convert transformed to Tensor if it's not already + if not isinstance(transformed, Tensor): + transformed = Tensor(transformed, dtype=dtypes.uint32) - # Transform to frequency domain - freq1 = ntt(poly1, N, M, w) - freq2 = ntt(poly2, N, M, w) + n = len(transformed) + matrix = intt_matrix(n, prime, primitive_root) - # Multiply in frequency domain (element-wise) - freq_prod = (freq1 * freq2) % M + # Ensure transformed values are within the field + transformed = transformed % prime - # Transform back to time domain - poly_prod = intt(freq_prod, N, M, w) + # Compute the matrix multiplication (dot product) and take modulo prime + result = matrix.matmul(transformed) % prime - return poly_prod + return result -# Example usage function -def example(): - # Verify our parameters - if not is_prime(DEFAULT_M) or (DEFAULT_M-1) % DEFAULT_N != 0: - print(f"Invalid parameters: M={DEFAULT_M} must be prime and (M-1) must be divisible by N={DEFAULT_N}") - return - - # Generate a primitive Nth root of unity - w = get_nth_root_of_unity(DEFAULT_M, DEFAULT_N) - print(f"Primitive {DEFAULT_N}th root of unity modulo {DEFAULT_M}: {w}") - - # Create a sample polynomial (coefficients) - poly = Tensor([3, 1, 4, 1, 5, 9, 2, 6]) - print(f"Original polynomial: {poly.numpy()}") - - # Perform forward NTT - freq_domain = ntt(poly) - print(f"After NTT: {freq_domain.numpy()}") - - # Perform inverse NTT - time_domain = intt(freq_domain) - print(f"After INTT (should match original): {time_domain.numpy()}") - - # Verify the transformation is correct - if np.allclose(poly.numpy(), time_domain.numpy()): - print("✓ NTT transformation is correct!") - else: - print("✗ NTT transformation failed!") - - # Example of polynomial multiplication using NTT - poly1 = Tensor([1, 2, 3, 4, 0, 0, 0, 0]) # 4x^3 + 3x^2 + 2x + 1 - poly2 = Tensor([5, 6, 7, 0, 0, 0, 0, 0]) # 7x^2 + 6x + 5 - - print(f"\nPolynomial 1: {poly1.numpy()}") - print(f"Polynomial 2: {poly2.numpy()}") +# Example usage +if __name__ == "__main__": + # Parameters for a small example + n = 8 # Must be a power of 2 + prime = 17 # A prime number where (prime - 1) is divisible by n + primitive_root = 3 # A primitive root modulo prime - # Multiply polynomials using our function - poly_prod = polynomial_multiply(poly1, poly2) - print(f"Polynomial product: {poly_prod.numpy()}") + # Example polynomial: x^3 + 2x^2 + 3x + 4 + polynomial = Tensor([4, 3, 2, 1, 0, 0, 0, 0], dtype=dtypes.uint32) - # Let's calculate the expected result by direct convolution - expected = np.zeros(8, dtype=int) - for i in range(4): # poly1 has 4 non-zero terms - for j in range(3): # poly2 has 3 non-zero terms - expected[i+j] = (expected[i+j] + poly1.numpy()[i] * poly2.numpy()[j]) % DEFAULT_M + # Perform NTT + transformed = ntt(polynomial, prime, primitive_root) + print("NTT result:", transformed.numpy()) - print(f"Expected (direct convolution): {expected}") + # Perform INTT to get back the original polynomial + original = intt(transformed, prime, primitive_root) + print("INTT result (should match original polynomial):", original.numpy()) - # Check if our computed result matches the expected polynomial product - if np.allclose(poly_prod.numpy(), expected): - print("✓ Polynomial multiplication using NTT is correct!") - else: - print("✗ Polynomial multiplication using NTT failed!") - -if __name__ == "__main__": - example() \ No newline at end of file + # Verify the result + assert (polynomial.numpy() == original.numpy()).all(), "INTT(NTT(polynomial)) != polynomial" + print("Verification successful: INTT(NTT(polynomial)) = polynomial") \ No newline at end of file From e10914d8accec9b96a59bd5697ecc932149bdcce Mon Sep 17 00:00:00 2001 From: Bhargav Annem Date: Tue, 11 Mar 2025 19:05:55 -0700 Subject: [PATCH 04/10] feat(wip): vectorized ntt matrices --- algebra/poly/ntt.py | 98 +++++++++++++++++++-------------------------- 1 file changed, 41 insertions(+), 57 deletions(-) diff --git a/algebra/poly/ntt.py b/algebra/poly/ntt.py index bf5a633..dfc7025 100644 --- a/algebra/poly/ntt.py +++ b/algebra/poly/ntt.py @@ -2,73 +2,57 @@ from tinygrad import dtypes def ntt_matrix(n, prime, primitive_root): - """ - Generate the NTT matrix for a given size n, prime modulus, and primitive root. - - Args: - n: Size of the transform - prime: The prime modulus - primitive_root: A primitive n-th root of unity modulo prime - - Returns: - n x n Tensor for the NTT transform - """ # Create omega as the primitive n-th root of unity omega = pow(primitive_root, (prime - 1) // n, prime) - # Generate the NTT matrix using tinygrad Tensor - matrix = Tensor.zeros(n, n, dtype=dtypes.uint32).contiguous() - - # Fill the matrix manually using elementwise operations - for i in range(n): - for j in range(n): - # Compute omega^(i*j) mod prime - power = (i * j) % n - value = pow(omega, power, prime) - # Assign value to matrix position (i,j) - # We need to use a temporary tensor and update row by row - row = matrix[i].numpy() - row[j] = value - matrix[i] = Tensor(row, dtype=dtypes.uint32) - - return matrix + # Create row and column indices for broadcasting + i = Tensor.arange(n, dtype=dtypes.int32).reshape(n, 1) + j = Tensor.arange(n, dtype=dtypes.int32).reshape(1, n) + + # Compute powers matrix (i*j) % n using tinygrad operations + powers = (i * j) % n + + # Precompute all powers of omega as a tensor directly + omega_values = [] + current = 1 + for k in range(n): + omega_values.append(current) + current = (current * omega) % prime + + omega_tensor = Tensor(omega_values, dtype=dtypes.int32) + + # Using advanced indexing with the powers tensor to get omega^(i*j) % prime + ntt_matrix = omega_tensor[powers] + + return ntt_matrix def intt_matrix(n, prime, primitive_root): - """ - Generate the inverse NTT matrix for a given size n, prime modulus, and primitive root. - - Args: - n: Size of the transform - prime: The prime modulus - primitive_root: A primitive n-th root of unity modulo prime - - Returns: - n x n Tensor for the inverse NTT transform - """ # Create omega as the primitive n-th root of unity omega = pow(primitive_root, (prime - 1) // n, prime) - # Compute inverse of omega omega_inv = pow(omega, prime - 2, prime) - # Generate the inverse NTT matrix - matrix = Tensor.zeros(n, n, dtype=dtypes.uint32).contiguous() - - # Fill the matrix manually using elementwise operations - for i in range(n): - for j in range(n): - # Compute omega_inv^(i*j) mod prime - power = (i * j) % n - value = pow(omega_inv, power, prime) - # Assign value to matrix position (i,j) - row = matrix[i].numpy() - row[j] = value - matrix[i] = Tensor(row, dtype=dtypes.uint32) - - # Multiply by n_inv (modular multiplicative inverse of n) - n_inv = pow(n, prime - 2, prime) - matrix = (matrix * n_inv) % prime + # Create row and column indices for broadcasting + i = Tensor.arange(n, dtype=dtypes.int32).reshape(n, 1) + j = Tensor.arange(n, dtype=dtypes.int32).reshape(1, n) + + # Compute powers matrix (i*j) % n using tinygrad operations + powers = (i * j) % n + + # Precompute all powers of omega as a tensor directly + omega_values = [] + current = 1 + for k in range(n): + omega_values.append(current) + current = (current * omega_inv) % prime - return matrix + omega_tensor = Tensor(omega_values, dtype=dtypes.int32) + + # Using advanced indexing with the powers tensor to get omega^(i*j) % prime + ntt_matrix = omega_tensor[powers] + n_inv = pow(n, prime - 2, prime) + ntt_matrix = (ntt_matrix * n_inv) % prime + + return ntt_matrix def ntt(polynomial, prime, primitive_root): """ From 75901286e6070931982d0033ddc4b72e59c357c6 Mon Sep 17 00:00:00 2001 From: Bhargav Annem Date: Tue, 11 Mar 2025 19:09:27 -0700 Subject: [PATCH 05/10] chore: clen --- algebra/poly/ntt.py | 132 ++++++++++++++------------------------------ 1 file changed, 41 insertions(+), 91 deletions(-) diff --git a/algebra/poly/ntt.py b/algebra/poly/ntt.py index dfc7025..b6b3ca1 100644 --- a/algebra/poly/ntt.py +++ b/algebra/poly/ntt.py @@ -1,58 +1,6 @@ from tinygrad.tensor import Tensor from tinygrad import dtypes - -def ntt_matrix(n, prime, primitive_root): - # Create omega as the primitive n-th root of unity - omega = pow(primitive_root, (prime - 1) // n, prime) - - # Create row and column indices for broadcasting - i = Tensor.arange(n, dtype=dtypes.int32).reshape(n, 1) - j = Tensor.arange(n, dtype=dtypes.int32).reshape(1, n) - - # Compute powers matrix (i*j) % n using tinygrad operations - powers = (i * j) % n - - # Precompute all powers of omega as a tensor directly - omega_values = [] - current = 1 - for k in range(n): - omega_values.append(current) - current = (current * omega) % prime - - omega_tensor = Tensor(omega_values, dtype=dtypes.int32) - - # Using advanced indexing with the powers tensor to get omega^(i*j) % prime - ntt_matrix = omega_tensor[powers] - - return ntt_matrix - -def intt_matrix(n, prime, primitive_root): - # Create omega as the primitive n-th root of unity - omega = pow(primitive_root, (prime - 1) // n, prime) - omega_inv = pow(omega, prime - 2, prime) - - # Create row and column indices for broadcasting - i = Tensor.arange(n, dtype=dtypes.int32).reshape(n, 1) - j = Tensor.arange(n, dtype=dtypes.int32).reshape(1, n) - - # Compute powers matrix (i*j) % n using tinygrad operations - powers = (i * j) % n - - # Precompute all powers of omega as a tensor directly - omega_values = [] - current = 1 - for k in range(n): - omega_values.append(current) - current = (current * omega_inv) % prime - - omega_tensor = Tensor(omega_values, dtype=dtypes.int32) - - # Using advanced indexing with the powers tensor to get omega^(i*j) % prime - ntt_matrix = omega_tensor[powers] - n_inv = pow(n, prime - 2, prime) - ntt_matrix = (ntt_matrix * n_inv) % prime - - return ntt_matrix +import numpy as np def ntt(polynomial, prime, primitive_root): """ @@ -62,24 +10,11 @@ def ntt(polynomial, prime, primitive_root): polynomial: Coefficient vector of the polynomial (Tensor or list/array) prime: The prime modulus primitive_root: A primitive root modulo prime - + Returns: The NTT of the polynomial as a Tensor """ - # Convert polynomial to Tensor if it's not already - if not isinstance(polynomial, Tensor): - polynomial = Tensor(polynomial, dtype=dtypes.uint32) - - n = len(polynomial) - matrix = ntt_matrix(n, prime, primitive_root) - - # Ensure polynomial coefficients are within the field - polynomial = polynomial % prime - - # Compute the matrix multiplication (dot product) and take modulo prime - result = matrix.matmul(polynomial) % prime - - return result + return _transform(polynomial, prime, primitive_root, inverse=False) def intt(transformed, prime, primitive_root): """ @@ -89,43 +24,58 @@ def intt(transformed, prime, primitive_root): transformed: The transformed polynomial (Tensor or list/array) prime: The prime modulus primitive_root: A primitive root modulo prime - + Returns: The original polynomial coefficients as a Tensor """ - # Convert transformed to Tensor if it's not already - if not isinstance(transformed, Tensor): - transformed = Tensor(transformed, dtype=dtypes.uint32) + return _transform(transformed, prime, primitive_root, inverse=True) + +def _transform(x, prime, primitive_root, inverse=False): + """Internal function to perform NTT or INTT transformation""" + if not isinstance(x, Tensor): + x = Tensor(x, dtype=dtypes.uint32) - n = len(transformed) - matrix = intt_matrix(n, prime, primitive_root) + dtype = x.dtype + n = len(x) - # Ensure transformed values are within the field - transformed = transformed % prime + # Create powers matrix (i*j) % n + i = Tensor.arange(n, dtype=dtype).reshape(n, 1) + j = Tensor.arange(n, dtype=dtype).reshape(1, n) + powers = (i * j) % n - # Compute the matrix multiplication (dot product) and take modulo prime - result = matrix.matmul(transformed) % prime + # Get omega (or inverse omega for INTT) + omega = pow(primitive_root, (prime - 1) // n, prime) + if inverse: + omega = pow(omega, prime - 2, prime) # Modular inverse - return result + # Compute all powers of omega + omega_powers = np.ones(n, dtype=np.uint32) + current = 1 + for k in range(n): + omega_powers[k] = current + current = (current * omega) % prime + + # Create transformation matrix + matrix = Tensor(omega_powers, dtype=dtypes.uint32)[powers] + + # For inverse, apply scaling factor (1/n mod prime) + if inverse: + n_inv = pow(n, prime - 2, prime) + matrix = (matrix * n_inv) % prime + + # Perform the transformation + return (matrix @ (x % prime)) % prime -# Example usage if __name__ == "__main__": - # Parameters for a small example - n = 8 # Must be a power of 2 - prime = 17 # A prime number where (prime - 1) is divisible by n - primitive_root = 3 # A primitive root modulo prime - - # Example polynomial: x^3 + 2x^2 + 3x + 4 + # Example + n, prime, primitive_root = 8, 17, 3 polynomial = Tensor([4, 3, 2, 1, 0, 0, 0, 0], dtype=dtypes.uint32) - # Perform NTT transformed = ntt(polynomial, prime, primitive_root) print("NTT result:", transformed.numpy()) - # Perform INTT to get back the original polynomial original = intt(transformed, prime, primitive_root) - print("INTT result (should match original polynomial):", original.numpy()) + print("INTT result:", original.numpy()) - # Verify the result assert (polynomial.numpy() == original.numpy()).all(), "INTT(NTT(polynomial)) != polynomial" - print("Verification successful: INTT(NTT(polynomial)) = polynomial") \ No newline at end of file + print("Verification successful") \ No newline at end of file From b54a37d8020948681e4ac46168fdcfef999df6ec Mon Sep 17 00:00:00 2001 From: Bhargav Annem Date: Tue, 11 Mar 2025 20:54:25 -0700 Subject: [PATCH 06/10] feat(wip): ntt --- algebra/ff/babybear.py | 1 + algebra/ff/m31.py | 3 +- algebra/ff/prime_field.py | 1 + algebra/ff/test.py | 6 ---- algebra/poly/ntt.py | 57 +++++++++++++++++++++++++++++--------- algebra/poly/univariate.py | 31 +++++++++++++++------ test.py | 16 +++++++++++ tests/test_poly.py | 6 ++++ 8 files changed, 92 insertions(+), 29 deletions(-) delete mode 100644 algebra/ff/test.py create mode 100644 test.py diff --git a/algebra/ff/babybear.py b/algebra/ff/babybear.py index e0950f6..07f45e5 100644 --- a/algebra/ff/babybear.py +++ b/algebra/ff/babybear.py @@ -4,3 +4,4 @@ class BabyBear(PrimeField): # BabyBear prime: 2^31 - 2^27 + 1 P = 2013265921 + w = 31 diff --git a/algebra/ff/m31.py b/algebra/ff/m31.py index 1e61c5a..d5bd675 100644 --- a/algebra/ff/m31.py +++ b/algebra/ff/m31.py @@ -3,4 +3,5 @@ class M31(PrimeField): # Mersenne31 prime: 2^31 - 1 - P = 2147483647 + P = 2**31 - 1 + w = 7 \ No newline at end of file diff --git a/algebra/ff/prime_field.py b/algebra/ff/prime_field.py index 504593f..f8fbded 100644 --- a/algebra/ff/prime_field.py +++ b/algebra/ff/prime_field.py @@ -4,6 +4,7 @@ class PrimeField: P: int = None + w: int = None def __init__(self, x): if isinstance(x, (int, float, list, Tensor)): diff --git a/algebra/ff/test.py b/algebra/ff/test.py deleted file mode 100644 index e4760fd..0000000 --- a/algebra/ff/test.py +++ /dev/null @@ -1,6 +0,0 @@ -from tinygrad import Tensor - -a = Tensor([1, 2, 3]) -b = Tensor([4, 5, 6]) - -print((a % 3).numpy()) diff --git a/algebra/poly/ntt.py b/algebra/poly/ntt.py index b6b3ca1..5ed4e69 100644 --- a/algebra/poly/ntt.py +++ b/algebra/poly/ntt.py @@ -2,7 +2,33 @@ from tinygrad import dtypes import numpy as np -def ntt(polynomial, prime, primitive_root): +def _next_valid_n(l, p): + """Find next n >= l that divides p-1 efficiently""" + p_minus_1 = p - 1 + for n in range(l, min(p, 2*l)): + if p_minus_1 % n == 0: + return n + + bit_length = l.bit_length() + power_of_two = 1 << bit_length + if power_of_two < l: + power_of_two = power_of_two << 1 + + while power_of_two < p: + if p_minus_1 % power_of_two == 0: + return power_of_two + power_of_two = power_of_two << 1 + + for divisor in range(l, int(p_minus_1**0.5) + 1): + if p_minus_1 % divisor == 0: + quotient = p_minus_1 // divisor + if quotient >= l: + return quotient + return divisor + + return p - 1 + +def ntt(polynomial: Tensor, prime: int, primitive_root: int) -> Tensor: """ Compute the Number Theoretic Transform of a polynomial. @@ -16,7 +42,7 @@ def ntt(polynomial, prime, primitive_root): """ return _transform(polynomial, prime, primitive_root, inverse=False) -def intt(transformed, prime, primitive_root): +def intt(transformed: Tensor, prime: int, primitive_root: int) -> Tensor: """ Compute the Inverse Number Theoretic Transform. @@ -30,17 +56,22 @@ def intt(transformed, prime, primitive_root): """ return _transform(transformed, prime, primitive_root, inverse=True) -def _transform(x, prime, primitive_root, inverse=False): +def _transform(x, prime: int, primitive_root: int, inverse: bool = False) -> Tensor: """Internal function to perform NTT or INTT transformation""" if not isinstance(x, Tensor): - x = Tensor(x, dtype=dtypes.uint32) + x = Tensor(x, dtype=dtypes.uint64) dtype = x.dtype - n = len(x) + n = _next_valid_n(len(x), prime) + + if len(x) < n: + padded_np = Tensor.zeros(n, dtype=dtype).contiguous() + padded_np[:len(x)] = x + x = padded_np # Create powers matrix (i*j) % n - i = Tensor.arange(n, dtype=dtype).reshape(n, 1) - j = Tensor.arange(n, dtype=dtype).reshape(1, n) + i = Tensor.arange(n, dtype=dtypes.uint64).reshape(n, 1) + j = Tensor.arange(n, dtype=dtypes.uint64).reshape(1, n) powers = (i * j) % n # Get omega (or inverse omega for INTT) @@ -49,14 +80,14 @@ def _transform(x, prime, primitive_root, inverse=False): omega = pow(omega, prime - 2, prime) # Modular inverse # Compute all powers of omega - omega_powers = np.ones(n, dtype=np.uint32) + omega_powers = np.ones(n, dtype=np.uint64) current = 1 for k in range(n): omega_powers[k] = current current = (current * omega) % prime # Create transformation matrix - matrix = Tensor(omega_powers, dtype=dtypes.uint32)[powers] + matrix = Tensor(omega_powers, dtype=dtypes.uint64)[powers] # For inverse, apply scaling factor (1/n mod prime) if inverse: @@ -67,9 +98,9 @@ def _transform(x, prime, primitive_root, inverse=False): return (matrix @ (x % prime)) % prime if __name__ == "__main__": - # Example - n, prime, primitive_root = 8, 17, 3 - polynomial = Tensor([4, 3, 2, 1, 0, 0, 0, 0], dtype=dtypes.uint32) + from algebra.ff.m31 import M31 + n, prime, primitive_root = 8, M31.P, M31.w + polynomial = Tensor([1, 2, 3], dtype=dtypes.uint64) transformed = ntt(polynomial, prime, primitive_root) print("NTT result:", transformed.numpy()) @@ -77,5 +108,5 @@ def _transform(x, prime, primitive_root, inverse=False): original = intt(transformed, prime, primitive_root) print("INTT result:", original.numpy()) - assert (polynomial.numpy() == original.numpy()).all(), "INTT(NTT(polynomial)) != polynomial" + assert (polynomial.numpy() == original.numpy()[:len(polynomial)]).all(), "INTT(NTT(polynomial)) != polynomial" print("Verification successful") \ No newline at end of file diff --git a/algebra/poly/univariate.py b/algebra/poly/univariate.py index ca5c83c..9afc4de 100644 --- a/algebra/poly/univariate.py +++ b/algebra/poly/univariate.py @@ -1,7 +1,8 @@ from tinygrad.tensor import Tensor from tinygrad import dtypes from algebra.ff.prime_field import PrimeField as PF - +from algebra.poly.ntt import ntt, intt +import numpy as np class Polynomial: PrimeField = None @@ -14,17 +15,15 @@ class Polynomial: """ def __init__(self, coeffs: list[int] | Tensor, prime_field: PF = None): - """ - Initialize the polynomial. - - coeffs: a list of field elements (instances of a PrimeField subclass) - prime_field: optional prime field class to use for modular arithmetic - """ + """Initialize polynomial with coefficients and optional prime field.""" self.PrimeField = prime_field if isinstance(coeffs, list): + coeffs = np.trim_zeros(coeffs, 'b') self.coeffs = Tensor(coeffs, dtype=dtypes.int32) - elif isinstance(coeffs, Tensor): - self.coeffs = coeffs + else: + coeffs_np = coeffs.numpy() + coeffs_np = np.trim_zeros(coeffs_np, 'b') + self.coeffs = Tensor(coeffs_np, dtype=coeffs.dtype) def degree(self) -> int: """ @@ -49,6 +48,20 @@ def evaluate(self, x: int | Tensor): result = result * x + coeff return result + def ntt(self): + """ + Compute the Number Theoretic Transform of the polynomial. + """ + p_ntt = ntt(self.coeffs.cast(dtypes.uint32), self.PrimeField.P, self.PrimeField.w) + return Polynomial(p_ntt, self.PrimeField) + + def intt(self): + """ + Compute the Inverse Number Theoretic Transform of the polynomial. + """ + p_intt = intt(self.coeffs.cast(dtypes.uint32), self.PrimeField.P, self.PrimeField.w) + return Polynomial(p_intt, self.PrimeField) + def __evaluate_all(self, xs: Tensor): """ Evaluate the polynomial at all elements in xs using Horner's method. diff --git a/test.py b/test.py new file mode 100644 index 0000000..4f8c604 --- /dev/null +++ b/test.py @@ -0,0 +1,16 @@ +from algebra.ff.m31 import M31 +from algebra.poly.univariate import Polynomial +from random import randint +from tinygrad import Tensor, dtypes + +p1 = Polynomial(Tensor([randint(0, 10) for _ in range(10)], dtype=dtypes.uint32), M31) +print(f'p1: {p1.coeffs.numpy()}') + +p1_ntt = p1.ntt() +print(f'p1_ntt: {p1_ntt.coeffs.numpy()}') + +p1_intt = p1_ntt.intt() +print(f'p1_intt: {p1_intt.coeffs.numpy()}') + +assert (p1_intt.coeffs.numpy() == p1.coeffs.numpy()).all() + diff --git a/tests/test_poly.py b/tests/test_poly.py index aa8753c..ee33f3d 100644 --- a/tests/test_poly.py +++ b/tests/test_poly.py @@ -2,6 +2,7 @@ from algebra.ff.m31 import M31 from algebra.ff.babybear import BabyBear from tinygrad import Tensor +from random import randint def test_polynomial_operations_m31(): @@ -70,3 +71,8 @@ def test_polynomial_operations_babybear(): # Test evaluate_all result = p1(Tensor([1, 2, 3])).numpy() assert (result == [6, 17, 34]).all() + + p7 = Polynomial([randint(0, 100) for _ in range(10)], M31) + p7_ntt = p7.ntt() + p7_intt = p7_ntt.intt() + assert (p7_intt.coeffs.numpy() == p7.coeffs.numpy()).all() \ No newline at end of file From 823f1e04d2c2b4653803f6014631da767d08a8bf Mon Sep 17 00:00:00 2001 From: Bhargav Annem Date: Tue, 11 Mar 2025 21:00:24 -0700 Subject: [PATCH 07/10] feat(wip): passing test --- algebra/poly/univariate.py | 4 ++-- test.py | 16 ---------------- tests/test_poly.py | 2 +- 3 files changed, 3 insertions(+), 19 deletions(-) delete mode 100644 test.py diff --git a/algebra/poly/univariate.py b/algebra/poly/univariate.py index 9afc4de..f5350a2 100644 --- a/algebra/poly/univariate.py +++ b/algebra/poly/univariate.py @@ -52,14 +52,14 @@ def ntt(self): """ Compute the Number Theoretic Transform of the polynomial. """ - p_ntt = ntt(self.coeffs.cast(dtypes.uint32), self.PrimeField.P, self.PrimeField.w) + p_ntt = ntt(self.coeffs.cast(dtypes.uint64), self.PrimeField.P, self.PrimeField.w).cast(self.coeffs.dtype) return Polynomial(p_ntt, self.PrimeField) def intt(self): """ Compute the Inverse Number Theoretic Transform of the polynomial. """ - p_intt = intt(self.coeffs.cast(dtypes.uint32), self.PrimeField.P, self.PrimeField.w) + p_intt = intt(self.coeffs.cast(dtypes.uint64), self.PrimeField.P, self.PrimeField.w).cast(self.coeffs.dtype) return Polynomial(p_intt, self.PrimeField) def __evaluate_all(self, xs: Tensor): diff --git a/test.py b/test.py deleted file mode 100644 index 4f8c604..0000000 --- a/test.py +++ /dev/null @@ -1,16 +0,0 @@ -from algebra.ff.m31 import M31 -from algebra.poly.univariate import Polynomial -from random import randint -from tinygrad import Tensor, dtypes - -p1 = Polynomial(Tensor([randint(0, 10) for _ in range(10)], dtype=dtypes.uint32), M31) -print(f'p1: {p1.coeffs.numpy()}') - -p1_ntt = p1.ntt() -print(f'p1_ntt: {p1_ntt.coeffs.numpy()}') - -p1_intt = p1_ntt.intt() -print(f'p1_intt: {p1_intt.coeffs.numpy()}') - -assert (p1_intt.coeffs.numpy() == p1.coeffs.numpy()).all() - diff --git a/tests/test_poly.py b/tests/test_poly.py index ee33f3d..9e9dc63 100644 --- a/tests/test_poly.py +++ b/tests/test_poly.py @@ -72,7 +72,7 @@ def test_polynomial_operations_babybear(): result = p1(Tensor([1, 2, 3])).numpy() assert (result == [6, 17, 34]).all() - p7 = Polynomial([randint(0, 100) for _ in range(10)], M31) + p7 = Polynomial([randint(0, 100) for _ in range(8)], M31) p7_ntt = p7.ntt() p7_intt = p7_ntt.intt() assert (p7_intt.coeffs.numpy() == p7.coeffs.numpy()).all() \ No newline at end of file From 3471fcb06f0c1b1a8349497d043bd38110d8b1c9 Mon Sep 17 00:00:00 2001 From: Bhargav Annem Date: Tue, 11 Mar 2025 21:01:33 -0700 Subject: [PATCH 08/10] chore: ruff --- algebra/ff/babybear.py | 2 +- algebra/ff/m31.py | 2 +- algebra/poly/ntt.py | 204 +++++++++++++++++++------------------ algebra/poly/univariate.py | 5 +- tests/test_poly.py | 2 +- 5 files changed, 111 insertions(+), 104 deletions(-) diff --git a/algebra/ff/babybear.py b/algebra/ff/babybear.py index 07f45e5..b2ee162 100644 --- a/algebra/ff/babybear.py +++ b/algebra/ff/babybear.py @@ -4,4 +4,4 @@ class BabyBear(PrimeField): # BabyBear prime: 2^31 - 2^27 + 1 P = 2013265921 - w = 31 + w = 31 diff --git a/algebra/ff/m31.py b/algebra/ff/m31.py index d5bd675..9675479 100644 --- a/algebra/ff/m31.py +++ b/algebra/ff/m31.py @@ -4,4 +4,4 @@ class M31(PrimeField): # Mersenne31 prime: 2^31 - 1 P = 2**31 - 1 - w = 7 \ No newline at end of file + w = 7 diff --git a/algebra/poly/ntt.py b/algebra/poly/ntt.py index 5ed4e69..cebcec4 100644 --- a/algebra/poly/ntt.py +++ b/algebra/poly/ntt.py @@ -2,111 +2,117 @@ from tinygrad import dtypes import numpy as np + def _next_valid_n(l, p): - """Find next n >= l that divides p-1 efficiently""" - p_minus_1 = p - 1 - for n in range(l, min(p, 2*l)): - if p_minus_1 % n == 0: - return n - - bit_length = l.bit_length() - power_of_two = 1 << bit_length - if power_of_two < l: - power_of_two = power_of_two << 1 - - while power_of_two < p: - if p_minus_1 % power_of_two == 0: - return power_of_two - power_of_two = power_of_two << 1 - - for divisor in range(l, int(p_minus_1**0.5) + 1): - if p_minus_1 % divisor == 0: - quotient = p_minus_1 // divisor - if quotient >= l: - return quotient - return divisor - - return p - 1 + """Find next n >= l that divides p-1 efficiently""" + p_minus_1 = p - 1 + for n in range(l, min(p, 2 * l)): + if p_minus_1 % n == 0: + return n + + bit_length = l.bit_length() + power_of_two = 1 << bit_length + if power_of_two < l: + power_of_two = power_of_two << 1 + + while power_of_two < p: + if p_minus_1 % power_of_two == 0: + return power_of_two + power_of_two = power_of_two << 1 + + for divisor in range(l, int(p_minus_1**0.5) + 1): + if p_minus_1 % divisor == 0: + quotient = p_minus_1 // divisor + if quotient >= l: + return quotient + return divisor + + return p - 1 + def ntt(polynomial: Tensor, prime: int, primitive_root: int) -> Tensor: - """ - Compute the Number Theoretic Transform of a polynomial. - - Args: - polynomial: Coefficient vector of the polynomial (Tensor or list/array) - prime: The prime modulus - primitive_root: A primitive root modulo prime - - Returns: - The NTT of the polynomial as a Tensor - """ - return _transform(polynomial, prime, primitive_root, inverse=False) + """ + Compute the Number Theoretic Transform of a polynomial. + + Args: + polynomial: Coefficient vector of the polynomial (Tensor or list/array) + prime: The prime modulus + primitive_root: A primitive root modulo prime + + Returns: + The NTT of the polynomial as a Tensor + """ + return _transform(polynomial, prime, primitive_root, inverse=False) + def intt(transformed: Tensor, prime: int, primitive_root: int) -> Tensor: - """ - Compute the Inverse Number Theoretic Transform. - - Args: - transformed: The transformed polynomial (Tensor or list/array) - prime: The prime modulus - primitive_root: A primitive root modulo prime - - Returns: - The original polynomial coefficients as a Tensor - """ - return _transform(transformed, prime, primitive_root, inverse=True) + """ + Compute the Inverse Number Theoretic Transform. + + Args: + transformed: The transformed polynomial (Tensor or list/array) + prime: The prime modulus + primitive_root: A primitive root modulo prime + + Returns: + The original polynomial coefficients as a Tensor + """ + return _transform(transformed, prime, primitive_root, inverse=True) + def _transform(x, prime: int, primitive_root: int, inverse: bool = False) -> Tensor: - """Internal function to perform NTT or INTT transformation""" - if not isinstance(x, Tensor): - x = Tensor(x, dtype=dtypes.uint64) - - dtype = x.dtype - n = _next_valid_n(len(x), prime) - - if len(x) < n: - padded_np = Tensor.zeros(n, dtype=dtype).contiguous() - padded_np[:len(x)] = x - x = padded_np - - # Create powers matrix (i*j) % n - i = Tensor.arange(n, dtype=dtypes.uint64).reshape(n, 1) - j = Tensor.arange(n, dtype=dtypes.uint64).reshape(1, n) - powers = (i * j) % n - - # Get omega (or inverse omega for INTT) - omega = pow(primitive_root, (prime - 1) // n, prime) - if inverse: - omega = pow(omega, prime - 2, prime) # Modular inverse - - # Compute all powers of omega - omega_powers = np.ones(n, dtype=np.uint64) - current = 1 - for k in range(n): - omega_powers[k] = current - current = (current * omega) % prime - - # Create transformation matrix - matrix = Tensor(omega_powers, dtype=dtypes.uint64)[powers] - - # For inverse, apply scaling factor (1/n mod prime) - if inverse: - n_inv = pow(n, prime - 2, prime) - matrix = (matrix * n_inv) % prime - - # Perform the transformation - return (matrix @ (x % prime)) % prime + """Internal function to perform NTT or INTT transformation""" + if not isinstance(x, Tensor): + x = Tensor(x, dtype=dtypes.uint64) + + dtype = x.dtype + n = _next_valid_n(len(x), prime) + + if len(x) < n: + padded_np = Tensor.zeros(n, dtype=dtype).contiguous() + padded_np[: len(x)] = x + x = padded_np + + # Create powers matrix (i*j) % n + i = Tensor.arange(n, dtype=dtypes.uint64).reshape(n, 1) + j = Tensor.arange(n, dtype=dtypes.uint64).reshape(1, n) + powers = (i * j) % n + + # Get omega (or inverse omega for INTT) + omega = pow(primitive_root, (prime - 1) // n, prime) + if inverse: + omega = pow(omega, prime - 2, prime) # Modular inverse + + # Compute all powers of omega + omega_powers = np.ones(n, dtype=np.uint64) + current = 1 + for k in range(n): + omega_powers[k] = current + current = (current * omega) % prime + + # Create transformation matrix + matrix = Tensor(omega_powers, dtype=dtypes.uint64)[powers] + + # For inverse, apply scaling factor (1/n mod prime) + if inverse: + n_inv = pow(n, prime - 2, prime) + matrix = (matrix * n_inv) % prime + + # Perform the transformation + return (matrix @ (x % prime)) % prime + if __name__ == "__main__": - from algebra.ff.m31 import M31 - n, prime, primitive_root = 8, M31.P, M31.w - polynomial = Tensor([1, 2, 3], dtype=dtypes.uint64) - - transformed = ntt(polynomial, prime, primitive_root) - print("NTT result:", transformed.numpy()) - - original = intt(transformed, prime, primitive_root) - print("INTT result:", original.numpy()) - - assert (polynomial.numpy() == original.numpy()[:len(polynomial)]).all(), "INTT(NTT(polynomial)) != polynomial" - print("Verification successful") \ No newline at end of file + from algebra.ff.m31 import M31 + + n, prime, primitive_root = 8, M31.P, M31.w + polynomial = Tensor([1, 2, 3], dtype=dtypes.uint64) + + transformed = ntt(polynomial, prime, primitive_root) + print("NTT result:", transformed.numpy()) + + original = intt(transformed, prime, primitive_root) + print("INTT result:", original.numpy()) + + assert (polynomial.numpy() == original.numpy()[: len(polynomial)]).all(), "INTT(NTT(polynomial)) != polynomial" + print("Verification successful") diff --git a/algebra/poly/univariate.py b/algebra/poly/univariate.py index f5350a2..6867ca4 100644 --- a/algebra/poly/univariate.py +++ b/algebra/poly/univariate.py @@ -4,6 +4,7 @@ from algebra.poly.ntt import ntt, intt import numpy as np + class Polynomial: PrimeField = None @@ -18,11 +19,11 @@ def __init__(self, coeffs: list[int] | Tensor, prime_field: PF = None): """Initialize polynomial with coefficients and optional prime field.""" self.PrimeField = prime_field if isinstance(coeffs, list): - coeffs = np.trim_zeros(coeffs, 'b') + coeffs = np.trim_zeros(coeffs, "b") self.coeffs = Tensor(coeffs, dtype=dtypes.int32) else: coeffs_np = coeffs.numpy() - coeffs_np = np.trim_zeros(coeffs_np, 'b') + coeffs_np = np.trim_zeros(coeffs_np, "b") self.coeffs = Tensor(coeffs_np, dtype=coeffs.dtype) def degree(self) -> int: diff --git a/tests/test_poly.py b/tests/test_poly.py index 9e9dc63..b9b87dd 100644 --- a/tests/test_poly.py +++ b/tests/test_poly.py @@ -75,4 +75,4 @@ def test_polynomial_operations_babybear(): p7 = Polynomial([randint(0, 100) for _ in range(8)], M31) p7_ntt = p7.ntt() p7_intt = p7_ntt.intt() - assert (p7_intt.coeffs.numpy() == p7.coeffs.numpy()).all() \ No newline at end of file + assert (p7_intt.coeffs.numpy() == p7.coeffs.numpy()).all() From 3a2a933d2a614559659ac7e940cd648d397502c2 Mon Sep 17 00:00:00 2001 From: Bhargav Annem Date: Tue, 11 Mar 2025 23:26:21 -0700 Subject: [PATCH 09/10] fix: modulus --- algebra/poly/ntt.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/algebra/poly/ntt.py b/algebra/poly/ntt.py index cebcec4..8fb7eee 100644 --- a/algebra/poly/ntt.py +++ b/algebra/poly/ntt.py @@ -67,6 +67,7 @@ def _transform(x, prime: int, primitive_root: int, inverse: bool = False) -> Ten dtype = x.dtype n = _next_valid_n(len(x), prime) + print(f"n: {n}, prime: {prime}, primitive_root: {primitive_root}") if len(x) < n: padded_np = Tensor.zeros(n, dtype=dtype).contiguous() @@ -92,21 +93,24 @@ def _transform(x, prime: int, primitive_root: int, inverse: bool = False) -> Ten # Create transformation matrix matrix = Tensor(omega_powers, dtype=dtypes.uint64)[powers] + result = (matrix @ (x % prime)) % prime # For inverse, apply scaling factor (1/n mod prime) if inverse: n_inv = pow(n, prime - 2, prime) - matrix = (matrix * n_inv) % prime + result = (result * n_inv) % prime # Perform the transformation - return (matrix @ (x % prime)) % prime + return result if __name__ == "__main__": from algebra.ff.m31 import M31 + from random import randint - n, prime, primitive_root = 8, M31.P, M31.w - polynomial = Tensor([1, 2, 3], dtype=dtypes.uint64) + n, prime, primitive_root = 10, M31.P, M31.w + polynomial = Tensor([randint(0, prime - 1) for _ in range(n)], dtype=dtypes.uint64) + print(f'polynomial: {polynomial.numpy()}') transformed = ntt(polynomial, prime, primitive_root) print("NTT result:", transformed.numpy()) From cb69007f977fcd91ef634927d3f40f4182b05445 Mon Sep 17 00:00:00 2001 From: Bhargav Annem Date: Tue, 11 Mar 2025 23:26:44 -0700 Subject: [PATCH 10/10] chore: formatting --- algebra/poly/ntt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algebra/poly/ntt.py b/algebra/poly/ntt.py index 8fb7eee..e2d7542 100644 --- a/algebra/poly/ntt.py +++ b/algebra/poly/ntt.py @@ -110,7 +110,7 @@ def _transform(x, prime: int, primitive_root: int, inverse: bool = False) -> Ten n, prime, primitive_root = 10, M31.P, M31.w polynomial = Tensor([randint(0, prime - 1) for _ in range(n)], dtype=dtypes.uint64) - print(f'polynomial: {polynomial.numpy()}') + print(f"polynomial: {polynomial.numpy()}") transformed = ntt(polynomial, prime, primitive_root) print("NTT result:", transformed.numpy())