Calculate r powers ahead of time

This commit is contained in:
Byron Lathi
2025-10-26 15:43:58 -07:00
parent faef39c4d3
commit fd50ecc4f0

View File

@@ -4,13 +4,13 @@ def mask_r(r: int) -> int:
r_bytes = r.to_bytes(16, "little") r_bytes = r.to_bytes(16, "little")
r_masked = bytearray(r_bytes) r_masked = bytearray(r_bytes)
r_masked[3] &= 15; r_masked[3] &= 15
r_masked[7] &= 15; r_masked[7] &= 15
r_masked[11] &= 15; r_masked[11] &= 15
r_masked[15] &= 15; r_masked[15] &= 15
r_masked[4] &= 252; r_masked[4] &= 252
r_masked[8] &= 252; r_masked[8] &= 252
r_masked[12] &= 252; r_masked[12] &= 252
r_masked = int.from_bytes(r_masked, "little") r_masked = int.from_bytes(r_masked, "little")
@@ -39,6 +39,8 @@ def poly1305(message: bytes, r: int, s: int):
def parallel_poly1305(message: bytes, r: int, s: int, lanes: int): def parallel_poly1305(message: bytes, r: int, s: int, lanes: int):
r = mask_r(r) r = mask_r(r)
p = 2**130-5 p = 2**130-5
r_powers = [r**i % p for i in range(lanes+1)]
acc = [0]*lanes acc = [0]*lanes
@@ -54,7 +56,7 @@ def parallel_poly1305(message: bytes, r: int, s: int, lanes: int):
byte_length = (lane.bit_length() + 7) // 8 byte_length = (lane.bit_length() + 7) // 8
lane += 1 << (8*byte_length) lane += 1 << (8*byte_length)
acc[j] = ((acc[j] + lane)*(r**power)) % p acc[j] = ((acc[j] + lane)*(r_powers[power])) % p
combined_acc = sum(acc) % p combined_acc = sum(acc) % p
combined_acc += s combined_acc += s