Add mult, but it doesn't quite work
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
from typing import List
|
||||
import random
|
||||
|
||||
from modulo_theory import friendly_modular_mult, friendly_modulo
|
||||
|
||||
@@ -111,11 +112,19 @@ def test_on_long_string():
|
||||
print(f"{regular_result:x}")
|
||||
print(f"{parallel_result:x}")
|
||||
|
||||
def test_random():
|
||||
r = mask_r(random.randint(0, 2**128-1))
|
||||
s = random.randint(0, 2**128-1)
|
||||
|
||||
msg = random.randbytes(random.randint(16, 1500))
|
||||
|
||||
parallel_poly1305(msg, r, s, 8)
|
||||
|
||||
def main():
|
||||
test_regular()
|
||||
test_parallel()
|
||||
test_on_long_string()
|
||||
test_random()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -57,6 +57,9 @@ def friendly_modular_mult(value_a: int, value_b: int) -> int:
|
||||
|
||||
mods = [friendly_modulo(prod, 26*i) for i, prod in enumerate(prods)]
|
||||
|
||||
if sum(mods) >= 2*PRIME:
|
||||
print("Saw greater than 2x prime!!!")
|
||||
|
||||
|
||||
mod_sum = friendly_modulo(sum(mods), 0)
|
||||
|
||||
|
||||
@@ -5,9 +5,15 @@ tests:
|
||||
- "poly1305_core"
|
||||
sources: "sources.list"
|
||||
waves: True
|
||||
- name: "friendly_modulo"
|
||||
- name: "poly1305_friendly_modulo"
|
||||
toplevel: "poly1305_friendly_modulo"
|
||||
modules:
|
||||
- "poly1305_friendly_modulo"
|
||||
sources: sources.list
|
||||
waves: True
|
||||
- name: "poly1305_friendly_modular_mult"
|
||||
toplevel: "poly1305_friendly_modular_mult"
|
||||
modules:
|
||||
- "poly1305_friendly_modular_mult"
|
||||
sources: sources.list
|
||||
waves: True
|
||||
96
ChaCha20_Poly1305_64/sim/poly1305_friendly_modular_mult.py
Normal file
96
ChaCha20_Poly1305_64/sim/poly1305_friendly_modular_mult.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import logging
|
||||
|
||||
|
||||
import cocotb
|
||||
from cocotb.clock import Clock
|
||||
from cocotb.triggers import Timer, RisingEdge, FallingEdge
|
||||
from cocotb.queue import Queue
|
||||
|
||||
from cocotbext.axi import AxiStreamBus, AxiStreamSource
|
||||
|
||||
import random
|
||||
|
||||
PRIME = 2**130-5
|
||||
|
||||
CLK_PERIOD = 4
|
||||
|
||||
|
||||
class TB:
|
||||
def __init__(self, dut):
|
||||
self.dut = dut
|
||||
|
||||
self.log = logging.getLogger("cocotb.tb")
|
||||
self.log.setLevel(logging.INFO)
|
||||
|
||||
self.input_queue = Queue()
|
||||
|
||||
self.expected_queue = Queue()
|
||||
self.output_queue = Queue()
|
||||
|
||||
cocotb.start_soon(Clock(self.dut.i_clk, CLK_PERIOD, units="ns").start())
|
||||
|
||||
cocotb.start_soon(self.run_input())
|
||||
cocotb.start_soon(self.run_output())
|
||||
|
||||
async def cycle_reset(self):
|
||||
await self._cycle_reset(self.dut.i_rst, self.dut.i_clk)
|
||||
|
||||
async def _cycle_reset(self, rst, clk):
|
||||
rst.setimmediatevalue(0)
|
||||
await RisingEdge(clk)
|
||||
await RisingEdge(clk)
|
||||
rst.value = 1
|
||||
await RisingEdge(clk)
|
||||
await RisingEdge(clk)
|
||||
rst.value = 0
|
||||
await RisingEdge(clk)
|
||||
await RisingEdge(clk)
|
||||
|
||||
async def write_input(self, data: int, h: int):
|
||||
await self.input_queue.put((data, h))
|
||||
await self.expected_queue.put((data * h) % PRIME)
|
||||
|
||||
async def run_input(self):
|
||||
while True:
|
||||
data, h = await self.input_queue.get()
|
||||
self.dut.i_valid.value = 1
|
||||
self.dut.i_data.value = data
|
||||
self.dut.i_accumulator.value = h
|
||||
while True:
|
||||
await RisingEdge(self.dut.i_clk)
|
||||
if (self.dut.o_ready.value == 1):
|
||||
break
|
||||
self.dut.i_valid.value = 0
|
||||
self.dut.i_data.value = 0
|
||||
self.dut.i_accumulator.value = 0
|
||||
|
||||
async def run_output(self):
|
||||
while True:
|
||||
await RisingEdge(self.dut.i_clk)
|
||||
if self.dut.o_valid.value:
|
||||
await self.output_queue.put(self.dut.o_result.value.integer)
|
||||
|
||||
@cocotb.test
|
||||
async def test_sanity(dut):
|
||||
tb = TB(dut)
|
||||
|
||||
await tb.cycle_reset()
|
||||
|
||||
count = 1
|
||||
|
||||
for _ in range(count):
|
||||
await tb.write_input(random.randint(1,2**128-1), random.randint(0, 2**130-6))
|
||||
|
||||
fail = False
|
||||
|
||||
for _ in range(count):
|
||||
sim_val = await tb.expected_queue.get()
|
||||
dut_val = await tb.output_queue.get()
|
||||
|
||||
if sim_val != dut_val:
|
||||
tb.log.info(f"{sim_val:x} -> {dut_val:x}")
|
||||
fail = True
|
||||
|
||||
# assert not fail
|
||||
|
||||
await Timer(1, "us")
|
||||
Reference in New Issue
Block a user