Pull out common functionality into StreamBase and StreamPause

This commit is contained in:
Alex Forencich
2020-11-22 23:08:15 -08:00
parent f2995d716e
commit 306b09f967

View File

@@ -48,15 +48,19 @@ class StreamTransaction(object):
return f"{type(self).__name__}({', '.join(f'{s}={int(getattr(self, s))}' for s in self._signals)})"
class StreamSource(object):
class StreamBase(object):
_signals = ["data", "valid", "ready"]
_optional_signals = []
_signal_widths = {"valid": 1, "ready": 1}
_init_x = False
_valid_signal = "valid"
_valid_init = None
_ready_signal = "ready"
_ready_init = None
_transaction_obj = StreamTransaction
@@ -72,34 +76,89 @@ class StreamSource(object):
self.queue = deque()
self.queue_sync = Event()
self.drive_obj = None
self.drive_sync = Event()
self.ready = None
self.valid = None
if self._ready_signal is not None and hasattr(self.bus, self._ready_signal):
self.ready = getattr(self.bus, self._ready_signal)
if self._ready_init is not None:
self.ready.setimmediatevalue(self._ready_init)
if self._valid_signal is not None and hasattr(self.bus, self._valid_signal):
self.valid = getattr(self.bus, self._valid_signal)
self.valid.setimmediatevalue(0)
if self._valid_init is not None:
self.valid.setimmediatevalue(self._valid_init)
for sig in self._signals+self._optional_signals:
if hasattr(self.bus, sig):
if sig in self._signal_widths:
assert len(getattr(self.bus, sig)) == self._signal_widths[sig]
if sig not in (self._valid_signal, self._ready_signal):
if self._init_x and sig not in (self._valid_signal, self._ready_signal):
v = getattr(self.bus, sig).value
v.binstr = 'x'*len(v)
getattr(self.bus, sig).setimmediatevalue(v)
self.active = False
def count(self):
return len(self.queue)
def empty(self):
return not self.queue
def clear(self):
self.queue.clear()
class StreamPause(object):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.pause = False
self._pause_generator = None
self._pause_cr = None
def set_pause_generator(self, generator=None):
if self._pause_cr is not None:
self._pause_cr.kill()
self._pause_cr = None
self._pause_generator = generator
if self._pause_generator is not None:
self._pause_cr = cocotb.fork(self._run_pause())
def clear_pause_generator(self):
self.set_pause_generator(None)
async def _run_pause(self):
for val in self._pause_generator:
self.pause = val
await RisingEdge(self.clock)
class StreamSource(StreamBase, StreamPause):
_signals = ["data", "valid", "ready"]
_optional_signals = []
_signal_widths = {"valid": 1, "ready": 1}
_init_x = True
_valid_signal = "valid"
_valid_init = 0
_ready_signal = "ready"
_ready_init = None
_transaction_obj = StreamTransaction
def __init__(self, entity, name, clock, reset=None, *args, **kwargs):
super().__init__(entity, name, clock, reset, *args, **kwargs)
self.drive_obj = None
self.drive_sync = Event()
self.active = False
cocotb.fork(self._run_source())
cocotb.fork(self._run())
@@ -114,12 +173,6 @@ class StreamSource(object):
self.queue.append(obj)
self.queue_sync.set()
def count(self):
return len(self.queue)
def empty(self):
return not self.queue
def idle(self):
return self.empty() and not self.active
@@ -128,23 +181,10 @@ class StreamSource(object):
await RisingEdge(self.clock)
def clear(self):
self.queue = deque()
self.queue.clear()
self.drive_obj = None
self.drive_sync.set()
def set_pause_generator(self, generator=None):
if self._pause_cr is not None:
self._pause_cr.kill()
self._pause_cr = None
self._pause_generator = generator
if self._pause_generator is not None:
self._pause_cr = cocotb.fork(self._run_pause())
def clear_pause_generator(self):
self.set_pause_generator(None)
async def _run_source(self):
while True:
await ReadOnly()
@@ -184,54 +224,25 @@ class StreamSource(object):
await self.drive(self.queue.popleft())
async def _run_pause(self):
for val in self._pause_generator:
self.pause = val
await RisingEdge(self.clock)
class StreamSink(object):
class StreamSink(StreamBase, StreamPause):
_signals = ["data", "valid", "ready"]
_optional_signals = []
_signal_widths = {"valid": 1, "ready": 1}
_init_x = False
_valid_signal = "valid"
_valid_init = None
_ready_signal = "ready"
_ready_init = 0
_transaction_obj = StreamTransaction
def __init__(self, entity, name, clock, reset=None, *args, **kwargs):
self.log = SimLog("cocotb.%s.%s" % (entity._name, name))
self.entity = entity
self.clock = clock
self.reset = reset
self.bus = Bus(self.entity, name, self._signals, optional_signals=self._optional_signals, **kwargs)
super().__init__(*args, **kwargs)
self.ready = None
self.valid = None
if self._ready_signal is not None and hasattr(self.bus, self._ready_signal):
self.ready = getattr(self.bus, self._ready_signal)
self.ready.setimmediatevalue(0)
if self._valid_signal is not None and hasattr(self.bus, self._valid_signal):
self.valid = getattr(self.bus, self._valid_signal)
for sig in self._signals+self._optional_signals:
if hasattr(self.bus, sig):
if sig in self._signal_widths:
assert len(getattr(self.bus, sig)) == self._signal_widths[sig]
self.queue = deque()
self.queue_sync = Event()
self.pause = False
self._pause_generator = None
self._pause_cr = None
super().__init__(entity, name, clock, reset, *args, **kwargs)
cocotb.fork(self._run_sink())
@@ -240,12 +251,6 @@ class StreamSink(object):
return self.queue.popleft()
return None
def count(self):
return len(self.queue)
def empty(self):
return not self.queue
async def wait(self, timeout=0, timeout_unit=None):
if not self.empty():
return
@@ -255,26 +260,10 @@ class StreamSink(object):
else:
await self.queue_sync.wait()
def clear(self):
self.queue = deque()
def callback(self, obj):
self.queue.append(obj)
self.queue_sync.set()
def set_pause_generator(self, generator=None):
if self._pause_cr is not None:
self._pause_cr.kill()
self._pause_cr = None
self._pause_generator = generator
if self._pause_generator is not None:
self._pause_cr = cocotb.fork(self._run_pause())
def clear_pause_generator(self):
self.set_pause_generator(None)
async def _run_sink(self):
while True:
await ReadOnly()
@@ -299,49 +288,25 @@ class StreamSink(object):
if self.ready is not None:
self.ready <= (not self.pause)
async def _run_pause(self):
for val in self._pause_generator:
self.pause = val
await RisingEdge(self.clock)
class StreamMonitor(object):
class StreamMonitor(StreamBase):
_signals = ["data", "valid", "ready"]
_optional_signals = []
_signal_widths = {"valid": 1, "ready": 1}
_init_x = False
_valid_signal = "valid"
_valid_init = None
_ready_signal = "ready"
_ready_init = None
_transaction_obj = StreamTransaction
def __init__(self, entity, name, clock, reset=None, *args, **kwargs):
self.log = SimLog("cocotb.%s.%s" % (entity._name, name))
self.entity = entity
self.clock = clock
self.reset = reset
self.bus = Bus(self.entity, name, self._signals, optional_signals=self._optional_signals, **kwargs)
super().__init__(*args, **kwargs)
self.ready = None
self.valid = None
if self._ready_signal is not None and hasattr(self.bus, self._ready_signal):
self.ready = getattr(self.bus, self._ready_signal)
if self._valid_signal is not None and hasattr(self.bus, self._valid_signal):
self.valid = getattr(self.bus, self._valid_signal)
for sig in self._signals+self._optional_signals:
if hasattr(self.bus, sig):
if sig in self._signal_widths:
assert len(getattr(self.bus, sig)) == self._signal_widths[sig]
self.queue = deque()
self.queue_sync = Event()
super().__init__(entity, name, clock, reset, *args, **kwargs)
cocotb.fork(self._run_monitor())
@@ -365,9 +330,6 @@ class StreamMonitor(object):
else:
await self.queue_sync.wait()
def clear(self):
self.queue = deque()
def callback(self, obj):
self.queue.append(obj)
self.queue_sync.set()