Add poly1305 stage

This commit is contained in:
Byron Lathi
2025-11-01 20:53:02 -07:00
parent 2102cb41f4
commit d5035c6c81
8 changed files with 356 additions and 5 deletions

View File

@@ -34,3 +34,33 @@ we need
r\*r = r^2 r\*r = r^2
r\*r^2 = r^3 r^2\*r^2 = r^4 r\*r^2 = r^3 r^2\*r^2 = r^4
r^4\*r = r^5 r^2\*r^4 = r^6 r^3\*r^4 = r^7 r^4\*r^4 = r^8 r^4\*r = r^5 r^2\*r^4 = r^6 r^3\*r^4 = r^7 r^4\*r^4 = r^8
we can do all of these in parallel, so we 4 (n/2) multiply blocks that feed back
on themselves, with some kind of FSM to control it. This can be done while another
block is being hashed, but there will be a delay between when the key is ready from
the chacha block and when the powers are ready, so there needs to be a fifo in between.
Basically we have to wait until we see that the accumulator was written with our index.
At reset though, the acumulator is unwritten? So we need to pretend that it was written
Lets just write out what we want to happen:
1. The index starts at 0. We accept new data, and send it through the pipeline
2. We increment the index to 1.
3. We accept new data and send it through the pipeline
4. We increment the index to 2
5. We need to wait until the index 0 is written before we can say we are ready
6. If the index 1 is written then we still need to say we are ready though
7. We can just use the 1 to indicate that is a valid write then?
So in the shift register we just need to say whether it is a valid write or not,
so always 1?
But if we send in 0, then send in 1, then the current index will be 0
and eventually the final index will always be 0. We need to store what
the last written one is.
We can just say the last written one was 2 I guess
We also need an input that tells it to reset the accumulator

View File

@@ -0,0 +1,59 @@
<mxfile host="Electron" agent="Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) draw.io/26.2.2 Chrome/134.0.6998.178 Electron/35.1.2 Safari/537.36" version="26.2.2">
<diagram name="Page-1" id="b4c9RxKzofB-lxyaVzG6">
<mxGraphModel dx="616" dy="416" grid="1" gridSize="10" guides="1" tooltips="1" connect="1" arrows="1" fold="1" page="1" pageScale="1" pageWidth="850" pageHeight="1100" math="0" shadow="0">
<root>
<mxCell id="0" />
<mxCell id="1" parent="0" />
<mxCell id="3x5Ie6wAwAZYy6GZGmB0-7" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;entryX=1;entryY=0.5;entryDx=0;entryDy=0;" edge="1" parent="1" source="3x5Ie6wAwAZYy6GZGmB0-1" target="3x5Ie6wAwAZYy6GZGmB0-6">
<mxGeometry relative="1" as="geometry">
<Array as="points">
<mxPoint x="500" y="160" />
<mxPoint x="500" y="230" />
</Array>
</mxGeometry>
</mxCell>
<mxCell id="3x5Ie6wAwAZYy6GZGmB0-17" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;" edge="1" parent="1" source="3x5Ie6wAwAZYy6GZGmB0-1">
<mxGeometry relative="1" as="geometry">
<mxPoint x="560" y="160" as="targetPoint" />
</mxGeometry>
</mxCell>
<mxCell id="3x5Ie6wAwAZYy6GZGmB0-1" value="Modular Multiplier (10 cycle latency)" style="rounded=0;whiteSpace=wrap;html=1;" vertex="1" parent="1">
<mxGeometry x="280" y="120" width="200" height="80" as="geometry" />
</mxCell>
<mxCell id="3x5Ie6wAwAZYy6GZGmB0-11" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;entryX=0.5;entryY=1;entryDx=0;entryDy=0;" edge="1" parent="1" source="3x5Ie6wAwAZYy6GZGmB0-6" target="3x5Ie6wAwAZYy6GZGmB0-10">
<mxGeometry relative="1" as="geometry" />
</mxCell>
<mxCell id="3x5Ie6wAwAZYy6GZGmB0-6" value="H temp" style="rounded=0;whiteSpace=wrap;html=1;" vertex="1" parent="1">
<mxGeometry x="345" y="210" width="80" height="40" as="geometry" />
</mxCell>
<mxCell id="3x5Ie6wAwAZYy6GZGmB0-12" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;entryX=0;entryY=0.75;entryDx=0;entryDy=0;" edge="1" parent="1" source="3x5Ie6wAwAZYy6GZGmB0-10" target="3x5Ie6wAwAZYy6GZGmB0-1">
<mxGeometry relative="1" as="geometry" />
</mxCell>
<mxCell id="3x5Ie6wAwAZYy6GZGmB0-10" value="+" style="ellipse;whiteSpace=wrap;html=1;aspect=fixed;" vertex="1" parent="1">
<mxGeometry x="200" y="160" width="40" height="40" as="geometry" />
</mxCell>
<mxCell id="3x5Ie6wAwAZYy6GZGmB0-13" value="" style="endArrow=classic;html=1;rounded=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;" edge="1" parent="1" target="3x5Ie6wAwAZYy6GZGmB0-10">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="160" y="180" as="sourcePoint" />
<mxPoint x="170" y="140" as="targetPoint" />
</mxGeometry>
</mxCell>
<mxCell id="3x5Ie6wAwAZYy6GZGmB0-14" value="message" style="edgeLabel;html=1;align=center;verticalAlign=middle;resizable=0;points=[];" vertex="1" connectable="0" parent="3x5Ie6wAwAZYy6GZGmB0-13">
<mxGeometry x="-0.1731" y="1" relative="1" as="geometry">
<mxPoint x="-46" y="1" as="offset" />
</mxGeometry>
</mxCell>
<mxCell id="3x5Ie6wAwAZYy6GZGmB0-16" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;entryX=0;entryY=0.25;entryDx=0;entryDy=0;" edge="1" parent="1" source="3x5Ie6wAwAZYy6GZGmB0-15" target="3x5Ie6wAwAZYy6GZGmB0-1">
<mxGeometry relative="1" as="geometry">
<Array as="points">
<mxPoint x="220" y="140" />
</Array>
</mxGeometry>
</mxCell>
<mxCell id="3x5Ie6wAwAZYy6GZGmB0-15" value="r" style="shape=trapezoid;perimeter=trapezoidPerimeter;whiteSpace=wrap;html=1;fixedSize=1;rotation=0;flipV=1;size=10;" vertex="1" parent="1">
<mxGeometry x="190" y="90" width="60" height="30" as="geometry" />
</mxCell>
</root>
</mxGraphModel>
</diagram>
</mxfile>

View File

@@ -17,3 +17,9 @@ tests:
- "poly1305_friendly_modular_mult" - "poly1305_friendly_modular_mult"
sources: sources.list sources: sources.list
waves: True waves: True
- name: "poly1305_stage"
toplevel: "poly1305_stage"
modules:
- "poly1305_stage"
sources: sources.list
waves: True

View File

@@ -78,7 +78,7 @@ async def test_sanity(dut):
await tb.cycle_reset() await tb.cycle_reset()
count = 1024 count = 1
for _ in range(count): for _ in range(count):
await tb.write_input(random.randint(1,2**128-1), random.randint(0, 2**130-6)) await tb.write_input(random.randint(1,2**128-1), random.randint(0, 2**130-6))

View File

@@ -0,0 +1,110 @@
import logging
import cocotb
from cocotb.clock import Clock
from cocotb.triggers import Timer, RisingEdge, FallingEdge
from cocotb.queue import Queue
from cocotbext.axi import AxiStreamBus, AxiStreamSource
from modulo_theory import friendly_modular_mult
import random
PRIME = 2**130-5
CLK_PERIOD = 4
class TB:
def __init__(self, dut):
self.dut = dut
self.log = logging.getLogger("cocotb.tb")
self.log.setLevel(logging.INFO)
self.input_queue = Queue()
self.expected_queue = Queue()
self.output_queue = Queue()
cocotb.start_soon(Clock(self.dut.i_clk, CLK_PERIOD, units="ns").start())
cocotb.start_soon(self.run_input())
cocotb.start_soon(self.run_output())
self.index = 0
self.accumulators = [0, 0]
async def cycle_reset(self):
await self._cycle_reset(self.dut.i_rst, self.dut.i_clk)
async def _cycle_reset(self, rst, clk):
rst.setimmediatevalue(0)
await RisingEdge(clk)
await RisingEdge(clk)
rst.value = 1
await RisingEdge(clk)
await RisingEdge(clk)
rst.value = 0
await RisingEdge(clk)
await RisingEdge(clk)
async def write_input(self, msg: int, r_power: int, clear_acc: int):
await self.input_queue.put((msg, r_power, clear_acc))
if clear_acc:
expected_result = friendly_modular_mult((msg) % PRIME, r_power)
else:
expected_result = friendly_modular_mult((msg + self.accumulators[self.index]) % PRIME, r_power)
self.accumulators[self.index] = expected_result
await self.expected_queue.put(expected_result)
self.index = 1 if self.index == 0 else 0
async def run_input(self):
while True:
msg, r_power, clear_acc = await self.input_queue.get()
self.dut.i_valid.value = 1
self.dut.i_r_power.value = r_power
self.dut.i_message.value = msg
self.dut.i_clear_acc.value = clear_acc
while True:
await RisingEdge(self.dut.i_clk)
if (self.dut.o_ready.value == 1):
break
self.dut.i_valid.value = 0
self.dut.i_r_power.value = 0
self.dut.i_message.value = 0
self.dut.i_clear_acc.value = 0
async def run_output(self):
while True:
await RisingEdge(self.dut.i_clk)
if self.dut.o_valid.value:
await self.output_queue.put(self.dut.o_result.value.integer)
@cocotb.test
async def test_sanity(dut):
tb = TB(dut)
await tb.cycle_reset()
count = 1024
for _ in range(count):
clr = 1 if random.randint(0,10) == 0 else 0
await tb.write_input(random.randint(1,2**128-1), random.randint(0, 2**130-6), clr)
fail = False
for _ in range(count):
sim_val = await tb.expected_queue.get()
dut_val = await tb.output_queue.get()
if sim_val != dut_val:
tb.log.info(f"{_} {sim_val:x} -> {dut_val:x}")
fail = True
assert not fail

View File

@@ -1,5 +1,5 @@
module poly1305_friendly_modular_mult #( module poly1305_friendly_modular_mult #(
parameter DATA_WIDTH = 131, parameter DATA_WIDTH = 130,
parameter ACC_WIDTH = 130 parameter ACC_WIDTH = 130
) ( ) (
input logic i_clk, input logic i_clk,
@@ -54,7 +54,7 @@ always_ff @(posedge i_clk) begin
end end
always_comb begin always_comb begin
data_next = data; data_next = data; // If neccesary, we can remove this state?
h_next = h; h_next = h;
state_counter_next = state_counter; state_counter_next = state_counter;

View File

@@ -0,0 +1,145 @@
module poly1305_stage #(
) (
input logic i_clk,
input logic i_rst,
input logic i_valid,
output logic o_ready,
input logic i_clear_acc,
input logic [129:0] i_r_power,
input logic [127:0] i_message,
output logic o_valid,
output logic [129:0] o_result
);
localparam [129:0] PRIME = (1 << 130) - 5;
logic mult_i_valid;
logic mult_o_ready;
logic [129:0] r_power, r_power_next;
logic [130:0] mult_accumulator, mult_accumulator_next;
logic mult_valid;
logic [129:0] mult_result;
logic [129:0] accumulators [2];
logic [129:0] accumulators_next [2];
logic [1:0] ops_in_flight, ops_in_flight_next;
logic [1:0] index, index_next;
logic [1:0] index_sr [16];
logic [1:0] index_sr_next;
enum logic [1:0] {SUM1, SUM2, MUL, STORE} state, state_next;
always_ff @(posedge i_clk) begin
if (i_rst) begin
state <= SUM1;
r_power <= '0;
mult_accumulator <= '0;
ops_in_flight <= 2'h0;
index <= 2'h2;
index_sr <= '{default: '0};
for (int i = 0; i < 2; i++) begin
accumulators[i] <= '0;
end
end else begin
state <= state_next;
r_power <= r_power_next;
mult_accumulator <= mult_accumulator_next;
index <= index_next;
ops_in_flight <= ops_in_flight_next;
for (int i = 0; i < 2; i++) begin
accumulators[i] <= accumulators_next[i];
end
index_sr[0] <= index;
for (int i = 1; i < 16; i++) begin
index_sr[i] <= index_sr[i-1];
end
end
end
logic [1:0] ops_in_flight_state;
assign ops_in_flight_state = {i_valid & o_ready, mult_valid};
always_comb begin
state_next = state;
o_ready = '0;
mult_accumulator_next = mult_accumulator;
mult_i_valid = '0;
index_next = index;
if (mult_valid) begin
accumulators_next[index_sr[12][0]] = mult_result;
end
case (ops_in_flight_state)
2'b00: ops_in_flight_next = ops_in_flight;
2'b01: ops_in_flight_next = ops_in_flight - 1;
2'b10: ops_in_flight_next = ops_in_flight + 1;
2'b11: ops_in_flight_next = ops_in_flight;
endcase
o_valid = mult_valid;
o_result = mult_result;
case (state)
SUM1: begin
o_ready = ops_in_flight < 2;
if (i_valid && o_ready) begin
r_power_next = i_r_power;
if (i_clear_acc) begin
mult_accumulator_next = {3'b0, i_message};
end else begin
mult_accumulator_next = accumulators[index[0]] + {3'b0, i_message};
end
state_next = SUM2;
end
end
SUM2: begin
mult_accumulator_next = mult_accumulator >= 131'(PRIME) ? 131'(mult_accumulator - PRIME) : mult_accumulator;
state_next = MUL;
end
MUL: begin
mult_i_valid = '1;
if (mult_o_ready) begin
index_next = index + 1;
state_next = SUM1;
end
end
default: begin
state_next = SUM1;
end
endcase
end
poly1305_friendly_modular_mult u_modular_mult (
.i_clk (i_clk),
.i_rst (i_rst),
.i_valid (mult_i_valid),
.o_ready (mult_o_ready),
.i_data (r_power),
.i_accumulator (mult_accumulator[129:0]),
.o_valid (mult_valid),
.o_result (mult_result)
);
endmodule

View File

@@ -6,3 +6,4 @@ chacha20_pipelined_block.sv
poly1305_core.sv poly1305_core.sv
poly1305_friendly_modulo.sv poly1305_friendly_modulo.sv
poly1305_friendly_modular_mult.sv poly1305_friendly_modular_mult.sv
poly1305_stage.sv