from typing import List from modulo_theory import friendly_modular_mult, friendly_modulo def mask_r(r: int) -> int: r_bytes = r.to_bytes(16, "little") r_masked = bytearray(r_bytes) r_masked[3] &= 15 r_masked[7] &= 15 r_masked[11] &= 15 r_masked[15] &= 15 r_masked[4] &= 252 r_masked[8] &= 252 r_masked[12] &= 252 r_masked = int.from_bytes(r_masked, "little") return r_masked def poly1305(message: bytes, r: int, s: int): r = mask_r(r) p = 2**130-5 acc = 0 blocks = [int.from_bytes(message[i:i+16], "little") for i in range(0, len(message), 16)] for block in blocks: byte_length = (block.bit_length() + 7) // 8 block += 1 << (8*byte_length) acc = ((acc+block)*r) % p acc += s return acc & (2**128-1) def parallel_poly1305(message: bytes, r: int, s: int, lanes: int): r = mask_r(r) p = 2**130-5 r_powers = [1, r] for l_pow_log2 in range(3): l_pow = 2**l_pow_log2 for r_pow in range(1,l_pow+1): r_powers.append(friendly_modular_mult(r_powers[l_pow], r_powers[r_pow])) acc = [0]*lanes blocks = [int.from_bytes(message[i:i+16], "little") for i in range(0, len(message), 16)] lane_blocks = [blocks[i:i+lanes] for i in range(0, len(blocks), lanes)] for i, lane_block in enumerate(lane_blocks): for j, lane in enumerate(lane_block): idx = i*lanes + j power = min(lanes, len(blocks) - idx) # There is a division here but we can get this value somehow else byte_length = (lane.bit_length() + 7) // 8 lane += 1 << (8*byte_length) acc[j] = friendly_modular_mult(acc[j] + lane, r_powers[power]) combined_acc = friendly_modulo(sum(acc), 0) combined_acc += s return combined_acc & (2**128-1) def test_regular(): r = 0xa806d542fe52447f336d555778bed685 s = 0x1bf54941aff6bf4afdb20dfb8a800301 golden_result = 0xa927010caf8b2bc2c6365130c11d06a8 msg = b"Cryptographic Forum Research Group" result = poly1305(msg, r, s) print(f"{golden_result:x}") print(f"{result:x}") def test_parallel(): r = 0xa806d542fe52447f336d555778bed685 s = 0x1bf54941aff6bf4afdb20dfb8a800301 golden_result = 0xa927010caf8b2bc2c6365130c11d06a8 msg = b"Cryptographic Forum Research Group" result = parallel_poly1305(msg, r, s, 8) print(f"{golden_result:x}") print(f"{result:x}") def test_on_long_string(): r = 0xa806d542fe52447f336d555778bed685 s = 0x1bf54941aff6bf4afdb20dfb8a800301 msg = b"Very long message with lots of words that is very long and requires a lot of cycles to complete because of how long it is" regular_result = poly1305(msg, r, s) parallel_result = parallel_poly1305(msg, r, s, 8) print(f"{regular_result:x}") print(f"{parallel_result:x}") def main(): test_regular() test_parallel() test_on_long_string() if __name__ == "__main__": main()