diff --git a/cocotbext/axi/axi_master.py b/cocotbext/axi/axi_master.py index 992b327..3afa0b2 100644 --- a/cocotbext/axi/axi_master.py +++ b/cocotbext/axi/axi_master.py @@ -49,6 +49,92 @@ AxiReadRespCmd = namedtuple("AxiReadRespCmd", ["address", "length", "size", "cyc AxiReadResp = namedtuple("AxiReadResp", ["address", "data", "resp", "user"]) +class TagContext: + def __init__(self, manager): + self.current_tag = 0 + self._cmd_queue = Queue() + self._current_cmd = None + self._resp_queue = Queue() + self._cr = None + self._manager = manager + + async def get_resp(self): + return await self._resp_queue.get() + + def get_resp_nowait(self): + return self._resp_queue.get_nowait() + + def _start(self): + if self._cr is None: + self._cr = cocotb.fork(self._process_queue()) + + def _flush(self): + flushed_cmds = [] + if self._cr is not None: + self._cr.kill() + self._cr = None + self._manager._set_idle(self) + if self._current_cmd is not None: + flushed_cmds.append(self._current_cmd) + self._current_cmd = None + while not self._cmd_queue.empty(): + flushed_cmds.append(self._cmd_queue.get_nowait()) + while not self._resp_queue.empty(): + self._resp_queue.get_nowait() + return flushed_cmds + + async def _process_queue(self): + while True: + cmd = await self._cmd_queue.get() + self._current_cmd = cmd + await self._manager._process(self, cmd) + self._current_cmd = None + + if self._cmd_queue.empty() and self._resp_queue.empty(): + self._manager._set_idle(self) + + +class TagContextManager: + def __init__(self, process): + self._context_list = [] + self._context_idle_list = [] + self._context_mapping = {} + self._process = process + + def _get_context(self, tag): + if tag in self._context_mapping: + return self._context_mapping[tag] + elif self._context_idle_list: + context = self._context_idle_list.pop() + else: + context = TagContext(self) + self._context_list.append(context) + context._start() + context.current_tag = tag + self._context_mapping[tag] = context + return context + + def start_cmd(self, tag, cmd): + context = self._get_context(tag) + context._cmd_queue.put_nowait(cmd) + + def put_resp(self, tag, resp): + context = self._get_context(tag) + context._resp_queue.put_nowait(resp) + + def _set_idle(self, context): + if context.current_tag in self._context_mapping: + del self._context_mapping[context.current_tag] + self._context_idle_list.append(context) + context.current_tag = None + + def flush(self): + flushed_cmds = [] + for c in self._context_list: + flushed_cmds.extend(c._flush()) + return flushed_cmds + + class AxiMasterWrite(Reset): def __init__(self, bus, clock, reset=None, reset_active_level=True, max_burst_len=256): self.log = logging.getLogger(f"cocotb.{bus.aw._entity._name}.{bus.aw._name}") @@ -72,9 +158,7 @@ class AxiMasterWrite(Reset): self.cur_id = 0 self.active_id = Counter() - self.int_write_resp_command_queue = [Queue() for k in range(self.id_count)] - self.current_write_resp_command = [None for k in range(self.id_count)] - self.int_write_resp_queue_list = [Queue() for k in range(self.id_count)] + self.tag_context_manager = TagContextManager(self._process_write_resp_id) self.in_flight_operations = 0 self._idle = Event() @@ -104,7 +188,6 @@ class AxiMasterWrite(Reset): self._process_write_cr = None self._process_write_resp_cr = None - self._process_write_resp_id_cr = None self._init_reset(reset, reset_active_level) @@ -207,10 +290,6 @@ class AxiMasterWrite(Reset): if self._process_write_resp_cr is not None: self._process_write_resp_cr.kill() self._process_write_resp_cr = None - if self._process_write_resp_id_cr is not None: - for cr in self._process_write_resp_id_cr: - cr.kill() - self._process_write_resp_id_cr = None self.aw_channel.clear() self.w_channel.clear() @@ -230,20 +309,8 @@ class AxiMasterWrite(Reset): self.current_write_command = None flush_cmd(cmd) - for q in self.int_write_resp_command_queue: - while not q.empty(): - cmd = q.get_nowait() - flush_cmd(cmd) - - for k in range(len(self.current_write_resp_command)): - if self.current_write_resp_command[k]: - cmd = self.current_write_resp_command[k] - self.current_write_resp_command[k] = None - flush_cmd(cmd) - - for q in self.int_write_resp_queue_list: - while not q.empty(): - q.get_nowait() + for cmd in self.tag_context_manager.flush(): + flush_cmd(cmd) self.cur_id = 0 self.active_id = Counter() @@ -256,8 +323,6 @@ class AxiMasterWrite(Reset): self._process_write_cr = cocotb.fork(self._process_write()) if self._process_write_resp_cr is None: self._process_write_resp_cr = cocotb.fork(self._process_write_resp()) - if self._process_write_resp_id_cr is None: - self._process_write_resp_id_cr = [cocotb.fork(self._process_write_resp_id(i)) for i in range(self.id_count)] async def _process_write(self): while True: @@ -361,7 +426,7 @@ class AxiMasterWrite(Reset): cycle_offset = (cycle_offset + num_bytes) % self.byte_width resp_cmd = AxiWriteRespCmd(cmd.address, len(cmd.data), cmd.size, cycles, cmd.prot, burst_list, cmd.event) - await self.int_write_resp_command_queue[awid].put(resp_cmd) + self.tag_context_manager.start_cmd(awid, resp_cmd) self.current_write_command = None @@ -374,48 +439,44 @@ class AxiMasterWrite(Reset): if self.active_id[bid] <= 0: raise Exception(f"Unexpected burst ID {bid}") - await self.int_write_resp_queue_list[bid].put(b) + self.tag_context_manager.put_resp(bid, b) - async def _process_write_resp_id(self, bid): - while True: - cmd = await self.int_write_resp_command_queue[bid].get() - self.current_write_resp_command[bid] = cmd + async def _process_write_resp_id(self, context, cmd): + bid = context.current_tag - resp = AxiResp.OKAY - user = [] + resp = AxiResp.OKAY + user = [] - for burst_length in cmd.burst_list: - b = await self.int_write_resp_queue_list[bid].get() + for burst_length in cmd.burst_list: + b = await context.get_resp() - burst_resp = AxiResp(b.bresp) - burst_user = int(b.buser) + burst_resp = AxiResp(b.bresp) + burst_user = int(b.buser) - if burst_resp != AxiResp.OKAY: - resp = burst_resp + if burst_resp != AxiResp.OKAY: + resp = burst_resp - if burst_user is not None: - user.append(burst_user) + if burst_user is not None: + user.append(burst_user) - if self.active_id[bid] <= 0: - raise Exception(f"Unexpected burst ID {bid}") + if self.active_id[bid] <= 0: + raise Exception(f"Unexpected burst ID {bid}") - self.active_id[bid] -= 1 + self.active_id[bid] -= 1 - self.log.info("Write burst complete bid: 0x%x bresp: %s", bid, burst_resp) + self.log.info("Write burst complete bid: 0x%x bresp: %s", bid, burst_resp) - self.log.info("Write complete addr: 0x%08x prot: %s resp: %s length: %d", - cmd.address, cmd.prot, resp, cmd.length) + self.log.info("Write complete addr: 0x%08x prot: %s resp: %s length: %d", + cmd.address, cmd.prot, resp, cmd.length) - write_resp = AxiWriteResp(cmd.address, cmd.length, resp, user) + write_resp = AxiWriteResp(cmd.address, cmd.length, resp, user) - cmd.event.set(write_resp) + cmd.event.set(write_resp) - self.current_write_resp_command[bid] = None + self.in_flight_operations -= 1 - self.in_flight_operations -= 1 - - if self.in_flight_operations == 0: - self._idle.set() + if self.in_flight_operations == 0: + self._idle.set() class AxiMasterRead(Reset): @@ -439,9 +500,7 @@ class AxiMasterRead(Reset): self.cur_id = 0 self.active_id = Counter() - self.int_read_resp_command_queue = [Queue() for k in range(self.id_count)] - self.current_read_resp_command = [None for k in range(self.id_count)] - self.int_read_resp_queue_list = [Queue() for k in range(self.id_count)] + self.tag_context_manager = TagContextManager(self._process_read_resp_id) self.in_flight_operations = 0 self._idle = Event() @@ -469,7 +528,6 @@ class AxiMasterRead(Reset): self._process_read_cr = None self._process_read_resp_cr = None - self._process_read_resp_id_cr = None self._init_reset(reset, reset_active_level) @@ -567,10 +625,6 @@ class AxiMasterRead(Reset): if self._process_read_resp_cr is not None: self._process_read_resp_cr.kill() self._process_read_resp_cr = None - if self._process_read_resp_id_cr is not None: - for cr in self._process_read_resp_id_cr: - cr.kill() - self._process_read_resp_id_cr = None self.ar_channel.clear() self.r_channel.clear() @@ -589,20 +643,8 @@ class AxiMasterRead(Reset): self.current_read_command = None flush_cmd(cmd) - for q in self.int_read_resp_command_queue: - while not q.empty(): - cmd = q.get_nowait() - flush_cmd(cmd) - - for k in range(len(self.current_read_resp_command)): - if self.current_read_resp_command[k]: - cmd = self.current_read_resp_command[k] - self.current_read_resp_command[k] = None - flush_cmd(cmd) - - for q in self.int_read_resp_queue_list: - while not q.empty(): - q.get_nowait() + for cmd in self.tag_context_manager.flush(): + flush_cmd(cmd) self.cur_id = 0 self.active_id = Counter() @@ -615,8 +657,6 @@ class AxiMasterRead(Reset): self._process_read_cr = cocotb.fork(self._process_read()) if self._process_read_resp_cr is None: self._process_read_resp_cr = cocotb.fork(self._process_read_resp()) - if self._process_read_resp_id_cr is None: - self._process_read_resp_id_cr = [cocotb.fork(self._process_read_resp_id(i)) for i in range(self.id_count)] async def _process_read(self): while True: @@ -679,7 +719,7 @@ class AxiMasterRead(Reset): cur_addr += num_bytes resp_cmd = AxiReadRespCmd(cmd.address, cmd.length, cmd.size, cycles, cmd.prot, burst_list, cmd.event) - await self.int_read_resp_command_queue[arid].put(resp_cmd) + self.tag_context_manager.start_cmd(arid, resp_cmd) self.current_read_command = None @@ -702,79 +742,75 @@ class AxiMasterRead(Reset): cur_rid = rid if int(r.rlast): - await self.int_read_resp_queue_list[rid].put(burst) + self.tag_context_manager.put_resp(rid, burst) burst = [] cur_rid = None - async def _process_read_resp_id(self, rid): - while True: - cmd = await self.int_read_resp_command_queue[rid].get() - self.current_read_resp_command[rid] = cmd + async def _process_read_resp_id(self, context, cmd): + rid = context.current_tag - num_bytes = 2**cmd.size + num_bytes = 2**cmd.size - aligned_addr = (cmd.address // num_bytes) * num_bytes - word_addr = (cmd.address // self.byte_width) * self.byte_width + aligned_addr = (cmd.address // num_bytes) * num_bytes + word_addr = (cmd.address // self.byte_width) * self.byte_width - start_offset = cmd.address % self.byte_width + start_offset = cmd.address % self.byte_width - cycle_offset = aligned_addr - word_addr - data = bytearray() + cycle_offset = aligned_addr - word_addr + data = bytearray() - resp = AxiResp.OKAY - user = [] + resp = AxiResp.OKAY + user = [] - first = True + first = True - for burst_length in cmd.burst_list: - burst = await self.int_read_resp_queue_list[rid].get() + for burst_length in cmd.burst_list: + burst = await context.get_resp() - if len(burst) != burst_length: - raise Exception(f"Burst length incorrect (ID {rid}, expected {burst_length}, got {len(burst)}") + if len(burst) != burst_length: + raise Exception(f"Burst length incorrect (ID {rid}, expected {burst_length}, got {len(burst)}") - for r in burst: - cycle_data = int(r.rdata) - cycle_resp = AxiResp(r.rresp) - cycle_user = int(r.ruser) + for r in burst: + cycle_data = int(r.rdata) + cycle_resp = AxiResp(r.rresp) + cycle_user = int(r.ruser) - if cycle_resp != AxiResp.OKAY: - resp = cycle_resp + if cycle_resp != AxiResp.OKAY: + resp = cycle_resp - if cycle_user is not None: - user.append(cycle_user) + if cycle_user is not None: + user.append(cycle_user) - start = cycle_offset - stop = cycle_offset+num_bytes + start = cycle_offset + stop = cycle_offset+num_bytes - if first: - start = start_offset + if first: + start = start_offset - for j in range(start, stop): - data.append((cycle_data >> j*8) & 0xff) + for j in range(start, stop): + data.append((cycle_data >> j*8) & 0xff) - cycle_offset = (cycle_offset + num_bytes) % self.byte_width + cycle_offset = (cycle_offset + num_bytes) % self.byte_width - first = False + first = False - self.active_id[rid] -= 1 + self.active_id[rid] -= 1 - self.log.info("Read burst complete rid: 0x%x rresp: %s", rid, resp) + self.log.info("Read burst complete rid: 0x%x rresp: %s", rid, resp) - data = data[:cmd.length] + data = data[:cmd.length] - self.log.info("Read complete addr: 0x%08x prot: %s resp: %s data: %s", - cmd.address, cmd.prot, resp, ' '.join((f'{c:02x}' for c in data))) + self.log.info("Read complete addr: 0x%08x prot: %s resp: %s data: %s", + cmd.address, cmd.prot, resp, ' '.join((f'{c:02x}' for c in data))) - read_resp = AxiReadResp(cmd.address, data, resp, user) + read_resp = AxiReadResp(cmd.address, data, resp, user) - cmd.event.set(read_resp) + cmd.event.set(read_resp) - self.current_read_resp_command[rid] = None + self.in_flight_operations -= 1 - self.in_flight_operations -= 1 - - if self.in_flight_operations == 0: - self._idle.set() + if self.in_flight_operations == 0: + self._idle.set() class AxiMaster: