diff --git a/ChaCha20_Poly1305_64/sim/do_poly_1305.py b/ChaCha20_Poly1305_64/sim/do_poly_1305.py index c648483..7a63286 100644 --- a/ChaCha20_Poly1305_64/sim/do_poly_1305.py +++ b/ChaCha20_Poly1305_64/sim/do_poly_1305.py @@ -36,8 +36,33 @@ def poly1305(message: bytes, r: int, s: int): return acc & (2**128-1) +def parallel_poly1305(message: bytes, r: int, s: int, lanes: int): + r = mask_r(r) + p = 2**130-5 -def main(): + acc = [0]*lanes + + blocks = [int.from_bytes(message[i:i+16], "little") for i in range(0, len(message), 16)] + + lane_blocks = [blocks[i:i+lanes] for i in range(0, len(blocks), lanes)] + + for i, lane_block in enumerate(lane_blocks): + for j, lane in enumerate(lane_block): + idx = i*lanes + j + power = min(lanes, len(blocks) - idx) + + byte_length = (lane.bit_length() + 7) // 8 + lane += 1 << (8*byte_length) + + acc[j] = ((acc[j] + lane)*(r**power)) % p + + combined_acc = sum(acc) % p + combined_acc += s + + return combined_acc & (2**128-1) + + +def test_regular(): r = 0xa806d542fe52447f336d555778bed685 s = 0x1bf54941aff6bf4afdb20dfb8a800301 @@ -50,5 +75,37 @@ def main(): print(f"{golden_result:x}") print(f"{result:x}") +def test_parallel(): + r = 0xa806d542fe52447f336d555778bed685 + s = 0x1bf54941aff6bf4afdb20dfb8a800301 + + golden_result = 0xa927010caf8b2bc2c6365130c11d06a8 + + msg = b"Cryptographic Forum Research Group" + + result = parallel_poly1305(msg, r, s, 8) + + print(f"{golden_result:x}") + print(f"{result:x}") + + +def test_on_long_string(): + r = 0xa806d542fe52447f336d555778bed685 + s = 0x1bf54941aff6bf4afdb20dfb8a800301 + + msg = b"Very long message with lots of words that is very long and requires a lot of cycles to complete because of how long it is" + + regular_result = poly1305(msg, r, s) + parallel_result = parallel_poly1305(msg, r, s, 8) + + print(f"{regular_result:x}") + print(f"{parallel_result:x}") + + +def main(): + test_regular() + test_parallel() + test_on_long_string() + if __name__ == "__main__": main()