diff --git a/ChaCha20_Poly1305_64/doc/friendly_modular_mult.drawio b/ChaCha20_Poly1305_64/doc/friendly_modular_mult.drawio new file mode 100644 index 0000000..98e9b71 --- /dev/null +++ b/ChaCha20_Poly1305_64/doc/friendly_modular_mult.drawio @@ -0,0 +1,132 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/ChaCha20_Poly1305_64/sim/do_poly_1305.py b/ChaCha20_Poly1305_64/sim/do_poly_1305.py index ee5d1f0..39d51dc 100644 --- a/ChaCha20_Poly1305_64/sim/do_poly_1305.py +++ b/ChaCha20_Poly1305_64/sim/do_poly_1305.py @@ -1,4 +1,5 @@ from typing import List +import random from modulo_theory import friendly_modular_mult, friendly_modulo @@ -111,11 +112,19 @@ def test_on_long_string(): print(f"{regular_result:x}") print(f"{parallel_result:x}") +def test_random(): + r = mask_r(random.randint(0, 2**128-1)) + s = random.randint(0, 2**128-1) + + msg = random.randbytes(random.randint(16, 1500)) + + parallel_poly1305(msg, r, s, 8) def main(): test_regular() test_parallel() test_on_long_string() + test_random() if __name__ == "__main__": main() diff --git a/ChaCha20_Poly1305_64/sim/modulo_theory.py b/ChaCha20_Poly1305_64/sim/modulo_theory.py index 84cf82c..dd8929c 100644 --- a/ChaCha20_Poly1305_64/sim/modulo_theory.py +++ b/ChaCha20_Poly1305_64/sim/modulo_theory.py @@ -57,6 +57,9 @@ def friendly_modular_mult(value_a: int, value_b: int) -> int: mods = [friendly_modulo(prod, 26*i) for i, prod in enumerate(prods)] + if sum(mods) >= 2*PRIME: + print("Saw greater than 2x prime!!!") + mod_sum = friendly_modulo(sum(mods), 0) diff --git a/ChaCha20_Poly1305_64/sim/poly1305.yaml b/ChaCha20_Poly1305_64/sim/poly1305.yaml index ce869e6..62e3618 100644 --- a/ChaCha20_Poly1305_64/sim/poly1305.yaml +++ b/ChaCha20_Poly1305_64/sim/poly1305.yaml @@ -5,9 +5,15 @@ tests: - "poly1305_core" sources: "sources.list" waves: True - - name: "friendly_modulo" + - name: "poly1305_friendly_modulo" toplevel: "poly1305_friendly_modulo" modules: - "poly1305_friendly_modulo" sources: sources.list + waves: True + - name: "poly1305_friendly_modular_mult" + toplevel: "poly1305_friendly_modular_mult" + modules: + - "poly1305_friendly_modular_mult" + sources: sources.list waves: True \ No newline at end of file diff --git a/ChaCha20_Poly1305_64/sim/poly1305_friendly_modular_mult.py b/ChaCha20_Poly1305_64/sim/poly1305_friendly_modular_mult.py new file mode 100644 index 0000000..180f00b --- /dev/null +++ b/ChaCha20_Poly1305_64/sim/poly1305_friendly_modular_mult.py @@ -0,0 +1,96 @@ +import logging + + +import cocotb +from cocotb.clock import Clock +from cocotb.triggers import Timer, RisingEdge, FallingEdge +from cocotb.queue import Queue + +from cocotbext.axi import AxiStreamBus, AxiStreamSource + +import random + +PRIME = 2**130-5 + +CLK_PERIOD = 4 + + +class TB: + def __init__(self, dut): + self.dut = dut + + self.log = logging.getLogger("cocotb.tb") + self.log.setLevel(logging.INFO) + + self.input_queue = Queue() + + self.expected_queue = Queue() + self.output_queue = Queue() + + cocotb.start_soon(Clock(self.dut.i_clk, CLK_PERIOD, units="ns").start()) + + cocotb.start_soon(self.run_input()) + cocotb.start_soon(self.run_output()) + + async def cycle_reset(self): + await self._cycle_reset(self.dut.i_rst, self.dut.i_clk) + + async def _cycle_reset(self, rst, clk): + rst.setimmediatevalue(0) + await RisingEdge(clk) + await RisingEdge(clk) + rst.value = 1 + await RisingEdge(clk) + await RisingEdge(clk) + rst.value = 0 + await RisingEdge(clk) + await RisingEdge(clk) + + async def write_input(self, data: int, h: int): + await self.input_queue.put((data, h)) + await self.expected_queue.put((data * h) % PRIME) + + async def run_input(self): + while True: + data, h = await self.input_queue.get() + self.dut.i_valid.value = 1 + self.dut.i_data.value = data + self.dut.i_accumulator.value = h + while True: + await RisingEdge(self.dut.i_clk) + if (self.dut.o_ready.value == 1): + break + self.dut.i_valid.value = 0 + self.dut.i_data.value = 0 + self.dut.i_accumulator.value = 0 + + async def run_output(self): + while True: + await RisingEdge(self.dut.i_clk) + if self.dut.o_valid.value: + await self.output_queue.put(self.dut.o_result.value.integer) + +@cocotb.test +async def test_sanity(dut): + tb = TB(dut) + + await tb.cycle_reset() + + count = 1 + + for _ in range(count): + await tb.write_input(random.randint(1,2**128-1), random.randint(0, 2**130-6)) + + fail = False + + for _ in range(count): + sim_val = await tb.expected_queue.get() + dut_val = await tb.output_queue.get() + + if sim_val != dut_val: + tb.log.info(f"{sim_val:x} -> {dut_val:x}") + fail = True + + # assert not fail + + await Timer(1, "us") \ No newline at end of file diff --git a/ChaCha20_Poly1305_64/src/poly1305_friendly_modular_mult.sv b/ChaCha20_Poly1305_64/src/poly1305_friendly_modular_mult.sv new file mode 100644 index 0000000..7a9dc1e --- /dev/null +++ b/ChaCha20_Poly1305_64/src/poly1305_friendly_modular_mult.sv @@ -0,0 +1,101 @@ +module poly1305_friendly_modular_mult #( + parameter DATA_WIDTH = 128, + parameter ACC_WIDTH = 130 +) ( + input logic i_clk, + input logic i_rst, + + input logic i_valid, + output logic o_ready, + input logic [DATA_WIDTH-1:0] i_data, + input logic [ACC_WIDTH-1:0] i_accumulator, + + output logic o_valid, + output logic [ACC_WIDTH-1:0] o_result +); + +localparam [129:0] PRIME = (1 << 130) - 5; + +logic [2:0] state_counter, state_counter_next; + +logic [2:0] state_counter_p [4]; + +logic [ACC_WIDTH-1:0] accumulator, accumulator_next; // accumulator is outgoing + +logic [DATA_WIDTH-1:0] data, data_next; +logic [ACC_WIDTH-1:0] h, h_next; // h is incoming + +logic [DATA_WIDTH+26-1:0] mult_product, mult_product_next; + +logic [ACC_WIDTH-1:0] modulo_result; + +assign o_ready = state_counter >= 3'h4; +assign o_result = accumulator; + +always_ff @(posedge i_clk) begin + if (i_rst) begin + state_counter <= 3'h5; + state_counter_p <= '{default: 3'h5}; + end else begin + state_counter <= state_counter_next; + accumulator <= accumulator_next; + data <= data_next; + h <= h_next; + mult_product <= mult_product_next; + + state_counter_p[0] <= state_counter; + + o_valid <= state_counter_p[3] == 3'h4; + + for (int i = 1; i < 4; i++) begin + state_counter_p[i] <= state_counter_p[i-1]; + end + end +end + +always_comb begin + data_next = data; + h_next = h; + + state_counter_next = state_counter; + + accumulator_next = '0; + mult_product_next = '0; + + + if (state_counter >= 3'h4 && i_valid) begin + data_next = i_data; + h_next = i_accumulator; + state_counter_next = '0; + end + + if (state_counter < 3'h5) begin + mult_product_next = h[state_counter*26 +: 26] * data; + state_counter_next = state_counter + 1; + end + + if (state_counter_p[3] == '0) begin + accumulator_next = modulo_result; + end else begin + if (accumulator + modulo_result > PRIME) begin + accumulator_next = accumulator + modulo_result - PRIME; + end else begin + accumulator_next = accumulator + modulo_result; + end + end + +end + +poly1305_friendly_modulo u_modulo ( + .i_clk (i_clk), + .i_rst (i_rst), + + .i_valid ('1), + .i_val ((2*ACC_WIDTH)'(mult_product)), + .i_shift_amount (state_counter_p[0]), + + .o_valid (), + .o_result (modulo_result) +); + +endmodule \ No newline at end of file diff --git a/ChaCha20_Poly1305_64/src/sources.list b/ChaCha20_Poly1305_64/src/sources.list index 4aac61c..2954a74 100644 --- a/ChaCha20_Poly1305_64/src/sources.list +++ b/ChaCha20_Poly1305_64/src/sources.list @@ -4,4 +4,5 @@ chacha20_pipelined_round.sv chacha20_pipelined_block.sv poly1305_core.sv -poly1305_friendly_modulo.sv \ No newline at end of file +poly1305_friendly_modulo.sv +poly1305_friendly_modular_mult.sv \ No newline at end of file