import cocotb import logging from chacha_helpers import chacha_block import struct import random from cocotb.clock import Clock from cocotb.triggers import Timer, RisingEdge, FallingEdge from cocotb.queue import Queue CLK_PERIOD = 4 CONSTANT = [0x65787061, 0x6e642033, 0x322d6279, 0x7465206b] class TB: def __init__(self, dut): self.dut = dut self.log = logging.getLogger("cocotb.tb") self.log.setLevel(logging.INFO) self.input_queue = Queue() self.expected_queue = Queue() self.output_queue = Queue() cocotb.start_soon(Clock(self.dut.i_clk, CLK_PERIOD, units="ns").start()) cocotb.start_soon(self.run_input()) cocotb.start_soon(self.run_output()) async def cycle_reset(self): await self._cycle_reset(self.dut.i_rst, self.dut.i_clk) async def _cycle_reset(self, rst, clk): rst.setimmediatevalue(0) await RisingEdge(clk) await RisingEdge(clk) rst.value = 1 await RisingEdge(clk) await RisingEdge(clk) rst.value = 0 await RisingEdge(clk) await RisingEdge(clk) async def write_input(self, key, counter, nonce): await self.input_queue.put((key, counter, nonce)) data_in = CONSTANT[:] data_in.extend(struct.unpack("8I", key.to_bytes(32, "little"))) data_in.extend(struct.unpack("2I", counter.to_bytes(8, "little"))) data_in.extend(struct.unpack("2I", nonce.to_bytes(8, "little"))) data_out = chacha_block(data_in) await self.expected_queue.put(data_out) async def run_input(self): while True: key, counter, nonce = await self.input_queue.get() self.dut.i_key.value = key self.dut.i_counter.value = counter self.dut.i_nonce.value = nonce self.dut.i_ready.value = 1 self.dut.i_valid.value = 1 await RisingEdge(self.dut.i_clk) while not self.dut.o_ready.value == 1: await RisingEdge(self.dut.i_clk) self.dut.i_valid.value = 0 async def run_output(self): while True: await RisingEdge(self.dut.i_clk) if self.dut.o_valid.value: state = self.dut.o_state.value.integer state_bytes = int.to_bytes(state, 64, "little") state_words = struct.unpack("16I", state_bytes) await self.output_queue.put(state_words) @cocotb.test async def test_sanity(dut): tb = TB(dut) await tb.cycle_reset() count = 200 for i in range(count): key = random.randint(0, 2**256-1) key = 0 counter = i nonce = random.randint(0, 2**64-1) nonce = 0 await tb.write_input(key, counter, nonce) fail = False for _ in range(count): sim_vals = await tb.expected_queue.get() dut_vals = await tb.output_queue.get() for i, (sim_val, dut_val) in enumerate(zip(sim_vals, dut_vals)): if sim_val != dut_val: tb.log.info(f"{i}: {sim_val:x} -> {dut_val:x}") fail = True await Timer(1, "us") assert not fail