Compare commits

..

4 Commits

Author SHA1 Message Date
Byron Lathi
d5035c6c81 Add poly1305 stage 2025-11-01 20:53:02 -07:00
Byron Lathi
2102cb41f4 Fix bug where top 2 bits were getting lost in the modulo 2025-10-30 22:34:52 -07:00
Byron Lathi
d6a062baa0 Make modular mult work 2025-10-28 21:59:28 -07:00
Byron Lathi
ad257f4220 Add mult, but it doesn't quite work 2025-10-28 08:27:36 -07:00
13 changed files with 728 additions and 5 deletions

View File

@@ -0,0 +1,132 @@
<mxfile host="Electron" agent="Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) draw.io/26.2.2 Chrome/134.0.6998.178 Electron/35.1.2 Safari/537.36" version="26.2.2">
<diagram name="Page-1" id="b4c9RxKzofB-lxyaVzG6">
<mxGraphModel dx="289" dy="195" grid="1" gridSize="10" guides="1" tooltips="1" connect="1" arrows="1" fold="1" page="1" pageScale="1" pageWidth="850" pageHeight="1100" math="0" shadow="0">
<root>
<mxCell id="0" />
<mxCell id="1" parent="0" />
<mxCell id="yBq3zbYGeky0_LNz2CMc-4" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;entryX=0.5;entryY=0;entryDx=0;entryDy=0;" edge="1" parent="1" source="yBq3zbYGeky0_LNz2CMc-1" target="yBq3zbYGeky0_LNz2CMc-2">
<mxGeometry relative="1" as="geometry">
<Array as="points">
<mxPoint x="680" y="100" />
<mxPoint x="580" y="100" />
</Array>
</mxGeometry>
</mxCell>
<mxCell id="yBq3zbYGeky0_LNz2CMc-28" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;" edge="1" parent="1" source="yBq3zbYGeky0_LNz2CMc-1">
<mxGeometry relative="1" as="geometry">
<mxPoint x="760" y="140" as="targetPoint" />
</mxGeometry>
</mxCell>
<mxCell id="yBq3zbYGeky0_LNz2CMc-29" value="result" style="edgeLabel;html=1;align=center;verticalAlign=middle;resizable=0;points=[];" vertex="1" connectable="0" parent="yBq3zbYGeky0_LNz2CMc-28">
<mxGeometry x="0.7628" y="1" relative="1" as="geometry">
<mxPoint x="25" y="1" as="offset" />
</mxGeometry>
</mxCell>
<mxCell id="yBq3zbYGeky0_LNz2CMc-1" value="accumulator w/ wrap" style="rounded=0;whiteSpace=wrap;html=1;" vertex="1" parent="1">
<mxGeometry x="640" y="120" width="80" height="40" as="geometry" />
</mxCell>
<mxCell id="yBq3zbYGeky0_LNz2CMc-3" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;entryX=0;entryY=0.5;entryDx=0;entryDy=0;" edge="1" parent="1" source="yBq3zbYGeky0_LNz2CMc-2" target="yBq3zbYGeky0_LNz2CMc-1">
<mxGeometry relative="1" as="geometry" />
</mxCell>
<mxCell id="yBq3zbYGeky0_LNz2CMc-2" value="+" style="ellipse;whiteSpace=wrap;html=1;aspect=fixed;" vertex="1" parent="1">
<mxGeometry x="560" y="120" width="40" height="40" as="geometry" />
</mxCell>
<mxCell id="yBq3zbYGeky0_LNz2CMc-6" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;entryX=0;entryY=0.5;entryDx=0;entryDy=0;" edge="1" parent="1" source="yBq3zbYGeky0_LNz2CMc-5" target="yBq3zbYGeky0_LNz2CMc-2">
<mxGeometry relative="1" as="geometry" />
</mxCell>
<mxCell id="yBq3zbYGeky0_LNz2CMc-5" value="Friendly Modulo" style="rounded=0;whiteSpace=wrap;html=1;" vertex="1" parent="1">
<mxGeometry x="360" y="80" width="160" height="120" as="geometry" />
</mxCell>
<mxCell id="yBq3zbYGeky0_LNz2CMc-14" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;entryX=0;entryY=0.5;entryDx=0;entryDy=0;" edge="1" parent="1" source="yBq3zbYGeky0_LNz2CMc-7" target="yBq3zbYGeky0_LNz2CMc-13">
<mxGeometry relative="1" as="geometry" />
</mxCell>
<mxCell id="yBq3zbYGeky0_LNz2CMc-7" value="X" style="ellipse;whiteSpace=wrap;html=1;aspect=fixed;" vertex="1" parent="1">
<mxGeometry x="200" y="120" width="40" height="40" as="geometry" />
</mxCell>
<mxCell id="yBq3zbYGeky0_LNz2CMc-9" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;entryX=0.5;entryY=0;entryDx=0;entryDy=0;" edge="1" parent="1" source="yBq3zbYGeky0_LNz2CMc-8" target="yBq3zbYGeky0_LNz2CMc-7">
<mxGeometry relative="1" as="geometry" />
</mxCell>
<mxCell id="yBq3zbYGeky0_LNz2CMc-8" value="Data (128 bit)" style="rounded=0;whiteSpace=wrap;html=1;" vertex="1" parent="1">
<mxGeometry x="40" y="80" width="120" height="40" as="geometry" />
</mxCell>
<mxCell id="yBq3zbYGeky0_LNz2CMc-11" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;entryX=0.5;entryY=1;entryDx=0;entryDy=0;" edge="1" parent="1" source="yBq3zbYGeky0_LNz2CMc-10" target="yBq3zbYGeky0_LNz2CMc-7">
<mxGeometry relative="1" as="geometry" />
</mxCell>
<mxCell id="yBq3zbYGeky0_LNz2CMc-10" value="h (26x5 bit)" style="rounded=0;whiteSpace=wrap;html=1;" vertex="1" parent="1">
<mxGeometry x="40" y="160" width="120" height="40" as="geometry" />
</mxCell>
<mxCell id="yBq3zbYGeky0_LNz2CMc-15" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;entryX=0;entryY=0.5;entryDx=0;entryDy=0;" edge="1" parent="1" source="yBq3zbYGeky0_LNz2CMc-13" target="yBq3zbYGeky0_LNz2CMc-5">
<mxGeometry relative="1" as="geometry" />
</mxCell>
<mxCell id="yBq3zbYGeky0_LNz2CMc-13" value="pipe reg" style="rounded=0;whiteSpace=wrap;html=1;" vertex="1" parent="1">
<mxGeometry x="280" y="120" width="40" height="40" as="geometry" />
</mxCell>
<mxCell id="yBq3zbYGeky0_LNz2CMc-25" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;entryX=0;entryY=0.5;entryDx=0;entryDy=0;" edge="1" parent="1" source="yBq3zbYGeky0_LNz2CMc-16" target="yBq3zbYGeky0_LNz2CMc-22">
<mxGeometry relative="1" as="geometry" />
</mxCell>
<mxCell id="yBq3zbYGeky0_LNz2CMc-16" value="pipe reg" style="whiteSpace=wrap;html=1;aspect=fixed;" vertex="1" parent="1">
<mxGeometry x="280" y="160" width="40" height="40" as="geometry" />
</mxCell>
<mxCell id="yBq3zbYGeky0_LNz2CMc-19" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;entryX=0.5;entryY=1;entryDx=0;entryDy=0;" edge="1" parent="1" source="yBq3zbYGeky0_LNz2CMc-18" target="yBq3zbYGeky0_LNz2CMc-10">
<mxGeometry relative="1" as="geometry" />
</mxCell>
<mxCell id="yBq3zbYGeky0_LNz2CMc-20" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;entryX=0;entryY=0.5;entryDx=0;entryDy=0;" edge="1" parent="1" source="yBq3zbYGeky0_LNz2CMc-18" target="yBq3zbYGeky0_LNz2CMc-16">
<mxGeometry relative="1" as="geometry">
<Array as="points">
<mxPoint x="100" y="220" />
<mxPoint x="260" y="220" />
<mxPoint x="260" y="180" />
</Array>
</mxGeometry>
</mxCell>
<mxCell id="yBq3zbYGeky0_LNz2CMc-18" value="state counter" style="rounded=0;whiteSpace=wrap;html=1;" vertex="1" parent="1">
<mxGeometry x="40" y="240" width="120" height="40" as="geometry" />
</mxCell>
<mxCell id="yBq3zbYGeky0_LNz2CMc-21" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;entryX=0;entryY=0.833;entryDx=0;entryDy=0;entryPerimeter=0;" edge="1" parent="1" source="yBq3zbYGeky0_LNz2CMc-16" target="yBq3zbYGeky0_LNz2CMc-5">
<mxGeometry relative="1" as="geometry" />
</mxCell>
<mxCell id="yBq3zbYGeky0_LNz2CMc-22" value="pipe reg" style="whiteSpace=wrap;html=1;aspect=fixed;" vertex="1" parent="1">
<mxGeometry x="360" y="240" width="40" height="40" as="geometry" />
</mxCell>
<mxCell id="yBq3zbYGeky0_LNz2CMc-23" value="pipe reg" style="whiteSpace=wrap;html=1;aspect=fixed;" vertex="1" parent="1">
<mxGeometry x="400" y="240" width="40" height="40" as="geometry" />
</mxCell>
<mxCell id="yBq3zbYGeky0_LNz2CMc-26" value="" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;entryX=0.5;entryY=1;entryDx=0;entryDy=0;" edge="1" parent="1" source="yBq3zbYGeky0_LNz2CMc-24" target="yBq3zbYGeky0_LNz2CMc-1">
<mxGeometry relative="1" as="geometry" />
</mxCell>
<mxCell id="yBq3zbYGeky0_LNz2CMc-27" value="reset" style="edgeLabel;html=1;align=center;verticalAlign=middle;resizable=0;points=[];" vertex="1" connectable="0" parent="yBq3zbYGeky0_LNz2CMc-26">
<mxGeometry x="-0.2" y="1" relative="1" as="geometry">
<mxPoint x="-10" y="-9" as="offset" />
</mxGeometry>
</mxCell>
<mxCell id="yBq3zbYGeky0_LNz2CMc-31" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;entryX=0;entryY=0.5;entryDx=0;entryDy=0;" edge="1" parent="1" source="yBq3zbYGeky0_LNz2CMc-24" target="yBq3zbYGeky0_LNz2CMc-30">
<mxGeometry relative="1" as="geometry">
<Array as="points">
<mxPoint x="680" y="260" />
<mxPoint x="680" y="200" />
</Array>
</mxGeometry>
</mxCell>
<mxCell id="yBq3zbYGeky0_LNz2CMc-24" value="pipe reg" style="whiteSpace=wrap;html=1;aspect=fixed;" vertex="1" parent="1">
<mxGeometry x="440" y="240" width="40" height="40" as="geometry" />
</mxCell>
<mxCell id="yBq3zbYGeky0_LNz2CMc-32" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;" edge="1" parent="1" source="yBq3zbYGeky0_LNz2CMc-30">
<mxGeometry relative="1" as="geometry">
<mxPoint x="800" y="200" as="targetPoint" />
<Array as="points">
<mxPoint x="800" y="200" />
</Array>
</mxGeometry>
</mxCell>
<mxCell id="yBq3zbYGeky0_LNz2CMc-33" value="done flag" style="edgeLabel;html=1;align=center;verticalAlign=middle;resizable=0;points=[];" vertex="1" connectable="0" parent="yBq3zbYGeky0_LNz2CMc-32">
<mxGeometry x="0.1299" y="1" relative="1" as="geometry">
<mxPoint x="7" y="-19" as="offset" />
</mxGeometry>
</mxCell>
<mxCell id="yBq3zbYGeky0_LNz2CMc-30" value="==4?" style="rounded=0;whiteSpace=wrap;html=1;" vertex="1" parent="1">
<mxGeometry x="700" y="180" width="60" height="40" as="geometry" />
</mxCell>
</root>
</mxGraphModel>
</diagram>
</mxfile>

View File

@@ -33,4 +33,34 @@ we need
r\*r = r^2 r\*r = r^2
r\*r^2 = r^3 r^2\*r^2 = r^4 r\*r^2 = r^3 r^2\*r^2 = r^4
r^4\*r = r^5 r^2\*r^4 = r^6 r^3\*r^4 = r^7 r^4\*r^4 = r^8 r^4\*r = r^5 r^2\*r^4 = r^6 r^3\*r^4 = r^7 r^4\*r^4 = r^8
we can do all of these in parallel, so we 4 (n/2) multiply blocks that feed back
on themselves, with some kind of FSM to control it. This can be done while another
block is being hashed, but there will be a delay between when the key is ready from
the chacha block and when the powers are ready, so there needs to be a fifo in between.
Basically we have to wait until we see that the accumulator was written with our index.
At reset though, the acumulator is unwritten? So we need to pretend that it was written
Lets just write out what we want to happen:
1. The index starts at 0. We accept new data, and send it through the pipeline
2. We increment the index to 1.
3. We accept new data and send it through the pipeline
4. We increment the index to 2
5. We need to wait until the index 0 is written before we can say we are ready
6. If the index 1 is written then we still need to say we are ready though
7. We can just use the 1 to indicate that is a valid write then?
So in the shift register we just need to say whether it is a valid write or not,
so always 1?
But if we send in 0, then send in 1, then the current index will be 0
and eventually the final index will always be 0. We need to store what
the last written one is.
We can just say the last written one was 2 I guess
We also need an input that tells it to reset the accumulator

View File

@@ -0,0 +1,59 @@
<mxfile host="Electron" agent="Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) draw.io/26.2.2 Chrome/134.0.6998.178 Electron/35.1.2 Safari/537.36" version="26.2.2">
<diagram name="Page-1" id="b4c9RxKzofB-lxyaVzG6">
<mxGraphModel dx="616" dy="416" grid="1" gridSize="10" guides="1" tooltips="1" connect="1" arrows="1" fold="1" page="1" pageScale="1" pageWidth="850" pageHeight="1100" math="0" shadow="0">
<root>
<mxCell id="0" />
<mxCell id="1" parent="0" />
<mxCell id="3x5Ie6wAwAZYy6GZGmB0-7" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;entryX=1;entryY=0.5;entryDx=0;entryDy=0;" edge="1" parent="1" source="3x5Ie6wAwAZYy6GZGmB0-1" target="3x5Ie6wAwAZYy6GZGmB0-6">
<mxGeometry relative="1" as="geometry">
<Array as="points">
<mxPoint x="500" y="160" />
<mxPoint x="500" y="230" />
</Array>
</mxGeometry>
</mxCell>
<mxCell id="3x5Ie6wAwAZYy6GZGmB0-17" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;" edge="1" parent="1" source="3x5Ie6wAwAZYy6GZGmB0-1">
<mxGeometry relative="1" as="geometry">
<mxPoint x="560" y="160" as="targetPoint" />
</mxGeometry>
</mxCell>
<mxCell id="3x5Ie6wAwAZYy6GZGmB0-1" value="Modular Multiplier (10 cycle latency)" style="rounded=0;whiteSpace=wrap;html=1;" vertex="1" parent="1">
<mxGeometry x="280" y="120" width="200" height="80" as="geometry" />
</mxCell>
<mxCell id="3x5Ie6wAwAZYy6GZGmB0-11" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;entryX=0.5;entryY=1;entryDx=0;entryDy=0;" edge="1" parent="1" source="3x5Ie6wAwAZYy6GZGmB0-6" target="3x5Ie6wAwAZYy6GZGmB0-10">
<mxGeometry relative="1" as="geometry" />
</mxCell>
<mxCell id="3x5Ie6wAwAZYy6GZGmB0-6" value="H temp" style="rounded=0;whiteSpace=wrap;html=1;" vertex="1" parent="1">
<mxGeometry x="345" y="210" width="80" height="40" as="geometry" />
</mxCell>
<mxCell id="3x5Ie6wAwAZYy6GZGmB0-12" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;entryX=0;entryY=0.75;entryDx=0;entryDy=0;" edge="1" parent="1" source="3x5Ie6wAwAZYy6GZGmB0-10" target="3x5Ie6wAwAZYy6GZGmB0-1">
<mxGeometry relative="1" as="geometry" />
</mxCell>
<mxCell id="3x5Ie6wAwAZYy6GZGmB0-10" value="+" style="ellipse;whiteSpace=wrap;html=1;aspect=fixed;" vertex="1" parent="1">
<mxGeometry x="200" y="160" width="40" height="40" as="geometry" />
</mxCell>
<mxCell id="3x5Ie6wAwAZYy6GZGmB0-13" value="" style="endArrow=classic;html=1;rounded=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;" edge="1" parent="1" target="3x5Ie6wAwAZYy6GZGmB0-10">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="160" y="180" as="sourcePoint" />
<mxPoint x="170" y="140" as="targetPoint" />
</mxGeometry>
</mxCell>
<mxCell id="3x5Ie6wAwAZYy6GZGmB0-14" value="message" style="edgeLabel;html=1;align=center;verticalAlign=middle;resizable=0;points=[];" vertex="1" connectable="0" parent="3x5Ie6wAwAZYy6GZGmB0-13">
<mxGeometry x="-0.1731" y="1" relative="1" as="geometry">
<mxPoint x="-46" y="1" as="offset" />
</mxGeometry>
</mxCell>
<mxCell id="3x5Ie6wAwAZYy6GZGmB0-16" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;entryX=0;entryY=0.25;entryDx=0;entryDy=0;" edge="1" parent="1" source="3x5Ie6wAwAZYy6GZGmB0-15" target="3x5Ie6wAwAZYy6GZGmB0-1">
<mxGeometry relative="1" as="geometry">
<Array as="points">
<mxPoint x="220" y="140" />
</Array>
</mxGeometry>
</mxCell>
<mxCell id="3x5Ie6wAwAZYy6GZGmB0-15" value="r" style="shape=trapezoid;perimeter=trapezoidPerimeter;whiteSpace=wrap;html=1;fixedSize=1;rotation=0;flipV=1;size=10;" vertex="1" parent="1">
<mxGeometry x="190" y="90" width="60" height="30" as="geometry" />
</mxCell>
</root>
</mxGraphModel>
</diagram>
</mxfile>

View File

@@ -1,4 +1,5 @@
from typing import List from typing import List
import random
from modulo_theory import friendly_modular_mult, friendly_modulo from modulo_theory import friendly_modular_mult, friendly_modulo
@@ -111,11 +112,19 @@ def test_on_long_string():
print(f"{regular_result:x}") print(f"{regular_result:x}")
print(f"{parallel_result:x}") print(f"{parallel_result:x}")
def test_random():
r = mask_r(random.randint(0, 2**128-1))
s = random.randint(0, 2**128-1)
msg = random.randbytes(random.randint(16, 1500))
parallel_poly1305(msg, r, s, 8)
def main(): def main():
test_regular() test_regular()
test_parallel() test_parallel()
test_on_long_string() test_on_long_string()
test_random()
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -57,7 +57,6 @@ def friendly_modular_mult(value_a: int, value_b: int) -> int:
mods = [friendly_modulo(prod, 26*i) for i, prod in enumerate(prods)] mods = [friendly_modulo(prod, 26*i) for i, prod in enumerate(prods)]
mod_sum = friendly_modulo(sum(mods), 0) mod_sum = friendly_modulo(sum(mods), 0)
return mod_sum return mod_sum

View File

@@ -5,9 +5,21 @@ tests:
- "poly1305_core" - "poly1305_core"
sources: "sources.list" sources: "sources.list"
waves: True waves: True
- name: "friendly_modulo" - name: "poly1305_friendly_modulo"
toplevel: "poly1305_friendly_modulo" toplevel: "poly1305_friendly_modulo"
modules: modules:
- "poly1305_friendly_modulo" - "poly1305_friendly_modulo"
sources: sources.list sources: sources.list
waves: True
- name: "poly1305_friendly_modular_mult"
toplevel: "poly1305_friendly_modular_mult"
modules:
- "poly1305_friendly_modular_mult"
sources: sources.list
waves: True
- name: "poly1305_stage"
toplevel: "poly1305_stage"
modules:
- "poly1305_stage"
sources: sources.list
waves: True waves: True

View File

@@ -0,0 +1,96 @@
import logging
import cocotb
from cocotb.clock import Clock
from cocotb.triggers import Timer, RisingEdge, FallingEdge
from cocotb.queue import Queue
from cocotbext.axi import AxiStreamBus, AxiStreamSource
from modulo_theory import friendly_modular_mult
import random
PRIME = 2**130-5
CLK_PERIOD = 4
class TB:
def __init__(self, dut):
self.dut = dut
self.log = logging.getLogger("cocotb.tb")
self.log.setLevel(logging.INFO)
self.input_queue = Queue()
self.expected_queue = Queue()
self.output_queue = Queue()
cocotb.start_soon(Clock(self.dut.i_clk, CLK_PERIOD, units="ns").start())
cocotb.start_soon(self.run_input())
cocotb.start_soon(self.run_output())
async def cycle_reset(self):
await self._cycle_reset(self.dut.i_rst, self.dut.i_clk)
async def _cycle_reset(self, rst, clk):
rst.setimmediatevalue(0)
await RisingEdge(clk)
await RisingEdge(clk)
rst.value = 1
await RisingEdge(clk)
await RisingEdge(clk)
rst.value = 0
await RisingEdge(clk)
await RisingEdge(clk)
async def write_input(self, data: int, h: int):
await self.input_queue.put((data, h))
await self.expected_queue.put(friendly_modular_mult(h, data))
async def run_input(self):
while True:
data, h = await self.input_queue.get()
self.dut.i_valid.value = 1
self.dut.i_data.value = data
self.dut.i_accumulator.value = h
while True:
await RisingEdge(self.dut.i_clk)
if (self.dut.o_ready.value == 1):
break
self.dut.i_valid.value = 0
self.dut.i_data.value = 0
self.dut.i_accumulator.value = 0
async def run_output(self):
while True:
await RisingEdge(self.dut.i_clk)
if self.dut.o_valid.value:
await self.output_queue.put(self.dut.o_result.value.integer)
@cocotb.test
async def test_sanity(dut):
tb = TB(dut)
await tb.cycle_reset()
count = 1
for _ in range(count):
await tb.write_input(random.randint(1,2**128-1), random.randint(0, 2**130-6))
fail = False
for _ in range(count):
sim_val = await tb.expected_queue.get()
dut_val = await tb.output_queue.get()
if sim_val != dut_val:
tb.log.info(f"{sim_val:x} -> {dut_val:x}")
fail = True
assert not fail

View File

@@ -88,4 +88,23 @@ async def test_sanity(dut):
tb.log.info(f"{sim_val:x} -> {dut_val:x}") tb.log.info(f"{sim_val:x} -> {dut_val:x}")
fail = True fail = True
assert not fail
@cocotb.test
async def test_directed(dut):
tb = TB(dut)
await tb.cycle_reset()
await tb.write_input(0x14C0D69391E7116E057E7AD833B00B706AA2390C, 4)
fail = False
sim_val = await tb.expected_queue.get()
dut_val = await tb.output_queue.get()
if sim_val != dut_val:
tb.log.info(f"{sim_val:x} -> {dut_val:x}")
fail = True
assert not fail assert not fail

View File

@@ -0,0 +1,110 @@
import logging
import cocotb
from cocotb.clock import Clock
from cocotb.triggers import Timer, RisingEdge, FallingEdge
from cocotb.queue import Queue
from cocotbext.axi import AxiStreamBus, AxiStreamSource
from modulo_theory import friendly_modular_mult
import random
PRIME = 2**130-5
CLK_PERIOD = 4
class TB:
def __init__(self, dut):
self.dut = dut
self.log = logging.getLogger("cocotb.tb")
self.log.setLevel(logging.INFO)
self.input_queue = Queue()
self.expected_queue = Queue()
self.output_queue = Queue()
cocotb.start_soon(Clock(self.dut.i_clk, CLK_PERIOD, units="ns").start())
cocotb.start_soon(self.run_input())
cocotb.start_soon(self.run_output())
self.index = 0
self.accumulators = [0, 0]
async def cycle_reset(self):
await self._cycle_reset(self.dut.i_rst, self.dut.i_clk)
async def _cycle_reset(self, rst, clk):
rst.setimmediatevalue(0)
await RisingEdge(clk)
await RisingEdge(clk)
rst.value = 1
await RisingEdge(clk)
await RisingEdge(clk)
rst.value = 0
await RisingEdge(clk)
await RisingEdge(clk)
async def write_input(self, msg: int, r_power: int, clear_acc: int):
await self.input_queue.put((msg, r_power, clear_acc))
if clear_acc:
expected_result = friendly_modular_mult((msg) % PRIME, r_power)
else:
expected_result = friendly_modular_mult((msg + self.accumulators[self.index]) % PRIME, r_power)
self.accumulators[self.index] = expected_result
await self.expected_queue.put(expected_result)
self.index = 1 if self.index == 0 else 0
async def run_input(self):
while True:
msg, r_power, clear_acc = await self.input_queue.get()
self.dut.i_valid.value = 1
self.dut.i_r_power.value = r_power
self.dut.i_message.value = msg
self.dut.i_clear_acc.value = clear_acc
while True:
await RisingEdge(self.dut.i_clk)
if (self.dut.o_ready.value == 1):
break
self.dut.i_valid.value = 0
self.dut.i_r_power.value = 0
self.dut.i_message.value = 0
self.dut.i_clear_acc.value = 0
async def run_output(self):
while True:
await RisingEdge(self.dut.i_clk)
if self.dut.o_valid.value:
await self.output_queue.put(self.dut.o_result.value.integer)
@cocotb.test
async def test_sanity(dut):
tb = TB(dut)
await tb.cycle_reset()
count = 1024
for _ in range(count):
clr = 1 if random.randint(0,10) == 0 else 0
await tb.write_input(random.randint(1,2**128-1), random.randint(0, 2**130-6), clr)
fail = False
for _ in range(count):
sim_val = await tb.expected_queue.get()
dut_val = await tb.output_queue.get()
if sim_val != dut_val:
tb.log.info(f"{_} {sim_val:x} -> {dut_val:x}")
fail = True
assert not fail

View File

@@ -0,0 +1,110 @@
module poly1305_friendly_modular_mult #(
parameter DATA_WIDTH = 130,
parameter ACC_WIDTH = 130
) (
input logic i_clk,
input logic i_rst,
input logic i_valid,
output logic o_ready,
input logic [DATA_WIDTH-1:0] i_data,
input logic [ACC_WIDTH-1:0] i_accumulator,
output logic o_valid,
output logic [ACC_WIDTH-1:0] o_result
);
localparam INT_ACC_WIDTH = ACC_WIDTH + 3; // $clog2(8)
localparam [129:0] PRIME = (1 << 130) - 5;
logic [2:0] state_counter, state_counter_next;
logic [2:0] state_counter_p [5];
logic [INT_ACC_WIDTH-1:0] accumulator, accumulator_next; // accumulator is outgoing
logic [INT_ACC_WIDTH+1:0] accumulator_intermediate;
logic [DATA_WIDTH-1:0] data, data_next;
logic [INT_ACC_WIDTH-1:0] h, h_next; // h is incoming
logic [DATA_WIDTH+26-1:0] mult_product, mult_product_next;
logic [ACC_WIDTH-1:0] modulo_result;
assign o_ready = state_counter >= 3'h4;
always_ff @(posedge i_clk) begin
if (i_rst) begin
state_counter <= 3'h5;
state_counter_p <= '{default: 3'h5};
end else begin
state_counter <= state_counter_next;
accumulator <= accumulator_next;
data <= data_next;
h <= h_next;
mult_product <= mult_product_next;
state_counter_p[0] <= state_counter;
for (int i = 1; i < 5; i++) begin
state_counter_p[i] <= state_counter_p[i-1];
end
end
end
always_comb begin
data_next = data; // If neccesary, we can remove this state?
h_next = h;
state_counter_next = state_counter;
accumulator_next = '0;
mult_product_next = '0;
accumulator_intermediate = '0;
if (state_counter < 3'h5) begin
mult_product_next = h[state_counter*26 +: 26] * data;
state_counter_next = state_counter + 1;
end
if (state_counter >= 3'h4 && i_valid) begin
data_next = i_data;
h_next = (INT_ACC_WIDTH)'(i_accumulator);
state_counter_next = '0;
end
if (state_counter_p[3] == '0) begin
accumulator_next = (INT_ACC_WIDTH)'(modulo_result);
end else begin
accumulator_next = accumulator + (INT_ACC_WIDTH)'(modulo_result);
end
end
poly1305_friendly_modulo u_mult_modulo (
.i_clk (i_clk),
.i_rst (i_rst),
.i_valid ('1),
.i_val ((2*ACC_WIDTH)'(mult_product)),
.i_shift_amount (state_counter_p[0]),
.o_valid (),
.o_result (modulo_result)
);
poly1305_friendly_modulo u_sum_modulo (
.i_clk (i_clk),
.i_rst (i_rst),
.i_valid (state_counter_p[4] == 3'h4),
.i_val ({127'b0, accumulator}),
.i_shift_amount ('0),
.o_valid (o_valid),
.o_result (o_result)
);
endmodule

View File

@@ -36,7 +36,7 @@ assign o_valid = valid_sr[2];
always_ff @(posedge i_clk) begin always_ff @(posedge i_clk) begin
valid_sr <= {valid_sr[1:0], i_valid}; valid_sr <= {valid_sr[1:0], i_valid};
high_part_1 <= WIDTH'({3'b0, i_val} >> (130 - (i_shift_amount*SHIFT_SIZE))) * MDIFF; high_part_1 <= WIDE_WIDTH'({3'b0, i_val} >> (130 - (i_shift_amount*SHIFT_SIZE))) * MDIFF;
low_part_1 <= WIDTH'(i_val << (i_shift_amount*SHIFT_SIZE)); low_part_1 <= WIDTH'(i_val << (i_shift_amount*SHIFT_SIZE));
high_part_2 <= (intermediate_val >> WIDTH) * 5; high_part_2 <= (intermediate_val >> WIDTH) * 5;

View File

@@ -0,0 +1,145 @@
module poly1305_stage #(
) (
input logic i_clk,
input logic i_rst,
input logic i_valid,
output logic o_ready,
input logic i_clear_acc,
input logic [129:0] i_r_power,
input logic [127:0] i_message,
output logic o_valid,
output logic [129:0] o_result
);
localparam [129:0] PRIME = (1 << 130) - 5;
logic mult_i_valid;
logic mult_o_ready;
logic [129:0] r_power, r_power_next;
logic [130:0] mult_accumulator, mult_accumulator_next;
logic mult_valid;
logic [129:0] mult_result;
logic [129:0] accumulators [2];
logic [129:0] accumulators_next [2];
logic [1:0] ops_in_flight, ops_in_flight_next;
logic [1:0] index, index_next;
logic [1:0] index_sr [16];
logic [1:0] index_sr_next;
enum logic [1:0] {SUM1, SUM2, MUL, STORE} state, state_next;
always_ff @(posedge i_clk) begin
if (i_rst) begin
state <= SUM1;
r_power <= '0;
mult_accumulator <= '0;
ops_in_flight <= 2'h0;
index <= 2'h2;
index_sr <= '{default: '0};
for (int i = 0; i < 2; i++) begin
accumulators[i] <= '0;
end
end else begin
state <= state_next;
r_power <= r_power_next;
mult_accumulator <= mult_accumulator_next;
index <= index_next;
ops_in_flight <= ops_in_flight_next;
for (int i = 0; i < 2; i++) begin
accumulators[i] <= accumulators_next[i];
end
index_sr[0] <= index;
for (int i = 1; i < 16; i++) begin
index_sr[i] <= index_sr[i-1];
end
end
end
logic [1:0] ops_in_flight_state;
assign ops_in_flight_state = {i_valid & o_ready, mult_valid};
always_comb begin
state_next = state;
o_ready = '0;
mult_accumulator_next = mult_accumulator;
mult_i_valid = '0;
index_next = index;
if (mult_valid) begin
accumulators_next[index_sr[12][0]] = mult_result;
end
case (ops_in_flight_state)
2'b00: ops_in_flight_next = ops_in_flight;
2'b01: ops_in_flight_next = ops_in_flight - 1;
2'b10: ops_in_flight_next = ops_in_flight + 1;
2'b11: ops_in_flight_next = ops_in_flight;
endcase
o_valid = mult_valid;
o_result = mult_result;
case (state)
SUM1: begin
o_ready = ops_in_flight < 2;
if (i_valid && o_ready) begin
r_power_next = i_r_power;
if (i_clear_acc) begin
mult_accumulator_next = {3'b0, i_message};
end else begin
mult_accumulator_next = accumulators[index[0]] + {3'b0, i_message};
end
state_next = SUM2;
end
end
SUM2: begin
mult_accumulator_next = mult_accumulator >= 131'(PRIME) ? 131'(mult_accumulator - PRIME) : mult_accumulator;
state_next = MUL;
end
MUL: begin
mult_i_valid = '1;
if (mult_o_ready) begin
index_next = index + 1;
state_next = SUM1;
end
end
default: begin
state_next = SUM1;
end
endcase
end
poly1305_friendly_modular_mult u_modular_mult (
.i_clk (i_clk),
.i_rst (i_rst),
.i_valid (mult_i_valid),
.o_ready (mult_o_ready),
.i_data (r_power),
.i_accumulator (mult_accumulator[129:0]),
.o_valid (mult_valid),
.o_result (mult_result)
);
endmodule

View File

@@ -4,4 +4,6 @@ chacha20_pipelined_round.sv
chacha20_pipelined_block.sv chacha20_pipelined_block.sv
poly1305_core.sv poly1305_core.sv
poly1305_friendly_modulo.sv poly1305_friendly_modulo.sv
poly1305_friendly_modular_mult.sv
poly1305_stage.sv