diff --git a/.gitignore b/.gitignore index 0f520ff..916f8e7 100644 --- a/.gitignore +++ b/.gitignore @@ -3,9 +3,16 @@ __pycache__/ *.py[codz] *$py.class +# macOS +.DS_Store + # C extensions *.so *.dylib +*.dSYM/ + +# Generated PDFs (schematics, etc.) +*.pdf # Distribution / packaging .Python diff --git a/compiler/frontend/pycircuit/__init__.py b/compiler/frontend/pycircuit/__init__.py index 7a5933e..cfe81b3 100644 --- a/compiler/frontend/pycircuit/__init__.py +++ b/compiler/frontend/pycircuit/__init__.py @@ -1,6 +1,5 @@ from . import ct from . import hierarchical -from . import lib from . import logic from . import spec from . import wiring @@ -17,6 +16,23 @@ from .hw import Bundle, Circuit, ClockDomain, Pop, Reg, Vec, Wire, cat, unsigned from .jit import JitError, compile from .literals import LiteralValue, S, U, s, u +from .v5 import ( + CycleAwareCircuit, + CycleAwareDomain, + CycleAwareSignal, + CycleAwareTb, + StateSignal, + cas, + compile_cycle_aware, + log, + mux, + pyc_CircuitLogger, + pyc_CircuitModule, + pyc_ClockDomain, + pyc_Signal, + signal, +) +from . import lib from .probe import ProbeBuilder, ProbeError, ProbeRef, ProbeView, TbProbeHandle, TbProbes from .tb import Tb, sva from .testbench import TestbenchProgram @@ -25,6 +41,19 @@ probe = _probe_decorator __all__ = [ + "CycleAwareCircuit", + "CycleAwareDomain", + "CycleAwareSignal", + "CycleAwareTb", + "cas", + "compile_cycle_aware", + "log", + "mux", + "pyc_CircuitLogger", + "pyc_CircuitModule", + "pyc_ClockDomain", + "pyc_Signal", + "signal", "Connector", "ConnectorBundle", "ConnectorStruct", diff --git a/compiler/frontend/pycircuit/hw.py b/compiler/frontend/pycircuit/hw.py index 40b6b9a..6ae32c9 100644 --- a/compiler/frontend/pycircuit/hw.py +++ b/compiler/frontend/pycircuit/hw.py @@ -746,6 +746,13 @@ def scope(self, name: str) -> Iterator[None]: def domain(self, name: str) -> ClockDomain: return ClockDomain(clk=self.clock(f"{name}_clk"), rst=self.reset(f"{name}_rst")) + def create_domain(self, name: str, *, frequency_desc: str = "", reset_active_high: bool = False) -> Any: + """V5 cycle-aware domain (next/prev/push/pop); see `pycircuit.v5.CycleAwareDomain`.""" + from .v5 import CycleAwareDomain + + _ = (frequency_desc, reset_active_high) + return CycleAwareDomain(self, str(name)) + def input(self, name: str, *, width: int, signed: bool = False) -> Wire: # type: ignore[override] """Declare a module input port and return it as a `Wire`.""" return Wire(self, super().input(name, width=width), signed=bool(signed)) diff --git a/compiler/frontend/pycircuit/jit_cache.py b/compiler/frontend/pycircuit/jit_cache.py index 9265b21..be516cb 100644 --- a/compiler/frontend/pycircuit/jit_cache.py +++ b/compiler/frontend/pycircuit/jit_cache.py @@ -272,15 +272,24 @@ def get_function_meta(fn: Any, *, fn_name: str | None = None) -> FunctionMeta: if cached is not None and (fn_name is None or cached.fdef.name == fn_name): return cached - lines, start_line = inspect.getsourcelines(fn) - source = textwrap.dedent("".join(lines)) - tree = ast.parse(source) + synthetic = getattr(fn, "__pycircuit_jit_source__", None) + if isinstance(synthetic, str) and synthetic.strip(): + source = textwrap.dedent(synthetic).strip() + "\n" + start_line = int(getattr(fn, "__pycircuit_jit_start_line__", 1) or 1) + tree = ast.parse(source) + else: + lines, start_line = inspect.getsourcelines(fn) + source = textwrap.dedent("".join(lines)) + tree = ast.parse(source) name = fn_name if fn_name is not None else getattr(fn, "__name__", None) if not isinstance(name, str) or not name: raise RuntimeError(f"failed to infer function name for {fn!r}") fdef = _find_function_def(tree, name) - source_file = inspect.getsourcefile(fn) or inspect.getfile(fn) + if isinstance(synthetic, str) and synthetic.strip(): + source_file = getattr(fn, "__pycircuit_jit_source_file__", None) or "" + else: + source_file = inspect.getsourcefile(fn) or inspect.getfile(fn) source_stem = None try: if source_file: diff --git a/compiler/frontend/pycircuit/lib/cache.py b/compiler/frontend/pycircuit/lib/cache.py index d35c6ea..e5a5fe9 100644 --- a/compiler/frontend/pycircuit/lib/cache.py +++ b/compiler/frontend/pycircuit/lib/cache.py @@ -1,22 +1,17 @@ from __future__ import annotations -from ..connectors import Connector, ConnectorBundle -from ..design import module -from ..dsl import Signal -from ..hw import Circuit -from ..literals import u +from pycircuit.hw import Circuit, ClockDomain, Wire +from pycircuit.literals import u -@module(structural=True) def Cache( m: Circuit, - clk: Connector, - rst: Connector, - req_valid: Connector, - req_addr: Connector, - req_write: Connector, - req_wdata: Connector, - req_wmask: Connector, + cd: ClockDomain, + req_valid: Wire, + req_addr: Wire, + req_write: Wire, + req_wdata: Wire, + req_wmask: Wire, *, ways: int = 4, sets: int = 64, @@ -26,7 +21,7 @@ def Cache( write_back: bool = True, write_allocate: bool = True, replacement: str = "plru", -) -> ConnectorBundle: +): """Structural cache baseline. Default policy contract: @@ -39,19 +34,15 @@ def Cache( """ _ = (line_bytes, write_back, write_allocate, replacement) - clk_v = clk.read() if isinstance(clk, Connector) else clk - rst_v = rst.read() if isinstance(rst, Connector) else rst - - req_valid_v = req_valid.read() if isinstance(req_valid, Connector) else req_valid - req_addr_v = req_addr.read() if isinstance(req_addr, Connector) else req_addr - req_write_v = req_write.read() if isinstance(req_write, Connector) else req_write - req_wdata_v = req_wdata.read() if isinstance(req_wdata, Connector) else req_wdata - req_wmask_v = req_wmask.read() if isinstance(req_wmask, Connector) else req_wmask - req_valid_w = m.wire(req_valid_v) if isinstance(req_valid_v, Signal) else req_valid_v - req_addr_w = m.wire(req_addr_v) if isinstance(req_addr_v, Signal) else req_addr_v - req_write_w = m.wire(req_write_v) if isinstance(req_write_v, Signal) else req_write_v - req_wdata_w = m.wire(req_wdata_v) if isinstance(req_wdata_v, Signal) else req_wdata_v - _req_wmask_w = m.wire(req_wmask_v) if isinstance(req_wmask_v, Signal) else req_wmask_v + clk_v = cd.clk + rst_v = cd.rst + + req_valid_w = req_valid + req_addr_w = req_addr + req_write_w = req_write + req_wdata_w = req_wdata + _req_wmask_w = req_wmask + _ = _req_wmask_w ways_i = max(1, int(ways)) sets_i = max(1, int(sets)) set_bits = max(1, (sets_i - 1).bit_length()) @@ -59,11 +50,11 @@ def Cache( plru_bits = max(1, ways_i - 1) way_idx_bits = max(1, (ways_i - 1).bit_length()) - tags = [m.out(f"cache_tag_{i}", clk=clk_v, rst=rst_v, width=tag_bits, init=0) for i in range(ways_i)] - valids = [m.out(f"cache_valid_{i}", clk=clk_v, rst=rst_v, width=1, init=0) for i in range(ways_i)] - dirty = [m.out(f"cache_dirty_{i}", clk=clk_v, rst=rst_v, width=1, init=0) for i in range(ways_i)] - data = [m.out(f"cache_data_{i}", clk=clk_v, rst=rst_v, width=int(data_width), init=0) for i in range(ways_i)] - plru = m.out("cache_plru", clk=clk_v, rst=rst_v, width=plru_bits, init=0) + tags = [m.out(f"cache_tag_{i}", domain=cd, width=tag_bits, init=0) for i in range(ways_i)] + valids = [m.out(f"cache_valid_{i}", domain=cd, width=1, init=0) for i in range(ways_i)] + dirty = [m.out(f"cache_dirty_{i}", domain=cd, width=1, init=0) for i in range(ways_i)] + data = [m.out(f"cache_data_{i}", domain=cd, width=int(data_width), init=0) for i in range(ways_i)] + plru = m.out("cache_plru", domain=cd, width=plru_bits, init=0) req_tag = req_addr_w[set_bits : set_bits + tag_bits] @@ -73,8 +64,8 @@ def Cache( for i in range(ways_i): way_hit = valids[i].out() & (tags[i].out() == req_tag) - hit_data = data[i].out() if way_hit else hit_data - hit_way = i if way_hit else hit_way + hit_data = way_hit._select_internal(data[i].out(), hit_data) + hit_way = way_hit._select_internal(u(way_idx_bits, i), hit_way) hit = hit | way_hit victim_way = plru.out()[0:way_idx_bits] @@ -101,7 +92,7 @@ def Cache( resp_valid = req_valid_w resp_ready = req_valid_w resp_hit = hit - resp_data = hit_data if hit else u(int(data_width), 0) + resp_data = hit._select_internal(hit_data, u(int(data_width), 0)) miss = req_valid_w & (~hit) return m.bundle_connector( diff --git a/compiler/frontend/pycircuit/lib/mem2port.py b/compiler/frontend/pycircuit/lib/mem2port.py index e426d4a..e138aca 100644 --- a/compiler/frontend/pycircuit/lib/mem2port.py +++ b/compiler/frontend/pycircuit/lib/mem2port.py @@ -1,50 +1,44 @@ from __future__ import annotations -from ..connectors import Connector, ConnectorBundle, ConnectorError -from ..design import module -from ..dsl import Signal -from ..hw import Circuit +from pycircuit.dsl import Signal +from pycircuit.hw import Circuit, ClockDomain, Wire + + +class Mem2PortError(ValueError): + pass -@module(structural=True) def Mem2Port( m: Circuit, - clk: Connector, - rst: Connector, - ren0: Connector, - raddr0: Connector, - ren1: Connector, - raddr1: Connector, - wvalid: Connector, - waddr: Connector, - wdata: Connector, - wstrb: Connector, + cd: ClockDomain, + ren0: Wire, + raddr0: Wire, + ren1: Wire, + raddr1: Wire, + wvalid: Wire, + waddr: Wire, + wdata: Wire, + wstrb: Wire, *, depth: int, -) -> ConnectorBundle: - clk_v = clk.read() if isinstance(clk, Connector) else clk - rst_v = rst.read() if isinstance(rst, Connector) else rst +): + clk_v = cd.clk + rst_v = cd.rst if not isinstance(clk_v, Signal) or clk_v.ty != "!pyc.clock": - raise ConnectorError("Mem2Port.clk must be !pyc.clock") + raise Mem2PortError("Mem2Port domain clk must be !pyc.clock") if not isinstance(rst_v, Signal) or rst_v.ty != "!pyc.reset": - raise ConnectorError("Mem2Port.rst must be !pyc.reset") - - def wire_of(v): - vv = v.read() if isinstance(v, Connector) else v - if isinstance(vv, Signal): - return m.wire(vv) - return vv + raise Mem2PortError("Mem2Port domain rst must be !pyc.reset") - ren0_w = wire_of(ren0) - ren1_w = wire_of(ren1) - wvalid_w = wire_of(wvalid) - raddr0_w = wire_of(raddr0) - raddr1_w = wire_of(raddr1) - waddr_w = wire_of(waddr) - wdata_w = wire_of(wdata) - wstrb_w = wire_of(wstrb) + ren0_w = ren0 + ren1_w = ren1 + wvalid_w = wvalid + raddr0_w = raddr0 + raddr1_w = raddr1 + waddr_w = waddr + wdata_w = wdata + wstrb_w = wstrb if ren0_w.ty != "i1" or ren1_w.ty != "i1" or wvalid_w.ty != "i1": - raise ConnectorError("Mem2Port ren0/ren1/wvalid must be i1") + raise Mem2PortError("Mem2Port ren0/ren1/wvalid must be i1") rdata0, rdata1 = m.sync_mem_dp( clk_v, diff --git a/compiler/frontend/pycircuit/lib/picker.py b/compiler/frontend/pycircuit/lib/picker.py index d231afb..f2ab8a7 100644 --- a/compiler/frontend/pycircuit/lib/picker.py +++ b/compiler/frontend/pycircuit/lib/picker.py @@ -1,26 +1,18 @@ from __future__ import annotations -from ..connectors import Connector, ConnectorBundle, ConnectorError -from ..design import module -from ..dsl import Signal -from ..hw import Circuit -from ..literals import u +from pycircuit.hw import Circuit, Wire +from pycircuit.literals import u -@module(structural=True) def Picker( m: Circuit, - req: Connector, + req: Wire, *, width: int | None = None, -) -> ConnectorBundle: - req_v = req.read() if isinstance(req, Connector) else req - if isinstance(req_v, Signal): - req_w = m.wire(req_v) - else: - req_w = req_v +): + req_w = req if not hasattr(req_w, "ty") or not str(req_w.ty).startswith("i"): - raise ConnectorError("Picker.req must be an integer wire connector") + raise ValueError("Picker.req must be an integer wire") w = int(width) if width is not None else int(req_w.width) if w <= 0: raise ValueError("Picker width must be > 0") @@ -32,8 +24,8 @@ def Picker( for i in range(w): take = req_w[i] & ~found - grant = u(w, 1 << i) if take else grant - index = u(idx_w, i) if take else index + grant = take._select_internal(u(w, 1 << i), grant) + index = take._select_internal(u(idx_w, i), index) found = found | req_w[i] return m.bundle_connector( diff --git a/compiler/frontend/pycircuit/lib/queue.py b/compiler/frontend/pycircuit/lib/queue.py index e3ec5ec..9abca70 100644 --- a/compiler/frontend/pycircuit/lib/queue.py +++ b/compiler/frontend/pycircuit/lib/queue.py @@ -1,52 +1,39 @@ from __future__ import annotations -from ..connectors import Connector, ConnectorBundle, ConnectorError -from ..design import module -from ..dsl import Signal -from ..hw import Circuit, Wire +from pycircuit.dsl import Signal +from pycircuit.hw import Circuit, ClockDomain, Wire + + +class FIFOError(ValueError): + pass -@module(structural=True) def FIFO( m: Circuit, - clk: Connector, - rst: Connector, - in_valid: Connector, - in_data: Connector, - out_ready: Connector, + cd: ClockDomain, + in_valid: Wire, + in_data: Wire, + out_ready: Wire, *, depth: int = 2, -) -> ConnectorBundle: - clk_v = clk.read() if isinstance(clk, Connector) else clk - rst_v = rst.read() if isinstance(rst, Connector) else rst - in_valid_v = in_valid.read() if isinstance(in_valid, Connector) else in_valid - in_data_v = in_data.read() if isinstance(in_data, Connector) else in_data - out_ready_v = out_ready.read() if isinstance(out_ready, Connector) else out_ready - +): + clk_v = cd.clk + rst_v = cd.rst if not isinstance(clk_v, Signal) or clk_v.ty != "!pyc.clock": - raise ConnectorError("FIFO.clk must be !pyc.clock") + raise FIFOError("FIFO domain clk must be !pyc.clock") if not isinstance(rst_v, Signal) or rst_v.ty != "!pyc.reset": - raise ConnectorError("FIFO.rst must be !pyc.reset") + raise FIFOError("FIFO domain rst must be !pyc.reset") - if isinstance(in_valid_v, Signal): - in_valid_w = Wire(m, in_valid_v) - else: - in_valid_w = in_valid_v - if isinstance(in_data_v, Signal): - in_data_w = Wire(m, in_data_v) - else: - in_data_w = in_data_v - if isinstance(out_ready_v, Signal): - out_ready_w = Wire(m, out_ready_v) - else: - out_ready_w = out_ready_v + in_valid_w = in_valid + in_data_w = in_data + out_ready_w = out_ready if not isinstance(in_valid_w, Wire) or in_valid_w.ty != "i1": - raise ConnectorError("FIFO.in_valid must be i1") + raise FIFOError("FIFO.in_valid must be i1") if not isinstance(in_data_w, Wire): - raise ConnectorError("FIFO.in_data must be integer wire") + raise FIFOError("FIFO.in_data must be integer wire") if not isinstance(out_ready_w, Wire) or out_ready_w.ty != "i1": - raise ConnectorError("FIFO.out_ready must be i1") + raise FIFOError("FIFO.out_ready must be i1") in_ready, out_valid, out_data = m.fifo( clk_v, diff --git a/compiler/frontend/pycircuit/lib/regfile.py b/compiler/frontend/pycircuit/lib/regfile.py index 055acaa..7b9342c 100644 --- a/compiler/frontend/pycircuit/lib/regfile.py +++ b/compiler/frontend/pycircuit/lib/regfile.py @@ -1,27 +1,27 @@ from __future__ import annotations -from ..connectors import Connector, ConnectorBundle, ConnectorError -from ..design import module -from ..dsl import Signal -from ..hw import Circuit -from ..literals import u +from pycircuit.dsl import Signal +from pycircuit.hw import Circuit, ClockDomain, Wire +from pycircuit.literals import u + + +class RegFileError(ValueError): + """Invalid RegFile port wiring.""" -@module(structural=True) def RegFile( m: Circuit, - clk: Connector, - rst: Connector, - raddr_bus: Connector, - wen_bus: Connector, - waddr_bus: Connector, - wdata_bus: Connector, + cd: ClockDomain, + raddr_bus: Wire, + wen_bus: Wire, + waddr_bus: Wire, + wdata_bus: Wire, *, ptag_count: int = 256, const_count: int = 128, nr: int = 10, nw: int = 5, -) -> ConnectorBundle: +): ptag_n = int(ptag_count) const_n = int(const_count) nr_n = int(nr) @@ -36,36 +36,17 @@ def RegFile( raise ValueError("RegFile nw must be > 0") ptag_w = max(1, (ptag_n - 1).bit_length()) - clk_v = clk.read() if isinstance(clk, Connector) else clk - rst_v = rst.read() if isinstance(rst, Connector) else rst + clk_v = cd.clk + rst_v = cd.rst if not isinstance(clk_v, Signal) or clk_v.ty != "!pyc.clock": - raise ConnectorError("RegFile.clk must be !pyc.clock") + raise RegFileError("RegFile domain clk must be !pyc.clock") if not isinstance(rst_v, Signal) or rst_v.ty != "!pyc.reset": - raise ConnectorError("RegFile.rst must be !pyc.reset") - - raddr_bus_v = raddr_bus.read() if isinstance(raddr_bus, Connector) else raddr_bus - if isinstance(raddr_bus_v, Signal): - raddr_bus_w = m.wire(raddr_bus_v) - else: - raddr_bus_w = raddr_bus_v - - wen_bus_v = wen_bus.read() if isinstance(wen_bus, Connector) else wen_bus - if isinstance(wen_bus_v, Signal): - wen_bus_w = m.wire(wen_bus_v) - else: - wen_bus_w = wen_bus_v - - waddr_bus_v = waddr_bus.read() if isinstance(waddr_bus, Connector) else waddr_bus - if isinstance(waddr_bus_v, Signal): - waddr_bus_w = m.wire(waddr_bus_v) - else: - waddr_bus_w = waddr_bus_v - - wdata_bus_v = wdata_bus.read() if isinstance(wdata_bus, Connector) else wdata_bus - if isinstance(wdata_bus_v, Signal): - wdata_bus_w = m.wire(wdata_bus_v) - else: - wdata_bus_w = wdata_bus_v + raise RegFileError("RegFile domain rst must be !pyc.reset") + + raddr_bus_w = raddr_bus + wen_bus_w = wen_bus + waddr_bus_w = waddr_bus + wdata_bus_w = wdata_bus exp_raddr_w = nr_n * ptag_w exp_wen_w = nw_n @@ -73,17 +54,17 @@ def RegFile( exp_wdata_w = nw_n * 64 if raddr_bus_w.width != exp_raddr_w: - raise ConnectorError(f"RegFile.raddr_bus must be i{exp_raddr_w}") + raise RegFileError(f"RegFile.raddr_bus must be i{exp_raddr_w}") if wen_bus_w.width != exp_wen_w: - raise ConnectorError(f"RegFile.wen_bus must be i{exp_wen_w}") + raise RegFileError(f"RegFile.wen_bus must be i{exp_wen_w}") if waddr_bus_w.width != exp_waddr_w: - raise ConnectorError(f"RegFile.waddr_bus must be i{exp_waddr_w}") + raise RegFileError(f"RegFile.waddr_bus must be i{exp_waddr_w}") if wdata_bus_w.width != exp_wdata_w: - raise ConnectorError(f"RegFile.wdata_bus must be i{exp_wdata_w}") + raise RegFileError(f"RegFile.wdata_bus must be i{exp_wdata_w}") storage_depth = ptag_n - const_n - bank0 = [m.out(f"rf_bank0_{i}", clk=clk_v, rst=rst_v, width=32, init=u(32, 0)) for i in range(storage_depth)] - bank1 = [m.out(f"rf_bank1_{i}", clk=clk_v, rst=rst_v, width=32, init=u(32, 0)) for i in range(storage_depth)] + bank0 = [m.out(f"rf_bank0_{i}", domain=cd, width=32, init=u(32, 0)) for i in range(storage_depth)] + bank1 = [m.out(f"rf_bank1_{i}", domain=cd, width=32, init=u(32, 0)) for i in range(storage_depth)] raddr_lanes = [raddr_bus_w[i * ptag_w : (i + 1) * ptag_w] for i in range(nr_n)] wen_lanes = [wen_bus_w[i] for i in range(nw_n)] @@ -102,8 +83,8 @@ def RegFile( for lane in range(nw_n): hit = wen_lanes[lane] & (waddr_lanes[lane] == u(ptag_w, ptag)) we_any = we_any | hit - next_lo = wdata_lo[lane] if hit else next_lo - next_hi = wdata_hi[lane] if hit else next_hi + next_lo = hit._select_internal(wdata_lo[lane], next_lo) + next_hi = hit._select_internal(wdata_hi[lane], next_hi) bank0[sidx].set(next_lo, when=we_any) bank1[sidx].set(next_hi, when=we_any) @@ -126,12 +107,12 @@ def RegFile( for sidx in range(storage_depth): ptag = const_n + sidx hit = raddr_i == u(ptag_w, ptag) - store_lo = bank0[sidx].out() if hit else store_lo - store_hi = bank1[sidx].out() if hit else store_hi + store_lo = hit._select_internal(bank0[sidx].out(), store_lo) + store_hi = hit._select_internal(bank1[sidx].out(), store_hi) store64 = m.cat(store_hi, store_lo) - lane_data = const64 if is_const else store64 - lane_data = lane_data if is_valid else u(64, 0) + lane_data = is_const._select_internal(const64, store64) + lane_data = is_valid._select_internal(lane_data, u(64, 0)) rdata_lanes.append(lane_data) rdata_bus_out = rdata_lanes[0] diff --git a/compiler/frontend/pycircuit/lib/sram.py b/compiler/frontend/pycircuit/lib/sram.py index c2be4eb..95f67f3 100644 --- a/compiler/frontend/pycircuit/lib/sram.py +++ b/compiler/frontend/pycircuit/lib/sram.py @@ -1,46 +1,40 @@ from __future__ import annotations -from ..connectors import Connector, ConnectorBundle, ConnectorError -from ..design import module -from ..dsl import Signal -from ..hw import Circuit +from pycircuit.dsl import Signal +from pycircuit.hw import Circuit, ClockDomain, Wire + + +class SRAMError(ValueError): + pass -@module(structural=True) def SRAM( m: Circuit, - clk: Connector, - rst: Connector, - ren: Connector, - raddr: Connector, - wvalid: Connector, - waddr: Connector, - wdata: Connector, - wstrb: Connector, + cd: ClockDomain, + ren: Wire, + raddr: Wire, + wvalid: Wire, + waddr: Wire, + wdata: Wire, + wstrb: Wire, *, depth: int, -) -> ConnectorBundle: - clk_v = clk.read() if isinstance(clk, Connector) else clk - rst_v = rst.read() if isinstance(rst, Connector) else rst +): + clk_v = cd.clk + rst_v = cd.rst if not isinstance(clk_v, Signal) or clk_v.ty != "!pyc.clock": - raise ConnectorError("SRAM.clk must be !pyc.clock") + raise SRAMError("SRAM domain clk must be !pyc.clock") if not isinstance(rst_v, Signal) or rst_v.ty != "!pyc.reset": - raise ConnectorError("SRAM.rst must be !pyc.reset") - - def wire_of(v): - vv = v.read() if isinstance(v, Connector) else v - if isinstance(vv, Signal): - return m.wire(vv) - return vv - - ren_w = wire_of(ren) - wvalid_w = wire_of(wvalid) - raddr_w = wire_of(raddr) - waddr_w = wire_of(waddr) - wdata_w = wire_of(wdata) - wstrb_w = wire_of(wstrb) + raise SRAMError("SRAM domain rst must be !pyc.reset") + + ren_w = ren + wvalid_w = wvalid + raddr_w = raddr + waddr_w = waddr + wdata_w = wdata + wstrb_w = wstrb if ren_w.ty != "i1" or wvalid_w.ty != "i1": - raise ConnectorError("SRAM ren/wvalid must be i1") + raise SRAMError("SRAM ren/wvalid must be i1") rdata = m.sync_mem( clk_v, diff --git a/compiler/frontend/pycircuit/v5.py b/compiler/frontend/pycircuit/v5.py new file mode 100644 index 0000000..d3377b2 --- /dev/null +++ b/compiler/frontend/pycircuit/v5.py @@ -0,0 +1,949 @@ +"""PyCircuit V5 cycle-aware frontend (tutorial + Cycle-Aware API). + +Maps documented grammar onto the existing Circuit/Wire MLIR builder. Library and +top-level designs should use CycleAwareCircuit / CycleAwareDomain and +compile_cycle_aware() instead of @module + compile(). +""" + +from __future__ import annotations + +import ast +from contextlib import contextmanager +from dataclasses import dataclass, field +import inspect +import textwrap +import threading +from typing import Any, Callable, Iterable, Iterator, Mapping, TypeVar, Union + +from .dsl import Signal +from .hw import Circuit, ClockDomain, Reg, Wire +from .literals import LiteralValue, infer_literal_width +from .tb import Tb as _Tb + +F = TypeVar("F", bound=Callable[..., Any]) + +_tls = threading.local() + + +def _current_domain() -> "CycleAwareDomain | None": + return getattr(_tls, "domain", None) + + +def _set_current_domain(d: "CycleAwareDomain | None") -> None: + _tls.domain = d + + +@dataclass +class _ModuleCtx: + owner: "pyc_CircuitModule" + inputs: list[Any] + description: str + outputs: list[Any] = field(default_factory=list) + + +class CycleAwareCircuit(Circuit): + """V5 top-level builder; extends Circuit so m.out / m.cat / emit_mlir work unchanged.""" + + def create_domain(self, name: str, *, frequency_desc: str = "", reset_active_high: bool = False) -> "CycleAwareDomain": + _ = (frequency_desc, reset_active_high) + return CycleAwareDomain(self, str(name)) + + def const_signal(self, value: int, width: int, domain: "CycleAwareDomain") -> Wire: + return domain.create_const(int(value), width=int(width)) + + def input_signal(self, name: str, width: int, domain: "CycleAwareDomain") -> Wire: + return domain.create_signal(str(name), width=int(width)) + + +class CycleAwareDomain: + """Clock domain with logical occurrence index (tutorial: next/prev/push/pop/cycle).""" + + def __init__(self, circuit: Circuit, domain_name: str) -> None: + self._m = circuit + self._name = str(domain_name) + self._cd = _clock_domain_ports(circuit, self._name) + self._occurrence = 0 + self._stack: list[int] = [] + self._delay_serial = 0 + self._reg_serial = 0 + + @property + def clock_domain(self) -> ClockDomain: + """Underlying clk/rst pair for m.out(..., domain=...).""" + return self._cd + + @property + def circuit(self) -> Circuit: + return self._m + + def create_reset(self) -> Wire: + return Wire(self._m, self._cd.rst) + + def create_signal(self, port_name: str, *, width: int) -> Wire: + return self._m.input(str(port_name), width=int(width)) + + def create_const(self, value: int, *, width: int, name: str = "") -> Wire: + _ = name + return self._m.const(int(value), width=int(width)) + + def next(self) -> None: + self._occurrence += 1 + + def prev(self) -> None: + self._occurrence -= 1 + + def push(self) -> None: + self._stack.append(self._occurrence) + + def pop(self) -> None: + if not self._stack: + raise RuntimeError("clock_domain.pop() without matching push()") + self._occurrence = self._stack.pop() + + @property + def cycle_index(self) -> int: + return self._occurrence + + def cycle( + self, + sig: Union[Wire, Reg, "CycleAwareSignal"], + reset_value: int | None = None, + name: str = "", + ) -> Wire: + """Single-stage register (DFF); output is one logical cycle after the input value.""" + w = _as_wire(self._m, sig) + width = w.width + init = 0 if reset_value is None else int(reset_value) + reg_name = str(name).strip() or f"_v5_reg_{self._reg_serial}" + self._reg_serial += 1 + full = self._m.scoped_name(reg_name) + r = self._m.out(full, domain=self._cd, width=width, init=init) + r.set(w) + return r.q + + def state( + self, + *, + width: int, + reset_value: int = 0, + name: str = "", + ) -> "StateSignal": + """Declare a feedback state variable (register whose D depends on Q). + + Returns a :class:`StateSignal` that behaves like a ``CycleAwareSignal`` + (read its current value, use in expressions) and also supports + ``.set(next_val)`` to close the feedback loop. + + Typical pattern:: + + # Cycle 0: declare state and read current value + counter = domain.state(width=8, reset_value=0, name="cnt") + + domain.next() # → Cycle 1 + + # Cycle 1: conditionally update + counter.set(mux(enable, counter + 1, counter)) + """ + reg_name = str(name).strip() or f"_v5_reg_{self._reg_serial}" + self._reg_serial += 1 + full = self._m.scoped_name(reg_name) + reg = self._m.out(full, domain=self._cd, width=int(width), init=int(reset_value)) + return StateSignal(self, reg, self._occurrence) + + def delay_to(self, w: Wire, *, from_cycle: int, to_cycle: int, width: int) -> Wire: + """Insert (to_cycle - from_cycle) register stages for automatic cycle balancing.""" + if to_cycle <= from_cycle: + return w + d = to_cycle - from_cycle + cur: Wire = w + for _ in range(d): + self._delay_serial += 1 + nm = f"_v5_bal_{self._delay_serial}" + r = self._m.out(self._m.scoped_name(nm), domain=self._cd, width=width, init=0) + r.set(cur) + cur = r.q + return cur + + +def _clock_domain_ports(m: Circuit, name: str) -> ClockDomain: + if name == "clk": + return ClockDomain(clk=m.clock("clk"), rst=m.reset("rst")) + return m.domain(name) + + +def _as_wire(m: Circuit, sig: Union[Wire, Reg, "CycleAwareSignal", Signal]) -> Wire: + if isinstance(sig, CycleAwareSignal): + return sig.wire + if isinstance(sig, Reg): + return sig.q + if isinstance(sig, Wire): + return sig + if isinstance(sig, Signal): + return Wire(m, sig) + raise TypeError(f"expected Wire/Reg/CycleAwareSignal/Signal, got {type(sig).__name__}") + + +class StateSignal: + """Feedback register exposed as a cycle-aware value with deferred ``.set()``. + + Created by ``domain.state()``. Read it like any ``CycleAwareSignal``; + after ``domain.next()``, call ``.set(next_val)`` to close the feedback loop. + """ + + __slots__ = ("_domain", "_reg", "_cas") + + def __init__(self, domain: "CycleAwareDomain", reg: Reg, cycle: int) -> None: + self._domain = domain + self._reg = reg + self._cas = CycleAwareSignal(domain, reg.out(), cycle) + + def set( + self, + next_val: "Wire | Reg | CycleAwareSignal | StateSignal", + *, + when: "Wire | Reg | CycleAwareSignal | StateSignal | None" = None, + ) -> None: + """Connect the D input of the register (close the feedback loop).""" + w = _to_wire(next_val) + wh = _to_wire(when) if when is not None else None + if wh is not None: + self._reg.set(w, when=wh) + else: + self._reg.set(w) + + @property + def wire(self) -> Wire: + return self._cas.wire + + @property + def w(self) -> Wire: + return self._cas.wire + + @property + def sig(self) -> Signal: + return self._cas.sig + + @property + def cycle(self) -> int: + return self._cas.cycle + + @property + def domain(self) -> "CycleAwareDomain": + return self._domain + + def __getattr__(self, name: str) -> object: + return getattr(self._cas, name) + + def __add__(self, other: object) -> "CycleAwareSignal": + return self._cas.__add__(other) + + def __radd__(self, other: object) -> "CycleAwareSignal": + return self._cas.__radd__(other) + + def __sub__(self, other: object) -> "CycleAwareSignal": + return self._cas.__sub__(other) + + def __mul__(self, other: object) -> "CycleAwareSignal": + return self._cas.__mul__(other) + + def __and__(self, other: object) -> "CycleAwareSignal": + return self._cas.__and__(other) + + def __or__(self, other: object) -> "CycleAwareSignal": + if isinstance(other, str): + return self._cas + return self._cas.__or__(other) + + def __xor__(self, other: object) -> "CycleAwareSignal": + return self._cas.__xor__(other) + + def __invert__(self) -> "CycleAwareSignal": + return self._cas.__invert__() + + def __eq__(self, other: object) -> "CycleAwareSignal": # type: ignore[override] + return self._cas.__eq__(other) + + def __ne__(self, other: object) -> "CycleAwareSignal": # type: ignore[override] + return self._cas.__ne__(other) + + def __lt__(self, other: object) -> "CycleAwareSignal": + return self._cas.__lt__(other) + + def __gt__(self, other: object) -> "CycleAwareSignal": + return self._cas.__gt__(other) + + def __le__(self, other: object) -> "CycleAwareSignal": + return self._cas.__le__(other) + + def __ge__(self, other: object) -> "CycleAwareSignal": + return self._cas.__ge__(other) + + def __getitem__(self, idx: int | slice) -> "CycleAwareSignal": + return self._cas.__getitem__(idx) + + def __repr__(self) -> str: + return f"StateSignal({self._cas.wire}, cycle={self._cas.cycle})" + + +def _to_wire(v: "Wire | Reg | CycleAwareSignal | StateSignal") -> Wire: + if isinstance(v, StateSignal): + return v.wire + if isinstance(v, CycleAwareSignal): + return v.wire + if isinstance(v, Reg): + return v.q + if isinstance(v, Wire): + return v + raise TypeError(f"expected Wire/Reg/CycleAwareSignal/StateSignal, got {type(v).__name__}") + + +class CycleAwareSignal: + """Value with logical cycle tag; operators align by delaying earlier operands.""" + + __slots__ = ("_domain", "_w", "_cycle") + + def __init__(self, domain: CycleAwareDomain, wire: Wire, cycle: int) -> None: + if wire.m is not domain._m: + raise ValueError("Wire must belong to the same circuit as the domain") + self._domain = domain + self._w = wire + self._cycle = int(cycle) + + @property + def wire(self) -> Wire: + return self._w + + @property + def w(self) -> Wire: + return self._w + + @property + def cycle(self) -> int: + return self._cycle + + @property + def domain(self) -> CycleAwareDomain: + return self._domain + + @property + def sig(self) -> Signal: + return self._w.sig + + @property + def name(self) -> str: + return str(self._w) + + @property + def signed(self) -> bool: + return bool(self._w.signed) + + def named(self, name: str) -> "CycleAwareSignal": + nw = self._domain._m.named(self._w, str(name)) + return CycleAwareSignal(self._domain, nw, self._cycle) + + def _align(self, other: "CycleAwareSignal | Wire | Reg | int | LiteralValue") -> tuple[Wire, Wire, int]: + if isinstance(other, CycleAwareSignal): + if other._domain is not self._domain: + raise ValueError("CycleAwareSignal operands must share the same domain") + oc = other._cycle + ow = other._w + elif isinstance(other, (Wire, Reg)): + ow = other.q if isinstance(other, Reg) else other + oc = self._domain.cycle_index + elif isinstance(other, int): + ow = self._domain._m.const(other, width=max(1, infer_literal_width(other, signed=other < 0))) + oc = self._domain.cycle_index + elif isinstance(other, LiteralValue): + lit_w = other.width if other.width is not None else infer_literal_width(int(other.value), signed=bool(other.signed)) + ow = self._domain._m.const(int(other.value), width=int(lit_w)) + oc = self._domain.cycle_index + else: + raise TypeError(f"unsupported operand: {type(other).__name__}") + mx = max(self._cycle, oc) + aw = self._domain.delay_to(self._w, from_cycle=self._cycle, to_cycle=mx, width=self._w.width) + bw = self._domain.delay_to(ow, from_cycle=oc, to_cycle=mx, width=ow.width) + a2, b2 = _promote_pair(self._domain._m, aw, bw) + return a2, b2, mx + + def __add__(self, other: object) -> "CycleAwareSignal": + a, b, c = self._align(other) # type: ignore[arg-type] + return CycleAwareSignal(self._domain, a + b, c) + + def __radd__(self, other: object) -> "CycleAwareSignal": + return self.__add__(other) + + def __sub__(self, other: object) -> "CycleAwareSignal": + a, b, c = self._align(other) # type: ignore[arg-type] + return CycleAwareSignal(self._domain, a - b, c) + + def __mul__(self, other: object) -> "CycleAwareSignal": + a, b, c = self._align(other) # type: ignore[arg-type] + return CycleAwareSignal(self._domain, a * b, c) + + def __and__(self, other: object) -> "CycleAwareSignal": + a, b, c = self._align(other) # type: ignore[arg-type] + return CycleAwareSignal(self._domain, a & b, c) + + def __or__(self, other: object) -> "CycleAwareSignal": # type: ignore[override] + if isinstance(other, str): + _ = other + return self + a, b, c = self._align(other) # type: ignore[arg-type] + return CycleAwareSignal(self._domain, a | b, c) + + def __xor__(self, other: object) -> "CycleAwareSignal": + a, b, c = self._align(other) # type: ignore[arg-type] + return CycleAwareSignal(self._domain, a ^ b, c) + + def __invert__(self) -> "CycleAwareSignal": + return CycleAwareSignal(self._domain, ~self._w, self._cycle) + + def __eq__(self, other: object) -> "CycleAwareSignal": # type: ignore[override] + a, b, c = self._align(other) # type: ignore[arg-type] + return CycleAwareSignal(self._domain, a == b, c) + + def __ne__(self, other: object) -> "CycleAwareSignal": # type: ignore[override] + a, b, c = self._align(other) # type: ignore[arg-type] + return CycleAwareSignal(self._domain, a != b, c) + + def __lt__(self, other: object) -> "CycleAwareSignal": + a, b, c = self._align(other) # type: ignore[arg-type] + return CycleAwareSignal(self._domain, a < b, c) + + def __gt__(self, other: object) -> "CycleAwareSignal": + a, b, c = self._align(other) # type: ignore[arg-type] + return CycleAwareSignal(self._domain, a > b, c) + + def __le__(self, other: object) -> "CycleAwareSignal": + a, b, c = self._align(other) # type: ignore[arg-type] + return CycleAwareSignal(self._domain, a <= b, c) + + def __ge__(self, other: object) -> "CycleAwareSignal": + a, b, c = self._align(other) # type: ignore[arg-type] + return CycleAwareSignal(self._domain, a >= b, c) + + def eq(self, other: object) -> "CycleAwareSignal": + return self.__eq__(other) + + def lt(self, other: object) -> "CycleAwareSignal": + return self.__lt__(other) + + def gt(self, other: object) -> "CycleAwareSignal": + return self.__gt__(other) + + def le(self, other: object) -> "CycleAwareSignal": + return self.__le__(other) + + def ge(self, other: object) -> "CycleAwareSignal": + return self.__ge__(other) + + def trunc(self, width: int) -> "CycleAwareSignal": + return CycleAwareSignal(self._domain, self._w.trunc(width=int(width)), self._cycle) + + def zext(self, width: int) -> "CycleAwareSignal": + return CycleAwareSignal(self._domain, self._w.zext(width=int(width)), self._cycle) + + def sext(self, width: int) -> "CycleAwareSignal": + return CycleAwareSignal(self._domain, self._w.sext(width=int(width)), self._cycle) + + def slice(self, high: int, low: int) -> "CycleAwareSignal": + lo = int(low) + hi = int(high) + return CycleAwareSignal(self._domain, self._w[lo : hi + 1], self._cycle) + + def select(self, true_val: object, false_val: object) -> "CycleAwareSignal": + return mux(self, true_val, false_val) + + def as_signed(self) -> "CycleAwareSignal": + return CycleAwareSignal(self._domain, Wire(self._domain._m, self._w.sig, signed=True), self._cycle) + + def as_unsigned(self) -> "CycleAwareSignal": + return CycleAwareSignal(self._domain, Wire(self._domain._m, self._w.sig, signed=False), self._cycle) + + def __getitem__(self, idx: int | slice) -> "CycleAwareSignal": + return CycleAwareSignal(self._domain, self._w[idx], self._cycle) + + +def _promote_pair(m: Circuit, a: Wire, b: Wire) -> tuple[Wire, Wire]: + if a.width == b.width: + return a, b + out_w = max(a.width, b.width) + if a.width < out_w: + a = a._sext(width=out_w) if a.signed else a._zext(width=out_w) + if b.width < out_w: + b = b._sext(width=out_w) if b.signed else b._zext(width=out_w) + return a, b + + +def _is_cas(v: object) -> bool: + return isinstance(v, (CycleAwareSignal, StateSignal)) + + +def mux( + cond: Union[Wire, Reg, CycleAwareSignal, StateSignal], + a: Union[Wire, Reg, CycleAwareSignal, StateSignal, int, LiteralValue], + b: Union[Wire, Reg, CycleAwareSignal, StateSignal, int, LiteralValue], +) -> Union[Wire, CycleAwareSignal]: + if _is_cas(cond) or _is_cas(a) or _is_cas(b): + c2 = cond._cas if isinstance(cond, StateSignal) else cond + a2 = a._cas if isinstance(a, StateSignal) else a + b2 = b._cas if isinstance(b, StateSignal) else b + return _mux_cycle_aware(c2, a2, b2) + return _mux_wire(cond, a, b) + + +def _mux_wire( + cond: Union[Wire, Reg], + a: Union[Wire, Reg, int, LiteralValue], + b: Union[Wire, Reg, int, LiteralValue], +) -> Wire: + c = cond.q if isinstance(cond, Reg) else cond + m = c.m + if not isinstance(m, Circuit): + raise TypeError("mux(cond, ...) requires wires from a Circuit") + + def as_wire(v: Union[Wire, Reg, int, LiteralValue], *, ctx_w: int | None) -> Wire: + if isinstance(v, Reg): + return v.q + if isinstance(v, Wire): + return v + if isinstance(v, LiteralValue): + if v.width is not None: + lit_w = int(v.width) + else: + lit_w = infer_literal_width( + int(v.value), + signed=(bool(v.signed) if v.signed is not None else int(v.value) < 0), + ) + return m.const(int(v.value), width=int(lit_w)) + if isinstance(v, int): + w = ctx_w if ctx_w is not None else max(1, infer_literal_width(int(v), signed=(int(v) < 0))) + return m.const(int(v), width=int(w)) + raise TypeError(f"mux: unsupported branch type {type(v).__name__}") + + aw = as_wire(a, ctx_w=c.width) + bw = as_wire(b, ctx_w=c.width) + aw, bw = _promote_pair(m, aw, bw) + if c.ty != "i1": + raise TypeError("mux condition must be i1") + return c._select_internal(aw, bw) + + +def _mux_cycle_aware( + cond: Union[Wire, Reg, CycleAwareSignal], + a: Union[Wire, Reg, CycleAwareSignal, int, LiteralValue], + b: Union[Wire, Reg, CycleAwareSignal, int, LiteralValue], +) -> CycleAwareSignal: + def pick_dom() -> CycleAwareDomain: + for x in (cond, a, b): + if isinstance(x, CycleAwareSignal): + return x._domain + raise RuntimeError("internal: mux cycle-aware without CycleAwareSignal") + + dom = pick_dom() + m = dom._m + + def to_cas(x: Union[Wire, Reg, CycleAwareSignal, int, LiteralValue]) -> CycleAwareSignal: + if isinstance(x, CycleAwareSignal): + return x + if isinstance(x, Reg): + return CycleAwareSignal(dom, x.q, dom.cycle_index) + if isinstance(x, Wire): + return CycleAwareSignal(dom, x, dom.cycle_index) + if isinstance(x, int): + w = m.const(x, width=max(1, infer_literal_width(x, signed=x < 0))) + return CycleAwareSignal(dom, w, dom.cycle_index) + if isinstance(x, LiteralValue): + lw = x.width if x.width is not None else infer_literal_width(int(x.value), signed=bool(x.signed)) + w = m.const(int(x.value), width=int(lw)) + return CycleAwareSignal(dom, w, dom.cycle_index) + raise TypeError(f"mux: unsupported value {type(x).__name__}") + + c_cas = to_cas(cond) if not isinstance(cond, CycleAwareSignal) else cond + ca = to_cas(a) + cb = to_cas(b) + cc = c_cas._cycle + cw = c_cas._w + mx = max(cc, ca._cycle, cb._cycle) + cw2 = dom.delay_to(cw, from_cycle=cc, to_cycle=mx, width=cw.width) + aw = dom.delay_to(ca.wire, from_cycle=ca._cycle, to_cycle=mx, width=ca.wire.width) + bw = dom.delay_to(cb.wire, from_cycle=cb._cycle, to_cycle=mx, width=cb.wire.width) + aw, bw = _promote_pair(m, aw, bw) + if cw2.ty != "i1": + raise TypeError("mux condition must be i1") + out_w = cw2._select_internal(aw, bw) + return CycleAwareSignal(dom, out_w, mx) + + +def cas(domain: CycleAwareDomain, w: Wire, *, cycle: int | None = None) -> CycleAwareSignal: + c = domain.cycle_index if cycle is None else int(cycle) + return CycleAwareSignal(domain, w, c) + + +def _strip_domain_for_jit(fn: Callable[..., Any], *, domain_name: str) -> Callable[..., Any]: + """Drop the ``domain`` parameter for JIT and prepend ``domain = m.create_domain(...)``.""" + try: + source = textwrap.dedent(inspect.getsource(fn)) + except OSError as e: + raise TypeError( + "compile_cycle_aware(fn): need inspectable source for JIT; use eager=True or define fn in a .py file" + ) from e + tree = ast.parse(source) + name = getattr(fn, "__name__", None) + if not isinstance(name, str) or not name: + raise TypeError("compile_cycle_aware(fn): function must have a __name__") + fdef: ast.FunctionDef | None = None + for node in tree.body: + if isinstance(node, ast.FunctionDef) and node.name == name: + fdef = node + break + if fdef is None: + raise TypeError(f"compile_cycle_aware: could not find def {name!r} in source of {fn!r}") + pos = fdef.args.args + if len(pos) < 2: + raise TypeError("compile_cycle_aware(fn): source must declare at least (m, domain, ...)") + m_arg = pos[0].arg + if pos[1].arg != "domain": + raise TypeError( + "compile_cycle_aware(fn): second parameter must be named 'domain' for JIT (or use eager=True)" + ) + fdef.args.args.pop(1) + prelude = ast.Assign( + targets=[ast.Name(id="domain", ctx=ast.Store())], + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id=m_arg, ctx=ast.Load()), + attr="create_domain", + ctx=ast.Load(), + ), + args=[ast.Constant(value=str(domain_name))], + keywords=[], + ), + ) + fdef.body.insert(0, prelude) + ast.fix_missing_locations(fdef) + new_src = ast.unparse(fdef) + "\n" + globs = dict(fn.__globals__) + exec(compile(ast.parse(new_src), "", "exec"), globs) + out: Callable[..., Any] = globs[name] + out.__pycircuit_jit_source__ = new_src + out.__pycircuit_jit_start_line__ = 1 + out.__pycircuit_jit_source_file__ = "" + setattr(out, "__pycircuit_kind__", "module") + setattr(out, "__pycircuit_inline__", False) + for attr in ("__pycircuit_name__", "__pycircuit_module_name__"): + if hasattr(fn, attr): + setattr(out, attr, getattr(fn, attr)) + return out + + +def compile_cycle_aware( + fn: F, + *, + name: str | None = None, + domain_name: str = "clk", + eager: bool = False, + structural: bool | None = None, + value_params: Mapping[str, str] | dict[str, str] | None = None, + **jit_params: Any, +) -> Any: + """Compile or execute ``fn(m, domain, **kwargs)``. + + By default this lowers through :func:`pycircuit.jit.compile`: a tiny ``@module``-style + wrapper instantiates :class:`CycleAwareDomain` from ``domain_name`` and calls ``fn``. + Pass ``eager=True`` to run ``fn`` directly in Python and get a + :class:`CycleAwareCircuit` (no JIT; no ``if Wire`` / JIT control flow). + """ + if eager: + circuit_name = name if isinstance(name, str) and name.strip() else getattr(fn, "__name__", "design") or "design" + m = CycleAwareCircuit(str(circuit_name)) + dom = m.create_domain(str(domain_name)) + out = fn(m, dom, **jit_params) + if out is not None: + _register_implicit_outputs(m, out) + return m + + from .jit import compile as jit_compile + + if name is None or not str(name).strip(): + override = getattr(fn, "__pycircuit_name__", None) + if isinstance(override, str) and override.strip(): + sym = override.strip() + else: + sym = getattr(fn, "__name__", "Top") + else: + sym = str(name).strip() + + struc = bool(getattr(fn, "__pycircuit_emit_structural__", False)) if structural is None else bool(structural) + + if value_params is None: + vp_raw = getattr(fn, "__pycircuit_value_params__", None) + vp: dict[str, str] = dict(vp_raw) if isinstance(vp_raw, dict) else {} + else: + vp = dict(value_params) + + domain_n = str(domain_name) + + _jit_fn = _strip_domain_for_jit(fn, domain_name=domain_n) + setattr(_jit_fn, "__pycircuit_module_name__", sym) + setattr(_jit_fn, "__pycircuit_kind__", "module") + setattr(_jit_fn, "__pycircuit_inline__", False) + setattr(_jit_fn, "__pycircuit_emit_structural__", struc) + setattr(_jit_fn, "__pycircuit_value_params__", vp) + pn = getattr(fn, "__pycircuit_name__", None) + if isinstance(pn, str) and pn.strip(): + setattr(_jit_fn, "__pycircuit_name__", pn.strip()) + else: + setattr(_jit_fn, "__pycircuit_name__", sym) + + return jit_compile(_jit_fn, name=name, **jit_params) + + +def _register_implicit_outputs(m: Circuit, out: Any) -> None: + if isinstance(out, CycleAwareSignal): + m.output("result", out.wire) + return + if isinstance(out, Wire): + m.output("result", out) + return + if isinstance(out, Reg): + m.output("result", out.q) + return + if isinstance(out, tuple): + for i, x in enumerate(out): + _register_implicit_outputs_single(m, f"result{i}", x) + return + _register_implicit_outputs_single(m, "result", out) + + +def _register_implicit_outputs_single(m: Circuit, port: str, x: Any) -> None: + if isinstance(x, CycleAwareSignal): + m.output(port, x.wire) + elif isinstance(x, Wire): + m.output(port, x) + elif isinstance(x, Reg): + m.output(port, x.q) + + +class pyc_CircuitModule: + """Tutorial-style module base (hierarchy + with self.module(...)).""" + + def __init__(self, name: str, clock_domain: CycleAwareDomain) -> None: + self.name = str(name) + self.clock_domain = clock_domain + self._m = clock_domain.circuit + + @property + def circuit(self) -> CycleAwareCircuit: + return self._m + + @contextmanager + def module( + self, + *, + inputs: list[Any] | None = None, + description: str = "", + ) -> Iterator[_ModuleCtx]: + _ = description + ctx = _ModuleCtx(self, list(inputs or []), description) + prev = _current_domain() + _set_current_domain(self.clock_domain) + try: + with self._m.scope(self.name): + yield ctx + finally: + _set_current_domain(prev) + for out in ctx.outputs: + _ = out + + +# Tutorial aliases +pyc_ClockDomain = CycleAwareDomain +pyc_Signal = CycleAwareSignal + + +class pyc_CircuitLogger: + """Minimal hierarchical text logger (tutorial compatibility).""" + + def __init__(self, filename: str, is_flatten: bool = False) -> None: + self.filename = str(filename) + self.is_flatten = bool(is_flatten) + self._lines: list[str] = [] + + def reset(self) -> None: + self._lines.clear() + + def write_to_file(self) -> None: + with open(self.filename, "w", encoding="utf-8") as f: + f.write("\n".join(self._lines)) + + +def log(value: Any) -> Any: + return value + + +class _SignalSlice: + def __init__(self, high: int, low: int) -> None: + self.high = int(high) + self.low = int(low) + self.width = self.high - self.low + 1 + + def __call__(self, *, value: Any = 0, name: str = "") -> CycleAwareSignal: + dom = _current_domain() + if dom is None: + raise RuntimeError("signal[...](...) requires an active pyc_CircuitModule.module() context") + w = _materialize_signal_value(dom, value, self.width, str(name)) + return CycleAwareSignal(dom, w, dom.cycle_index) + + +class _SignalMeta(type): + def __getitem__(cls, item: Any) -> _SignalSlice: + if isinstance(item, slice): + if item.step not in (None, 1): + raise ValueError("signal slice step must be 1") + hi, lo = item.start, item.stop + if hi is None or lo is None: + raise ValueError("signal[h:l] requires both high and low") + return _SignalSlice(int(hi), int(lo)) + if isinstance(item, str): + part = item.split(":", 1) + if len(part) != 2: + raise ValueError('signal["h:l"] expects one ":"') + return _SignalSlice(int(part[0].strip()), int(part[1].strip())) + raise TypeError("signal[...] expects slice like [7:0] or string '7:0'") + + def __call__(cls, *, value: Any = 0, name: str = "") -> CycleAwareSignal: + if cls is signal: + return _signal_plain(value=value, name=name) + return type.__call__(cls) + + +class signal(metaclass=_SignalMeta): + """Tutorial: ``signal[7:0](value=0) | \"desc\"`` and ``signal(value=...)``.""" + + +def _signal_plain(*, value: Any = 0, name: str = "") -> CycleAwareSignal: + dom = _current_domain() + if dom is None: + raise RuntimeError("signal(value=...) requires an active pyc_CircuitModule.module() context") + w = _materialize_signal_value(dom, value, None, str(name)) + return CycleAwareSignal(dom, w, dom.cycle_index) + + +def _materialize_signal_value(dom: CycleAwareDomain, value: Any, width: int | None, name: str) -> Wire: + m = dom._m + if isinstance(value, int): + w = infer_literal_width(int(value), signed=(int(value) < 0)) if width is None else int(width) + return m.const(int(value), width=w) + if isinstance(value, str): + base = str(value).strip() + if base.isidentifier(): + guess = 8 if width is None else int(width) + return m.input(base, width=guess) + return m.named_wire(dom._m.scoped_name(name or "sig"), width=int(width or 8)) + if isinstance(value, Wire): + return value + raise TypeError(f"unsupported signal value: {type(value).__name__}") + + +# --------------------------------------------------------------------------- +# V5 Cycle-Aware Testbench wrapper +# --------------------------------------------------------------------------- + +class CycleAwareTb: + """V5 cycle-aware testbench wrapper. + + Wraps :class:`Tb` so that ``drive`` / ``expect`` / ``finish`` calls use the + current cycle tracked by :meth:`next` instead of an explicit ``at=`` + parameter, mirroring ``domain.next()`` in design code. + + Usage inside a ``@testbench`` function:: + + @testbench + def tb(t: Tb) -> None: + tb = CycleAwareTb(t) + tb.clock("clk") + tb.reset("rst", cycles_asserted=2, cycles_deasserted=1) + tb.timeout(64) + + # --- cycle 0 --- + tb.drive("enable", 1) + tb.expect("count", 1) + + tb.next() # --- cycle 1 --- + tb.expect("count", 2) + + tb.finish() + """ + + __slots__ = ("_t", "_cycle") + + def __init__(self, t: _Tb) -> None: + if not isinstance(t, _Tb): + raise TypeError( + f"CycleAwareTb requires a Tb instance, got {type(t).__name__}" + ) + self._t = t + self._cycle = 0 + + # -- cycle management --------------------------------------------------- + + def next(self) -> None: + """Advance to the next clock cycle (like ``domain.next()``).""" + self._cycle += 1 + + @property + def cycle(self) -> int: + """Current cycle index.""" + return self._cycle + + # -- setup (cycle-independent) ------------------------------------------ + + def clock(self, port: str, **kw: Any) -> None: + self._t.clock(port, **kw) + + def reset(self, port: str, **kw: Any) -> None: + self._t.reset(port, **kw) + + def timeout(self, cycles: int) -> None: + self._t.timeout(cycles) + + # -- stimulus / check (cycle-relative) ---------------------------------- + + def drive(self, port: str, value: int | bool) -> None: + """Drive *port* at the current cycle.""" + self._t.drive(port, value, at=self._cycle) + + def expect( + self, + port: str, + value: int | bool, + *, + phase: str = "post", + msg: str | None = None, + ) -> None: + """Check *port* at the current cycle.""" + self._t.expect(port, value, at=self._cycle, phase=phase, msg=msg) + + def finish(self, *, at: int | None = None) -> None: + """End the simulation at the current cycle (or at an explicit cycle).""" + self._t.finish(at=self._cycle if at is None else int(at)) + + # -- print helpers ------------------------------------------------------ + + def print(self, fmt: str, *, ports: Iterable[str] = ()) -> None: + """Print at the current cycle.""" + self._t.print(fmt, at=self._cycle, ports=ports) + + def print_every(self, fmt: str, **kw: Any) -> None: + self._t.print_every(fmt, **kw) + + # -- pass-through ------------------------------------------------------- + + def sva_assert(self, expr: Any, **kw: Any) -> None: + self._t.sva_assert(expr, **kw) + + def random(self, port: str, **kw: Any) -> None: + self._t.random(port, **kw) + + diff --git a/compiler/mlir/tools/pycc.cpp b/compiler/mlir/tools/pycc.cpp index 4d40d5b..e83aa22 100644 --- a/compiler/mlir/tools/pycc.cpp +++ b/compiler/mlir/tools/pycc.cpp @@ -2168,8 +2168,8 @@ int main(int argc, char **argv) { GreedyRewriteConfig canonicalizeCfg; if (effectiveCanonicalizeBudget > 0) { - canonicalizeCfg.maxIterations = static_cast(effectiveCanonicalizeBudget); - canonicalizeCfg.maxNumRewrites = static_cast(effectiveCanonicalizeBudget) * 4096; + canonicalizeCfg.setMaxIterations(static_cast(effectiveCanonicalizeBudget)); + canonicalizeCfg.setMaxNumRewrites(static_cast(effectiveCanonicalizeBudget) * 4096); } // Cleanup + optimization pipeline tuned for netlist-style emission. diff --git a/designs/BypassUnit/bypass_unit.py b/designs/BypassUnit/bypass_unit.py index 477433e..f9407a5 100644 --- a/designs/BypassUnit/bypass_unit.py +++ b/designs/BypassUnit/bypass_unit.py @@ -1,6 +1,13 @@ from __future__ import annotations -from pycircuit import Circuit, Tb, compile, function, module, testbench, u +from pycircuit import ( + CycleAwareCircuit, + CycleAwareDomain, + Tb, + compile_cycle_aware, + mux, + testbench, +) PTYPE_C = 0 PTYPE_P = 1 @@ -8,15 +15,12 @@ PTYPE_U = 3 -@function -def _not1(m: Circuit, x): - _ = m - return u(1, 1) ^ x +def _not1(m, x): + return m.const(1, width=1) ^ x -@function def _select_stage( - m: Circuit, + m, *, src_valid, src_ptag, @@ -29,23 +33,22 @@ def _select_stage( lane_w: int, data_w: int, ): - has = u(1, 0) - sel_lane = u(int(lane_w), 0) - sel_data = u(int(data_w), 0) + has = m.const(0, width=1) + sel_lane = m.const(0, width=int(lane_w)) + sel_data = m.const(0, width=int(data_w)) for j in range(int(lanes)): match = src_valid & lane_valid[j] & (lane_ptag[j] == src_ptag) & (lane_ptype[j] == src_ptype) take = match & _not1(m, has) - sel_lane = (u(int(lane_w), j)) if take else sel_lane - sel_data = lane_data[j] if take else sel_data + sel_lane = mux(take, m.const(j, width=int(lane_w)), sel_lane) + sel_data = mux(take, lane_data[j], sel_data) has = has | match return has, sel_lane, sel_data -@function def _resolve_src( - m: Circuit, + m, *, src_valid, src_ptag, @@ -68,67 +71,40 @@ def _resolve_src( data_w: int, ): has_w1, lane_w1, data_w1 = _select_stage( - m, - src_valid=src_valid, - src_ptag=src_ptag, - src_ptype=src_ptype, - lane_valid=w1_valid, - lane_ptag=w1_ptag, - lane_ptype=w1_ptype, - lane_data=w1_data, - lanes=lanes, - lane_w=lane_w, - data_w=data_w, + m, src_valid=src_valid, src_ptag=src_ptag, src_ptype=src_ptype, + lane_valid=w1_valid, lane_ptag=w1_ptag, lane_ptype=w1_ptype, lane_data=w1_data, + lanes=lanes, lane_w=lane_w, data_w=data_w, ) has_w2, lane_w2, data_w2 = _select_stage( - m, - src_valid=src_valid, - src_ptag=src_ptag, - src_ptype=src_ptype, - lane_valid=w2_valid, - lane_ptag=w2_ptag, - lane_ptype=w2_ptype, - lane_data=w2_data, - lanes=lanes, - lane_w=lane_w, - data_w=data_w, + m, src_valid=src_valid, src_ptag=src_ptag, src_ptype=src_ptype, + lane_valid=w2_valid, lane_ptag=w2_ptag, lane_ptype=w2_ptype, lane_data=w2_data, + lanes=lanes, lane_w=lane_w, data_w=data_w, ) has_w3, lane_w3, data_w3 = _select_stage( - m, - src_valid=src_valid, - src_ptag=src_ptag, - src_ptype=src_ptype, - lane_valid=w3_valid, - lane_ptag=w3_ptag, - lane_ptype=w3_ptype, - lane_data=w3_data, - lanes=lanes, - lane_w=lane_w, - data_w=data_w, + m, src_valid=src_valid, src_ptag=src_ptag, src_ptype=src_ptype, + lane_valid=w3_valid, lane_ptag=w3_ptag, lane_ptype=w3_ptype, lane_data=w3_data, + lanes=lanes, lane_w=lane_w, data_w=data_w, ) - out_data = data_w3 if has_w3 else src_rf_data - out_hit = u(1, 1) if has_w3 else u(1, 0) - out_stage = u(2, 3) if has_w3 else u(2, 0) - out_lane = lane_w3 if has_w3 else u(int(lane_w), 0) + out_data = mux(has_w3, data_w3, src_rf_data) + out_hit = mux(has_w3, m.const(1, width=1), m.const(0, width=1)) + out_stage = mux(has_w3, m.const(3, width=2), m.const(0, width=2)) + out_lane = mux(has_w3, lane_w3, m.const(0, width=int(lane_w))) - out_data = data_w2 if has_w2 else out_data - out_hit = u(1, 1) if has_w2 else out_hit - out_stage = u(2, 2) if has_w2 else out_stage - out_lane = lane_w2 if has_w2 else out_lane + out_data = mux(has_w2, data_w2, out_data) + out_hit = mux(has_w2, m.const(1, width=1), out_hit) + out_stage = mux(has_w2, m.const(2, width=2), out_stage) + out_lane = mux(has_w2, lane_w2, out_lane) - out_data = data_w1 if has_w1 else out_data - out_hit = u(1, 1) if has_w1 else out_hit - out_stage = u(2, 1) if has_w1 else out_stage - out_lane = lane_w1 if has_w1 else out_lane + out_data = mux(has_w1, data_w1, out_data) + out_hit = mux(has_w1, m.const(1, width=1), out_hit) + out_stage = mux(has_w1, m.const(1, width=2), out_stage) + out_lane = mux(has_w1, lane_w1, out_lane) return out_data, out_hit, out_stage, out_lane -@module -def build( - m: Circuit, - *, +def build(m: CycleAwareCircuit, domain: CycleAwareDomain, *, lanes: int = 8, data_width: int = 64, ptag_count: int = 256, @@ -154,10 +130,6 @@ def build( ptype_w = max(1, (ptype_n - 1).bit_length()) lane_w = max(1, (lanes_n - 1).bit_length()) - # Declared for pyCircuit testbench generation flow. - _clk = m.clock("clk") - _rst = m.reset("rst") - w_valid: dict[str, list] = {} w_ptag: dict[str, list] = {} w_ptype: dict[str, list] = {} @@ -177,25 +149,12 @@ def build( out_data, out_hit, out_stage, out_lane = _resolve_src( m, - src_valid=src_valid, - src_ptag=src_ptag, - src_ptype=src_ptype, + src_valid=src_valid, src_ptag=src_ptag, src_ptype=src_ptype, src_rf_data=src_rf_data, - w1_valid=w_valid["w1"], - w1_ptag=w_ptag["w1"], - w1_ptype=w_ptype["w1"], - w1_data=w_data["w1"], - w2_valid=w_valid["w2"], - w2_ptag=w_ptag["w2"], - w2_ptype=w_ptype["w2"], - w2_data=w_data["w2"], - w3_valid=w_valid["w3"], - w3_ptag=w_ptag["w3"], - w3_ptype=w_ptype["w3"], - w3_data=w_data["w3"], - lanes=lanes_n, - lane_w=lane_w, - data_w=data_w, + w1_valid=w_valid["w1"], w1_ptag=w_ptag["w1"], w1_ptype=w_ptype["w1"], w1_data=w_data["w1"], + w2_valid=w_valid["w2"], w2_ptag=w_ptag["w2"], w2_ptype=w_ptype["w2"], w2_data=w_data["w2"], + w3_valid=w_valid["w3"], w3_ptag=w_ptag["w3"], w3_ptype=w_ptype["w3"], w3_data=w_data["w3"], + lanes=lanes_n, lane_w=lane_w, data_w=data_w, ) m.output(f"i2{i}_{src}_data", out_data) @@ -217,12 +176,12 @@ def tb(t: Tb) -> None: if __name__ == "__main__": print( - compile( - build, + compile_cycle_aware(build, name="bypass_unit", + eager=True, lanes=8, data_width=64, ptag_count=256, ptype_count=4, - ).emit_mlir() + ).emit_mlir()[:500] ) diff --git a/designs/BypassUnit/tb_bypass_unit.py b/designs/BypassUnit/tb_bypass_unit.py index 9ec3b2a..8c0bf57 100644 --- a/designs/BypassUnit/tb_bypass_unit.py +++ b/designs/BypassUnit/tb_bypass_unit.py @@ -3,7 +3,7 @@ import sys from pathlib import Path -from pycircuit import Tb, compile, testbench +from pycircuit import CycleAwareTb, Tb, compile_cycle_aware, CycleAwareCircuit, CycleAwareDomain, testbench from pycircuit.tb import sva _THIS_DIR = Path(__file__).resolve().parent @@ -60,45 +60,43 @@ def _resolve_expected(src: dict, wb: dict, *, lanes: int) -> tuple[int, int, int return src_rf_data, 0, 0, 0 -def _drive_cycle(t: Tb, cyc: int, spec: dict, *, lanes: int) -> None: +def _drive_cycle(tb: CycleAwareTb, spec: dict, *, lanes: int) -> None: wb = spec["wb"] i2 = spec["i2"] for stage in _STAGES: for lane in range(lanes): w = wb[stage][lane] - t.drive(f"{stage}{lane}_valid", int(w["valid"]), at=cyc) - t.drive(f"{stage}{lane}_ptag", int(w["ptag"]), at=cyc) - t.drive(f"{stage}{lane}_ptype", int(w["ptype"]), at=cyc) - t.drive(f"{stage}{lane}_data", int(w["data"]), at=cyc) + tb.drive(f"{stage}{lane}_valid", int(w["valid"])) + tb.drive(f"{stage}{lane}_ptag", int(w["ptag"])) + tb.drive(f"{stage}{lane}_ptype", int(w["ptype"])) + tb.drive(f"{stage}{lane}_data", int(w["data"])) for i in range(lanes): for src in _SRCS: s = i2[i][src] - t.drive(f"i2{i}_{src}_valid", int(s["valid"]), at=cyc) - t.drive(f"i2{i}_{src}_ptag", int(s["ptag"]), at=cyc) - t.drive(f"i2{i}_{src}_ptype", int(s["ptype"]), at=cyc) - t.drive(f"i2{i}_{src}_rf_data", int(s["rf_data"]), at=cyc) + tb.drive(f"i2{i}_{src}_valid", int(s["valid"])) + tb.drive(f"i2{i}_{src}_ptag", int(s["ptag"])) + tb.drive(f"i2{i}_{src}_ptype", int(s["ptype"])) + tb.drive(f"i2{i}_{src}_rf_data", int(s["rf_data"])) -def _expect_cycle(t: Tb, cyc: int, spec: dict, *, lanes: int) -> None: +def _expect_cycle(tb: CycleAwareTb, cyc: int, spec: dict, *, lanes: int) -> None: wb = spec["wb"] i2 = spec["i2"] for i in range(lanes): for src in _SRCS: exp_data, exp_hit, exp_stage, exp_lane = _resolve_expected(i2[i][src], wb, lanes=lanes) - t.expect(f"i2{i}_{src}_data", exp_data, at=cyc, msg=f"data mismatch lane={i} src={src} cycle={cyc}") - t.expect(f"i2{i}_{src}_hit", exp_hit, at=cyc, msg=f"hit mismatch lane={i} src={src} cycle={cyc}") - t.expect( + tb.expect(f"i2{i}_{src}_data", exp_data, msg=f"data mismatch lane={i} src={src} cycle={cyc}") + tb.expect(f"i2{i}_{src}_hit", exp_hit, msg=f"hit mismatch lane={i} src={src} cycle={cyc}") + tb.expect( f"i2{i}_{src}_sel_stage", exp_stage, - at=cyc, msg=f"sel_stage mismatch lane={i} src={src} cycle={cyc}", ) - t.expect( + tb.expect( f"i2{i}_{src}_sel_lane", exp_lane, - at=cyc, msg=f"sel_lane mismatch lane={i} src={src} cycle={cyc}", ) @@ -340,6 +338,7 @@ def _gen_random_stress(*, lanes: int, ptag_count: int, count: int, seed: int) -> @testbench def tb(t: Tb) -> None: + tb = CycleAwareTb(t) lanes = 8 ptag_count = 256 @@ -393,10 +392,10 @@ def tb(t: Tb) -> None: cycles.extend(_gen_invalid_source_sweep(lanes=lanes, ptag_count=ptag_count)) cycles.extend(_gen_random_stress(lanes=lanes, ptag_count=ptag_count, count=32, seed=0xD1CE_BA5E_F00D_CAFE)) - t.clock("clk") - t.reset("rst", cycles_asserted=2, cycles_deasserted=1) - t.timeout(len(cycles) + 64) - t.print_every("bypass", start=0, every=32, ports=["i20_srcL_hit", "i20_srcR_hit"]) + tb.clock("clk") + tb.reset("rst", cycles_asserted=2, cycles_deasserted=1) + tb.timeout(len(cycles) + 64) + tb.print_every("bypass", start=0, every=32, ports=["i20_srcL_hit", "i20_srcR_hit"]) for i in range(lanes): for src in _SRCS: @@ -405,7 +404,7 @@ def tb(t: Tb) -> None: for b in range(a + 1, lanes): match_a = _match_expr(stage, a, i, src) match_b = _match_expr(stage, b, i, src) - t.sva_assert( + tb.sva_assert( ~(match_a & match_b), clock="clk", reset="rst", @@ -413,17 +412,19 @@ def tb(t: Tb) -> None: msg=f"illegal same-stage multihit stage={stage} src={src} lane={i}", ) + # --- cycle 0 --- for cyc, spec in enumerate(cycles): - _drive_cycle(t, cyc, spec, lanes=lanes) - _expect_cycle(t, cyc, spec, lanes=lanes) + if cyc > 0: + tb.next() # --- advance to next cycle --- + _drive_cycle(tb, spec, lanes=lanes) + _expect_cycle(tb, cyc, spec, lanes=lanes) - t.finish(at=len(cycles) - 1) + tb.finish() if __name__ == "__main__": print( - compile( - build, + compile_cycle_aware(build, name="tb_bypass_unit_top", lanes=8, data_width=64, diff --git a/designs/IssueQueue/issq.py b/designs/IssueQueue/issq.py index 8a4cdc6..1f88c9a 100644 --- a/designs/IssueQueue/issq.py +++ b/designs/IssueQueue/issq.py @@ -4,7 +4,7 @@ from pathlib import Path from typing import Any -from pycircuit import Circuit, compile, function, module, u +from pycircuit import Circuit, compile_cycle_aware, CycleAwareCircuit, CycleAwareDomain, function, module, u _THIS_DIR = Path(__file__).resolve().parent if str(_THIS_DIR) not in sys.path: @@ -347,10 +347,7 @@ def _emit_debug_and_ready( m.output("issued_total", issued_total_q.out()) -@module -def build( - m: Circuit, - *, +def build(m: CycleAwareCircuit, domain: CycleAwareDomain, *, entries: int = 16, ptag_count: int = 64, payload_width: int = 32, @@ -377,9 +374,9 @@ def build( occ_w = int(cfg.occupancy_width) issue_cnt_w = int(cfg.issue_count_width) issued_total_w = int(cfg.issued_total_width) - - clk = m.clock("clk") - rst = m.reset("rst") + cd = domain.clock_domain + clk = cd.clk + rst = cd.rst uop_spec = _uop_spec(m, cfg) entry_spec = _entry_spec(m, cfg) @@ -388,27 +385,26 @@ def build( enq_uops = [m.inputs(uop_spec, prefix=f"enq{k}_") for k in range(n_enq)] entry_state = [ - m.state(entry_spec, clk=clk, rst=rst, prefix=f"ent{i}_", init=0) + m.state(entry_spec, clk=cd.clk, rst=cd.rst, prefix=f"ent{i}_", init=0) for i in range(e) ] age_state = [ - [m.out(f"age_{i}_{j}", clk=clk, rst=rst, width=1, init=u(1, 0)) for j in range(e)] + [m.out(f"age_{i}_{j}", domain=cd, width=1, init=u(1, 0)) for j in range(e)] for i in range(e) ] ready_state = [ m.out( f"ready_ptag_{t}", - clk=clk, - rst=rst, + domain=cd, width=1, init=u(1, (int(cfg.init_ready_mask) >> t) & 1), ) for t in range(p) ] - issued_total_q = m.out("issued_total_q", clk=clk, rst=rst, width=issued_total_w, init=u(issued_total_w, 0)) + issued_total_q = m.out("issued_total_q", domain=cd, width=issued_total_w, init=u(issued_total_w, 0)) cur = _snapshot_entries(m, entry_state, e) @@ -496,8 +492,7 @@ def build( if __name__ == "__main__": print( - compile( - build, + compile_cycle_aware(build, name="issq", entries=16, ptag_count=64, diff --git a/designs/IssueQueue/tb_issq.py b/designs/IssueQueue/tb_issq.py index a62ff84..ff73dfa 100644 --- a/designs/IssueQueue/tb_issq.py +++ b/designs/IssueQueue/tb_issq.py @@ -3,7 +3,7 @@ import sys from pathlib import Path -from pycircuit import Tb, compile, testbench +from pycircuit import CycleAwareTb, Tb, compile_cycle_aware, CycleAwareCircuit, CycleAwareDomain, testbench _THIS_DIR = Path(__file__).resolve().parent if str(_THIS_DIR) not in sys.path: @@ -24,6 +24,7 @@ @testbench def tb(t: Tb) -> None: + tb = CycleAwareTb(t) entries = 16 ptag_count = 64 enq_ports = 2 @@ -89,38 +90,41 @@ def tb(t: Tb) -> None: else: raise RuntimeError("test stream did not drain (possible deadlock)") - t.clock("clk") - t.reset("rst", cycles_asserted=2, cycles_deasserted=1) - t.timeout(len(cycles) + 64) - t.expect("occupancy", 0, at=0, phase="pre") - t.print_every("issq", start=0, every=8, ports=["occupancy", "issued_total"]) + tb.clock("clk") + tb.reset("rst", cycles_asserted=2, cycles_deasserted=1) + tb.timeout(len(cycles) + 64) + tb.expect("occupancy", 0, phase="pre") + tb.print_every("issq", start=0, every=8, ports=["occupancy", "issued_total"]) + # --- cycle 0 --- for cyc, (lane_valid, lane_uops, obs) in enumerate(cycles): + if cyc > 0: + tb.next() # --- advance to next cycle --- + for k in range(enq_ports): uop = lane_uops[k] v = 1 if lane_valid[k] else 0 - t.drive(f"enq{k}_valid", v, at=cyc) - t.drive(f"enq{k}_src0_valid", int(uop.src0.valid), at=cyc) - t.drive(f"enq{k}_src0_ptag", int(uop.src0.ptag), at=cyc) - t.drive(f"enq{k}_src0_ready", int(uop.src0.ready), at=cyc) - t.drive(f"enq{k}_src1_valid", int(uop.src1.valid), at=cyc) - t.drive(f"enq{k}_src1_ptag", int(uop.src1.ptag), at=cyc) - t.drive(f"enq{k}_src1_ready", int(uop.src1.ready), at=cyc) - t.drive(f"enq{k}_dst_valid", int(uop.dst.valid), at=cyc) - t.drive(f"enq{k}_dst_ptag", int(uop.dst.ptag), at=cyc) - t.drive(f"enq{k}_dst_ready", int(uop.dst.ready), at=cyc) - t.drive(f"enq{k}_payload", int(uop.payload), at=cyc) + tb.drive(f"enq{k}_valid", v) + tb.drive(f"enq{k}_src0_valid", int(uop.src0.valid)) + tb.drive(f"enq{k}_src0_ptag", int(uop.src0.ptag)) + tb.drive(f"enq{k}_src0_ready", int(uop.src0.ready)) + tb.drive(f"enq{k}_src1_valid", int(uop.src1.valid)) + tb.drive(f"enq{k}_src1_ptag", int(uop.src1.ptag)) + tb.drive(f"enq{k}_src1_ready", int(uop.src1.ready)) + tb.drive(f"enq{k}_dst_valid", int(uop.dst.valid)) + tb.drive(f"enq{k}_dst_ptag", int(uop.dst.ptag)) + tb.drive(f"enq{k}_dst_ready", int(uop.dst.ready)) + tb.drive(f"enq{k}_payload", int(uop.payload)) _ = obs - t.finish(at=len(cycles) - 1) + tb.finish() if __name__ == "__main__": print( - compile( - build, + compile_cycle_aware(build, name="tb_issq_top", entries=16, ptag_count=64, diff --git a/designs/RegisterFile/emulate_regfile.py b/designs/RegisterFile/emulate_regfile.py new file mode 100644 index 0000000..b19f82d --- /dev/null +++ b/designs/RegisterFile/emulate_regfile.py @@ -0,0 +1,230 @@ +#!/usr/bin/env python3 +""" +emulate_regfile.py — True RTL simulation of the 256-entry, 10R/5W register file. + +Runs: + 1. Functional correctness tests (write then read-back, constant ROM, etc.) + 2. Performance benchmark: 100K cycles of mixed read/write traffic. +""" +from __future__ import annotations + +import ctypes +import random +import sys +import time +from pathlib import Path + +RESET = "\033[0m"; BOLD = "\033[1m"; DIM = "\033[2m" +RED = "\033[31m"; GREEN = "\033[32m"; YELLOW = "\033[33m"; CYAN = "\033[36m" + +NR = 10 +NW = 5 +PTAG_COUNT = 256 +CONST_COUNT = 128 +MASK64 = (1 << 64) - 1 + + +def const64(ptag: int) -> int: + v = ptag & 0xFFFF_FFFF + return ((v << 32) | v) & MASK64 + + +class RegFileRTL: + def __init__(self, lib_path: str | None = None): + if lib_path is None: + lib_path = str(Path(__file__).resolve().parent / "libregfile_sim.dylib") + L = ctypes.CDLL(lib_path) + + L.rf_create.restype = ctypes.c_void_p + L.rf_destroy.argtypes = [ctypes.c_void_p] + L.rf_reset.argtypes = [ctypes.c_void_p, ctypes.c_uint64] + L.rf_drive_read.argtypes = [ctypes.c_void_p, ctypes.c_uint32, ctypes.c_uint8] + L.rf_drive_write.argtypes = [ctypes.c_void_p, ctypes.c_uint32, ctypes.c_uint8, + ctypes.c_uint8, ctypes.c_uint64] + L.rf_tick.argtypes = [ctypes.c_void_p, ctypes.c_uint64] + L.rf_get_rdata.argtypes = [ctypes.c_void_p, ctypes.c_uint32] + L.rf_get_rdata.restype = ctypes.c_uint64 + L.rf_get_cycle.argtypes = [ctypes.c_void_p] + L.rf_get_cycle.restype = ctypes.c_uint64 + L.rf_run_bench.argtypes = [ctypes.c_void_p, ctypes.c_uint64] + L.rf_run_bench_cd.argtypes = [ctypes.c_void_p, ctypes.c_uint64, ctypes.c_uint32] + + self._L = L + self._c = L.rf_create() + + def __del__(self): + if hasattr(self, "_c") and self._c: + self._L.rf_destroy(self._c) + + def reset(self): + self._L.rf_reset(self._c, 2) + + def drive_read(self, lane: int, addr: int): + self._L.rf_drive_read(self._c, lane, addr & 0xFF) + + def drive_write(self, lane: int, en: int, addr: int, data: int): + self._L.rf_drive_write(self._c, lane, en & 1, addr & 0xFF, data & MASK64) + + def tick(self, n: int = 1): + self._L.rf_tick(self._c, n) + + def get_rdata(self, lane: int) -> int: + return self._L.rf_get_rdata(self._c, lane) + + @property + def cycle(self) -> int: + return self._L.rf_get_cycle(self._c) + + def run_bench(self, n_cycles: int): + self._L.rf_run_bench(self._c, n_cycles) + + def run_bench_cd(self, n_cycles: int, active_pct: int = 100): + self._L.rf_run_bench_cd(self._c, n_cycles, active_pct) + + +def test_functional(rf: RegFileRTL) -> tuple[int, int]: + passed = 0 + failed = 0 + + def check(desc: str, got: int, exp: int): + nonlocal passed, failed + if got == exp: + passed += 1 + else: + failed += 1 + print(f" {RED}FAIL{RESET} {desc}: got=0x{got:016X} exp=0x{exp:016X}") + + rf.reset() + + # ── Test 1: constant ROM reads ── + print(f" {DIM}[T1]{RESET} Constant ROM reads (addr 0..9)...") + for i in range(NR): + rf.drive_read(i, i) + rf.tick(1) + for i in range(NR): + check(f"const[{i}]", rf.get_rdata(i), const64(i)) + + # ── Test 2: uninitialized data reads should be 0 ── + print(f" {DIM}[T2]{RESET} Uninitialized data reads (addr 128..137)...") + for i in range(NR): + rf.drive_read(i, CONST_COUNT + i) + rf.tick(1) + for i in range(NR): + check(f"uninit[{CONST_COUNT + i}]", rf.get_rdata(i), 0) + + # ── Test 3: write then read-back ── + print(f" {DIM}[T3]{RESET} Write then read-back (5 entries)...") + test_data = [ + (128, 0x1111222233334444), + (129, 0x5555666677778888), + (130, 0xDEADBEEFCAFEBABE), + (200, 0x89ABCDEF01234567), + (255, 0x0123456789ABCDEF), + ] + for lane, (addr, data) in enumerate(test_data): + rf.drive_write(lane, 1, addr, data) + rf.tick(1) + # clear writes, set up reads + for lane in range(NW): + rf.drive_write(lane, 0, 0, 0) + for i, (addr, _) in enumerate(test_data): + rf.drive_read(i, addr) + for i in range(len(test_data), NR): + rf.drive_read(i, 0) + rf.tick(1) + for i, (addr, data) in enumerate(test_data): + check(f"wb[{addr}]", rf.get_rdata(i), data) + + # ── Test 4: constant ROM writes are ignored ── + print(f" {DIM}[T4]{RESET} Writes to constant ROM are ignored...") + rf.drive_write(0, 1, 7, 0xAAAAAAAAAAAAAAAA) + rf.drive_write(1, 1, 127, 0xBBBBBBBBBBBBBBBB) + for lane in range(2, NW): + rf.drive_write(lane, 0, 0, 0) + rf.tick(1) + rf.drive_write(0, 0, 0, 0) + rf.drive_write(1, 0, 0, 0) + rf.drive_read(0, 7) + rf.drive_read(1, 127) + rf.tick(1) + check("const[7] unchanged", rf.get_rdata(0), const64(7)) + check("const[127] unchanged", rf.get_rdata(1), const64(127)) + + # ── Test 5: overwrite existing entries ── + print(f" {DIM}[T5]{RESET} Overwrite existing entries...") + rf.drive_write(0, 1, 128, 0x0BADF00D0BADF00D) + rf.drive_write(1, 1, 129, 0x0102030405060708) + for lane in range(2, NW): + rf.drive_write(lane, 0, 0, 0) + rf.tick(1) + for lane in range(NW): + rf.drive_write(lane, 0, 0, 0) + rf.drive_read(0, 128) + rf.drive_read(1, 129) + rf.tick(1) + check("overwrite[128]", rf.get_rdata(0), 0x0BADF00D0BADF00D) + check("overwrite[129]", rf.get_rdata(1), 0x0102030405060708) + + return passed, failed + + +def benchmark(rf: RegFileRTL, n_cycles: int) -> float: + rf.reset() + + # warm up + rf.run_bench(1000) + + # timed run + t0 = time.perf_counter() + rf.run_bench(n_cycles) + t1 = time.perf_counter() + return t1 - t0 + + +def main(): + print(f"\n{BOLD}{CYAN}RegisterFile RTL Simulation{RESET}") + print(f" Config: {PTAG_COUNT} entries, {CONST_COUNT} constants, {NR}R/{NW}W, 64-bit data") + print(f"{'=' * 60}\n") + + rf = RegFileRTL() + + # ── Functional tests ── + print(f"{BOLD}Functional Correctness Tests{RESET}") + passed, failed = test_functional(rf) + total = passed + failed + if failed == 0: + print(f"\n {GREEN}{BOLD}ALL {total} checks PASSED{RESET}\n") + else: + print(f"\n {RED}{BOLD}{failed}/{total} checks FAILED{RESET}\n") + + # ── Benchmark: 100% active ── + N = 100_000 + print(f"{BOLD}Performance Benchmark ({N // 1000}K cycles, 100% active){RESET}") + print(f" Mixed random read/write traffic per cycle...") + + elapsed = benchmark(rf, N) + khz = N / elapsed / 1000 + print(f"\n Cycles: {N:>12,}") + print(f" Elapsed: {elapsed:>12.4f} s") + print(f" Throughput:{khz:>12.1f} Kcycles/s") + print(f" Per cycle: {elapsed / N * 1e6:>12.2f} us") + + # ── Benchmark: change-detection with varying activity rates ── + print(f"\n{BOLD}Change-Detection Benchmark ({N // 1000}K cycles){RESET}") + for pct in [100, 50, 25, 10, 1]: + rf.reset() + rf.run_bench_cd(1000, pct) # warm up + t0 = time.perf_counter() + rf.run_bench_cd(N, pct) + t1 = time.perf_counter() + el = t1 - t0 + kc = N / el / 1000 + print(f" {pct:3d}% active: {el:.4f}s ({kc:.1f} Kcycles/s)") + + print(f"\n{GREEN}{BOLD}Done.{RESET}\n") + + sys.exit(1 if failed else 0) + + +if __name__ == "__main__": + main() diff --git a/designs/RegisterFile/pgo_profiles/_pgo_train.py b/designs/RegisterFile/pgo_profiles/_pgo_train.py new file mode 100644 index 0000000..d837ada --- /dev/null +++ b/designs/RegisterFile/pgo_profiles/_pgo_train.py @@ -0,0 +1,11 @@ +#!/usr/bin/env python3 +import ctypes, sys +L = ctypes.CDLL('designs/RegisterFile/pgo_profiles/libinstr.dylib') +L.rf_create.restype = ctypes.c_void_p +L.rf_reset.argtypes = [ctypes.c_void_p, ctypes.c_uint64] +L.rf_run_bench.argtypes = [ctypes.c_void_p, ctypes.c_uint64] +L.rf_destroy.argtypes = [ctypes.c_void_p] +c = L.rf_create() +L.rf_reset(c, 2) +L.rf_run_bench(c, 10000) +L.rf_destroy(c) diff --git a/designs/RegisterFile/pgo_profiles/default.profraw b/designs/RegisterFile/pgo_profiles/default.profraw new file mode 100644 index 0000000..1dc0c80 Binary files /dev/null and b/designs/RegisterFile/pgo_profiles/default.profraw differ diff --git a/designs/RegisterFile/pgo_profiles/merged.profdata b/designs/RegisterFile/pgo_profiles/merged.profdata new file mode 100644 index 0000000..1b13ee5 Binary files /dev/null and b/designs/RegisterFile/pgo_profiles/merged.profdata differ diff --git a/designs/RegisterFile/pgo_profiles2/default.profraw b/designs/RegisterFile/pgo_profiles2/default.profraw new file mode 100644 index 0000000..d8df3a1 Binary files /dev/null and b/designs/RegisterFile/pgo_profiles2/default.profraw differ diff --git a/designs/RegisterFile/pgo_profiles2/merged.profdata b/designs/RegisterFile/pgo_profiles2/merged.profdata new file mode 100644 index 0000000..8521253 Binary files /dev/null and b/designs/RegisterFile/pgo_profiles2/merged.profdata differ diff --git a/designs/RegisterFile/pgo_profiles3/default.profraw b/designs/RegisterFile/pgo_profiles3/default.profraw new file mode 100644 index 0000000..8ae995d Binary files /dev/null and b/designs/RegisterFile/pgo_profiles3/default.profraw differ diff --git a/designs/RegisterFile/pgo_profiles3/merged.profdata b/designs/RegisterFile/pgo_profiles3/merged.profdata new file mode 100644 index 0000000..8e39288 Binary files /dev/null and b/designs/RegisterFile/pgo_profiles3/merged.profdata differ diff --git a/designs/RegisterFile/regfile.profdata b/designs/RegisterFile/regfile.profdata new file mode 100644 index 0000000..cc5ed58 Binary files /dev/null and b/designs/RegisterFile/regfile.profdata differ diff --git a/designs/RegisterFile/regfile.profraw b/designs/RegisterFile/regfile.profraw new file mode 100644 index 0000000..8558bde Binary files /dev/null and b/designs/RegisterFile/regfile.profraw differ diff --git a/designs/RegisterFile/regfile.py b/designs/RegisterFile/regfile.py index bffb69a..8f3df05 100644 --- a/designs/RegisterFile/regfile.py +++ b/designs/RegisterFile/regfile.py @@ -1,13 +1,17 @@ from __future__ import annotations -from pycircuit import Circuit, compile, module -from pycircuit.lib import RegFile - - -@module -def build( - m: Circuit, - *, +from pycircuit import ( + CycleAwareCircuit, + CycleAwareDomain, + CycleAwareSignal, + cas, + compile_cycle_aware, + mux, + u, +) + + +def build(m: CycleAwareCircuit, domain: CycleAwareDomain, *, ptag_count: int = 256, const_count: int = 128, nr: int = 10, @@ -27,52 +31,78 @@ def build( raise ValueError("regfile nw must be > 0") ptag_w = max(1, (ptag_n - 1).bit_length()) + storage_depth = ptag_n - const_n + cmp_w = ptag_w + 1 + + # ══════════════════════════════════════════════════════════════ + # Cycle 0 — Inputs + # ══════════════════════════════════════════════════════════════ + raddr = [cas(domain, m.input(f"raddr{i}", width=ptag_w), cycle=0) for i in range(nr_n)] + wen = [cas(domain, m.input(f"wen{i}", width=1), cycle=0) for i in range(nw_n)] + waddr = [cas(domain, m.input(f"waddr{i}", width=ptag_w), cycle=0) for i in range(nw_n)] + wdata = [cas(domain, m.input(f"wdata{i}", width=64), cycle=0) for i in range(nw_n)] + + wdata_lo = [wd[0:32] for wd in wdata] + wdata_hi = [wd[32:64] for wd in wdata] + + # ══════════════════════════════════════════════════════════════ + # Cycle 0 — Storage state (feedback registers via domain.state) + # ══════════════════════════════════════════════════════════════ + bank0 = [domain.state(width=32, reset_value=0, name=f"rf_bank0_{i}") for i in range(storage_depth)] + bank1 = [domain.state(width=32, reset_value=0, name=f"rf_bank1_{i}") for i in range(storage_depth)] + + # ══════════════════════════════════════════════════════════════ + # Cycle 0 — Combinational read logic + # ══════════════════════════════════════════════════════════════ + zero32 = cas(domain, m.const(0, width=32), cycle=0) + zero64 = cas(domain, m.const(0, width=64), cycle=0) - clk = m.clock("clk") - rst = m.reset("rst") - - raddr = [m.input(f"raddr{i}", width=ptag_w) for i in range(nr_n)] - wen = [m.input(f"wen{i}", width=1) for i in range(nw_n)] - waddr = [m.input(f"waddr{i}", width=ptag_w) for i in range(nw_n)] - wdata = [m.input(f"wdata{i}", width=64) for i in range(nw_n)] - - raddr_bus = raddr[0] - for i in range(1, nr_n): - raddr_bus = m.cat(raddr[i], raddr_bus) - - wen_bus = wen[0] - for i in range(1, nw_n): - wen_bus = m.cat(wen[i], wen_bus) - - waddr_bus = waddr[0] - for i in range(1, nw_n): - waddr_bus = m.cat(waddr[i], waddr_bus) - - wdata_bus = wdata[0] - for i in range(1, nw_n): - wdata_bus = m.cat(wdata[i], wdata_bus) - - rf = RegFile( - m, - clk=clk, - rst=rst, - raddr_bus=raddr_bus, - wen_bus=wen_bus, - waddr_bus=waddr_bus, - wdata_bus=wdata_bus, - ptag_count=ptag_n, - const_count=const_n, - nr=nr_n, - nw=nw_n, - ) - - rdata_bus = rf["rdata_bus"].read() for i in range(nr_n): - m.output(f"rdata{i}", rdata_bus[i * 64 : (i + 1) * 64]) + ra = raddr[i] + ra_ext = cas(domain, ra.wire + u(cmp_w, 0), cycle=0) + is_valid = ra_ext < cas(domain, m.const(ptag_n, width=cmp_w), cycle=0) + is_const = ra_ext < cas(domain, m.const(const_n, width=cmp_w), cycle=0) + + if ra.wire.width > 32: + const32 = cas(domain, ra.wire[0:32], cycle=0) + else: + const32 = cas(domain, ra.wire + u(32, 0), cycle=0) + const64 = cas(domain, m.cat(const32.wire, const32.wire), cycle=0) + + store_lo: CycleAwareSignal = zero32 + store_hi: CycleAwareSignal = zero32 + for sidx in range(storage_depth): + ptag = const_n + sidx + hit = ra == cas(domain, m.const(ptag, width=ptag_w), cycle=0) + store_lo = mux(hit, bank0[sidx], store_lo) + store_hi = mux(hit, bank1[sidx], store_hi) + store64 = cas(domain, m.cat(store_hi.wire, store_lo.wire), cycle=0) + + lane_data = mux(is_const, const64, store64) + lane_data = mux(is_valid, lane_data, zero64) + m.output(f"rdata{i}", lane_data.wire) + + # ══════════════════════════════════════════════════════════════ + # domain.next() → Cycle 1 — Synchronous write (close feedback) + # ══════════════════════════════════════════════════════════════ + domain.next() + + for sidx in range(storage_depth): + ptag = const_n + sidx + we_any = cas(domain, m.const(0, width=1), cycle=0) + next_lo: CycleAwareSignal = bank0[sidx] + next_hi: CycleAwareSignal = bank1[sidx] + for lane in range(nw_n): + hit = wen[lane] & (waddr[lane] == cas(domain, m.const(ptag, width=ptag_w), cycle=0)) + we_any = we_any | hit + next_lo = mux(hit, wdata_lo[lane], next_lo) + next_hi = mux(hit, wdata_hi[lane], next_hi) + bank0[sidx].set(next_lo, when=we_any) + bank1[sidx].set(next_hi, when=we_any) build.__pycircuit_name__ = "regfile" if __name__ == "__main__": - print(compile(build, name="regfile").emit_mlir()) + print(compile_cycle_aware(build, name="regfile", eager=True).emit_mlir()) diff --git a/designs/RegisterFile/regfile_capi.cpp b/designs/RegisterFile/regfile_capi.cpp new file mode 100644 index 0000000..2b6d586 --- /dev/null +++ b/designs/RegisterFile/regfile_capi.cpp @@ -0,0 +1,257 @@ +/** + * regfile_capi.cpp — C API wrapper for the RegisterFile RTL model. + * + * Build (from pyCircuit root): + * c++ -std=c++17 -O2 -shared -fPIC -I include \ + * -o designs/RegisterFile/libregfile_sim.dylib \ + * designs/RegisterFile/regfile_capi.cpp + */ +#include +#include +#include +#include +#include + +#include "generated/regfile_gen.hpp" + +using pyc::cpp::Wire; +using pyc::cpp::InputFingerprint; + +static constexpr unsigned NR = 10; +static constexpr unsigned NW = 5; +static constexpr unsigned PTAG_W = 8; + +struct SimContext { + pyc::gen::RegFile__p6da24dd3 dut{}; + pyc::cpp::Testbench tb; + uint64_t cycle = 0; + + InputFingerprint<80, 5, 40, 320> input_fp; + bool eval_dirty = true; + + SimContext() + : tb(dut), + input_fp(dut.raddr_bus, dut.wen_bus, dut.waddr_bus, dut.wdata_bus) { + tb.addClock(dut.clk, 1); + } + + void mark_inputs_dirty() { eval_dirty = true; } + + void eval_if_dirty() { + if (eval_dirty || input_fp.check_and_capture()) { + dut.eval(); + eval_dirty = false; + } + } + + void force_eval() { + dut.eval(); + input_fp.capture(); + eval_dirty = false; + } +}; + +static void pack_raddr(SimContext *c, const uint8_t addrs[NR]) { + uint64_t w0 = 0; + for (unsigned i = 0; i < 8; i++) + w0 |= (uint64_t)addrs[i] << (i * PTAG_W); + uint64_t w1 = 0; + for (unsigned i = 8; i < NR; i++) + w1 |= (uint64_t)addrs[i] << ((i - 8) * PTAG_W); + c->dut.raddr_bus.setWord(0, w0); + c->dut.raddr_bus.setWord(1, w1); +} + +static void pack_write(SimContext *c, const uint8_t wen[NW], + const uint8_t waddr[NW], const uint64_t wdata[NW]) { + uint64_t wen_val = 0; + for (unsigned i = 0; i < NW; i++) + if (wen[i]) wen_val |= (1u << i); + c->dut.wen_bus = Wire<5>((uint64_t)wen_val); + + uint64_t wa = 0; + for (unsigned i = 0; i < NW; i++) + wa |= (uint64_t)waddr[i] << (i * PTAG_W); + c->dut.waddr_bus = Wire<40>(wa); + + for (unsigned i = 0; i < NW; i++) + c->dut.wdata_bus.setWord(i, wdata[i]); +} + +static uint64_t extract_rdata(SimContext *c, unsigned lane) { + return c->dut.rdata_bus.word(lane); +} + +extern "C" { + +SimContext *rf_create() { return new SimContext(); } +void rf_destroy(SimContext *c) { delete c; } + +void rf_reset(SimContext *c, uint64_t n) { + c->dut.wen_bus = Wire<5>(0u); + c->dut.raddr_bus = Wire<80>(0u); + c->dut.waddr_bus = Wire<40>(0u); + for (unsigned i = 0; i < NW; i++) + c->dut.wdata_bus.setWord(i, 0); + c->tb.reset(c->dut.rst, n, 1); + c->force_eval(); + c->cycle = 0; +} + +void rf_drive_read(SimContext *c, uint32_t lane, uint8_t addr) { + uint64_t w = c->dut.raddr_bus.word(lane / 8); + unsigned shift = (lane % 8) * PTAG_W; + w &= ~((uint64_t)0xFF << shift); + w |= (uint64_t)addr << shift; + c->dut.raddr_bus.setWord(lane / 8, w); + c->mark_inputs_dirty(); +} + +void rf_drive_write(SimContext *c, uint32_t lane, uint8_t en, + uint8_t addr, uint64_t data) { + uint64_t wen_val = c->dut.wen_bus.value(); + if (en) wen_val |= (1u << lane); else wen_val &= ~(1u << lane); + c->dut.wen_bus = Wire<5>((uint64_t)wen_val); + + uint64_t wa = c->dut.waddr_bus.value(); + unsigned shift = lane * PTAG_W; + wa &= ~((uint64_t)0xFF << shift); + wa |= (uint64_t)addr << shift; + c->dut.waddr_bus = Wire<40>(wa); + + c->dut.wdata_bus.setWord(lane, data); + c->mark_inputs_dirty(); +} + +void rf_tick(SimContext *c, uint64_t n) { + c->tb.runCycles(n); + c->cycle += n; + c->eval_dirty = true; +} + +uint64_t rf_get_rdata(SimContext *c, uint32_t lane) { + return extract_rdata(c, lane); +} + +uint64_t rf_get_cycle(SimContext *c) { return c->cycle; } + +// High-performance benchmark loop with change-detection fast path. +// Inlines the clock toggling and eval to avoid Testbench dispatch overhead. +void rf_run_bench(SimContext *c, uint64_t n_cycles) { + uint8_t raddrs[NR]; + uint8_t wen[NW] = {}; + uint8_t waddr[NW] = {}; + uint64_t wdata[NW] = {}; + + auto &dut = c->dut; + + uint64_t rng = 0xDEADBEEF12345678ULL; + auto xorshift = [&]() -> uint64_t { + rng ^= rng << 13; + rng ^= rng >> 7; + rng ^= rng << 17; + return rng; + }; + + for (uint64_t i = 0; i < n_cycles; i++) { + // Drive random inputs + uint64_t r = xorshift(); + for (unsigned j = 0; j < NR; j++) + raddrs[j] = (uint8_t)((r >> (j * 2)) & 0xFF); + pack_raddr(c, raddrs); + + r = xorshift(); + for (unsigned j = 0; j < NW; j++) { + wen[j] = (r >> j) & 1; + waddr[j] = (uint8_t)((r >> (8 + j * 8)) & 0xFF); + wdata[j] = xorshift(); + } + pack_write(c, wen, waddr, wdata); + + // Pre-posedge combinational settle + dut.eval(); + + // Posedge + dut.clk = Wire<1>(1u); + dut.tick(); + + // Post-posedge combinational settle + dut.eval(); + + // Negedge — lightweight: just update clkPrev on all registers + dut.clk = Wire<1>(0u); + dut.tick(); + + c->cycle++; + } +} + +// Benchmark loop with idle cycles to demonstrate change-detection benefit. +// Alternates between 'active_pct' % active cycles (random traffic) and +// idle cycles (no input changes, eval skippable). +void rf_run_bench_cd(SimContext *c, uint64_t n_cycles, uint32_t active_pct) { + auto &dut = c->dut; + auto &fp = c->input_fp; + + uint64_t rng = 0xDEADBEEF12345678ULL; + auto xorshift = [&]() -> uint64_t { + rng ^= rng << 13; + rng ^= rng >> 7; + rng ^= rng << 17; + return rng; + }; + + uint64_t evals_skipped = 0; + + for (uint64_t i = 0; i < n_cycles; i++) { + bool active = (xorshift() % 100) < active_pct; + + if (active) { + // Drive new random inputs + uint64_t r = xorshift(); + uint64_t w0 = 0; + for (unsigned j = 0; j < 8; j++) + w0 |= (uint64_t)((uint8_t)((r >> (j * 2)) & 0xFF)) << (j * PTAG_W); + uint64_t w1 = 0; + for (unsigned j = 8; j < NR; j++) + w1 |= (uint64_t)((uint8_t)((r >> (j * 2)) & 0xFF)) << ((j - 8) * PTAG_W); + dut.raddr_bus.setWord(0, w0); + dut.raddr_bus.setWord(1, w1); + + r = xorshift(); + uint64_t wen_val = r & 0x1F; + dut.wen_bus = Wire<5>((uint64_t)wen_val); + + uint64_t wa = 0; + for (unsigned j = 0; j < NW; j++) + wa |= (uint64_t)((uint8_t)((r >> (8 + j * 8)) & 0xFF)) << (j * PTAG_W); + dut.waddr_bus = Wire<40>(wa); + + for (unsigned j = 0; j < NW; j++) + dut.wdata_bus.setWord(j, xorshift()); + } + + // Change-detection eval: skip if inputs are identical to last capture + if (fp.check_and_capture()) { + dut.eval(); + } else { + evals_skipped++; + } + + // Posedge + dut.clk = Wire<1>(1u); + dut.tick(); + + // Post-posedge settle (registers may have changed, must re-eval) + dut.eval(); + fp.capture(); + + // Negedge + dut.clk = Wire<1>(0u); + dut.tick(); + + c->cycle++; + } +} + +} // extern "C" diff --git a/designs/RegisterFile/tb_regfile.py b/designs/RegisterFile/tb_regfile.py index ff8715b..dfa99fa 100644 --- a/designs/RegisterFile/tb_regfile.py +++ b/designs/RegisterFile/tb_regfile.py @@ -3,7 +3,7 @@ import sys from pathlib import Path -from pycircuit import Tb, compile, testbench +from pycircuit import CycleAwareTb, Tb, compile_cycle_aware, CycleAwareCircuit, CycleAwareDomain, testbench _THIS_DIR = Path(__file__).resolve().parent if str(_THIS_DIR) not in sys.path: @@ -14,9 +14,10 @@ @testbench def tb(t: Tb) -> None: - t.clock("clk") - t.reset("rst", cycles_asserted=2, cycles_deasserted=1) - t.timeout(64) + tb = CycleAwareTb(t) + tb.clock("clk") + tb.reset("rst", cycles_asserted=2, cycles_deasserted=1) + tb.timeout(64) nr = 10 nw = 5 @@ -36,23 +37,23 @@ def read_expected(addr: int, storage: dict[int, int]) -> int: return int(storage.get(a, 0)) & mask64 return 0 - def drive_cycle(cyc: int, reads: list[int], writes: list[tuple[int, int, int]]) -> None: + def drive_cycle(reads: list[int], writes: list[tuple[int, int, int]]) -> None: if len(reads) != nr: raise ValueError(f"tb reads length mismatch: got {len(reads)} expected {nr}") for lane in range(nr): - t.drive(f"raddr{lane}", int(reads[lane]), at=cyc) + tb.drive(f"raddr{lane}", int(reads[lane])) for lane in range(nw): - t.drive(f"wen{lane}", 0, at=cyc) - t.drive(f"waddr{lane}", 0, at=cyc) - t.drive(f"wdata{lane}", 0, at=cyc) + tb.drive(f"wen{lane}", 0) + tb.drive(f"waddr{lane}", 0) + tb.drive(f"wdata{lane}", 0) for lane, waddr, wdata in writes: if lane < 0 or lane >= nw: raise ValueError(f"tb write lane out of range: {lane}") - t.drive(f"wen{lane}", 1, at=cyc) - t.drive(f"waddr{lane}", int(waddr), at=cyc) - t.drive(f"wdata{lane}", int(wdata) & mask64, at=cyc) + tb.drive(f"wen{lane}", 1) + tb.drive(f"waddr{lane}", int(waddr)) + tb.drive(f"wdata{lane}", int(wdata) & mask64) seq = [ { @@ -119,9 +120,12 @@ def drive_cycle(cyc: int, reads: list[int], writes: list[tuple[int, int, int]]) storage: dict[int, int] = {} for cyc, step in enumerate(seq): + if cyc > 0: + tb.next() # --- advance to next cycle --- + reads = list(step["reads"]) writes = list(step["writes"]) - drive_cycle(cyc, reads, writes) + drive_cycle(reads, writes) for _, waddr, wdata in writes: wa = int(waddr) @@ -130,10 +134,10 @@ def drive_cycle(cyc: int, reads: list[int], writes: list[tuple[int, int, int]]) for lane in range(nr): exp = read_expected(reads[lane], storage) - t.expect(f"rdata{lane}", exp, at=cyc, msg=f"regfile mismatch cycle={cyc} lane={lane}") + tb.expect(f"rdata{lane}", exp, msg=f"regfile mismatch cycle={cyc} lane={lane}") - t.finish(at=len(seq) - 1) + tb.finish() if __name__ == "__main__": - print(compile(build, name="tb_regfile_top").emit_mlir()) + print(compile_cycle_aware(build, name="tb_regfile_top", eager=True).emit_mlir()) diff --git a/designs/examples/arith/arith.py b/designs/examples/arith/arith.py index 23a299c..f5ce5b2 100644 --- a/designs/examples/arith/arith.py +++ b/designs/examples/arith/arith.py @@ -1,6 +1,6 @@ from __future__ import annotations -from pycircuit import Circuit, compile, const, ct, module, spec, u +from pycircuit import Circuit, compile_cycle_aware, CycleAwareCircuit, CycleAwareDomain, const, ct, module, spec, u @spec.valueclass @@ -28,8 +28,7 @@ def _lane_mask(m: Circuit, *, width: int) -> int: return ct.bitmask(w) -@module -def build(m: Circuit, lanes: int = 8, lane_width: int = 16) -> None: +def build(m: CycleAwareCircuit, domain: CycleAwareDomain, lanes: int = 8, lane_width: int = 16) -> None: cfg = _derive_cfg(m, lanes=lanes, lane_width=lane_width) acc_w = _acc_width(m, cfg) lane_mask = _lane_mask(m, width=int(cfg.lane_width)) @@ -47,4 +46,4 @@ def build(m: Circuit, lanes: int = 8, lane_width: int = 16) -> None: if __name__ == "__main__": - print(compile(build, name="arith", lanes=8, lane_width=16).emit_mlir()) + print(compile_cycle_aware(build, name="arith", eager=True, lanes=8, lane_width=16).emit_mlir()) diff --git a/designs/examples/arith/tb_arith.py b/designs/examples/arith/tb_arith.py index 8276e22..8c3299e 100644 --- a/designs/examples/arith/tb_arith.py +++ b/designs/examples/arith/tb_arith.py @@ -3,7 +3,7 @@ import sys from pathlib import Path -from pycircuit import Tb, compile, testbench +from pycircuit import CycleAwareTb, Tb, compile_cycle_aware, CycleAwareCircuit, CycleAwareDomain, testbench _THIS_DIR = Path(__file__).resolve().parent if str(_THIS_DIR) not in sys.path: @@ -15,15 +15,19 @@ @testbench def tb(t: Tb) -> None: + tb = CycleAwareTb(t) p = TB_PRESETS["smoke"] - t.timeout(int(p["timeout"])) - t.drive("a", 1, at=0) - t.drive("b", 2, at=0) - t.expect("sum", 3, at=0) - t.expect("lane_mask", 0xFFFF, at=0) - t.expect("acc_width", 19, at=0) - t.finish(at=int(p["finish"])) + tb.timeout(int(p["timeout"])) + + # --- cycle 0 --- + tb.drive("a", 1) + tb.drive("b", 2) + tb.expect("sum", 3) + tb.expect("lane_mask", 0xFFFF) + tb.expect("acc_width", 19) + + tb.finish(at=int(p["finish"])) if __name__ == "__main__": - print(compile(build, name="tb_arith_top", **DEFAULT_PARAMS).emit_mlir()) + print(compile_cycle_aware(build, name="tb_arith_top", eager=True, **DEFAULT_PARAMS).emit_mlir()) diff --git a/designs/examples/boundary_value_ports/boundary_value_ports.py b/designs/examples/boundary_value_ports/boundary_value_ports.py index 4ec15de..52ee87c 100644 --- a/designs/examples/boundary_value_ports/boundary_value_ports.py +++ b/designs/examples/boundary_value_ports/boundary_value_ports.py @@ -1,6 +1,6 @@ from __future__ import annotations -from pycircuit import Circuit, compile, module, u +from pycircuit import Circuit, module, compile_cycle_aware, CycleAwareCircuit, CycleAwareDomain, u @module(value_params={"gain": "i8", "bias": "i32", "enable": "i1"}) @@ -10,8 +10,8 @@ def _lane(m: Circuit, x, gain, bias, enable, *, width: int = 32): m.output("y", y) -@module -def build(m: Circuit, *, width: int = 32): +def build(m: CycleAwareCircuit, domain: CycleAwareDomain, *, width: int = 32): + _ = domain seed = m.input("seed", width=width) lane0 = m.new( @@ -40,4 +40,4 @@ def build(m: Circuit, *, width: int = 32): build.__pycircuit_name__ = "boundary_value_ports" if __name__ == "__main__": - print(compile(build, name="boundary_value_ports", width=32).emit_mlir()) + print(compile_cycle_aware(build, name="boundary_value_ports", width=32).emit_mlir()) diff --git a/designs/examples/boundary_value_ports/tb_boundary_value_ports.py b/designs/examples/boundary_value_ports/tb_boundary_value_ports.py index a2c1a17..a67205d 100644 --- a/designs/examples/boundary_value_ports/tb_boundary_value_ports.py +++ b/designs/examples/boundary_value_ports/tb_boundary_value_ports.py @@ -3,7 +3,7 @@ import sys from pathlib import Path -from pycircuit import Tb, compile, testbench +from pycircuit import CycleAwareTb, Tb, compile_cycle_aware, CycleAwareCircuit, CycleAwareDomain, testbench _THIS_DIR = Path(__file__).resolve().parent if str(_THIS_DIR) not in sys.path: @@ -15,12 +15,14 @@ @testbench def tb(t: Tb) -> None: + tb = CycleAwareTb(t) p = TB_PRESETS["smoke"] - t.timeout(int(p["timeout"])) - t.drive("seed", 10, at=0) - t.expect("acc", 48, at=0) - t.finish(at=int(p["finish"])) + tb.timeout(int(p["timeout"])) + # --- cycle 0 --- + tb.drive("seed", 10) + tb.expect("acc", 48) + tb.finish(at=int(p["finish"])) if __name__ == "__main__": - print(compile(build, name="tb_boundary_value_ports_top", **DEFAULT_PARAMS).emit_mlir()) + print(compile_cycle_aware(build, name="tb_boundary_value_ports_top", **DEFAULT_PARAMS).emit_mlir()) diff --git a/designs/examples/bundle_probe_expand/bundle_probe_expand.py b/designs/examples/bundle_probe_expand/bundle_probe_expand.py index 02f91db..1d6d6e0 100644 --- a/designs/examples/bundle_probe_expand/bundle_probe_expand.py +++ b/designs/examples/bundle_probe_expand/bundle_probe_expand.py @@ -1,6 +1,6 @@ from __future__ import annotations -from pycircuit import Circuit, ProbeBuilder, ProbeView, compile, const, module, probe, spec +from pycircuit import Circuit, ProbeBuilder, ProbeView, compile_cycle_aware, CycleAwareCircuit, CycleAwareDomain, const, module, probe, spec @const @@ -14,8 +14,7 @@ def _probe_struct(m: Circuit): ) -@module -def build(m: Circuit) -> None: +def build(m: CycleAwareCircuit, domain: CycleAwareDomain) -> None: _clk = m.clock("clk") _rst = m.reset("rst") @@ -23,6 +22,7 @@ def build(m: Circuit) -> None: inp = m.inputs(s, prefix="in_") build.__pycircuit_name__ = "bundle_probe_expand" +build.__pycircuit_kind__ = "module" @probe(target=build, name="pv") @@ -39,4 +39,4 @@ def bundle_probe(p: ProbeBuilder, dut: ProbeView) -> None: if __name__ == "__main__": - print(compile(build, name="bundle_probe_expand").emit_mlir()) + print(compile_cycle_aware(build, name="bundle_probe_expand", eager=True).emit_mlir()) diff --git a/designs/examples/bundle_probe_expand/tb_bundle_probe_expand.py b/designs/examples/bundle_probe_expand/tb_bundle_probe_expand.py index d5ba24c..b091be3 100644 --- a/designs/examples/bundle_probe_expand/tb_bundle_probe_expand.py +++ b/designs/examples/bundle_probe_expand/tb_bundle_probe_expand.py @@ -3,7 +3,7 @@ import sys from pathlib import Path -from pycircuit import Tb, TbProbes, compile, testbench +from pycircuit import CycleAwareTb, Tb, TbProbes, compile_cycle_aware, CycleAwareCircuit, CycleAwareDomain, testbench _THIS_DIR = Path(__file__).resolve().parent if str(_THIS_DIR) not in sys.path: @@ -15,28 +15,31 @@ @testbench def tb(t: Tb, probes: TbProbes) -> None: + tb = CycleAwareTb(t) p = TB_PRESETS["smoke"] _ = probes["dut:probe.pv.in.a"] _ = probes["dut:probe.pv.in.b.c"] - t.clock("clk") - t.reset("rst", cycles_asserted=2, cycles_deasserted=0) - t.timeout(int(p["timeout"])) + tb.clock("clk") + tb.reset("rst", cycles_asserted=2, cycles_deasserted=0) + tb.timeout(int(p["timeout"])) - t.drive("in_a", 0, at=0) - t.drive("in_b_c", 0, at=0) + # --- cycle 0 --- + tb.drive("in_a", 0) + tb.drive("in_b_c", 0) - t.drive("in_a", 0x12, at=0) - t.drive("in_b_c", 1, at=0) - t.expect("in_a", 0x12, at=0, phase="pre") - t.expect("in_b_c", 1, at=0, phase="pre") + tb.drive("in_a", 0x12) + tb.drive("in_b_c", 1) + tb.expect("in_a", 0x12, phase="pre") + tb.expect("in_b_c", 1, phase="pre") - t.drive("in_a", 0x34, at=1) - t.drive("in_b_c", 0, at=1) - t.expect("in_a", 0x34, at=1, phase="pre") - t.expect("in_b_c", 0, at=1, phase="pre") + tb.next() # --- cycle 1 --- + tb.drive("in_a", 0x34) + tb.drive("in_b_c", 0) + tb.expect("in_a", 0x34, phase="pre") + tb.expect("in_b_c", 0, phase="pre") - t.finish(at=int(p["finish"])) + tb.finish(at=int(p["finish"])) if __name__ == "__main__": - print(compile(build, name="tb_bundle_probe_expand_top", **DEFAULT_PARAMS).emit_mlir()) + print(compile_cycle_aware(build, name="tb_bundle_probe_expand_top", eager=True, **DEFAULT_PARAMS).emit_mlir()) diff --git a/designs/examples/cache_params/cache_params.py b/designs/examples/cache_params/cache_params.py index 91be813..31a2c7a 100644 --- a/designs/examples/cache_params/cache_params.py +++ b/designs/examples/cache_params/cache_params.py @@ -1,6 +1,6 @@ from __future__ import annotations -from pycircuit import Circuit, compile, ct, module, const, u +from pycircuit import Circuit, compile_cycle_aware, CycleAwareCircuit, CycleAwareDomain, const, ct, u @const @@ -28,15 +28,15 @@ def _cache_cfg( return (ways_i, sets_i, line_b, off_bits, idx_bits, tag_bits, line_words) -@module def build( - m: Circuit, + m: CycleAwareCircuit, domain: CycleAwareDomain, ways: int = 4, sets: int = 64, line_bytes: int = 64, addr_width: int = 40, data_width: int = 64, ) -> None: + _ = domain ways_cfg, sets_cfg, line_bytes_cfg, off_bits, idx_bits, tag_bits, line_words = _cache_cfg( m, ways=ways, @@ -61,7 +61,7 @@ def build( if __name__ == "__main__": print( - compile(build, name="cache_params", + compile_cycle_aware(build, name="cache_params", eager=True, ways=4, sets=64, line_bytes=64, diff --git a/designs/examples/cache_params/tb_cache_params.py b/designs/examples/cache_params/tb_cache_params.py index 14a1dfd..9b731d6 100644 --- a/designs/examples/cache_params/tb_cache_params.py +++ b/designs/examples/cache_params/tb_cache_params.py @@ -3,7 +3,7 @@ import sys from pathlib import Path -from pycircuit import Tb, compile, testbench +from pycircuit import CycleAwareTb, Tb, compile_cycle_aware, CycleAwareCircuit, CycleAwareDomain, testbench _THIS_DIR = Path(__file__).resolve().parent if str(_THIS_DIR) not in sys.path: @@ -15,14 +15,18 @@ @testbench def tb(t: Tb) -> None: + tb = CycleAwareTb(t) p = TB_PRESETS["smoke"] - t.timeout(int(p["timeout"])) - t.drive("addr", 0, at=0) - t.expect("tag", 0, at=0) - t.expect("line_words", 8, at=0) - t.expect("tag_bits", 28, at=0) - t.finish(at=int(p["finish"])) + tb.timeout(int(p["timeout"])) + + # --- cycle 0 --- + tb.drive("addr", 0) + tb.expect("tag", 0) + tb.expect("line_words", 8) + tb.expect("tag_bits", 28) + + tb.finish(at=int(p["finish"])) if __name__ == "__main__": - print(compile(build, name="tb_cache_params_top", **DEFAULT_PARAMS).emit_mlir()) + print(compile_cycle_aware(build, name="tb_cache_params_top", eager=True, **DEFAULT_PARAMS).emit_mlir()) diff --git a/designs/examples/calculator/calculator.py b/designs/examples/calculator/calculator.py index 32255cc..2afcf26 100644 --- a/designs/examples/calculator/calculator.py +++ b/designs/examples/calculator/calculator.py @@ -1,6 +1,6 @@ from __future__ import annotations -from pycircuit import Circuit, compile, module, unsigned, u +from pycircuit import Circuit, module, compile_cycle_aware, CycleAwareCircuit, CycleAwareDomain, unsigned, u KEY_ADD = 10 KEY_SUB = 11 @@ -15,18 +15,18 @@ OP_DIV = 3 -@module -def build(m: Circuit) -> None: - clk = m.clock("clk") - rst = m.reset("rst") +def build(m: CycleAwareCircuit, domain: CycleAwareDomain) -> None: + cd = domain.clock_domain + clk = cd.clk + rst = cd.rst key = m.input("key", width=5) key_press = m.input("key_press", width=1) - lhs = m.out("lhs", clk=clk, rst=rst, width=64, init=u(64, 0)) - rhs = m.out("rhs", clk=clk, rst=rst, width=64, init=u(64, 0)) - op = m.out("op", clk=clk, rst=rst, width=2, init=u(2, 0)) - in_rhs = m.out("in_rhs", clk=clk, rst=rst, width=1, init=u(1, 0)) - display = m.out("display_r", clk=clk, rst=rst, width=64, init=u(64, 0)) + lhs = m.out("lhs", domain=cd, width=64, init=u(64, 0)) + rhs = m.out("rhs", domain=cd, width=64, init=u(64, 0)) + op = m.out("op", domain=cd, width=2, init=u(2, 0)) + in_rhs = m.out("in_rhs", domain=cd, width=1, init=u(1, 0)) + display = m.out("display_r", domain=cd, width=64, init=u(64, 0)) digit = unsigned(key[0:4]) + u(64, 0) is_digit = key_press & (key <= u(5, 9)) @@ -95,4 +95,4 @@ def build(m: Circuit) -> None: if __name__ == "__main__": - print(compile(build, name="calculator").emit_mlir()) + print(compile_cycle_aware(build, name="calculator").emit_mlir()) diff --git a/designs/examples/calculator/emulate_calculator.py b/designs/examples/calculator/emulate_calculator.py index e36bc0b..34cd07d 100644 --- a/designs/examples/calculator/emulate_calculator.py +++ b/designs/examples/calculator/emulate_calculator.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 -from pycircuit import s # -*- coding: utf-8 -*- +from __future__ import annotations + """ emulate_calculator.py — True RTL simulation of the 16-digit calculator with decimal support, animated terminal display. @@ -18,7 +19,6 @@ Run: python designs/examples/calculator/emulate_calculator.py """ -from __future__ import annotations import ctypes, re as _re, sys, time from pathlib import Path diff --git a/designs/examples/calculator/tb_calculator.py b/designs/examples/calculator/tb_calculator.py index 4768f3c..c4ef0fb 100644 --- a/designs/examples/calculator/tb_calculator.py +++ b/designs/examples/calculator/tb_calculator.py @@ -3,7 +3,7 @@ import sys from pathlib import Path -from pycircuit import Tb, compile, testbench +from pycircuit import CycleAwareTb, Tb, compile_cycle_aware, CycleAwareCircuit, CycleAwareDomain, testbench _THIS_DIR = Path(__file__).resolve().parent if str(_THIS_DIR) not in sys.path: @@ -15,15 +15,17 @@ @testbench def tb(t: Tb) -> None: + tb = CycleAwareTb(t) p = TB_PRESETS["smoke"] - t.clock("clk") - t.reset("rst", cycles_asserted=2, cycles_deasserted=1) - t.timeout(int(p["timeout"])) - t.drive("key_press", 0, at=0) - t.drive("key", 0, at=0) - t.expect("display", 0, at=0) - t.finish(at=int(p["finish"])) + tb.clock("clk") + tb.reset("rst", cycles_asserted=2, cycles_deasserted=1) + tb.timeout(int(p["timeout"])) + # --- cycle 0 --- + tb.drive("key_press", 0) + tb.drive("key", 0) + tb.expect("display", 0) + tb.finish(at=int(p["finish"])) if __name__ == "__main__": - print(compile(build, name="tb_calculator_top", **DEFAULT_PARAMS).emit_mlir()) + print(compile_cycle_aware(build, name="tb_calculator_top", **DEFAULT_PARAMS).emit_mlir()) diff --git a/designs/examples/counter/counter.py b/designs/examples/counter/counter.py index 2663691..ac2ada4 100644 --- a/designs/examples/counter/counter.py +++ b/designs/examples/counter/counter.py @@ -1,22 +1,26 @@ from __future__ import annotations -from pycircuit import Circuit, compile, module, u +from pycircuit import ( + CycleAwareCircuit, + CycleAwareDomain, + cas, + compile_cycle_aware, + mux, +) -@module -def build(m: Circuit, width: int = 8) -> None: - clk = m.clock("clk") - rst = m.reset("rst") - en = m.input("enable", width=1) +def build(m: CycleAwareCircuit, domain: CycleAwareDomain, width: int = 8) -> None: + enable = cas(domain, m.input("enable", width=1), cycle=0) + count = domain.state(width=width, reset_value=0, name="count") - count = m.out("count_q", clk=clk, rst=rst, width=width, init=u(width, 0)) - count.set(count.out() + 1, when=en) - m.output("count", count) + m.output("count", count.wire) + domain.next() + count.set(count + 1, when=enable) build.__pycircuit_name__ = "counter" if __name__ == "__main__": - print(compile(build, name="counter", width=8).emit_mlir()) + print(compile_cycle_aware(build, name="counter", eager=True, width=8).emit_mlir()) diff --git a/designs/examples/counter/tb_counter.py b/designs/examples/counter/tb_counter.py index 607101b..660909f 100644 --- a/designs/examples/counter/tb_counter.py +++ b/designs/examples/counter/tb_counter.py @@ -3,7 +3,7 @@ import sys from pathlib import Path -from pycircuit import Tb, compile, testbench +from pycircuit import CycleAwareTb, Tb, compile_cycle_aware, CycleAwareCircuit, CycleAwareDomain, testbench _THIS_DIR = Path(__file__).resolve().parent if str(_THIS_DIR) not in sys.path: @@ -15,15 +15,30 @@ @testbench def tb(t: Tb) -> None: + tb = CycleAwareTb(t) p = TB_PRESETS["smoke"] - t.clock("clk") - t.reset("rst", cycles_asserted=2, cycles_deasserted=1) - t.timeout(int(p["timeout"])) - t.drive("enable", 1, at=0) - for cyc in range(5): - t.expect("count", cyc + 1, at=cyc) - t.finish(at=int(p["finish"])) + tb.clock("clk") + tb.reset("rst", cycles_asserted=2, cycles_deasserted=1) + tb.timeout(int(p["timeout"])) + + # --- cycle 0 --- + tb.drive("enable", 1) + tb.expect("count", 1) + + tb.next() # --- cycle 1 --- + tb.expect("count", 2) + + tb.next() # --- cycle 2 --- + tb.expect("count", 3) + + tb.next() # --- cycle 3 --- + tb.expect("count", 4) + + tb.next() # --- cycle 4 --- + tb.expect("count", 5) + + tb.finish(at=int(p["finish"])) if __name__ == "__main__": - print(compile(build, name="tb_counter_top", **DEFAULT_PARAMS).emit_mlir()) + print(compile_cycle_aware(build, name="tb_counter_top", eager=True, **DEFAULT_PARAMS).emit_mlir()) diff --git a/designs/examples/decode_rules/decode_rules.py b/designs/examples/decode_rules/decode_rules.py index 8d01299..7c3c269 100644 --- a/designs/examples/decode_rules/decode_rules.py +++ b/designs/examples/decode_rules/decode_rules.py @@ -1,6 +1,6 @@ from __future__ import annotations -from pycircuit import Circuit, compile, const, module, spec, u +from pycircuit import Circuit, compile_cycle_aware, CycleAwareCircuit, CycleAwareDomain, const, module, spec, u @const @@ -15,8 +15,7 @@ def _decode_rules(m: Circuit): ) -@module -def build(m: Circuit): +def build(m: CycleAwareCircuit, domain: CycleAwareDomain) : insn = m.input("insn", width=8) op = u(4, 0) ln = u(3, 0) @@ -32,4 +31,4 @@ def build(m: Circuit): build.__pycircuit_name__ = "decode_rules" if __name__ == "__main__": - print(compile(build, name="decode_rules").emit_mlir()) + print(compile_cycle_aware(build, name="decode_rules").emit_mlir()) diff --git a/designs/examples/decode_rules/tb_decode_rules.py b/designs/examples/decode_rules/tb_decode_rules.py index fb9f61b..5fe127a 100644 --- a/designs/examples/decode_rules/tb_decode_rules.py +++ b/designs/examples/decode_rules/tb_decode_rules.py @@ -3,7 +3,7 @@ import sys from pathlib import Path -from pycircuit import Tb, compile, testbench +from pycircuit import CycleAwareTb, Tb, compile_cycle_aware, CycleAwareCircuit, CycleAwareDomain, testbench _THIS_DIR = Path(__file__).resolve().parent if str(_THIS_DIR) not in sys.path: @@ -15,13 +15,15 @@ @testbench def tb(t: Tb) -> None: + tb = CycleAwareTb(t) p = TB_PRESETS["smoke"] - t.timeout(int(p["timeout"])) - t.drive("insn", 0x10, at=0) - t.expect("op", 1, at=0) - t.expect("len", 4, at=0) - t.finish(at=int(p["finish"])) + tb.timeout(int(p["timeout"])) + # --- cycle 0 --- + tb.drive("insn", 0x10) + tb.expect("op", 1) + tb.expect("len", 4) + tb.finish(at=int(p["finish"])) if __name__ == "__main__": - print(compile(build, name="tb_decode_rules_top", **DEFAULT_PARAMS).emit_mlir()) + print(compile_cycle_aware(build, name="tb_decode_rules_top", **DEFAULT_PARAMS).emit_mlir()) diff --git a/designs/examples/digital_clock/digital_clock.py b/designs/examples/digital_clock/digital_clock.py index fb90adc..1bad44f 100644 --- a/designs/examples/digital_clock/digital_clock.py +++ b/designs/examples/digital_clock/digital_clock.py @@ -1,6 +1,6 @@ from __future__ import annotations -from pycircuit import Circuit, cat, compile, function, module, u +from pycircuit import Circuit, cat, compile_cycle_aware, CycleAwareCircuit, CycleAwareDomain, function, module, u MODE_RUN = 0 MODE_SET_HOUR = 1 @@ -15,21 +15,21 @@ def _to_bcd8(m: Circuit, v): return cat(tens[0:4], ones[0:4]) -@module -def build(m: Circuit, clk_freq: int = 50_000_000) -> None: - clk = m.clock("clk") - rst = m.reset("rst") +def build(m: CycleAwareCircuit, domain: CycleAwareDomain, clk_freq: int = 50_000_000) -> None: + cd = domain.clock_domain + clk = cd.clk + rst = cd.rst btn_set = m.input("btn_set", width=1) btn_plus = m.input("btn_plus", width=1) btn_minus = m.input("btn_minus", width=1) prescaler_w = max((int(clk_freq) - 1).bit_length(), 1) - prescaler = m.out("prescaler", clk=clk, rst=rst, width=prescaler_w, init=u(prescaler_w, 0)) - sec = m.out("sec", clk=clk, rst=rst, width=6, init=u(6, 0)) - minute = m.out("minute", clk=clk, rst=rst, width=6, init=u(6, 0)) - hour = m.out("hour", clk=clk, rst=rst, width=5, init=u(5, 0)) - mode = m.out("mode", clk=clk, rst=rst, width=2, init=u(2, MODE_RUN)) - blink = m.out("blink", clk=clk, rst=rst, width=1, init=u(1, 0)) + prescaler = m.out("prescaler", domain=cd, width=prescaler_w, init=u(prescaler_w, 0)) + sec = m.out("sec", domain=cd, width=6, init=u(6, 0)) + minute = m.out("minute", domain=cd, width=6, init=u(6, 0)) + hour = m.out("hour", domain=cd, width=5, init=u(5, 0)) + mode = m.out("mode", domain=cd, width=2, init=u(2, MODE_RUN)) + blink = m.out("blink", domain=cd, width=1, init=u(1, 0)) tick_1hz = prescaler == u(prescaler_w, clk_freq - 1) @@ -85,4 +85,4 @@ def build(m: Circuit, clk_freq: int = 50_000_000) -> None: if __name__ == "__main__": - print(compile(build, name="digital_clock", clk_freq=50_000_000).emit_mlir()) + print(compile_cycle_aware(build, name="digital_clock", clk_freq=50_000_000).emit_mlir()) diff --git a/designs/examples/digital_clock/emulate_digital_clock.py b/designs/examples/digital_clock/emulate_digital_clock.py index 18380aa..14a7f21 100644 --- a/designs/examples/digital_clock/emulate_digital_clock.py +++ b/designs/examples/digital_clock/emulate_digital_clock.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 -from pycircuit import s # -*- coding: utf-8 -*- +from __future__ import annotations + """ emulate_digital_clock.py — True RTL simulation of the digital clock with an animated terminal display. @@ -17,7 +18,6 @@ Run: python designs/examples/digital_clock/emulate_digital_clock.py """ -from __future__ import annotations import ctypes import os diff --git a/designs/examples/digital_clock/tb_digital_clock.py b/designs/examples/digital_clock/tb_digital_clock.py index ec1b3e8..8a03030 100644 --- a/designs/examples/digital_clock/tb_digital_clock.py +++ b/designs/examples/digital_clock/tb_digital_clock.py @@ -3,7 +3,7 @@ import sys from pathlib import Path -from pycircuit import Tb, compile, testbench +from pycircuit import CycleAwareTb, Tb, compile_cycle_aware, CycleAwareCircuit, CycleAwareDomain, testbench _THIS_DIR = Path(__file__).resolve().parent if str(_THIS_DIR) not in sys.path: @@ -15,16 +15,18 @@ @testbench def tb(t: Tb) -> None: + tb = CycleAwareTb(t) p = TB_PRESETS["smoke"] - t.clock("clk") - t.reset("rst", cycles_asserted=2, cycles_deasserted=1) - t.timeout(int(p["timeout"])) - t.drive("btn_set", 0, at=0) - t.drive("btn_plus", 0, at=0) - t.drive("btn_minus", 0, at=0) - t.expect("seconds_bcd", 0, at=0) - t.finish(at=int(p["finish"])) + tb.clock("clk") + tb.reset("rst", cycles_asserted=2, cycles_deasserted=1) + tb.timeout(int(p["timeout"])) + # --- cycle 0 --- + tb.drive("btn_set", 0) + tb.drive("btn_plus", 0) + tb.drive("btn_minus", 0) + tb.expect("seconds_bcd", 0) + tb.finish(at=int(p["finish"])) if __name__ == "__main__": - print(compile(build, name="tb_digital_clock_top", **DEFAULT_PARAMS).emit_mlir()) + print(compile_cycle_aware(build, name="tb_digital_clock_top", **DEFAULT_PARAMS).emit_mlir()) diff --git a/designs/examples/digital_filter/README.md b/designs/examples/digital_filter/README.md new file mode 100644 index 0000000..4655eef --- /dev/null +++ b/designs/examples/digital_filter/README.md @@ -0,0 +1,75 @@ +# 4-Tap FIR Feed-Forward Filter (pyCircuit) + +A 4-tap direct-form FIR (Finite Impulse Response) filter implemented in +pyCircuit's unified signal model, with true RTL simulation and waveform +visualization. + +## Transfer Function + +``` +y[n] = c0·x[n] + c1·x[n-1] + c2·x[n-2] + c3·x[n-3] +``` + +Default coefficients: `c0=1, c1=2, c2=3, c3=4` + +## Architecture + +``` +x_in ──┬──[×c0]──┐ + │ │ + z⁻¹─[×c1]─(+)──┐ + │ │ + z⁻¹─[×c2]─────(+)──┐ + │ │ + z⁻¹─[×c3]──────────(+)──→ y_out +``` + +Single-cycle design: 3-stage delay line (shift register) + 4 parallel +multipliers + accumulator tree. + +| Register | Width | Description | +|----------|-------|-------------| +| delay_1 | 16 | x[n-1] | +| delay_2 | 16 | x[n-2] | +| delay_3 | 16 | x[n-3] | +| y_valid | 1 | Output valid (1-cycle delayed x_valid) | + +Accumulator width: DATA_W + COEFF_W + 2 guard bits = 34 bits (signed). + +## Ports + +| Port | Dir | Width | Description | +|------|-----|-------|-------------| +| x_in | in | 16 | Input sample (signed) | +| x_valid | in | 1 | Input strobe | +| y_out | out | 34 | Filter output (signed) | +| y_valid | out | 1 | Output valid | + +## Build & Run + +```bash +# 1. Compile RTL +PYTHONPATH=python:. python -m pycircuit.cli emit \ + examples/digital_filter/digital_filter.py \ + -o examples/generated/digital_filter/digital_filter.pyc +build/bin/pyc-compile examples/generated/digital_filter/digital_filter.pyc \ + --emit=cpp -o examples/generated/digital_filter/digital_filter_gen.hpp + +# 2. Build shared library +c++ -std=c++17 -O2 -shared -fPIC -I include -I . \ + -o examples/digital_filter/libfilter_sim.dylib \ + examples/digital_filter/filter_capi.cpp + +# 3. Run emulator +python examples/digital_filter/emulate_filter.py +``` + +## Test Scenarios + +| # | Input | Description | +|---|-------|-------------| +| 1 | Impulse [1,0,0,...] | Verifies impulse response = coefficients | +| 2 | Step [1,1,1,...] | Verifies step response converges to sum(coeffs)=10 | +| 3 | Ramp [0,1,2,...] | Verifies linear input response | +| 4 | Alternating ±100 | Tests signed arithmetic with cancellation | +| 5 | Large values (10000) | Tests near-overflow behavior | diff --git a/designs/examples/digital_filter/__init__.py b/designs/examples/digital_filter/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/designs/examples/digital_filter/digital_filter.py b/designs/examples/digital_filter/digital_filter.py new file mode 100644 index 0000000..724fc50 --- /dev/null +++ b/designs/examples/digital_filter/digital_filter.py @@ -0,0 +1,65 @@ +# -*- coding: utf-8 -*- +"""4-tap Feed-Forward (FIR) Filter — pyCircuit V5 cycle-aware. + +Implements: + y[n] = c0·x[n] + c1·x[n-1] + c2·x[n-2] + c3·x[n-3] +""" +from __future__ import annotations + +from pycircuit import ( + CycleAwareCircuit, + CycleAwareDomain, + cas, + compile_cycle_aware, +) + + +def build(m: CycleAwareCircuit, domain: CycleAwareDomain, *, + TAPS: int = 4, + DATA_W: int = 16, + COEFF_W: int = 16, + COEFFS: tuple = (1, 2, 3, 4), +) -> None: + assert len(COEFFS) == TAPS, f"need {TAPS} coefficients, got {len(COEFFS)}" + + GUARD = (TAPS - 1).bit_length() + ACC_W = DATA_W + COEFF_W + GUARD + + x_in = cas(domain, m.input("x_in", width=DATA_W), cycle=0) + x_valid = cas(domain, m.input("x_valid", width=1), cycle=0) + + delay_states = [domain.state(width=DATA_W, reset_value=0, name=f"delay_{i}") for i in range(1, TAPS)] + + taps_wire = [x_in.wire] + [st.wire for st in delay_states] + + coeff_wires = [m.const(cv, width=ACC_W) for cv in COEFFS] + + acc_w = m.const(0, width=ACC_W) + for i in range(TAPS): + tap_ext = taps_wire[i].as_signed()._sext(width=ACC_W) + product = tap_ext * coeff_wires[i] + acc_w = acc_w + product + + y_comb = cas(domain, acc_w[0:ACC_W], cycle=0) + + y_out_state = domain.state(width=ACC_W, reset_value=0, name="y_out_reg") + y_valid_state = domain.state(width=1, reset_value=0, name="y_valid_reg") + + m.output("y_out", y_out_state.wire) + m.output("y_valid", y_valid_state.wire) + + domain.next() + + delay_states[0].set(x_in, when=x_valid) + for i in range(1, len(delay_states)): + delay_states[i].set(delay_states[i - 1], when=x_valid) + + y_out_state.set(y_comb, when=x_valid) + y_valid_state.set(x_valid) + + +build.__pycircuit_name__ = "digital_filter" + +if __name__ == "__main__": + print(compile_cycle_aware(build, name="digital_filter", eager=True, + TAPS=4, DATA_W=16, COEFF_W=16, COEFFS=(1, 2, 3, 4)).emit_mlir()) diff --git a/designs/examples/digital_filter/emulate_filter.py b/designs/examples/digital_filter/emulate_filter.py new file mode 100644 index 0000000..db6a3a0 --- /dev/null +++ b/designs/examples/digital_filter/emulate_filter.py @@ -0,0 +1,284 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +emulate_filter.py — True RTL simulation of the 4-tap FIR filter +with animated terminal visualization. + +Shows the filter structure, delay line contents, coefficients, +input/output waveforms, and step-by-step operation. + +Build (from pyCircuit root): + PYTHONPATH=python:. python -m pycircuit.cli emit \ + examples/digital_filter/digital_filter.py \ + -o examples/generated/digital_filter/digital_filter.pyc + build/bin/pyc-compile examples/generated/digital_filter/digital_filter.pyc \ + --emit=cpp -o examples/generated/digital_filter/digital_filter_gen.hpp + c++ -std=c++17 -O2 -shared -fPIC -I include -I . \ + -o examples/digital_filter/libfilter_sim.dylib \ + examples/digital_filter/filter_capi.cpp + +Run: + python examples/digital_filter/emulate_filter.py +""" +from __future__ import annotations + +import ctypes +import re as _re +import struct +import sys +import time +from pathlib import Path + +# ═══════════════════════════════════════════════════════════════════ +# ANSI +# ═══════════════════════════════════════════════════════════════════ +RESET = "\033[0m"; BOLD = "\033[1m"; DIM = "\033[2m" +RED = "\033[31m"; GREEN = "\033[32m"; YELLOW = "\033[33m" +CYAN = "\033[36m"; WHITE = "\033[37m"; MAGENTA = "\033[35m" +BG_GREEN = "\033[42m"; BLACK = "\033[30m"; BLUE = "\033[34m" + +_ANSI = _re.compile(r'\x1b\[[0-9;]*m') +def _vl(s): return len(_ANSI.sub('', s)) +def _pad(s, w): return s + ' ' * max(0, w - _vl(s)) +def clear(): sys.stdout.write("\033[2J\033[H"); sys.stdout.flush() + +# ═══════════════════════════════════════════════════════════════════ +# Filter coefficients (must match RTL) +# ═══════════════════════════════════════════════════════════════════ +COEFFS = (1, 2, 3, 4) +TAPS = len(COEFFS) +DATA_W = 16 + +# ═══════════════════════════════════════════════════════════════════ +# RTL wrapper +# ═══════════════════════════════════════════════════════════════════ +class FilterRTL: + def __init__(self, lib_path=None): + if lib_path is None: + lib_path = str(Path(__file__).resolve().parent / "libfilter_sim.dylib") + L = ctypes.CDLL(lib_path) + L.fir_create.restype = ctypes.c_void_p + L.fir_destroy.argtypes = [ctypes.c_void_p] + L.fir_reset.argtypes = [ctypes.c_void_p, ctypes.c_uint64] + L.fir_push_sample.argtypes = [ctypes.c_void_p, ctypes.c_int16] + L.fir_idle.argtypes = [ctypes.c_void_p, ctypes.c_uint64] + L.fir_get_y_out.argtypes = [ctypes.c_void_p]; L.fir_get_y_out.restype = ctypes.c_int64 + L.fir_get_y_valid.argtypes = [ctypes.c_void_p]; L.fir_get_y_valid.restype = ctypes.c_uint32 + L.fir_get_cycle.argtypes = [ctypes.c_void_p]; L.fir_get_cycle.restype = ctypes.c_uint64 + self._L, self._c = L, L.fir_create() + self._delay = [0] * TAPS # Python-side tracking for display + + def __del__(self): + if hasattr(self,'_c') and self._c: self._L.fir_destroy(self._c) + + def reset(self): + self._L.fir_reset(self._c, 2) + self._delay = [0] * TAPS + + def push(self, sample: int): + self._L.fir_push_sample(self._c, sample & 0xFFFF) + # Track delay line for display + for i in range(TAPS - 1, 0, -1): + self._delay[i] = self._delay[i - 1] + self._delay[0] = sample + + def idle(self, n=4): + self._L.fir_idle(self._c, n) + + @property + def y_out(self): + raw = self._L.fir_get_y_out(self._c) + # Sign-extend from ACC_W bits + ACC_W = DATA_W + 16 + (TAPS - 1).bit_length() + if raw >= (1 << (ACC_W - 1)): + raw -= (1 << ACC_W) + return raw + + @property + def y_valid(self): return bool(self._L.fir_get_y_valid(self._c)) + @property + def cycle(self): return self._L.fir_get_cycle(self._c) + + def expected_output(self): + """Compute expected y using Python for verification.""" + return sum(self._delay[i] * COEFFS[i] for i in range(TAPS)) + +# ═══════════════════════════════════════════════════════════════════ +# Terminal UI +# ═══════════════════════════════════════════════════════════════════ +BOX_W = 64 + +def _bl(content): + return f" {CYAN}║{RESET}{_pad(content, BOX_W)}{CYAN}║{RESET}" + +def _bar_char(val, max_abs, width=20): + """Render a horizontal bar for a signed value.""" + if max_abs == 0: max_abs = 1 + half = width // 2 + pos = int(abs(val) / max_abs * half) + pos = min(pos, half) + if val >= 0: + bar = " " * half + "│" + f"{GREEN}{'█' * pos}{RESET}" + " " * (half - pos) + else: + bar = " " * (half - pos) + f"{RED}{'█' * pos}{RESET}" + "│" + " " * half + return bar + +def draw(sim, x_history, y_history, message="", test_info="", step=-1): + clear() + bar = "═" * BOX_W + + print(f"\n {CYAN}╔{bar}╗{RESET}") + print(_bl(f" {BOLD}{WHITE}4-TAP FIR FILTER — TRUE RTL SIMULATION{RESET}")) + print(f" {CYAN}╠{bar}╣{RESET}") + + if test_info: + print(_bl(f" {YELLOW}{test_info}{RESET}")) + print(f" {CYAN}╠{bar}╣{RESET}") + + # Filter structure diagram + print(_bl("")) + print(_bl(f" {BOLD}y[n] = c0·x[n] + c1·x[n-1] + c2·x[n-2] + c3·x[n-3]{RESET}")) + print(_bl(f" {DIM}Coefficients: c0={COEFFS[0]}, c1={COEFFS[1]}, c2={COEFFS[2]}, c3={COEFFS[3]}{RESET}")) + print(_bl("")) + + # Delay line contents + print(_bl(f" {BOLD}{CYAN}Delay Line:{RESET}")) + for i in range(TAPS): + tag = "x[n] " if i == 0 else f"x[n-{i}]" + val = sim._delay[i] + coef = COEFFS[i] + prod = val * coef + vc = f"{GREEN}" if val >= 0 else f"{RED}" + pc = f"{GREEN}" if prod >= 0 else f"{RED}" + print(_bl(f" {tag} = {vc}{val:>7}{RESET} × c{i}={coef:>3} = {pc}{prod:>10}{RESET}")) + + expected = sim.expected_output() + actual = sim.y_out + match = actual == expected + mc = GREEN if match else RED + + print(_bl(f" {'─' * 48}")) + print(_bl(f" {BOLD}y_out = {mc}{actual:>10}{RESET} " + f"(expected: {expected:>10} {'✓' if match else '✗'})")) + print(_bl("")) + + # Waveform display (last 16 samples) + WAVE_LEN = 16 + max_x = max((abs(v) for v in x_history[-WAVE_LEN:]), default=1) or 1 + max_y = max((abs(v) for v in y_history[-WAVE_LEN:]), default=1) or 1 + max_all = max(max_x, max_y) + + print(_bl(f" {BOLD}{CYAN}Input Waveform (last {min(len(x_history), WAVE_LEN)} samples):{RESET}")) + for v in x_history[-WAVE_LEN:]: + print(_bl(f" {v:>7} {_bar_char(v, max_all)}")) + + print(_bl("")) + print(_bl(f" {BOLD}{CYAN}Output Waveform:{RESET}")) + for v in y_history[-WAVE_LEN:]: + print(_bl(f" {v:>7} {_bar_char(v, max_all)}")) + + print(_bl("")) + print(_bl(f" Cycle: {DIM}{sim.cycle}{RESET}")) + + if message: + print(f" {CYAN}╠{bar}╣{RESET}") + print(_bl(f" {BOLD}{WHITE}{message}{RESET}")) + print(f" {CYAN}╚{bar}╝{RESET}") + print() + + +# ═══════════════════════════════════════════════════════════════════ +# Test scenarios +# ═══════════════════════════════════════════════════════════════════ + +def main(): + print(" Loading FIR filter RTL simulation...") + sim = FilterRTL() + sim.reset() + sim.idle(4) + print(f" {GREEN}RTL model loaded. Coefficients: {COEFFS}{RESET}") + time.sleep(0.5) + + x_hist = [] + y_hist = [] + all_ok = True + + def run_scenario(name, num, inputs, sim, x_hist, y_hist): + """Run a filter test scenario. Returns True if all outputs match. + + The RTL output is registered (1-cycle latency): after pushing x[n], + the y_out we read corresponds to the computation from x[n]'s state + (delay line updated, then combinational result captured). + We compare against the Python model which tracks the delay line + identically. + """ + nonlocal all_ok + sim.reset(); x_hist.clear(); y_hist.clear() + info = f"Test {num}: {name}" + + draw(sim, x_hist, y_hist, name, test_info=info) + time.sleep(0.8) + + ok_all = True + for i, x in enumerate(inputs): + sim.push(x) + x_hist.append(x) + y = sim.y_out + y_hist.append(y) + exp = sim.expected_output() + ok = (y == exp) + if not ok: + ok_all = False + all_ok = False + st = f"{GREEN}✓{RESET}" if ok else f"{RED}✗ exp {exp}{RESET}" + draw(sim, x_hist, y_hist, + f"Push x={x:>6}, y={y:>8} {st}", + test_info=info) + time.sleep(0.5) + + result = f"{GREEN}PASS{RESET}" if ok_all else f"{RED}FAIL{RESET}" + draw(sim, x_hist, y_hist, + f"{name} — {result}", test_info=info) + time.sleep(0.8) + return ok_all + + # ── Test 1: Impulse ────────────────────────────────────── + run_scenario("Impulse [1, 0, 0, 0, 0, 0, 0, 0]", 1, + [1, 0, 0, 0, 0, 0, 0, 0], sim, x_hist, y_hist) + + # ── Test 2: Step ───────────────────────────────────────── + run_scenario("Step [1, 1, 1, 1, 1, 1, 1, 1]", 2, + [1]*8, sim, x_hist, y_hist) + + # ── Test 3: Ramp ───────────────────────────────────────── + run_scenario("Ramp [0, 1, 2, 3, 4, 5, 6, 7]", 3, + list(range(8)), sim, x_hist, y_hist) + + # ── Test 4: Alternating ±100 ───────────────────────────── + run_scenario("Alternating ±100", 4, + [100, -100, 100, -100, 100, -100, 100, -100], + sim, x_hist, y_hist) + + # ── Test 5: Large values ───────────────────────────────── + run_scenario("Large values (10000)", 5, + [10000, 10000, 10000, 10000, 0, 0, 0, 0], + sim, x_hist, y_hist) + + # ── Summary ────────────────────────────────────────────── + if all_ok: + draw(sim, x_hist, y_hist, + f"All 5 tests PASSED! Filter verified against RTL.", + test_info="Complete") + time.sleep(2.0) + print(f" {GREEN}{BOLD}All tests passed (TRUE RTL SIMULATION).{RESET}\n") + else: + draw(sim, x_hist, y_hist, + f"{RED}Some tests FAILED!{RESET}", + test_info="Complete") + time.sleep(2.0) + print(f" {RED}{BOLD}Some tests failed.{RESET}\n") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/designs/examples/digital_filter/filter_capi.cpp b/designs/examples/digital_filter/filter_capi.cpp new file mode 100644 index 0000000..5072e1b --- /dev/null +++ b/designs/examples/digital_filter/filter_capi.cpp @@ -0,0 +1,59 @@ +/** + * filter_capi.cpp — C API wrapper for the 4-tap FIR filter RTL. + * + * Build (from pyCircuit root): + * c++ -std=c++17 -O2 -shared -fPIC -I include -I . \ + * -o examples/digital_filter/libfilter_sim.dylib \ + * examples/digital_filter/filter_capi.cpp + */ +#include +#include +#include + +#include "examples/generated/digital_filter/digital_filter_gen.hpp" + +using pyc::cpp::Wire; + +struct SimContext { + pyc::gen::digital_filter dut{}; + pyc::cpp::Testbench tb; + uint64_t cycle = 0; + SimContext() : tb(dut) { tb.addClock(dut.clk, 1); } +}; + +extern "C" { + +SimContext* fir_create() { return new SimContext(); } +void fir_destroy(SimContext* c) { delete c; } + +void fir_reset(SimContext* c, uint64_t n) { + c->tb.reset(c->dut.rst, n, 1); + c->dut.eval(); + c->cycle = 0; +} + +void fir_push_sample(SimContext* c, int16_t sample) { + // Assert x_in + x_valid for 1 cycle. + // The registered output captures the result on this clock edge. + c->dut.x_in = Wire<16>(static_cast(static_cast(sample))); + c->dut.x_valid = Wire<1>(1u); + c->tb.runCycles(1); + c->cycle++; + // Deassert and idle 1 cycle so output is stable for reading. + c->dut.x_valid = Wire<1>(0u); + c->dut.x_in = Wire<16>(0u); + c->tb.runCycles(1); + c->cycle++; +} + +void fir_idle(SimContext* c, uint64_t n) { + c->dut.x_valid = Wire<1>(0u); + c->tb.runCycles(n); + c->cycle += n; +} + +int64_t fir_get_y_out(SimContext* c) { return static_cast(c->dut.y_out.value()); } +uint32_t fir_get_y_valid(SimContext* c) { return c->dut.y_valid.value(); } +uint64_t fir_get_cycle(SimContext* c) { return c->cycle; } + +} // extern "C" diff --git a/designs/examples/dodgeball_game/README.md b/designs/examples/dodgeball_game/README.md new file mode 100644 index 0000000..bbe9df8 --- /dev/null +++ b/designs/examples/dodgeball_game/README.md @@ -0,0 +1,66 @@ +# Dodgeball Game (pyCircuit) + +A cycle-aware rewrite of the dodgeball VGA demo. The design keeps the original +FSM and object motion timing while adding `left/right` movement for the player. +The terminal emulator renders a downsampled VGA view to keep runtime low. + +**Key files** +- `lab_final_top.py`: pyCircuit top-level (game FSM, objects, player, VGA colors). +- `lab_final_VGA.py`: VGA timing generator (640x480 @ 60Hz). +- `dodgeball_capi.cpp`: C API wrapper for ctypes simulation. +- `emulate_dodgeball.py`: terminal visualization + optional auto-build. +- `stimuli/basic.py`: external stimulus for `START/left/right/RST_BTN`. + +## Ports + +| Port | Dir | Width | Description | +|------|-----|-------|-------------| +| `clk` | in | 1 | System clock | +| `rst` | in | 1 | Synchronous reset (for deterministic init) | +| `RST_BTN` | in | 1 | Game reset input (matches reference behavior) | +| `START` | in | 1 | Start game | +| `left` | in | 1 | Move player left (game tick) | +| `right` | in | 1 | Move player right (game tick) | +| `VGA_HS_O` | out | 1 | VGA HSync | +| `VGA_VS_O` | out | 1 | VGA VSync | +| `VGA_R` | out | 4 | VGA red (MSB used) | +| `VGA_G` | out | 4 | VGA green (MSB used) | +| `VGA_B` | out | 4 | VGA blue (MSB used) | +| `dbg_state` | out | 3 | FSM state (0 init, 1 play, 2 over) | +| `dbg_j` | out | 5 | Object step counter | +| `dbg_player_x` | out | 4 | Player column (0-15) | +| `dbg_ob*_x/y` | out | 4 | Object positions | + +## Run (Auto-Build) + +The emulator will build the C++ simulation library if it is missing. Use +`--rebuild` to force regeneration. + +```bash +python3 examples/dodgeball_game/emulate_dodgeball.py +python3 examples/dodgeball_game/emulate_dodgeball.py --rebuild +``` + +## Manual Build and Run + +```bash +PYTHONPATH=python:. python3 -m pycircuit.cli emit \ + examples/dodgeball_game/lab_final_top.py \ + -o examples/generated/dodgeball_game/dodgeball_game.pyc + +./build/bin/pyc-compile examples/generated/dodgeball_game/dodgeball_game.pyc \ + --emit=cpp --out-dir=examples/generated/dodgeball_game + +c++ -std=c++17 -O2 -shared -fPIC -I include -I . \ + -o examples/dodgeball_game/libdodgeball_sim.dylib \ + examples/dodgeball_game/dodgeball_capi.cpp + +python3 examples/dodgeball_game/emulate_dodgeball.py --stim basic +``` + +## Stimuli + +Stimulus is separated from the DUT and loaded as a module. +Available modules live under `examples/dodgeball_game/stimuli/`. + +- `basic`: start, move left, then move right, plus a reset/restart sequence. diff --git a/designs/examples/dodgeball_game/__init__.py b/designs/examples/dodgeball_game/__init__.py new file mode 100644 index 0000000..dd630ac --- /dev/null +++ b/designs/examples/dodgeball_game/__init__.py @@ -0,0 +1 @@ +# Package marker for dodgeball_game example. diff --git a/designs/examples/dodgeball_game/dodgeball_capi.cpp b/designs/examples/dodgeball_game/dodgeball_capi.cpp new file mode 100644 index 0000000..bcdc45e --- /dev/null +++ b/designs/examples/dodgeball_game/dodgeball_capi.cpp @@ -0,0 +1,82 @@ +/** + * dodgeball_capi.cpp — C API wrapper around the generated RTL model. + * + * Build: + * cd + * c++ -std=c++17 -O2 -shared -fPIC -I include -I . \ + * -o examples/dodgeball_game/libdodgeball_sim.dylib \ + * examples/dodgeball_game/dodgeball_capi.cpp + */ + +#include +#include +#include + +#include "../generated/dodgeball_game/dodgeball_game.hpp" + +using pyc::cpp::Wire; + +struct SimContext { + pyc::gen::dodgeball_game dut{}; + pyc::cpp::Testbench tb; + uint64_t cycle = 0; + + SimContext() : tb(dut) { + tb.addClock(dut.clk, /*halfPeriodSteps=*/1); + } +}; + +extern "C" { + +SimContext* db_create() { + return new SimContext(); +} + +void db_destroy(SimContext* ctx) { + delete ctx; +} + +void db_reset(SimContext* ctx, uint64_t cycles) { + ctx->tb.reset(ctx->dut.rst, /*cyclesAsserted=*/cycles, /*cyclesDeasserted=*/1); + ctx->dut.eval(); + ctx->cycle = 0; +} + +void db_set_inputs(SimContext* ctx, int rst_btn, int start, int left, int right) { + ctx->dut.RST_BTN = Wire<1>(rst_btn ? 1u : 0u); + ctx->dut.START = Wire<1>(start ? 1u : 0u); + ctx->dut.left = Wire<1>(left ? 1u : 0u); + ctx->dut.right = Wire<1>(right ? 1u : 0u); +} + +void db_tick(SimContext* ctx) { + ctx->tb.runCycles(1); + ctx->cycle++; +} + +void db_run_cycles(SimContext* ctx, uint64_t n) { + ctx->tb.runCycles(n); + ctx->cycle += n; +} + +// VGA outputs +uint32_t db_get_vga_hs(SimContext* ctx) { return ctx->dut.VGA_HS_O.value(); } +uint32_t db_get_vga_vs(SimContext* ctx) { return ctx->dut.VGA_VS_O.value(); } +uint32_t db_get_vga_r(SimContext* ctx) { return ctx->dut.VGA_R.value(); } +uint32_t db_get_vga_g(SimContext* ctx) { return ctx->dut.VGA_G.value(); } +uint32_t db_get_vga_b(SimContext* ctx) { return ctx->dut.VGA_B.value(); } + +// Debug outputs +uint32_t db_get_state(SimContext* ctx) { return ctx->dut.dbg_state.value(); } +uint32_t db_get_j(SimContext* ctx) { return ctx->dut.dbg_j.value(); } +uint32_t db_get_player_x(SimContext* ctx) { return ctx->dut.dbg_player_x.value(); } +uint32_t db_get_ob1_x(SimContext* ctx) { return ctx->dut.dbg_ob1_x.value(); } +uint32_t db_get_ob1_y(SimContext* ctx) { return ctx->dut.dbg_ob1_y.value(); } +uint32_t db_get_ob2_x(SimContext* ctx) { return ctx->dut.dbg_ob2_x.value(); } +uint32_t db_get_ob2_y(SimContext* ctx) { return ctx->dut.dbg_ob2_y.value(); } +uint32_t db_get_ob3_x(SimContext* ctx) { return ctx->dut.dbg_ob3_x.value(); } +uint32_t db_get_ob3_y(SimContext* ctx) { return ctx->dut.dbg_ob3_y.value(); } + +uint64_t db_get_cycle(SimContext* ctx) { return ctx->cycle; } + +} // extern "C" diff --git a/designs/examples/dodgeball_game/emulate_dodgeball.py b/designs/examples/dodgeball_game/emulate_dodgeball.py new file mode 100644 index 0000000..0b8c26c --- /dev/null +++ b/designs/examples/dodgeball_game/emulate_dodgeball.py @@ -0,0 +1,368 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +emulate_dodgeball.py — True RTL simulation of the dodgeball game +with a terminal visualization. + +By default the script will build the C++ simulation library if missing. +Use --rebuild to force regeneration. +""" +from __future__ import annotations + +import argparse +import ctypes +import importlib +import os +import shutil +import subprocess +import sys +import time +from pathlib import Path + +# ============================================================================= +# ANSI helpers +# ============================================================================= + +RESET = "\033[0m" +BOLD = "\033[1m" +DIM = "\033[2m" +RED = "\033[31m" +GREEN = "\033[32m" +YELLOW = "\033[33m" +BLUE = "\033[34m" +CYAN = "\033[36m" +WHITE = "\033[37m" + + +def clear_screen() -> None: + print("\033[2J\033[H", end="") + + +# ============================================================================= +# RTL simulation wrapper (ctypes -> compiled C++ netlist) +# ============================================================================= + +MAIN_CLK_BIT = 20 +CYCLES_PER_TICK = 1 << (MAIN_CLK_BIT + 1) + + +class DodgeballRTL: + def __init__(self, lib_path: str | None = None): + if lib_path is None: + lib_path = str(Path(__file__).resolve().parent / "libdodgeball_sim.dylib") + self._lib = ctypes.CDLL(lib_path) + + self._lib.db_create.restype = ctypes.c_void_p + self._lib.db_destroy.argtypes = [ctypes.c_void_p] + self._lib.db_reset.argtypes = [ctypes.c_void_p, ctypes.c_uint64] + self._lib.db_set_inputs.argtypes = [ctypes.c_void_p, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int] + self._lib.db_tick.argtypes = [ctypes.c_void_p] + self._lib.db_run_cycles.argtypes = [ctypes.c_void_p, ctypes.c_uint64] + + for name in ( + "db_get_state", "db_get_j", "db_get_player_x", + "db_get_ob1_x", "db_get_ob1_y", + "db_get_ob2_x", "db_get_ob2_y", + "db_get_ob3_x", "db_get_ob3_y", + "db_get_vga_hs", "db_get_vga_vs", + "db_get_vga_r", "db_get_vga_g", "db_get_vga_b", + ): + getattr(self._lib, name).argtypes = [ctypes.c_void_p] + getattr(self._lib, name).restype = ctypes.c_uint32 + + self._lib.db_get_cycle.argtypes = [ctypes.c_void_p] + self._lib.db_get_cycle.restype = ctypes.c_uint64 + + self._ctx = self._lib.db_create() + self.rst_btn = 0 + self.start = 0 + self.left = 0 + self.right = 0 + + def __del__(self): + if hasattr(self, "_ctx") and self._ctx: + self._lib.db_destroy(self._ctx) + + def reset(self, cycles: int = 2): + self._lib.db_reset(self._ctx, cycles) + + def _apply_inputs(self): + self._lib.db_set_inputs(self._ctx, self.rst_btn, self.start, self.left, self.right) + + def tick(self): + self._apply_inputs() + self._lib.db_tick(self._ctx) + + def run_cycles(self, n: int): + self._apply_inputs() + self._lib.db_run_cycles(self._ctx, n) + + @property + def state(self) -> int: + return int(self._lib.db_get_state(self._ctx)) + + @property + def j(self) -> int: + return int(self._lib.db_get_j(self._ctx)) + + @property + def player_x(self) -> int: + return int(self._lib.db_get_player_x(self._ctx)) + + @property + def ob1(self) -> tuple[int, int]: + return (int(self._lib.db_get_ob1_x(self._ctx)), int(self._lib.db_get_ob1_y(self._ctx))) + + @property + def ob2(self) -> tuple[int, int]: + return (int(self._lib.db_get_ob2_x(self._ctx)), int(self._lib.db_get_ob2_y(self._ctx))) + + @property + def ob3(self) -> tuple[int, int]: + return (int(self._lib.db_get_ob3_x(self._ctx)), int(self._lib.db_get_ob3_y(self._ctx))) + + @property + def cycle(self) -> int: + return int(self._lib.db_get_cycle(self._ctx)) + + +# ============================================================================= +# Build helpers +# ============================================================================= + + +def _find_root() -> Path: + return Path(__file__).resolve().parents[2] + + +def _find_pyc_compile(root: Path) -> Path: + candidates = [ + root / "build-top" / "bin" / "pyc-compile", + root / "build" / "bin" / "pyc-compile", + root / "pyc" / "mlir" / "build" / "bin" / "pyc-compile", + ] + for c in candidates: + if c.is_file() and os.access(c, os.X_OK): + return c + found = shutil.which("pyc-compile") + if found: + return Path(found) + raise RuntimeError("missing pyc-compile (build it with: scripts/pyc build)") + + +def _ensure_built(force: bool = False) -> None: + root = _find_root() + lib_path = Path(__file__).resolve().parent / "libdodgeball_sim.dylib" + srcs = [ + root / "examples" / "dodgeball_game" / "lab_final_top.py", + root / "examples" / "dodgeball_game" / "lab_final_VGA.py", + root / "examples" / "dodgeball_game" / "dodgeball_capi.cpp", + ] + if lib_path.exists() and not force: + lib_mtime = lib_path.stat().st_mtime + if all(s.exists() and s.stat().st_mtime <= lib_mtime for s in srcs): + return + + gen_dir = root / "examples" / "generated" / "dodgeball_game" + gen_dir.mkdir(parents=True, exist_ok=True) + + env = os.environ.copy() + py_path = f"{root}/python:{root}" + if env.get("PYTHONPATH"): + py_path = f"{py_path}:{env['PYTHONPATH']}" + env["PYTHONPATH"] = py_path + + subprocess.run( + [ + sys.executable, + "-m", + "pycircuit.cli", + "emit", + "examples/dodgeball_game/lab_final_top.py", + "-o", + str(gen_dir / "dodgeball_game.pyc"), + ], + cwd=root, + env=env, + check=True, + ) + + pyc_compile = _find_pyc_compile(root) + subprocess.run( + [ + str(pyc_compile), + str(gen_dir / "dodgeball_game.pyc"), + "--emit=cpp", + f"--out-dir={gen_dir}", + ], + cwd=root, + check=True, + ) + + subprocess.run( + [ + "c++", + "-std=c++17", + "-O2", + "-shared", + "-fPIC", + "-I", + "include", + "-I", + ".", + "-o", + str(lib_path), + "examples/dodgeball_game/dodgeball_capi.cpp", + ], + cwd=root, + check=True, + ) + + +# ============================================================================= +# Rendering (downsampled VGA) +# ============================================================================= + +ACTIVE_W = 640 +ACTIVE_H = 480 +SCALE_X = 40 +SCALE_Y = 40 +GRID_W = ACTIVE_W // SCALE_X +GRID_H = ACTIVE_H // SCALE_Y + +_COLOR = { + (0, 0, 0): f"{DIM}.{RESET}", + (1, 0, 0): f"{RED}#{RESET}", + (0, 1, 0): f"{GREEN}#{RESET}", + (0, 0, 1): f"{BLUE}#{RESET}", + (1, 1, 0): f"{YELLOW}#{RESET}", + (1, 0, 1): f"{RED}#{RESET}", + (0, 1, 1): f"{CYAN}#{RESET}", + (1, 1, 1): f"{WHITE}#{RESET}", +} + +STATE_NAMES = { + 0: "INIT", + 1: "PLAY", + 2: "OVER", +} + + +def _vga_color_at( + x: int, + y: int, + *, + state: int, + player_x: int, + objects: list[tuple[int, int]], +) -> tuple[int, int, int]: + def in_range(v: int, lo: int, hi: int) -> bool: + return (v > lo) and (v < hi) + + sq_player = ( + in_range(x, 40 * player_x, 40 * (player_x + 1)) and + in_range(y, 400, 440) + ) + + def sq_object(ox: int, oy: int) -> bool: + return ( + in_range(x, 40 * ox, 40 * (ox + 1)) and + in_range(y, 40 * oy, 40 * (oy + 1)) + ) + + sq_obj1 = sq_object(*objects[0]) + sq_obj2 = sq_object(*objects[1]) + sq_obj3 = sq_object(*objects[2]) + + over_wire = in_range(x, 0, 640) and in_range(y, 0, 480) + down = in_range(x, 0, 640) and in_range(y, 440, 480) + up = in_range(x, 0, 640) and in_range(y, 0, 40) + + over = (state == 2) + not_over = not over + + r = 1 if (sq_player and not_over) else 0 + b = 1 if ((sq_obj1 or sq_obj2 or sq_obj3 or down or up) and not_over) else 0 + g = 1 if (over_wire and over) else 0 + return (r, g, b) + + +def render_vga_sampled(state: int, player_x: int, objects: list[tuple[int, int]]) -> list[str]: + lines: list[str] = [] + for row in range(GRID_H): + y = row * SCALE_Y + (SCALE_Y // 2) + line = [] + for col in range(GRID_W): + x = col * SCALE_X + (SCALE_X // 2) + rgb = _vga_color_at(x, y, state=state, player_x=player_x, objects=objects) + line.append(_COLOR.get(rgb, _COLOR[(0, 0, 0)])) + lines.append("".join(line)) + return lines + + +# ============================================================================= +# Stimulus loading +# ============================================================================= + + +def _load_stimulus(name: str): + if "." in name: + return importlib.import_module(name) + try: + return importlib.import_module(f"examples.dodgeball_game.stimuli.{name}") + except ModuleNotFoundError: + root = _find_root() + sys.path.insert(0, str(root)) + return importlib.import_module(f"examples.dodgeball_game.stimuli.{name}") + + +def main(): + ap = argparse.ArgumentParser(description="Dodgeball terminal emulator") + ap.add_argument( + "--stim", + default="basic", + help="Stimulus module name (e.g. basic)", + ) + ap.add_argument( + "--rebuild", + action="store_true", + help="Force rebuild of the C++ simulation library", + ) + args = ap.parse_args() + + _ensure_built(force=args.rebuild) + + stim = _load_stimulus(args.stim) + + rtl = DodgeballRTL() + rtl.reset() + if hasattr(stim, "init"): + stim.init(rtl) + + total_ticks = int(getattr(stim, "total_ticks", lambda: 20)()) + frame_sleep = float(getattr(stim, "sleep_s", lambda: 0.08)()) + + for tick in range(total_ticks): + if hasattr(stim, "step"): + stim.step(tick, rtl) + rtl.run_cycles(CYCLES_PER_TICK) + + clear_screen() + + state_name = STATE_NAMES.get(rtl.state, f"S{rtl.state}") + objs = [rtl.ob1, rtl.ob2, rtl.ob3] + grid_lines = render_vga_sampled(rtl.state, rtl.player_x, objs) + + print(f"{BOLD}{CYAN}dodgeball_game{RESET} tick={tick}") + print(f"cycle={rtl.cycle} state={state_name} j={rtl.j} main_clk_bit={MAIN_CLK_BIT}") + print(f"RST_BTN={rtl.rst_btn} START={rtl.start} left={rtl.left} right={rtl.right}") + print(f"note: VGA shown with {GRID_W}x{GRID_H} downsample") + print("") + for line in grid_lines: + print(line) + + time.sleep(frame_sleep) + + +if __name__ == "__main__": + main() diff --git a/designs/examples/dodgeball_game/lab_final_VGA.py b/designs/examples/dodgeball_game/lab_final_VGA.py new file mode 100644 index 0000000..694a5f5 --- /dev/null +++ b/designs/examples/dodgeball_game/lab_final_VGA.py @@ -0,0 +1,72 @@ +# -*- coding: utf-8 -*- +"""VGA timing generator — pyCircuit v4.0 rewrite of lab_final_VGA.v. + +Implements the same 640x480@60Hz timing logic with 800x524 total counts. +""" +from __future__ import annotations + +from pycircuit import Circuit, module, compile_cycle_aware, CycleAwareCircuit, CycleAwareDomain, u + +# VGA timing constants (same as reference Verilog) +HS_STA = 16 +HS_END = 16 + 96 +HA_STA = 16 + 96 + 48 +VS_STA = 480 + 11 +VS_END = 480 + 11 + 2 +VA_END = 480 +LINE = 800 +SCREEN = 524 + + +def build(m: CycleAwareCircuit, domain: CycleAwareDomain) -> None: + """Standalone VGA module (ports mirror the reference Verilog).""" + cd = domain.clock_domain + clk = cd.clk + rst = cd.rst + + i_pix_stb = m.input("i_pix_stb", width=1) + + h_count = m.out("vga_h_count", domain=cd, width=10, init=u(10, 0)) + v_count = m.out("vga_v_count", domain=cd, width=10, init=u(10, 0)) + + h = h_count.out() + v = v_count.out() + + h_end = h == u(10, LINE) + v_end = v == u(10, SCREEN) + + h_inc = h + u(10, 1) + v_inc = v + u(10, 1) + + h_after = u(10, 0) if h_end else h_inc + v_after = v_inc if h_end else v + v_after = u(10, 0) if v_end else v_after + + h_next = h_after if i_pix_stb else h + v_next = v_after if i_pix_stb else v + + o_hs = ~((h >= u(10, HS_STA)) & (h < u(10, HS_END))) + o_vs = ~((v >= u(10, VS_STA)) & (v < u(10, VS_END))) + + o_x = u(10, 0) if (h < u(10, HA_STA)) else (h - u(10, HA_STA)) + y_full = u(10, VA_END - 1) if (v >= u(10, VA_END)) else v + o_y = y_full[0:9] + + o_blanking = (h < u(10, HA_STA)) | (v > u(10, VA_END - 1)) + o_animate = (v == u(10, VA_END - 1)) & (h == u(10, LINE)) + + h_count.set(h_next) + v_count.set(v_next) + + m.output("o_hs", o_hs) + m.output("o_vs", o_vs) + m.output("o_blanking", o_blanking) + m.output("o_animate", o_animate) + m.output("o_x", o_x) + m.output("o_y", o_y) + + +build.__pycircuit_name__ = "lab_final_vga" + +if __name__ == "__main__": + print(compile_cycle_aware(build, name="lab_final_vga").emit_mlir()) diff --git a/designs/examples/dodgeball_game/lab_final_top.py b/designs/examples/dodgeball_game/lab_final_top.py new file mode 100644 index 0000000..78e1940 --- /dev/null +++ b/designs/examples/dodgeball_game/lab_final_top.py @@ -0,0 +1,283 @@ +# -*- coding: utf-8 -*- +"""Dodgeball top — pyCircuit v4.0 rewrite of lab_final_top.v. + +Notes: +- `clk` corresponds to the original `CLK_in`. +- A synchronous `rst` port is introduced for deterministic initialization. +- The internal game logic still uses `RST_BTN` exactly like the reference. +""" +from __future__ import annotations + +from pycircuit import Circuit, module, compile_cycle_aware, CycleAwareCircuit, CycleAwareDomain, u + +# VGA timing constants (same as lab_final_VGA) +HS_STA = 16 +HS_END = 16 + 96 +HA_STA = 16 + 96 + 48 +VS_STA = 480 + 11 +VS_END = 480 + 11 + 2 +VA_END = 480 +LINE = 800 +SCREEN = 524 + + +def build(m: CycleAwareCircuit, domain: CycleAwareDomain, *, MAIN_CLK_BIT: int = 20) -> None: + if MAIN_CLK_BIT < 0 or MAIN_CLK_BIT > 24: + raise ValueError("MAIN_CLK_BIT must be in [0, 24]") + cd = domain.clock_domain + clk = cd.clk + rst = cd.rst + + # ================================================================ + # Inputs + # ================================================================ + rst_btn = m.input("RST_BTN", width=1) + start = m.input("START", width=1) + left = m.input("left", width=1) + right = m.input("right", width=1) + + # ================================================================ + # Registers + # ================================================================ + cnt = m.out("pix_cnt", domain=cd, width=16, init=u(16, 0)) + pix_stb = m.out("pix_stb", domain=cd, width=1, init=u(1, 0)) + main_clk = m.out("main_clk", domain=cd, width=25, init=u(25, 0)) + + player_x = m.out("player_x", domain=cd, width=4, init=u(4, 8)) + j = m.out("j", domain=cd, width=5, init=u(5, 0)) + + ob1_x = m.out("ob1_x", domain=cd, width=4, init=u(4, 1)) + ob2_x = m.out("ob2_x", domain=cd, width=4, init=u(4, 4)) + ob3_x = m.out("ob3_x", domain=cd, width=4, init=u(4, 7)) + + ob1_y = m.out("ob1_y", domain=cd, width=4, init=u(4, 0)) + ob2_y = m.out("ob2_y", domain=cd, width=4, init=u(4, 0)) + ob3_y = m.out("ob3_y", domain=cd, width=4, init=u(4, 0)) + + fsm_state = m.out("fsm_state", domain=cd, width=3, init=u(3, 0)) + + # ================================================================ + # Combinational logic + # ================================================================ + + # --- Pixel strobe divider --- + cnt_ext = cnt.out() | u(17, 0) + sum17 = cnt_ext + u(17, 0x4000) + cnt_next = sum17[0:16] + pix_stb_next = sum17[16] + + # --- Main clock divider bit (for game logic tick) --- + main_clk_next = main_clk.out() + u(25, 1) + main_bit = main_clk.out()[MAIN_CLK_BIT] + main_next_bit = main_clk_next[MAIN_CLK_BIT] + game_tick = (~main_bit) & main_next_bit + + # --- VGA timing (inlined from lab_final_VGA) --- + vga_h_count = m.out("vga_h_count", domain=cd, width=10, init=u(10, 0)) + vga_v_count = m.out("vga_v_count", domain=cd, width=10, init=u(10, 0)) + + vh = vga_h_count.out() + vv = vga_v_count.out() + + vh_end = vh == u(10, LINE) + vv_end = vv == u(10, SCREEN) + + vh_inc = vh + u(10, 1) + vv_inc = vv + u(10, 1) + + vh_after = u(10, 0) if vh_end else vh_inc + vv_after = vv_inc if vh_end else vv + vv_after = u(10, 0) if vv_end else vv_after + + i_pix_stb = pix_stb.out() + vh_next = vh_after if i_pix_stb else vh + vv_next = vv_after if i_pix_stb else vv + + vga_hs = ~((vh >= u(10, HS_STA)) & (vh < u(10, HS_END))) + vga_vs = ~((vv >= u(10, VS_STA)) & (vv < u(10, VS_END))) + + vga_x_raw = u(10, 0) if (vh < u(10, HA_STA)) else (vh - u(10, HA_STA)) + vga_y_full = u(10, VA_END - 1) if (vv >= u(10, VA_END)) else vv + vga_y_raw = vga_y_full[0:9] + + vga_h_count.set(vh_next) + vga_v_count.set(vv_next) + + x = vga_x_raw + y = vga_y_raw + + # --- Read register Q outputs for combinational logic --- + px = player_x.out() + jv = j.out() + o1x = ob1_x.out(); o1y = ob1_y.out() + o2x = ob2_x.out(); o2y = ob2_y.out() + o3x = ob3_x.out(); o3y = ob3_y.out() + fsm = fsm_state.out() + + # --- Collision detection --- + collision = ( + ((o1x == px) & (o1y == u(4, 10))) | + ((o2x == px) & (o2y == u(4, 10))) | + ((o3x == px) & (o3y == u(4, 10))) + ) + + # --- Object motion increments (boolean -> 4-bit) --- + inc1 = ((jv > u(5, 0)) & (jv < u(5, 13))) | u(4, 0) + inc2 = ((jv > u(5, 3)) & (jv < u(5, 16))) | u(4, 0) + inc3 = ((jv > u(5, 7)) & (jv < u(5, 20))) | u(4, 0) + + # --- FSM state flags --- + st0 = fsm == u(3, 0) + st1 = fsm == u(3, 1) + st2 = fsm == u(3, 2) + + cond_state0 = game_tick & st0 + cond_state1 = game_tick & st1 + cond_state2 = game_tick & st2 + + cond_start = cond_state0 & start + cond_rst_s1 = cond_state1 & rst_btn + cond_rst_s2 = cond_state2 & rst_btn + cond_collision = cond_state1 & collision + cond_j20 = cond_state1 & (jv == u(5, 20)) + + # --- Player movement (left/right) --- + left_only = left & ~right + right_only = right & ~left + can_left = px > u(4, 0) + can_right = px < u(4, 15) + move_left = cond_state1 & left_only & can_left + move_right = cond_state1 & right_only & can_right + + # --- VGA draw logic --- + x10 = x + y10 = y | u(10, 0) + + player_x0 = (px | u(10, 0)) * u(10, 40) + player_x1 = ((px + u(4, 1)) | u(10, 0)) * u(10, 40) + + ob1_x0 = (o1x | u(10, 0)) * u(10, 40) + ob1_x1 = ((o1x + u(4, 1)) | u(10, 0)) * u(10, 40) + ob1_y0 = (o1y | u(10, 0)) * u(10, 40) + ob1_y1 = ((o1y + u(4, 1)) | u(10, 0)) * u(10, 40) + + ob2_x0 = (o2x | u(10, 0)) * u(10, 40) + ob2_x1 = ((o2x + u(4, 1)) | u(10, 0)) * u(10, 40) + ob2_y0 = (o2y | u(10, 0)) * u(10, 40) + ob2_y1 = ((o2y + u(4, 1)) | u(10, 0)) * u(10, 40) + + ob3_x0 = (o3x | u(10, 0)) * u(10, 40) + ob3_x1 = ((o3x + u(4, 1)) | u(10, 0)) * u(10, 40) + ob3_y0 = (o3y | u(10, 0)) * u(10, 40) + ob3_y1 = ((o3y + u(4, 1)) | u(10, 0)) * u(10, 40) + + sq_player = ( + (x10 > player_x0) & (y10 > u(10, 400)) & + (x10 < player_x1) & (y10 < u(10, 440)) + ) + + sq_object1 = ( + (x10 > ob1_x0) & (y10 > ob1_y0) & + (x10 < ob1_x1) & (y10 < ob1_y1) + ) + sq_object2 = ( + (x10 > ob2_x0) & (y10 > ob2_y0) & + (x10 < ob2_x1) & (y10 < ob2_y1) + ) + sq_object3 = ( + (x10 > ob3_x0) & (y10 > ob3_y0) & + (x10 < ob3_x1) & (y10 < ob3_y1) + ) + + over_wire = ( + (x10 > u(10, 0)) & (y10 > u(10, 0)) & + (x10 < u(10, 640)) & (y10 < u(10, 480)) + ) + down = ( + (x10 > u(10, 0)) & (y10 > u(10, 440)) & + (x10 < u(10, 640)) & (y10 < u(10, 480)) + ) + up = ( + (x10 > u(10, 0)) & (y10 > u(10, 0)) & + (x10 < u(10, 640)) & (y10 < u(10, 40)) + ) + + fsm_over = fsm == u(3, 2) + not_over = ~fsm_over + + circle = u(1, 0) + + vga_r_bit = sq_player & not_over + vga_b_bit = (sq_object1 | sq_object2 | sq_object3 | down | up) & not_over + vga_g_bit = circle | (over_wire & fsm_over) + + vga_r = m.cat(vga_r_bit, u(3, 0)) + vga_g = m.cat(vga_g_bit, u(3, 0)) + vga_b = m.cat(vga_b_bit, u(3, 0)) + + # ================================================================ + # Register updates (last-write-wins order mirrors Verilog) + # ================================================================ + + # Clock divider flops + cnt.set(cnt_next) + pix_stb.set(pix_stb_next) + main_clk.set(main_clk_next) + + # FSM state + fsm_state.set(u(3, 1), when=cond_start) + fsm_state.set(u(3, 0), when=cond_rst_s1) + fsm_state.set(u(3, 2), when=cond_collision) + fsm_state.set(u(3, 0), when=cond_rst_s2) + + # j counter + j.set(u(5, 0), when=cond_rst_s1) + j.set(u(5, 0), when=cond_j20) + j.set(jv + u(5, 1), when=cond_state1) + j.set(u(5, 0), when=cond_rst_s2) + + # player movement + player_x.set(px - u(4, 1), when=move_left) + player_x.set(px + u(4, 1), when=move_right) + + # object Y updates + ob1_y.set(u(4, 0), when=cond_rst_s1) + ob1_y.set(u(4, 0), when=cond_j20) + ob1_y.set(o1y + inc1, when=cond_state1) + ob1_y.set(u(4, 0), when=cond_rst_s2) + + ob2_y.set(u(4, 0), when=cond_rst_s1) + ob2_y.set(u(4, 0), when=cond_j20) + ob2_y.set(o2y + inc2, when=cond_state1) + ob2_y.set(u(4, 0), when=cond_rst_s2) + + ob3_y.set(u(4, 0), when=cond_rst_s1) + ob3_y.set(u(4, 0), when=cond_j20) + ob3_y.set(o3y + inc3, when=cond_state1) + ob3_y.set(u(4, 0), when=cond_rst_s2) + + # ================================================================ + # Outputs + # ================================================================ + m.output("VGA_HS_O", vga_hs) + m.output("VGA_VS_O", vga_vs) + m.output("VGA_R", vga_r) + m.output("VGA_G", vga_g) + m.output("VGA_B", vga_b) + + # Debug / visualization taps + m.output("dbg_state", fsm_state) + m.output("dbg_j", j) + m.output("dbg_player_x", player_x) + m.output("dbg_ob1_x", ob1_x) + m.output("dbg_ob1_y", ob1_y) + m.output("dbg_ob2_x", ob2_x) + m.output("dbg_ob2_y", ob2_y) + m.output("dbg_ob3_x", ob3_x) + m.output("dbg_ob3_y", ob3_y) + + +build.__pycircuit_name__ = "dodgeball_game" + +if __name__ == "__main__": + print(compile_cycle_aware(build, name="dodgeball_game", MAIN_CLK_BIT=20).emit_mlir()) diff --git a/designs/examples/dodgeball_game/reference/lab_final_VGA.v b/designs/examples/dodgeball_game/reference/lab_final_VGA.v new file mode 100644 index 0000000..6c6d8b9 --- /dev/null +++ b/designs/examples/dodgeball_game/reference/lab_final_VGA.v @@ -0,0 +1,56 @@ +`timescale 1ns / 1ps + +module vga( + input wire i_clk, // base clock + input wire i_pix_stb, // pixel clock strobe + output wire o_hs, // horizontal sync + output wire o_vs, // vertical sync + output wire o_blanking, // high during blanking interval + output wire o_animate, // high for one tick at end of active drawing + output wire [9:0] o_x, // current pixel x position: 10-bit value: 0-1023 + output wire [8:0] o_y // current pixel y position: 9-bit value: 0-511 + ); + + localparam HS_STA = 16; // horizontal sync start + localparam HS_END = 16 + 96; // horizontal sync end + localparam HA_STA = 16 + 96 + 48; // horizontal active pixel start + localparam VS_STA = 480 + 11; // vertical sync start + localparam VS_END = 480 + 11 + 2; // vertical sync end + localparam VA_END = 480; // vertical active pixel end + localparam LINE = 800; // complete line (pixels) + localparam SCREEN = 524; // complete screen (lines) + + reg [9:0] h_count = 0; // line position: 10-bit value: 0-1023 + reg [9:0] v_count = 0; // screen position: 10-bit value: 0-1023 + + // generate horizontal and vertical sync signals (both active low for 640x480) + assign o_hs = ~((h_count >= HS_STA) & (h_count < HS_END)); + assign o_vs = ~((v_count >= VS_STA) & (v_count < VS_END)); + + // keep x and y bound within the active pixels + assign o_x = (h_count < HA_STA) ? 0 : (h_count - HA_STA); + assign o_y = (v_count >= VA_END) ? (VA_END - 1) : (v_count); + + // blanking: high within the blanking period + assign o_blanking = ((h_count < HA_STA) | (v_count > VA_END - 1)); + + // animate: high for one tick at the end of the final active pixel line + assign o_animate = ((v_count == VA_END - 1) & (h_count == LINE)); + + always @ (posedge i_clk) + begin + if (i_pix_stb) // once per pixel + begin + if (h_count == LINE) // end of line + begin + h_count <= 0; + v_count <= v_count + 1; + end + else + h_count <= h_count + 1; + + if (v_count == SCREEN) // end of screen + v_count <= 0; + end + end +endmodule diff --git a/designs/examples/dodgeball_game/reference/lab_final_top.v b/designs/examples/dodgeball_game/reference/lab_final_top.v new file mode 100644 index 0000000..d5d18f2 --- /dev/null +++ b/designs/examples/dodgeball_game/reference/lab_final_top.v @@ -0,0 +1,139 @@ +`timescale 1ns / 1ps +////////////////////////////////////////////////////////////////////////////////// +// Company: +// Engineer: +// +// Create Date: 2018/06/09 20:25:15 +// Design Name: +// Module Name: lab_final_top +// Project Name: +// Target Devices: +// Tool Versions: +// Description: +// +// Dependencies: +// +// Revision: +// Revision 0.01 - File Created +// Additional Comments: +// +////////////////////////////////////////////////////////////////////////////////// + + +module top( + input wire CLK_in, // board clock: 100 MHz + input wire RST_BTN, // reset button + input wire START, //game start + output wire VGA_HS_O, // horizontal sync output + output wire VGA_VS_O, // vertical sync output + output wire [3:0] VGA_R, // 4-bit VGA red output + output wire [3:0] VGA_G, // 4-bit VGA green output + output wire [3:0] VGA_B, // 4-bit VGA blue output + input wire left, + input wire right + ); + +// wire rst = ~RST_BTN; // reset is active low on Arty + + // generate a 25 MHz pixel strobe + reg [15:0] cnt = 0; + reg pix_stb = 0; + reg [24:0]MAIN_CLK = 0; + always@(posedge CLK_in) + MAIN_CLK <= MAIN_CLK + 1; + always @(posedge CLK_in) + {pix_stb, cnt} <= cnt + 16'h4000; // divide clock by 4: (2^16)/4 = 0x4000 + + wire [9:0] x; // current pixel x position: 10-bit value: 0-1023 + wire [8:0] y; // current pixel y position: 9-bit value: 0-511 + + vga display ( + .i_clk(CLK_in), + .i_pix_stb(pix_stb), + .o_hs(VGA_HS_O), + .o_vs(VGA_VS_O), + .o_x(x), + .o_y(y) + ); + + wire sq_player; + wire sq_object1; + wire sq_object2; + wire sq_object3; + wire over_wire; + wire down; + wire up; + + reg [3:0]i=8; + reg [4:0]j=0; + + reg [3:0]MAIN_OB_1_x=1; + reg [3:0]MAIN_OB_2_x=4; + reg [3:0]MAIN_OB_3_x=7; + reg [3:0]MAIN_OB_1_y=0; + reg [3:0]MAIN_OB_2_y=0; + reg [3:0]MAIN_OB_3_y=0; + reg [2:0]FSM_state; + //0 initial + //1 gaming + //2 over + always@(posedge MAIN_CLK[22])begin + case(FSM_state) + 0: + begin + if (START == 1)begin + FSM_state <= 1; + end + end + 1: + begin + if (RST_BTN == 1)begin + FSM_state <= 0; + j <= 0; + MAIN_OB_1_y <= 0; + MAIN_OB_2_y <= 0; + MAIN_OB_3_y <= 0; + end + if ((MAIN_OB_1_x == i && MAIN_OB_1_y == 10) || (MAIN_OB_2_x == i && MAIN_OB_2_y == 10) || (MAIN_OB_3_x == i && MAIN_OB_3_y == 10)) + FSM_state <= 2; + if (j == 20)begin + j <= 0; + MAIN_OB_1_y <= 0; + MAIN_OB_2_y <= 0; + MAIN_OB_3_y <= 0; + end + begin + j <= j+1; + MAIN_OB_1_y <= MAIN_OB_1_y + ((j>0)&&(j<13)); + MAIN_OB_2_y <= MAIN_OB_2_y + ((j>3)&&(j<16)); + MAIN_OB_3_y <= MAIN_OB_3_y + ((j>7)&&(j<20)); + end + end + 2: + begin + if (RST_BTN == 1)begin + FSM_state <= 0; + j <= 0; + MAIN_OB_1_y <= 0; + MAIN_OB_2_y <= 0; + MAIN_OB_3_y <= 0; + end + end + endcase + end + + wire circle; + + assign sq_player=((x > 40*i) & (y > 400) & (x < 40*(i+1)) & (y < 440)) ? 1 : 0; + assign sq_object1=((x > 40*MAIN_OB_1_x) & (y > 40*MAIN_OB_1_y) & (x < 40*(MAIN_OB_1_x+1)) & (y < 40*(MAIN_OB_1_y+1))) ? 1 : 0; + assign sq_object2=((x > 40*MAIN_OB_2_x) & (y > 40*MAIN_OB_2_y) & (x < 40*(MAIN_OB_2_x+1)) & (y < 40*(MAIN_OB_2_y+1))) ? 1 : 0; + assign sq_object3=((x > 40*MAIN_OB_3_x) & (y > 40*MAIN_OB_3_y) & (x < 40*(MAIN_OB_3_x+1)) & (y < 40*(MAIN_OB_3_y+1))) ? 1 : 0; + assign over_wire=((x > 0) & (y > 0) & (x < 640) & (y < 480)) ? 1 : 0; + assign down=((x > 0) & (y > 440) & (x < 640) & (y < 480)) ? 1 : 0; + assign down=((x > 0) & (y > 0) & (x < 640) & (y < 40)) ? 1 : 0; + + assign VGA_R[3] = (sq_player & ~(FSM_state == 2)); // square b is red + assign VGA_B[3] = ((sq_object1|sq_object2|sq_object3|down|up) & ~(FSM_state == 2)); + assign VGA_G[3] = (circle|(over_wire & (FSM_state == 2))); + +endmodule \ No newline at end of file diff --git a/designs/examples/dodgeball_game/stimuli/__init__.py b/designs/examples/dodgeball_game/stimuli/__init__.py new file mode 100644 index 0000000..3b2c7a8 --- /dev/null +++ b/designs/examples/dodgeball_game/stimuli/__init__.py @@ -0,0 +1 @@ +# Package marker for dodgeball_game stimuli. diff --git a/designs/examples/dodgeball_game/stimuli/basic.py b/designs/examples/dodgeball_game/stimuli/basic.py new file mode 100644 index 0000000..290b2d3 --- /dev/null +++ b/designs/examples/dodgeball_game/stimuli/basic.py @@ -0,0 +1,32 @@ +# -*- coding: utf-8 -*- +"""Basic stimulus for the dodgeball demo.""" +from __future__ import annotations + + +def init(rtl) -> None: + rtl.rst_btn = 0 + rtl.start = 0 + rtl.left = 0 + rtl.right = 0 + + +def total_ticks() -> int: + return 24 + + +def sleep_s() -> float: + return 0.08 + + +def step(tick: int, rtl) -> None: + # Start the game at tick 0 + rtl.start = 1 if tick == 0 else 0 + + # Move left for a few ticks, then right + rtl.left = 1 if 4 <= tick < 7 else 0 + rtl.right = 1 if 9 <= tick < 12 else 0 + + # Demonstrate reset and restart + rtl.rst_btn = 1 if tick == 16 else 0 + if tick == 18: + rtl.start = 1 diff --git a/designs/examples/fastfwd/fastfwd.py b/designs/examples/fastfwd/fastfwd.py index 35bf7d6..3cf114d 100644 --- a/designs/examples/fastfwd/fastfwd.py +++ b/designs/examples/fastfwd/fastfwd.py @@ -1,6 +1,6 @@ from __future__ import annotations -from pycircuit import Circuit, ct, module, const, u +from pycircuit import Circuit, CycleAwareCircuit, CycleAwareDomain, const, ct, u @const @@ -14,9 +14,8 @@ def _total_engines(m: Circuit, n_fe: int | None, eng_per_lane: int) -> int: return max(1, int(eng_per_lane)) * ct.div_ceil(4, 1) -@module def build( - m: Circuit, + m: CycleAwareCircuit, domain: CycleAwareDomain, N_FE: int | None = None, ENG_PER_LANE: int = 1, LANE_Q_DEPTH: int = 16, @@ -27,6 +26,7 @@ def build( STASH_WIN: int = 6, BKPR_SLACK: int = 1, ) -> None: + _ = domain _ = (LANE_Q_DEPTH, ENG_Q_DEPTH, ROB_DEPTH, SEQ_W, HIST_DEPTH, STASH_WIN, BKPR_SLACK) total_eng = _total_engines(m, N_FE, ENG_PER_LANE) diff --git a/designs/examples/fastfwd/tb_fastfwd.py b/designs/examples/fastfwd/tb_fastfwd.py index 9aec8d2..f569ad1 100644 --- a/designs/examples/fastfwd/tb_fastfwd.py +++ b/designs/examples/fastfwd/tb_fastfwd.py @@ -3,7 +3,7 @@ import sys from pathlib import Path -from pycircuit import Tb, compile, testbench +from pycircuit import CycleAwareTb, Tb, compile_cycle_aware, CycleAwareCircuit, CycleAwareDomain, testbench _THIS_DIR = Path(__file__).resolve().parent if str(_THIS_DIR) not in sys.path: @@ -15,11 +15,15 @@ @testbench def tb(t: Tb) -> None: + tb = CycleAwareTb(t) p = TB_PRESETS["smoke"] - t.timeout(int(p["timeout"])) - t.expect("pkt_in_bkpr", 0, at=0) - t.finish(at=int(p["finish"])) + tb.timeout(int(p["timeout"])) + + # --- cycle 0 --- + tb.expect("pkt_in_bkpr", 0) + + tb.finish(at=int(p["finish"])) if __name__ == "__main__": - print(compile(build, name="tb_fastfwd_top", **DEFAULT_PARAMS).emit_mlir()) + print(compile_cycle_aware(build, name="tb_fastfwd_top", eager=True, **DEFAULT_PARAMS).emit_mlir()) diff --git a/designs/examples/fifo_loopback/fifo_loopback.py b/designs/examples/fifo_loopback/fifo_loopback.py index 8017f78..4ee791b 100644 --- a/designs/examples/fifo_loopback/fifo_loopback.py +++ b/designs/examples/fifo_loopback/fifo_loopback.py @@ -1,18 +1,18 @@ from __future__ import annotations -from pycircuit import Circuit, compile, module +from pycircuit import Circuit, module, compile_cycle_aware, CycleAwareCircuit, CycleAwareDomain -@module -def build(m: Circuit, depth: int = 2) -> None: - clk = m.clock("clk") - rst = m.reset("rst") +def build(m: CycleAwareCircuit, domain: CycleAwareDomain, depth: int = 2) -> None: + cd = domain.clock_domain + clk = cd.clk + rst = cd.rst in_valid = m.input("in_valid", width=1) in_data = m.input("in_data", width=8) out_ready = m.input("out_ready", width=1) - q = m.rv_queue("q", clk=clk, rst=rst, width=8, depth=depth) + q = m.rv_queue("q", domain=cd, width=8, depth=depth) q.push(in_data, when=in_valid) p = q.pop(when=out_ready) @@ -26,4 +26,4 @@ def build(m: Circuit, depth: int = 2) -> None: if __name__ == "__main__": - print(compile(build, name="fifo_loopback", depth=2).emit_mlir()) + print(compile_cycle_aware(build, name="fifo_loopback", eager=True, depth=2).emit_mlir()) diff --git a/designs/examples/fifo_loopback/tb_fifo_loopback.py b/designs/examples/fifo_loopback/tb_fifo_loopback.py index 7791f01..1065ff8 100644 --- a/designs/examples/fifo_loopback/tb_fifo_loopback.py +++ b/designs/examples/fifo_loopback/tb_fifo_loopback.py @@ -3,7 +3,7 @@ import sys from pathlib import Path -from pycircuit import Tb, compile, testbench +from pycircuit import CycleAwareTb, Tb, compile_cycle_aware, CycleAwareCircuit, CycleAwareDomain, testbench _THIS_DIR = Path(__file__).resolve().parent if str(_THIS_DIR) not in sys.path: @@ -15,15 +15,19 @@ @testbench def tb(t: Tb) -> None: + tb = CycleAwareTb(t) p = TB_PRESETS["smoke"] - t.clock("clk") - t.reset("rst", cycles_asserted=2, cycles_deasserted=1) - t.timeout(int(p["timeout"])) - t.drive("in_valid", 1, at=0) - t.drive("in_data", 0x2A, at=0) - t.drive("out_ready", 1, at=0) - t.finish(at=int(p["finish"])) + tb.clock("clk") + tb.reset("rst", cycles_asserted=2, cycles_deasserted=1) + tb.timeout(int(p["timeout"])) + + # --- cycle 0 --- + tb.drive("in_valid", 1) + tb.drive("in_data", 0x2A) + tb.drive("out_ready", 1) + + tb.finish(at=int(p["finish"])) if __name__ == "__main__": - print(compile(build, name="tb_fifo_loopback_top", **DEFAULT_PARAMS).emit_mlir()) + print(compile_cycle_aware(build, name="tb_fifo_loopback_top", eager=True, **DEFAULT_PARAMS).emit_mlir()) diff --git a/designs/examples/fm16/README.md b/designs/examples/fm16/README.md new file mode 100644 index 0000000..7efb742 --- /dev/null +++ b/designs/examples/fm16/README.md @@ -0,0 +1,54 @@ +# FM16 — 16-NPU Full-Mesh System Simulation + +Cycle-accurate simulation of a 16-chip Ascend950-like NPU cluster with +full-mesh interconnect topology. + +## System Architecture + +``` + NPU0 ──4 links── NPU1 ──4 links── NPU2 ... + │╲ │╲ + │ ╲ full mesh │ ╲ + │ ╲ (4 links │ ╲ + │ ╲ per pair)│ ╲ + NPU3 ──────────── NPU4 ... (16 NPUs total) +``` + +### NPU Node (Ascend950 simplified) +- **HBM**: 1.6 Tbps bandwidth (packet injection) +- **UB Ports**: 18×4×112 Gbps (simplified to N mesh ports) +- Routing: destination-based (dst → output port mapping) +- Output FIFOs per port with round-robin arbitration + +### SW5809s Switch (simplified) +- 16×8×112 Gbps ports +- VOQ (Virtual Output Queue) per (input, output) pair +- Crossbar with round-robin / MDRR scheduling + +### Packet Format +- 512 bytes per packet +- 32-bit descriptor: src[4] | dst[4] | seq[8] | tag[16] + +## Topology +- **Full mesh**: 4 links per NPU pair (16×15/2 = 120 bidirectional pairs) +- **All-to-all traffic**: each NPU continuously sends to all other NPUs + +## Files + +| File | Description | +|------|-------------| +| `npu_node.py` | pyCircuit RTL of single NPU (compile-verified) | +| `sw5809s.py` | pyCircuit RTL of switch (compile-verified) | +| `fm16_system.py` | Python behavioral system simulator with real-time visualization | + +## Run + +```bash +python examples/fm16/fm16_system.py +``` + +## Statistics +- Per-NPU delivered bandwidth (bar chart) +- Aggregate system bandwidth (Gbps) +- Latency distribution: avg, P50, P95, P99 +- Histogram visualization diff --git a/designs/examples/fm16/__init__.py b/designs/examples/fm16/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/designs/examples/fm16/fm16_system.py b/designs/examples/fm16/fm16_system.py new file mode 100644 index 0000000..144d68f --- /dev/null +++ b/designs/examples/fm16/fm16_system.py @@ -0,0 +1,606 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +FM16 vs SW16 System Comparison Simulator. + +Compares two 16-NPU topologies side-by-side: + + FM16: Full Mesh — 4 direct links between every NPU pair + (16×15/2 = 120 bidirectional link-pairs, 480 total links) + Each pair: 4 × 112 Gbps = 448 Gbps + + SW16: Star via SW5809s — each NPU connects to a central switch + with 8×4 = 32 links (simplified to SW_LINKS_PER_NPU). + Switch: VOQ + crossbar + round-robin (MDRR). + Path: NPU → switch → NPU (2 hops) + +Both run all-to-all continuous 512B packet traffic from 4Tbps HBM. + +Usage: + python examples/fm16/fm16_system.py +""" +from __future__ import annotations + +import collections +import random +import re as _re +import sys +import time +from dataclasses import dataclass, field + +# ═══════════════════════════════════════════════════════════════════ +# ANSI +# ═══════════════════════════════════════════════════════════════════ +RESET = "\033[0m"; BOLD = "\033[1m"; DIM = "\033[2m" +RED = "\033[31m"; GREEN = "\033[32m"; YELLOW = "\033[33m" +CYAN = "\033[36m"; WHITE = "\033[37m"; MAGENTA = "\033[35m"; BLUE = "\033[34m" +_ANSI = _re.compile(r'\x1b\[[0-9;]*m') +def _vl(s): return len(_ANSI.sub('', s)) +def _pad(s, w): return s + ' ' * max(0, w - _vl(s)) +def clear(): sys.stdout.write("\033[2J\033[H"); sys.stdout.flush() + +# ═══════════════════════════════════════════════════════════════════ +# Parameters +# ═══════════════════════════════════════════════════════════════════ +N_NPUS = 16 +FM_LINKS_PER_PAIR = 4 # FM16: 4 links per NPU pair +SW_LINKS_PER_NPU = 32 # SW16: 32 links from each NPU to the switch (8×4) +SW_XBAR_LINKS = 512 # SW5809s: 512×512 physical links (112Gbps each) +SW_LINKS_PER_PORT = 4 # 4 links bundled as 1 logical port +SW_XBAR_PORTS = SW_XBAR_LINKS // SW_LINKS_PER_PORT # 128 logical ports +SW_PORTS_PER_NPU = SW_LINKS_PER_NPU // SW_LINKS_PER_PORT # 8 logical ports per NPU +PKT_SIZE = 512 # bytes +LINK_BW_GBPS = 112 # Gbps per link +HBM_BW_TBPS = 4.0 # Tbps HBM per NPU +PKT_TIME_NS = PKT_SIZE * 8 / LINK_BW_GBPS # ~36.6 ns +HBM_INJECT_PROB = min(1.0, HBM_BW_TBPS * 1000 / LINK_BW_GBPS / N_NPUS) +INJECT_BATCH = 8 # ~8 pkt/cycle/NPU ≈ SW capacity (128 ports / 16 NPUs) +FIFO_DEPTH = 64 +VOQ_DEPTH = 32 +SIM_CYCLES = 3000 +DISPLAY_INTERVAL = 150 + +FM_LINK_LATENCY = 3 # direct mesh: 3 cycle pipeline +SW_LINK_LATENCY = 2 # NPU→switch or switch→NPU: 2 cycles each +SW_XBAR_LATENCY = 1 # switch internal crossbar: 1 cycle + + +# ═══════════════════════════════════════════════════════════════════ +# Packet +# ═══════════════════════════════════════════════════════════════════ +@dataclass +class Packet: + src: int + dst: int + seq: int + inject_cycle: int + def latency(self, now): return now - self.inject_cycle + + +# ═══════════════════════════════════════════════════════════════════ +# NPU Node (shared by both topologies) +# ═══════════════════════════════════════════════════════════════════ +class NPUNode: + def __init__(self, nid, n_ports): + self.id = nid + self.n_ports = n_ports + self.out_fifos = [collections.deque(maxlen=FIFO_DEPTH) for _ in range(n_ports)] + self.seq = 0 + self.pkts_injected = 0 + self.pkts_delivered = 0 + self.latencies: list[int] = [] + + def inject(self, cycle, rng): + for _ in range(INJECT_BATCH): + if rng.random() > HBM_INJECT_PROB: + continue + dst = self.id + while dst == self.id: + dst = rng.randint(0, N_NPUS - 1) + pkt = Packet(self.id, dst, self.seq, cycle) + self.seq += 1 + port = dst % self.n_ports + if len(self.out_fifos[port]) < FIFO_DEPTH: + self.out_fifos[port].append(pkt) + self.pkts_injected += 1 + + def tx(self, port): + if self.out_fifos[port]: + return self.out_fifos[port].popleft() + return None + + def rx(self, pkt, cycle): + self.pkts_delivered += 1 + self.latencies.append(pkt.latency(cycle)) + + +# ═══════════════════════════════════════════════════════════════════ +# SW5809s Switch (behavioral — VOQ + crossbar + round-robin) +# ═══════════════════════════════════════════════════════════════════ +class SW5809s: + """SW5809s: 512×512 link crossbar, 128×128 logical port crossbar. + + Physical: 512 input links × 512 output links (each 112 Gbps). + Logical: every 4 links are bundled into 1 port → 128×128 port crossbar. + Each logical port is independently arbitrated: up to + SW_LINKS_PER_PORT (4) packets per cycle. + + NPU mapping: NPU i → ports [i*8 .. i*8+7] (8 ports, 32 links). + + Ingress path for a packet from src_npu to dst_npu: + 1. Pick one of dst_npu's 8 egress ports via ECMP hash/policy + 2. Enqueue into VOQ[input_port][chosen_egress_port] + 3. Egress arbiter grants crossbar connection and delivers + + ECMP modes: + 'independent' : each input port has its own independent RR per dest NPU. + This is the REAL hardware behavior — causes VOQ collision + because uncoordinated RR pointers naturally converge. + 'coordinated' : a single global RR per dest NPU shared across all input + ports — ideal distribution, no collision (reference). + + VOQ collision: when multiple input ports independently pick the *same* + egress port for the same destination NPU, those packets pile up in + VOQs targeting that one port while the other 7 ports sit idle. + This increases tail latency significantly under high load. + """ + + def __init__(self, ecmp_mode: str = "independent"): + self.n_ports = SW_XBAR_PORTS # 128 + self.ports_per_npu = SW_PORTS_PER_NPU # 8 + self.pkts_per_port = SW_LINKS_PER_PORT # 4 + self.ecmp_mode = ecmp_mode + + self.voqs = [[collections.deque(maxlen=VOQ_DEPTH) + for _ in range(self.n_ports)] + for _ in range(self.n_ports)] + self.rr = [0] * self.n_ports + + # Independent mode: each input port has its own RR pointer per dest NPU + # Shape: [n_ports][N_NPUS] — 128 × 16 = 2048 independent counters + self.ingress_rr = [[0] * N_NPUS for _ in range(self.n_ports)] + + # Coordinated mode: single global RR per dest NPU (ideal reference) + self.global_rr = [0] * N_NPUS + + self.rng = random.Random(123) + + # Statistics + self.pkts_switched = 0 + self.pkts_enqueued = 0 + self.pkts_dropped = 0 # VOQ full drops + self.port_enq_count = [0] * self.n_ports # per-egress-port cumulative enqueue + self._voq_max_depth = [0] * self.n_ports # per-egress-port peak VOQ depth + self._voq_depth_sum = [0] * self.n_ports # for computing average + self._voq_snapshot_count = 0 + + def npu_to_ports(self, npu_id): + base = npu_id * self.ports_per_npu + return range(base, base + self.ports_per_npu) + + def enqueue(self, src_npu, in_port_hint, pkt): + """Enqueue packet arriving at a specific input port. + + in_port_hint: the physical input port index (within src NPU's 8 ports). + The input port uses its OWN independent RR to pick the egress port. + """ + dst_npu = pkt.dst + if dst_npu == src_npu or dst_npu >= N_NPUS: + return False + + # Determine actual input port + in_port = src_npu * self.ports_per_npu + (in_port_hint % self.ports_per_npu) + dst_base = dst_npu * self.ports_per_npu + + # ECMP: pick one of dst_npu's 8 egress ports + if self.ecmp_mode == "independent": + # Each input port has its own RR counter per dest NPU + idx = self.ingress_rr[in_port][dst_npu] + self.ingress_rr[in_port][dst_npu] = (idx + 1) % self.ports_per_npu + else: # coordinated + # Global RR shared by ALL input ports → perfect distribution + idx = self.global_rr[dst_npu] + self.global_rr[dst_npu] = (idx + 1) % self.ports_per_npu + + out_port = dst_base + idx + + if len(self.voqs[in_port][out_port]) < VOQ_DEPTH: + self.voqs[in_port][out_port].append(pkt) + self.pkts_enqueued += 1 + self.port_enq_count[out_port] += 1 + return True + self.pkts_dropped += 1 + return False + + def schedule(self): + """Crossbar scheduling: each egress port independently arbitrates + to select exactly 1 packet per cycle from all input-port VOQs. + + 128 egress ports × 1 pkt/cycle = 128 pkt/cycle max throughput. + Round-robin arbiter per egress port scans across 128 input ports. + """ + delivered = [] + for out_port in range(self.n_ports): + dest_npu = out_port // self.ports_per_npu + # Round-robin: pick 1 packet from any input port's VOQ + for offset in range(self.n_ports): + in_port = (self.rr[out_port] + offset) % self.n_ports + if in_port // self.ports_per_npu == dest_npu: + continue # skip loopback + if self.voqs[in_port][out_port]: + pkt = self.voqs[in_port][out_port].popleft() + self.rr[out_port] = (in_port + 1) % self.n_ports + self.pkts_switched += 1 + delivered.append((dest_npu, pkt)) + break # exactly 1 per egress port per cycle + return delivered + + def occupancy(self): + return sum(len(self.voqs[i][j]) + for i in range(self.n_ports) for j in range(self.n_ports)) + + def snapshot_voq_depths(self): + """Snapshot current VOQ depths per egress port. Call every cycle.""" + for out_port in range(self.n_ports): + depth = sum(len(self.voqs[i][out_port]) for i in range(self.n_ports)) + if depth > self._voq_max_depth[out_port]: + self._voq_max_depth[out_port] = depth + self._voq_depth_sum[out_port] += depth + self._voq_snapshot_count += 1 + + def voq_depth_stats(self): + """Return per-dest-NPU VOQ depth stats: (avg_of_avg, avg_of_max, max_of_max).""" + if self._voq_snapshot_count == 0: + return 0, 0, 0 + npu_avg = [] + npu_max = [] + for npu in range(N_NPUS): + ports = self.npu_to_ports(npu) + port_avgs = [self._voq_depth_sum[p] / self._voq_snapshot_count for p in ports] + port_maxs = [self._voq_max_depth[p] for p in ports] + npu_avg.append(sum(port_avgs) / len(port_avgs)) + npu_max.append(max(port_maxs)) + return (sum(npu_avg) / len(npu_avg), + sum(npu_max) / len(npu_max), + max(npu_max)) + + def port_load_imbalance(self): + """Return (min, avg, max) cumulative enqueue count across egress ports per NPU.""" + imbalances = [] + for npu in range(N_NPUS): + ports = self.npu_to_ports(npu) + counts = [self.port_enq_count[p] for p in ports] + if max(counts) > 0: + imbalances.append((min(counts), sum(counts)/len(counts), max(counts))) + if not imbalances: + return 0, 0, 0 + mins = [x[0] for x in imbalances] + avgs = [x[1] for x in imbalances] + maxs = [x[2] for x in imbalances] + return sum(mins)/len(mins), sum(avgs)/len(avgs), sum(maxs)/len(maxs) + + +# ═══════════════════════════════════════════════════════════════════ +# FM16 Topology: full mesh, 4 links per pair +# ═══════════════════════════════════════════════════════════════════ +class FM16System: + def __init__(self): + self.npus = [NPUNode(i, N_NPUS) for i in range(N_NPUS)] + self.cycle = 0 + self.rng = random.Random(42) + self._inflight: list[tuple[int, Packet]] = [] + + def step(self): + for npu in self.npus: + npu.inject(self.cycle, self.rng) + + for npu in self.npus: + for port in range(N_NPUS): + for _ in range(FM_LINKS_PER_PAIR): + pkt = npu.tx(port) + if pkt is None: break + if pkt.dst == npu.id: continue + qlat = len(npu.out_fifos[port]) + self._inflight.append((self.cycle + FM_LINK_LATENCY + qlat, pkt)) + + keep = [] + for (t, pkt) in self._inflight: + if t <= self.cycle: + self.npus[pkt.dst].rx(pkt, self.cycle) + else: + keep.append((t, pkt)) + self._inflight = keep + self.cycle += 1 + + def stats(self): + return _compute_stats(self.npus, self.cycle) + + +# ═══════════════════════════════════════════════════════════════════ +# SW16 Topology: star through SW5809s +# ═══════════════════════════════════════════════════════════════════ +class SW16System: + def __init__(self, ecmp_mode="ideal_rr"): + self.ecmp_mode = ecmp_mode + self.npus = [NPUNode(i, N_NPUS) for i in range(N_NPUS)] + self.switch = SW5809s(ecmp_mode=ecmp_mode) + self.cycle = 0 + self.rng = random.Random(42) + self._to_switch: list[tuple[int, int, Packet]] = [] # (arrive, src_npu, pkt) + self._to_npu: list[tuple[int, Packet]] = [] # (arrive, pkt) + + def step(self): + for npu in self.npus: + npu.inject(self.cycle, self.rng) + + # NPU → switch: each NPU can push up to SW_LINKS_PER_NPU pkts/cycle + # Packets are distributed across the NPU's 8 input ports via RR + for npu in self.npus: + sent = 0 + for port in range(N_NPUS): + while sent < SW_LINKS_PER_NPU: + pkt = npu.tx(port) + if pkt is None: break + if pkt.dst == npu.id: continue + # Assign to one of src NPU's 8 input ports (RR) + in_port_idx = sent % SW_PORTS_PER_NPU + self._to_switch.append((self.cycle + SW_LINK_LATENCY, + npu.id, in_port_idx, pkt)) + sent += 1 + + # Deliver to switch — each packet arrives at a specific input port + keep = [] + for (t, src, port_idx, pkt) in self._to_switch: + if t <= self.cycle: + self.switch.enqueue(src, port_idx, pkt) + else: + keep.append((t, src, port_idx, pkt)) + self._to_switch = keep + + # Switch crossbar: 128 ports × 1 pkt/port = 128 pkt/cycle max + self.switch.snapshot_voq_depths() # track VOQ depths before scheduling + delivered = self.switch.schedule() + for (dst_npu, pkt) in delivered: + self._to_npu.append((self.cycle + SW_XBAR_LATENCY + SW_LINK_LATENCY, pkt)) + + # Deliver to destination NPU + keep2 = [] + for (t, pkt) in self._to_npu: + if t <= self.cycle: + self.npus[pkt.dst].rx(pkt, self.cycle) + else: + keep2.append((t, pkt)) + self._to_npu = keep2 + + self.cycle += 1 + + def stats(self): + s = _compute_stats(self.npus, self.cycle) + s["sw_occupancy"] = self.switch.occupancy() + s["sw_switched"] = self.switch.pkts_switched + return s + + +# ═══════════════════════════════════════════════════════════════════ +# Statistics helper +# ═══════════════════════════════════════════════════════════════════ +def _compute_stats(npus, cycle): + all_lats = [] + total_inj = total_del = 0 + for n in npus: + all_lats.extend(n.latencies) + total_inj += n.pkts_injected + total_del += n.pkts_delivered + if not all_lats: + return {"avg":0,"p50":0,"p95":0,"p99":0,"max_lat":0, + "bw_gbps":0,"inj":total_inj,"del":total_del,"npu_del":[0]*len(npus)} + all_lats.sort() + n = len(all_lats) + t_ns = cycle * PKT_TIME_NS + n_npus = len(npus) + agg_bw = total_del * PKT_SIZE * 8 / t_ns if t_ns > 0 else 0 + return { + "avg": sum(all_lats)/n, + "p50": all_lats[n//2], + "p95": all_lats[int(n*0.95)], + "p99": all_lats[int(n*0.99)], + "max_lat": all_lats[-1], + "agg_bw_gbps": agg_bw, + "per_npu_bw_gbps": agg_bw / n_npus if n_npus > 0 else 0, + "inj": total_inj, + "del": total_del, + "npu_del": [npu.pkts_delivered for npu in npus], + } + +def _hist(npus, bins=12): + lats = [] + for n in npus: lats.extend(n.latencies) + if not lats: return [], 0, 0 + lo, hi = min(lats), max(lats) + if lo == hi: return [len(lats)], lo, hi + bw = max(1, (hi - lo + bins - 1) // bins) + h = [0] * bins + for l in lats: + h[min((l - lo) // bw, bins - 1)] += 1 + return h, lo, hi + + +# ═══════════════════════════════════════════════════════════════════ +# Side-by-side visualization +# ═══════════════════════════════════════════════════════════════════ +COL_W = 35 # width of each column +BOX_W = COL_W * 2 + 5 # total inner width + +def _bl(content): + return f" {CYAN}║{RESET}{_pad(content, BOX_W)}{CYAN}║{RESET}" + +def _bar(v, mx, w=14, ch="█", co=GREEN): + if mx <= 0: return "" + n = min(int(v / mx * w), w) + return f"{co}{ch*n}{RESET}" + +def _side(left, right): + """Render two strings side-by-side in the box.""" + return _bl(f" {_pad(left, COL_W)} │ {_pad(right, COL_W)}") + +def draw(fm, sw, cycle): + clear() + bar = "═" * BOX_W + sf = fm.stats() + ss = sw.stats() + pct = cycle * 100 // SIM_CYCLES + + print(f"\n {CYAN}╔{bar}╗{RESET}") + print(_bl(f" {BOLD}{WHITE}FM16 vs SW16 — Side-by-Side Comparison{RESET}")) + print(f" {CYAN}╠{bar}╣{RESET}") + print(_bl(f" {DIM}16 NPU | HBM {HBM_BW_TBPS}Tbps | 512B pkts | All-to-all{RESET}")) + prog = _bar(cycle, SIM_CYCLES, 30, "█", CYAN) + print(_bl(f" Cycle {cycle}/{SIM_CYCLES} [{prog}] {pct}%")) + print(f" {CYAN}╠{bar}╣{RESET}") + + # Headers + print(_side(f"{BOLD}{YELLOW}FM16 (Full Mesh){RESET}", + f"{BOLD}{MAGENTA}SW16 (Switch){RESET}")) + print(_side(f"{DIM}4 links/pair, 1 hop{RESET}", + f"{DIM}{SW_XBAR_LINKS}×{SW_XBAR_LINKS} xbar, {SW_LINKS_PER_PORT}link/port, 2 hop{RESET}")) + print(_bl(f" {'─' * COL_W} │ {'─' * COL_W}")) + + # Bandwidth (per NPU) + fm_max = (N_NPUS - 1) * FM_LINKS_PER_PAIR * LINK_BW_GBPS # 15×4×112 = 6720 + sw_max = SW_LINKS_PER_NPU * LINK_BW_GBPS # 32×112 = 3584 + # But switch crossbar limits to 1 pkt/output/cycle → effective max: + sw_eff = LINK_BW_GBPS # 1 pkt per output per cycle = 112 Gbps per dest + print(_side(f"Per-NPU BW: {BOLD}{sf['per_npu_bw_gbps']:>6.0f}{RESET} Gbps", + f"Per-NPU BW: {BOLD}{ss['per_npu_bw_gbps']:>6.0f}{RESET} Gbps")) + print(_side(f" (max: {fm_max} Gbps mesh)", + f" (max: {sw_max} Gbps link)")) + print(_side(f"Aggregate: {sf['agg_bw_gbps']:>8.0f} Gbps", + f"Aggregate: {ss['agg_bw_gbps']:>8.0f} Gbps")) + print(_side(f"Injected: {sf['inj']:>8d}", + f"Injected: {ss['inj']:>8d}")) + print(_side(f"Delivered: {sf['del']:>8d}", + f"Delivered: {ss['del']:>8d}")) + sw_extra = f" SW queued: {ss.get('sw_occupancy',0):>5d}" + print(_side("", sw_extra)) + + print(_bl(f" {'─' * COL_W} │ {'─' * COL_W}")) + + # Latency + print(_side(f"Avg: {YELLOW}{sf['avg']:>5.1f}{RESET} P50:{sf['p50']:>3d} P99:{sf['p99']:>3d}", + f"Avg: {YELLOW}{ss['avg']:>5.1f}{RESET} P50:{ss['p50']:>3d} P99:{ss['p99']:>3d}")) + print(_side(f"Max: {sf['max_lat']:>3d} cycles", + f"Max: {ss['max_lat']:>3d} cycles")) + + print(_bl(f" {'─' * COL_W} │ {'─' * COL_W}")) + + # Per-NPU bars + print(_side(f"{BOLD}Per-NPU delivered:{RESET}", f"{BOLD}Per-NPU delivered:{RESET}")) + max_f = max(sf["npu_del"]) if sf["npu_del"] else 1 + max_s = max(ss["npu_del"]) if ss["npu_del"] else 1 + mx = max(max_f, max_s, 1) + for i in range(N_NPUS): + fd = sf["npu_del"][i] if i < len(sf["npu_del"]) else 0 + sd = ss["npu_del"][i] if i < len(ss["npu_del"]) else 0 + fb = _bar(fd, mx, 12, "█", GREEN) + sb = _bar(sd, mx, 12, "█", MAGENTA) + print(_side(f" {i:>2d}:{fb}{fd:>6d}", f" {i:>2d}:{sb}{sd:>6d}")) + + print(_bl(f" {'─' * COL_W} │ {'─' * COL_W}")) + + # Latency histograms + hf, lof, hif = _hist(fm.npus, bins=8) + hs, los, his = _hist(sw.npus, bins=8) + print(_side(f"{BOLD}Latency Histogram:{RESET}", f"{BOLD}Latency Histogram:{RESET}")) + maxh = max(max(hf, default=1), max(hs, default=1), 1) + nbins = max(len(hf), len(hs)) + for bi in range(nbins): + bwf = max(1, (hif - lof + len(hf) - 1) // len(hf)) if hf else 1 + bws = max(1, (his - los + len(hs) - 1) // len(hs)) if hs else 1 + fv = hf[bi] if bi < len(hf) else 0 + sv = hs[bi] if bi < len(hs) else 0 + flo = lof + bi * bwf if hf else 0 + slo = los + bi * bws if hs else 0 + fb = _bar(fv, maxh, 10, "▓", GREEN) + sb = _bar(sv, maxh, 10, "▓", MAGENTA) + print(_side(f" {flo:>3d}+: {fb}{fv:>6d}", f" {slo:>3d}+: {sb}{sv:>6d}")) + + print(_bl("")) + print(f" {CYAN}╚{bar}╝{RESET}") + print() + + +# ═══════════════════════════════════════════════════════════════════ +# Main +# ═══════════════════════════════════════════════════════════════════ +def main(): + print(f" {BOLD}FM16 vs SW16 — Topology + ECMP Collision Comparison{RESET}") + print(f" Initializing 3 systems (FM16 + SW16-independent + SW16-coordinated)...") + + fm = FM16System() + sw_ind = SW16System(ecmp_mode="independent") # real hardware: VOQ collision + sw_crd = SW16System(ecmp_mode="coordinated") # ideal: no collision + + print(f" {GREEN}Systems ready. Running {SIM_CYCLES} cycles...{RESET}") + time.sleep(0.3) + + t0 = time.time() + for cyc in range(SIM_CYCLES): + fm.step() + sw_ind.step() + sw_crd.step() + if (cyc + 1) % DISPLAY_INTERVAL == 0 or cyc == SIM_CYCLES - 1: + draw(fm, sw_ind, cyc + 1) + elapsed = time.time() - t0 + if elapsed < 0.3: + time.sleep(0.03) + t1 = time.time() + + sf = fm.stats() + si = sw_ind.stats() + sc = sw_crd.stats() + li_min, li_avg, li_max = sw_ind.switch.port_load_imbalance() + lc_min, lc_avg, lc_max = sw_crd.switch.port_load_imbalance() + vi_avg, vi_avg_max, vi_peak = sw_ind.switch.voq_depth_stats() + vc_avg, vc_avg_max, vc_peak = sw_crd.switch.voq_depth_stats() + + print(f" {GREEN}{BOLD}Simulation complete!{RESET} ({t1-t0:.2f}s)") + print(f" {'─'*72}") + print(f" {'':24s} {'FM16':>14s} {'SW16-indep':>14s} {'SW16-coord':>14s}") + print(f" {'Per-NPU BW (Gbps)':24s} {sf['per_npu_bw_gbps']:>14.0f} {si['per_npu_bw_gbps']:>14.0f} {sc['per_npu_bw_gbps']:>14.0f}") + print(f" {'Aggregate BW (Gbps)':24s} {sf['agg_bw_gbps']:>14.0f} {si['agg_bw_gbps']:>14.0f} {sc['agg_bw_gbps']:>14.0f}") + print(f" {'Avg Latency (cycles)':24s} {sf['avg']:>14.1f} {si['avg']:>14.1f} {sc['avg']:>14.1f}") + print(f" {'P50 Latency':24s} {sf['p50']:>14d} {si['p50']:>14d} {sc['p50']:>14d}") + print(f" {'P95 Latency':24s} {sf['p95']:>14d} {si['p95']:>14d} {sc['p95']:>14d}") + print(f" {'P99 Latency':24s} {sf['p99']:>14d} {si['p99']:>14d} {sc['p99']:>14d}") + print(f" {'Max Latency':24s} {sf['max_lat']:>14d} {si['max_lat']:>14d} {sc['max_lat']:>14d}") + print(f" {'Delivered pkts':24s} {sf['del']:>14d} {si['del']:>14d} {sc['del']:>14d}") + print(f" {'Dropped pkts':24s} {'N/A':>14s} {si.get('sw_dropped',sw_ind.switch.pkts_dropped):>14d} {sc.get('sw_dropped',sw_crd.switch.pkts_dropped):>14d}") + print(f" {'─'*72}") + + print(f"\n {YELLOW}ECMP VOQ Collision Analysis:{RESET}") + print(f" Each input port independently round-robins across 8 egress ports.") + print(f" 'independent': 128 uncoordinated RR pointers → collisions") + print(f" 'coordinated': 1 global RR per dest NPU → no collision (ideal)") + print(f"") + print(f" {'Cumulative enqueue (per dest port)':40s} {'Independent':>14s} {'Coordinated':>14s}") + print(f" {' Min enqueued':40s} {li_min:>14.0f} {lc_min:>14.0f}") + print(f" {' Avg enqueued':40s} {li_avg:>14.0f} {lc_avg:>14.0f}") + print(f" {' Max enqueued':40s} {li_max:>14.0f} {lc_max:>14.0f}") + if li_avg > 0: + print(f" {' Max/Avg ratio':40s} {li_max/li_avg:>14.2f}x {lc_max/lc_avg:>14.2f}x") + print(f"") + print(f" {'VOQ depth (per egress port)':40s} {'Independent':>14s} {'Coordinated':>14s}") + print(f" {' Avg depth':40s} {vi_avg:>14.1f} {vc_avg:>14.1f}") + print(f" {' Avg peak depth':40s} {vi_avg_max:>14.1f} {vc_avg_max:>14.1f}") + print(f" {' Max peak depth (worst port)':40s} {vi_peak:>14d} {vc_peak:>14d}") + print(f"") + print(f" VOQ collision causes the {'independent':s} mode to have") + if si['p99'] > sc['p99']: + print(f" {RED}higher P99 latency: {si['p99']} vs {sc['p99']} cycles{RESET}") + else: + print(f" similar latency (collision effect minimal at this load level)") + print() + + +if __name__ == "__main__": + main() diff --git a/designs/examples/fm16/npu_node.py b/designs/examples/fm16/npu_node.py new file mode 100644 index 0000000..2f3aeb1 --- /dev/null +++ b/designs/examples/fm16/npu_node.py @@ -0,0 +1,71 @@ +# -*- coding: utf-8 -*- +"""Simplified NPU node — pyCircuit V5 cycle-aware.""" +from __future__ import annotations + +from pycircuit import ( + CycleAwareCircuit, + CycleAwareDomain, + compile_cycle_aware, + mux, +) + +PKT_W = 32 + + +def build(m: CycleAwareCircuit, domain: CycleAwareDomain, *, N_PORTS: int = 4, FIFO_DEPTH: int = 8, NODE_ID: int = 0) -> None: + cd = domain.clock_domain + + hbm_pkt = m.input("hbm_pkt", width=PKT_W) + hbm_valid = m.input("hbm_valid", width=1) + + rx_pkts = [m.input(f"rx_pkt_{i}", width=PKT_W) for i in range(N_PORTS)] + rx_vals = [m.input(f"rx_valid_{i}", width=1) for i in range(N_PORTS)] + + fifos = [] + for i in range(N_PORTS): + q = m.rv_queue(f"oq_{i}", domain=cd, width=PKT_W, depth=FIFO_DEPTH) + fifos.append(q) + + PORT_BITS = max((N_PORTS - 1).bit_length(), 1) + hbm_dst = hbm_pkt[24:28] + hbm_port = hbm_dst[0:PORT_BITS] + + for j in range(N_PORTS): + merged_data = m.const(0, width=PKT_W) + merged_valid = m.const(0, width=1) + + for i in range(N_PORTS): + rx_dst_i = rx_pkts[i][24:28] + rx_port_i = rx_dst_i[0:PORT_BITS] + fwd_match = (rx_port_i == m.const(j, width=PORT_BITS)) & rx_vals[i] + merged_data = mux(fwd_match, rx_pkts[i], merged_data) + merged_valid = fwd_match | merged_valid + + hbm_match_j = hbm_valid & (hbm_port == m.const(j, width=PORT_BITS)) + merged_data = mux(hbm_match_j, hbm_pkt, merged_data) + merged_valid = hbm_match_j | merged_valid + + fifos[j].push(merged_data, when=merged_valid) + + tx_pkts = [] + tx_vals = [] + for i in range(N_PORTS): + pop_result = fifos[i].pop(when=m.const(1, width=1)) + tx_pkts.append(pop_result.data) + tx_vals.append(pop_result.valid) + + hbm_ready_sig = m.const(1, width=1) + + for i in range(N_PORTS): + m.output(f"tx_pkt_{i}", tx_pkts[i]) + m.output(f"tx_valid_{i}", tx_vals[i]) + m.output("hbm_ready", hbm_ready_sig) + + +build.__pycircuit_name__ = "npu_node" + +if __name__ == "__main__": + circuit = compile_cycle_aware(build, name="npu_node", eager=True, + N_PORTS=4, FIFO_DEPTH=8, NODE_ID=0) + print(circuit.emit_mlir()[:500]) + print(f"... ({len(circuit.emit_mlir())} chars)") diff --git a/designs/examples/fm16/sw5809s.py b/designs/examples/fm16/sw5809s.py new file mode 100644 index 0000000..9d4d8ef --- /dev/null +++ b/designs/examples/fm16/sw5809s.py @@ -0,0 +1,79 @@ +# -*- coding: utf-8 -*- +"""Simplified SW5809s switch — pyCircuit V5 cycle-aware.""" +from __future__ import annotations + +from pycircuit import ( + CycleAwareCircuit, + CycleAwareDomain, + cas, + compile_cycle_aware, + mux, +) + +PKT_W = 32 + + +def build(m: CycleAwareCircuit, domain: CycleAwareDomain, *, N_PORTS: int = 4, VOQ_DEPTH: int = 4) -> None: + cd = domain.clock_domain + + PORT_BITS = max((N_PORTS - 1).bit_length(), 1) + + in_pkts = [m.input(f"in_pkt_{i}", width=PKT_W) for i in range(N_PORTS)] + in_vals = [m.input(f"in_valid_{i}", width=1) for i in range(N_PORTS)] + + voqs = [] + for i in range(N_PORTS): + row = [] + for j in range(N_PORTS): + q = m.rv_queue(f"voq_{i}_{j}", domain=cd, width=PKT_W, depth=VOQ_DEPTH) + row.append(q) + voqs.append(row) + + for i in range(N_PORTS): + pkt_dst = in_pkts[i][24:28][0:PORT_BITS] + for j in range(N_PORTS): + dst_match = (pkt_dst == m.const(j, width=PORT_BITS)) & in_vals[i] + voqs[i][j].push(in_pkts[i], when=dst_match) + + rr_states = [domain.state(width=PORT_BITS, reset_value=0, name=f"rr_{j}") for j in range(N_PORTS)] + + out_pkts = [] + out_vals = [] + + for j in range(N_PORTS): + peeks = [] + for i in range(N_PORTS): + peek = voqs[i][j].pop(when=m.const(0, width=1)) + peeks.append(peek) + + sel_pkt = m.const(0, width=PKT_W) + sel_val = m.const(0, width=1) + + for i in range(N_PORTS): + has_data = peeks[i].valid + sel_pkt = mux(has_data, peeks[i].data, sel_pkt) + sel_val = has_data | sel_val + + out_pkts.append(sel_pkt) + out_vals.append(sel_val) + + domain.next() + + for j in range(N_PORTS): + rr_cur = rr_states[j] + wrap = rr_cur == cas(domain, m.const(N_PORTS - 1, width=PORT_BITS), cycle=0) + next_rr = mux(wrap, cas(domain, m.const(0, width=PORT_BITS), cycle=0), rr_cur + 1) + rr_states[j].set(next_rr, when=cas(domain, out_vals[j], cycle=0)) + + for j in range(N_PORTS): + m.output(f"out_pkt_{j}", out_pkts[j]) + m.output(f"out_valid_{j}", out_vals[j]) + + +build.__pycircuit_name__ = "sw5809s" + +if __name__ == "__main__": + circuit = compile_cycle_aware(build, name="sw5809s", eager=True, + N_PORTS=4, VOQ_DEPTH=4) + print(circuit.emit_mlir()[:500]) + print(f"... ({len(circuit.emit_mlir())} chars)") diff --git a/designs/examples/fmac/README.md b/designs/examples/fmac/README.md new file mode 100644 index 0000000..54a42c7 --- /dev/null +++ b/designs/examples/fmac/README.md @@ -0,0 +1,94 @@ +# BF16 Fused Multiply-Accumulate (FMAC) + +A BF16 floating-point fused multiply-accumulate unit with 4-stage pipeline, +built from primitive standard cells (half adders, full adders, MUXes). + +## Operation + +``` +acc_out (FP32) = acc_in (FP32) + a (BF16) × b (BF16) +``` + +## Formats + +| Format | Bits | Layout | Bias | +|--------|------|--------|------| +| BF16 | 16 | sign(1) \| exp(8) \| mantissa(7) | 127 | +| FP32 | 32 | sign(1) \| exp(8) \| mantissa(23) | 127 | + +## 4-Stage Pipeline — Critical Path Summary + +``` + Stage 1: Unpack + PP + 2×CSA depth = 13 ██████ + Stage 2: Complete Multiply depth = 22 ███████████ + Stage 3: Align + Add depth = 21 ██████████ + Stage 4: Normalize + Pack depth = 31 ███████████████ + ────────────────────────────────────────────── + Total combinational depth depth = 87 + Max stage (critical path) depth = 31 +``` + +| Stage | Function | Depth | Key Components | +|-------|----------|------:|----------------| +| 1 | Unpack BF16, exp add, **PP generation + 2 CSA rounds** | 13 | Bit extract, MUX, 10-bit RCA, AND array, 2× 3:2 CSA | +| 2 | Complete multiply (remaining CSA + carry-select final add) | 22 | 3:2 CSA rounds, 16-bit carry-select adder | +| 3 | Align exponents, add/sub mantissas | 21 | Exponent compare, 5-level barrel shift, 26-bit RCA, magnitude compare | +| 4 | Normalize, pack FP32 | 31 | 26-bit LZC (priority MUX), 5-level barrel shift left/right, exponent adjust | + +**Pipeline balance**: The 8×8 multiplier is split across Stages 1 and 2. +Stage 1 generates partial products (AND gate array) and runs 2 rounds of +3:2 carry-save compression, reducing 8 rows to ~4. The intermediate +carry-save rows are stored in pipeline registers. Stage 2 completes the +reduction and uses a carry-select adder for the final addition. This +achieves good balance: **13 / 22 / 21 / 31** (critical path in Stage 4). + +## Design Hierarchy + +``` +bf16_fmac.py (top level) +└── primitive_standard_cells.py + ├── half_adder, full_adder (1-bit) + ├── ripple_carry_adder (N-bit) + ├── partial_product_array (AND gate array) + ├── compress_3to2 (CSA) (carry-save adder) + ├── reduce_partial_products (Wallace tree) + ├── unsigned_multiplier (N×M multiply) + ├── barrel_shift_right/left (MUX layers) + └── leading_zero_count (priority encoder) +``` + +## Files + +| File | Description | +|------|-------------| +| `primitive_standard_cells.py` | HA, FA, RCA, CSA, multiplier, shifters, LZC | +| `bf16_fmac.py` | 4-stage pipelined FMAC | +| `fmac_capi.cpp` | C API wrapper | +| `test_bf16_fmac.py` | 100 test cases (true RTL simulation) | + +## Build & Run + +```bash +# 1. Compile RTL +PYTHONPATH=python:. python -m pycircuit.cli emit \ + examples/fmac/bf16_fmac.py \ + -o examples/generated/fmac/bf16_fmac.pyc +build/bin/pyc-compile examples/generated/fmac/bf16_fmac.pyc \ + --emit=cpp -o examples/generated/fmac/bf16_fmac_gen.hpp + +# 2. Build shared library +c++ -std=c++17 -O2 -shared -fPIC -I include -I . \ + -o examples/fmac/libfmac_sim.dylib examples/fmac/fmac_capi.cpp + +# 3. Run 100 test cases +python examples/fmac/test_bf16_fmac.py +``` + +## Test Results + +100 test cases verified against Python float reference via true RTL simulation: + +- **100/100 passed** +- **Max relative error**: 5.36e-04 (limited by BF16's 7-bit mantissa) +- **Test groups**: simple values, powers of 2, small fractions, accumulation + chains, sign cancellation (acc ≈ -a×b), and 40 random cases diff --git a/designs/examples/fmac/__init__.py b/designs/examples/fmac/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/designs/examples/fmac/bf16_fmac.py b/designs/examples/fmac/bf16_fmac.py new file mode 100644 index 0000000..8dc217c --- /dev/null +++ b/designs/examples/fmac/bf16_fmac.py @@ -0,0 +1,366 @@ +# -*- coding: utf-8 -*- +"""BF16 Fused Multiply-Accumulate (FMAC) — 4-stage pipeline, pyCircuit v4.0. + +Computes: acc += a * b + where a, b are BF16 (1-8-7 format), acc is FP32 (1-8-23 format). + +BF16 format: sign(1) | exponent(8) | mantissa(7) bias=127 +FP32 format: sign(1) | exponent(8) | mantissa(23) bias=127 + +Pipeline stages: + Stage 1 (cycle 0→1): Unpack BF16 operands, compute product sign/exponent + depth ≈ 8 (exponent add via RCA) + Stage 2 (cycle 1→2): 8×8 mantissa multiply (partial product + reduction) + depth ≈ 12 (Wallace tree + final RCA) + Stage 3 (cycle 2→3): Align product to accumulator (barrel shift), add mantissas + depth ≈ 14 (shift + 26-bit RCA) + Stage 4 (cycle 3→4): Normalize result (LZC + shift + exponent adjust), pack FP32 + depth ≈ 14 (LZC + barrel shift + RCA) + +All arithmetic built from primitive standard cells (HA, FA, RCA, MUX). +""" +from __future__ import annotations + +import sys +from pathlib import Path + +from pycircuit import Circuit, module, compile_cycle_aware, CycleAwareCircuit, CycleAwareDomain, u, s + +try: + from .primitive_standard_cells import ( + unsigned_multiplier, ripple_carry_adder_packed, + barrel_shift_right, barrel_shift_left, leading_zero_count, + multiplier_pp_and_partial_reduce, multiplier_complete_reduce, + ) +except ImportError: + sys.path.insert(0, str(Path(__file__).resolve().parent)) + from primitive_standard_cells import ( + unsigned_multiplier, ripple_carry_adder_packed, + barrel_shift_right, barrel_shift_left, leading_zero_count, + multiplier_pp_and_partial_reduce, multiplier_complete_reduce, + ) + + +# ── Format constants ───────────────────────────────────────── +BF16_W = 16; BF16_EXP = 8; BF16_MAN = 7; BF16_BIAS = 127 +FP32_W = 32; FP32_EXP = 8; FP32_MAN = 23; FP32_BIAS = 127 + +# Internal mantissa with implicit 1: 8 bits for BF16 (1.7), 24 for FP32 (1.23) +BF16_MANT_FULL = BF16_MAN + 1 # 8 +FP32_MANT_FULL = FP32_MAN + 1 # 24 + +# Product mantissa: 8 × 8 = 16 bits (1.7 × 1.7 = 2.14, normalized to 1.15 → 16 bits) +PROD_MANT_W = BF16_MANT_FULL * 2 # 16 + +# Accumulator mantissa with guard bits for alignment: 26 bits +ACC_MANT_W = FP32_MANT_FULL + 2 # 26 (24 + 2 guard bits) + +_pipeline_depths: dict = {} + + +def build(m: CycleAwareCircuit, domain: CycleAwareDomain) -> None: + pipeline_depths = {} + cd = domain.clock_domain + clk = cd.clk + rst = cd.rst + + # ════════════════════════════════════════════════════════════ + # Inputs + # ════════════════════════════════════════════════════════════ + a_in = m.input("a_in", width=BF16_W) + b_in = m.input("b_in", width=BF16_W) + acc_in = m.input("acc_in", width=FP32_W) + valid_in = m.input("valid_in", width=1) + + # ════════════════════════════════════════════════════════════ + # Pipeline registers (all declared at top level) + # ════════════════════════════════════════════════════════════ + MAX_INTER_ROWS = 6 + + # Stage 1→2 registers + s1_prod_sign = m.out("s1_prod_sign", domain=cd, width=1, init=u(1, 0)) + s1_prod_exp = m.out("s1_prod_exp", domain=cd, width=10, init=u(10, 0)) + s1_acc_sign = m.out("s1_acc_sign", domain=cd, width=1, init=u(1, 0)) + s1_acc_exp = m.out("s1_acc_exp", domain=cd, width=8, init=u(8, 0)) + s1_acc_mant = m.out("s1_acc_mant", domain=cd, width=FP32_MANT_FULL, init=u(FP32_MANT_FULL, 0)) + s1_prod_zero = m.out("s1_prod_zero", domain=cd, width=1, init=u(1, 0)) + s1_acc_zero = m.out("s1_acc_zero", domain=cd, width=1, init=u(1, 0)) + s1_valid = m.out("s1_valid", domain=cd, width=1, init=u(1, 0)) + s1_mul_rows = [m.out(f"s1_mul_row{i}", domain=cd, width=PROD_MANT_W, init=u(PROD_MANT_W, 0)) + for i in range(MAX_INTER_ROWS)] + s1_mul_nrows = m.out("s1_mul_nrows", domain=cd, width=4, init=u(4, 0)) + + # Stage 2→3 registers + s2_prod_mant = m.out("s2_prod_mant", domain=cd, width=PROD_MANT_W, init=u(PROD_MANT_W, 0)) + s2_prod_sign = m.out("s2_prod_sign", domain=cd, width=1, init=u(1, 0)) + s2_prod_exp = m.out("s2_prod_exp", domain=cd, width=10, init=u(10, 0)) + s2_acc_sign = m.out("s2_acc_sign", domain=cd, width=1, init=u(1, 0)) + s2_acc_exp = m.out("s2_acc_exp", domain=cd, width=8, init=u(8, 0)) + s2_acc_mant = m.out("s2_acc_mant", domain=cd, width=FP32_MANT_FULL, init=u(FP32_MANT_FULL, 0)) + s2_prod_zero = m.out("s2_prod_zero", domain=cd, width=1, init=u(1, 0)) + s2_acc_zero = m.out("s2_acc_zero", domain=cd, width=1, init=u(1, 0)) + s2_valid = m.out("s2_valid", domain=cd, width=1, init=u(1, 0)) + + # Stage 3→4 registers + s3_result_sign = m.out("s3_result_sign", domain=cd, width=1, init=u(1, 0)) + s3_result_exp = m.out("s3_result_exp", domain=cd, width=10, init=u(10, 0)) + s3_result_mant = m.out("s3_result_mant", domain=cd, width=ACC_MANT_W, init=u(ACC_MANT_W, 0)) + s3_valid = m.out("s3_valid", domain=cd, width=1, init=u(1, 0)) + + # Output registers + result_r = m.out("result", domain=cd, width=FP32_W, init=u(FP32_W, 0)) + valid_r = m.out("result_valid", domain=cd, width=1, init=u(1, 0)) + + # ════════════════════════════════════════════════════════════ + # STAGE 1 (cycle 0): Unpack + exponent add + # ════════════════════════════════════════════════════════════ + s1_depth = 0 + + # Unpack BF16 a + a_sign = a_in[15] + a_exp = a_in[7:15] # 8 bits + a_mant_raw = a_in[0:7] # 7 bits + a_is_zero = a_exp == u(8, 0) + a_mant = (u(BF16_MANT_FULL, 0) if a_is_zero else + ((u(1, 1) | u(BF16_MANT_FULL, 0)) << BF16_MAN | + (a_mant_raw | u(BF16_MANT_FULL, 0)))) + s1_depth = max(s1_depth, 3) # mux + or + + # Unpack BF16 b + b_sign = b_in[15] + b_exp = b_in[7:15] + b_mant_raw = b_in[0:7] + b_is_zero = b_exp == u(8, 0) + b_mant = (u(BF16_MANT_FULL, 0) if b_is_zero else + ((u(1, 1) | u(BF16_MANT_FULL, 0)) << BF16_MAN | + (b_mant_raw | u(BF16_MANT_FULL, 0)))) + + # Unpack FP32 accumulator + acc_sign = acc_in[31] + acc_exp = acc_in[23:31] # 8 bits + acc_mant_raw = acc_in[0:23] # 23 bits + acc_is_zero = acc_exp == u(8, 0) + acc_mant = (u(FP32_MANT_FULL, 0) if acc_is_zero else + ((u(1, 1) | u(FP32_MANT_FULL, 0)) << FP32_MAN | + (acc_mant_raw | u(FP32_MANT_FULL, 0)))) + + # Product sign = a_sign XOR b_sign + prod_sign = a_sign ^ b_sign + s1_depth = max(s1_depth, 1) + + # Product exponent = a_exp + b_exp - bias (10-bit to handle overflow) + prod_exp_sum = (a_exp | u(10, 0)) + (b_exp | u(10, 0)) + prod_exp = prod_exp_sum - u(10, BF16_BIAS) + s1_depth = max(s1_depth, 8) + + # Product is zero if either input is zero + prod_zero = a_is_zero | b_is_zero + + # ── Partial product generation + 2 CSA rounds (still in Stage 1) ── + CSA_ROUNDS_IN_S1 = 2 + mul_inter_rows, pp_csa_depth = multiplier_pp_and_partial_reduce( + m, a_mant, b_mant, + BF16_MANT_FULL, BF16_MANT_FULL, + csa_rounds=CSA_ROUNDS_IN_S1, name="mantmul" + ) + n_inter_rows = len(mul_inter_rows) + s1_depth = max(s1_depth, 8 + pp_csa_depth) + + pipeline_depths["Stage 1: Unpack + PP + 2×CSA"] = s1_depth + + # ──── Pipeline register write (stage 1) ──── + s1_prod_sign.set(prod_sign) + s1_prod_exp.set(prod_exp) + s1_acc_sign.set(acc_sign) + s1_acc_exp.set(acc_exp) + s1_acc_mant.set(acc_mant) + s1_prod_zero.set(prod_zero) + s1_acc_zero.set(acc_is_zero) + s1_valid.set(valid_in) + for i in range(MAX_INTER_ROWS): + if i < n_inter_rows: + s1_mul_rows[i].set(mul_inter_rows[i]) + else: + s1_mul_rows[i].set(u(PROD_MANT_W, 0)) + s1_mul_nrows.set(u(4, n_inter_rows)) + + # ════════════════════════════════════════════════════════════ + # STAGE 2 (cycle 1): Complete multiply (remaining CSA + carry-select) + # ════════════════════════════════════════════════════════════ + prod_mant, mul_depth = multiplier_complete_reduce( + m, [s1_mul_rows[i].out() for i in range(n_inter_rows)], + PROD_MANT_W, name="mantmul" + ) + pipeline_depths["Stage 2: Complete Multiply"] = mul_depth + + # ──── Pipeline register write (stage 2) ──── + s2_prod_mant.set(prod_mant) + s2_prod_sign.set(s1_prod_sign.out()) + s2_prod_exp.set(s1_prod_exp.out()) + s2_acc_sign.set(s1_acc_sign.out()) + s2_acc_exp.set(s1_acc_exp.out()) + s2_acc_mant.set(s1_acc_mant.out()) + s2_prod_zero.set(s1_prod_zero.out()) + s2_acc_zero.set(s1_acc_zero.out()) + s2_valid.set(s1_valid.out()) + + # ════════════════════════════════════════════════════════════ + # STAGE 3 (cycle 2): Align + Add + # ════════════════════════════════════════════════════════════ + s3_depth = 0 + + s2_pm = s2_prod_mant.out() + s2_pe = s2_prod_exp.out() + s2_ps = s2_prod_sign.out() + s2_as = s2_acc_sign.out() + s2_ae = s2_acc_exp.out() + s2_am = s2_acc_mant.out() + s2_pz = s2_prod_zero.out() + + # Normalize product mantissa: 8×8 product is in 2.14 format (16 bits). + prod_msb = s2_pm[PROD_MANT_W - 1] + prod_mant_norm = (s2_pm >> 1) if prod_msb else s2_pm + prod_exp_norm = (s2_pe + 1) if prod_msb else s2_pe + s3_depth = s3_depth + 3 + + # Extend product mantissa to ACC_MANT_W (26 bits) + prod_mant_ext = (prod_mant_norm | u(ACC_MANT_W, 0)) << 9 + + # Extend accumulator mantissa to ACC_MANT_W + acc_mant_ext = s2_am | u(ACC_MANT_W, 0) + + # Determine exponent difference and align + prod_exp_8 = prod_exp_norm[0:8] + exp_diff_raw = prod_exp_8.as_signed() - s2_ae.as_signed() + exp_diff_pos = exp_diff_raw[0:8] + + prod_bigger = prod_exp_8 > s2_ae + exp_diff_abs = ((prod_exp_8 - s2_ae)[0:8] if prod_bigger else + (s2_ae - prod_exp_8)[0:8]) + s3_depth = s3_depth + 2 + + # Shift the smaller operand right to align + shift_5 = exp_diff_abs[0:5] + shift_capped = (u(5, ACC_MANT_W) if (exp_diff_abs > u(8, ACC_MANT_W)) + else shift_5) + + prod_aligned = (prod_mant_ext if prod_bigger else + barrel_shift_right(prod_mant_ext, shift_capped, ACC_MANT_W, 5, "prod_bsr")[0]) + acc_aligned = (barrel_shift_right(acc_mant_ext, shift_capped, ACC_MANT_W, 5, "acc_bsr")[0] + if prod_bigger else acc_mant_ext) + s3_depth = s3_depth + 12 + + result_exp = prod_exp_8 if prod_bigger else s2_ae + + # Add or subtract mantissas based on signs + same_sign = ~(s2_ps ^ s2_as) + sum_mant = ((prod_aligned | u(ACC_MANT_W+1, 0)) + + (acc_aligned | u(ACC_MANT_W+1, 0)))[0:ACC_MANT_W] + + mag_prod_ge = prod_aligned >= acc_aligned + diff_mant = ((prod_aligned - acc_aligned) if mag_prod_ge else + (acc_aligned - prod_aligned)) + + result_mant = sum_mant if same_sign else diff_mant + result_sign = (s2_ps if same_sign else + (s2_ps if mag_prod_ge else s2_as)) + s3_depth = s3_depth + 4 + + # Handle zeros + result_mant_final = acc_mant_ext if s2_pz else result_mant + result_exp_final = s2_ae if s2_pz else result_exp + result_sign_final = s2_as if s2_pz else result_sign + + pipeline_depths["Stage 3: Align + Add"] = s3_depth + + # ──── Pipeline register write (stage 3) ──── + s3_result_sign.set(result_sign_final) + s3_result_exp.set(result_exp_final | u(10, 0)) + s3_result_mant.set(result_mant_final) + s3_valid.set(s2_valid.out()) + + # ════════════════════════════════════════════════════════════ + # STAGE 4 (cycle 3): Normalize + Pack FP32 + # ════════════════════════════════════════════════════════════ + s4_depth = 0 + + s3_rm = s3_result_mant.out() + s3_re = s3_result_exp.out() + s3_rs = s3_result_sign.out() + s3_v = s3_valid.out() + + # Leading-zero count for normalization + lzc, lzc_depth = leading_zero_count(s3_rm, ACC_MANT_W, "norm_lzc") + s4_depth = s4_depth + lzc_depth + + GUARD_BITS = 2 + lzc_5 = lzc[0:5] + + need_left = lzc_5 > u(5, GUARD_BITS) + need_right = lzc_5 < u(5, GUARD_BITS) + + left_amt = (lzc_5 - u(5, GUARD_BITS))[0:5] + right_amt = (u(5, GUARD_BITS) - lzc_5)[0:5] + + left_shifted, bsl_depth = barrel_shift_left( + s3_rm, left_amt, ACC_MANT_W, 5, "norm_bsl") + right_shifted, _ = barrel_shift_right( + s3_rm, right_amt, ACC_MANT_W, 5, "norm_bsr") + + norm_mant = (left_shifted if need_left else + (right_shifted if need_right else s3_rm)) + s4_depth = s4_depth + bsl_depth + 4 + + # Adjust exponent: exp = exp + GUARD_BITS - lzc + norm_exp = s3_re + u(10, GUARD_BITS) - (lzc | u(10, 0)) + s4_depth = s4_depth + 4 + + # Extract FP32 mantissa: implicit 1 now at bit 23. + fp32_mant = norm_mant[0:23] # 23 fractional bits + + # Pack FP32: sign(1) | exp(8) | mantissa(23) + fp32_exp = norm_exp[0:8] + + # Handle zero result + result_is_zero = s3_rm == u(ACC_MANT_W, 0) + fp32_packed = (u(FP32_W, 0) if result_is_zero else + (((s3_rs | u(FP32_W, 0)) << 31) | + ((fp32_exp | u(FP32_W, 0)) << 23) | + (fp32_mant | u(FP32_W, 0)))) + s4_depth = s4_depth + 3 + + pipeline_depths["Stage 4: Normalize + Pack"] = s4_depth + + # ──── Output register write ──── + result_r.set(fp32_packed, when=s3_v) + valid_r.set(s3_v) + + # ════════════════════════════════════════════════════════════ + # Outputs + # ════════════════════════════════════════════════════════════ + m.output("result", result_r) + m.output("result_valid", valid_r) + + _pipeline_depths.update(pipeline_depths) + + +build.__pycircuit_name__ = "bf16_fmac" + +if __name__ == "__main__": + _pipeline_depths.clear() + circuit = compile_cycle_aware(build, name="bf16_fmac") + + print("\n" + "=" * 60) + print(" BF16 FMAC — Pipeline Critical Path Analysis") + print("=" * 60) + total = 0 + for stage, depth in _pipeline_depths.items(): + print(f" {stage:<35s} depth = {depth:>3d}") + total += depth + print(f" {'─' * 50}") + print(f" {'Total combinational depth':<35s} depth = {total:>3d}") + print(f" {'Max stage depth (critical path)':<35s} depth = {max(_pipeline_depths.values()):>3d}") + print("=" * 60 + "\n") + + mlir = circuit.emit_mlir() + print(f"MLIR: {len(mlir)} chars") diff --git a/designs/examples/fmac/fmac_capi.cpp b/designs/examples/fmac/fmac_capi.cpp new file mode 100644 index 0000000..c61d8a3 --- /dev/null +++ b/designs/examples/fmac/fmac_capi.cpp @@ -0,0 +1,54 @@ +/** + * fmac_capi.cpp — C API for the BF16 FMAC RTL model. + * + * Build (from pyCircuit root): + * c++ -std=c++17 -O2 -shared -fPIC -I include -I . \ + * -o examples/fmac/libfmac_sim.dylib examples/fmac/fmac_capi.cpp + */ +#include +#include +#include + +#include "examples/generated/fmac/bf16_fmac_gen.hpp" + +using pyc::cpp::Wire; + +struct SimContext { + pyc::gen::bf16_fmac dut{}; + pyc::cpp::Testbench tb; + uint64_t cycle = 0; + SimContext() : tb(dut) { tb.addClock(dut.clk, 1); } +}; + +extern "C" { + +SimContext* fmac_create() { return new SimContext(); } +void fmac_destroy(SimContext* c) { delete c; } + +void fmac_reset(SimContext* c, uint64_t n) { + c->tb.reset(c->dut.rst, n, 1); + c->dut.eval(); + c->cycle = 0; +} + +void fmac_push(SimContext* c, uint16_t a_bf16, uint16_t b_bf16, uint32_t acc_fp32) { + c->dut.a_in = Wire<16>(a_bf16); + c->dut.b_in = Wire<16>(b_bf16); + c->dut.acc_in = Wire<32>(acc_fp32); + c->dut.valid_in = Wire<1>(1u); + c->tb.runCycles(1); + c->cycle++; + c->dut.valid_in = Wire<1>(0u); +} + +void fmac_idle(SimContext* c, uint64_t n) { + c->dut.valid_in = Wire<1>(0u); + c->tb.runCycles(n); + c->cycle += n; +} + +uint32_t fmac_get_result(SimContext* c) { return c->dut.result.value(); } +uint32_t fmac_get_result_valid(SimContext* c) { return c->dut.result_valid.value(); } +uint64_t fmac_get_cycle(SimContext* c) { return c->cycle; } + +} // extern "C" diff --git a/designs/examples/fmac/primitive_standard_cells.py b/designs/examples/fmac/primitive_standard_cells.py new file mode 100644 index 0000000..a859c09 --- /dev/null +++ b/designs/examples/fmac/primitive_standard_cells.py @@ -0,0 +1,450 @@ +# -*- coding: utf-8 -*- +"""Primitive standard cells for building arithmetic from first principles. + +All functions accept and return Wire. Inputs are at most +4 bits wide. Higher-level structures (RCA, multiplier, etc.) are +composed by calling these primitives hierarchically. + +Logic depth tracking: each function returns (result, depth) where depth +is the combinational gate-level depth (AND/OR/XOR = 1 level each). +""" +from __future__ import annotations + +from pycircuit.hw import Wire, Reg +from pycircuit import u + + +def _mux(sel, t, f): + """Hardware mux usable outside JIT context: sel=1→t, sel=0→f.""" + if isinstance(sel, Reg): + sel = sel.q + if isinstance(sel, Wire): + return sel._select_internal(t, f) + return t if sel else f + + +# ═══════════════════════════════════════════════════════════════════ +# Level 0 — single-gate primitives (depth = 1) +# ═══════════════════════════════════════════════════════════════════ + +def inv(a: Wire) -> tuple[Wire, int]: + """Inverter. depth=1.""" + return ~a, 1 + + +def and2(a, b) -> tuple[Wire, int]: + """2-input AND. depth=1.""" + return a & b, 1 + + +def or2(a, b) -> tuple[Wire, int]: + """2-input OR. depth=1.""" + return a | b, 1 + + +def xor2(a, b) -> tuple[Wire, int]: + """2-input XOR. depth=1.""" + return a ^ b, 1 + + +def mux2(sel, a_true, a_false) -> tuple[Wire, int]: + """2:1 MUX (sel=1 → a_true). depth=2 (AND-OR).""" + return _mux(sel, a_true, a_false), 2 + + +# ═══════════════════════════════════════════════════════════════════ +# Level 1 — half adder, full adder (depth = 2–3) +# ═══════════════════════════════════════════════════════════════════ + +def half_adder(a, b) -> tuple[Wire, Wire, int]: + """Half adder. Returns (sum, carry_out, depth). + sum = a ^ b (depth 1) + cout = a & b (depth 1) + Total depth = 1. + """ + s = a ^ b + c = a & b + return s, c, 1 + + +def full_adder(a, b, cin) -> tuple[Wire, Wire, int]: + """Full adder. Returns (sum, carry_out, depth). + sum = a ^ b ^ cin (depth 2: xor chain) + cout = (a & b) | (cin & (a ^ b)) (depth 2: xor+and | and, then or) + Total depth = 2. + """ + ab = a ^ b # depth 1 + s = ab ^ cin # depth 2 + c = (a & b) | (cin & ab) # depth 2 (and + or in parallel with xor) + return s, c, 2 + + +# ═══════════════════════════════════════════════════════════════════ +# Level 2 — multi-bit adders (ripple-carry, depth = 2*N) +# ═══════════════════════════════════════════════════════════════════ + +def ripple_carry_adder(a_bits, b_bits, cin, name="rca"): + """N-bit ripple carry adder from full adders.""" + n = len(a_bits) + assert len(b_bits) == n, f"bit width mismatch: {n} vs {len(b_bits)}" + sums = [] + carry = cin + depth = 0 + for i in range(n): + s, carry, d = full_adder(a_bits[i], b_bits[i], carry) + depth = max(depth, 2 * (i + 1)) + sums.append(s) + return sums, carry, depth + + +def carry_select_adder(m, a_bits, b_bits, cin, name="csa"): + """N-bit carry-select adder — splits into halves for faster carry propagation.""" + n = len(a_bits) + assert len(b_bits) == n + if n <= 4: + return ripple_carry_adder(a_bits, b_bits, cin, name) + + half = n // 2 + lo_a, hi_a = a_bits[:half], a_bits[half:] + lo_b, hi_b = b_bits[:half], b_bits[half:] + + lo_sum, lo_cout, lo_depth = ripple_carry_adder( + lo_a, lo_b, cin, f"{name}_lo") + + zero_w = 0 + one_w = 1 + hi_sum0, hi_cout0, _ = ripple_carry_adder( + hi_a, hi_b, zero_w, f"{name}_hi0") + hi_sum1, hi_cout1, _ = ripple_carry_adder( + hi_a, hi_b, one_w, f"{name}_hi1") + + hi_sum = [_mux(lo_cout, hi_sum1[i], hi_sum0[i]) for i in range(len(hi_a))] + cout = _mux(lo_cout, hi_cout1, hi_cout0) + + depth = lo_depth + 2 + return lo_sum + hi_sum, cout, depth + + +def ripple_carry_adder_packed(a, b, cin, width, name="rca"): + """Packed version: takes N-bit signals, returns N-bit sum + cout.""" + a_bits = [a[i] for i in range(width)] + b_bits = [b[i] for i in range(width)] + cin_1 = cin if cin.width == 1 else cin[0] + + sum_bits, cout, depth = ripple_carry_adder(a_bits, b_bits, cin_1, name) + result = _recombine_bits(sum_bits, width) + return result, cout, depth + + +# ═══════════════════════════════════════════════════════════════════ +# Level 3 — partial-product generation for multiplier +# ═══════════════════════════════════════════════════════════════════ + +def and_gate_array(a_bit, b_bits): + """AND a single bit with each bit of b. Returns list of 1-bit signals.""" + return [a_bit & bb for bb in b_bits], 1 + + +def partial_product_array(a_bits, b_bits): + """Generate partial products for a*b (unsigned).""" + pp_rows = [] + for i, ab in enumerate(a_bits): + row, _ = and_gate_array(ab, b_bits) + pp_rows.append((row, i)) + return pp_rows, 1 + + +# ═══════════════════════════════════════════════════════════════════ +# Level 4 — partial-product reduction (Wallace/Dadda tree) +# ═══════════════════════════════════════════════════════════════════ + +def compress_3to2(a_bits, b_bits, c_bits): + """3:2 compressor (carry-save adder): reduces 3 rows to 2.""" + n = max(len(a_bits), len(b_bits), len(c_bits)) + sums = [] + carries = [] + for i in range(n): + a = a_bits[i] if i < len(a_bits) else None + b = b_bits[i] if i < len(b_bits) else None + c = c_bits[i] if i < len(c_bits) else None + + if a is None and b is None and c is None: + continue + if a is not None and b is not None and c is not None: + s, co, _ = full_adder(a, b, c) + sums.append(s) + carries.append(co) + elif a is not None and b is not None: + s, co, _ = half_adder(a, b) + sums.append(s) + carries.append(co) + elif a is not None: + sums.append(a) + elif b is not None: + sums.append(b) + else: + sums.append(c) + + return sums, carries, 2 + + +def reduce_partial_products(m, pp_rows, result_width, name="mul"): + """Reduce partial product rows to 2 rows using 3:2 compressors, + then final ripple-carry addition. + + `m` is a Circuit instance needed for creating Wire-type zero constants. + """ + zero = 0 + + rows = [] + for bits, shift in pp_rows: + padded = [None] * shift + list(bits) + [None] * (result_width - shift - len(bits)) + padded = padded[:result_width] + rows.append(padded) + + for r in range(len(rows)): + for col in range(result_width): + if rows[r][col] is None: + rows[r][col] = zero + + depth = 1 + + while len(rows) > 2: + new_rows = [] + i = 0 + round_depth = 0 + while i + 2 < len(rows): + a_row = rows[i] + b_row = rows[i + 1] + c_row = rows[i + 2] + s_row, c_row_out, d = compress_3to2(a_row, b_row, c_row) + c_shifted = [zero] + c_row_out + while len(s_row) < result_width: + s_row.append(zero) + while len(c_shifted) < result_width: + c_shifted.append(zero) + new_rows.append(s_row[:result_width]) + new_rows.append(c_shifted[:result_width]) + round_depth = max(round_depth, d) + i += 3 + while i < len(rows): + new_rows.append(rows[i]) + i += 1 + depth += round_depth + rows = new_rows + + if len(rows) == 2: + sum_bits, _, final_depth = carry_select_adder( + m, rows[0], rows[1], zero, name=f"{name}_final" + ) + depth += final_depth + elif len(rows) == 1: + sum_bits = rows[0] + else: + sum_bits = [zero] * result_width + + return sum_bits, depth + + +# ═══════════════════════════════════════════════════════════════════ +# Level 5 — N×M unsigned multiplier +# ═══════════════════════════════════════════════════════════════════ + +def unsigned_multiplier(m, a, b, a_width, b_width, name="umul"): + """Unsigned multiplier built from partial products + reduction tree. + + `m` is a Circuit instance. + """ + result_width = a_width + b_width + + a_bits = [a[i] for i in range(a_width)] + b_bits = [b[i] for i in range(b_width)] + + pp_rows, pp_depth = partial_product_array(a_bits, b_bits) + product_bits, tree_depth = reduce_partial_products( + m, pp_rows, result_width, name=name + ) + + result = _recombine_bits(product_bits, result_width) + return result, pp_depth + tree_depth + + +def _recombine_bits(bits, width): + """Pack a list of 1-bit signals (Wire or int) into a single N-bit signal.""" + const_mask = 0 + wire_parts = [] + for i in range(min(len(bits), width)): + b = bits[i] + if isinstance(b, int): + if b & 1: + const_mask |= (1 << i) + else: + wire_parts.append((i, b)) + + if not wire_parts: + return u(width, const_mask) + + i0, b0 = wire_parts[0] + result = (b0 | u(width, 0)) << i0 + for idx, b in wire_parts[1:]: + result = result | ((b | u(width, 0)) << idx) + + if const_mask: + result = result | u(width, const_mask) + return result + + +# ── Split multiplier (for cross-pipeline-stage multiply) ───── + +def multiplier_pp_and_partial_reduce(m, a, b, a_width, b_width, + csa_rounds=2, name="umul"): + """Stage A of a split multiplier: generate partial products and + run *csa_rounds* levels of 3:2 compression. + + `m` is a Circuit instance. + """ + result_width = a_width + b_width + zero = 0 + + a_bits = [a[i] for i in range(a_width)] + b_bits = [b[i] for i in range(b_width)] + + pp_rows, _ = partial_product_array(a_bits, b_bits) + depth = 1 + + rows = [] + for bits, shift in pp_rows: + padded = [None] * shift + list(bits) + [None] * (result_width - shift - len(bits)) + padded = padded[:result_width] + rows.append(padded) + for r in range(len(rows)): + for col in range(result_width): + if rows[r][col] is None: + rows[r][col] = zero + + for _round in range(csa_rounds): + if len(rows) <= 2: + break + new_rows = [] + i = 0 + round_depth = 0 + while i + 2 < len(rows): + s_row, c_row_out, d = compress_3to2(rows[i], rows[i+1], rows[i+2]) + c_shifted = [zero] + c_row_out + while len(s_row) < result_width: s_row.append(zero) + while len(c_shifted) < result_width: c_shifted.append(zero) + new_rows.append(s_row[:result_width]) + new_rows.append(c_shifted[:result_width]) + round_depth = max(round_depth, d) + i += 3 + while i < len(rows): + new_rows.append(rows[i]) + i += 1 + depth += round_depth + rows = new_rows + + packed = [] + for row in rows: + packed.append(_recombine_bits(row, result_width)) + + return packed, depth + + +def multiplier_complete_reduce(m, packed_rows, result_width, name="umul"): + """Stage B of a split multiplier: finish compression and final addition. + + `m` is a Circuit instance. + """ + zero = 0 + + rows = [] + for packed in packed_rows: + rows.append([packed[i] for i in range(result_width)]) + + depth = 0 + + while len(rows) > 2: + new_rows = [] + i = 0 + round_depth = 0 + while i + 2 < len(rows): + s_row, c_row_out, d = compress_3to2(rows[i], rows[i+1], rows[i+2]) + c_shifted = [zero] + c_row_out + while len(s_row) < result_width: s_row.append(zero) + while len(c_shifted) < result_width: c_shifted.append(zero) + new_rows.append(s_row[:result_width]) + new_rows.append(c_shifted[:result_width]) + round_depth = max(round_depth, d) + i += 3 + while i < len(rows): + new_rows.append(rows[i]) + i += 1 + depth += round_depth + rows = new_rows + + if len(rows) == 2: + sum_bits, _, final_depth = carry_select_adder( + m, rows[0], rows[1], zero, name=f"{name}_final") + depth += final_depth + product = _recombine_bits(sum_bits, result_width) + elif len(rows) == 1: + product = _recombine_bits(rows[0], result_width) + else: + product = u(result_width, 0) + + return product, depth + + +# ═══════════════════════════════════════════════════════════════════ +# Level 6 — shifters (barrel shifter from MUX layers) +# ═══════════════════════════════════════════════════════════════════ + +def barrel_shift_right(data, shift_amt, data_width, shift_bits, name="bsr"): + """Barrel right-shifter built from MUX layers. + + Each layer handles one bit of the shift amount. + depth = 2 * shift_bits (each MUX = depth 2). + """ + result = data + depth = 0 + for i in range(shift_bits): + shift_by = 1 << i + shifted = result >> shift_by + result = _mux(shift_amt[i], shifted, result) + depth += 2 + return result, depth + + +def barrel_shift_left(data, shift_amt, data_width, shift_bits, name="bsl"): + """Barrel left-shifter built from MUX layers. + + depth = 2 * shift_bits. + """ + result = data + depth = 0 + for i in range(shift_bits): + shift_by = 1 << i + shifted = result << shift_by + result = _mux(shift_amt[i], shifted, result) + depth += 2 + return result, depth + + +# ═══════════════════════════════════════════════════════════════════ +# Level 7 — leading-zero counter +# ═══════════════════════════════════════════════════════════════════ + +def leading_zero_count(data, width, name="lzc"): + """Count leading zeros using a priority encoder (MUX chain). + + depth ≈ 2 * log2(width). + """ + lzc_width = (width - 1).bit_length() + 1 + + count = u(lzc_width, width) + for bit_pos in range(width): + leading_zeros = width - 1 - bit_pos + count = _mux(data[bit_pos], u(lzc_width, leading_zeros), count) + + depth = 2 * ((width - 1).bit_length()) + return count, depth diff --git a/designs/examples/fmac/test_bf16_fmac.py b/designs/examples/fmac/test_bf16_fmac.py new file mode 100644 index 0000000..cfdc8d7 --- /dev/null +++ b/designs/examples/fmac/test_bf16_fmac.py @@ -0,0 +1,247 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +test_bf16_fmac.py — 100 test cases for the BF16 FMAC via true RTL simulation. + +Tests: acc_out = acc_in + a_bf16 * b_bf16 (BF16 inputs, FP32 accumulator) + +Verifies against Python float reference. Allows small rounding error +because the RTL uses fixed-width mantissas and integer arithmetic. + +Build first (from pyCircuit root): + c++ -std=c++17 -O2 -shared -fPIC -I include -I . \ + -o examples/fmac/libfmac_sim.dylib examples/fmac/fmac_capi.cpp + +Run: + python examples/fmac/test_bf16_fmac.py +""" +from __future__ import annotations + +import ctypes +import math +import random +import struct +import sys +import time +from pathlib import Path + +# ═══════════════════════════════════════════════════════════════════ +# ANSI +# ═══════════════════════════════════════════════════════════════════ +RESET = "\033[0m"; BOLD = "\033[1m"; DIM = "\033[2m" +RED = "\033[31m"; GREEN = "\033[32m"; YELLOW = "\033[33m"; CYAN = "\033[36m" + +# ═══════════════════════════════════════════════════════════════════ +# BF16 / FP32 conversion helpers +# ═══════════════════════════════════════════════════════════════════ + +def float_to_bf16(f: float) -> int: + """Convert Python float to BF16 (truncate, no rounding).""" + fp32 = struct.pack('>f', f) + return (fp32[0] << 8) | fp32[1] + + +def bf16_to_float(bf16: int) -> float: + """Convert BF16 to Python float.""" + fp32_bytes = bytes([(bf16 >> 8) & 0xFF, bf16 & 0xFF, 0, 0]) + return struct.unpack('>f', fp32_bytes)[0] + + +def float_to_fp32(f: float) -> int: + """Convert Python float to IEEE 754 FP32 (uint32).""" + return struct.unpack('>I', struct.pack('>f', f))[0] + + +def fp32_to_float(u32: int) -> float: + """Convert IEEE 754 FP32 (uint32) to Python float.""" + return struct.unpack('>f', struct.pack('>I', u32 & 0xFFFFFFFF))[0] + + +# ═══════════════════════════════════════════════════════════════════ +# RTL wrapper +# ═══════════════════════════════════════════════════════════════════ + +PIPELINE_DEPTH = 4 # 4-stage pipeline + + +class FmacRTL: + def __init__(self, lib_path=None): + if lib_path is None: + lib_path = str(Path(__file__).resolve().parent / "libfmac_sim.dylib") + L = ctypes.CDLL(lib_path) + L.fmac_create.restype = ctypes.c_void_p + L.fmac_destroy.argtypes = [ctypes.c_void_p] + L.fmac_reset.argtypes = [ctypes.c_void_p, ctypes.c_uint64] + L.fmac_push.argtypes = [ctypes.c_void_p, ctypes.c_uint16, ctypes.c_uint16, ctypes.c_uint32] + L.fmac_idle.argtypes = [ctypes.c_void_p, ctypes.c_uint64] + L.fmac_get_result.argtypes = [ctypes.c_void_p]; L.fmac_get_result.restype = ctypes.c_uint32 + L.fmac_get_result_valid.argtypes = [ctypes.c_void_p]; L.fmac_get_result_valid.restype = ctypes.c_uint32 + L.fmac_get_cycle.argtypes = [ctypes.c_void_p]; L.fmac_get_cycle.restype = ctypes.c_uint64 + self._L, self._c = L, L.fmac_create() + + def __del__(self): + if hasattr(self, '_c') and self._c: + self._L.fmac_destroy(self._c) + + def reset(self): + self._L.fmac_reset(self._c, 2) + + def compute(self, a_bf16: int, b_bf16: int, acc_fp32: int) -> int: + """Push inputs, wait for pipeline, return FP32 result.""" + self._L.fmac_push(self._c, a_bf16, b_bf16, acc_fp32) + # Wait for pipeline to flush (PIPELINE_DEPTH cycles) + self._L.fmac_idle(self._c, PIPELINE_DEPTH + 2) + return self._L.fmac_get_result(self._c) + + +# ═══════════════════════════════════════════════════════════════════ +# Test generation +# ═══════════════════════════════════════════════════════════════════ + +def make_test_cases(): + """Generate 100 test cases: (a_float, b_float, acc_float).""" + cases = [] + + # Group 1: Simple integer-like values (20 cases) + simple_pairs = [ + (1.0, 1.0, 0.0), (2.0, 3.0, 0.0), (1.5, 2.0, 0.0), + (0.5, 4.0, 0.0), (1.0, 0.0, 0.0), (0.0, 5.0, 0.0), + (1.0, 1.0, 1.0), (2.0, 3.0, 1.0), (1.5, 2.0, 10.0), + (-1.0, 1.0, 0.0), (-2.0, 3.0, 0.0), (1.0, -1.0, 0.0), + (-1.0, -1.0, 0.0), (2.0, 2.0, -8.0), (3.0, 3.0, -9.0), + (0.5, 0.5, 0.0), (0.25, 4.0, 0.0), (8.0, 0.125, 0.0), + (10.0, 10.0, 0.0), (100.0, 0.01, 0.0), + ] + cases.extend(simple_pairs) + + # Group 2: Powers of 2 (10 cases) + for i in range(10): + a = 2.0 ** (i - 3) + b = 2.0 ** (5 - i) + acc = 0.0 + cases.append((a, b, acc)) + + # Group 3: Small values (10 cases) + for i in range(10): + a = (i + 1) * 0.0625 + b = (10 - i) * 0.125 + acc = i * 0.5 + cases.append((a, b, acc)) + + # Group 4: Accumulation chain (10 cases) — acc carries over + for i in range(10): + a = float(i + 1) + b = 0.5 + acc = float(i * 2) + cases.append((a, b, acc)) + + # Group 5: Negative accumulator (10 cases) + for i in range(10): + a = float(i + 1) + b = float(i + 2) + acc = -float((i + 1) * (i + 2)) # acc = -(a*b), so result ≈ 0 + cases.append((a, b, acc)) + + # Group 6: Random values (40 cases) + rng = random.Random(42) + for _ in range(40): + # Random BF16-representable values + a = bf16_to_float(float_to_bf16(rng.uniform(-10, 10))) + b = bf16_to_float(float_to_bf16(rng.uniform(-10, 10))) + acc = fp32_to_float(float_to_fp32(rng.uniform(-100, 100))) + cases.append((a, b, acc)) + + return cases[:100] + + +# ═══════════════════════════════════════════════════════════════════ +# Main test runner +# ═══════════════════════════════════════════════════════════════════ + +def main(): + print(f" {BOLD}BF16 FMAC — 100 Test Cases (True RTL Simulation){RESET}") + print(f" {'=' * 55}") + + # Print pipeline depth analysis + print(f"\n {CYAN}Pipeline Critical Path Analysis:{RESET}") + depths = { + "Stage 1: Unpack + PP + 2×CSA": 13, + "Stage 2: Complete Multiply": 22, + "Stage 3: Align + Add": 21, + "Stage 4: Normalize + Pack": 31, + } + for stage, d in depths.items(): + bar = "█" * (d // 2) + print(f" {stage:<35s} depth={d:>3d} {CYAN}{bar}{RESET}") + print(f" {'─' * 50}") + print(f" {'Max stage (critical path)':<35s} depth={max(depths.values()):>3d}") + print() + + sim = FmacRTL() + sim.reset() + + cases = make_test_cases() + passed = 0 + failed = 0 + max_err = 0.0 + + t0 = time.time() + + for i, (a_f, b_f, acc_f) in enumerate(cases): + a_bf16 = float_to_bf16(a_f) + b_bf16 = float_to_bf16(b_f) + acc_u32 = float_to_fp32(acc_f) + + # RTL result + result_u32 = sim.compute(a_bf16, b_bf16, acc_u32) + rtl_f = fp32_to_float(result_u32) + + # Python reference: acc + a * b + # Use BF16-truncated values for fair comparison + a_exact = bf16_to_float(a_bf16) + b_exact = bf16_to_float(b_bf16) + acc_exact = fp32_to_float(acc_u32) + expected_f = acc_exact + a_exact * b_exact + + # Tolerance: allow ~1% relative error or 1e-4 absolute + # (BF16 has limited mantissa precision) + if expected_f == 0: + err = abs(rtl_f) + ok = err < 0.01 + else: + err = abs(rtl_f - expected_f) / max(abs(expected_f), 1e-10) + ok = err < 0.02 # 2% relative error tolerance for BF16 precision + + max_err = max(max_err, err) + + if ok: + passed += 1 + status = f"{GREEN}PASS{RESET}" + else: + failed += 1 + status = f"{RED}FAIL{RESET}" + + # Print each test + tag = f"{DIM}" if ok else f"{BOLD}" + print(f" {tag}[{i+1:3d}/100]{RESET} " + f"a={a_exact:>9.4f} b={b_exact:>9.4f} acc={acc_exact:>10.4f} → " + f"RTL={rtl_f:>12.4f} exp={expected_f:>12.4f} " + f"err={err:.2e} {status}") + + t1 = time.time() + + print(f"\n {'=' * 55}") + print(f" Results: {GREEN}{passed}{RESET}/{len(cases)} passed, " + f"{RED}{failed}{RESET} failed") + print(f" Max relative error: {max_err:.2e}") + print(f" Time: {t1 - t0:.2f}s") + + if failed == 0: + print(f" {GREEN}{BOLD}ALL 100 TESTS PASSED (TRUE RTL SIMULATION).{RESET}\n") + else: + print(f" {RED}{BOLD}{failed} tests failed.{RESET}\n") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/designs/examples/hier_modules/hier_modules.py b/designs/examples/hier_modules/hier_modules.py index 64b865e..b46d83d 100644 --- a/designs/examples/hier_modules/hier_modules.py +++ b/designs/examples/hier_modules/hier_modules.py @@ -1,15 +1,17 @@ from __future__ import annotations -from pycircuit import Circuit, compile, module +from pycircuit import ( + CycleAwareCircuit, + CycleAwareDomain, + compile_cycle_aware, +) -@module -def _incrementer(m: Circuit, x, *, width: int = 8): - m.output("y", (x + 1)[0:width]) +def _incrementer(m, x, *, width: int = 8): + return (x + 1)[0:width] -@module -def build(m: Circuit, width: int = 8, stages: int = 3) -> None: +def build(m: CycleAwareCircuit, domain: CycleAwareDomain, width: int = 8, stages: int = 3) -> None: x = m.input("x", width=width) v_conn = x for i in range(stages): @@ -17,9 +19,8 @@ def build(m: Circuit, width: int = 8, stages: int = 3) -> None: m.output("y", v_conn) - build.__pycircuit_name__ = "hier_modules" if __name__ == "__main__": - print(compile(build, name="hier_modules", width=8, stages=3).emit_mlir()) + print(compile_cycle_aware(build, name="hier_modules", eager=True, width=8, stages=3).emit_mlir()) diff --git a/designs/examples/hier_modules/tb_hier_modules.py b/designs/examples/hier_modules/tb_hier_modules.py index 6aaac08..bcaf030 100644 --- a/designs/examples/hier_modules/tb_hier_modules.py +++ b/designs/examples/hier_modules/tb_hier_modules.py @@ -3,7 +3,7 @@ import sys from pathlib import Path -from pycircuit import Tb, compile, testbench +from pycircuit import CycleAwareTb, Tb, compile_cycle_aware, CycleAwareCircuit, CycleAwareDomain, testbench _THIS_DIR = Path(__file__).resolve().parent if str(_THIS_DIR) not in sys.path: @@ -15,12 +15,16 @@ @testbench def tb(t: Tb) -> None: + tb = CycleAwareTb(t) p = TB_PRESETS["smoke"] - t.timeout(int(p["timeout"])) - t.drive("x", 1, at=0) - t.expect("y", 4, at=0) - t.finish(at=int(p["finish"])) + tb.timeout(int(p["timeout"])) + + # --- cycle 0 --- + tb.drive("x", 1) + tb.expect("y", 4) + + tb.finish(at=int(p["finish"])) if __name__ == "__main__": - print(compile(build, name="tb_hier_modules_top", **DEFAULT_PARAMS).emit_mlir()) + print(compile_cycle_aware(build, name="tb_hier_modules_top", eager=True, **DEFAULT_PARAMS).emit_mlir()) diff --git a/designs/examples/huge_hierarchy_stress/huge_hierarchy_stress.py b/designs/examples/huge_hierarchy_stress/huge_hierarchy_stress.py index f33dff8..f5df4b7 100644 --- a/designs/examples/huge_hierarchy_stress/huge_hierarchy_stress.py +++ b/designs/examples/huge_hierarchy_stress/huge_hierarchy_stress.py @@ -1,6 +1,6 @@ from __future__ import annotations -from pycircuit import Circuit, Connector, compile, const, ct, function, module, spec, u +from pycircuit import Circuit, Connector, compile_cycle_aware, CycleAwareCircuit, CycleAwareDomain, const, ct, function, module, spec, u from pycircuit.lib import Cache @@ -122,10 +122,7 @@ def _node( m.output("y", y) -@module -def build( - m: Circuit, - *, +def build(m: CycleAwareCircuit, domain: CycleAwareDomain, *, width: int = 64, module_count: int = 32, hierarchy_depth: int = 2, @@ -133,8 +130,9 @@ def build( cache_ways: int = 4, cache_sets: int = 64, ): - clk = m.clock("clk") - rst = m.reset("rst") + cd = domain.clock_domain + clk = cd.clk + rst = cd.rst in_spec = _top_in_struct(m, width=width) top_in = m.inputs(in_spec, prefix="") @@ -177,13 +175,12 @@ def build( cur = _mix3(m, cur, yi.read(), cur.lshr(amount=(i % max(1, width // 8)) + 1)) req_wmask_w = max(1, width // 8) - cache_req_wmask = u(req_wmask_w, ct.bitmask(req_wmask_w)) - cache_req_write = u(1, 0) - cache_req_valid = u(1, 1) + cache_req_wmask = m.const(ct.bitmask(req_wmask_w), width=req_wmask_w) + cache_req_write = m.const(0, width=1) + cache_req_valid = m.const(1, width=1) cache = Cache( m, - clk=clk, - rst=rst, + cd, req_valid=cache_req_valid, req_addr=cur, req_write=cache_req_write, @@ -210,7 +207,7 @@ def build( if __name__ == "__main__": print( - compile(build, name="huge_hierarchy_stress", + compile_cycle_aware(build, name="huge_hierarchy_stress", width=64, module_count=16, hierarchy_depth=2, diff --git a/designs/examples/huge_hierarchy_stress/tb_huge_hierarchy_stress.py b/designs/examples/huge_hierarchy_stress/tb_huge_hierarchy_stress.py index a174467..5ceec2c 100644 --- a/designs/examples/huge_hierarchy_stress/tb_huge_hierarchy_stress.py +++ b/designs/examples/huge_hierarchy_stress/tb_huge_hierarchy_stress.py @@ -3,7 +3,7 @@ import sys from pathlib import Path -from pycircuit import Tb, compile, testbench +from pycircuit import CycleAwareTb, Tb, compile_cycle_aware, CycleAwareCircuit, CycleAwareDomain, testbench _THIS_DIR = Path(__file__).resolve().parent if str(_THIS_DIR) not in sys.path: @@ -15,13 +15,15 @@ @testbench def tb(t: Tb) -> None: + tb = CycleAwareTb(t) p = TB_PRESETS["smoke"] - t.clock("clk") - t.reset("rst", cycles_asserted=2, cycles_deasserted=1) - t.timeout(int(p["timeout"])) - t.drive("seed", 0x1234, at=0) - t.finish(at=int(p["finish"])) + tb.clock("clk") + tb.reset("rst", cycles_asserted=2, cycles_deasserted=1) + tb.timeout(int(p["timeout"])) + # --- cycle 0 --- + tb.drive("seed", 0x1234) + tb.finish(at=int(p["finish"])) if __name__ == "__main__": - print(compile(build, name="tb_huge_hierarchy_stress_top", **DEFAULT_PARAMS).emit_mlir()) + print(compile_cycle_aware(build, name="tb_huge_hierarchy_stress_top", **DEFAULT_PARAMS).emit_mlir()) diff --git a/designs/examples/instance_map/instance_map.py b/designs/examples/instance_map/instance_map.py index c0a488c..d8a179d 100644 --- a/designs/examples/instance_map/instance_map.py +++ b/designs/examples/instance_map/instance_map.py @@ -1,6 +1,6 @@ from __future__ import annotations -from pycircuit import Circuit, compile, const, module, spec, u, wiring +from pycircuit import Circuit, compile_cycle_aware, CycleAwareCircuit, CycleAwareDomain, const, module, spec, u, wiring @const @@ -15,7 +15,6 @@ def _unit_out_spec(m: Circuit, *, width: int): return spec.struct("unit_out").field("y", width=width).field("valid", width=1).build() -@module(structural=True) def _unit(m: Circuit, *, width: int = 32, gain: int = 1): in_spec = _unit_in_spec(m, width=width) out_spec = _unit_out_spec(m, width=width) @@ -34,8 +33,7 @@ def _top_struct(m: Circuit, *, width: int): return s.add_field("lsu", width=width).rename_field("bru", "branch").select_fields(["alu", "branch", "lsu"]) -@module -def build(m: Circuit, *, width: int = 32): +def build(m: CycleAwareCircuit, domain: CycleAwareDomain, *, width: int = 32): top_spec = _top_struct(m, width=width) top_in = m.inputs(top_spec, prefix="in_") @@ -75,4 +73,4 @@ def build(m: Circuit, *, width: int = 32): build.__pycircuit_name__ = "instance_map" if __name__ == "__main__": - print(compile(build, name="instance_map", width=32).emit_mlir()) + print(compile_cycle_aware(build, name="instance_map", width=32).emit_mlir()) diff --git a/designs/examples/instance_map/tb_instance_map.py b/designs/examples/instance_map/tb_instance_map.py index e20bd15..80f98e7 100644 --- a/designs/examples/instance_map/tb_instance_map.py +++ b/designs/examples/instance_map/tb_instance_map.py @@ -3,7 +3,7 @@ import sys from pathlib import Path -from pycircuit import Tb, compile, testbench +from pycircuit import CycleAwareTb, Tb, compile_cycle_aware, CycleAwareCircuit, CycleAwareDomain, testbench _THIS_DIR = Path(__file__).resolve().parent if str(_THIS_DIR) not in sys.path: @@ -15,17 +15,21 @@ @testbench def tb(t: Tb) -> None: + tb = CycleAwareTb(t) p = TB_PRESETS["smoke"] - t.timeout(int(p["timeout"])) - t.drive("in_alu", 0, at=0) - t.drive("in_branch", 0, at=0) - t.drive("in_lsu", 0, at=0) - t.expect("alu_y", 1, at=0) - t.expect("branch_y", 2, at=0) - t.expect("lsu_y", 3, at=0) - t.expect("acc", 6, at=0) - t.finish(at=int(p["finish"])) + tb.timeout(int(p["timeout"])) + + # --- cycle 0 --- + tb.drive("in_alu", 0) + tb.drive("in_branch", 0) + tb.drive("in_lsu", 0) + tb.expect("alu_y", 1) + tb.expect("branch_y", 2) + tb.expect("lsu_y", 3) + tb.expect("acc", 6) + + tb.finish(at=int(p["finish"])) if __name__ == "__main__": - print(compile(build, name="tb_instance_map_top", **DEFAULT_PARAMS).emit_mlir()) + print(compile_cycle_aware(build, name="tb_instance_map_top", **DEFAULT_PARAMS).emit_mlir()) diff --git a/designs/examples/interface_wiring/interface_wiring.py b/designs/examples/interface_wiring/interface_wiring.py index 6d0093c..f596583 100644 --- a/designs/examples/interface_wiring/interface_wiring.py +++ b/designs/examples/interface_wiring/interface_wiring.py @@ -1,6 +1,6 @@ from __future__ import annotations -from pycircuit import Circuit, compile, const, module, spec, wiring +from pycircuit import Circuit, compile_cycle_aware, CycleAwareCircuit, CycleAwareDomain, const, module, spec, wiring @const @@ -10,7 +10,6 @@ def _pair_spec(m: Circuit, *, width: int): return base.remove_field("drop").rename_field("right", "rhs").select_fields(["left", "rhs"]) -@module def pair_add(m: Circuit, *, width: int = 16): spec = _pair_spec(m, width=width) ins = m.inputs(spec, prefix="in_") @@ -19,8 +18,7 @@ def pair_add(m: Circuit, *, width: int = 16): m.outputs(spec, {"left": a, "rhs": (a + b)[0:width]}, prefix="out_") -@module -def build(m: Circuit, *, width: int = 16): +def build(m: CycleAwareCircuit, domain: CycleAwareDomain, *, width: int = 16): in_spec = _pair_spec(m, width=width) top_in = m.inputs(in_spec, prefix="top_in_") h = m.new( @@ -42,4 +40,4 @@ def build(m: Circuit, *, width: int = 16): build.__pycircuit_name__ = "interface_wiring" if __name__ == "__main__": - print(compile(build, name="interface_wiring", width=16).emit_mlir()) + print(compile_cycle_aware(build, name="interface_wiring", width=16).emit_mlir()) diff --git a/designs/examples/interface_wiring/tb_interface_wiring.py b/designs/examples/interface_wiring/tb_interface_wiring.py index b3fb4ea..412c2f2 100644 --- a/designs/examples/interface_wiring/tb_interface_wiring.py +++ b/designs/examples/interface_wiring/tb_interface_wiring.py @@ -3,7 +3,7 @@ import sys from pathlib import Path -from pycircuit import Tb, compile, testbench +from pycircuit import CycleAwareTb, Tb, compile_cycle_aware, CycleAwareCircuit, CycleAwareDomain, testbench _THIS_DIR = Path(__file__).resolve().parent if str(_THIS_DIR) not in sys.path: @@ -15,14 +15,18 @@ @testbench def tb(t: Tb) -> None: + tb = CycleAwareTb(t) p = TB_PRESETS["smoke"] - t.timeout(int(p["timeout"])) - t.drive("top_in_left", 1, at=0) - t.drive("top_in_rhs", 2, at=0) - t.expect("top_out_left", 1, at=0) - t.expect("top_out_rhs", 3, at=0) - t.finish(at=int(p["finish"])) + tb.timeout(int(p["timeout"])) + + # --- cycle 0 --- + tb.drive("top_in_left", 1) + tb.drive("top_in_rhs", 2) + tb.expect("top_out_left", 1) + tb.expect("top_out_rhs", 3) + + tb.finish(at=int(p["finish"])) if __name__ == "__main__": - print(compile(build, name="tb_interface_wiring_top", **DEFAULT_PARAMS).emit_mlir()) + print(compile_cycle_aware(build, name="tb_interface_wiring_top", **DEFAULT_PARAMS).emit_mlir()) diff --git a/designs/examples/issue_queue_2picker/issue_queue_2picker.py b/designs/examples/issue_queue_2picker/issue_queue_2picker.py index f531672..fb65afd 100644 --- a/designs/examples/issue_queue_2picker/issue_queue_2picker.py +++ b/designs/examples/issue_queue_2picker/issue_queue_2picker.py @@ -1,29 +1,29 @@ from __future__ import annotations -from pycircuit import Circuit, compile, function, module, u +from pycircuit import ( + CycleAwareCircuit, + CycleAwareDomain, + cas, + compile_cycle_aware, + mux, +) -@function -def _shift4(m: Circuit, v: list, d: list, z): - _ = m +def _shift4(v: list, d: list, z): return [v[1], v[2], v[3], z], [d[1], d[2], d[3], d[3]] -@module -def build(m: Circuit) -> None: - clk = m.clock("clk") - rst = m.reset("rst") +def build(m: CycleAwareCircuit, domain: CycleAwareDomain) -> None: + in_valid = cas(domain, m.input("in_valid", width=1), cycle=0) + in_data = cas(domain, m.input("in_data", width=8), cycle=0) + out0_ready = cas(domain, m.input("out0_ready", width=1), cycle=0) + out1_ready = cas(domain, m.input("out1_ready", width=1), cycle=0) - in_valid = m.input("in_valid", width=1) - in_data = m.input("in_data", width=8) - out0_ready = m.input("out0_ready", width=1) - out1_ready = m.input("out1_ready", width=1) + vals = [domain.state(width=1, reset_value=0, name=f"val{i}") for i in range(4)] + data = [domain.state(width=8, reset_value=0, name=f"data{i}") for i in range(4)] - vals = [m.out(f"val{i}", clk=clk, rst=rst, width=1, init=u(1, 0)) for i in range(4)] - data = [m.out(f"data{i}", clk=clk, rst=rst, width=8, init=u(8, 0)) for i in range(4)] - - v0 = [x.out() for x in vals] - d0 = [x.out() for x in data] + v0 = [x for x in vals] + d0 = [x for x in data] out0_valid = v0[0] out1_valid = v0[1] pop0 = out0_valid & out0_ready @@ -31,14 +31,14 @@ def build(m: Circuit) -> None: in_ready = ~v0[3] | pop0 push = in_valid & in_ready - z1 = u(1, 0) - s1_v, s1_d = _shift4(m, v0, d0, z1) - a1_v = [s1_v[i] if pop0 else v0[i] for i in range(4)] - a1_d = [s1_d[i] if pop0 else d0[i] for i in range(4)] + zero1 = cas(domain, m.const(0, width=1), cycle=0) + s1_v, s1_d = _shift4(v0, d0, zero1) + a1_v = [mux(pop0, s1_v[i], v0[i]) for i in range(4)] + a1_d = [mux(pop0, s1_d[i], d0[i]) for i in range(4)] - s2_v, s2_d = _shift4(m, a1_v, a1_d, z1) - a2_v = [s2_v[i] if pop1 else a1_v[i] for i in range(4)] - a2_d = [s2_d[i] if pop1 else a1_d[i] for i in range(4)] + s2_v, s2_d = _shift4(a1_v, a1_d, zero1) + a2_v = [mux(pop1, s2_v[i], a1_v[i]) for i in range(4)] + a2_d = [mux(pop1, s2_d[i], a1_d[i]) for i in range(4)] en = [] pref = push @@ -47,20 +47,21 @@ def build(m: Circuit) -> None: en.append(en_i) pref = pref & a2_v[i] - for i in range(4): - vals[i].set(a2_v[i] | en[i]) - data[i].set(in_data if en[i] else a2_d[i]) + m.output("in_ready", in_ready.wire) + m.output("out0_valid", out0_valid.wire) + m.output("out0_data", d0[0].wire) + m.output("out1_valid", out1_valid.wire) + m.output("out1_data", d0[1].wire) - m.output("in_ready", in_ready) - m.output("out0_valid", out0_valid) - m.output("out0_data", d0[0]) - m.output("out1_valid", out1_valid) - m.output("out1_data", d0[1]) + domain.next() + for i in range(4): + vals[i].set(a2_v[i] | en[i]) + data[i].set(mux(en[i], in_data, a2_d[i])) build.__pycircuit_name__ = "issue_queue_2picker" if __name__ == "__main__": - print(compile(build, name="issue_queue_2picker").emit_mlir()) + print(compile_cycle_aware(build, name="issue_queue_2picker", eager=True).emit_mlir()) diff --git a/designs/examples/issue_queue_2picker/tb_issue_queue_2picker.py b/designs/examples/issue_queue_2picker/tb_issue_queue_2picker.py index b7cd8b2..1e86567 100644 --- a/designs/examples/issue_queue_2picker/tb_issue_queue_2picker.py +++ b/designs/examples/issue_queue_2picker/tb_issue_queue_2picker.py @@ -3,7 +3,7 @@ import sys from pathlib import Path -from pycircuit import Tb, compile, testbench +from pycircuit import CycleAwareTb, Tb, compile_cycle_aware, CycleAwareCircuit, CycleAwareDomain, testbench _THIS_DIR = Path(__file__).resolve().parent if str(_THIS_DIR) not in sys.path: @@ -15,17 +15,19 @@ @testbench def tb(t: Tb) -> None: + tb = CycleAwareTb(t) p = TB_PRESETS["smoke"] - t.clock("clk") - t.reset("rst", cycles_asserted=2, cycles_deasserted=1) - t.timeout(int(p["timeout"])) - t.drive("in_valid", 0, at=0) - t.drive("in_data", 0, at=0) - t.drive("out0_ready", 0, at=0) - t.drive("out1_ready", 0, at=0) - t.expect("in_ready", 1, at=0) - t.finish(at=int(p["finish"])) + tb.clock("clk") + tb.reset("rst", cycles_asserted=2, cycles_deasserted=1) + tb.timeout(int(p["timeout"])) + # --- cycle 0 --- + tb.drive("in_valid", 0) + tb.drive("in_data", 0) + tb.drive("out0_ready", 0) + tb.drive("out1_ready", 0) + tb.expect("in_ready", 1) + tb.finish(at=int(p["finish"])) if __name__ == "__main__": - print(compile(build, name="tb_issue_queue_2picker_top", **DEFAULT_PARAMS).emit_mlir()) + print(compile_cycle_aware(build, name="tb_issue_queue_2picker_top", eager=True, **DEFAULT_PARAMS).emit_mlir()) diff --git a/designs/examples/jit_control_flow/jit_control_flow.py b/designs/examples/jit_control_flow/jit_control_flow.py index 938b59b..a58b1bb 100644 --- a/designs/examples/jit_control_flow/jit_control_flow.py +++ b/designs/examples/jit_control_flow/jit_control_flow.py @@ -1,33 +1,37 @@ from __future__ import annotations -from pycircuit import Circuit, compile, module, u - - -@module -def build(m: Circuit, rounds: int = 4) -> None: - a = m.input("a", width=8) - b = m.input("b", width=8) - op = m.input("op", width=2) - - acc = a + u(8, 0) - if op == u(2, 0): - acc = a + b - elif op == u(2, 1): - acc = a - b - elif op == u(2, 2): - acc = a ^ b - else: - acc = a & b +from pycircuit import ( + CycleAwareCircuit, + CycleAwareDomain, + cas, + compile_cycle_aware, + mux, + u, +) + + +def build(m: CycleAwareCircuit, domain: CycleAwareDomain, rounds: int = 4) -> None: + a = cas(domain, m.input("a", width=8), cycle=0) + b = cas(domain, m.input("b", width=8), cycle=0) + op = cas(domain, m.input("op", width=2), cycle=0) + + op0 = cas(domain, m.const(0, width=2), cycle=0) + op1 = cas(domain, m.const(1, width=2), cycle=0) + op2 = cas(domain, m.const(2, width=2), cycle=0) + + acc = mux(op == op0, a + b, + mux(op == op1, a - b, + mux(op == op2, a ^ b, + a & b))) for _ in range(rounds): acc = acc + 1 - m.output("result", acc) - + m.output("result", acc.wire) build.__pycircuit_name__ = "jit_control_flow" if __name__ == "__main__": - print(compile(build, name="jit_control_flow", rounds=4).emit_mlir()) + print(compile_cycle_aware(build, name="jit_control_flow", eager=True, rounds=4).emit_mlir()) diff --git a/designs/examples/jit_control_flow/tb_jit_control_flow.py b/designs/examples/jit_control_flow/tb_jit_control_flow.py index e11e5f8..733a7de 100644 --- a/designs/examples/jit_control_flow/tb_jit_control_flow.py +++ b/designs/examples/jit_control_flow/tb_jit_control_flow.py @@ -3,7 +3,7 @@ import sys from pathlib import Path -from pycircuit import Tb, compile, testbench +from pycircuit import CycleAwareTb, Tb, compile_cycle_aware, CycleAwareCircuit, CycleAwareDomain, testbench _THIS_DIR = Path(__file__).resolve().parent if str(_THIS_DIR) not in sys.path: @@ -15,14 +15,18 @@ @testbench def tb(t: Tb) -> None: + tb = CycleAwareTb(t) p = TB_PRESETS["smoke"] - t.timeout(int(p["timeout"])) - t.drive("a", 1, at=0) - t.drive("b", 2, at=0) - t.drive("op", 0, at=0) - t.expect("result", 7, at=0) - t.finish(at=int(p["finish"])) + tb.timeout(int(p["timeout"])) + + # --- cycle 0 --- + tb.drive("a", 1) + tb.drive("b", 2) + tb.drive("op", 0) + tb.expect("result", 7) + + tb.finish(at=int(p["finish"])) if __name__ == "__main__": - print(compile(build, name="tb_jit_control_flow_top", **DEFAULT_PARAMS).emit_mlir()) + print(compile_cycle_aware(build, name="tb_jit_control_flow_top", eager=True, **DEFAULT_PARAMS).emit_mlir()) diff --git a/designs/examples/jit_pipeline_vec/jit_pipeline_vec.py b/designs/examples/jit_pipeline_vec/jit_pipeline_vec.py index eb204ae..662ea70 100644 --- a/designs/examples/jit_pipeline_vec/jit_pipeline_vec.py +++ b/designs/examples/jit_pipeline_vec/jit_pipeline_vec.py @@ -1,36 +1,34 @@ from __future__ import annotations -from pycircuit import Circuit, compile, module, u +from pycircuit import ( + CycleAwareCircuit, + CycleAwareDomain, + cas, + compile_cycle_aware, + mux, +) -@module -def build(m: Circuit, stages: int = 3) -> None: - clk = m.clock("clk") - rst = m.reset("rst") +def build(m: CycleAwareCircuit, domain: CycleAwareDomain, stages: int = 3) -> None: + a = cas(domain, m.input("a", width=16), cycle=0) + b = cas(domain, m.input("b", width=16), cycle=0) + sel = cas(domain, m.input("sel", width=1), cycle=0) - a = m.input("a", width=16) - b = m.input("b", width=16) - sel = m.input("sel", width=1) - - tag = a == b - data = a + b if sel else a ^ b + tag = (a == b) + data = mux(sel, a + b, a ^ b) for i in range(stages): - tag_q = m.out(f"tag_s{i}", clk=clk, rst=rst, width=1, init=u(1, 0)) - data_q = m.out(f"data_s{i}", clk=clk, rst=rst, width=16, init=u(16, 0)) - tag_q.set(tag) - data_q.set(data) - tag = tag_q.out() - data = data_q.out() - - m.output("tag", tag) - m.output("data", data) - m.output("lo8", data[0:8]) + domain.next() + tag = cas(domain, domain.cycle(tag, name=f"tag_s{i}"), cycle=0) + data = cas(domain, domain.cycle(data, name=f"data_s{i}"), cycle=0) + m.output("tag", tag.wire) + m.output("data", data.wire) + m.output("lo8", data.wire[0:8]) build.__pycircuit_name__ = "jit_pipeline_vec" if __name__ == "__main__": - print(compile(build, name="jit_pipeline_vec", stages=3).emit_mlir()) + print(compile_cycle_aware(build, name="jit_pipeline_vec", eager=True, stages=3).emit_mlir()) diff --git a/designs/examples/jit_pipeline_vec/tb_jit_pipeline_vec.py b/designs/examples/jit_pipeline_vec/tb_jit_pipeline_vec.py index 3756a9c..6ac4ab4 100644 --- a/designs/examples/jit_pipeline_vec/tb_jit_pipeline_vec.py +++ b/designs/examples/jit_pipeline_vec/tb_jit_pipeline_vec.py @@ -3,7 +3,7 @@ import sys from pathlib import Path -from pycircuit import Tb, compile, testbench +from pycircuit import CycleAwareTb, Tb, compile_cycle_aware, CycleAwareCircuit, CycleAwareDomain, testbench _THIS_DIR = Path(__file__).resolve().parent if str(_THIS_DIR) not in sys.path: @@ -15,15 +15,19 @@ @testbench def tb(t: Tb) -> None: + tb = CycleAwareTb(t) p = TB_PRESETS["smoke"] - t.clock("clk") - t.reset("rst", cycles_asserted=2, cycles_deasserted=1) - t.timeout(int(p["timeout"])) - t.drive("a", 1, at=0) - t.drive("b", 1, at=0) - t.drive("sel", 1, at=0) - t.finish(at=int(p["finish"])) + tb.clock("clk") + tb.reset("rst", cycles_asserted=2, cycles_deasserted=1) + tb.timeout(int(p["timeout"])) + + # --- cycle 0 --- + tb.drive("a", 1) + tb.drive("b", 1) + tb.drive("sel", 1) + + tb.finish(at=int(p["finish"])) if __name__ == "__main__": - print(compile(build, name="tb_jit_pipeline_vec_top", **DEFAULT_PARAMS).emit_mlir()) + print(compile_cycle_aware(build, name="tb_jit_pipeline_vec_top", eager=True, **DEFAULT_PARAMS).emit_mlir()) diff --git a/designs/examples/mem_rdw_olddata/mem_rdw_olddata.py b/designs/examples/mem_rdw_olddata/mem_rdw_olddata.py index 291f373..f966a82 100644 --- a/designs/examples/mem_rdw_olddata/mem_rdw_olddata.py +++ b/designs/examples/mem_rdw_olddata/mem_rdw_olddata.py @@ -1,12 +1,12 @@ from __future__ import annotations -from pycircuit import Circuit, compile, module +from pycircuit import Circuit, module, compile_cycle_aware, CycleAwareCircuit, CycleAwareDomain -@module -def build(m: Circuit, depth: int = 4, data_width: int = 32, addr_width: int = 2) -> None: - clk = m.clock("clk") - rst = m.reset("rst") +def build(m: CycleAwareCircuit, domain: CycleAwareDomain, depth: int = 4, data_width: int = 32, addr_width: int = 2) -> None: + cd = domain.clock_domain + clk = cd.clk + rst = cd.rst ren = m.input("ren", width=1) raddr = m.input("raddr", width=addr_width) @@ -35,5 +35,5 @@ def build(m: Circuit, depth: int = 4, data_width: int = 32, addr_width: int = 2) if __name__ == "__main__": - print(compile(build, name="mem_rdw_olddata", depth=4, data_width=32, addr_width=2).emit_mlir()) + print(compile_cycle_aware(build, name="mem_rdw_olddata", eager=True, depth=4, data_width=32, addr_width=2).emit_mlir()) diff --git a/designs/examples/mem_rdw_olddata/tb_mem_rdw_olddata.py b/designs/examples/mem_rdw_olddata/tb_mem_rdw_olddata.py index 5ff6e76..0dad91b 100644 --- a/designs/examples/mem_rdw_olddata/tb_mem_rdw_olddata.py +++ b/designs/examples/mem_rdw_olddata/tb_mem_rdw_olddata.py @@ -3,7 +3,7 @@ import sys from pathlib import Path -from pycircuit import Tb, compile, testbench +from pycircuit import CycleAwareTb, Tb, compile_cycle_aware, CycleAwareCircuit, CycleAwareDomain, testbench _THIS_DIR = Path(__file__).resolve().parent if str(_THIS_DIR) not in sys.path: @@ -15,44 +15,47 @@ @testbench def tb(t: Tb) -> None: + tb = CycleAwareTb(t) p = TB_PRESETS["smoke"] - t.clock("clk") - t.reset("rst", cycles_asserted=2, cycles_deasserted=1) - t.timeout(int(p["timeout"])) + tb.clock("clk") + tb.reset("rst", cycles_asserted=2, cycles_deasserted=1) + tb.timeout(int(p["timeout"])) + # --- cycle 0 --- # Default drives. - t.drive("ren", 0, at=0) - t.drive("raddr", 0, at=0) - t.drive("wvalid", 0, at=0) - t.drive("waddr", 0, at=0) - t.drive("wdata", 0, at=0) - t.drive("wstrb", 0, at=0) + tb.drive("ren", 0) + tb.drive("raddr", 0) + tb.drive("wvalid", 0) + tb.drive("waddr", 0) + tb.drive("wdata", 0) + tb.drive("wstrb", 0) # Cycle 0: write old value. - t.drive("wvalid", 1, at=0) - t.drive("waddr", 0, at=0) - t.drive("wdata", 0x11111111, at=0) - t.drive("wstrb", 0xF, at=0) + tb.drive("wvalid", 1) + tb.drive("waddr", 0) + tb.drive("wdata", 0x11111111) + tb.drive("wstrb", 0xF) + tb.next() # --- cycle 1 --- # Cycle 1: read+write same address -> expect old-data. - t.drive("ren", 1, at=1) - t.drive("raddr", 0, at=1) - t.drive("wvalid", 1, at=1) - t.drive("waddr", 0, at=1) - t.drive("wdata", 0x22222222, at=1) - t.drive("wstrb", 0xF, at=1) - t.expect("rdata", 0x11111111, at=1, phase="post", msg="RDW must return old-data") - + tb.drive("ren", 1) + tb.drive("raddr", 0) + tb.drive("wvalid", 1) + tb.drive("waddr", 0) + tb.drive("wdata", 0x22222222) + tb.drive("wstrb", 0xF) + tb.expect("rdata", 0x11111111, phase="post", msg="RDW must return old-data") + + tb.next() # --- cycle 2 --- # Cycle 2: read again -> expect new data. - t.drive("wvalid", 0, at=2) - t.drive("wstrb", 0, at=2) - t.drive("ren", 1, at=2) - t.drive("raddr", 0, at=2) - t.expect("rdata", 0x22222222, at=2, phase="post") + tb.drive("wvalid", 0) + tb.drive("wstrb", 0) + tb.drive("ren", 1) + tb.drive("raddr", 0) + tb.expect("rdata", 0x22222222, phase="post") - t.finish(at=int(p["finish"])) + tb.finish(at=int(p["finish"])) if __name__ == "__main__": - print(compile(build, name="tb_mem_rdw_olddata_top", **DEFAULT_PARAMS).emit_mlir()) - + print(compile_cycle_aware(build, name="tb_mem_rdw_olddata_top", eager=True, **DEFAULT_PARAMS).emit_mlir()) diff --git a/designs/examples/module_collection/module_collection.py b/designs/examples/module_collection/module_collection.py index fd93b21..0343774 100644 --- a/designs/examples/module_collection/module_collection.py +++ b/designs/examples/module_collection/module_collection.py @@ -1,6 +1,6 @@ from __future__ import annotations -from pycircuit import Circuit, compile, const, module, spec, u, wiring +from pycircuit import Circuit, compile_cycle_aware, CycleAwareCircuit, CycleAwareDomain, const, module, spec, u, wiring @const @@ -15,7 +15,6 @@ def _lane_out_spec(m: Circuit, *, width: int): return _lane_in_spec(m, width=width).rename_field("payload.data", "sum").add_field("meta.idx", width=8) -@module(structural=True) def _lane(m: Circuit, *, width: int = 32): in_spec = _lane_in_spec(m, width=width) out_spec = _lane_out_spec(m, width=width) @@ -36,8 +35,7 @@ def _lane(m: Circuit, *, width: int = 32): ) -@module -def build(m: Circuit, *, width: int = 32, lanes: int = 8): +def build(m: CycleAwareCircuit, domain: CycleAwareDomain, *, width: int = 32, lanes: int = 8): seed = m.input("seed", width=width) in_spec = _lane_in_spec(m, width=width) @@ -72,4 +70,4 @@ def build(m: Circuit, *, width: int = 32, lanes: int = 8): build.__pycircuit_name__ = "module_collection" if __name__ == "__main__": - print(compile(build, name="module_collection", width=32, lanes=8).emit_mlir()) + print(compile_cycle_aware(build, name="module_collection", width=32, lanes=8).emit_mlir()) diff --git a/designs/examples/module_collection/tb_module_collection.py b/designs/examples/module_collection/tb_module_collection.py index 66885d7..711b516 100644 --- a/designs/examples/module_collection/tb_module_collection.py +++ b/designs/examples/module_collection/tb_module_collection.py @@ -3,7 +3,7 @@ import sys from pathlib import Path -from pycircuit import Tb, compile, testbench +from pycircuit import CycleAwareTb, Tb, compile_cycle_aware, CycleAwareCircuit, CycleAwareDomain, testbench _THIS_DIR = Path(__file__).resolve().parent if str(_THIS_DIR) not in sys.path: @@ -15,12 +15,14 @@ @testbench def tb(t: Tb) -> None: + tb = CycleAwareTb(t) p = TB_PRESETS["smoke"] - t.timeout(int(p["timeout"])) - t.drive("seed", 0, at=0) - t.expect("acc", 100, at=0) - t.finish(at=int(p["finish"])) + tb.timeout(int(p["timeout"])) + # --- cycle 0 --- + tb.drive("seed", 0) + tb.expect("acc", 100) + tb.finish(at=int(p["finish"])) if __name__ == "__main__": - print(compile(build, name="tb_module_collection_top", **DEFAULT_PARAMS).emit_mlir()) + print(compile_cycle_aware(build, name="tb_module_collection_top", **DEFAULT_PARAMS).emit_mlir()) diff --git a/designs/examples/multiclock_regs/multiclock_regs.py b/designs/examples/multiclock_regs/multiclock_regs.py index 917dad5..41bd5e9 100644 --- a/designs/examples/multiclock_regs/multiclock_regs.py +++ b/designs/examples/multiclock_regs/multiclock_regs.py @@ -1,17 +1,20 @@ from __future__ import annotations -from pycircuit import Circuit, compile, module, u +from pycircuit import Circuit, module, compile_cycle_aware, CycleAwareCircuit, CycleAwareDomain, u +from pycircuit.hw import ClockDomain -@module -def build(m: Circuit) -> None: +def build(m: CycleAwareCircuit, domain: CycleAwareDomain) -> None: + _ = domain clk_a = m.clock("clk_a") rst_a = m.reset("rst_a") clk_b = m.clock("clk_b") rst_b = m.reset("rst_b") + cd_a = ClockDomain(clk=clk_a, rst=rst_a) + cd_b = ClockDomain(clk=clk_b, rst=rst_b) - a = m.out("a_q", clk=clk_a, rst=rst_a, width=8, init=u(8, 0)) - b = m.out("b_q", clk=clk_b, rst=rst_b, width=8, init=u(8, 0)) + a = m.out("a_q", domain=cd_a, width=8, init=u(8, 0)) + b = m.out("b_q", domain=cd_b, width=8, init=u(8, 0)) a.set(a.out() + 1) b.set(b.out() + 1) @@ -25,4 +28,4 @@ def build(m: Circuit) -> None: if __name__ == "__main__": - print(compile(build, name="multiclock_regs").emit_mlir()) + print(compile_cycle_aware(build, name="multiclock_regs").emit_mlir()) diff --git a/designs/examples/multiclock_regs/tb_multiclock_regs.py b/designs/examples/multiclock_regs/tb_multiclock_regs.py index 98b691d..add0130 100644 --- a/designs/examples/multiclock_regs/tb_multiclock_regs.py +++ b/designs/examples/multiclock_regs/tb_multiclock_regs.py @@ -3,7 +3,7 @@ import sys from pathlib import Path -from pycircuit import Tb, compile, testbench +from pycircuit import CycleAwareTb, Tb, compile_cycle_aware, CycleAwareCircuit, CycleAwareDomain, testbench _THIS_DIR = Path(__file__).resolve().parent if str(_THIS_DIR) not in sys.path: @@ -15,14 +15,16 @@ @testbench def tb(t: Tb) -> None: + tb = CycleAwareTb(t) p = TB_PRESETS["smoke"] - t.clock("clk_a") - t.clock("clk_b") - t.reset("rst_a", cycles_asserted=2, cycles_deasserted=1) - t.timeout(int(p["timeout"])) - t.drive("rst_b", 0, at=0) - t.finish(at=int(p["finish"])) + tb.clock("clk_a") + tb.clock("clk_b") + tb.reset("rst_a", cycles_asserted=2, cycles_deasserted=1) + tb.timeout(int(p["timeout"])) + # --- cycle 0 --- + tb.drive("rst_b", 0) + tb.finish(at=int(p["finish"])) if __name__ == "__main__": - print(compile(build, name="tb_multiclock_regs_top", **DEFAULT_PARAMS).emit_mlir()) + print(compile_cycle_aware(build, name="tb_multiclock_regs_top", **DEFAULT_PARAMS).emit_mlir()) diff --git a/designs/examples/net_resolution_depth_smoke/net_resolution_depth_smoke.py b/designs/examples/net_resolution_depth_smoke/net_resolution_depth_smoke.py index 300c826..ac1da4f 100644 --- a/designs/examples/net_resolution_depth_smoke/net_resolution_depth_smoke.py +++ b/designs/examples/net_resolution_depth_smoke/net_resolution_depth_smoke.py @@ -1,26 +1,30 @@ from __future__ import annotations -from pycircuit import Circuit, compile, module +from pycircuit import ( + CycleAwareCircuit, + CycleAwareDomain, + cas, + compile_cycle_aware, +) -@module -def build(m: Circuit, width: int = 8) -> None: - clk = m.clock("clk") - rst = m.reset("rst") - in_x = m.input("in_x", width=width) +def build(m: CycleAwareCircuit, domain: CycleAwareDomain, width: int = 8) -> None: + in_x = cas(domain, m.input("in_x", width=width), cycle=0) d0 = in_x + 1 d1 = d0 + 1 d2 = d1 + 1 d3 = d2 + 1 - q = m.out("q", clk=clk, rst=rst, width=width, init=0) + q = domain.state(width=width, reset_value=0, name="q") + m.output("y", q.wire) + + domain.next() q.set(d3) - m.output("y", q) build.__pycircuit_name__ = "net_resolution_depth_smoke" if __name__ == "__main__": - print(compile(build, name="net_resolution_depth_smoke", width=8).emit_mlir()) + print(compile_cycle_aware(build, name="net_resolution_depth_smoke", eager=True, width=8).emit_mlir()) diff --git a/designs/examples/net_resolution_depth_smoke/tb_net_resolution_depth_smoke.py b/designs/examples/net_resolution_depth_smoke/tb_net_resolution_depth_smoke.py index 29969ac..50532b8 100644 --- a/designs/examples/net_resolution_depth_smoke/tb_net_resolution_depth_smoke.py +++ b/designs/examples/net_resolution_depth_smoke/tb_net_resolution_depth_smoke.py @@ -3,7 +3,7 @@ import sys from pathlib import Path -from pycircuit import Tb, compile, testbench +from pycircuit import CycleAwareTb, Tb, compile_cycle_aware, CycleAwareCircuit, CycleAwareDomain, testbench _THIS_DIR = Path(__file__).resolve().parent if str(_THIS_DIR) not in sys.path: @@ -15,21 +15,24 @@ @testbench def tb(t: Tb) -> None: + tb = CycleAwareTb(t) p = TB_PRESETS["smoke"] - t.clock("clk") - t.reset("rst", cycles_asserted=2, cycles_deasserted=0) - t.timeout(int(p["timeout"])) + tb.clock("clk") + tb.reset("rst", cycles_asserted=2, cycles_deasserted=0) + tb.timeout(int(p["timeout"])) - t.drive("in_x", 1, at=0) - t.expect("y", 0, at=0, phase="pre") - t.expect("y", 5, at=0, phase="post") + # --- cycle 0 --- + tb.drive("in_x", 1) + tb.expect("y", 0, phase="pre") + tb.expect("y", 5, phase="post") - t.drive("in_x", 2, at=1) - t.expect("y", 5, at=1, phase="pre") - t.expect("y", 6, at=1, phase="post") + tb.next() # --- cycle 1 --- + tb.drive("in_x", 2) + tb.expect("y", 5, phase="pre") + tb.expect("y", 6, phase="post") - t.finish(at=int(p["finish"])) + tb.finish(at=int(p["finish"])) if __name__ == "__main__": - print(compile(build, name="tb_net_resolution_depth_smoke_top", **DEFAULT_PARAMS).emit_mlir()) + print(compile_cycle_aware(build, name="tb_net_resolution_depth_smoke_top", eager=True, **DEFAULT_PARAMS).emit_mlir()) diff --git a/designs/examples/obs_points/obs_points.py b/designs/examples/obs_points/obs_points.py index d260722..2b093aa 100644 --- a/designs/examples/obs_points/obs_points.py +++ b/designs/examples/obs_points/obs_points.py @@ -1,27 +1,28 @@ from __future__ import annotations -from pycircuit import Circuit, compile, module, u +from pycircuit import ( + CycleAwareCircuit, + CycleAwareDomain, + cas, + compile_cycle_aware, +) -@module -def build(m: Circuit, width: int = 8) -> None: - clk = m.clock("clk") - rst = m.reset("rst") - - x = m.input("x", width=width) +def build(m: CycleAwareCircuit, domain: CycleAwareDomain, width: int = 8) -> None: + x = cas(domain, m.input("x", width=width), cycle=0) y = x + 1 - # Sample-and-hold: capture combinational `y` into state `q` each cycle. - q = m.out("q_q", clk=clk, rst=rst, width=width, init=u(width, 0)) - q.set(y) + q = domain.state(width=width, reset_value=0, name="q") - m.output("y", y) - m.output("q", q) + m.output("y", y.wire) + m.output("q", q.wire) + + domain.next() + q.set(y) build.__pycircuit_name__ = "obs_points" if __name__ == "__main__": - print(compile(build, name="obs_points", width=8).emit_mlir()) - + print(compile_cycle_aware(build, name="obs_points", eager=True, width=8).emit_mlir()) diff --git a/designs/examples/obs_points/tb_obs_points.py b/designs/examples/obs_points/tb_obs_points.py index 5d0f8a1..3497a5f 100644 --- a/designs/examples/obs_points/tb_obs_points.py +++ b/designs/examples/obs_points/tb_obs_points.py @@ -3,7 +3,7 @@ import sys from pathlib import Path -from pycircuit import Tb, compile, testbench +from pycircuit import CycleAwareTb, Tb, compile_cycle_aware, CycleAwareCircuit, CycleAwareDomain, testbench _THIS_DIR = Path(__file__).resolve().parent if str(_THIS_DIR) not in sys.path: @@ -15,28 +15,30 @@ @testbench def tb(t: Tb) -> None: + tb = CycleAwareTb(t) p = TB_PRESETS["smoke"] - t.clock("clk") - t.reset("rst", cycles_asserted=2, cycles_deasserted=0) - t.timeout(int(p["timeout"])) + tb.clock("clk") + tb.reset("rst", cycles_asserted=2, cycles_deasserted=0) + tb.timeout(int(p["timeout"])) + # --- cycle 0 --- # Default drives. - t.drive("x", 0, at=0) - + tb.drive("x", 0) # Cycle 0: comb changes visible at pre; state updates visible at post. - t.drive("x", 10, at=0) - t.expect("y", 11, at=0, phase="pre", msg="TICK-OBS: comb must reflect current drives") - t.expect("q", 0, at=0, phase="pre", msg="TICK-OBS: state is pre-commit") - t.expect("q", 11, at=0, phase="post", msg="XFER-OBS: state commit is visible") + tb.drive("x", 10) + tb.expect("y", 11, phase="pre", msg="TICK-OBS: comb must reflect current drives") + tb.expect("q", 0, phase="pre", msg="TICK-OBS: state is pre-commit") + tb.expect("q", 11, phase="post", msg="XFER-OBS: state commit is visible") + tb.next() # --- cycle 1 --- # Cycle 1: repeat with a new drive to validate both obs points again. - t.drive("x", 20, at=1) - t.expect("y", 21, at=1, phase="pre") - t.expect("q", 11, at=1, phase="pre") - t.expect("q", 21, at=1, phase="post") + tb.drive("x", 20) + tb.expect("y", 21, phase="pre") + tb.expect("q", 11, phase="pre") + tb.expect("q", 21, phase="post") - t.finish(at=int(p["finish"])) + tb.finish(at=int(p["finish"])) if __name__ == "__main__": - print(compile(build, name="tb_obs_points_top", **DEFAULT_PARAMS).emit_mlir()) + print(compile_cycle_aware(build, name="tb_obs_points_top", eager=True, **DEFAULT_PARAMS).emit_mlir()) diff --git a/designs/examples/pipeline_builder/pipeline_builder.py b/designs/examples/pipeline_builder/pipeline_builder.py index d3b4b79..46e67f7 100644 --- a/designs/examples/pipeline_builder/pipeline_builder.py +++ b/designs/examples/pipeline_builder/pipeline_builder.py @@ -1,6 +1,6 @@ from __future__ import annotations -from pycircuit import Circuit, ConnectorStruct, compile, const, module, spec, u +from pycircuit import Circuit, ConnectorStruct, compile_cycle_aware, CycleAwareCircuit, CycleAwareDomain, const, module, spec, u @const @@ -18,21 +18,21 @@ def _pipe_struct(m: Circuit, *, width: int): ) -@module -def build(m: Circuit, *, width: int = 32): - clk = m.clock("clk") - rst = m.reset("rst") +def build(m: CycleAwareCircuit, domain: CycleAwareDomain, *, width: int = 32): + cd = domain.clock_domain + clk = cd.clk + rst = cd.rst s = _pipe_struct(m, width=width) in_b = m.inputs(s, prefix="in_") - st0 = m.state(s, clk=clk, rst=rst, prefix="st0_") + st0 = m.state(s, clk=cd.clk, rst=cd.rst, prefix="st0_") m.connect(st0, in_b) st1_in = st0.flatten() st1_in["payload.word"] = (st0["payload.word"].read() + u(width, 1))[0:width] - st1 = m.state(s, clk=clk, rst=rst, prefix="st1_") + st1 = m.state(s, clk=cd.clk, rst=cd.rst, prefix="st1_") m.connect(st1, ConnectorStruct(st1_in, spec=s)) m.outputs(s, st1, prefix="out_") @@ -40,4 +40,4 @@ def build(m: Circuit, *, width: int = 32): build.__pycircuit_name__ = "pipeline_builder" if __name__ == "__main__": - print(compile(build, name="pipeline_builder", width=32).emit_mlir()) + print(compile_cycle_aware(build, name="pipeline_builder", width=32).emit_mlir()) diff --git a/designs/examples/pipeline_builder/tb_pipeline_builder.py b/designs/examples/pipeline_builder/tb_pipeline_builder.py index a0ea85e..188383d 100644 --- a/designs/examples/pipeline_builder/tb_pipeline_builder.py +++ b/designs/examples/pipeline_builder/tb_pipeline_builder.py @@ -3,7 +3,7 @@ import sys from pathlib import Path -from pycircuit import Tb, compile, testbench +from pycircuit import CycleAwareTb, Tb, compile_cycle_aware, CycleAwareCircuit, CycleAwareDomain, testbench _THIS_DIR = Path(__file__).resolve().parent if str(_THIS_DIR) not in sys.path: @@ -15,15 +15,17 @@ @testbench def tb(t: Tb) -> None: + tb = CycleAwareTb(t) p = TB_PRESETS["smoke"] - t.clock("clk") - t.reset("rst", cycles_asserted=2, cycles_deasserted=1) - t.timeout(int(p["timeout"])) - t.drive("in_payload_word", 5, at=0) - t.drive("in_ctrl_valid", 1, at=0) - t.expect("out_ctrl_valid", 0, at=0) - t.finish(at=int(p["finish"])) + tb.clock("clk") + tb.reset("rst", cycles_asserted=2, cycles_deasserted=1) + tb.timeout(int(p["timeout"])) + # --- cycle 0 --- + tb.drive("in_payload_word", 5) + tb.drive("in_ctrl_valid", 1) + tb.expect("out_ctrl_valid", 0) + tb.finish(at=int(p["finish"])) if __name__ == "__main__": - print(compile(build, name="tb_pipeline_builder_top", **DEFAULT_PARAMS).emit_mlir()) + print(compile_cycle_aware(build, name="tb_pipeline_builder_top", **DEFAULT_PARAMS).emit_mlir()) diff --git a/designs/examples/reset_invalidate_order_smoke/reset_invalidate_order_smoke.py b/designs/examples/reset_invalidate_order_smoke/reset_invalidate_order_smoke.py index aec9b6d..8f9f725 100644 --- a/designs/examples/reset_invalidate_order_smoke/reset_invalidate_order_smoke.py +++ b/designs/examples/reset_invalidate_order_smoke/reset_invalidate_order_smoke.py @@ -1,20 +1,29 @@ from __future__ import annotations -from pycircuit import Circuit, ProbeBuilder, ProbeView, compile, module, probe +from pycircuit import ( + CycleAwareCircuit, + CycleAwareDomain, + ProbeBuilder, + ProbeView, + cas, + compile_cycle_aware, + mux, + probe, +) -@module -def build(m: Circuit, width: int = 8) -> None: - clk = m.clock("clk") - rst = m.reset("rst") - en = m.input("en", width=1) +def build(m: CycleAwareCircuit, domain: CycleAwareDomain, width: int = 8) -> None: + en = cas(domain, m.input("en", width=1), cycle=0) - q = m.out("q", clk=clk, rst=rst, width=width, init=0) - q.set(q.out() + 1, when=en) - m.output("y", q) + q = domain.state(width=width, reset_value=0, name="q") + m.output("y", q.wire) + + domain.next() + q.set(q + 1, when=en) build.__pycircuit_name__ = "reset_invalidate_order_smoke" +build.__pycircuit_kind__ = "module" @probe(target=build, name="reset") @@ -29,4 +38,4 @@ def reset_probe(p: ProbeBuilder, dut: ProbeView, width: int = 8) -> None: if __name__ == "__main__": - print(compile(build, name="reset_invalidate_order_smoke", width=8).emit_mlir()) + print(compile_cycle_aware(build, name="reset_invalidate_order_smoke", eager=True, width=8).emit_mlir()) diff --git a/designs/examples/reset_invalidate_order_smoke/tb_reset_invalidate_order_smoke.py b/designs/examples/reset_invalidate_order_smoke/tb_reset_invalidate_order_smoke.py index 8d4165f..83f8da2 100644 --- a/designs/examples/reset_invalidate_order_smoke/tb_reset_invalidate_order_smoke.py +++ b/designs/examples/reset_invalidate_order_smoke/tb_reset_invalidate_order_smoke.py @@ -3,7 +3,7 @@ import sys from pathlib import Path -from pycircuit import Tb, compile, testbench +from pycircuit import CycleAwareTb, Tb, compile_cycle_aware, CycleAwareCircuit, CycleAwareDomain, testbench _THIS_DIR = Path(__file__).resolve().parent if str(_THIS_DIR) not in sys.path: @@ -15,21 +15,24 @@ @testbench def tb(t: Tb) -> None: + tb = CycleAwareTb(t) p = TB_PRESETS["smoke"] - t.clock("clk") - t.reset("rst", cycles_asserted=2, cycles_deasserted=0) - t.timeout(int(p["timeout"])) + tb.clock("clk") + tb.reset("rst", cycles_asserted=2, cycles_deasserted=0) + tb.timeout(int(p["timeout"])) - t.drive("en", 1, at=0) - t.expect("y", 0, at=0, phase="pre") - t.expect("y", 1, at=0, phase="post") + # --- cycle 0 --- + tb.drive("en", 1) + tb.expect("y", 0, phase="pre") + tb.expect("y", 1, phase="post") - t.drive("en", 1, at=1) - t.expect("y", 1, at=1, phase="pre") - t.expect("y", 2, at=1, phase="post") + tb.next() # --- cycle 1 --- + tb.drive("en", 1) + tb.expect("y", 1, phase="pre") + tb.expect("y", 2, phase="post") - t.finish(at=int(p["finish"])) + tb.finish(at=int(p["finish"])) if __name__ == "__main__": - print(compile(build, name="tb_reset_invalidate_order_smoke_top", **DEFAULT_PARAMS).emit_mlir()) + print(compile_cycle_aware(build, name="tb_reset_invalidate_order_smoke_top", eager=True, **DEFAULT_PARAMS).emit_mlir()) diff --git a/designs/examples/struct_transform/struct_transform.py b/designs/examples/struct_transform/struct_transform.py index a1afa96..3b42650 100644 --- a/designs/examples/struct_transform/struct_transform.py +++ b/designs/examples/struct_transform/struct_transform.py @@ -1,6 +1,6 @@ from __future__ import annotations -from pycircuit import Circuit, compile, const, module, spec, u +from pycircuit import Circuit, compile_cycle_aware, CycleAwareCircuit, CycleAwareDomain, const, module, spec, u @const @@ -27,15 +27,13 @@ def _pipe_struct(m: Circuit, *, width: int): return spec.with_prefix("u_") -@module -def build(m: Circuit, *, width: int = 32): - clk = m.clock("clk") - rst = m.reset("rst") +def build(m: CycleAwareCircuit, domain: CycleAwareDomain, *, width: int = 32): + cd = domain.clock_domain spec = _pipe_struct(m, width=width) ins = m.inputs(spec, prefix="in_") - regs = m.state(spec, clk=clk, rst=rst, prefix="st_") + regs = m.state(spec, clk=cd.clk, rst=cd.rst, prefix="st_") m.connect(regs, ins) op = regs["u_hdr.op"].read() @@ -49,4 +47,4 @@ def build(m: Circuit, *, width: int = 32): build.__pycircuit_name__ = "struct_transform" if __name__ == "__main__": - print(compile(build, name="struct_transform", width=32).emit_mlir()) + print(compile_cycle_aware(build, name="struct_transform", width=32).emit_mlir()) diff --git a/designs/examples/struct_transform/tb_struct_transform.py b/designs/examples/struct_transform/tb_struct_transform.py index c67cb79..3683e63 100644 --- a/designs/examples/struct_transform/tb_struct_transform.py +++ b/designs/examples/struct_transform/tb_struct_transform.py @@ -3,7 +3,7 @@ import sys from pathlib import Path -from pycircuit import Tb, compile, testbench +from pycircuit import CycleAwareTb, Tb, compile_cycle_aware, CycleAwareCircuit, CycleAwareDomain, testbench _THIS_DIR = Path(__file__).resolve().parent if str(_THIS_DIR) not in sys.path: @@ -15,18 +15,20 @@ @testbench def tb(t: Tb) -> None: + tb = CycleAwareTb(t) p = TB_PRESETS["smoke"] - t.clock("clk") - t.reset("rst", cycles_asserted=2, cycles_deasserted=1) - t.timeout(int(p["timeout"])) - t.drive("in_u_hdr_op", 1, at=0) - t.drive("in_u_hdr_dst", 2, at=0) - t.drive("in_u_payload_word", 3, at=0) - t.drive("in_u_ctrl_valid", 1, at=0) - t.expect("out_u_ctrl_valid", 1, at=0) - t.expect("out_u_payload_word", 5, at=0) - t.finish(at=int(p["finish"])) + tb.clock("clk") + tb.reset("rst", cycles_asserted=2, cycles_deasserted=1) + tb.timeout(int(p["timeout"])) + # --- cycle 0 --- + tb.drive("in_u_hdr_op", 1) + tb.drive("in_u_hdr_dst", 2) + tb.drive("in_u_payload_word", 3) + tb.drive("in_u_ctrl_valid", 1) + tb.expect("out_u_ctrl_valid", 1) + tb.expect("out_u_payload_word", 5) + tb.finish(at=int(p["finish"])) if __name__ == "__main__": - print(compile(build, name="tb_struct_transform_top", **DEFAULT_PARAMS).emit_mlir()) + print(compile_cycle_aware(build, name="tb_struct_transform_top", **DEFAULT_PARAMS).emit_mlir()) diff --git a/designs/examples/sync_mem_init_zero/sync_mem_init_zero.py b/designs/examples/sync_mem_init_zero/sync_mem_init_zero.py index 1fde3d6..9676817 100644 --- a/designs/examples/sync_mem_init_zero/sync_mem_init_zero.py +++ b/designs/examples/sync_mem_init_zero/sync_mem_init_zero.py @@ -1,12 +1,12 @@ from __future__ import annotations -from pycircuit import Circuit, compile, module +from pycircuit import Circuit, module, compile_cycle_aware, CycleAwareCircuit, CycleAwareDomain -@module -def build(m: Circuit, depth: int = 4, data_width: int = 32, addr_width: int = 2) -> None: - clk = m.clock("clk") - rst = m.reset("rst") +def build(m: CycleAwareCircuit, domain: CycleAwareDomain, depth: int = 4, data_width: int = 32, addr_width: int = 2) -> None: + cd = domain.clock_domain + clk = cd.clk + rst = cd.rst ren = m.input("ren", width=1) raddr = m.input("raddr", width=addr_width) @@ -35,5 +35,5 @@ def build(m: Circuit, depth: int = 4, data_width: int = 32, addr_width: int = 2) if __name__ == "__main__": - print(compile(build, name="sync_mem_init_zero", depth=4, data_width=32, addr_width=2).emit_mlir()) + print(compile_cycle_aware(build, name="sync_mem_init_zero", eager=True, depth=4, data_width=32, addr_width=2).emit_mlir()) diff --git a/designs/examples/sync_mem_init_zero/tb_sync_mem_init_zero.py b/designs/examples/sync_mem_init_zero/tb_sync_mem_init_zero.py index d9b7a9c..b0614f4 100644 --- a/designs/examples/sync_mem_init_zero/tb_sync_mem_init_zero.py +++ b/designs/examples/sync_mem_init_zero/tb_sync_mem_init_zero.py @@ -3,7 +3,7 @@ import sys from pathlib import Path -from pycircuit import Tb, compile, testbench +from pycircuit import CycleAwareTb, Tb, compile_cycle_aware, CycleAwareCircuit, CycleAwareDomain, testbench _THIS_DIR = Path(__file__).resolve().parent if str(_THIS_DIR) not in sys.path: @@ -15,29 +15,31 @@ @testbench def tb(t: Tb) -> None: + tb = CycleAwareTb(t) p = TB_PRESETS["smoke"] - t.clock("clk") - t.reset("rst", cycles_asserted=2, cycles_deasserted=1) - t.timeout(int(p["timeout"])) + tb.clock("clk") + tb.reset("rst", cycles_asserted=2, cycles_deasserted=1) + tb.timeout(int(p["timeout"])) + # --- cycle 0 --- # Default drives (no writes). - t.drive("wvalid", 0, at=0) - t.drive("waddr", 0, at=0) - t.drive("wdata", 0, at=0) - t.drive("wstrb", 0, at=0) + tb.drive("wvalid", 0) + tb.drive("waddr", 0) + tb.drive("wdata", 0) + tb.drive("wstrb", 0) # Read from unwritten addresses: deterministic sim init must be 0. - t.drive("ren", 1, at=0) - t.drive("raddr", 1, at=0) - t.expect("rdata", 0, at=0, phase="post", msg="sync_mem must initialize entries to 0 (deterministic sim)") + tb.drive("ren", 1) + tb.drive("raddr", 1) + tb.expect("rdata", 0, phase="post", msg="sync_mem must initialize entries to 0 (deterministic sim)") - t.drive("ren", 1, at=1) - t.drive("raddr", 3, at=1) - t.expect("rdata", 0, at=1, phase="post") + tb.next() # --- cycle 1 --- + tb.drive("ren", 1) + tb.drive("raddr", 3) + tb.expect("rdata", 0, phase="post") - t.finish(at=int(p["finish"])) + tb.finish(at=int(p["finish"])) if __name__ == "__main__": - print(compile(build, name="tb_sync_mem_init_zero_top", **DEFAULT_PARAMS).emit_mlir()) - + print(compile_cycle_aware(build, name="tb_sync_mem_init_zero_top", eager=True, **DEFAULT_PARAMS).emit_mlir()) diff --git a/designs/examples/trace_dsl_smoke/tb_trace_dsl_smoke.py b/designs/examples/trace_dsl_smoke/tb_trace_dsl_smoke.py index fa1450d..fc57ca8 100644 --- a/designs/examples/trace_dsl_smoke/tb_trace_dsl_smoke.py +++ b/designs/examples/trace_dsl_smoke/tb_trace_dsl_smoke.py @@ -3,7 +3,7 @@ import sys from pathlib import Path -from pycircuit import Tb, compile, testbench +from pycircuit import CycleAwareTb, Tb, compile_cycle_aware, CycleAwareCircuit, CycleAwareDomain, testbench _THIS_DIR = Path(__file__).resolve().parent if str(_THIS_DIR) not in sys.path: @@ -15,34 +15,38 @@ @testbench def tb(t: Tb) -> None: + tb = CycleAwareTb(t) p = TB_PRESETS["smoke"] - t.clock("clk") - t.reset("rst", cycles_asserted=2, cycles_deasserted=0) - t.timeout(int(p["timeout"])) + tb.clock("clk") + tb.reset("rst", cycles_asserted=2, cycles_deasserted=0) + tb.timeout(int(p["timeout"])) + # --- cycle 0 --- # Cycle 0: pre is init, post sees commit. - t.drive("in_x", 0x12, at=0) - t.expect("y0", 0x00, at=0, phase="pre") - t.expect("y1", 0x00, at=0, phase="pre") - t.expect("y0", 0x12, at=0, phase="post") - t.expect("y1", 0x12, at=0, phase="post") + tb.drive("in_x", 0x12) + tb.expect("y0", 0x00, phase="pre") + tb.expect("y1", 0x00, phase="pre") + tb.expect("y0", 0x12, phase="post") + tb.expect("y1", 0x12, phase="post") + tb.next() # --- cycle 1 --- # Cycle 1: same behavior with a new drive. - t.drive("in_x", 0x34, at=1) - t.expect("y0", 0x12, at=1, phase="pre") - t.expect("y1", 0x12, at=1, phase="pre") - t.expect("y0", 0x34, at=1, phase="post") - t.expect("y1", 0x34, at=1, phase="post") + tb.drive("in_x", 0x34) + tb.expect("y0", 0x12, phase="pre") + tb.expect("y1", 0x12, phase="pre") + tb.expect("y0", 0x34, phase="post") + tb.expect("y1", 0x34, phase="post") + tb.next() # --- cycle 2 --- # Cycle 2: stable drive (committed output holds; trace still records Write intent; Decision 0053). - t.drive("in_x", 0x34, at=2) - t.expect("y0", 0x34, at=2, phase="pre") - t.expect("y1", 0x34, at=2, phase="pre") - t.expect("y0", 0x34, at=2, phase="post") - t.expect("y1", 0x34, at=2, phase="post") + tb.drive("in_x", 0x34) + tb.expect("y0", 0x34, phase="pre") + tb.expect("y1", 0x34, phase="pre") + tb.expect("y0", 0x34, phase="post") + tb.expect("y1", 0x34, phase="post") - t.finish(at=int(p["finish"])) + tb.finish(at=int(p["finish"])) if __name__ == "__main__": - print(compile(build, name="tb_trace_dsl_smoke_top", **DEFAULT_PARAMS).emit_mlir()) + print(compile_cycle_aware(build, name="tb_trace_dsl_smoke_top", **DEFAULT_PARAMS).emit_mlir()) diff --git a/designs/examples/trace_dsl_smoke/trace_dsl_smoke.py b/designs/examples/trace_dsl_smoke/trace_dsl_smoke.py index 168ad9f..9813104 100644 --- a/designs/examples/trace_dsl_smoke/trace_dsl_smoke.py +++ b/designs/examples/trace_dsl_smoke/trace_dsl_smoke.py @@ -1,21 +1,21 @@ from __future__ import annotations -from pycircuit import Circuit, ProbeBuilder, ProbeView, compile, module, probe +from pycircuit import Circuit, ProbeBuilder, ProbeView, compile_cycle_aware, CycleAwareCircuit, CycleAwareDomain, module, probe +from pycircuit.hw import ClockDomain @module -def leaf(m: Circuit) -> None: - clk = m.clock("clk") - rst = m.reset("rst") +def leaf(m: Circuit, clk, rst) -> None: + cd = ClockDomain(clk=clk, rst=rst) x = m.input("in_x", width=8) - r = m.out("r", clk=clk, rst=rst, width=8, init=0) + r = m.out("r", domain=cd, width=8, init=0) r.set(x) m.output("out_y", r) -@module -def build(m: Circuit) -> None: +def build(m: CycleAwareCircuit, domain: CycleAwareDomain) -> None: + _ = domain clk = m.clock("clk") rst = m.reset("rst") x = m.input("in_x", width=8) @@ -41,4 +41,4 @@ def leaf_pipeview(p: ProbeBuilder, dut: ProbeView) -> None: if __name__ == "__main__": - print(compile(build, name="trace_dsl_smoke").emit_mlir()) + print(compile_cycle_aware(build, name="trace_dsl_smoke").emit_mlir()) diff --git a/designs/examples/traffic_lights_ce_pyc/PLAN.md b/designs/examples/traffic_lights_ce_pyc/PLAN.md new file mode 100644 index 0000000..d009fd1 --- /dev/null +++ b/designs/examples/traffic_lights_ce_pyc/PLAN.md @@ -0,0 +1,53 @@ +# PLAN: traffic_lights_ce_pyc + +## Core observations from Traffic-lights-ce + +- Two-direction intersection with East/West (main) and North/South (secondary). +- Default timing: EW green 45s, EW yellow 5s, NS green 30s, NS yellow 5s. +- Red durations are derived from the opposite direction's green+yellow (EW red = 30+5, NS red = 45+5). +- Yellow blinks at 1 Hz during yellow phases. +- Emergency mode forces all-red and displays "88" on both countdowns. +- Original design uses separate countdown modules per direction and an edge-trigger to make single-cycle change pulses. + +## Implementation plan for pyCircuit + +- Build a new example under `examples/traffic_lights_ce_pyc/` with a cycle-aware design. +- Top-level outputs are 8-bit BCD countdowns (`ew_bcd`, `ns_bcd`) plus discrete red/yellow/green lights. +- Reuse `examples/digital_clock/bcd.py` for BCD conversion (`bin_to_bcd_60`). +- Use a combined 4-phase FSM: EW_GREEN -> EW_YELLOW -> NS_GREEN -> NS_YELLOW -> EW_GREEN +- Maintain two countdown registers (EW/NS). Decrement on each 1 Hz tick. + - Reload only the direction whose light changes. + - Red durations are derived from opposite green+yellow. +- Emergency behavior: + - Outputs forced to all-red and BCD=0x88. + - Internal counters and phase freeze while `emergency=1` or `go=0`. +- Provide a C API wrapper and a terminal emulator similar to `digital_clock`. + +## Deliverables + +- `traffic_lights_ce.py` (pyCircuit design) +- `traffic_lights_capi.cpp` (C API wrapper) +- `emulate_traffic_lights.py` (terminal visualization) +- `README.md` (build and run instructions) +- `PLAN.md` (this document) +- `__init__.py` (package marker) + +## Interfaces (planned) + +- Inputs: `clk`, `rst`, `go`, `emergency` +- Outputs: + - `ew_bcd`, `ns_bcd` (8-bit BCD, `{tens, ones}`) + - `ew_red/ew_yellow/ew_green`, `ns_red/ns_yellow/ns_green` + +## JIT parameters (planned) + +- `CLK_FREQ` (Hz) +- `EW_GREEN_S`, `EW_YELLOW_S` +- `NS_GREEN_S`, `NS_YELLOW_S` +- Derived: `EW_RED_S = NS_GREEN_S + NS_YELLOW_S`, `NS_RED_S = EW_GREEN_S + EW_YELLOW_S` + +## Test/usage (planned) + +- Generate MLIR via `pycircuit.cli emit` with optional `--param CLK_FREQ=1000` for faster emulation. +- Compile to Verilog/C++ using `pyc-compile --emit=verilog/cpp`. +- Build shared lib and run `emulate_traffic_lights.py`. diff --git a/designs/examples/traffic_lights_ce_pyc/README.md b/designs/examples/traffic_lights_ce_pyc/README.md new file mode 100644 index 0000000..8d140a5 --- /dev/null +++ b/designs/examples/traffic_lights_ce_pyc/README.md @@ -0,0 +1,78 @@ +# Traffic Lights (pyCircuit) + +A cycle-aware traffic lights controller based on the [Traffic-lights-ce](https://github.com/Starrynightzyq/Traffic-lights-ce) design. +It exposes BCD countdowns for East/West and North/South, plus discrete red/yellow/green lights. +The terminal emulator renders a simple 7-seg view and can load multiple stimulus patterns. + +**Key files** +- `traffic_lights_ce.py`: pyCircuit implementation of the FSM, countdowns, blink, and outputs. +- `traffic_lights_capi.cpp`: C API wrapper around the generated C++ model for ctypes. +- `emulate_traffic_lights.py`: terminal visualization; drives the DUT via the C API. +- `stimuli/*.py`: independent stimulus modules (driver logic separated from the DUT). +- `PLAN.md`: design notes and implementation plan. + +## Ports + +| Port | Dir | Width | Description | +|------|-----|-------|-------------| +| `clk` | in | 1 | System clock | +| `rst` | in | 1 | Synchronous reset | +| `go` | in | 1 | Run/pause (1=run, 0=freeze) | +| `emergency` | in | 1 | Emergency override (1=all red, BCD=88) | +| `ew_bcd` | out | 8 | East/West countdown BCD `{tens,ones}` | +| `ns_bcd` | out | 8 | North/South countdown BCD `{tens,ones}` | +| `ew_red` | out | 1 | East/West red | +| `ew_yellow` | out | 1 | East/West yellow (blink) | +| `ew_green` | out | 1 | East/West green | +| `ns_red` | out | 1 | North/South red | +| `ns_yellow` | out | 1 | North/South yellow (blink) | +| `ns_green` | out | 1 | North/South green | + +## JIT parameters + +| Parameter | Default | Description | +|-----------|---------|-------------| +| `CLK_FREQ` | 50_000_000 | System clock frequency (Hz) | +| `EW_GREEN_S` | 45 | East/West green time (seconds) | +| `EW_YELLOW_S` | 5 | East/West yellow time (seconds) | +| `NS_GREEN_S` | 30 | North/South green time (seconds) | +| `NS_YELLOW_S` | 5 | North/South yellow time (seconds) | + +Derived durations: +- `EW_RED_S = NS_GREEN_S + NS_YELLOW_S` +- `NS_RED_S = EW_GREEN_S + EW_YELLOW_S` + +## Build and Run + +The emulator assumes `CLK_FREQ=1000` for fast visualization. Set it via +`PYC_TL_CLK_FREQ=1000` when emitting the design. The following sequence is +verified end-to-end (including all stimuli): + +```bash +PYC_TL_CLK_FREQ=1000 PYTHONPATH=python python3 -m pycircuit.cli emit \ + examples/traffic_lights_ce_pyc/traffic_lights_ce.py \ + -o /tmp/traffic_lights_ce_pyc.pyc + +./build/bin/pyc-compile /tmp/traffic_lights_ce_pyc.pyc \ + --emit=verilog --out-dir=examples/generated/traffic_lights_ce_pyc + +./build/bin/pyc-compile /tmp/traffic_lights_ce_pyc.pyc \ + --emit=cpp --out-dir=examples/generated/traffic_lights_ce_pyc + +c++ -std=c++17 -O2 -shared -fPIC -I include -I . \ + -o examples/traffic_lights_ce_pyc/libtraffic_lights_sim.dylib \ + examples/traffic_lights_ce_pyc/traffic_lights_capi.cpp + +python3 examples/traffic_lights_ce_pyc/emulate_traffic_lights.py --stim basic +python3 examples/traffic_lights_ce_pyc/emulate_traffic_lights.py --stim emergency_pulse +python3 examples/traffic_lights_ce_pyc/emulate_traffic_lights.py --stim pause_resume +``` + +## Stimuli + +Stimulus is loaded as an independent module, separate from the DUT. +Available modules live under `examples/traffic_lights_ce_pyc/stimuli/`. + +- `basic`: continuous run, no interruptions +- `emergency_pulse`: assert emergency for a window +- `pause_resume`: toggle `go` to pause/resume diff --git a/designs/examples/traffic_lights_ce_pyc/__init__.py b/designs/examples/traffic_lights_ce_pyc/__init__.py new file mode 100644 index 0000000..5b0a864 --- /dev/null +++ b/designs/examples/traffic_lights_ce_pyc/__init__.py @@ -0,0 +1 @@ +# Package marker for traffic_lights_ce_pyc example. diff --git a/designs/examples/traffic_lights_ce_pyc/emulate_traffic_lights.py b/designs/examples/traffic_lights_ce_pyc/emulate_traffic_lights.py new file mode 100644 index 0000000..9f0568b --- /dev/null +++ b/designs/examples/traffic_lights_ce_pyc/emulate_traffic_lights.py @@ -0,0 +1,268 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +emulate_traffic_lights.py — True RTL simulation of the traffic lights +with a terminal visualization. + +Build the shared library first: + cd + c++ -std=c++17 -O2 -shared -fPIC -I include -I . \ + -o examples/traffic_lights_ce_pyc/libtraffic_lights_sim.dylib \ + examples/traffic_lights_ce_pyc/traffic_lights_capi.cpp + +Then run: + python examples/traffic_lights_ce_pyc/emulate_traffic_lights.py +""" +from __future__ import annotations + +import argparse +import ctypes +import importlib +import sys +import time +from pathlib import Path + +# ============================================================================= +# ANSI helpers +# ============================================================================= + +RESET = "\033[0m" +BOLD = "\033[1m" +DIM = "\033[2m" +RED = "\033[31m" +YELLOW = "\033[33m" +GREEN = "\033[32m" +WHITE = "\033[37m" +CYAN = "\033[36m" + + +def clear_screen() -> None: + print("\033[2J\033[H", end="") + + +# ============================================================================= +# 7-segment ASCII art +# ============================================================================= + +_SEG = { + 0: (" _ ", "| |", "|_|"), + 1: (" ", " |", " |"), + 2: (" _ ", " _|", "|_ "), + 3: (" _ ", " _|", " _|"), + 4: (" ", "|_|", " |"), + 5: (" _ ", "|_ ", " _|"), + 6: (" _ ", "|_ ", "|_|"), + 7: (" _ ", " |", " |"), + 8: (" _ ", "|_|", "|_|"), + 9: (" _ ", "|_|", " _|"), +} + + +def _digit_rows(d: int, color: str = WHITE) -> list[str]: + rows = _SEG.get(d, _SEG[0]) + return [f"{color}{r}{RESET}" for r in rows] + + +def _box(rows: list[str]) -> list[str]: + """Wrap content rows with a 1-char ASCII border.""" + if not rows: + raise ValueError("expected at least 1 row for box content") + width = len(rows[0]) + if any(len(r) != width for r in rows): + raise ValueError("all rows must be the same width for box") + top = "+" + "-" * width + "+" + mid = [f"|{r}|" for r in rows] + return [top, *mid, top] + + +def _light_cluster(label: str, on: int, color: str) -> list[str]: + """3x3 letter cluster representing a single light.""" + ch = label if on else label.lower() + paint = color if on else DIM + row = f"{paint}{ch*3}{RESET}" + return [row, row, row] + + +def _digits_box(tens: int, ones: int, color: str = WHITE) -> list[str]: + d0 = _digit_rows(tens, color) + d1 = _digit_rows(ones, color) + rows = [f"{d0[i]} {d1[i]}" for i in range(3)] + return _box(rows) + + +# ============================================================================= +# RTL simulation wrapper (ctypes -> compiled C++ netlist) +# ============================================================================= + +# Must match the CLK_FREQ used when generating the RTL for this demo. +RTL_CLK_FREQ = 1000 + + +class TrafficLightsRTL: + def __init__(self, lib_path: str | None = None): + if lib_path is None: + lib_path = str(Path(__file__).resolve().parent / "libtraffic_lights_sim.dylib") + self._lib = ctypes.CDLL(lib_path) + + self._lib.tl_create.restype = ctypes.c_void_p + self._lib.tl_destroy.argtypes = [ctypes.c_void_p] + self._lib.tl_reset.argtypes = [ctypes.c_void_p, ctypes.c_uint64] + self._lib.tl_set_inputs.argtypes = [ctypes.c_void_p, ctypes.c_int, ctypes.c_int] + self._lib.tl_tick.argtypes = [ctypes.c_void_p] + self._lib.tl_run_cycles.argtypes = [ctypes.c_void_p, ctypes.c_uint64] + + for name in ( + "tl_get_ew_bcd", "tl_get_ns_bcd", + "tl_get_ew_red", "tl_get_ew_yellow", "tl_get_ew_green", + "tl_get_ns_red", "tl_get_ns_yellow", "tl_get_ns_green", + ): + getattr(self._lib, name).argtypes = [ctypes.c_void_p] + getattr(self._lib, name).restype = ctypes.c_uint32 + + self._lib.tl_get_cycle.argtypes = [ctypes.c_void_p] + self._lib.tl_get_cycle.restype = ctypes.c_uint64 + + self._ctx = self._lib.tl_create() + self.go = 0 + self.emergency = 0 + + def __del__(self): + if hasattr(self, "_ctx") and self._ctx: + self._lib.tl_destroy(self._ctx) + + def reset(self, cycles: int = 2): + self._lib.tl_reset(self._ctx, cycles) + + def _apply_inputs(self): + self._lib.tl_set_inputs(self._ctx, self.go, self.emergency) + + def tick(self): + self._apply_inputs() + self._lib.tl_tick(self._ctx) + + def run_cycles(self, n: int): + self._apply_inputs() + self._lib.tl_run_cycles(self._ctx, n) + + @property + def ew_bcd(self) -> tuple[int, int]: + v = self._lib.tl_get_ew_bcd(self._ctx) + return ((v >> 4) & 0xF, v & 0xF) + + @property + def ns_bcd(self) -> tuple[int, int]: + v = self._lib.tl_get_ns_bcd(self._ctx) + return ((v >> 4) & 0xF, v & 0xF) + + @property + def ew_lights(self) -> tuple[int, int, int]: + return ( + int(self._lib.tl_get_ew_red(self._ctx)), + int(self._lib.tl_get_ew_yellow(self._ctx)), + int(self._lib.tl_get_ew_green(self._ctx)), + ) + + @property + def ns_lights(self) -> tuple[int, int, int]: + return ( + int(self._lib.tl_get_ns_red(self._ctx)), + int(self._lib.tl_get_ns_yellow(self._ctx)), + int(self._lib.tl_get_ns_green(self._ctx)), + ) + + @property + def cycle(self) -> int: + return int(self._lib.tl_get_cycle(self._ctx)) + + +# ============================================================================= +# Rendering +# ============================================================================= + + +def render_direction(label: str, tens: int, ones: int, lights: tuple[int, int, int]) -> list[str]: + r, y, g = lights + header = f"{BOLD}{label}{RESET}" + + digits_box = _digits_box(tens, ones, WHITE) + + r_cluster = _light_cluster("R", r, RED) + y_cluster = _light_cluster("Y", y, YELLOW) + g_cluster = _light_cluster("G", g, GREEN) + lights_row = " ".join([r_cluster[1], y_cluster[1], g_cluster[1]]) + lights_box = _box([lights_row]) + + lines = [header] + lines.extend([f" {row}" for row in lights_box]) + lines.extend([f" {row}" for row in digits_box]) + return lines + + +def _load_stimulus(name: str): + if "." in name: + return importlib.import_module(name) + try: + return importlib.import_module(f"examples.traffic_lights_ce_pyc.stimuli.{name}") + except ModuleNotFoundError: + root = Path(__file__).resolve().parents[2] + sys.path.insert(0, str(root)) + return importlib.import_module(f"examples.traffic_lights_ce_pyc.stimuli.{name}") + + +def main(): + ap = argparse.ArgumentParser(description="Traffic lights terminal emulator") + ap.add_argument( + "--stim", + default="emergency_pulse", + help="Stimulus module name (e.g. basic, emergency_pulse, pause_resume)", + ) + ap.add_argument( + "--debug", + action="store_true", + help="Print extra debug info (BCD values as integers)", + ) + args = ap.parse_args() + + stim = _load_stimulus(args.stim) + + rtl = TrafficLightsRTL() + rtl.reset() + if hasattr(stim, "init"): + stim.init(rtl) + else: + rtl.go = 1 + rtl.emergency = 0 + + total_seconds = int(getattr(stim, "total_seconds", lambda: 120)()) + sleep_s = float(getattr(stim, "sleep_s", lambda: 0.08)()) + + for sec in range(total_seconds): + if hasattr(stim, "step"): + stim.step(sec, rtl) + + clear_screen() + ew_t, ew_o = rtl.ew_bcd + ns_t, ns_o = rtl.ns_bcd + + ew_lines = render_direction("EW", ew_t, ew_o, rtl.ew_lights) + ns_lines = render_direction("NS", ns_t, ns_o, rtl.ns_lights) + + ew_val = ew_t * 10 + ew_o + ns_val = ns_t * 10 + ns_o + print(f"{CYAN}traffic_lights_ce_pyc{RESET} cycle={rtl.cycle} sec={sec}") + print(f"go={rtl.go} emergency={rtl.emergency} CLK_FREQ={RTL_CLK_FREQ}") + if args.debug: + print(f"ew_bcd={ew_t}{ew_o} ({ew_val}) ns_bcd={ns_t}{ns_o} ({ns_val})") + print("") + for line in ew_lines: + print(line) + print("") + for line in ns_lines: + print(line) + + rtl.run_cycles(RTL_CLK_FREQ) + time.sleep(sleep_s) + + +if __name__ == "__main__": + main() diff --git a/designs/examples/traffic_lights_ce_pyc/stimuli/__init__.py b/designs/examples/traffic_lights_ce_pyc/stimuli/__init__.py new file mode 100644 index 0000000..32ffd7b --- /dev/null +++ b/designs/examples/traffic_lights_ce_pyc/stimuli/__init__.py @@ -0,0 +1 @@ +"""Stimulus modules for traffic_lights_ce_pyc emulator.""" diff --git a/designs/examples/traffic_lights_ce_pyc/stimuli/basic.py b/designs/examples/traffic_lights_ce_pyc/stimuli/basic.py new file mode 100644 index 0000000..3166552 --- /dev/null +++ b/designs/examples/traffic_lights_ce_pyc/stimuli/basic.py @@ -0,0 +1,20 @@ +"""Basic stimulus: run continuously with no interruptions.""" + + +def total_seconds() -> int: + return 120 + + +def sleep_s() -> float: + return 0.08 + + +def init(rtl) -> None: + rtl.go = 1 + rtl.emergency = 0 + + +def step(sec: int, rtl) -> None: + _ = sec + _ = rtl + # No changes during run. diff --git a/designs/examples/traffic_lights_ce_pyc/stimuli/emergency_pulse.py b/designs/examples/traffic_lights_ce_pyc/stimuli/emergency_pulse.py new file mode 100644 index 0000000..952d9aa --- /dev/null +++ b/designs/examples/traffic_lights_ce_pyc/stimuli/emergency_pulse.py @@ -0,0 +1,21 @@ +"""Emergency pulse stimulus: inject emergency for a short window.""" + + +def total_seconds() -> int: + return 140 + + +def sleep_s() -> float: + return 0.08 + + +def init(rtl) -> None: + rtl.go = 1 + rtl.emergency = 0 + + +def step(sec: int, rtl) -> None: + if sec == 60: + rtl.emergency = 1 + if sec == 72: + rtl.emergency = 0 diff --git a/designs/examples/traffic_lights_ce_pyc/stimuli/pause_resume.py b/designs/examples/traffic_lights_ce_pyc/stimuli/pause_resume.py new file mode 100644 index 0000000..6b53fb1 --- /dev/null +++ b/designs/examples/traffic_lights_ce_pyc/stimuli/pause_resume.py @@ -0,0 +1,21 @@ +"""Pause/resume stimulus: toggles go while running.""" + + +def total_seconds() -> int: + return 140 + + +def sleep_s() -> float: + return 0.08 + + +def init(rtl) -> None: + rtl.go = 1 + rtl.emergency = 0 + + +def step(sec: int, rtl) -> None: + if sec == 50: + rtl.go = 0 + if sec == 65: + rtl.go = 1 diff --git a/designs/examples/traffic_lights_ce_pyc/traffic_lights_capi.cpp b/designs/examples/traffic_lights_ce_pyc/traffic_lights_capi.cpp new file mode 100644 index 0000000..e4da887 --- /dev/null +++ b/designs/examples/traffic_lights_ce_pyc/traffic_lights_capi.cpp @@ -0,0 +1,73 @@ +/** + * traffic_lights_capi.cpp — C API wrapper around the generated RTL model. + * + * Build: + * cd + * c++ -std=c++17 -O2 -shared -fPIC -I include -I . \ + * -o examples/traffic_lights_ce_pyc/libtraffic_lights_sim.dylib \ + * examples/traffic_lights_ce_pyc/traffic_lights_capi.cpp + */ + +#include +#include +#include + +#include "../generated/traffic_lights_ce_pyc/traffic_lights_ce_pyc.hpp" + +using pyc::cpp::Wire; + +struct SimContext { + pyc::gen::traffic_lights_ce_pyc dut{}; + pyc::cpp::Testbench tb; + uint64_t cycle = 0; + + SimContext() : tb(dut) { + tb.addClock(dut.clk, /*halfPeriodSteps=*/1); + } +}; + +extern "C" { + +SimContext* tl_create() { + return new SimContext(); +} + +void tl_destroy(SimContext* ctx) { + delete ctx; +} + +void tl_reset(SimContext* ctx, uint64_t cycles) { + ctx->tb.reset(ctx->dut.rst, /*cyclesAsserted=*/cycles, /*cyclesDeasserted=*/1); + ctx->dut.eval(); + ctx->cycle = 0; +} + +void tl_set_inputs(SimContext* ctx, int go, int emergency) { + ctx->dut.go = Wire<1>(go ? 1u : 0u); + ctx->dut.emergency = Wire<1>(emergency ? 1u : 0u); +} + +void tl_tick(SimContext* ctx) { + ctx->tb.runCycles(1); + ctx->cycle++; +} + +void tl_run_cycles(SimContext* ctx, uint64_t n) { + ctx->tb.runCycles(n); + ctx->cycle += n; +} + +uint32_t tl_get_ew_bcd(SimContext* ctx) { return ctx->dut.ew_bcd.value(); } +uint32_t tl_get_ns_bcd(SimContext* ctx) { return ctx->dut.ns_bcd.value(); } + +uint32_t tl_get_ew_red(SimContext* ctx) { return ctx->dut.ew_red.value(); } +uint32_t tl_get_ew_yellow(SimContext* ctx) { return ctx->dut.ew_yellow.value(); } +uint32_t tl_get_ew_green(SimContext* ctx) { return ctx->dut.ew_green.value(); } + +uint32_t tl_get_ns_red(SimContext* ctx) { return ctx->dut.ns_red.value(); } +uint32_t tl_get_ns_yellow(SimContext* ctx) { return ctx->dut.ns_yellow.value(); } +uint32_t tl_get_ns_green(SimContext* ctx) { return ctx->dut.ns_green.value(); } + +uint64_t tl_get_cycle(SimContext* ctx) { return ctx->cycle; } + +} // extern "C" diff --git a/designs/examples/traffic_lights_ce_pyc/traffic_lights_ce.py b/designs/examples/traffic_lights_ce_pyc/traffic_lights_ce.py new file mode 100644 index 0000000..e38e636 --- /dev/null +++ b/designs/examples/traffic_lights_ce_pyc/traffic_lights_ce.py @@ -0,0 +1,209 @@ +# -*- coding: utf-8 -*- +"""Traffic Lights Controller — pyCircuit v4.0 design. + +Reimplements the Traffic-lights-ce project in the pyCircuit unified signal model. +Outputs are BCD countdowns per direction plus discrete red/yellow/green lights. + +JIT parameters: + CLK_FREQ — system clock frequency in Hz (default 50 MHz) + EW_GREEN_S — east/west green time in seconds + EW_YELLOW_S — east/west yellow time in seconds + NS_GREEN_S — north/south green time in seconds + NS_YELLOW_S — north/south yellow time in seconds + +Derived: + EW_RED_S = NS_GREEN_S + NS_YELLOW_S + NS_RED_S = EW_GREEN_S + EW_YELLOW_S +""" +from __future__ import annotations + +import os + +from pycircuit import Circuit, module, compile_cycle_aware, CycleAwareCircuit, CycleAwareDomain, function, u + + +# Phase encoding +PH_EW_GREEN = 0 +PH_EW_YELLOW = 1 +PH_NS_GREEN = 2 +PH_NS_YELLOW = 3 + + +@function +def bin_to_bcd_60(m: Circuit, val, width): + """Convert 0-59 binary value to 8-bit packed BCD (tens in [7:4], units in [3:0]).""" + tens = (u(4, 5) if (val >= u(width, 50)) else + u(4, 4) if (val >= u(width, 40)) else + u(4, 3) if (val >= u(width, 30)) else + u(4, 2) if (val >= u(width, 20)) else + u(4, 1) if (val >= u(width, 10)) else + u(4, 0)) + tens_w = tens | u(width, 0) + units = (val - tens_w * u(width, 10))[0:4] + return (tens | u(8, 0)) << 4 | (units | u(8, 0)) + + +def build(m: CycleAwareCircuit, domain: CycleAwareDomain, *, + CLK_FREQ: int = 50_000_000, + EW_GREEN_S: int = 45, + EW_YELLOW_S: int = 5, + NS_GREEN_S: int = 30, + NS_YELLOW_S: int = 5, +) -> None: + if min(EW_GREEN_S, EW_YELLOW_S, NS_GREEN_S, NS_YELLOW_S) <= 0: + raise ValueError("all durations must be > 0") + + EW_RED_S = NS_GREEN_S + NS_YELLOW_S + NS_RED_S = EW_GREEN_S + EW_YELLOW_S + + max_dur = max(EW_GREEN_S, EW_YELLOW_S, NS_GREEN_S, NS_YELLOW_S, EW_RED_S, NS_RED_S) + if max_dur > 59: + raise ValueError("all durations must be <= 59 to fit bin_to_bcd_60") + cd = domain.clock_domain + clk = cd.clk + rst = cd.rst + + # ================================================================ + # Inputs + # ================================================================ + go = m.input("go", width=1) + emergency = m.input("emergency", width=1) + + # ================================================================ + # Registers + # ================================================================ + PRESCALER_W = max((CLK_FREQ - 1).bit_length(), 1) + CNT_W = max(max_dur.bit_length(), 1) + + prescaler_r = m.out("prescaler", domain=cd, width=PRESCALER_W, init=u(PRESCALER_W, 0)) + phase_r = m.out("phase", domain=cd, width=2, init=u(2, PH_EW_GREEN)) + ew_cnt_r = m.out("ew_cnt", domain=cd, width=CNT_W, init=u(CNT_W, EW_GREEN_S)) + ns_cnt_r = m.out("ns_cnt", domain=cd, width=CNT_W, init=u(CNT_W, NS_RED_S)) + blink_r = m.out("blink", domain=cd, width=1, init=u(1, 0)) + + # ================================================================ + # Combinational logic + # ================================================================ + pv = prescaler_r.out() + ph = phase_r.out() + ew = ew_cnt_r.out() + ns = ns_cnt_r.out() + bl = blink_r.out() + + en = go & (~emergency) + + # 1 Hz tick via prescaler (gated by en) + tick_raw = pv == u(PRESCALER_W, CLK_FREQ - 1) + tick_1hz = tick_raw & en + inner_prescaler = u(PRESCALER_W, 0) if tick_raw else (pv + 1) + prescaler_next = inner_prescaler if en else pv + + # Phase flags + is_ew_green = ph == u(2, PH_EW_GREEN) + is_ew_yellow = ph == u(2, PH_EW_YELLOW) + is_ns_green = ph == u(2, PH_NS_GREEN) + is_ns_yellow = ph == u(2, PH_NS_YELLOW) + yellow_active = is_ew_yellow | is_ns_yellow + + # Countdown end flags + ew_end = ew == u(CNT_W, 0) + ns_end = ns == u(CNT_W, 0) + + ew_cnt_dec = ew - 1 + ns_cnt_dec = ns - 1 + + # Phase transitions (when counter reaches 0 on a tick) + cond_ew_to_yellow = tick_1hz & is_ew_green & ew_end + cond_ew_to_ns_green = tick_1hz & is_ew_yellow & ew_end + cond_ns_to_yellow = tick_1hz & is_ns_green & ns_end + cond_ns_to_ew_green = tick_1hz & is_ns_yellow & ns_end + + phase_next = ph + phase_next = u(2, PH_EW_YELLOW) if cond_ew_to_yellow else phase_next + phase_next = u(2, PH_NS_GREEN) if cond_ew_to_ns_green else phase_next + phase_next = u(2, PH_NS_YELLOW) if cond_ns_to_yellow else phase_next + phase_next = u(2, PH_EW_GREEN) if cond_ns_to_ew_green else phase_next + + # EW countdown + ew_cnt_next = ew + ew_cnt_next = ew_cnt_dec if (tick_1hz & (~ew_end)) else ew_cnt_next + ew_cnt_next = u(CNT_W, EW_YELLOW_S) if cond_ew_to_yellow else ew_cnt_next + ew_cnt_next = u(CNT_W, EW_RED_S) if cond_ew_to_ns_green else ew_cnt_next + ew_cnt_next = u(CNT_W, EW_GREEN_S) if cond_ns_to_ew_green else ew_cnt_next + + # NS countdown + ns_cnt_next = ns + ns_cnt_next = ns_cnt_dec if (tick_1hz & (~ns_end)) else ns_cnt_next + ns_cnt_next = u(CNT_W, NS_GREEN_S) if cond_ew_to_ns_green else ns_cnt_next + ns_cnt_next = u(CNT_W, NS_YELLOW_S) if cond_ns_to_yellow else ns_cnt_next + ns_cnt_next = u(CNT_W, NS_RED_S) if cond_ns_to_ew_green else ns_cnt_next + + # BCD conversion (combinational) + ew_bcd_raw = bin_to_bcd_60(m, ew, CNT_W) + ns_bcd_raw = bin_to_bcd_60(m, ns, CNT_W) + + # Lights (base, before emergency override) + ew_red_base = is_ns_green | is_ns_yellow + ew_green_base = is_ew_green + ew_yellow_base = is_ew_yellow & bl + + ns_red_base = is_ew_green | is_ew_yellow + ns_green_base = is_ns_green + ns_yellow_base = is_ns_yellow & bl + + # Emergency overrides + ew_bcd = u(8, 0x88) if emergency else ew_bcd_raw + ns_bcd = u(8, 0x88) if emergency else ns_bcd_raw + + ew_red = u(1, 1) if emergency else ew_red_base + ew_yellow = u(1, 0) if emergency else ew_yellow_base + ew_green = u(1, 0) if emergency else ew_green_base + + ns_red = u(1, 1) if emergency else ns_red_base + ns_yellow = u(1, 0) if emergency else ns_yellow_base + ns_green = u(1, 0) if emergency else ns_green_base + + # ================================================================ + # Register updates + # ================================================================ + prescaler_r.set(prescaler_next) + phase_r.set(phase_next) + ew_cnt_r.set(ew_cnt_next) + ns_cnt_r.set(ns_cnt_next) + + # Blink: reset to 0 when not in yellow; toggle on tick_1hz while yellow. + blink_r.set(u(1, 0), when=~yellow_active) + blink_r.set(~bl, when=tick_1hz & yellow_active) + + # ================================================================ + # Outputs + # ================================================================ + m.output("ew_bcd", ew_bcd) + m.output("ns_bcd", ns_bcd) + m.output("ew_red", ew_red) + m.output("ew_yellow", ew_yellow) + m.output("ew_green", ew_green) + m.output("ns_red", ns_red) + m.output("ns_yellow", ns_yellow) + m.output("ns_green", ns_green) + + +build.__pycircuit_name__ = "traffic_lights_ce_pyc" + +if __name__ == "__main__": + def _env_int(key: str, default: int) -> int: + raw = os.getenv(key) + if raw is None: + return default + try: + return int(raw, 0) + except ValueError as exc: + raise ValueError(f"invalid {key}={raw!r}") from exc + + print(compile_cycle_aware(build, name="traffic_lights_ce_pyc", + CLK_FREQ=_env_int("PYC_TL_CLK_FREQ", 50_000_000), + EW_GREEN_S=_env_int("PYC_TL_EW_GREEN_S", 45), + EW_YELLOW_S=_env_int("PYC_TL_EW_YELLOW_S", 5), + NS_GREEN_S=_env_int("PYC_TL_NS_GREEN_S", 30), + NS_YELLOW_S=_env_int("PYC_TL_NS_YELLOW_S", 5), + ).emit_mlir()) diff --git a/designs/examples/wire_ops/tb_wire_ops.py b/designs/examples/wire_ops/tb_wire_ops.py index 1984587..8afa564 100644 --- a/designs/examples/wire_ops/tb_wire_ops.py +++ b/designs/examples/wire_ops/tb_wire_ops.py @@ -3,7 +3,7 @@ import sys from pathlib import Path -from pycircuit import Tb, compile, testbench +from pycircuit import CycleAwareTb, Tb, compile_cycle_aware, CycleAwareCircuit, CycleAwareDomain, testbench _THIS_DIR = Path(__file__).resolve().parent if str(_THIS_DIR) not in sys.path: @@ -15,16 +15,20 @@ @testbench def tb(t: Tb) -> None: + tb = CycleAwareTb(t) p = TB_PRESETS["smoke"] - t.clock("clk") - t.reset("rst", cycles_asserted=2, cycles_deasserted=1) - t.timeout(int(p["timeout"])) - t.drive("a", 3, at=0) - t.drive("b", 1, at=0) - t.drive("sel", 1, at=0) - t.expect("y", 1, at=0) - t.finish(at=int(p["finish"])) + tb.clock("clk") + tb.reset("rst", cycles_asserted=2, cycles_deasserted=1) + tb.timeout(int(p["timeout"])) + + # --- cycle 0 --- + tb.drive("a", 3) + tb.drive("b", 1) + tb.drive("sel", 1) + tb.expect("y", 1) + + tb.finish(at=int(p["finish"])) if __name__ == "__main__": - print(compile(build, name="tb_wire_ops_top", **DEFAULT_PARAMS).emit_mlir()) + print(compile_cycle_aware(build, name="tb_wire_ops_top", eager=True, **DEFAULT_PARAMS).emit_mlir()) diff --git a/designs/examples/wire_ops/wire_ops.py b/designs/examples/wire_ops/wire_ops.py index 2808742..c0e3fcf 100644 --- a/designs/examples/wire_ops/wire_ops.py +++ b/designs/examples/wire_ops/wire_ops.py @@ -1,27 +1,28 @@ from __future__ import annotations -from pycircuit import Circuit, compile, module, u +from pycircuit import ( + CycleAwareCircuit, + CycleAwareDomain, + cas, + compile_cycle_aware, + mux, +) -@module -def build(m: Circuit) -> None: - clk = m.clock("clk") - rst = m.reset("rst") +def build(m: CycleAwareCircuit, domain: CycleAwareDomain) -> None: + a = cas(domain, m.input("a", width=8), cycle=0) + b = cas(domain, m.input("b", width=8), cycle=0) + sel = cas(domain, m.input("sel", width=1), cycle=0) - a = m.input("a", width=8) - b = m.input("b", width=8) - sel = m.input("sel", width=1) - - y = a & b if sel else a ^ b - y_q = m.out("y_q", clk=clk, rst=rst, width=8, init=u(8, 0)) - y_q.set(y) - - m.output("y", y_q) + result = mux(sel, a & b, a ^ b) + domain.next() + y = domain.cycle(result, name="y") + m.output("y", y) build.__pycircuit_name__ = "wire_ops" if __name__ == "__main__": - print(compile(build, name="wire_ops").emit_mlir()) + print(compile_cycle_aware(build, name="wire_ops", eager=True).emit_mlir()) diff --git a/designs/examples/xz_value_model_smoke/tb_xz_value_model_smoke.py b/designs/examples/xz_value_model_smoke/tb_xz_value_model_smoke.py index 5bac3fc..0b692a3 100644 --- a/designs/examples/xz_value_model_smoke/tb_xz_value_model_smoke.py +++ b/designs/examples/xz_value_model_smoke/tb_xz_value_model_smoke.py @@ -3,7 +3,7 @@ import sys from pathlib import Path -from pycircuit import Tb, compile, testbench +from pycircuit import CycleAwareTb, Tb, compile_cycle_aware, CycleAwareCircuit, CycleAwareDomain, testbench _THIS_DIR = Path(__file__).resolve().parent if str(_THIS_DIR) not in sys.path: @@ -15,21 +15,24 @@ @testbench def tb(t: Tb) -> None: + tb = CycleAwareTb(t) p = TB_PRESETS["smoke"] - t.clock("clk") - t.reset("rst", cycles_asserted=2, cycles_deasserted=0) - t.timeout(int(p["timeout"])) + tb.clock("clk") + tb.reset("rst", cycles_asserted=2, cycles_deasserted=0) + tb.timeout(int(p["timeout"])) - t.drive("in_a", 0x12, at=0) - t.expect("y", 0x00, at=0, phase="pre") - t.expect("y", 0x12, at=0, phase="post") + # --- cycle 0 --- + tb.drive("in_a", 0x12) + tb.expect("y", 0x00, phase="pre") + tb.expect("y", 0x12, phase="post") - t.drive("in_a", 0x56, at=1) - t.expect("y", 0x12, at=1, phase="pre") - t.expect("y", 0x56, at=1, phase="post") + tb.next() # --- cycle 1 --- + tb.drive("in_a", 0x56) + tb.expect("y", 0x12, phase="pre") + tb.expect("y", 0x56, phase="post") - t.finish(at=int(p["finish"])) + tb.finish(at=int(p["finish"])) if __name__ == "__main__": - print(compile(build, name="tb_xz_value_model_smoke_top", **DEFAULT_PARAMS).emit_mlir()) + print(compile_cycle_aware(build, name="tb_xz_value_model_smoke_top", eager=True, **DEFAULT_PARAMS).emit_mlir()) diff --git a/designs/examples/xz_value_model_smoke/xz_value_model_smoke.py b/designs/examples/xz_value_model_smoke/xz_value_model_smoke.py index f49e844..190cd41 100644 --- a/designs/examples/xz_value_model_smoke/xz_value_model_smoke.py +++ b/designs/examples/xz_value_model_smoke/xz_value_model_smoke.py @@ -1,21 +1,28 @@ from __future__ import annotations -from pycircuit import Circuit, ProbeBuilder, ProbeView, compile, module, probe +from pycircuit import ( + CycleAwareCircuit, + CycleAwareDomain, + ProbeBuilder, + ProbeView, + cas, + compile_cycle_aware, + probe, +) -@module -def build(m: Circuit, width: int = 8) -> None: - clk = m.clock("clk") - rst = m.reset("rst") - in_a = m.input("in_a", width=width) +def build(m: CycleAwareCircuit, domain: CycleAwareDomain, width: int = 8) -> None: + in_a = cas(domain, m.input("in_a", width=width), cycle=0) - q = m.out("q", clk=clk, rst=rst, width=width, init=0) - q.set(in_a) + q = domain.state(width=width, reset_value=0, name="q") + m.output("y", q.wire) - m.output("y", q) + domain.next() + q.set(in_a) build.__pycircuit_name__ = "xz_value_model_smoke" +build.__pycircuit_kind__ = "module" @probe(target=build, name="value") @@ -30,4 +37,4 @@ def value_probe(p: ProbeBuilder, dut: ProbeView, width: int = 8) -> None: if __name__ == "__main__": - print(compile(build, name="xz_value_model_smoke", width=8).emit_mlir()) + print(compile_cycle_aware(build, name="xz_value_model_smoke", eager=True, width=8).emit_mlir()) diff --git a/docs/PyCircuit V5 Programming Tutorial.md b/docs/PyCircuit V5 Programming Tutorial.md new file mode 100644 index 0000000..671edd0 --- /dev/null +++ b/docs/PyCircuit V5 Programming Tutorial.md @@ -0,0 +1,1070 @@ +# PyCircuit Programming Tutorial + +**作者:Liao Heng** + +**版本:1.0** + +--- + +## 目录 + +1. [概述](#概述) +2. [核心概念](#核心概念) + - [Clock Domain(时钟域)](#clock-domain时钟域) + - [Signal(信号)](#signal信号) + - [Module(模块)](#module模块) + - [clock_domain.next()](#clock_domainnext) + - [clock_domain.prev()](#clock_domainprev) + - [clock_domain.push() / pop()](#clock_domainpush--pop) + - [clock_domain.cycle()](#clock_domaincycle) + - [Nested Module(嵌套模块)](#nested-module嵌套模块) +3. [自动周期平衡](#自动周期平衡) +4. [两种输出模式](#两种输出模式) +5. [编程范例](#编程范例) + - [范例1:频率分频器(testdivider.py)](#范例1频率分频器testdividerpy) + - [范例2:实时时钟系统(testproject.py)](#范例2实时时钟系统testprojectpy) + - [范例3:RISC-V CPU(riscv.py)](#范例3risc-v-cpuriscvpy) +6. [生成的电路图](#生成的电路图) +7. [最佳实践](#最佳实践) + +--- + +## 概述 + +PyCircuit 是一个基于 Python 的硬件描述语言(HDL)框架,专为数字电路设计而创建。它提供了一种直观的方式来描述时序逻辑电路,核心特性包括: + +- **周期感知信号(Cycle-aware Signals)**:每个信号都携带其时序周期信息 +- **多时钟域支持**:独立管理多个时钟域及其复位信号 +- **自动周期平衡**:自动插入延迟(DFF)或反馈(FB)以对齐信号时序 +- **自动变量名提取**:使用 JIT 方法从 Python 源码提取变量名 +- **层次化/扁平化输出**:支持两种电路描述模式 + +### 安装与导入 + +```python +from pyCircuit import ( + pyc_ClockDomain, # 时钟域 + pyc_Signal, # 信号类 + pyc_CircuitModule, # 电路模块基类 + pyc_CircuitLogger, # 电路日志器 + signal, # 信号创建快捷方式 + log, # 日志函数 + mux # 多路选择器 +) +import pyCircuit +from pyVisualize import visualize_circuit # 可视化工具 +``` + +--- + +## 核心概念 + +### Clock Domain(时钟域) + +时钟域是 PyCircuit 中最基础的概念,它代表一个独立的时钟信号及其相关的时序逻辑。 + +#### 创建时钟域 + +```python +# 语法 +clock_domain = pyc_ClockDomain(name, frequency_desc="", reset_active_high=False) + +# 示例 +cpu_clk = pyc_ClockDomain("CPU_CLK", "100MHz CPU clock", reset_active_high=False) +rtc_clk = pyc_ClockDomain("RTC_CLK", "1Hz RTC domain", reset_active_high=False) +``` + +**参数说明:** +- `name`:时钟域名称(字符串) +- `frequency_desc`:频率描述(可选,用于文档) +- `reset_active_high`:复位信号极性,`False` 表示低电平有效(rstn) + +#### 创建复位信号 + +```python +rst = clock_domain.create_reset() # 创建复位信号 +# 自动命名为 {domain_name}_rstn 或 {domain_name}_rst +``` + +#### 创建输入信号 + +```python +clk_in = clock_domain.create_signal("clock_input") +data_in = clock_domain.create_signal("data_input") +``` + +--- + +### Signal(信号) + +信号是 PyCircuit 中的基本数据单元,每个信号都包含: +- 表达式(expression) +- 周期(cycle) +- 时钟域(domain) +- 位宽(width,可选) + +#### 信号创建语法 + +```python +# 方式1:使用 signal 快捷方式(推荐) +counter = signal[7:0](value=0) | "8-bit counter" +data = signal[31:0](value="input_data") | "32-bit data" +flag = signal(value="condition") | "Boolean flag" + +# 方式2:动态位宽 +bits = 8 +reg = signal[f"{bits}-1:0"](value=0) | "Dynamic width register" + +# 方式3:位选择表达式 +opcode = signal[6:0](value=f"{instruction}[6:0]") | "Opcode field" +``` + +**语法说明:** +- `signal[high:low](value=...)`:创建指定位宽的信号 +- `| "description"`:管道运算符添加描述(可选但推荐) +- `value` 可以是: + - 整数常量:`0`, `0xFF` + - 字符串表达式:`"input_data"`, `"a + b"` + - 格式化字符串:`f"{other_signal}[7:0]"` + +#### 信号运算 + +PyCircuit 重载了 Python 运算符,支持硬件描述式的信号运算: + +```python +# 算术运算 +sum_val = (a + b) | "Addition" +diff = (a - b) | "Subtraction" +prod = (a * b) | "Multiplication" + +# 逻辑运算 +and_result = (a & b) | "Bitwise AND" +or_result = (a | b) | "Bitwise OR" +xor_result = (a ^ b) | "Bitwise XOR" +not_result = (~a) | "Bitwise NOT" + +# 比较运算 +eq = (a == b) | "Equal" +ne = (a != b) | "Not equal" +lt = (a < b) | "Less than" +gt = (a > b) | "Greater than" + +# 多路选择器 +result = mux(condition, true_value, false_value) | "Mux selection" +``` + +--- + +### Module(模块) + +模块是电路设计的基本组织单元,封装了一组相关的信号和逻辑。 + +#### 定义模块类 + +```python +class MyModule(pyc_CircuitModule): + """自定义电路模块""" + + def __init__(self, name, clock_domain): + super().__init__(name, clock_domain=clock_domain) + # 初始化模块参数 + + def build(self, input1, input2): + """构建模块逻辑""" + with self.module( + inputs=[input1, input2], + description="Module description" + ) as mod: + # 模块内部逻辑 + result = (input1 + input2) | "Sum" + + # 设置输出 + mod.outputs = [result] + + return result +``` + +#### 模块上下文管理器 + +`self.module()` 返回一个上下文管理器,用于: +- 记录模块边界 +- 管理输入/输出信号 +- 在嵌套模块中正确处理时钟周期 + +```python +with self.module( + inputs=[sig1, sig2], # 输入信号列表 + description="描述文字" # 模块描述 +) as mod: + # 模块逻辑 + mod.outputs = [out1, out2] # 设置输出 +``` + +--- + +### clock_domain.next() + +`next()` 方法推进时钟周期边界,标记时序逻辑的分界点。 + +#### 语法 + +```python +self.clock_domain.next() # 推进到下一个时钟周期 +``` + +#### 语义 + +- 调用 `next()` 后,所有新创建的信号将属于新的周期 +- 用于分隔组合逻辑和时序逻辑 +- 在流水线设计中标记各级边界 + +#### 示例 + +```python +def build(self, input_data): + with self.module(inputs=[input_data]) as mod: + # Cycle 0: 输入处理 + processed = (input_data & 0xFF) | "Masked input" + + self.clock_domain.next() # 推进到 Cycle 1 + + # Cycle 1: 进一步处理 + result = (processed + 1) | "Incremented" + + self.clock_domain.next() # 推进到 Cycle 2 + + # Cycle 2: 输出 + output = result | "Final output" + mod.outputs = [output] +``` + +--- + +### clock_domain.prev() + +`prev()` 方法将时钟周期回退一步,与 `next()` 相反。 + +#### 语法 + +```python +self.clock_domain.prev() # 回退到上一个时钟周期 +``` + +#### 语义 + +- 调用 `prev()` 后,当前默认周期减 1 +- 允许在过程式编程中灵活调整周期位置 +- 周期计数可以变为负数(这是设计允许的) + +#### 示例 + +```python +def build(self, input_data): + with self.module(inputs=[input_data]) as mod: + # Cycle 0 + a = input_data | "Input" + + self.clock_domain.next() # -> Cycle 1 + b = (a + 1) | "Incremented" + + self.clock_domain.next() # -> Cycle 2 + c = (b * 2) | "Doubled" + + self.clock_domain.prev() # -> Cycle 1 (回退) + # 现在我们回到了 Cycle 1,可以添加更多同周期的信号 + d = (a - 1) | "Decremented" +``` + +--- + +### clock_domain.push() / pop() + +`push()` 和 `pop()` 方法提供周期状态的栈管理,允许子函数拥有独立的周期划分而不影响调用者。 + +#### 语法 + +```python +self.clock_domain.push() # 保存当前周期到栈 +# ... 进行周期操作 ... +self.clock_domain.pop() # 恢复之前保存的周期 +``` + +#### 语义 + +- `push()` 将当前周期状态保存到用户周期栈 +- `pop()` 从栈中弹出并恢复周期状态 +- 支持嵌套调用(多层 push/pop) +- 如果 `pop()` 在没有匹配的 `push()` 时调用,会抛出 `RuntimeError` + +#### 使用场景 + +这些方法特别适合过程式编程,允许不同的子函数拥有独立的周期管理策略: + +```python +class MyModule(pyc_CircuitModule): + def helper_function_a(self, data): + """子函数 A:使用 2 个周期""" + self.clock_domain.push() # 保存调用者的周期状态 + + # 进行自己的周期划分 + result = data | "Input" + self.clock_domain.next() + result = (result + 1) | "Processed" + self.clock_domain.next() + final = (result * 2) | "Final" + + self.clock_domain.pop() # 恢复调用者的周期状态 + return final + + def helper_function_b(self, data): + """子函数 B:使用 1 个周期""" + self.clock_domain.push() # 保存调用者的周期状态 + + # 不同的周期划分策略 + result = (data & 0xFF) | "Masked" + self.clock_domain.next() + output = (result | 0x100) | "Flagged" + + self.clock_domain.pop() # 恢复调用者的周期状态 + return output + + def build(self, input_data): + with self.module(inputs=[input_data]) as mod: + # Cycle 0 + processed = input_data | "Input" + + # 调用子函数,它们各自管理自己的周期 + result_a = self.helper_function_a(processed) + result_b = self.helper_function_b(processed) + + # 仍在 Cycle 0(子函数的周期操作不影响这里) + combined = (result_a + result_b) | "Combined" + + mod.outputs = [combined] +``` + +#### 嵌套使用示例 + +```python +def outer_function(self, data): + self.clock_domain.push() # 保存周期 0 + + self.clock_domain.next() # -> 周期 1 + intermediate = self.inner_function(data) # inner 也可以 push/pop + + self.clock_domain.next() # -> 周期 2 + result = intermediate | "Result" + + self.clock_domain.pop() # 恢复周期 0 + return result + +def inner_function(self, data): + self.clock_domain.push() # 保存周期 1 + + self.clock_domain.next() # -> 周期 2 + self.clock_domain.next() # -> 周期 3 + processed = data | "Processed" + + self.clock_domain.pop() # 恢复周期 1 + return processed +``` + +--- + +### clock_domain.cycle() + +`cycle()` 方法实现 D 触发器(单周期延迟),用于创建时序元件。 + +#### 语法 + +```python +registered = self.clock_domain.cycle(signal, description="", reset_value=None) +``` + +**参数:** +- `signal`:要寄存的信号 +- `description`:描述(可选) +- `reset_value`:复位值(可选) + +#### 语义 + +- 输出信号的周期 = 输入信号周期 + 1 +- 如果指定 `reset_value`,生成带复位的 DFF +- 等效于 Verilog 的 `always @(posedge clk)` 块 + +#### 示例 + +```python +# 简单寄存器 +data_reg = self.clock_domain.cycle(data, "Data register") + +# 带复位值的计数器 +counter_reg = self.clock_domain.cycle(counter_next, reset_value=0) | "Counter register" + +# 流水线寄存器 +stage1_reg = self.clock_domain.cycle(stage0_out, "Pipeline stage 1") +stage2_reg = self.clock_domain.cycle(stage1_reg, "Pipeline stage 2") +``` + +--- + +### Nested Module(嵌套模块) + +PyCircuit 支持模块的层次化设计,允许在一个模块内实例化其他模块。 + +#### 语法 + +```python +# 在父模块的 build 方法中 +submodule = SubModuleClass("instance_name", self.clock_domain) +outputs = submodule.build(input1, input2) +``` + +#### 子模块周期隔离 + +子模块内部调用 `clock_domain.next()` 不会影响父模块的周期状态: + +```python +class ParentModule(pyc_CircuitModule): + def build(self, input_data): + with self.module(inputs=[input_data]) as mod: + # 父模块 Cycle 0 + processed = input_data | "Input" + + self.clock_domain.next() # 父模块推进到 Cycle 1 + + # 实例化子模块 + child = ChildModule("child", self.clock_domain) + result = child.build(processed) # 子模块内部可以有自己的 next() + + # 仍在父模块 Cycle 1(子模块的 next() 不影响父模块) + output = result | "Output" + mod.outputs = [output] +``` + +#### 层次化 vs 扁平化 + +PyCircuit 支持两种输出模式: + +1. **层次化模式(Hierarchical)**:保留模块边界,显示嵌套结构 +2. **扁平化模式(Flatten)**:展开所有子模块,信号名带模块前缀 + +```python +# 层次化模式 +hier_logger = pyc_CircuitLogger("circuit.txt", is_flatten=False) + +# 扁平化模式 +flat_logger = pyc_CircuitLogger("circuit.txt", is_flatten=True) +``` + +--- + +## 自动周期平衡 + +PyCircuit 的核心特性之一是自动周期平衡(Automatic Cycle Balancing)。 + +### 规则 + +当组合不同周期的信号时: +- **输出周期 ≥ max(输入周期)** +- 如果输入周期 < 输出周期:自动插入 `DFF`(延迟) +- 如果输入周期 > 输出周期:自动插入 `FB`(反馈) +- 如果输入周期 == 输出周期:直接使用 + +### 示例 + +```python +# sig_a 在 Cycle 0,sig_b 在 Cycle 2 +result = (sig_a & sig_b) | "Combined" +# 输出:result 在 Cycle 2,sig_a 自动延迟 2 个周期 +``` + +生成的描述: +``` +result = (DFF(DFF(sig_a)) & sig_b) + → Cycle balancing: inputs at [0, 2] → output at 2 +``` + +--- + +## 两种输出模式 + +### 层次化模式(Hierarchical Mode) + +保留模块层次结构,每个模块独立显示: + +``` +┌────────────────────────────────────────────────────────────────────────────┐ +│ MODULE: ParentModule │ +└────────────────────────────────────────────────────────────────────────────┘ + INPUTS: + • input_signal [cycle=0, domain=CLK] + + SUBMODULES: + • ChildModule + - Inputs: processed + - Outputs: result + + OUTPUTS: + • output [cycle=2, domain=CLK] +``` + +### 扁平化模式(Flatten Mode) + +展开所有子模块,信号名带模块前缀: + +``` +┌────────────────────────────────────────────────────────────────────────────┐ +│ MODULE: TopLevel │ +└────────────────────────────────────────────────────────────────────────────┘ + SIGNALS: + ChildModule.internal_sig = ... + ChildModule.result = ... + output = ChildModule.result +``` + +--- + +## 编程范例 + +### 范例1:频率分频器(testdivider.py) + +这是一个简单的频率分频器,将输入时钟分频为指定倍数。 + +#### 代码 + +```python +class FrequencyDivider(pyc_CircuitModule): + """ + 频率分频器模块 + """ + + def __init__(self, name, divide_by, input_clock_domain): + super().__init__(name, clock_domain=input_clock_domain) + self.divide_by = divide_by + self.counter_bits = (divide_by - 1).bit_length() + + def build(self, clk_in): + """构建分频器电路""" + with self.module( + inputs=[clk_in], + description=f"Frequency Divider: Divide by {self.divide_by}" + ) as mod: + # 初始化计数器(Cycle -1:初始化信号) + counter = signal[f"{self.counter_bits}-1:0"](value=0) | "Counter initial value" + + # 计数器逻辑 + counter_next = (counter + 1) | "Counter increment" + counter_eq = (counter == (self.divide_by - 1)) | f"Counter == {self.divide_by-1}" + counter_wrap = mux(counter_eq, 0, counter_next) | "Counter wrap-around" + + self.clock_domain.next() # 推进到下一周期 + + # 更新计数器(反馈) + counter = counter_wrap | "update counter" + + # 输出使能信号 + clk_enable = (counter == (self.divide_by - 1)) | "Clock enable output" + + mod.outputs = [clk_enable] + + return clk_enable +``` + +#### 使用方法 + +```python +def main(): + # 创建时钟域 + clk_domain = pyc_ClockDomain("DIV_CLK", "Divider clock domain") + clk_domain.create_reset() + + clk_domain.next() + clk_in = clk_domain.create_signal("clock_in") + + # 实例化分频器 + divider = FrequencyDivider("Divider13", 13, clk_domain) + clk_enable = divider.build(clk_in) +``` + +#### 生成的电路描述 + +**层次化模式(hier_testdivider.txt):** + +``` +================================================================================ +CIRCUIT DESCRIPTION (HIERARCHICAL MODE) +================================================================================ + +┌────────────────────────────────────────────────────────────────────────────┐ +│ MODULE: Divider13 │ +│ Frequency Divider: Divide by 13 │ +└────────────────────────────────────────────────────────────────────────────┘ + + INPUTS: + • clock_in [cycle=-1, domain=DIV_CLK] + + SIGNALS: + + ────────────────────────────────────────────────────────────────────── + CYCLE -1 + ────────────────────────────────────────────────────────────────────── + + counter = forward_declare("Counter initial value") + // Counter initial value + + + ────────────────────────────────────────────────────────────────────── + CYCLE 1 + ────────────────────────────────────────────────────────────────────── + + counter_next = (counter + 1) + // Counter increment + → Cycle balancing: inputs at [-1] → output at 1 + + counter_eq = (counter == (self.divide_by - 1)) + // Counter == 12 + → Cycle balancing: inputs at [-1] → output at 1 + + counter_wrap = mux(counter_eq, 0, counter_next) + // Counter wrap-around (mux) + + + ────────────────────────────────────────────────────────────────────── + CYCLE 2 + ────────────────────────────────────────────────────────────────────── + + counter = counter_wrap + // Feedback: update counter + → Cycle balancing: inputs at [1] → output at 2 + + clk_enable = (counter == (self.divide_by - 1)) + // Clock enable output + + OUTPUTS: + • clk_enable [cycle=2, domain=DIV_CLK] +``` + +#### 电路图 + +![Hierarchical Divider](hier_testdivider.pdf) + +![Flatten Divider](flat_testdivider.pdf) + +--- + +### 范例2:实时时钟系统(testproject.py) + +这是一个完整的实时时钟系统,包含: +- 高频振荡器时钟域 +- 频率分频器(1024分频) +- 带 SET/PLUS/MINUS 按钮的实时时钟 + +#### 多时钟域示例 + +```python +# 创建两个独立的时钟域 +osc_domain = pyc_ClockDomain("OSC_CLK", "High-frequency oscillator domain") +rtc_domain = pyc_ClockDomain("RTC_CLK", "1Hz RTC domain") + +# 各自创建复位信号 +osc_rst = osc_domain.create_reset() +rtc_rst = rtc_domain.create_reset() +``` + +#### 实时时钟模块 + +```python +class RealTimeClock(pyc_CircuitModule): + """带按钮控制的实时时钟""" + + STATE_RUNNING = 0 + STATE_SETTING_HOUR = 1 + STATE_SETTING_MINUTE = 2 + STATE_SETTING_SECOND = 3 + + def __init__(self, name, rtc_clock_domain): + super().__init__(name, clock_domain=rtc_clock_domain) + + def build(self, clk_enable, set_btn, plus_btn, minus_btn): + with self.module( + inputs=[clk_enable, set_btn, plus_btn, minus_btn], + description="Real-Time Clock with SET/PLUS/MINUS control" + ) as mod: + # 初始化时间计数器 + sec = signal[5:0](value=0) | "Seconds" + min = signal[5:0](value=0) | "Minutes" + hr = signal[4:0](value=0) | "Hours" + state = signal[1:0](value=self.STATE_RUNNING) | "State" + + self.clock_domain.next() + + # 状态机逻辑 + state_is_running = (state == self.STATE_RUNNING) | "Check RUNNING" + # ... 更多逻辑 ... + + self.clock_domain.next() + + # 寄存时间值 + seconds_out = self.clock_domain.cycle(sec_next, reset_value=0) + minutes_out = self.clock_domain.cycle(min_next, reset_value=0) + hours_out = self.clock_domain.cycle(hr_next, reset_value=0) + + mod.outputs = [seconds_out, minutes_out, hours_out, state] +``` + +#### 生成的电路描述 + +**层次化模式部分输出(hier_circuit.txt):** + +``` +┌────────────────────────────────────────────────────────────────────────────┐ +│ MODULE: FreqDiv1024 │ +│ Frequency Divider: Divide by 1024 │ +└────────────────────────────────────────────────────────────────────────────┘ + + INPUTS: + • oscillator_in [cycle=-1, domain=OSC_CLK] + + SIGNALS: + ... + + OUTPUTS: + • clk_enable [cycle=3, domain=OSC_CLK] + + +┌────────────────────────────────────────────────────────────────────────────┐ +│ MODULE: RTC │ +│ Real-Time Clock with SET/PLUS/MINUS control buttons │ +└────────────────────────────────────────────────────────────────────────────┘ + + INPUTS: + • clk_enable [cycle=3, domain=OSC_CLK] + • SET_btn [cycle=-1, domain=RTC_CLK] + • PLUS_btn [cycle=-1, domain=RTC_CLK] + • MINUS_btn [cycle=-1, domain=RTC_CLK] + ... +``` + +#### 电路图 + +**频率分频器模块:** + +![FreqDiv1024](hier_FreqDiv1024.pdf) + +**实时时钟模块:** + +![RTC](hier_RTC.pdf) + +**扁平化模式完整电路:** + +![Flatten Circuit](flat_circuit_diagram.pdf) + +--- + +### 范例3:RISC-V CPU(riscv.py) + +这是一个完整的 RISC-V CPU 实现,展示了 PyCircuit 处理复杂层次化设计的能力。 + +#### CPU 结构 + +``` +RISCVCpu +├── InstructionDecoder (指令解码器) +├── RegisterFile (寄存器文件) +├── ALU (算术逻辑单元) +└── ExceptionHandler (异常处理器) +``` + +#### 5 级流水线实现 + +```python +class RISCVCpu(pyc_CircuitModule): + def build(self, instruction_mem_data, data_mem_data, interrupt_req): + with self.module(inputs=[...]) as mod: + # ========== STAGE 1: INSTRUCTION FETCH ========== + pc = signal[31:0](value=0) | "Program Counter" + + self.clock_domain.next() # Cycle 1 + pc_next = pc + 4 | "PC + 4" + instruction = instruction_mem_data | "Fetched instruction" + + # ========== STAGE 2: INSTRUCTION DECODE ========== + self.clock_domain.next() # Cycle 2 + instruction_reg = self.clock_domain.cycle(instruction) + + # 实例化解码器子模块 + decoder = InstructionDecoder("Decoder", self.clock_domain) + (opcode, funct3, ...) = decoder.build(instruction_reg) + + # 实例化寄存器文件 + reg_file = RegisterFile("RegFile", self.clock_domain) + rs1_data, rs2_data = reg_file.build(rs1, rs2, ...) + + # ========== STAGE 3: EXECUTE ========== + self.clock_domain.next() # Cycle 3 + + # 实例化 ALU + alu = ALU("ALU", self.clock_domain) + alu_result, zero_flag, lt_flag = alu.build(...) + + # ========== STAGE 4: MEMORY ACCESS ========== + self.clock_domain.next() # Cycle 4 + + # 异常处理 + exc_handler = ExceptionHandler("ExceptionHandler", self.clock_domain) + exception_valid, exception_code, ... = exc_handler.build(...) + + # ========== STAGE 5: WRITE BACK ========== + self.clock_domain.next() # Cycle 5 + + wb_data = mux(mem_read_wb, mem_data_wb, alu_result_wb) | "Write-back data" +``` + +#### 子模块示例:ALU + +```python +class ALU(pyc_CircuitModule): + """算术逻辑单元""" + + ALU_ADD = 0 + ALU_SUB = 1 + ALU_AND = 2 + # ... 更多操作码 + + def build(self, operand_a, operand_b, alu_op): + with self.module(inputs=[operand_a, operand_b, alu_op]) as mod: + # 算术运算 + add_result = (operand_a + operand_b) | "ALU ADD" + sub_result = (operand_a - operand_b) | "ALU SUB" + + # 逻辑运算 + and_result = (operand_a & operand_b) | "ALU AND" + or_result = (operand_a | operand_b) | "ALU OR" + + # 使用 mux 链选择结果 + result = mux(alu_op == self.ALU_SUB, sub_result, add_result) + result = mux(alu_op == self.ALU_AND, and_result, result) + # ... + + mod.outputs = [result, zero_flag, lt_flag] +``` + +#### 生成的电路描述 + +**层次化模式(hier_riscv.txt)部分:** + +``` +┌────────────────────────────────────────────────────────────────────────────┐ +│ MODULE: RISCVCpu │ +│ RISC-V CPU: 5-stage pipeline with precise exception handling │ +└────────────────────────────────────────────────────────────────────────────┘ + + INPUTS: + • instruction_mem_data [cycle=-1, domain=CPU_CLK] + • data_mem_data [cycle=-1, domain=CPU_CLK] + • interrupt_req [cycle=-1, domain=CPU_CLK] + + SUBMODULES: + • Decoder + • RegFile + • ALU + • ExceptionHandler + + OUTPUTS: + • pc [cycle=6, domain=CPU_CLK] + • instruction_mem_addr [cycle=6, domain=CPU_CLK] + ... +``` + +#### 电路图 + +**RISC-V CPU 顶层模块(层次化):** + +![RISC-V CPU](hier_riscv_RISCVCpu.pdf) + +**指令解码器模块:** + +![Decoder](hier_riscv_Decoder.pdf) + +**寄存器文件模块:** + +![RegFile](hier_riscv_RegFile.pdf) + +**ALU 模块:** + +![ALU](hier_riscv_ALU.pdf) + +**扁平化模式完整 CPU:** + +![Flatten RISC-V](flat_riscv_RISCVCpu.pdf) + +--- + +## 生成的电路图 + +PyCircuit 使用 `pyVisualize` 模块生成电路图,支持 PDF 和 PNG 格式。 + +### 使用方法 + +```python +from pyVisualize import visualize_circuit + +# 生成完整电路图 +pdf_file = visualize_circuit( + logger, + figsize=(18, 14), + output_file="circuit_diagram.pdf" +) + +# 生成单个模块的电路图 +module_pdf = visualize_circuit( + logger, + module_name="ALU", + output_file="alu_diagram.pdf" +) +``` + +### 输出文件列表 + +| 文件名 | 说明 | +|--------|------| +| `hier_testdivider.txt` | 分频器层次化描述 | +| `flat_testdivider.txt` | 分频器扁平化描述 | +| `hier_testdivider.pdf` | 分频器层次化电路图 | +| `flat_testdivider.pdf` | 分频器扁平化电路图 | +| `hier_circuit.txt` | RTC系统层次化描述 | +| `flat_circuit.txt` | RTC系统扁平化描述 | +| `hier_FreqDiv1024.pdf` | 频率分频器电路图 | +| `hier_RTC.pdf` | 实时时钟电路图 | +| `hier_riscv.txt` | RISC-V CPU 层次化描述 | +| `flat_riscv.txt` | RISC-V CPU 扁平化描述 | +| `hier_riscv_*.pdf` | 各模块层次化电路图 | +| `flat_riscv_*.pdf` | 扁平化电路图 | + +--- + +## 最佳实践 + +### 1. 模块设计原则 + +```python +class GoodModule(pyc_CircuitModule): + def __init__(self, name, clock_domain, param1, param2): + super().__init__(name, clock_domain=clock_domain) + self.param1 = param1 # 保存配置参数 + self.param2 = param2 + + def build(self, input1, input2): + # 使用 with 语句管理模块上下文 + with self.module( + inputs=[input1, input2], + description=f"Module with param1={self.param1}" + ) as mod: + # 模块逻辑 + result = ... + + # 明确设置输出 + mod.outputs = [result] + + return result # 返回输出信号供父模块使用 +``` + +### 2. 信号命名规范 + +```python +# ✓ 好的命名 +counter_next = (counter + 1) | "Counter next value" +data_valid_reg = self.clock_domain.cycle(data_valid) | "Registered valid" + +# ✗ 避免的命名 +x = (a + b) | "Some signal" # 太简短 +temp = result | "" # 无描述 +``` + +### 3. 周期管理 + +```python +# ✓ 明确标记周期边界 +self.clock_domain.next() # Cycle N -> N+1 + +# 使用 cycle() 创建寄存器 +registered_data = self.clock_domain.cycle(data, reset_value=0) | "Registered data" + +# ✓ 理解自动周期平衡 +# 当组合不同周期的信号时,系统会自动插入延迟 +``` + +### 4. 层次化设计 + +```python +# ✓ 合理拆分模块 +class TopLevel(pyc_CircuitModule): + def build(self, ...): + with self.module(...) as mod: + # 实例化功能子模块 + decoder = Decoder("decoder", self.clock_domain) + alu = ALU("alu", self.clock_domain) + + # 连接子模块 + decoded = decoder.build(instruction) + result = alu.build(op_a, op_b, alu_op) +``` + +### 5. 调试技巧 + +```python +# 使用描述帮助调试 +signal_name = expression | "Descriptive comment for debugging" + +# 检查生成的 .txt 文件确认: +# - 信号周期是否正确 +# - 自动周期平衡是否如预期 +# - 模块层次是否正确 +``` + +--- + +## 附录:API 参考 + +### pyc_ClockDomain + +| 方法 | 说明 | +|------|------| +| `__init__(name, frequency_desc, reset_active_high)` | 创建时钟域 | +| `create_reset()` | 创建复位信号 | +| `create_signal(name)` | 创建输入信号 | +| `next()` | 推进时钟周期(周期 +1) | +| `prev()` | 回退时钟周期(周期 -1) | +| `push()` | 保存当前周期状态到栈 | +| `pop()` | 从栈恢复周期状态 | +| `cycle(signal, description, reset_value)` | 创建寄存器(DFF) | + +### pyc_CircuitModule + +| 方法 | 说明 | +|------|------| +| `__init__(name, clock_domain)` | 初始化模块 | +| `module(inputs, description)` | 模块上下文管理器 | +| `build(...)` | 构建模块逻辑(需重写) | + +### pyc_CircuitLogger + +| 方法 | 说明 | +|------|------| +| `__init__(filename, is_flatten)` | 创建日志器 | +| `write_to_file()` | 写入电路描述文件 | +| `reset()` | 重置日志器状态 | + +### 全局函数 + +| 函数 | 说明 | +|------|------| +| `signal[high:low](value=...)` | 创建信号 | +| `mux(condition, true_val, false_val)` | 多路选择器 | +| `log(signal)` | 记录信号(用于调试) | + +--- + +**Copyright © 2024 Liao Heng. All rights reserved.** + diff --git a/docs/PyCurcit V5_CYCLE_AWARE_API.md b/docs/PyCurcit V5_CYCLE_AWARE_API.md new file mode 100644 index 0000000..073ecf6 --- /dev/null +++ b/docs/PyCurcit V5_CYCLE_AWARE_API.md @@ -0,0 +1,387 @@ +# PyCircuit Cycle-Aware API Reference + +**Version: 2.0** + +--- + +## Overview + +The cycle-aware system is a new programming paradigm for PyCircuit that tracks signal timing cycles automatically. Key features include: + +- **Cycle-aware Signals**: Each signal carries its cycle information +- **Automatic Cycle Balancing**: Automatic DFF insertion when combining signals of different cycles +- **Domain-based Cycle Management**: `next()`, `prev()`, `push()`, `pop()` methods for cycle control +- **JIT Compilation**: Python source code compiles to MLIR hardware description + +## Installation + +```python +from pycircuit import ( + CycleAwareCircuit, + CycleAwareDomain, + CycleAwareSignal, + compile_cycle_aware, + mux, +) +``` + +--- + +## Core Components + +### CycleAwareCircuit + +The main circuit builder class that manages clock domains and signal generation. + +```python +m = CycleAwareCircuit("my_circuit") +``` + +**Methods:** + +| Method | Description | +|--------|-------------| +| `create_domain(name)` | Create a new clock domain | +| `get_default_domain()` | Get the default clock domain | +| `const_signal(value, width, domain)` | Create a constant signal | +| `input_signal(name, width, domain)` | Create an input signal | +| `output(name, signal)` | Register an output signal | +| `emit_mlir()` | Generate MLIR representation | + +### CycleAwareDomain + +Manages clock cycle state for a specific clock domain. + +```python +domain = m.create_domain("clk") +``` + +**Methods:** + +| Method | Description | +|--------|-------------| +| `create_signal(name, width)` | Create an input signal | +| `create_const(value, width, name)` | Create a constant signal | +| `next()` | Advance current cycle by 1 | +| `prev()` | Decrease current cycle by 1 | +| `push()` | Save current cycle to stack | +| `pop()` | Restore cycle from stack | +| `cycle(signal, reset_value, name)` | Insert DFF register | + +### CycleAwareSignal + +Wrapper that carries cycle information along with the underlying MLIR signal. + +**Attributes:** + +| Attribute | Description | +|-----------|-------------| +| `sig` | Underlying MLIR Signal | +| `cycle` | Current cycle number | +| `domain` | Associated CycleAwareDomain | +| `name` | Signal name for debugging | +| `signed` | Whether signal is signed | + +**Operator Overloading:** + +All standard Python operators are overloaded with automatic cycle balancing: + +```python +# Arithmetic +result = a + b # Addition +result = a - b # Subtraction +result = a * b # Multiplication + +# Bitwise +result = a & b # AND +result = a | b # OR +result = a ^ b # XOR +result = ~a # NOT +result = a << n # Left shift +result = a >> n # Right shift + +# Comparison +result = a.eq(b) # Equal +result = a.lt(b) # Less than +result = a.gt(b) # Greater than +result = a.le(b) # Less or equal +result = a.ge(b) # Greater or equal +``` + +**Signal Methods:** + +| Method | Description | +|--------|-------------| +| `select(true_val, false_val)` | Conditional selection (mux) | +| `trunc(width)` | Truncate to width bits | +| `zext(width)` | Zero extend to width bits | +| `sext(width)` | Sign extend to width bits | +| `slice(high, low)` | Extract bit slice | +| `named(name)` | Add debug name | +| `as_signed()` | Mark as signed | +| `as_unsigned()` | Mark as unsigned | + +--- + +## Automatic Cycle Balancing + +When combining signals with different cycles, the system automatically inserts DFF chains to align timing. + +### Rule + +``` +output_cycle = max(input_cycles) +earlier_signals → automatically delayed via DFF insertion +``` + +### Example + +```python +def design(m: CycleAwareCircuit, domain: CycleAwareDomain): + # Cycle 0: Input + data_in = domain.create_signal("data_in", width=8) + + # Save reference at Cycle 0 + data_at_cycle0 = data_in + + domain.next() # -> Cycle 1 + stage1 = domain.cycle(data_in, reset_value=0, name="stage1") + + domain.next() # -> Cycle 2 + stage2 = domain.cycle(stage1, reset_value=0, name="stage2") + + # data_at_cycle0 is at Cycle 0, stage2 is at Cycle 2 + # System automatically inserts 2-level DFF chain for data_at_cycle0 + combined = data_at_cycle0 + stage2 # Output at Cycle 2 + + m.output("result", combined.sig) +``` + +Generated MLIR shows automatic DFF insertion: + +```mlir +%data_delayed1 = pyc.reg %clk, %rst, %en, %data_at_cycle0, %reset_val : i8 +%data_delayed2 = pyc.reg %clk, %rst, %en, %data_delayed1, %reset_val : i8 +%result = pyc.add %data_delayed2, %stage2 : i8 +``` + +--- + +## Cycle Management + +### next() / prev() + +Advance or decrease the current cycle counter. + +```python +# Cycle 0 +a = domain.create_signal("a", width=8) + +domain.next() # -> Cycle 1 +b = domain.cycle(a, name="b") + +domain.next() # -> Cycle 2 +c = domain.cycle(b, name="c") + +domain.prev() # -> Cycle 1 +# Can add more signals at Cycle 1 +d = (a + 1) # Also at Cycle 1 (with auto balancing) +``` + +### push() / pop() + +Save and restore cycle state for nested function calls. + +```python +def helper_function(domain: CycleAwareDomain, data): + domain.push() # Save caller's cycle + + # Internal cycle management + domain.next() + result = domain.cycle(data, name="helper_reg") + domain.next() + final = result + 1 + + domain.pop() # Restore caller's cycle + return final + +def main_design(m: CycleAwareCircuit, domain: CycleAwareDomain): + data = domain.create_signal("data", width=8) + + # Call helper - its internal next() doesn't affect our cycle + result = helper_function(domain, data) + + # Still at our original cycle + domain.next() # Our own cycle advancement +``` + +### cycle() + +Insert a DFF register (single-cycle delay). + +```python +# Basic register +reg = domain.cycle(data, name="data_reg") + +# Register with reset value +counter_reg = domain.cycle(counter_next, reset_value=0, name="counter") +``` + +--- + +## JIT Compilation + +### compile_cycle_aware() + +Compile a Python function to a CycleAwareCircuit. + +```python +def my_design(m: CycleAwareCircuit, domain: CycleAwareDomain, width: int = 8): + # Design logic + data = domain.create_signal("data", width=width) + processed = data + 1 + domain.next() + output = domain.cycle(processed, name="output") + m.output("out", output.sig) + +# Compile +circuit = compile_cycle_aware(my_design, name="my_circuit", width=16) + +# Generate MLIR +mlir_code = circuit.emit_mlir() +``` + +### Parameters + +| Parameter | Description | +|-----------|-------------| +| `fn` | Python function to compile | +| `name` | Circuit name (optional) | +| `domain_name` | Default clock domain name (default: "clk") | +| `**params` | Additional parameters passed to function | + +### Return Statement + +The JIT compiler handles return statements by registering outputs: + +```python +def design(m: CycleAwareCircuit, domain: CycleAwareDomain): + data = domain.create_signal("data", width=8) + result = data + 1 + return result # Automatically becomes output "result" +``` + +--- + +## Global Functions + +### mux() + +Conditional selection with automatic cycle balancing. + +```python +result = mux(condition, true_value, false_value) +``` + +**Parameters:** + +- `condition`: CycleAwareSignal (1-bit) for selection +- `true_value`: Value when condition is true (CycleAwareSignal or int) +- `false_value`: Value when condition is false (CycleAwareSignal or int) + +**Example:** + +```python +enable = domain.create_signal("enable", width=1) +data = domain.create_signal("data", width=8) +result = mux(enable, data + 1, data) # Increment when enabled +``` + +--- + +## Complete Example + +```python +# -*- coding: utf-8 -*- +"""Counter with enable - cycle-aware implementation.""" + +from pycircuit import ( + CycleAwareCircuit, + CycleAwareDomain, + compile_cycle_aware, + mux, +) + + +def counter_with_enable( + m: CycleAwareCircuit, + domain: CycleAwareDomain, + width: int = 8, +): + """8-bit counter with enable control.""" + + # Cycle 0: Inputs + enable = domain.create_signal("enable", width=1) + + # Counter initial value + count = domain.create_const(0, width=width, name="count_init") + + # Combinational logic + count_next = count + 1 + count_with_enable = mux(enable, count_next, count) + + # Cycle 1: Register + domain.next() + count_reg = domain.cycle(count_with_enable, reset_value=0, name="count") + + # Output + m.output("count", count_reg.sig) + + +if __name__ == "__main__": + circuit = compile_cycle_aware(counter_with_enable, name="counter", width=8) + print(circuit.emit_mlir()) +``` + +--- + +## Migration from Legacy API + +| Legacy API | Cycle-Aware API | +|------------|-----------------| +| `Circuit` | `CycleAwareCircuit` | +| `ClockDomain` | `CycleAwareDomain` | +| `Wire` / `Reg` | `CycleAwareSignal` | +| `compile()` | `compile_cycle_aware()` | +| Manual DFF insertion | Automatic via `domain.cycle()` | +| No cycle tracking | Full cycle tracking | + +--- + +## Best Practices + +1. **Use descriptive names**: The `named()` method helps with debugging + ```python + result = (a + b).named("sum_ab") + ``` + +2. **Mark cycle boundaries clearly**: Use comments to document pipeline stages + ```python + # === Stage 1: Fetch === + domain.next() + ``` + +3. **Use push/pop for helper functions**: Avoid cycle state leakage + ```python + def helper(domain, data): + domain.push() + # ... logic ... + domain.pop() + return result + ``` + +4. **Let automatic balancing work**: Trust the system to insert DFFs when needed + +--- + +**Copyright (C) 2024-2026 PyCircuit Contributors** diff --git a/docs/cycle_balance_improvement.md b/docs/cycle_balance_improvement.md new file mode 100644 index 0000000..fe6c23a --- /dev/null +++ b/docs/cycle_balance_improvement.md @@ -0,0 +1,100 @@ +# Cycle balance 设计改进(pyCircuit) + +## 1. 背景与问题 + +在 **cycle-aware** 编译模型中,每个数据值关联一个 **逻辑周期索引(occurrence / stage cycle)**,表示该值在流水线或调度语义下“有效”的周期。当 `pyc.assign` 的左值(目标线网)与右值在该索引上不一致时,编译器需要在右值侧插入 **寄存器(DFF / `pyc.reg`)** 做 **cycle balance**,使对齐后的右值与左值处于同一周期。 + +**Fanout 冗余问题**:若同一右值 SSA 被多个左值引用,且各自独立做 balance,可能在每条路径上各插一条等长延迟链,导致: + +- 寄存器与连线重复,面积与功耗上升; +- 行为虽可能对齐,但结构非最小。 + +**期望**:编译器应 **intern(复用)** 延迟结果——对同一 `(右值, 时钟上下文, 复位上下文, 延迟深度 d)` 只保留一条延迟链,所有需要 `d` 拍对齐的 `assign` 共用其输出。 + +## 2. 当前 pyCircuit 编译器实现(摘要) + +### 2.1 驱动与前端 + +- Python `pycircuit` 前端通过 `Module`/`Circuit` 生成文本 **`.pyc`(MLIR)**。 +- **pyc4.0 以 cycle-aware 为推荐主路径**:`m.clock()` 返回 **`ClockHandle`**,用 **`clk.next()`** 推进当前 occurrence;对 **`named_wire` 的 `m.assign`** 自动写入 `dst_cycle`/`src_cycle`;亦可显式传 `assign(..., dst_cycle=, src_cycle=)`。教程见 `docs/pyCircuit_Tutorial.md` §3.1。 + +### 2.2 `pycc` 流水线(与 cycle 相关的位置) + +典型优化与合法性顺序(节选,见 `compiler/mlir/tools/pycc.cpp`): + +1. 契约与层次:`pyc-check-frontend-contract`、`inline`、规范化、CSE、SCCP +2. 结构整理:`pyc-lower-scf-to-pyc-static` +3. **周期对齐**:`pyc-cycle-balance`(按 `dst_cycle`/`src_cycle` 插入并 **复用** 共享延迟寄存器) +4. 线网:`pyc-eliminate-wires`、`pyc-eliminate-dead-state`、`pyc-comb-canonicalize`、… +5. 合法性:`pyc-check-comb-cycles`、`pyc-check-clock-domains` +6. 寄存器打包:`pyc-pack-i1-regs` +7. 组合融合:`pyc-fuse-comb`(可选) +8. 深度统计:`pyc-check-logic-depth` + +组合环检查依赖 `pyc.reg` 等作为时序割点;`pyc-cycle-balance` 新增的寄存器同样参与该割集。 + +### 2.3 当前 PYC IR(与本文相关部分) + +| 构造 | 角色 | +|------|------| +| `pyc.wire` | 组合线网占位 | +| `pyc.assign` | `dst`(须为 `wire` 结果)← `src` | +| `pyc.reg` | `clk, rst, en, next, init` → `q` | +| `pyc.comb` | 融合组合区(与 tick/transfer 后端协作) | +| `pyc.instance` | 层次实例 | + +周期语义 **尚未** 作为一等类型出现在类型系统里;若引入 cycle balance,宜先用 **assign 上的可选属性** 或独立 metadata pass 输入,再逐步规范化。 + +## 3. 设计目标(新要求) + +1. **正确性**:`dst_cycle` 与 `src_cycle` 给定且 `dst_cycle >= src_cycle` 时,插入 `dst_cycle - src_cycle` 拍延迟,使驱动 `dst` 的数据与左值周期一致(在单时钟域、与既有 `tick/transfer` 语义一致的前提下)。 +2. **共享延迟**:同一 `(src, clk, rst, d)` 只构建 **一条** `d` 级寄存器链(或等价结构),多 `assign` 复用最后一级 `q`(及中间级若需要)。 +3. **时钟域**:首版可要求 **单主时钟/复位**(与模块内既有 `pyc.reg` 一致);多域需显式扩展(绑定到域 ID 或不同 `clk/rst` 对)。 +4. **可观测性**:插入的寄存器可带 `pyc.name` 前缀(如 `pyc_cyclebal_`)便于波形与调试。 +5. **默认无行为**:未携带周期属性的 `pyc.assign` 与今保持一致,保证现有设计零差异。 + +## 4. 实现方案概要 + +### 4.1 IR 扩展 + +在 `pyc.assign` 上增加 **可选** 属性: + +- `dst_cycle`:`i64`,左值周期索引 +- `src_cycle`:`i64`,右值周期索引 + +约定:二者 **同时出现或同时省略**;若出现,必须 `dst_cycle >= src_cycle`。深度 `d = dst_cycle - src_cycle`;`d == 0` 时不插入寄存器,并可剥离属性。 + +### 4.2 新 Pass:`pyc-cycle-balance` + +- **作用域**:`func.func` 内(与多数 PYC transform 一致)。 +- **算法要点**: + - 从函数体中解析 **默认 `clk/rst`**(例如取第一个 `pyc.reg` 的时钟与复位;若存在多组不一致则报错)。 + - 对每个带周期属性的 `pyc.assign`,计算 `d`,调用 `getOrCreateDelayed(src, d, clk, rst)`: + - 内部缓存 `map[(src,clk,rst,d)] → q`; + - 递归构造:`delayed(src,0)=src`;`delayed(src,k)` = 一级 `pyc.reg`,`next = delayed(src,k-1)`,`en = 1`,`init = 0`。 + - 将 `assign` 的 `src` 操作数替换为延迟链输出;**移除**周期属性,避免重复执行。 +- **插入位置**:在对应 `pyc.assign` **之前**(保证 `src` 支配新寄存器)。 +- **流水线位置**:在 **`pyc-eliminate-wires` 之前** 运行——此时仍保留 `wire`+`assign` 形态,与 `assign` 校验一致。 + +### 4.3 后续可选工作 + +- 前端/Python 生成 `dst_cycle`/`src_cycle`。 +- 与 `pyc-check-clock-domains` 对齐:显式校验 balance 寄存器与目标 assign 的域一致。 +- 带 `en` 的流水线寄存(非恒 1)的精确语义与共享策略。 +- 在 `pyc-fuse-comb` 之后是否再跑一遍 CSE 以合并重复别名。 + +## 5. 文档索引 + +更细的步骤、文件清单与验收标准见 **`docs/cycle_balance_improvement_detailed_plan.md`**。 + +## 6. 实现落点(代码) + +| 组件 | 路径 | +|------|------| +| IR:`pyc.assign` 周期属性 | `compiler/mlir/include/pyc/Dialect/PYC/PYCOps.td` | +| 校验 | `compiler/mlir/lib/Dialect/PYC/PYCOps.cpp`(`AssignOp::verify`) | +| Pass | `compiler/mlir/lib/Transforms/CycleBalancePass.cpp`(`--pyc-cycle-balance`) | +| 注册与链接 | `Passes.h`、`compiler/mlir/CMakeLists.txt` | +| 流水线 | `compiler/mlir/tools/pycc.cpp`(`createCycleBalancePass` 位于 lower-scf 与 eliminate-wires 之间) | + +另:`pycc.cpp` 中对 `GreedyRewriteConfig` 使用 `setMaxIterations` / `setMaxNumRewrites`,以兼容 LLVM 21 将对应字段改为私有的变更。 diff --git a/docs/cycle_balance_improvement_detailed_plan.md b/docs/cycle_balance_improvement_detailed_plan.md new file mode 100644 index 0000000..13162ae --- /dev/null +++ b/docs/cycle_balance_improvement_detailed_plan.md @@ -0,0 +1,49 @@ +# Cycle balance 详细实施计划 + +本文档是 `cycle_balance_improvement.md` 的落地细化,并记录已执行项。 + +## 阶段 A:IR 与校验 + +| 步骤 | 内容 | 状态 | +|------|------|------| +| A1 | 在 `include/pyc/Dialect/PYC/PYCOps.td` 为 `PYC_AssignOp` 增加 `OptionalAttr`:`dst_cycle`、`src_cycle` | 已完成 | +| A2 | 重新 TableGen(构建时自动生成) | 随构建 | +| A3 | 在 `lib/Dialect/PYC/PYCOps.cpp` 的 `AssignOp::verify` 中:若仅一侧有属性则报错;若 `dst_cycle < src_cycle` 则报错 | 已完成 | + +## 阶段 B:Pass 实现 + +| 步骤 | 内容 | 状态 | +|------|------|------| +| B1 | 新增 `lib/Transforms/CycleBalancePass.cpp`:`OperationPass`,参数名 `pyc-cycle-balance` | 已完成 | +| B2 | `inferClkRst`:遍历 `pyc.reg` 取第一组 `(clk,rst)` 并检查全体一致;若无 `reg` 则尝试入口块 `!pyc.clock` / `!pyc.reset` 参数 | 已完成 | +| B3 | `getOrCreateDelayed`:`std::map` 键 `(src,clk,rst,depth)`;在 `inner` 定义之后插入下一级 `pyc.reg` 以保证支配 | 已完成 | +| B4 | 遍历带双属性的 `pyc.assign`:`d = dst - src`;`d==0` 删属性;`d>0` 替换 `src` 后删属性 | 已完成 | +| B5 | 插入的 `pyc.reg` 带 `pyc.name` = `pyc_cyclebal_N` | 已完成 | + +## 阶段 C:集成 + +| 步骤 | 内容 | 状态 | +|------|------|------| +| C1 | `include/pyc/Transforms/Passes.h` 声明 `createCycleBalancePass()` | 已完成 | +| C2 | `compiler/mlir/CMakeLists.txt` 将 `CycleBalancePass.cpp` 加入 `pyc_transforms` | 已完成 | +| C3 | `tools/pycc.cpp`:`pyc-lower-scf-to-pyc-static` → `pyc-cycle-balance` → `pyc-eliminate-wires` | 已完成 | + +## 阶段 D:验收 + +| 步骤 | 内容 | 状态 | +|------|------|------| +| D1 | 完整链接 `pycc` | 已在 LLVM 21 上通过;`pycc.cpp` 改用 `setMaxIterations` / `setMaxNumRewrites` | +| D2 | 手写 `.pyc`:两 `assign` 共享 `src`、相同 `d`,确认仅一条深度为 `d` 的寄存器链 | 建议 | +| D3 | 无周期属性的 IR | pass 为 no-op | + +## 风险与回滚 + +- **风险**:多组 `(clk,rst)` 的模块在存在带属性 `assign` 时会被拒绝。 +- **缓解**:无带属性 `assign` 时不做 `clk/rst` 一致性扫描。 +- **回滚**:从 `pycc` 移除 `createCycleBalancePass` 一行即可。 + +## 执行记录 + +- **IR**:`pyc.assign` 可选 `dst_cycle` / `src_cycle`(`i64`),须成对且 `dst_cycle >= src_cycle`。 +- **共享**:缓存键含原始 `src` 的 opaque 指针、`clk`/`rst`、`depth`;多 `assign` 同 `(src,d)` 复用同一末级 `q`。 +- **前端**:`Circuit.assign` / `Module.assign` 支持关键字参数 `dst_cycle`、`src_cycle`(须成对),生成带属性的 `pyc.assign`。 diff --git a/docs/designs_upgrade_to_v5.md b/docs/designs_upgrade_to_v5.md new file mode 100644 index 0000000..bf6e69f --- /dev/null +++ b/docs/designs_upgrade_to_v5.md @@ -0,0 +1,1626 @@ +# PyCircuit Designs — V5 Cycle-Aware 升级计划 + +**版本**: 1.0 +**日期**: 2026-03-26 + +--- + +## 目标 + +将 `designs/` 下**全部**设计升级为 PyCircuit V5 的 cycle-aware 编程风格: + +1. **函数签名** `(m: CycleAwareCircuit, domain: CycleAwareDomain, ...)` +2. **编译入口** `compile_cycle_aware(build, name=..., eager=True)` +3. **输入信号** 用 `cas(domain, m.input(...), cycle=0)` 包装为 `CycleAwareSignal` +4. **反馈寄存器** 用 `domain.state(width=..., reset_value=..., name=...)` 声明 +5. **流水级边界** 用 `domain.next()` 标记,不同周期的逻辑分段书写 +6. **组合选择** 用 `mux(cond, a, b)` 替代 `if Wire else` 或 `_select_internal()` +7. **管线寄存器** 用 `domain.cycle(sig, name=...)` 替代手动 `m.out().set()` +8. **子模块** 保留 `@module` / `m.new` / `m.array` 用法不变 + +--- + +## 改造难度分级 + +| 等级 | 含义 | 工作量 | +|------|------|--------| +| ★☆☆ | 纯组合或单寄存器,无 JIT `if Wire`,改签名+换 `domain.state()`/`mux()` 即可 | < 30 min | +| ★★☆ | 有 JIT `if Wire` 或多寄存器,需逐个替换为 `mux()` 并加 `domain.next()` | 1–3 h | +| ★★★ | 多级流水/复杂 FSM/大量 JIT 条件,需重构逻辑结构、划分 cycle 阶段 | 3–8 h | + +--- + +## 时序分类总览 + +在深入分析每个设计的源代码后,按**实际时序结构**分类如下: + +| 时序类型 | 设计数 | 设计列表 | +|---------|--------|---------| +| **纯组合** (0 寄存器) | 11 | jit_control_flow, hier_modules, module_collection, interface_wiring, instance_map, fastfwd, decode_rules, cache_params, arith, bundle_probe_expand, BypassUnit | +| **单寄存器反馈** | 6 | counter(1), wire_ops(1), obs_points(1), net_resolution_depth_smoke(1), xz_value_model_smoke(1), reset_invalidate_order_smoke(1) | +| **多寄存器/FSM** | 8 | multiclock_regs(2), digital_filter(5), digital_clock(6,FSM), calculator(5,FSM), traffic_lights_ce(5,FSM), dodgeball_game(14+2,FSM), trace_dsl_smoke(2子模块), issue_queue_2picker(8) | +| **多级流水线** | 3 | bf16_fmac(**30**寄存器/**4**级), jit_pipeline_vec(**6**/**3**级), pipeline_builder(**2**/**2**级) | +| **大型设计** | 2 | RegisterFile(**256** domain.state, 2 cycle), IssueQueue(**321** m.out, 单周期状态机) | +| **IP 封装** | 5 | fifo_loopback(rv_queue), mem_rdw_olddata(sync_mem), sync_mem_init_zero(sync_mem), npu_node(rv_queue×4), sw5809s(rv_queue×16 + 4寄存器) | +| **层次化** | 2 | huge_hierarchy_stress(叶子含寄存器), struct_transform(1 m.state) | +| **非硬件** | 1 | fm16_system(纯 Python 行为模型,无需迁移) | + +--- + +## 一、大型设计(designs/ 根目录) + +### 1. RegisterFile (`designs/RegisterFile/regfile.py`) — ✅ 已完成 + +| 项目 | 内容 | +|------|------| +| **功能** | 256 条目、128 常量 ROM、10R/5W、64-bit 参数化寄存器堆 | +| **时序类型** | **多寄存器 2-cycle 设计**(读写分相) | +| **寄存器数** | **256 个 `domain.state()`**(128 × bank0[32b] + 128 × bank1[32b]) | +| **端口** | 25 输入(10 raddr + 5 wen + 5 waddr + 5 wdata) · 10 输出(rdata0–9) | +| **当前状态** | **已完成 V5 改造** | + +#### 详细时序结构 + +``` +┌─── Cycle 0:组合读 ──────────────────────────────────────┐ +│ • 25 个输入 cas() 包装(raddr/wen/waddr/wdata) │ +│ • 256 个 domain.state() 声明(bank0[0..127], bank1[0..127])│ +│ • 对每个读口 i (0..9): │ +│ - 地址比较:是常量区? 是合法 ptag? │ +│ - mux() 选择:常量拼接 / 存储体读出 / 零值 │ +│ • m.output("rdata{i}", lane_data.wire) │ +│ • ~3860 次 mux() 调用 │ +├─── domain.next() ────────────────────────────────────────┤ +│ │ +├─── Cycle 1:同步写回 ────────────────────────────────────┐ +│ • 对每个存储项 sidx (0..127): │ +│ - 累加各写口的 hit → we_any │ +│ - mux() 链选出 next_lo / next_hi │ +│ - bank0[sidx].set(next_lo, when=we_any) │ +│ - bank1[sidx].set(next_hi, when=we_any) │ +└──────────────────────────────────────────────────────────┘ +``` + +| V5 API 使用 | 数量 | +|-------------|------| +| `cas()` | ~2135 | +| `mux()` | ~3860 | +| `domain.state()` | 256 | +| `domain.next()` | 1 | +| `domain.cycle()` | 0 | + +#### 验证状态 +- 29/29 功能测试通过,100K 周期仿真 57.4 Kcycles/s + +--- + +### 2. IssueQueue (`designs/IssueQueue/issq.py`) — ★★★ + +| 项目 | 内容 | +|------|------| +| **功能** | 多入多出发射队列:entry 状态管理、年龄矩阵排序、ptag 就绪广播、按龄优先发射 | +| **时序类型** | **大量寄存器的单周期状态机**(所有组合决策 + 状态更新在同一拍完成) | +| **寄存器数** | **321 个 `m.out()`**(默认 entries=16, ptag_count=64)| +| **端口** | `enq_ports`×(1+struct) 输入 · `issue_ports`×(1+struct) + `enq_ports` + 2 输出 | +| **JIT `if Wire`** | issq_config.py 中 4 处 | +| **`@function`** | issq.py 10 个 + issq_config.py 8 个 = **18 个** | + +#### 寄存器分解(默认参数) + +| 寄存器组 | 公式 | 默认数量 | 位宽/个 | 总 bit | +|---------|------|---------|---------|--------| +| entry 状态(valid/src/dst/payload) | entries | 16 | 57b | 912b | +| 年龄矩阵 `age_{i}_{j}` | entries² | 256 | 1b | 256b | +| 就绪表 `ready_ptag_{t}` | ptag_count | 64 | 1b | 64b | +| 已发射计数 `issued_total_q` | 1 | 1 | 16b | 16b | +| **合计** | | **321** m.out + 16 entry | | **~1248b** | + +#### 详细时序结构 + +``` +┌─── 单周期逻辑(当前拍输入 + 上拍状态 → 本拍输出 + 下拍状态)──┐ +│ │ +│ 1. _snapshot_entries:从 entry_state[0..15] 读取当前状态 │ +│ 2. _select_oldest_ready: │ +│ • entry_ready = valid & src0_ready & src1_ready │ +│ • 年龄矩阵仲裁 → 选最老 ready entry(one-hot) │ +│ • 多发射口串行扣除已选 → issue_sel[], issue_valid[] │ +│ 3. _allocate_enqueue_lanes: │ +│ • 在空槽上分配入队 → alloc_lane[], next_valid[] │ +│ 4. _emit_issue_ports: │ +│ • one-hot mux → iss{k}_valid, iss{k}_* 输出 │ +│ 5. _issue_wake_vectors: │ +│ • 同拍旁路 wakeup: wake_valid/wake_ptag │ +│ 6. _write_entry_next_state: │ +│ • 对每个 slot: keep / new_alloc 选择 │ +│ • src ready 合并: 原值 | ready_table查找 | 同拍wake旁路 │ +│ → entry_state[i].set(next) │ +│ 7. _update_age_state: │ +│ • age[i][j] 更新: keep+keep→保留, keep+new→1, new+new→lane_lt │ +│ → age[i][j].set(next) │ +│ 8. _update_ready_table: │ +│ • ready_state[t].set(old | wake_t) │ +│ 9. _emit_debug_and_ready: │ +│ • occupancy, issued_total 计数 → 输出 │ +│ • issued_total_q.set(issued_total_q.out() + issue_count) │ +└──────────────────────────────────────────────────────────────┘ +``` + +#### V5 改造方案 + +| 步骤 | 改造内容 | +|------|---------| +| **签名** | `def build(m: CycleAwareCircuit, domain: CycleAwareDomain, ...)` | +| **Cycle 0:输入** | `enq_valid/data/ptag` 用 `cas()` 包装 | +| **Cycle 0:状态声明** | 16 个 entry → `domain.state()` × 16(需按 struct 字段分别声明或用 batch API) | +| **Cycle 0:年龄矩阵** | 256 个 1-bit `domain.state(width=1, name=f"age_{i}_{j}")` | +| **Cycle 0:就绪表** | 64 个 `domain.state(width=1, name=f"ready_ptag_{t}")` | +| **Cycle 0:issued_total** | `domain.state(width=16, name="issued_total")` | +| **Cycle 0:仲裁逻辑** | `_select_oldest_ready` 保持组合;`issq_config.py` 中 4 处 `if Wire else` → `mux()` | +| **Cycle 0:输出** | `iss{k}_*`, `enq{k}_ready`, `occupancy` 在 cycle 0 组合输出 | +| **`domain.next()`** | → **Cycle 1:状态更新** | +| **Cycle 1** | 全部 `.set()` 调用:entry[i].set(next), age[i][j].set(next), ready[t].set(next), issued_total.set(next) | +| **`@function` 保留** | 18 个 `@function` 保持 Wire 级;不在其中使用 CAS 对象 | + +**关键难点:** +- entry 是结构化类型(valid/src0.ptag/src0.ready/…),需将 `m.state(uop_spec)` 拆分为多个 `domain.state()` 或扩展 V5 API 支持 struct state +- `@function` 辅助函数内部不能使用 `CycleAwareSignal`,需在调用前 `.wire` 解包、返回后 `cas()` 重包 +- 年龄矩阵 256 个 1-bit state 的声明与更新循环需保持 Python 循环展开 + +| **难度** | ★★★(321 寄存器 + 18 个辅助函数 + 结构化状态) | + +--- + +### 3. BypassUnit (`designs/BypassUnit/bypass_unit.py`) — ★★☆ + +| 项目 | 内容 | +|------|------| +| **功能** | 8-lane 旁路网络:按 ptag+ptype 在 w1/w2/w3 写回级与 RF 数据之间做优先级选择 | +| **时序类型** | **纯组合**(0 寄存器) | +| **寄存器数** | **0** | +| **端口** | **160 输入**(3 stage × 8 lane × 4 域 + 8 lane × 2 src × 4 域) · **64 输出**(8 lane × 2 src × 4 域) | +| **JIT `if Wire`** | **14 处**(`_select_stage` 2 处 × 8 lane + `_resolve_src` 4×3=12) | +| **`@function`** | 3 个:`_not1`, `_select_stage`, `_resolve_src` | + +#### 旁路优先级结构 + +``` +对每条 lane i、每个 src (srcL/srcR): + + _resolve_src(src_valid, src_ptag, src_ptype, src_rf_data, + w1[0..7], w2[0..7], w3[0..7]) + ├── _select_stage(w3[0..7]) → 如果 ptag+ptype 匹配 → hit_w3, data_w3 + ├── _select_stage(w2[0..7]) → 如果 ptag+ptype 匹配 → hit_w2, data_w2 + ├── _select_stage(w1[0..7]) → 如果 ptag+ptype 匹配 → hit_w1, data_w1 + └── 优先级链(更晚的 stage 优先): + out_data = data_w3 if hit_w3 else (data_w2 if hit_w2 else (data_w1 if hit_w1 else rf_data)) + out_hit = hit_w3 | hit_w2 | hit_w1 + out_stage = 3 if hit_w3 else (2 if hit_w2 else (1 if hit_w1 else 0)) + + 同一 stage 内 lane 优先级:lane 0 > lane 1 > ... > lane 7(先匹配先胜) +``` + +#### V5 改造方案 + +| 步骤 | 改造内容 | +|------|---------| +| **签名** | `def build(m: CycleAwareCircuit, domain: CycleAwareDomain, ...)` | +| **Cycle 0(唯一 cycle)** | 全部 160 个输入 `cas()` 包装 | +| **Cycle 0** | 14 处 `if Wire else` 全部替换为 `mux()`:`out_data = mux(hit_w1, data_w1, mux(hit_w2, data_w2, mux(hit_w3, data_w3, rf_data)))` | +| **Cycle 0** | `_select_stage` 内 `take = match & ~has` → `sel_data = mux(take, lane_data, sel_data)` | +| **输出** | 全部组合输出,**无 `domain.next()`** | +| **`@function` 保留** | 3 个 `@function` 保持;内部 `if Wire else` → `mux()` | + +**关键难点:** +- `_select_stage` 和 `_resolve_src` 内的条件链必须保持优先级语义 +- 替换时注意 `mux(cond, true_val, false_val)` 的参数顺序与 `true_val if cond else false_val` 一致 + +| **难度** | ★★☆(14 处 `if Wire` → `mux()`,纯组合无时序风险) | + +--- + +## 二、示例设计(designs/examples/) + +### 4. counter — ★☆☆ 【单寄存器反馈 · 2 cycle】 + +| 项目 | 内容 | +|------|------| +| **功能** | 使能可控的上行计数器 | +| **时序类型** | 单寄存器反馈 | +| **寄存器** | 1 个 `m.out("count_q", width=width)`,enable 门控 `+1` | +| **端口** | 1 输入 `enable` · 1 输出 `count` | +| **JIT `if Wire`** | 0 | + +#### V5 周期结构 + +``` +┌─── Cycle 0 ─────────────────────┐ +│ enable = cas(m.input("enable")) │ +│ count = domain.state(width=W) │ +│ m.output("count", count.wire) │ +├─── domain.next() ───────────────┤ +├─── Cycle 1 ─────────────────────┐ +│ count.set(mux(enable, count+1, count)) │ +└─────────────────────────────────┘ +``` + +--- + +### 5. multiclock_regs — ★☆☆ 【多时钟域 · 各域 2 cycle】 + +| 项目 | 内容 | +|------|------| +| **功能** | 两个独立时钟域各一个自增计数器 | +| **时序类型** | 2 个独立时钟域,各含 1 个自增寄存器 | +| **寄存器** | 2 个 `m.out()`:`a_count_q`(clk_a 域)、`b_count_q`(clk_b 域) | +| **端口** | 2 clk + 2 rst → 2 输出(`a_count`, `b_count`) | +| **JIT `if Wire`** | 0 | + +#### V5 周期结构(每个域) + +``` +┌─── domain_a Cycle 0 ────────────┐ +│ a = domain_a.state(width=W) │ +│ m.output("a_count", a.wire) │ +├─── domain_a.next() ─────────────┤ +├─── domain_a Cycle 1 ────────────┐ +│ a.set(a + 1) │ +└──────────────────────────────────┘ +(domain_b 同理) +``` + +**注意:** 多时钟域需在 `build` 内手动 `m.create_domain()` 创建额外域 + +--- + +### 6. wire_ops — ★★☆ 【单寄存器 · 2 cycle】 + +| 项目 | 内容 | +|------|------| +| **功能** | 按 `sel` 选择 `a & b` 或 `a ^ b`,结果打入寄存器输出 | +| **时序类型** | 组合选择 → 单寄存器捕获 | +| **寄存器** | 1 个 `m.out("r")`:存储 mux 结果 | +| **端口** | 3 输入(a, b, sel) · 1 输出(y) | +| **JIT `if Wire`** | **1 处**:`a & b if sel else a ^ b` | + +#### V5 周期结构 + +``` +┌─── Cycle 0 ─────────────────────┐ +│ a, b, sel = cas(m.input(...)) │ +│ result = mux(sel, a & b, a ^ b) │ +├─── domain.next() ───────────────┤ +├─── Cycle 1 ─────────────────────┐ +│ r = domain.cycle(result, name="r") │ +│ m.output("y", r.wire) │ +└──────────────────────────────────┘ +``` + +--- + +### 7. jit_control_flow — ★★☆ 【纯组合 · 单 cycle】 + +| 项目 | 内容 | +|------|------| +| **功能** | 按 `op` 对 `a, b` 做算术/逻辑运算,再固定轮数 `+1`,输出组合结果 | +| **时序类型** | **纯组合**(0 寄存器) | +| **寄存器** | 0 | +| **端口** | 3 输入(a, b, op) · 1 输出(result) | +| **JIT `if Wire`** | **4 处** `if/elif op == ...` | + +#### V5 周期结构 + +``` +┌─── Cycle 0(唯一 cycle)──────────────────────┐ +│ a, b, op = cas(m.input(...)) │ +│ r = mux(op==0, a+b, mux(op==1, a-b, ...)) │ +│ for _ in range(rounds): r = r + 1 # 展开 │ +│ m.output("result", r.wire) │ +│ 无 domain.next() │ +└────────────────────────────────────────────────┘ +``` + +--- + +### 8. fifo_loopback — ★☆☆ 【IP 封装 · 无自建寄存器】 + +| 项目 | 内容 | +|------|------| +| **功能** | `rv_queue` FIFO push/pop 回环测试 | +| **时序类型** | **IP 封装**(`m.rv_queue` 内含寄存器,外部无自建寄存器) | +| **自建寄存器** | 0(FIFO 寄存器在 `rv_queue` IP 内部) | +| **端口** | 3 输入(in_valid, in_data, out_ready) · 3 输出(in_ready, out_valid, out_data) | +| **JIT `if Wire`** | 0 | + +#### V5 周期结构 + +``` +┌─── Cycle 0 ──────────────────────────────────┐ +│ in_valid, in_data, out_ready = cas(m.input(...))│ +│ fifo = m.rv_queue(depth=2, width=W) │ +│ fifo 接口连接(Wire 级,保持不变) │ +│ m.output(...) │ +│ 无 domain.next()(IP 内部自管时序) │ +└──────────────────────────────────────────────┘ +``` + +--- + +### 9. hier_modules — ★☆☆ 【纯组合 · 单 cycle】 + +| 项目 | 内容 | +|------|------| +| **功能** | 辅助函数串行 `+1` 共 `stages=3` 次(组合链) | +| **时序类型** | **纯组合**(0 寄存器) | +| **寄存器** | 0 | +| **端口** | 1 输入(x) · 1 输出(y) | + +#### V5 周期结构(改为真流水的方案) + +``` +┌─── Cycle 0 ──────────────┐ +│ val = cas(m.input("x")) │ +│ val = val + 1 │ +├─── domain.next() ────────┤ +├─── Cycle 1 ──────────────┐ +│ val = domain.cycle(val) │ +│ val = val + 1 │ +├─── domain.next() ────────┤ +├─── Cycle 2 ──────────────┐ +│ val = domain.cycle(val) │ +│ val = val + 1 │ +│ m.output("y", val.wire) │ +└──────────────────────────┘ +``` + +> 注:若保持纯组合语义,则无需 `domain.next()`,仅 `cas()` 包装输入即可 + +--- + +### 10. bf16_fmac — ★★★ 【4 级流水线 · 5 cycle】 + +| 项目 | 内容 | +|------|------| +| **功能** | BF16×BF16 乘加 → FP32 累加器 | +| **时序类型** | **4 级流水线** + 反馈累加器 | +| **寄存器** | **30 个 `m.out()`** 手工管理的流水线寄存器 | +| **端口** | 4 输入(a_in, b_in, acc_in, valid) · 2 输出(result, out_valid) | +| **JIT `if Wire`** | **~20 处**(NaN/Inf/零/符号异常路径) | + +#### 实际流水线时序 + +``` +┌─── Cycle 0:Stage 1 解包 ─────────────────────────────────┐ +│ a_in, b_in, acc_in, valid = cas(m.input(...)) │ +│ 解包 BF16 → 指数 e_a/e_b、尾数 m_a/m_b、符号 s_a/s_b │ +│ 部分乘积启动;NaN/Inf/Zero 检测 │ +│ ~8 个流水寄存器锁存中间结果 │ +├─── domain.next() ────────────────────────────────────────┤ +├─── Cycle 1:Stage 2 乘法完成 ─────────────────────────────┐ +│ 完成 8×8 尾数乘 → 16-bit 乘积 │ +│ 指数相加 → 乘积指数 │ +│ ~8 个流水寄存器 │ +├─── domain.next() ────────────────────────────────────────┤ +├─── Cycle 2:Stage 3 对齐加减 ─────────────────────────────┐ +│ 指数对齐 → 尾数右移 │ +│ 尾数加减(同符号/异符号处理) │ +│ ~7 个流水寄存器 │ +├─── domain.next() ────────────────────────────────────────┤ +├─── Cycle 3:Stage 4 归一化打包 ───────────────────────────┐ +│ 前导零检测 → 归一化移位 │ +│ 舍入 → FP32 打包 │ +│ 异常优先级:NaN > Inf > Zero > Normal │ +│ m.output("result", ...) m.output("out_valid", ...) │ +│ acc 反馈 → domain.state() 或 domain.cycle() │ +│ ~7 个流水寄存器 │ +└──────────────────────────────────────────────────────────┘ +``` + +**改造要点:** +- 30 个 `m.out()` → `domain.cycle()` / `domain.state()` +- 3 个 `domain.next()` 分割 4 级流水 +- ~20 处 `if Wire else` → `mux()`(异常处理路径需嵌套 `mux`) +- 累加器反馈用 `domain.state()` + +--- + +### 11. digital_filter — ★★☆ 【移位寄存器 + 输出锁存 · 2 cycle】 + +| 项目 | 内容 | +|------|------| +| **功能** | 参数化 4-tap FIR 滤波器:移位寄存器 + MAC | +| **时序类型** | 移位寄存器链 + 组合 MAC + 输出锁存 | +| **寄存器** | **5 个 `m.out()`**:3 个延迟线 `tap[1..3]` + 1 个输出 `y` + 1 个 `y_valid` | +| **端口** | 2 输入(x_in, x_valid) · 2 输出(y_out, y_valid) | +| **JIT `if Wire`** | 0 | + +#### V5 周期结构 + +``` +┌─── Cycle 0:组合读取 + MAC ───────────────────────────────┐ +│ x_in, x_valid = cas(m.input(...)) │ +│ tap[0..3] = domain.state(width=W) × 4(含 x_in 即 tap[0]) │ +│ acc = Σ(coeff[i] * tap[i]) # 组合 MAC │ +│ m.output("y_out", acc.wire) │ +├─── domain.next() ────────────────────────────────────────┤ +├─── Cycle 1:移位 + 输出锁存 ─────────────────────────────┐ +│ tap[3].set(tap[2]); tap[2].set(tap[1]); tap[1].set(x_in) │ +│ y_valid_state.set(x_valid) │ +└──────────────────────────────────────────────────────────┘ +``` + +--- + +### 12. digital_clock — ★★★ 【FSM · 2 cycle】 + +| 项目 | 内容 | +|------|------| +| **功能** | 1Hz 预分频 + 4 模式 FSM (RUN/SET_HOUR/SET_MIN/SET_SEC) + BCD 输出 | +| **时序类型** | **FSM + 多寄存器**(6 个状态寄存器,单 `domain.next()` 分相) | +| **寄存器** | **6 个 `m.out()`**:prescaler, seconds, minutes, hours, mode, blink_cnt | +| **端口** | 3 输入(btn_mode, btn_set, btn_inc) · 5 输出(hours_bcd, minutes_bcd, seconds_bcd, mode, blink) | +| **JIT `if Wire`** | **~22 处**(FSM 状态转换 + BCD 进位链) | +| **`@function`** | 若干 BCD/计时辅助函数 | + +#### V5 周期结构 + +``` +┌─── Cycle 0:FSM 次态计算 ────────────────────────────────┐ +│ btn_mode, btn_set, btn_inc = cas(m.input(...)) │ +│ prescaler, sec, min, hr, mode, blink = domain.state() × 6│ +│ tick = (prescaler == 0) # 1Hz 节拍 │ +│ FSM 次态逻辑(全部 ~22 处 if → mux()): │ +│ next_mode = mux(btn_mode_pressed, mode+1, mode) │ +│ next_sec = mux(tick & is_RUN, sec+1, mux(...)) │ +│ ...(进位、设时、BCD 转换) │ +│ m.output("hours_bcd", ...) 等 │ +├─── domain.next() ────────────────────────────────────────┤ +├─── Cycle 1:状态更新 ───────────────────────────────────┐ +│ prescaler.set(next_prescaler) │ +│ sec.set(next_sec); min.set(next_min); hr.set(next_hr) │ +│ mode.set(next_mode); blink.set(next_blink) │ +└──────────────────────────────────────────────────────────┘ +``` + +--- + +### 13. calculator — ★★★ 【FSM · 2 cycle】 + +| 项目 | 内容 | +|------|------| +| **功能** | 16-bit 十进制计算器:数字输入/四则运算/等号/全清 | +| **时序类型** | **FSM**(输入模式 → 运算 → 输出) | +| **寄存器** | **5 个 `m.out()`**:lhs, rhs, op, display, input_state | +| **端口** | 2 输入(key_code, key_valid) · 2 输出(display, overflow) | +| **JIT `if Wire`** | **~14 处**(数字/运算符/等号判断) | + +#### V5 周期结构 + +``` +┌─── Cycle 0:组合计算 ─────────────────────────────────────┐ +│ key_code, key_valid = cas(m.input(...)) │ +│ lhs, rhs, op, display, state = domain.state() × 5 │ +│ is_digit = (key_code < 10) │ +│ is_op = ...; is_eq = ...; is_ac = ... │ +│ next_lhs = mux(is_digit & is_lhs_mode, lhs*10+key, ...) │ +│ next_display = mux(is_eq, result, mux(is_ac, 0, display))│ +│ m.output("display", display.wire) │ +├─── domain.next() ─────────────────────────────────────────┤ +├─── Cycle 1 ──────────────────────────────────────────────┐ +│ lhs.set(next_lhs); rhs.set(next_rhs); op.set(next_op) │ +│ display.set(next_display); state.set(next_state) │ +└──────────────────────────────────────────────────────────┘ +``` + +--- + +### 14. traffic_lights_ce — ★★★ 【FSM · 2 cycle】 + +| 项目 | 内容 | +|------|------| +| **功能** | 交通灯:4 相倒计时 (EW_GREEN→EW_YELLOW→NS_GREEN→NS_YELLOW) + 紧急覆盖 + 黄灯闪烁 | +| **时序类型** | **FSM + 多寄存器** | +| **寄存器** | **5 个 `m.out()`**:phase, countdown, prescaler, emergency_latch, blink_cnt | +| **端口** | 2 输入(emergency, pause) · 8 输出(ew_red/yellow/green, ns_red/yellow/green, countdown_bcd, phase) | +| **JIT `if Wire`** | **~27 处**(相位判断 + 紧急/暂停逻辑 + BCD) | + +#### V5 周期结构 + +``` +┌─── Cycle 0:次态逻辑 ──────────────────────────────────┐ +│ emergency, pause = cas(m.input(...)) │ +│ phase, countdown, prescaler, emg, blink = domain.state()×5│ +│ ~27 处 if Wire → mux() 链 │ +│ m.output(灯光信号 + BCD + phase) │ +├─── domain.next() ────────────────────────────────────────┤ +├─── Cycle 1 ───────────────────────────────────────────────┐ +│ phase.set(next_phase); countdown.set(next_countdown) │ +│ prescaler.set(next_prescaler); emg.set(next_emg) │ +│ blink.set(next_blink) │ +└──────────────────────────────────────────────────────────┘ +``` + +--- + +### 15–16. dodgeball_game — ★★★ 【FSM + VGA · 各 2 cycle】 + +| 项目 | 内容 | +|------|------| +| **功能** | `lab_final_VGA.py`:VGA 640×480@60Hz 时序;`lab_final_top.py`:3 态游戏 FSM + VGA + 碰撞 | +| **时序类型** | **FSM + 计数器** | +| **寄存器** | VGA: **2** (h_count, v_count) · Top: **14** (game_state, player_x/y, obstacle_x/y, score, tick_div, pixel_div, …) | +| **JIT `if Wire`** | VGA **7 处** + Top **~7 处** = **~14 处** | + +#### V5 周期结构(lab_final_VGA) + +``` +┌─── Cycle 0 ─────────────────────────┐ +│ h_count, v_count = domain.state()×2 │ +│ 组合输出:hsync, vsync, active, x, y │ +│ m.output(...) │ +├─── domain.next() ────────────────────┤ +├─── Cycle 1 ─────────────────────────┐ +│ h_count.set(mux(h_end, 0, h+1)) │ +│ v_count.set(mux(h_end, mux(v_end, 0, v+1), v))│ +└──────────────────────────────────────┘ +``` + +#### V5 周期结构(lab_final_top) + +``` +┌─── Cycle 0 ──────────────────────────────────────────┐ +│ btns, switches = cas(m.input(...)) │ +│ 14 个 domain.state() 声明 │ +│ game_state FSM (IDLE/PLAY/GAMEOVER) → mux() 链 │ +│ 碰撞检测、移动逻辑、VGA 子模块实例化(m.new 保留) │ +│ RGB 输出 → mux() 链 │ +│ m.output(vga_signals + rgb + score + leds) │ +├─── domain.next() ─────────────────────────────────────┤ +├─── Cycle 1 ──────────────────────────────────────────┐ +│ game_state.set(next_state); player_x.set(next_px); ...│ +└──────────────────────────────────────────────────────┘ +``` + +--- + +### 17. obs_points — ★☆☆ 【单寄存器 · 2 cycle】 + +| 项目 | 内容 | +|------|------| +| **寄存器** | 1 个 `m.out()`:采样保持 | +| **V5** | `r = domain.state(...)` → Cycle 0 读/输出 → `domain.next()` → Cycle 1 `r.set(x+1)` | + +--- + +### 18. net_resolution_depth_smoke — ★☆☆ 【单寄存器 · 2 cycle】 + +| 项目 | 内容 | +|------|------| +| **寄存器** | 1 个 `m.out()`:4 级组合加法后锁存 | +| **V5** | `r = domain.state(...)` → Cycle 0 组合 x+4 → `domain.next()` → Cycle 1 `r.set(...)` | + +--- + +### 19–20. mem_rdw_olddata / sync_mem_init_zero — ★☆☆ 【IP 封装 · 无自建寄存器】 + +| 项目 | 内容 | +|------|------| +| **时序类型** | `m.sync_mem` 同步存储器 IP 封装(IP 内部含寄存器) | +| **自建寄存器** | 0 | +| **V5** | 输入 `cas()` → IP 保持不变 → 输出 `cas()` | + +--- + +### 21. jit_pipeline_vec — ★★☆ 【3 级流水线 · 4 cycle】 + +| 项目 | 内容 | +|------|------| +| **功能** | `stages=3` 级寄存器流水,每级含 `if sel` 选择 | +| **时序类型** | **3 级流水线** | +| **寄存器** | **6 个 `m.out()`**:每级 1 tag(1b) + 1 data(16b) = 2 寄存器/级 | +| **JIT `if Wire`** | **1 处/级** → `mux(sel, a & b, a ^ b)` | + +#### V5 周期结构 + +``` +┌─── Cycle 0 ──────────────────────────────┐ +│ a, b, sel = cas(m.input(...)) │ +│ compare = (a < b) │ +│ data0 = mux(sel, a & b, a ^ b) │ +├─── domain.next() ────────────────────────┤ +├─── Cycle 1 ──────────────────────────────┐ +│ tag1 = domain.cycle(compare, name="t1") │ +│ data1 = domain.cycle(data0, name="d1") │ +│ data1 = mux(tag1, data1 | 0xFF, data1) │ +├─── domain.next() ────────────────────────┤ +├─── Cycle 2 ──────────────────────────────┐ +│ tag2 = domain.cycle(tag1, name="t2") │ +│ data2 = domain.cycle(data1, name="d2") │ +│ data2 = mux(tag2, data2 & 0xFF, data2) │ +├─── domain.next() ────────────────────────┤ +├─── Cycle 3 ──────────────────────────────┐ +│ tag3 = domain.cycle(tag2, name="t3") │ +│ data3 = domain.cycle(data2, name="d3") │ +│ m.output("lo8", data3.wire[0:8]) │ +│ m.output("hi8", data3.wire[8:16]) │ +│ m.output("tag_out", tag3.wire) │ +└──────────────────────────────────────────┘ +``` + +**这是 `domain.next()` 流水的最佳示范设计。** + +--- + +### 22–23. xz_value_model_smoke / reset_invalidate_order_smoke — ★☆☆ 【单寄存器 · 2 cycle】 + +| 项目 | 内容 | +|------|------| +| **寄存器** | 各 1 个 `m.out()` | +| **V5** | `domain.state()` → Cycle 0 读 → `domain.next()` → Cycle 1 `.set()` | + +--- + +### 24. pipeline_builder — ★★☆ 【2 级流水线 · 3 cycle】 + +| 项目 | 内容 | +|------|------| +| **功能** | `spec.struct` 载荷两级流水 | +| **时序类型** | **2 级流水线**(`m.state()` 管理) | +| **寄存器** | **2 个 `m.state()`**(st0 捕获输入、st1 对 payload.word+1) | +| **端口** | 2 输入(struct) · 2 输出(struct) | +| **`@const`** | 1 个(struct 定义) | + +#### V5 周期结构 + +``` +┌─── Cycle 0:输入 ─────────────────┐ +│ in_ctrl, in_payload = cas(m.input(...)) │ +├─── domain.next() ────────────────┤ +├─── Cycle 1:Stage 0 ─────────────┐ +│ st0 = domain.cycle(...) │ +├─── domain.next() ────────────────┤ +├─── Cycle 2:Stage 1 ─────────────┐ +│ word_plus_1 = st0.word + 1 │ +│ st1 = domain.cycle(word_plus_1) │ +│ m.output(...) │ +└──────────────────────────────────┘ +``` + +--- + +### 25. struct_transform — ★☆☆ 【单级寄存器 · 2 cycle】 + +| 项目 | 内容 | +|------|------| +| **寄存器** | 1 个 `m.state()`(struct 格式) | +| **V5** | Cycle 0 输入 + 变换 → `domain.next()` → Cycle 1 `domain.cycle()` 锁存 | + +--- + +### 26. module_collection — ★★☆ 【纯组合 · 单 cycle】 + +| 项目 | 内容 | +|------|------| +| **时序类型** | 纯组合(8 路并行子模块 + 累加) | +| **寄存器** | 0 | +| **V5** | `build` 签名改 V5;`@module` 子模块保留;顶层 `cas()` 包装 + `m.array` 保留 | + +--- + +### 27. interface_wiring — ★☆☆ 【纯组合 · 单 cycle】 + +| 项目 | 内容 | +|------|------| +| **时序类型** | 纯组合(struct 接口绑定) | +| **寄存器** | 0 | +| **V5** | `build` 签名改 V5;`m.new` 保留 | + +--- + +### 28. instance_map — ★☆☆ 【纯组合 · 单 cycle】 + +| 项目 | 内容 | +|------|------| +| **时序类型** | 纯组合(3 类子模块实例累加) | +| **寄存器** | 0 | +| **V5** | 同 module_collection | + +--- + +### 29. huge_hierarchy_stress — ★★★ 【层次化 · 叶子 2 cycle】 + +| 项目 | 内容 | +|------|------| +| **功能** | 32 个 `_node` 实例树(深度=2, fanout=2),叶子含 `pipe` + `acc` 寄存器;顶层含 Cache(4-way, 64-set) | +| **时序类型** | **层次化** — 叶子有寄存器,节点/顶层为组合连接 | +| **寄存器** | 叶子每个含 `m.out("acc")` 1 个 + `m.pipe` 内部寄存器 | +| **`@module`** | `_leaf` + `_node` 各 1 个 | + +#### V5 改造 + +| 层级 | 改造 | +|------|------| +| `_leaf` | `acc` → `domain.state()` + `domain.next()` + `.set()` | +| `_node` | 保持 `@module(structural=True)` + `m.new` | +| 顶层 `build` | 签名改 V5;`cas()` 包装 `seed` 输入;`m.array` + Cache IP 保留 | + +--- + +### 30. fastfwd — ★☆☆ 【纯组合 · 单 cycle】 + +| 项目 | 内容 | +|------|------| +| **时序类型** | 纯组合直通 | +| **寄存器** | 0 | +| **端口** | 20 输入 · 29 输出 | +| **V5** | `cas()` 包装输入输出即可 | + +--- + +### 31. decode_rules — ★★☆ 【纯组合 · 单 cycle】 + +| 项目 | 内容 | +|------|------| +| **时序类型** | 纯组合优先级解码 | +| **寄存器** | 0 | +| **JIT `if Wire`** | **6 处**(3 条规则各 2 处 `if hit else`) | +| **V5** | 规则命中链改 `mux()` — **注意保持优先级不反转** | + +--- + +### 32. cache_params — ★☆☆ 【纯组合 · 单 cycle】 + +| 项目 | 内容 | +|------|------| +| **时序类型** | 纯组合参数推导 | +| **寄存器** | 0 | +| **V5** | `cas()` 包装输入;`@const` 保留 | + +--- + +### 33. bundle_probe_expand — ★☆☆ 【Stub · 单 cycle】 + +| 项目 | 内容 | +|------|------| +| **时序类型** | 占位(仅声明端口,无逻辑) | +| **V5** | `cas()` 包装输入;probe 基础设施不变 | + +--- + +### 34. boundary_value_ports — ★★☆ 【纯组合 · 单 cycle】 + +| 项目 | 内容 | +|------|------| +| **时序类型** | 纯组合(3 个 `_lane` 子模块各有 gain/bias/enable 值参数) | +| **寄存器** | 0 | +| **JIT `if Wire`** | **1 处**(`_lane` 内 `if enable else`) | +| **V5** | `_lane` 内 `if` → `mux()`;`build` 签名改 V5;`@module` + `m.new` 保留 | + +--- + +### 35. arith — ★☆☆ 【纯组合 · 单 cycle】 + +| 项目 | 内容 | +|------|------| +| **时序类型** | 纯组合加法 + 常量配置 | +| **寄存器** | 0 | +| **V5** | `cas()` 包装输入;`@const` 保留 | + +--- + +### 36. issue_queue_2picker — ★★☆ 【队列寄存器 · 2 cycle】 + +| 项目 | 内容 | +|------|------| +| **功能** | 4 槽移位队列,双 pop 口,单 push 口 | +| **时序类型** | **寄存器队列**(移位 + 仲裁) | +| **寄存器** | **8 个 `m.out()`**:4 slot × (valid + data) | +| **端口** | 4 输入(push/pop 控制 + data) · 5 输出(valid/data + in_ready) | +| **JIT `if Wire`** | **~20 处**(移位/pop/push 条件) | + +#### V5 周期结构 + +``` +┌─── Cycle 0:仲裁 + 移位计算 ──────────────────────────────┐ +│ push_valid, push_data, pop0_ready, pop1_ready = cas(...) │ +│ slot[0..3] = domain.state() × 8(valid+data 各 4) │ +│ 组合逻辑:pop0 取 slot[0], pop1 取 slot[1] │ +│ 移位:根据 pop 数量计算 slot[i] 的下一值 │ +│ push:向首个空位写入 │ +│ 全部 ~20 处 if Wire → mux() │ +│ m.output(pop0_valid/data, pop1_valid/data, in_ready) │ +├─── domain.next() ─────────────────────────────────────────┤ +├─── Cycle 1:状态更新 ──────────────────────────────────────┐ +│ slot[0].set(next_slot0); ... slot[3].set(next_slot3) │ +└──────────────────────────────────────────────────────────┘ +``` + +--- + +### 37. trace_dsl_smoke — ★☆☆ 【子模块寄存器 · 2 cycle】 + +| 项目 | 内容 | +|------|------| +| **时序类型** | 2 个 `leaf` 子模块实例,每个含 1 寄存器 | +| **寄存器** | 2 个(均在 `@module leaf` 内) | +| **V5** | `build` 改 V5 签名;`leaf` 保留 `@module`(`@probe(target=leaf)` 依赖);`m.new` 保留 | +| **注意** | leaf 内部需:`r = domain.state(...)` → `domain.next()` → `r.set(in_x)` | + +--- + +### 38. npu_node (fm16) — ★★☆ 【FIFO IP + 组合路由 · 单 cycle】 + +| 项目 | 内容 | +|------|------| +| **功能** | NPU 节点:HBM 注入 + 4 端口双向网络,按 dst 路由 | +| **时序类型** | **4 个 rv_queue IP**(内含寄存器) + 组合路由逻辑 | +| **自建寄存器** | 0(FIFO 在 IP 内部) | +| **JIT `if Wire`** | **~20 处**(路由 dst→port 匹配 + 合并 push) | + +#### V5 周期结构 + +``` +┌─── Cycle 0(唯一 cycle)──────────────────────────────────┐ +│ hbm_in, port[0..3]_in = cas(m.input(...)) │ +│ 4 × m.rv_queue(depth=8) → IP 保留 │ +│ 路由逻辑:dst mod 4 → push 到目标 FIFO │ +│ ~20 处 if Wire → mux() │ +│ m.output(port[0..3]_out, hbm_out, ...) │ +│ 无 domain.next()(IP 内部自管时序) │ +└──────────────────────────────────────────────────────────┘ +``` + +--- + +### 39. sw5809s (fm16) — ★★☆ 【FIFO + RR 仲裁寄存器 · 2 cycle】 + +| 项目 | 内容 | +|------|------| +| **功能** | 4×4 交叉开关:16 个 VOQ 队列 + round-robin 仲裁 | +| **时序类型** | **16 个 rv_queue IP** + **4 个 RR 指针寄存器** | +| **自建寄存器** | **4 个 `m.out()`**(rr_ptr[0..3],每个 2-bit) | +| **JIT `if Wire`** | **~52 处**(VOQ push/pop 条件 + RR 仲裁链) | + +#### V5 周期结构 + +``` +┌─── Cycle 0:仲裁 + 路由 ──────────────────────────────────┐ +│ port[0..3]_in = cas(m.input(...)) │ +│ 16 × m.rv_queue() → IP 保留 │ +│ rr_ptr[0..3] = domain.state() × 4 │ +│ 对每个输出端口:RR 扫描 4 个 VOQ → 选择非空最优先 │ +│ ~52 处 if Wire → mux() │ +│ m.output(port[0..3]_out, ...) │ +├─── domain.next() ─────────────────────────────────────────┤ +├─── Cycle 1:RR 指针更新 ─────────────────────────────────┐ +│ rr_ptr[i].set(mux(grant_valid, next_rr, rr_ptr[i])) │ +└──────────────────────────────────────────────────────────┘ +``` + +--- + +### 40. fm16_system — ⊘ 无需迁移 + +| 项目 | 内容 | +|------|------| +| **类型** | **纯 Python 行为级仿真器**,不使用 pycircuit 硬件构造 | +| **内容** | `class NPUNode`, `class SW5809s`, `class FM16System`, `class SW16System` — 全为 Python 类的功能模型 | +| **结论** | **无需迁移**,不属于硬件设计 | + +--- + +## 优先级与执行顺序 + +### Phase 1:★☆☆ 简单设计(15 个,预计 1–2 天) + +| # | 设计 | 时序类型 | 寄存器 | 要点 | +|---|------|---------|--------|------| +| 1 | counter | 单寄存器 | 1 | `domain.state()` + `domain.next()` + `.set()` | +| 2 | obs_points | 单寄存器 | 1 | 同上 | +| 3 | net_resolution_depth_smoke | 单寄存器 | 1 | 同上 | +| 4 | xz_value_model_smoke | 单寄存器 | 1 | 同上 | +| 5 | reset_invalidate_order_smoke | 单寄存器 | 1 | 同上 | +| 6 | struct_transform | 单级 m.state | 1 | `domain.state()` + `domain.next()` | +| 7 | fifo_loopback | IP 封装 | 0 | `cas()` 包装,IP 不动 | +| 8 | mem_rdw_olddata | IP 封装 | 0 | `cas()` 包装,IP 不动 | +| 9 | sync_mem_init_zero | IP 封装 | 0 | `cas()` 包装,IP 不动 | +| 10 | fastfwd | 纯组合 | 0 | `cas()` 包装 | +| 11 | cache_params | 纯组合 | 0 | `cas()` + `@const` 保留 | +| 12 | arith | 纯组合 | 0 | `cas()` + `@const` 保留 | +| 13 | bundle_probe_expand | Stub | 0 | `cas()` 包装 | +| 14 | interface_wiring | 纯组合 | 0 | 签名改 V5,`m.new` 保留 | +| 15 | instance_map | 纯组合 | 0 | 签名改 V5,`m.array` 保留 | + +### Phase 2:★★☆ 中等设计(13 个,预计 3–4 天) + +| # | 设计 | 时序类型 | 寄存器 | 核心改动 | +|---|------|---------|--------|----------| +| 1 | wire_ops | 单寄存器 | 1 | 1 处 `if` → `mux()` + `domain.next()` | +| 2 | multiclock_regs | 多时钟 | 2 | 双域 `domain.state()` + 各自 `domain.next()` | +| 3 | hier_modules | 纯组合→可选真流水 | 0→3 | 可选 `domain.cycle()` × `stages` | +| 4 | jit_control_flow | 纯组合 | 0 | 4 处 `if/elif` → `mux()` 嵌套链 | +| 5 | BypassUnit | **纯组合** | **0** | **14 处** `if` → `mux()`;优先级链语义 | +| 6 | digital_filter | 移位寄存器 | 5 | 延迟线 `domain.state()` × 3 + `domain.next()` | +| 7 | **jit_pipeline_vec** | **3 级流水** | **6** | `domain.next()` × 3 循环流水 **(示范设计)** | +| 8 | pipeline_builder | 2 级流水 | 2 | `domain.next()` × 2 + struct `domain.cycle()` | +| 9 | decode_rules | 纯组合 | 0 | 6 处 `if hit` → `mux()`,保持优先级 | +| 10 | module_collection | 纯组合 | 0 | `@module` 子模块保留;`cas()` 规约 | +| 11 | boundary_value_ports | 纯组合 | 0 | `_lane` 内 1 处 `if` → `mux()` | +| 12 | npu_node | FIFO IP | 0 | ~20 处路由 `if` → `mux()`,rv_queue 保留 | +| 13 | trace_dsl_smoke | 子模块寄存器 | 2 | `@module` leaf 内 `domain.state()` + `domain.next()` | + +### Phase 3:★★★ 复杂设计(9 个,预计 5–8 天) + +| # | 设计 | 时序类型 | 寄存器 | 核心挑战 | +|---|------|---------|--------|----------| +| 1 | **bf16_fmac** | **4 级流水** | **30** | 3 × `domain.next()` 分割流水 + ~20 处 `mux()` + 异常路径 | +| 2 | **IssueQueue** | 单周期状态机 | **321** | 大量 `domain.state()` + struct 状态 + 18 个 `@function` | +| 3 | **issue_queue_2picker** | 队列寄存器 | **8** | ~20 处 `if Wire` → `mux()` 移位逻辑 | +| 4 | **sw5809s** | FIFO+RR | **4** | **~52 处** `if Wire` → `mux()`;16 个 VOQ | +| 5 | **calculator** | FSM | 5 | ~14 处 `if Wire` → `mux()` FSM 链 | +| 6 | **digital_clock** | FSM | 6 | ~22 处 `if Wire` → `mux()` + BCD 进位 | +| 7 | **traffic_lights_ce** | FSM | 5 | **~27 处** `if Wire` → `mux()` | +| 8 | **dodgeball_game** (2 files) | FSM+VGA | **16** | ~14 处 `if Wire` + VGA 时序 + 碰撞 | +| 9 | **huge_hierarchy_stress** | 层次化 | ~32+ | `@module` 叶子 `domain.state()` + Cache IP 接口 | + +--- + +## 通用改造模板 + +### 模板 A:纯组合设计(无寄存器) + +```python +def build(m: CycleAwareCircuit, domain: CycleAwareDomain, ...) -> None: + # Cycle 0: inputs + a = cas(domain, m.input("a", width=W), cycle=0) + b = cas(domain, m.input("b", width=W), cycle=0) + + # Cycle 0: combinational logic + result = mux(sel, a + b, a - b) + + m.output("out", result.wire) +``` + +### 模板 B:单寄存器反馈(计数器/累加器) + +```python +def build(m: CycleAwareCircuit, domain: CycleAwareDomain, ...) -> None: + # Cycle 0: inputs + state + enable = cas(domain, m.input("en", width=1), cycle=0) + count = domain.state(width=8, reset_value=0, name="count") + + m.output("count", count.wire) + + # Cycle 1: update + domain.next() + count.set(mux(enable, count + 1, count)) +``` + +### 模板 C:多级流水 + +```python +def build(m: CycleAwareCircuit, domain: CycleAwareDomain, ...) -> None: + # Cycle 0: inputs + data = cas(domain, m.input("data", width=W), cycle=0) + valid = cas(domain, m.input("valid", width=1), cycle=0) + + # Cycle 0 → 1: Stage 1 + s1_data = data + 1 + domain.next() + s1_reg = domain.cycle(s1_data, name="s1") + s1_valid = domain.cycle(valid, name="s1_valid") + + # Cycle 1 → 2: Stage 2 + s2_data = s1_reg * 2 + domain.next() + s2_reg = domain.cycle(s2_data, name="s2") + + m.output("out", s2_reg) +``` + +### 模板 D:FSM(状态机) + +```python +def build(m: CycleAwareCircuit, domain: CycleAwareDomain, ...) -> None: + # Cycle 0: inputs + state + cmd = cas(domain, m.input("cmd", width=2), cycle=0) + state = domain.state(width=2, reset_value=0, name="fsm") + + IDLE, RUN, DONE = 0, 1, 2 + + # Next-state logic (combinational) + is_idle = state == cas(domain, m.const(IDLE, width=2), cycle=0) + is_run = state == cas(domain, m.const(RUN, width=2), cycle=0) + start = cmd == cas(domain, m.const(1, width=2), cycle=0) + + next_state = state # default: hold + next_state = mux(is_idle & start, cas(domain, m.const(RUN, width=2), cycle=0), next_state) + next_state = mux(is_run, cas(domain, m.const(DONE, width=2), cycle=0), next_state) + + m.output("state", state.wire) + + # Cycle 1: update + domain.next() + state.set(next_state) +``` + +--- + +## 验证策略(总则) + +每个设计改造后**必须**通过以下三关: + +1. **MLIR 结构对比**:新旧版 `pyc.reg` 数量一致、端口签名(`arg_names` / `result_names`)一致 +2. **功能仿真**(如有 `tb_*.py`):全部 `t.expect` 通过,无新增 FAIL +3. **性能基准**(如有 `emulate_*.py`):100K 周期吞吐无回归(±5%) + +--- + +## 各设计现有验证资产 & 升级后验证计划 + +> **图例** +> - TB = `tb_*.py` testbench(`@testbench` + `Tb` API) +> - CFG = `*_config.py`(含 `DEFAULT_PARAMS` / `TB_PRESETS`) +> - EMU = `emulate_*.py` / `test_*.py`(RTL 仿真/基准) +> - SVA = `t.sva_assert` 断言 +> - E(N) = N 次 `t.expect` 调用 + +### 一、大型设计 + +#### 1. RegisterFile + +| 验证资产 | 详情 | +|---------|------| +| TB | `tb_regfile.py`:10 个周期序列,每周期 10 读口 `t.expect`,共 **E(100)** | +| CFG | 无(参数内联于 TB) | +| EMU | `emulate_regfile.py`:ctypes RTL 仿真——功能正确性 29 项 + **100K 周期性能基准** | +| SVA | 无 | + +**升级后验证计划:** +- [x] **已验证**:V5 改造后 29/29 功能测试 PASS,100K 仿真 57.4 Kcycles/s(无回归) +- [ ] MLIR 对比:`pyc.reg` 数量 = 256(128×2 bank),端口签名不变 +- [ ] TB 编译:`compile_cycle_aware(build, name="tb_regfile_top", eager=True)` 出 MLIR 成功 + +--- + +#### 2. IssueQueue + +| 验证资产 | 详情 | +|---------|------| +| TB | `tb_issq.py`:Python 黄金模型 `_tb_step` 生成最多 512 周期入队/发射轨迹;**E(1)** `occupancy` 初始为 0 | +| CFG | `issq_config.py`:`IqCfg` 规格 + `TbState`/`TbUop` 参考模型(无 `DEFAULT_PARAMS`/`TB_PRESETS`) | +| EMU | 无 | +| SVA | 无 | + +**升级后验证计划:** +- [ ] MLIR 对比:entry 数 × (valid+age+ready+ptag+payload) 寄存器总数不变 +- [ ] TB 编译通过,占用量为 0 的初始检查仍 PASS +- [ ] **新增**:在 TB 中对 `issued_total` 添加终态 `t.expect`,确认发射总量 = 入队总量 + +--- + +#### 3. BypassUnit + +| 验证资产 | 详情 | +|---------|------| +| TB | `tb_bypass_unit.py`:8 手写场景 + 系统化 sweep(184 周期),**E(11776)** + **SVA(1344)** | +| CFG | 无 | +| EMU | 无 | +| SVA | `t.sva_assert`:同 stage 禁止双命中 | + +**升级后验证计划:** +- [ ] MLIR 对比:纯组合设计,`pyc.reg` = 0,端口数不变 +- [ ] TB 全部 11776 次 `t.expect` 通过 +- [ ] SVA 全部 1344 条 `t.sva_assert` 通过 +- [ ] **关键**:`if Wire else` → `mux()` 改造后,每个旁路优先级链须逐一验证 + +--- + +### 二、示例设计 + +#### 4. counter + +| 验证资产 | 详情 | +|---------|------| +| TB | `tb_counter.py`:**E(5)**(`count` 每周期 +1) | +| CFG | `counter_config.py`:`DEFAULT_PARAMS = {width: 8}`,smoke/nightly | + +**升级后验证计划:** +- [ ] TB 5 次 `t.expect` 全部 PASS +- [ ] MLIR:1 个 `pyc.reg`(计数器),端口 `clk/rst/enable → count` + +--- + +#### 5. multiclock_regs + +| 验证资产 | 详情 | +|---------|------| +| TB | `tb_multiclock_regs.py`:双时钟驱动,**E(0)**(仅驱动无检查) | +| CFG | `multiclock_regs_config.py`:`DEFAULT_PARAMS = {}`,smoke/nightly | + +**升级后验证计划:** +- [ ] MLIR 对比:2 个 `pyc.reg`(`a_q`/`b_q`),4 个 clock/reset 端口 +- [ ] **新增**:在 TB 中追加 `t.expect("a_count", 3, at=5)` 等基本计数检查 + +--- + +#### 6. wire_ops + +| 验证资产 | 详情 | +|---------|------| +| TB | `tb_wire_ops.py`:**E(1)** | +| CFG | `wire_ops_config.py`:smoke/nightly | + +**升级后验证计划:** +- [ ] TB `t.expect` PASS +- [ ] **关键**:`if sel else` → `mux(sel, a & b, a ^ b)` 的语义等价 + +--- + +#### 7. jit_control_flow + +| 验证资产 | 详情 | +|---------|------| +| TB | `tb_jit_control_flow.py`:**E(1)**(`result == 7`) | +| CFG | `jit_control_flow_config.py`:`rounds: 4` | + +**升级后验证计划:** +- [ ] TB 组合结果 `t.expect` PASS +- [ ] **关键**:多分支 `if/elif op ==` → `mux()` 嵌套链的等价验证 + +--- + +#### 8. fifo_loopback + +| 验证资产 | 详情 | +|---------|------| +| TB | `tb_fifo_loopback.py`:**E(0)**(仅驱动) | +| CFG | `fifo_loopback_config.py`:`depth: 2` | + +**升级后验证计划:** +- [ ] MLIR 编译通过(`m.rv_queue` IP 接口不变) +- [ ] **新增**:追加 `t.expect("out_data", ...)` 验证 FIFO 先入先出行为 + +--- + +#### 9. hier_modules + +| 验证资产 | 详情 | +|---------|------| +| TB | `tb_hier_modules.py`:**E(1)** | +| CFG | `hier_modules_config.py`:`width`/`stages` | + +**升级后验证计划:** +- [ ] TB `t.expect` PASS +- [ ] 若改为真流水(`domain.cycle()` × stages),MLIR `pyc.reg` 数应 = `stages` + +--- + +#### 10. bf16_fmac + +| 验证资产 | 详情 | +|---------|------| +| TB | 无标准 `tb_*.py`;有 `test_bf16_fmac.py`:ctypes RTL 100 用例,BF16 乘加与 Python 对比(≤2% 误差) | +| CFG | 无 | +| EMU | 无 | + +**升级后验证计划:** +- [ ] `test_bf16_fmac.py` 100 用例全部 PASS(误差阈值不变) +- [ ] MLIR:4 级流水寄存器总数不变 +- [ ] **关键**:50+ 处 `if Wire` → `mux()` 改造后须全量回归 + +--- + +#### 11. digital_filter + +| 验证资产 | 详情 | +|---------|------| +| TB | 无标准 `tb_*.py` | +| EMU | `emulate_filter.py`:4-tap FIR RTL 终端动画 | + +**升级后验证计划:** +- [ ] MLIR:`TAPS-1` 个延迟寄存器 + 1 个输出寄存器 + 1 个 valid 寄存器 +- [ ] `emulate_filter.py` 运行无崩溃 +- [ ] **新增**:编写 `tb_digital_filter.py`,对已知输入序列验证 FIR 输出 + +--- + +#### 12. digital_clock + +| 验证资产 | 详情 | +|---------|------| +| TB | `tb_digital_clock.py`:**E(1)**(复位后 `seconds_bcd`) | +| CFG | `digital_clock_config.py`:`clk_freq: 50_000_000` | +| EMU | `emulate_digital_clock.py`:RTL 动画时钟 | + +**升级后验证计划:** +- [ ] TB `t.expect` PASS +- [ ] `emulate_digital_clock.py` 运行无崩溃 +- [ ] **关键**:FSM `if` 链 → `mux()` 链须保持状态转换语义 + +--- + +#### 13. calculator + +| 验证资产 | 详情 | +|---------|------| +| TB | `tb_calculator.py`:**E(1)**(`display`) | +| CFG | `calculator_config.py`:`DEFAULT_PARAMS = {}` | +| EMU | `emulate_calculator.py`:RTL 动画计算器 | + +**升级后验证计划:** +- [ ] TB `t.expect` PASS +- [ ] `emulate_calculator.py` 运行无崩溃 +- [ ] **新增**:在 TB 中追加 `1+2=3`、`9*9=81` 等算术序列检查 + +--- + +#### 14. traffic_lights_ce + +| 验证资产 | 详情 | +|---------|------| +| TB | 无标准 `tb_*.py` | +| EMU | `emulate_traffic_lights.py`:RTL 可视化(含 `stimuli/` 激励) | + +**升级后验证计划:** +- [ ] `emulate_traffic_lights.py` 运行无崩溃 +- [ ] **新增**:编写 `tb_traffic_lights_ce.py`,验证相位切换、紧急模式覆盖、倒计时归零 + +--- + +#### 15–16. dodgeball_game (lab_final_VGA + lab_final_top) + +| 验证资产 | 详情 | +|---------|------| +| TB | 无标准 `tb_*.py` | +| EMU | `emulate_dodgeball.py`:RTL 游戏可视化(含 `stimuli/`) | + +**升级后验证计划:** +- [ ] `emulate_dodgeball.py` 运行无崩溃 +- [ ] MLIR 编译通过 +- [ ] **新增**:编写 `tb_lab_final_VGA.py`,验证 hsync/vsync 时序(640×480@60Hz 标准值) + +--- + +#### 17. obs_points + +| 验证资产 | 详情 | +|---------|------| +| TB | `tb_obs_points.py`:**E(6)**(`y`/`q` 的 pre/post 观测点) | +| CFG | `obs_points_config.py`:`width: 8` | + +**升级后验证计划:** +- [ ] TB 6 次 `t.expect` 全部 PASS + +--- + +#### 18. net_resolution_depth_smoke + +| 验证资产 | 详情 | +|---------|------| +| TB | `tb_net_resolution_depth_smoke.py`:**E(4)** | +| CFG | `net_resolution_depth_smoke_config.py`:`width: 8` | + +**升级后验证计划:** +- [ ] TB 4 次 `t.expect` PASS + +--- + +#### 19. mem_rdw_olddata + +| 验证资产 | 详情 | +|---------|------| +| TB | `tb_mem_rdw_olddata.py`:**E(2)**(同址读写返回旧值,再读新值) | +| CFG | `mem_rdw_olddata_config.py`:`depth/data_width/addr_width` | + +**升级后验证计划:** +- [ ] TB 2 次 `t.expect` PASS(旧数据语义) + +--- + +#### 20. sync_mem_init_zero + +| 验证资产 | 详情 | +|---------|------| +| TB | `tb_sync_mem_init_zero.py`:**E(2)**(读未写地址应为 0) | +| CFG | `sync_mem_init_zero_config.py` | + +**升级后验证计划:** +- [ ] TB 2 次 `t.expect` PASS + +--- + +#### 21. jit_pipeline_vec + +| 验证资产 | 详情 | +|---------|------| +| TB | `tb_jit_pipeline_vec.py`:**E(0)**(仅驱动 `a/b/sel`) | +| CFG | `jit_pipeline_vec_config.py`:`stages: 3` | + +**升级后验证计划:** +- [ ] MLIR:`pyc.reg` 数 = `stages`(tag 链)+ `stages`(data 链) +- [ ] **新增**:在 TB 中追加 `t.expect` 验证 `stages` 拍延迟后的 `lo8` 输出值 + +--- + +#### 22. xz_value_model_smoke + +| 验证资产 | 详情 | +|---------|------| +| TB | `tb_xz_value_model_smoke.py`:**E(4)** | +| CFG | `xz_value_model_smoke_config.py`:`width: 8` | + +**升级后验证计划:** +- [ ] TB 4 次 `t.expect` PASS + +--- + +#### 23. reset_invalidate_order_smoke + +| 验证资产 | 详情 | +|---------|------| +| TB | `tb_reset_invalidate_order_smoke.py`:**E(4)** | +| CFG | `reset_invalidate_order_smoke_config.py`:`width: 8` | + +**升级后验证计划:** +- [ ] TB 4 次 `t.expect` PASS + +--- + +#### 24. pipeline_builder + +| 验证资产 | 详情 | +|---------|------| +| TB | `tb_pipeline_builder.py`:**E(1)**(`out_ctrl_valid` 流水级差) | +| CFG | `pipeline_builder_config.py`:`width: 32` | + +**升级后验证计划:** +- [ ] TB `t.expect` PASS +- [ ] MLIR:2 级流水寄存器数不变 + +--- + +#### 25. struct_transform + +| 验证资产 | 详情 | +|---------|------| +| TB | `tb_struct_transform.py`:**E(2)**(bundle 位域变换) | +| CFG | `struct_transform_config.py`:`width: 32` | + +**升级后验证计划:** +- [ ] TB 2 次 `t.expect` PASS + +--- + +#### 26. module_collection + +| 验证资产 | 详情 | +|---------|------| +| TB | `tb_module_collection.py`:**E(1)**(`acc` 规约) | +| CFG | `module_collection_config.py`:`width`/`lanes` | + +**升级后验证计划:** +- [ ] TB `t.expect` PASS +- [ ] `m.array` 子模块实例数不变 + +--- + +#### 27. interface_wiring + +| 验证资产 | 详情 | +|---------|------| +| TB | `tb_interface_wiring.py`:**E(2)** | +| CFG | `interface_wiring_config.py`:`width: 16` | + +**升级后验证计划:** +- [ ] TB 2 次 `t.expect` PASS + +--- + +#### 28. instance_map + +| 验证资产 | 详情 | +|---------|------| +| TB | `tb_instance_map.py`:**E(4)** | +| CFG | `instance_map_config.py`:`width: 32` | + +**升级后验证计划:** +- [ ] TB 4 次 `t.expect` PASS + +--- + +#### 29. huge_hierarchy_stress + +| 验证资产 | 详情 | +|---------|------| +| TB | `tb_huge_hierarchy_stress.py`:**E(0)**(应力测试,仅驱动 `seed`) | +| CFG | `huge_hierarchy_stress_config.py`:`SIM_TIER: "heavy"`,`module_count/hierarchy_depth/fanout/cache_ways/cache_sets` | + +**升级后验证计划:** +- [ ] MLIR 编译通过(深层次 + Cache 实例化无报错) +- [ ] `pyc.reg` 总数不变 +- [ ] **新增**:追加 `t.expect("out", ...)` 在固定 `seed` 下对 `out` 做 golden 比对 + +--- + +#### 30. fastfwd + +| 验证资产 | 详情 | +|---------|------| +| TB | `tb_fastfwd.py`:**E(1)**(`pkt_in_bkpr`) | +| CFG | `fastfwd_config.py`:`DEFAULT_PARAMS = {}` | + +**升级后验证计划:** +- [ ] TB `t.expect` PASS + +--- + +#### 31. decode_rules + +| 验证资产 | 详情 | +|---------|------| +| TB | `tb_decode_rules.py`:**E(2)**(`op`/`len`) | +| CFG | `decode_rules_config.py` | + +**升级后验证计划:** +- [ ] TB 2 次 `t.expect` PASS +- [ ] **关键**:规则 `if hit else` → `mux()` 链优先级不能反转 + +--- + +#### 32. cache_params + +| 验证资产 | 详情 | +|---------|------| +| TB | `tb_cache_params.py`:**E(3)**(`tag`/`line_words`/`tag_bits`) | +| CFG | `cache_params_config.py`:`ways/sets/line_bytes/addr_width/data_width` | + +**升级后验证计划:** +- [ ] TB 3 次 `t.expect` PASS + +--- + +#### 33. bundle_probe_expand + +| 验证资产 | 详情 | +|---------|------| +| TB | `tb_bundle_probe_expand.py`:**E(4)**(bundle 展开 pre/post) | +| CFG | `bundle_probe_expand_config.py` | + +**升级后验证计划:** +- [ ] TB 4 次 `t.expect` PASS + +--- + +#### 34. boundary_value_ports + +| 验证资产 | 详情 | +|---------|------| +| TB | `tb_boundary_value_ports.py`:**E(1)** | +| CFG | `boundary_value_ports_config.py`:`width: 32` | + +**升级后验证计划:** +- [ ] TB `t.expect` PASS +- [ ] `_lane` 子模块 `if enable else` → `mux()` 验证 + +--- + +#### 35. arith + +| 验证资产 | 详情 | +|---------|------| +| TB | `tb_arith.py`:**E(3)**(`sum/lane_mask/acc_width`) | +| CFG | `arith_config.py`:`lanes/lane_width` | + +**升级后验证计划:** +- [ ] TB 3 次 `t.expect` PASS + +--- + +#### 36. issue_queue_2picker + +| 验证资产 | 详情 | +|---------|------| +| TB | `tb_issue_queue_2picker.py`:**E(1)**(`in_ready` backpressure) | +| CFG | `issue_queue_2picker_config.py` | + +**升级后验证计划:** +- [ ] TB `t.expect` PASS +- [ ] **关键**:队列移位逻辑 `if pop else` → `mux()` 等价验证 + +--- + +#### 37. trace_dsl_smoke + +| 验证资产 | 详情 | +|---------|------| +| TB | `tb_trace_dsl_smoke.py`:**E(12)**(双输出 trace 多周期 pre/post) | +| CFG | `trace_dsl_smoke_config.py`:smoke `timeout: 16, finish: 3` | + +**升级后验证计划:** +- [ ] TB 12 次 `t.expect` PASS +- [ ] **注意**:`leaf` 必须保留 `@module`(`@probe(target=leaf)` 依赖) | + +--- + +#### 38. npu_node (fm16) + +| 验证资产 | 详情 | +|---------|------| +| TB | 无 | +| EMU | `fm16_system.py`(系统级整合脚本) | + +**升级后验证计划:** +- [ ] MLIR 编译通过 +- [ ] **新增**:编写 `tb_npu_node.py`,验证单端口 push/pop 数据一致性 + +--- + +#### 39. sw5809s (fm16) + +| 验证资产 | 详情 | +|---------|------| +| TB | 无 | +| EMU | 共享 `fm16_system.py` | + +**升级后验证计划:** +- [ ] MLIR 编译通过 +- [ ] **新增**:编写 `tb_sw5809s.py`,验证 RR 仲裁公平性(每端口等概率获得授权) + +--- + +## 验证缺口汇总 & 补充测试计划 + +以下设计在升级前**缺少充分的功能验证**,改造时应同步补充: + +| 设计 | 现有验证 | 需补充 | +|------|---------|--------| +| multiclock_regs | TB 仅驱动,E(0) | 追加计数值 `t.expect` | +| fifo_loopback | TB 仅驱动,E(0) | 追加 FIFO 读出 `t.expect` | +| jit_pipeline_vec | TB 仅驱动,E(0) | 追加延迟后输出 `t.expect` | +| huge_hierarchy_stress | TB 仅驱动,E(0) | 追加固定 seed 输出 golden | +| digital_filter | 无 TB | 新建 `tb_digital_filter.py` | +| traffic_lights_ce | 无 TB | 新建 `tb_traffic_lights_ce.py` | +| dodgeball_game | 无 TB | 新建 `tb_lab_final_VGA.py` | +| fm16 (npu_node) | 无 TB | 新建 `tb_npu_node.py` | +| fm16 (sw5809s) | 无 TB | 新建 `tb_sw5809s.py` | + +--- + +## 自动化回归脚本 + +改造完成后,应创建一键回归脚本 `scripts/regress_v5.sh`: + +```bash +#!/bin/bash +set -e +PYTHONPATH=pyCircuit/compiler/frontend + +echo "=== Phase 1: MLIR 编译检查 ===" +for design in designs/RegisterFile/regfile.py \ + designs/IssueQueue/issq.py \ + designs/BypassUnit/bypass_unit.py \ + designs/examples/*/[!t]*.py; do + echo " Compiling $design ..." + python3 "$design" > /dev/null 2>&1 +done + +echo "=== Phase 2: Testbench 编译 ===" +for tb in designs/*/tb_*.py designs/examples/*/tb_*.py; do + [ -f "$tb" ] || continue + echo " Compiling $tb ..." + python3 "$tb" > /dev/null 2>&1 +done + +echo "=== Phase 3: Emulation 烟雾 ===" +for emu in designs/RegisterFile/emulate_regfile.py; do + echo " Running $emu ..." + python3 "$emu" +done + +echo "ALL PASSED" +``` + +--- + +**Copyright (C) 2024-2026 PyCircuit Contributors** diff --git a/docs/pyCircuit_Tutorial.md b/docs/pyCircuit_Tutorial.md deleted file mode 100644 index 66204d2..0000000 --- a/docs/pyCircuit_Tutorial.md +++ /dev/null @@ -1,212 +0,0 @@ -# pyCircuit v4.0 Tutorial (Hard-Break) - -This tutorial is the v4.0 (`pyc0.40`) guide for authoring, building, and testing -pyCircuit designs. - -CycleAware APIs were removed in pyc4.0 and are not part of v4 authoring. - -## 1. What pyc4.0 enforces - -- `@module` defines hierarchy boundaries that lower to `pyc.instance`. -- Simulation follows a two-phase model: - - `tick()` computes next state - - `transfer()` commits state -- Python control-flow is allowed during authoring, but backend IR must be static - hardware (no residual dynamic SCF/index in backend lanes). -- DFX/probe behavior is first-class and controlled by hardened metadata + trace DSL. - -Authoritative references: - -- `docs/rfcs/pyc4.0-decisions.md` -- `docs/updatePLAN.md` -- `docs/FRONTEND_API.md` -- `docs/TESTBENCH.md` -- `designs/examples/README.md` - -## 2. Environment and quick gate loop - -Build `pycc`: - -```bash -bash /Users/zhoubot/pyCircuit/flows/scripts/pyc build -``` - -Run compiler smoke: - -```bash -bash /Users/zhoubot/pyCircuit/flows/scripts/run_examples.sh -``` - -Run simulation smoke: - -```bash -bash /Users/zhoubot/pyCircuit/flows/scripts/run_sims.sh -``` - -Run semantic regression lane: - -```bash -bash /Users/zhoubot/pyCircuit/flows/scripts/run_semantic_regressions_v40.sh -``` - -## 3. Minimal module - -```python -from pycircuit import Circuit, module, u - -@module -def build(m: Circuit, width: int = 8) -> None: - clk = m.clock("clk") - rst = m.reset("rst") - en = m.input("enable", width=1) - - count = m.out("count_q", clk=clk, rst=rst, width=width, init=u(width, 0)) - count.set(count.out() + 1, when=en) - m.output("count", count) -``` - -Key points: - -- `m.out(...)` creates explicit sequential state. -- `.out()` reads current state. -- `.set(next, when=...)` sets next state with hold-by-default behavior. - -## 4. Authoring with Python control flow - -You can use `if` and `for` in `@module` bodies as authoring sugar. - -```python -from pycircuit import Circuit, module, u - -@module -def build(m: Circuit, rounds: int = 4) -> None: - a = m.input("a", width=8) - b = m.input("b", width=8) - op = m.input("op", width=2) - - acc = a - if op == u(2, 0): - acc = a + b - elif op == u(2, 1): - acc = a - b - elif op == u(2, 2): - acc = a ^ b - else: - acc = a & b - - for _ in range(rounds): - acc = acc + 1 - - m.output("result", acc) -``` - -The compiler must lower this to static hardware before backend emission. - -## 5. Structured interfaces - -For larger modules, prefer `spec` + structured IO to keep port conventions -stable and tool-visible. - -```python -from pycircuit import Circuit, module, spec - -Pair = spec.struct("pair").field("x", width=8).field("y", width=8).build() - -@module -def build(m: Circuit) -> None: - ins = m.inputs(Pair, prefix="in_") - m.outputs(Pair, {"x": ins["x"], "y": ins["y"]}, prefix="out_") -``` - -See `docs/SPEC_STRUCTURES.md` and `docs/SPEC_COLLECTIONS.md` for full patterns. - -## 6. Testbench flow - -Write a host-side `@testbench` with `Tb`: - -```python -from pycircuit import Tb, testbench - -@testbench -def tb(t: Tb) -> None: - t.clock("clk") - t.reset("rst", cycles_asserted=2, cycles_deasserted=1) - t.timeout(64) - - t.drive("enable", 1, at=0) - t.expect("count", 1, at=0, phase="pre") - t.expect("count", 1, at=0, phase="post") - t.finish(at=8) -``` - -Observation points: - -- `phase="pre"` = TICK-OBS -- `phase="post"` = XFER-OBS - -## 7. End-to-end build via CLI - -Build a device + TB project: - -```bash -PYTHONPATH=/Users/zhoubot/pyCircuit/compiler/frontend \ -python3 -m pycircuit.cli build \ - /Users/zhoubot/pyCircuit/designs/examples/counter/tb_counter.py \ - --out-dir /tmp/pyc_counter \ - --target both \ - --jobs 8 -``` - -Important artifacts: - -- `project_manifest.json` -- `device/modules/*.pyc` -- `device/cpp/**` and/or `device/verilog/**` -- `trace_plan.json` (when trace config is enabled) -- `probe_manifest.json` - -## 8. Trace and probe workflow - -When trace config is enabled, pyc4.0 emits binary `.pyctrace` plus manifest data. - -Decode with external manifest mode: - -```bash -python3 /Users/zhoubot/pyCircuit/flows/tools/dump_pyctrace.py \ - /tmp/pyc_counter/tb_tb_counter_top/tb_tb_counter_top.pyctrace \ - --manifest /tmp/pyc_counter/probe_manifest.json -``` - -Use `designs/examples/trace_dsl_smoke/*`, -`designs/examples/bundle_probe_expand/*`, -`designs/examples/xz_value_model_smoke/*`, and -`designs/examples/reset_invalidate_order_smoke/*` as reference patterns. - -## 9. Required gate mindset (v4.0) - -For semantic or IR-contract changes: - -1. Update verifier/pass gates first. -2. Implement behavior in dialect/passes, not backend-only fixups. -3. Re-run smoke + simulation gates and preserve logs. -4. Keep decision status current in `docs/gates/decision_status_v40.md`. -5. Run semantic closure regressions (`run_semantic_regressions_v40.sh`) before status promotion. - -## 10. Troubleshooting checklist - -- `pycc` not found: run `flows/scripts/pyc build` or set `PYCC`. -- Backend IR legality failures: inspect `pyc-check-no-dynamic` and - `pyc-check-flat-types` diagnostics. -- Hierarchy contract failures: ensure module boundaries are authored with - `@module` and instance creation paths remain explicit. -- Trace decoding failures: verify `.pyctrace` header + `probe_manifest.json` - consistency. - -## 11. Next reading - -- `docs/QUICKSTART.md` -- `docs/FRONTEND_API.md` -- `docs/TESTBENCH.md` -- `docs/IR_SPEC.md` -- `docs/tutorial/index.md` -- `designs/examples/README.md` diff --git a/docs/simulation.md b/docs/simulation.md new file mode 100644 index 0000000..ff99efb --- /dev/null +++ b/docs/simulation.md @@ -0,0 +1,512 @@ +# pyCircuit C++ 仿真引擎架构 + +## 1. 概述 + +pyCircuit 的 C++ 仿真引擎采用 **静态编译-直接执行 (Compiled-Code Simulation)** 模型, +而非传统 Verilog/VHDL 仿真器常用的 **事件驱动 (Event-Driven Simulation)** 模型。 + +整个 RTL 设计被编译为一个 **单一 C++ 结构体**,内含所有信号(`Wire`)、 +寄存器实例(`pyc_reg`)以及组合逻辑求值函数(`eval()`/`tick()`)。 +仿真通过反复调用这些方法来推进时钟周期,在主机 CPU 上直接执行原生 C++ 代码。 + +``` +┌─────────────────────────────────────────────────────────┐ +│ Python 测试驱动 (ctypes) │ +│ 设置输入 → 调用 C API → 读取输出 │ +├─────────────────────────────────────────────────────────┤ +│ C API 封装层 (*_capi.cpp) │ +│ rf_create / rf_reset / rf_tick / rf_get_rdata / ... │ +├─────────────────────────────────────────────────────────┤ +│ Testbench (pyc_tb.hpp) │ +│ 时钟管理 / reset 协议 / VCD 波形 / 二进制 Trace │ +├─────────────────────────────────────────────────────────┤ +│ 生成的 DUT 结构体 (*_gen.hpp) │ +│ Wire 信号成员 / eval() / tick() │ +├─────────────────────────────────────────────────────────┤ +│ 运行时库 (pyc_bits.hpp, pyc_primitives.hpp, ...) │ +│ Wire 位向量 / pyc_reg / pyc_fifo / pyc_sync_mem │ +└─────────────────────────────────────────────────────────┘ +``` + +## 2. 核心数据结构 + +### 2.1 `Wire` (pyc_bits.hpp) + +所有信号(无论组合还是寄存器输出)都用 `Wire` 表示,它是固定宽度的无符号位向量。 + +```cpp +template +class Bits { + static constexpr unsigned kWords = (Width + 63) / 64; + std::array words_{}; +}; +template +using Wire = Bits; +``` + +- 存储以 64-bit word 为单元,小端序(word[0] = bits[63:0]) +- 所有运算符(`+`, `-`, `*`, `&`, `|`, `^`, `~`, 比较等)直接在 word 数组上操作 +- 宽度 ≤ 64 bit 的信号仅占 1 个 word,零额外开销 +- 宽度 > 64 bit 的信号(如 RegisterFile 的 640-bit rdata_bus)自动扩展为多 word + +### 2.2 `pyc_reg` (pyc_primitives.hpp) + +寄存器原语,实现两阶段更新协议: + +```cpp +template +class pyc_reg { + Wire<1> &clk, &rst, &en; + Wire &d, &init, &q; + Wire qNext{}; + bool pending = false; + + void tick_compute(); // 阶段1: 检测上升沿,计算 qNext + void tick_commit(); // 阶段2: 原子提交 q = qNext +}; +``` + +## 3. 单周期执行流程 + +每个仿真步(half-cycle step)按固定顺序执行,**没有事件队列**: + +``` +┌───────────────────────────────────────────────────────────────┐ +│ Testbench::step() [pyc_tb.hpp:130] │ +│ │ +│ 1. eval() — 组合逻辑前向求值(输入→输出) │ +│ 2. clock toggle — 翻转时钟信号 │ +│ 3. tick() — 时序逻辑更新 │ +│ 3a. tick_compute() × 所有寄存器 — 计算下一状态 │ +│ 3b. tick_commit() × 所有寄存器 — 原子写入 │ +│ 4. eval() — 组合逻辑重新稳定(反映新寄存器值) │ +│ 5. VCD dump (可选) │ +└───────────────────────────────────────────────────────────────┘ +``` + +快速路径 `runPosedgeCyclesFast()` 对单时钟设计做了优化, +将上升沿和下降沿合并处理,每个完整周期执行: + +``` +comb → clk=1 → tick_posedge → transfer → comb → clk=0 → tick_negedge → transfer +``` + +快速路径支持 SFINAE 检测 DUT 的 `tick_posedge()` / `tick_negedge()` 方法。 +如果 DUT 提供了分离的时钟边沿方法,下降沿仅执行轻量级 `clkPrev` 更新, +避免对所有寄存器执行完整的 `tick_compute()` 检查。 + +### 3.1 eval() 组合逻辑求值 + +`eval()` 是编译器生成的纯函数,按 **拓扑排序** 展开所有组合逻辑节点。 +编译器在 MLIR 层已完成数据流分析和调度,将组合逻辑分割为多个 +`eval_comb_N()` 内联函数,顺序调用: + +```cpp +void eval() { + eval_comb_11(); // 解码 / 地址匹配 + rf_bank0_0 = pyc_reg_271; // 寄存器输出赋值 + rf_bank0_1 = pyc_reg_272; + ... + eval_comb_12(); // 写使能 / MUX 选择 + eval_comb_13(); + ... + rdata_bus = pyc_comb_8234; // 最终输出 +} +``` + +**关键特性**: 默认模式下,每个周期对所有组合节点做完整求值。 +通过可选的 **信号变化检测 (Change Detection)** 机制,可以在输入未变化时 +跳过 `eval()` 调用,形成混合 compiled/event 模型(参见 §5.6)。 + +### 3.2 tick() 时序更新 + +`tick()` 采用经典的 **两阶段更新协议**(compute-then-commit), +确保寄存器间无顺序依赖: + +```cpp +void tick() { + // Phase 1: 所有寄存器并行计算下一状态 + pyc_reg_271_inst.tick_compute(); + pyc_reg_272_inst.tick_compute(); + ... // × 256 个寄存器 + // Phase 2: 所有寄存器原子提交 + pyc_reg_271_inst.tick_commit(); + pyc_reg_272_inst.tick_commit(); + ... // × 256 个寄存器 +} +``` + +## 4. 与事件驱动仿真的对比 + +| 特性 | pyCircuit (Compiled-Code) | 事件驱动 (如 Verilator/iverilog) | +|---|---|---| +| **调度模型** | 无事件队列;支持可选变化检测 | 全局事件队列 + 敏感列表 | +| **Delta 周期** | 无;拓扑排序保证单遍收敛 | 需要 delta 迭代直到稳定 | +| **信号变化检测** | 可选 InputFingerprint 跳过 eval | 仅重新评估受影响的进程 | +| **时间模型** | 周期精确 (cycle-accurate) | 支持精细时间步 (time-step) | +| **代码生成** | 单一 C++ 结构体 + 内联函数 | 多线程调度器 + 进程模型 | +| **延迟建模** | 不支持门级延迟 | 支持 inertial/transport delay | +| **适用场景** | RTL 功能验证、高吞吐仿真 | 门级仿真、精确时序分析 | + +**pyCircuit 没有采用全局事件队列。** 它的核心是一个确定性的 +"对所有组合逻辑做一次完整拓扑排序求值 → 两阶段寄存器更新"循环。 +这种设计使得每个周期的执行路径完全确定,指令缓存友好,分支预测友好。 + +## 5. RegisterFile RTL 仿真基准测试 + +### 5.1 设计规格 + +| 参数 | 值 | +|---|---| +| 条目数 (ptag_count) | 256 | +| 常量 ROM 条目 (const_count) | 128 | +| 读端口 (nr) | 10 | +| 写端口 (nw) | 5 | +| 数据宽度 | 64 bit | +| 存储组织 | 2 bank × 128 entry × 32 bit | + +### 5.2 生成代码统计 + +| 指标 | 值 | +|---|---| +| 生成 C++ 行数 | 33,113 | +| Wire 信号成员 | ~17,590 | +| 寄存器实例 (pyc_reg) | 256 | +| 组合逻辑函数 (eval_comb) | 131 | +| tick_compute/commit 调用 | 各 256 次 | + +### 5.3 性能数据 + +测试环境:Apple M1 (arm64),macOS (darwin 25.2.0),Apple Clang 17。 +工作负载:每周期混合随机 10-路读 + 5-路写流量,100K cycles,取 5 次最优。 + +| 配置 | __TEXT 大小 | 耗时 | 吞吐量 | 加速比 | +|---|---|---|---|---| +| `-O2` baseline | 278 KB | 3.21 s | 31.2 Kcps | 1.00x | +| `-Os` (size-opt) | 246 KB | 2.46 s | 40.7 Kcps | 1.31x | +| `-Os` + SIMD + reg-opt | 262 KB | 2.58 s | 38.7 Kcps | 1.24x | +| `-O3 -flto` | 278 KB | 3.62 s | 27.7 Kcps | 0.89x | +| **PGO + `-O2` + SIMD** | **213 KB** | **1.69 s** | **59.1 Kcps** | **1.90x** | + +最佳配置(PGO + O2 + SIMD + pyc_reg 优化)实现了 **1.90x 加速**。 + +### 5.3.1 优化前后实测对比 + +| 指标 | 优化前 (`-O2`) | 优化后 (PGO+SIMD) | 提升 | +|---|---|---|---| +| 100K cycles 耗时 | 3.21 s | **1.69 s** | -47% | +| 吞吐量 | 31.2 Kcycles/s | **59.1 Kcycles/s** | +90% | +| 单周期耗时 | 32.10 μs | **16.93 μs** | -47% | +| __TEXT 代码大小 | 278 KB | **213 KB** | -23% | + +### 5.4 性能瓶颈分析与优化 + +**瓶颈诊断**: 生成代码的 `__TEXT` 段为 278 KB,远超 Apple M1 的 L1 +I-cache (192 KB/core)。`eval()` 函数体包含 131 个 eval_comb 子函数, +执行约 17,000 个信号赋值/MUX/位操作。这导致: + +1. **L1 I-cache thrashing**: eval() 代码无法完全放入 I-cache +2. **分支预测失效**: 大量 MUX 三元操作(`sel ? a : b`)创建不可预测分支 +3. **D-cache 压力**: ~17,590 个 Wire 成员 + 256 个 pyc_reg 实例,总计 > 100 KB + +**已实施的优化**: + +#### (1) NEON SIMD 向量化 (`pyc_bits.hpp`) + +为 `Wire` 的多 word(kWords ≥ 2,即宽度 > 64 bit)操作添加了 +ARM NEON 加速路径。每次处理 128 bit(2 × uint64_t): + +```cpp +// AND/OR/XOR: vld1q_u64 → vandq_u64/vorrq_u64/veorq_u64 → vst1q_u64 +// EQ compare: vceqq_u64 → lane reduce +// MUX select: vbslq_u64 (bitwise select, branch-free) +``` + +适用信号:`raddr_bus`(80b), `wdata_bus`(320b), `rdata_bus`(640b)。 +对此设计影响有限(多数操作在 ≤64b 信号上),但对宽数据路径设计显著有效。 + +#### (2) pyc_reg 优化 (`pyc_primitives.hpp`) + +- 使用 `__builtin_expect` 标注分支概率(negedge 远多于 posedge) +- 减少 `tick_compute` 中的分支数量 +- `tick_commit` 仅在 `pending` 时执行写入 + +#### (3) Profile-Guided Optimization (PGO) + +PGO 是最大的单一优化因素。流程: + +``` +# 1. 带插桩编译 +c++ -Os -fprofile-instr-generate ... -o lib_instr.dylib + +# 2. 运行训练负载(50K cycles) +LLVM_PROFILE_FILE=regfile.profraw python benchmark.py + +# 3. 合并 profile 数据 +xcrun llvm-profdata merge -output=regfile.profdata regfile.profraw + +# 4. 使用 profile 重新编译 +c++ -O2 -fprofile-instr-use=regfile.profdata ... -o lib_pgo.dylib +``` + +PGO 的效果: +- 编译器将冷路径(从未执行的 MUX 分支)优化为 size +- 热路径保持高度优化,布局紧凑 +- `__TEXT` 从 278 KB 降至 213 KB(-23%) +- 分支预测准确率大幅提升 + +#### (4) `-Os` 代码大小优化 + +`-O3` 反而比 `-O2` 慢(-11%),因为激进内联增大了 I-cache 压力。 +`-Os` 减少 `__TEXT` 至 246 KB 即获得 31% 加速,证实瓶颈是 I-cache。 + +### 5.5 优化因素分解 + +| 因素 | 单独贡献 | 说明 | +|---|---|---| +| PGO | ~1.86x | 解决 I-cache + 分支预测 | +| `-Os` 编译 | ~1.31x | 减少代码体积 | +| NEON SIMD | ~1.01x | 窄信号设计受益有限 | +| pyc_reg 优化 | ~1.01x | tick 仅占周期 <10% | + +**结论**: 对大型生成代码(> L1 I-cache),PGO 和代码大小优化比 +SIMD 向量化更有效。SIMD 的价值体现在宽数据路径密集的设计中。 + +### 5.6 信号变化检测 (Change Detection) + +**已实现。** 在 `pyc_change_detect.hpp` 中引入了混合 compiled/event 模型基础设施。 + +#### 核心组件 + +**`InputFingerprint`** — 跟踪一组输入信号的变化状态。 +使用 XOR-fold 哈希做快速拒绝,memcmp 做精确比较: + +```cpp +InputFingerprint<80, 5, 40, 320> fp(dut.raddr_bus, dut.wen_bus, + dut.waddr_bus, dut.wdata_bus); +// 每周期: +if (fp.check_and_capture()) { + dut.eval(); // 输入变化,必须重新求值 +} else { + // 输入未变化,跳过 eval() — 节省 ~17K 操作 +} +``` + +**`ChangeDetector`** — 跟踪单个 Wire 的变化(轻量级快照对比)。 + +**`EvalGuard`** — 包装 eval_comb 函数调用,仅在输入 +变化时执行(为编译器后端自动生成 guard 做准备)。 + +**`pyc_reg::posedge_tick_compute()` / `negedge_update()`** — 分离的 +时钟边沿方法。posedge 路径跳过 clkPrev 检查(调用者保证上升沿), +negedge 路径仅更新 clkPrev 标记,避免 256 次无效的 tick_compute 调用。 + +#### RegisterFile 变化检测实测数据 + +工作负载:100K cycles,按活动率混合随机/空闲周期。 + +| 活动率 | 100% active (baseline) | 50% active | 25% active | 10% active | 1% active | +|---|---|---|---|---|---| +| 耗时 (s) | 1.72 | 1.35 | 1.17 | 1.05 | 0.99 | +| 吞吐量 (Kcps) | 58.0 | 73.8 | 85.6 | 94.8 | 101.0 | +| 相对加速 | 1.00x | 1.27x | 1.48x | 1.63x | 1.74x | + +**结论**: 对活动率 50% 的设计(典型 CPU 流水线 stall 场景), +变化检测可提升 27%。对活动率 10% 的设计(外设/总线控制器), +可提升 63%。100% 活动时无额外开销(fingerprint 检查被内联后极轻量)。 + +### 5.7 自动化 PGO 构建 (pycircuit pgo-build) + +**已实现。** PGO 流程已集成到 `pycircuit.cli` 工具链,一条命令完成全流程。 + +#### 使用方式 + +```bash +# 基本用法(自动生成训练负载) +pycircuit pgo-build regfile_capi.cpp -o libregfile_sim.dylib -I include + +# 自定义训练命令 + 训练周期数 +pycircuit pgo-build regfile_capi.cpp -o libregfile_sim.dylib -I include \ + --train-cycles 50000 \ + --train-command "python3 my_benchmark.py" + +# 保留中间产物用于调试 +pycircuit pgo-build regfile_capi.cpp -o libregfile_sim.dylib -I include \ + --prof-dir ./pgo_profiles --keep-profiles + +# 指定编译器和优化标志 +pycircuit pgo-build regfile_capi.cpp -o libregfile_sim.dylib -I include \ + --cxx clang++ --opt-flags "-Os" --extra-flags "-march=native" +``` + +#### 自动化流程 + +``` +┌──────────────────────────────────────────────────────────────┐ +│ pycircuit pgo-build │ +│ │ +│ Step 1: 插桩编译 c++ -fprofile-generate → libinstr.dylib │ +│ Step 2: 训练运行 python3 _pgo_train.py (或自定义命令) │ +│ Step 3: Profile 合并 llvm-profdata merge → merged.profdata │ +│ Step 4: PGO 编译 c++ -fprofile-use → output.dylib │ +└──────────────────────────────────────────────────────────────┘ +``` + +#### CLI 参数 + +| 参数 | 默认值 | 说明 | +|---|---|---| +| `capi_source` | (必需) | C++ CAPI 封装源文件 | +| `-o, --output` | (必需) | 输出 .dylib / .so 路径 | +| `-I, --include-dir` | 自动检测 | 额外头文件目录 (可重复) | +| `--cxx` | `$CXX` 或 `c++` | C++ 编译器 | +| `--opt-flags` | `-O2` | 优化标志 | +| `--extra-flags` | (空) | 额外编译标志 | +| `--train-command` | 自动生成 | 自定义训练 shell 命令 | +| `--train-cycles` | 10000 | 自动训练的周期数 | +| `--prof-dir` | 临时目录 | Profile 数据存放目录 | +| `--keep-profiles` | false | 保留中间产物 | + +## 6. 多线程仿真可行性分析 + +### 6.1 当前架构的约束 + +当前仿真引擎是 **严格单线程** 的: + +1. **周期间串行依赖**: 周期 N+1 的 `eval()` 依赖周期 N 的 `tick_commit()` 结果, + 无法跨周期并行 +2. **周期内数据依赖**: `eval()` 内的 eval_comb 函数按拓扑排序调用, + 后序函数依赖前序函数的输出 +3. **共享状态**: 所有 Wire 信号是同一结构体的成员变量,没有内存隔离 + +### 6.2 可行的多线程改造方向 + +#### 方向 A: eval() 内部并行化(周期内并行) + +``` +eval_comb_0 ──┐ +eval_comb_1 ──┼── 独立子图 → Thread 0 +eval_comb_2 ──┘ +eval_comb_3 ──┐ +eval_comb_4 ──┼── 独立子图 → Thread 1 +eval_comb_5 ──┘ + └── barrier ──→ 依赖汇合 +eval_comb_6 ──── 需要两个子图的结果 → 单线程 +``` + +**可行性**: 中等。需要编译器在 MLIR 层做数据流分析, +识别不相互依赖的 eval_comb 子图,插入 barrier 同步点。 + +**挑战**: +- 线程同步开销(barrier、原子操作)每周期至少数百纳秒, + 而当前单周期仅 ~32 μs,同步开销占比可达 1-5% +- 对于像 RegisterFile 这样高度交叉的 MUX 网络, + 独立子图较少,可并行度有限 +- 需要保证 Wire 成员的缓存行对齐(避免 false sharing) + +**预期收益**: 对大型设计(eval 耗时 > 100 μs/cycle)可能有 1.5-3× 加速。 +对 RegisterFile 规模的设计,预期收益有限。 + +#### 方向 B: tick() 内部并行化(寄存器更新并行) + +``` +Thread 0: tick_compute() for reg[0..127] +Thread 1: tick_compute() for reg[128..255] +──── barrier ──── +Thread 0: tick_commit() for reg[0..127] +Thread 1: tick_commit() for reg[128..255] +``` + +**可行性**: 高。寄存器的 tick_compute 互相独立(只读共享状态, +写入各自的 qNext),天然适合数据并行。 + +**挑战**: +- tick() 通常只占每周期执行时间的一小部分(< 10%), + 大部分时间在 eval() +- 256 个寄存器的 tick_compute 每个仅几十纳秒, + 线程池调度开销可能 > 实际计算 + +**预期收益**: 微乎其微(< 5%)。除非寄存器数量极大(> 10K)。 + +#### 方向 C: 模块级并行化(多模块 SoC 设计) + +``` +┌──────────┐ ┌──────────┐ ┌──────────┐ +│ CPU Core │ │ RegFile │ │ Cache │ +│ Thread 0 │ │ Thread 1 │ │ Thread 2 │ +└─────┬─────┘ └─────┬─────┘ └────┬─────┘ + │ │ │ + └────── interface sync ──────────┘ +``` + +**可行性**: 低-中。需要 `pyc.instance` 保留层次边界, +各模块独立求值,接口处插入同步。 + +**挑战**: +- 当前 `pyc-compile` 会内联所有子模块(不支持 `pyc.instance`) +- 模块间组合路径(如 bypass 网络)跨越边界,需要迭代稳定 +- 需要重新设计编译器后端以保留层次结构 + +**预期收益**: 对大型 SoC(数十个模块)可能有 2-8× 加速, +但需要大量编译器和运行时工程。 + +#### 方向 D: SIMD 向量化(已实现) + +已在 `pyc_bits.hpp` 中为 ARM NEON 添加了加速路径: + +```cpp +// kWords >= 2 时自动使用 NEON (128-bit = 2×uint64) +// AND: vandq_u64, OR: vorrq_u64, XOR: veorq_u64, NOT: vmvnq_u8 +// EQ: vceqq_u64 + lane reduce +// MUX: vbslq_u64 (bitwise select, branch-free) +``` + +**实测结果**: 对以窄信号(≤64b)为主的 RegisterFile 设计, +SIMD 贡献约 1.01x。对宽数据路径密集的设计(如 512-bit AXI 总线), +预期 1.5-2x 加速。 + +#### 方向 E: Profile-Guided Optimization(已实现,效果最佳) + +PGO 让编译器基于实际运行 profile 优化代码布局: +- 将冷路径压缩(-Os),热路径保持优化 +- 改善分支预测准确率 +- `__TEXT` 从 278 KB 降至 213 KB(-23%) + +**实测结果**: 单独贡献 **1.86x 加速**,是目前最有效的单一优化手段。 + +### 6.3 总结与建议 + +| 方向 | 可行性 | 改造成本 | 实测/预期加速 | 适用规模 | +|---|---|---|---|---| +| **E: PGO** | **高** | **低 (CLI 已自动化)** | **1.86x (实测)** | **所有大型设计** | +| **F: 变化检测** | **高** | **低 (已实现)** | **1.27-1.74x (实测)** | **活动率 < 100%** | +| D: SIMD 向量化 | 高 | 中 (运行时) | 1.01x (窄) / ~2x (宽) | 宽数据路径 | +| `-Os` 编译 | 高 | 无 | 1.31x (实测) | __TEXT > L1 I$ | +| A: eval 内部并行 | 中 | 高 (编译器) | 1.5-3× (预期) | > 100 μs/cycle | +| B: tick 并行 | 高 | 低 (运行时) | < 1.1× (预期) | > 10K 寄存器 | +| C: 模块级并行 | 低-中 | 很高 (全栈) | 2-8× (预期) | SoC 级 | + +**已完成优化** (总加速 1.90x; 变化检测对低活动率设计可达 1.74x): +1. **PGO 构建流程**: `fprofile-instr-generate` → 训练 → `fprofile-instr-use` +2. **NEON SIMD**: `Wire` 多 word 位操作向量化 +3. **pyc_reg 优化**: `__builtin_expect` 分支提示 + posedge/negedge 分离 +4. **`-Os` 编译标志**: 作为非 PGO 场景的推荐默认 +5. ✅ **信号变化检测**: `InputFingerprint` / `ChangeDetector` / `EvalGuard` + 基础设施,跳过输入未变化周期的 `eval()` 调用。 + 实测:10% 活动率时 +63%,50% 活动率时 +27% +6. ✅ **自动化 PGO 构建**: `pycircuit pgo-build` CLI 子命令, + 一条命令完成 instrumented build → training → profile merge → PGO build + +**短期建议**: +7. 编译器后端自动生成 **per-eval_comb guard**, + 利用 `EvalGuard` 实现细粒度变化检测(当前为 DUT 级粗粒度) +8. 为大型设计启用 **编译期常量折叠**,消除 const ROM 的运行时求值 + +**中期建议**: +9. 在编译器中实现 **eval 子图分区**,为方向 A 做准备 +10. 编译器后端自动生成 `tick_posedge()` / `tick_negedge()` 方法 + +**长期建议**: +11. 实现模块级并行(方向 C),需要重新设计编译后端的实例化策略 +12. 探索 **GPU 加速仿真**:将宽位操作和 MUX 树映射到 GPU compute shader, + 适合极大规模(> 1M gate)的全芯片仿真 diff --git a/docs/tutorial/cycle-aware-computing.md b/docs/tutorial/cycle-aware-computing.md index b8c1cf3..d8e8541 100644 --- a/docs/tutorial/cycle-aware-computing.md +++ b/docs/tutorial/cycle-aware-computing.md @@ -37,3 +37,27 @@ See `docs/TESTBENCH.md` for the full `Tb` API. These contracts are enforced via MLIR-level verifiers/passes (see `docs/updatePLAN.md`). +## Occurrence cycles on combinational assigns + +**Primary style:** `clk = m.clock(...)` returns a **`ClockHandle`**. Use +**`clk.next()`** to advance the domain’s **current occurrence cycle**. Assigns +to **`named_wire`** targets then get **`dst_cycle = clk.cycle`** and +**`src_cycle`** from the RHS expression; `pycc` runs **`pyc-cycle-balance`** to +insert shared `pyc.reg` delays when needed. + +```python +clk = m.clock("clk") +raw = m.input("x", width=8) +clk.next() +w = m.named_wire("stage1_view", width=8) +m.assign(w, raw) +``` + +**Explicit** metadata is still supported: + +```python +m.assign(w, raw, dst_cycle=1, src_cycle=0) +``` + +See `docs/pyCircuit_Tutorial.md` §3.1 and `docs/cycle_balance_improvement.md`. + diff --git a/flows/scripts/lib.sh b/flows/scripts/lib.sh index 2f06840..7815d55 100755 --- a/flows/scripts/lib.sh +++ b/flows/scripts/lib.sh @@ -152,8 +152,8 @@ pyc_pythonpath() { fi # Prefer editable install (`pip install -e .`), but fall back to PYTHONPATH for - # repo-local runs. - echo "${PYC_ROOT_DIR}/compiler/frontend:${PYC_ROOT_DIR}/designs" + # repo-local runs. iplib/ is the standard IP library (RegFile, FIFO, Cache, …). + echo "${PYC_ROOT_DIR}/compiler/frontend:${PYC_ROOT_DIR}/designs:${PYC_ROOT_DIR}" } pyc_out_root() { diff --git a/include/cpp/pyc_async_fifo.hpp b/include/cpp/pyc_async_fifo.hpp new file mode 120000 index 0000000..19a114e --- /dev/null +++ b/include/cpp/pyc_async_fifo.hpp @@ -0,0 +1 @@ +../../runtime/cpp/pyc_async_fifo.hpp \ No newline at end of file diff --git a/include/cpp/pyc_bits.hpp b/include/cpp/pyc_bits.hpp new file mode 120000 index 0000000..7078b7f --- /dev/null +++ b/include/cpp/pyc_bits.hpp @@ -0,0 +1 @@ +../../runtime/cpp/pyc_bits.hpp \ No newline at end of file diff --git a/include/cpp/pyc_byte_mem.hpp b/include/cpp/pyc_byte_mem.hpp new file mode 120000 index 0000000..03cba51 --- /dev/null +++ b/include/cpp/pyc_byte_mem.hpp @@ -0,0 +1 @@ +../../runtime/cpp/pyc_byte_mem.hpp \ No newline at end of file diff --git a/include/cpp/pyc_cdc_sync.hpp b/include/cpp/pyc_cdc_sync.hpp new file mode 120000 index 0000000..959ede8 --- /dev/null +++ b/include/cpp/pyc_cdc_sync.hpp @@ -0,0 +1 @@ +../../runtime/cpp/pyc_cdc_sync.hpp \ No newline at end of file diff --git a/include/cpp/pyc_clock.hpp b/include/cpp/pyc_clock.hpp new file mode 120000 index 0000000..632c3b4 --- /dev/null +++ b/include/cpp/pyc_clock.hpp @@ -0,0 +1 @@ +../../runtime/cpp/pyc_clock.hpp \ No newline at end of file diff --git a/include/cpp/pyc_connector.hpp b/include/cpp/pyc_connector.hpp new file mode 120000 index 0000000..8a65a48 --- /dev/null +++ b/include/cpp/pyc_connector.hpp @@ -0,0 +1 @@ +../../runtime/cpp/pyc_connector.hpp \ No newline at end of file diff --git a/include/cpp/pyc_debug.hpp b/include/cpp/pyc_debug.hpp new file mode 120000 index 0000000..3fcb688 --- /dev/null +++ b/include/cpp/pyc_debug.hpp @@ -0,0 +1 @@ +../../runtime/cpp/pyc_debug.hpp \ No newline at end of file diff --git a/include/cpp/pyc_konata.hpp b/include/cpp/pyc_konata.hpp new file mode 120000 index 0000000..56ca660 --- /dev/null +++ b/include/cpp/pyc_konata.hpp @@ -0,0 +1 @@ +../../runtime/cpp/pyc_konata.hpp \ No newline at end of file diff --git a/include/cpp/pyc_linxtrace.hpp b/include/cpp/pyc_linxtrace.hpp new file mode 120000 index 0000000..1312b38 --- /dev/null +++ b/include/cpp/pyc_linxtrace.hpp @@ -0,0 +1 @@ +../../runtime/cpp/pyc_linxtrace.hpp \ No newline at end of file diff --git a/include/cpp/pyc_ops.hpp b/include/cpp/pyc_ops.hpp new file mode 120000 index 0000000..3d89f85 --- /dev/null +++ b/include/cpp/pyc_ops.hpp @@ -0,0 +1 @@ +../../runtime/cpp/pyc_ops.hpp \ No newline at end of file diff --git a/include/cpp/pyc_primitives.hpp b/include/cpp/pyc_primitives.hpp new file mode 120000 index 0000000..ed3a650 --- /dev/null +++ b/include/cpp/pyc_primitives.hpp @@ -0,0 +1 @@ +../../runtime/cpp/pyc_primitives.hpp \ No newline at end of file diff --git a/include/cpp/pyc_print.hpp b/include/cpp/pyc_print.hpp new file mode 120000 index 0000000..85fde40 --- /dev/null +++ b/include/cpp/pyc_print.hpp @@ -0,0 +1 @@ +../../runtime/cpp/pyc_print.hpp \ No newline at end of file diff --git a/include/cpp/pyc_probe_registry.hpp b/include/cpp/pyc_probe_registry.hpp new file mode 120000 index 0000000..5252b94 --- /dev/null +++ b/include/cpp/pyc_probe_registry.hpp @@ -0,0 +1 @@ +../../runtime/cpp/pyc_probe_registry.hpp \ No newline at end of file diff --git a/include/cpp/pyc_runtime.hpp b/include/cpp/pyc_runtime.hpp new file mode 120000 index 0000000..9cde197 --- /dev/null +++ b/include/cpp/pyc_runtime.hpp @@ -0,0 +1 @@ +../../runtime/cpp/pyc_runtime.hpp \ No newline at end of file diff --git a/include/cpp/pyc_sim.hpp b/include/cpp/pyc_sim.hpp new file mode 120000 index 0000000..80dbb3c --- /dev/null +++ b/include/cpp/pyc_sim.hpp @@ -0,0 +1 @@ +../../runtime/cpp/pyc_sim.hpp \ No newline at end of file diff --git a/include/cpp/pyc_sync_mem.hpp b/include/cpp/pyc_sync_mem.hpp new file mode 120000 index 0000000..6e4b0e5 --- /dev/null +++ b/include/cpp/pyc_sync_mem.hpp @@ -0,0 +1 @@ +../../runtime/cpp/pyc_sync_mem.hpp \ No newline at end of file diff --git a/include/cpp/pyc_tb.hpp b/include/cpp/pyc_tb.hpp new file mode 120000 index 0000000..3c6ec83 --- /dev/null +++ b/include/cpp/pyc_tb.hpp @@ -0,0 +1 @@ +../../runtime/cpp/pyc_tb.hpp \ No newline at end of file diff --git a/include/cpp/pyc_trace_bin.hpp b/include/cpp/pyc_trace_bin.hpp new file mode 120000 index 0000000..e286555 --- /dev/null +++ b/include/cpp/pyc_trace_bin.hpp @@ -0,0 +1 @@ +../../runtime/cpp/pyc_trace_bin.hpp \ No newline at end of file diff --git a/include/cpp/pyc_vcd.hpp b/include/cpp/pyc_vcd.hpp new file mode 120000 index 0000000..b2b3ec4 --- /dev/null +++ b/include/cpp/pyc_vcd.hpp @@ -0,0 +1 @@ +../../runtime/cpp/pyc_vcd.hpp \ No newline at end of file diff --git a/include/cpp/pyc_vec.hpp b/include/cpp/pyc_vec.hpp new file mode 120000 index 0000000..9f67c5e --- /dev/null +++ b/include/cpp/pyc_vec.hpp @@ -0,0 +1 @@ +../../runtime/cpp/pyc_vec.hpp \ No newline at end of file diff --git a/include/pyc/cpp/pyc_async_fifo.hpp b/include/pyc/cpp/pyc_async_fifo.hpp new file mode 120000 index 0000000..7b79737 --- /dev/null +++ b/include/pyc/cpp/pyc_async_fifo.hpp @@ -0,0 +1 @@ +../../../runtime/cpp/pyc_async_fifo.hpp \ No newline at end of file diff --git a/include/pyc/cpp/pyc_bits.hpp b/include/pyc/cpp/pyc_bits.hpp new file mode 120000 index 0000000..1dac521 --- /dev/null +++ b/include/pyc/cpp/pyc_bits.hpp @@ -0,0 +1 @@ +../../../runtime/cpp/pyc_bits.hpp \ No newline at end of file diff --git a/include/pyc/cpp/pyc_byte_mem.hpp b/include/pyc/cpp/pyc_byte_mem.hpp new file mode 120000 index 0000000..a71d0eb --- /dev/null +++ b/include/pyc/cpp/pyc_byte_mem.hpp @@ -0,0 +1 @@ +../../../runtime/cpp/pyc_byte_mem.hpp \ No newline at end of file diff --git a/include/pyc/cpp/pyc_cdc_sync.hpp b/include/pyc/cpp/pyc_cdc_sync.hpp new file mode 120000 index 0000000..c1eb654 --- /dev/null +++ b/include/pyc/cpp/pyc_cdc_sync.hpp @@ -0,0 +1 @@ +../../../runtime/cpp/pyc_cdc_sync.hpp \ No newline at end of file diff --git a/include/pyc/cpp/pyc_change_detect.hpp b/include/pyc/cpp/pyc_change_detect.hpp new file mode 120000 index 0000000..4c2a946 --- /dev/null +++ b/include/pyc/cpp/pyc_change_detect.hpp @@ -0,0 +1 @@ +../../../runtime/cpp/pyc_change_detect.hpp \ No newline at end of file diff --git a/include/pyc/cpp/pyc_clock.hpp b/include/pyc/cpp/pyc_clock.hpp new file mode 120000 index 0000000..d5e2ab7 --- /dev/null +++ b/include/pyc/cpp/pyc_clock.hpp @@ -0,0 +1 @@ +../../../runtime/cpp/pyc_clock.hpp \ No newline at end of file diff --git a/include/pyc/cpp/pyc_connector.hpp b/include/pyc/cpp/pyc_connector.hpp new file mode 120000 index 0000000..269946d --- /dev/null +++ b/include/pyc/cpp/pyc_connector.hpp @@ -0,0 +1 @@ +../../../runtime/cpp/pyc_connector.hpp \ No newline at end of file diff --git a/include/pyc/cpp/pyc_debug.hpp b/include/pyc/cpp/pyc_debug.hpp new file mode 120000 index 0000000..bfd4137 --- /dev/null +++ b/include/pyc/cpp/pyc_debug.hpp @@ -0,0 +1 @@ +../../../runtime/cpp/pyc_debug.hpp \ No newline at end of file diff --git a/include/pyc/cpp/pyc_konata.hpp b/include/pyc/cpp/pyc_konata.hpp new file mode 120000 index 0000000..309e539 --- /dev/null +++ b/include/pyc/cpp/pyc_konata.hpp @@ -0,0 +1 @@ +../../../runtime/cpp/pyc_konata.hpp \ No newline at end of file diff --git a/include/pyc/cpp/pyc_linxtrace.hpp b/include/pyc/cpp/pyc_linxtrace.hpp new file mode 120000 index 0000000..f4a0136 --- /dev/null +++ b/include/pyc/cpp/pyc_linxtrace.hpp @@ -0,0 +1 @@ +../../../runtime/cpp/pyc_linxtrace.hpp \ No newline at end of file diff --git a/include/pyc/cpp/pyc_ops.hpp b/include/pyc/cpp/pyc_ops.hpp new file mode 120000 index 0000000..b4da006 --- /dev/null +++ b/include/pyc/cpp/pyc_ops.hpp @@ -0,0 +1 @@ +../../../runtime/cpp/pyc_ops.hpp \ No newline at end of file diff --git a/include/pyc/cpp/pyc_primitives.hpp b/include/pyc/cpp/pyc_primitives.hpp new file mode 120000 index 0000000..334ab64 --- /dev/null +++ b/include/pyc/cpp/pyc_primitives.hpp @@ -0,0 +1 @@ +../../../runtime/cpp/pyc_primitives.hpp \ No newline at end of file diff --git a/include/pyc/cpp/pyc_print.hpp b/include/pyc/cpp/pyc_print.hpp new file mode 120000 index 0000000..fc5d8c1 --- /dev/null +++ b/include/pyc/cpp/pyc_print.hpp @@ -0,0 +1 @@ +../../../runtime/cpp/pyc_print.hpp \ No newline at end of file diff --git a/include/pyc/cpp/pyc_probe_registry.hpp b/include/pyc/cpp/pyc_probe_registry.hpp new file mode 120000 index 0000000..dcd0884 --- /dev/null +++ b/include/pyc/cpp/pyc_probe_registry.hpp @@ -0,0 +1 @@ +../../../runtime/cpp/pyc_probe_registry.hpp \ No newline at end of file diff --git a/include/pyc/cpp/pyc_runtime.hpp b/include/pyc/cpp/pyc_runtime.hpp new file mode 120000 index 0000000..c793f10 --- /dev/null +++ b/include/pyc/cpp/pyc_runtime.hpp @@ -0,0 +1 @@ +../../../runtime/cpp/pyc_runtime.hpp \ No newline at end of file diff --git a/include/pyc/cpp/pyc_sim.hpp b/include/pyc/cpp/pyc_sim.hpp new file mode 120000 index 0000000..c1117d0 --- /dev/null +++ b/include/pyc/cpp/pyc_sim.hpp @@ -0,0 +1 @@ +../../../runtime/cpp/pyc_sim.hpp \ No newline at end of file diff --git a/include/pyc/cpp/pyc_sync_mem.hpp b/include/pyc/cpp/pyc_sync_mem.hpp new file mode 120000 index 0000000..77fd3e3 --- /dev/null +++ b/include/pyc/cpp/pyc_sync_mem.hpp @@ -0,0 +1 @@ +../../../runtime/cpp/pyc_sync_mem.hpp \ No newline at end of file diff --git a/include/pyc/cpp/pyc_tb.hpp b/include/pyc/cpp/pyc_tb.hpp new file mode 120000 index 0000000..7040494 --- /dev/null +++ b/include/pyc/cpp/pyc_tb.hpp @@ -0,0 +1 @@ +../../../runtime/cpp/pyc_tb.hpp \ No newline at end of file diff --git a/include/pyc/cpp/pyc_trace_bin.hpp b/include/pyc/cpp/pyc_trace_bin.hpp new file mode 120000 index 0000000..534ee2d --- /dev/null +++ b/include/pyc/cpp/pyc_trace_bin.hpp @@ -0,0 +1 @@ +../../../runtime/cpp/pyc_trace_bin.hpp \ No newline at end of file diff --git a/include/pyc/cpp/pyc_vcd.hpp b/include/pyc/cpp/pyc_vcd.hpp new file mode 120000 index 0000000..06a57b5 --- /dev/null +++ b/include/pyc/cpp/pyc_vcd.hpp @@ -0,0 +1 @@ +../../../runtime/cpp/pyc_vcd.hpp \ No newline at end of file diff --git a/include/pyc/cpp/pyc_vec.hpp b/include/pyc/cpp/pyc_vec.hpp new file mode 120000 index 0000000..2b0557f --- /dev/null +++ b/include/pyc/cpp/pyc_vec.hpp @@ -0,0 +1 @@ +../../../runtime/cpp/pyc_vec.hpp \ No newline at end of file diff --git a/iplib/__init__.py b/iplib/__init__.py new file mode 100644 index 0000000..3158a8f --- /dev/null +++ b/iplib/__init__.py @@ -0,0 +1,17 @@ +from .cache import Cache +from .mem2port import Mem2Port +from .picker import Picker +from .queue import FIFO +from .regfile import RegFile +from .sram import SRAM +from .stream import StreamSig + +__all__ = [ + "Cache", + "FIFO", + "Mem2Port", + "Picker", + "RegFile", + "SRAM", + "StreamSig", +] diff --git a/iplib/cache.py b/iplib/cache.py new file mode 100644 index 0000000..e5a5fe9 --- /dev/null +++ b/iplib/cache.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +from pycircuit.hw import Circuit, ClockDomain, Wire +from pycircuit.literals import u + + +def Cache( + m: Circuit, + cd: ClockDomain, + req_valid: Wire, + req_addr: Wire, + req_write: Wire, + req_wdata: Wire, + req_wmask: Wire, + *, + ways: int = 4, + sets: int = 64, + line_bytes: int = 64, + addr_width: int = 64, + data_width: int = 64, + write_back: bool = True, + write_allocate: bool = True, + replacement: str = "plru", +): + """Structural cache baseline. + + Default policy contract: + - write_back=True + - write_allocate=True + - replacement="plru" + + This pyc4.0 baseline is intentionally compact and hierarchy-preserving; it keeps + state visible to the compiler flow without flattening into primitive wires. + """ + + _ = (line_bytes, write_back, write_allocate, replacement) + clk_v = cd.clk + rst_v = cd.rst + + req_valid_w = req_valid + req_addr_w = req_addr + req_write_w = req_write + req_wdata_w = req_wdata + _req_wmask_w = req_wmask + _ = _req_wmask_w + ways_i = max(1, int(ways)) + sets_i = max(1, int(sets)) + set_bits = max(1, (sets_i - 1).bit_length()) + tag_bits = max(1, int(addr_width) - set_bits) + plru_bits = max(1, ways_i - 1) + way_idx_bits = max(1, (ways_i - 1).bit_length()) + + tags = [m.out(f"cache_tag_{i}", domain=cd, width=tag_bits, init=0) for i in range(ways_i)] + valids = [m.out(f"cache_valid_{i}", domain=cd, width=1, init=0) for i in range(ways_i)] + dirty = [m.out(f"cache_dirty_{i}", domain=cd, width=1, init=0) for i in range(ways_i)] + data = [m.out(f"cache_data_{i}", domain=cd, width=int(data_width), init=0) for i in range(ways_i)] + plru = m.out("cache_plru", domain=cd, width=plru_bits, init=0) + + req_tag = req_addr_w[set_bits : set_bits + tag_bits] + + hit = u(1, 0) + hit_data = u(int(data_width), 0) + hit_way = u(way_idx_bits, 0) + + for i in range(ways_i): + way_hit = valids[i].out() & (tags[i].out() == req_tag) + hit_data = way_hit._select_internal(data[i].out(), hit_data) + hit_way = way_hit._select_internal(u(way_idx_bits, i), hit_way) + hit = hit | way_hit + + victim_way = plru.out()[0:way_idx_bits] + + do_alloc = req_valid_w & (~hit) + do_write_hit = req_valid_w & req_write_w & hit + do_write_alloc = req_valid_w & req_write_w & do_alloc + + for i in range(ways_i): + sel_hit = hit & (hit_way == i) + sel_victim = do_alloc & (victim_way == i) + + tags[i].set(req_tag, when=sel_victim) + valids[i].set(1, when=sel_victim) + + data[i].set(req_wdata_w, when=sel_hit & req_write_w) + data[i].set(req_wdata_w, when=sel_victim & req_write_w) + + dirty[i].set(1, when=sel_hit & req_write_w) + dirty[i].set(do_write_alloc, when=sel_victim) + + plru.set(plru.out() + 1, when=req_valid_w) + + resp_valid = req_valid_w + resp_ready = req_valid_w + resp_hit = hit + resp_data = hit._select_internal(hit_data, u(int(data_width), 0)) + miss = req_valid_w & (~hit) + + return m.bundle_connector( + resp_valid=resp_valid, + resp_ready=resp_ready, + resp_hit=resp_hit, + resp_data=resp_data, + miss=miss, + ) diff --git a/iplib/mem2port.py b/iplib/mem2port.py new file mode 100644 index 0000000..e138aca --- /dev/null +++ b/iplib/mem2port.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +from pycircuit.dsl import Signal +from pycircuit.hw import Circuit, ClockDomain, Wire + + +class Mem2PortError(ValueError): + pass + + +def Mem2Port( + m: Circuit, + cd: ClockDomain, + ren0: Wire, + raddr0: Wire, + ren1: Wire, + raddr1: Wire, + wvalid: Wire, + waddr: Wire, + wdata: Wire, + wstrb: Wire, + *, + depth: int, +): + clk_v = cd.clk + rst_v = cd.rst + if not isinstance(clk_v, Signal) or clk_v.ty != "!pyc.clock": + raise Mem2PortError("Mem2Port domain clk must be !pyc.clock") + if not isinstance(rst_v, Signal) or rst_v.ty != "!pyc.reset": + raise Mem2PortError("Mem2Port domain rst must be !pyc.reset") + + ren0_w = ren0 + ren1_w = ren1 + wvalid_w = wvalid + raddr0_w = raddr0 + raddr1_w = raddr1 + waddr_w = waddr + wdata_w = wdata + wstrb_w = wstrb + if ren0_w.ty != "i1" or ren1_w.ty != "i1" or wvalid_w.ty != "i1": + raise Mem2PortError("Mem2Port ren0/ren1/wvalid must be i1") + + rdata0, rdata1 = m.sync_mem_dp( + clk_v, + rst_v, + ren0=ren0_w, + raddr0=raddr0_w, + ren1=ren1_w, + raddr1=raddr1_w, + wvalid=wvalid_w, + waddr=waddr_w, + wdata=wdata_w, + wstrb=wstrb_w, + depth=int(depth), + name="mem", + ) + + return m.bundle_connector( + rdata0=rdata0, + rdata1=rdata1, + ) diff --git a/iplib/picker.py b/iplib/picker.py new file mode 100644 index 0000000..f2ab8a7 --- /dev/null +++ b/iplib/picker.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +from pycircuit.hw import Circuit, Wire +from pycircuit.literals import u + + +def Picker( + m: Circuit, + req: Wire, + *, + width: int | None = None, +): + req_w = req + if not hasattr(req_w, "ty") or not str(req_w.ty).startswith("i"): + raise ValueError("Picker.req must be an integer wire") + w = int(width) if width is not None else int(req_w.width) + if w <= 0: + raise ValueError("Picker width must be > 0") + + idx_w = max(1, (w - 1).bit_length()) + grant = req_w & 0 + index = req_w[0:idx_w] & 0 + found = req_w[0] & 0 + + for i in range(w): + take = req_w[i] & ~found + grant = take._select_internal(u(w, 1 << i), grant) + index = take._select_internal(u(idx_w, i), index) + found = found | req_w[i] + + return m.bundle_connector( + valid=found, + grant=grant, + index=index, + ) diff --git a/iplib/queue.py b/iplib/queue.py new file mode 100644 index 0000000..9abca70 --- /dev/null +++ b/iplib/queue.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +from pycircuit.dsl import Signal +from pycircuit.hw import Circuit, ClockDomain, Wire + + +class FIFOError(ValueError): + pass + + +def FIFO( + m: Circuit, + cd: ClockDomain, + in_valid: Wire, + in_data: Wire, + out_ready: Wire, + *, + depth: int = 2, +): + clk_v = cd.clk + rst_v = cd.rst + if not isinstance(clk_v, Signal) or clk_v.ty != "!pyc.clock": + raise FIFOError("FIFO domain clk must be !pyc.clock") + if not isinstance(rst_v, Signal) or rst_v.ty != "!pyc.reset": + raise FIFOError("FIFO domain rst must be !pyc.reset") + + in_valid_w = in_valid + in_data_w = in_data + out_ready_w = out_ready + + if not isinstance(in_valid_w, Wire) or in_valid_w.ty != "i1": + raise FIFOError("FIFO.in_valid must be i1") + if not isinstance(in_data_w, Wire): + raise FIFOError("FIFO.in_data must be integer wire") + if not isinstance(out_ready_w, Wire) or out_ready_w.ty != "i1": + raise FIFOError("FIFO.out_ready must be i1") + + in_ready, out_valid, out_data = m.fifo( + clk_v, + rst_v, + in_valid=in_valid_w, + in_data=in_data_w, + out_ready=out_ready_w, + depth=int(depth), + ) + + return m.bundle_connector( + in_ready=in_ready, + out_valid=out_valid, + out_data=out_data, + ) diff --git a/iplib/regfile.py b/iplib/regfile.py new file mode 100644 index 0000000..7b9342c --- /dev/null +++ b/iplib/regfile.py @@ -0,0 +1,124 @@ +from __future__ import annotations + +from pycircuit.dsl import Signal +from pycircuit.hw import Circuit, ClockDomain, Wire +from pycircuit.literals import u + + +class RegFileError(ValueError): + """Invalid RegFile port wiring.""" + + +def RegFile( + m: Circuit, + cd: ClockDomain, + raddr_bus: Wire, + wen_bus: Wire, + waddr_bus: Wire, + wdata_bus: Wire, + *, + ptag_count: int = 256, + const_count: int = 128, + nr: int = 10, + nw: int = 5, +): + ptag_n = int(ptag_count) + const_n = int(const_count) + nr_n = int(nr) + nw_n = int(nw) + if ptag_n <= 0: + raise ValueError("RegFile ptag_count must be > 0") + if const_n < 0 or const_n > ptag_n: + raise ValueError("RegFile const_count must satisfy 0 <= const_count <= ptag_count") + if nr_n <= 0: + raise ValueError("RegFile nr must be > 0") + if nw_n <= 0: + raise ValueError("RegFile nw must be > 0") + ptag_w = max(1, (ptag_n - 1).bit_length()) + + clk_v = cd.clk + rst_v = cd.rst + if not isinstance(clk_v, Signal) or clk_v.ty != "!pyc.clock": + raise RegFileError("RegFile domain clk must be !pyc.clock") + if not isinstance(rst_v, Signal) or rst_v.ty != "!pyc.reset": + raise RegFileError("RegFile domain rst must be !pyc.reset") + + raddr_bus_w = raddr_bus + wen_bus_w = wen_bus + waddr_bus_w = waddr_bus + wdata_bus_w = wdata_bus + + exp_raddr_w = nr_n * ptag_w + exp_wen_w = nw_n + exp_waddr_w = nw_n * ptag_w + exp_wdata_w = nw_n * 64 + + if raddr_bus_w.width != exp_raddr_w: + raise RegFileError(f"RegFile.raddr_bus must be i{exp_raddr_w}") + if wen_bus_w.width != exp_wen_w: + raise RegFileError(f"RegFile.wen_bus must be i{exp_wen_w}") + if waddr_bus_w.width != exp_waddr_w: + raise RegFileError(f"RegFile.waddr_bus must be i{exp_waddr_w}") + if wdata_bus_w.width != exp_wdata_w: + raise RegFileError(f"RegFile.wdata_bus must be i{exp_wdata_w}") + + storage_depth = ptag_n - const_n + bank0 = [m.out(f"rf_bank0_{i}", domain=cd, width=32, init=u(32, 0)) for i in range(storage_depth)] + bank1 = [m.out(f"rf_bank1_{i}", domain=cd, width=32, init=u(32, 0)) for i in range(storage_depth)] + + raddr_lanes = [raddr_bus_w[i * ptag_w : (i + 1) * ptag_w] for i in range(nr_n)] + wen_lanes = [wen_bus_w[i] for i in range(nw_n)] + waddr_lanes = [waddr_bus_w[i * ptag_w : (i + 1) * ptag_w] for i in range(nw_n)] + wdata_lanes = [wdata_bus_w[i * 64 : (i + 1) * 64] for i in range(nw_n)] + wdata_lo = [w[0:32] for w in wdata_lanes] + wdata_hi = [w[32:64] for w in wdata_lanes] + + # Multiple writes to the same storage PTAG in one cycle are intentionally + # left undefined by contract (strict no-conflict mode). + for sidx in range(storage_depth): + ptag = const_n + sidx + we_any = u(1, 0) + next_lo = bank0[sidx].out() + next_hi = bank1[sidx].out() + for lane in range(nw_n): + hit = wen_lanes[lane] & (waddr_lanes[lane] == u(ptag_w, ptag)) + we_any = we_any | hit + next_lo = hit._select_internal(wdata_lo[lane], next_lo) + next_hi = hit._select_internal(wdata_hi[lane], next_hi) + bank0[sidx].set(next_lo, when=we_any) + bank1[sidx].set(next_hi, when=we_any) + + cmp_w = ptag_w + 1 + rdata_lanes = [] + for lane in range(nr_n): + raddr_i = raddr_lanes[lane] + raddr_ext = raddr_i + u(cmp_w, 0) + is_valid = raddr_ext < u(cmp_w, ptag_n) + is_const = raddr_ext < u(cmp_w, const_n) + + if raddr_i.width > 32: + const32 = raddr_i[0:32] + else: + const32 = raddr_i + u(32, 0) + const64 = m.cat(const32, const32) + + store_lo = u(32, 0) + store_hi = u(32, 0) + for sidx in range(storage_depth): + ptag = const_n + sidx + hit = raddr_i == u(ptag_w, ptag) + store_lo = hit._select_internal(bank0[sidx].out(), store_lo) + store_hi = hit._select_internal(bank1[sidx].out(), store_hi) + store64 = m.cat(store_hi, store_lo) + + lane_data = is_const._select_internal(const64, store64) + lane_data = is_valid._select_internal(lane_data, u(64, 0)) + rdata_lanes.append(lane_data) + + rdata_bus_out = rdata_lanes[0] + for lane in range(1, nr_n): + rdata_bus_out = m.cat(rdata_lanes[lane], rdata_bus_out) + + return m.bundle_connector( + rdata_bus=rdata_bus_out, + ) diff --git a/iplib/sram.py b/iplib/sram.py new file mode 100644 index 0000000..95f67f3 --- /dev/null +++ b/iplib/sram.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +from pycircuit.dsl import Signal +from pycircuit.hw import Circuit, ClockDomain, Wire + + +class SRAMError(ValueError): + pass + + +def SRAM( + m: Circuit, + cd: ClockDomain, + ren: Wire, + raddr: Wire, + wvalid: Wire, + waddr: Wire, + wdata: Wire, + wstrb: Wire, + *, + depth: int, +): + clk_v = cd.clk + rst_v = cd.rst + if not isinstance(clk_v, Signal) or clk_v.ty != "!pyc.clock": + raise SRAMError("SRAM domain clk must be !pyc.clock") + if not isinstance(rst_v, Signal) or rst_v.ty != "!pyc.reset": + raise SRAMError("SRAM domain rst must be !pyc.reset") + + ren_w = ren + wvalid_w = wvalid + raddr_w = raddr + waddr_w = waddr + wdata_w = wdata + wstrb_w = wstrb + if ren_w.ty != "i1" or wvalid_w.ty != "i1": + raise SRAMError("SRAM ren/wvalid must be i1") + + rdata = m.sync_mem( + clk_v, + rst_v, + ren=ren_w, + raddr=raddr_w, + wvalid=wvalid_w, + waddr=waddr_w, + wdata=wdata_w, + wstrb=wstrb_w, + depth=int(depth), + name="mem", + ) + + return m.bundle_connector( + rdata=rdata, + ) diff --git a/iplib/stream.py b/iplib/stream.py new file mode 100644 index 0000000..af46b25 --- /dev/null +++ b/iplib/stream.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +from pycircuit.spec.types import BundleSpec, SignatureSpec, StructSpec + + +def StreamSig( + *, + name: str = "stream", + payload: StructSpec | BundleSpec | None = None, + payload_prefix: str = "payload", + valid_name: str = "valid", + ready_name: str = "ready", +) -> SignatureSpec: + """Create a strict ready/valid stream signature (producer perspective). + + Producer perspective: + - `valid`: out + - `ready`: in + - `payload.*`: out + + Use `StreamSig(...).flip()` for the consumer perspective. + """ + + leaves: dict[str, tuple[str, int, bool]] = { + str(valid_name): ("out", 1, False), + str(ready_name): ("in", 1, False), + } + + if payload is not None: + if isinstance(payload, StructSpec): + for path, fld in payload.flatten_fields(): + leaves[f"{payload_prefix}.{path}"] = ("out", int(fld.width or 0), bool(fld.signed)) + elif isinstance(payload, BundleSpec): + for f in payload.fields: + leaves[f"{payload_prefix}.{f.name}"] = ("out", int(f.width), bool(f.signed)) + else: + raise TypeError(f"StreamSig payload must be StructSpec or BundleSpec, got {type(payload).__name__}") + + return SignatureSpec.from_leaf_map(name=str(name), fields=leaves) diff --git a/janus/docs/TMU_SPEC.md b/janus/docs/TMU_SPEC.md new file mode 100644 index 0000000..6ae6a53 --- /dev/null +++ b/janus/docs/TMU_SPEC.md @@ -0,0 +1,781 @@ +# Janus TMU (Tile Management Unit) 微架构规格书 + +> 版本: 1.0 +> 日期: 2026-02-10 +> 实现代码: `janus/pyc/janus/tmu/janus_tmu_pyc.py` + +--- + +## 1. 概述 + +### 1.1 TMU 在 Janus 中的定位 + +Janus 是一个 AI 执行单元,由以下五个核心模块组成: + +| 模块 | 全称 | 功能 | +|------|------|------| +| **BCC** | Block Control Core | 标量控制核,负责指令调度与流程控制 | +| **TMU** | Tile Management Unit | Tile 寄存器文件管理单元,通过 Ring 互联提供高带宽数据访问 | +| **VectorCore** | 向量执行核 | 执行向量运算(load/store 通过 TMU 访问 TileReg) | +| **Cube** | 矩阵乘计算单元 | 基于 Systolic Array 的矩阵乘法引擎 | +| **TMA** | Tile Memory Access | 负责 TileReg 与外部 DDR 之间的数据搬运 | + +TMU 是 Janus 的**片上数据枢纽**,管理一块名为 **TileReg** 的可配置 SRAM 缓冲区(默认 1MB),通过 **8 站点双向 Ring 互联网络**为各个计算核提供高带宽、低延迟的数据读写服务。 + +### 1.2 设计目标 + +- **峰值带宽**: 256B x 8 / cycle = 2048B/cycle +- **低延迟**: 本地访问(node 访问自身 pipe)仅需 4 cycle +- **确定性路由**: 静态最短路径路由,无动态路由 +- **无活锁/饿死**: 通过 Tag 机制和 Round-Robin 仲裁保证公平性 +- **可配置容量**: TileReg 大小可通过参数配置(默认 1MB) + +--- + +## 2. 顶层架构 + +### 2.1 系统框图 + +``` + ┌─────────────────────────────────────────────┐ + │ TMU │ + │ │ + Vector port0 ──── │── node0 ──── pipe0 (128KB SRAM) │ + Cube port0 ──── │── node1 ──── pipe1 (128KB SRAM) │ + Vector port1 ──── │── node2 ──── pipe2 (128KB SRAM) │ + Cube port1 ──── │── node3 ──── pipe3 (128KB SRAM) │ + Vector port2 ──── │── node4 ──── pipe4 (128KB SRAM) │ + TMA port0 ──── │── node5 ──── pipe5 (128KB SRAM) │ + BCC/CSU ──── │── node6 ──── pipe6 (128KB SRAM) │ + TMA port1 ──── │── node7 ──── pipe7 (128KB SRAM) │ + │ │ + │ Ring Interconnect (CW/CC) │ + └─────────────────────────────────────────────┘ +``` + +### 2.2 Node-Pipe 映射关系 + +| Pipe | Node | 外部连接 | 用途 | +|------|------|----------|------| +| pipe0 | node0 | Vector port0 | Vector 内部 load 指令的访问通道 | +| pipe1 | node1 | Cube port0 | Cube 的读数据通道 | +| pipe2 | node2 | Vector port1 | Vector 内部 load 指令的访问通道 | +| pipe3 | node3 | Cube port1 | Cube 的写数据通道 | +| pipe4 | node4 | Vector port2 | Vector 内部 store 指令的访问通道 | +| pipe5 | node5 | TMA port0 | TMA 读数据通道(TStore: TileReg -> DDR) | +| pipe6 | node6 | BCC/CSU | 预留给 BCC 命令/响应或 CSU | +| pipe7 | node7 | TMA port1 | TMA 写数据通道(TLoad: DDR -> TileReg) | + +### 2.3 每个 CS (Station) 的能力 + +- 每个 CS 支持挂载**最多 3 个节点**(当前实现每个 CS 挂载 1 个节点) +- 每个 CS 支持**同拍上下 Ring**(请求 Ring 和响应 Ring 完全独立并行) +- 每个 CS 可同时向 CW 和 CC 两个方向各发出/接收一个 flit + +--- + +## 3. Ring 互联网络 + +### 3.1 拓扑结构 + +Ring 采用**双向环形拓扑**,8 个 station 按以下物理顺序连接: + +``` +RING_ORDER = [0, 1, 3, 5, 7, 6, 4, 2] +``` + +即 node 之间的连接关系为: + +``` +node0 <-> node1 <-> node3 <-> node5 <-> node7 <-> node6 <-> node4 <-> node2 <-> node0 +``` + +用环形图表示: + +``` + node0 + / \ + node2 node1 + | | + node4 node3 + | | + node6 node5 + \ / + node7 +``` + +### 3.2 双向车道 + +Ring 支持两个方向的数据流动: + +| 方向 | 缩写 | 含义 | +|------|------|------| +| Clockwise | CW | 顺时针方向:沿 RING_ORDER 正序流动 (0→1→3→5→7→6→4→2→0) | +| Counter-Clockwise | CC | 逆时针方向:沿 RING_ORDER 逆序流动 (0→2→4→6→7→5→3→1→0) | + +### 3.3 独立 Ring 通道 + +TMU 内部包含**四条独立的 Ring 通道**: + +| Ring 通道 | 方向 | 用途 | +|-----------|------|------| +| req_cw | CW | 请求 Ring 顺时针通道 | +| req_cc | CC | 请求 Ring 逆时针通道 | +| rsp_cw | CW | 响应 Ring 顺时针通道 | +| rsp_cc | CC | 响应 Ring 逆时针通道 | + +请求 Ring 和响应 Ring 完全解耦,可并行工作。 + +### 3.4 路由策略 + +采用**静态最短路径路由**,在编译时预计算每对 (src, dst) 的最优方向: + +```python +CW_PREF[src][dst] = 1 # 如果 CW 方向跳数 <= CC 方向跳数 +CW_PREF[src][dst] = 0 # 如果 CC 方向跳数更短 +``` + +**路由规则**: +- 不允许动态路由 +- 当 CW 和 CC 距离相等时,优先选择 CW +- 路由方向在请求注入 Ring 时确定,传输过程中不改变 + +### 3.5 Ring 跳数表 + +基于 RING_ORDER = [0, 1, 3, 5, 7, 6, 4, 2],各 node 之间的 Ring 跳数(最短路径): + +| src\dst | n0 | n1 | n2 | n3 | n4 | n5 | n6 | n7 | +|---------|----|----|----|----|----|----|----|----| +| **n0** | 0 | 1 | 1 | 2 | 2 | 3 | 3 | 4 | +| **n1** | 1 | 0 | 2 | 1 | 3 | 2 | 4 | 3 | +| **n2** | 1 | 2 | 0 | 3 | 1 | 4 | 2 | 3 | +| **n3** | 2 | 1 | 3 | 0 | 4 | 1 | 3 | 2 | +| **n4** | 2 | 3 | 1 | 4 | 0 | 3 | 1 | 2 | +| **n5** | 3 | 2 | 4 | 1 | 3 | 0 | 2 | 1 | +| **n6** | 3 | 4 | 2 | 3 | 1 | 2 | 0 | 1 | +| **n7** | 4 | 3 | 3 | 2 | 2 | 1 | 1 | 0 | + +--- + +## 4. Flit 格式 + +### 4.1 数据粒度 + +Ring 上传输的数据粒度为 **256 Bytes**(一个 cacheline),由 32 个 64-bit word 组成: + +``` +Flit Data = 32 x 64-bit words = 256 Bytes +``` + +### 4.2 请求 Flit Meta 格式 + +请求 flit 的 meta 信息打包在一个 64-bit 字段中: + +``` +[63 REQ_ADDR_LSB] [REQ_TAG_LSB] [REQ_DST_LSB] [REQ_SRC_LSB] [0] +|<------------- addr (20b) ---------->|<- tag (8b) ->|<- dst (3b) ->|<- src (3b) ->|<- write (1b) ->| +``` + +| 字段 | 位宽 | LSB | 含义 | +|------|------|-----|------| +| write | 1 | 0 | 读/写标志(1=写,0=读) | +| src | 3 (node_bits) | 1 | 源节点编号 | +| dst | 3 (node_bits) | 4 | 目的节点编号(= pipe 编号) | +| tag | 8 | 7 | 请求标签,用于匹配响应 | +| addr | 20 (addr_bits) | 15 | 字节地址 | + +### 4.3 响应 Flit Meta 格式 + +``` +[63 RSP_TAG_LSB] [RSP_DST_LSB] [RSP_SRC_LSB] [0] +|<-------- tag (8b) -------->|<- dst (3b) ->|<- src (3b) ->|<- write (1b) ->| +``` + +| 字段 | 位宽 | LSB | 含义 | +|------|------|-----|------| +| write | 1 | 0 | 原始请求的读/写标志 | +| src | 3 | 1 | 响应源(= pipe 编号) | +| dst | 3 | 4 | 响应目的(= 原始请求的 src) | +| tag | 8 | 7 | 原始请求的 tag,原样返回 | + +--- + +## 5. TileReg 存储结构 + +### 5.1 容量与划分 + +TileReg 是 TMU 管理的片上 SRAM 缓冲区: + +- **默认总容量**: 1MB (1,048,576 Bytes),可通过 `tile_bytes` 参数配置 +- **划分方式**: 均分为 8 个 **pipe**,每个 pipe 对应一块独立 SRAM +- **每 pipe 容量**: tile_bytes / 8 = 128KB(默认配置下) +- **每 pipe 行数**: pipe_bytes / 256 = 512 行(默认配置下) +- **每行大小**: 256 Bytes = 32 x 64-bit words + +``` +TileReg (1MB) +├── pipe0: 128KB SRAM (512 lines x 256B) ── node0 +├── pipe1: 128KB SRAM (512 lines x 256B) ── node1 +├── pipe2: 128KB SRAM (512 lines x 256B) ── node2 +├── pipe3: 128KB SRAM (512 lines x 256B) ── node3 +├── pipe4: 128KB SRAM (512 lines x 256B) ── node4 +├── pipe5: 128KB SRAM (512 lines x 256B) ── node5 +├── pipe6: 128KB SRAM (512 lines x 256B) ── node6 +└── pipe7: 128KB SRAM (512 lines x 256B) ── node7 +``` + +每个 pipe 内部由 32 个独立的 `byte_mem` 实例组成(每个 word 一个),支持单周期读写。 + +### 5.2 地址编码 + +以 1MB 容量为例,使用 20-bit 字节地址: + +``` +地址格式: [19:11] [10:8] [7:0] + index pipe offset + 9-bit 3-bit 8-bit +``` + +| 字段 | 位域 | 位宽 | 含义 | +|------|------|------|------| +| offset | [7:0] | 8 | 256B cacheline 内部的字节偏移 | +| pipe | [10:8] | 3 | 目标 pipe 编号(0~7),决定数据存储在哪个 SRAM | +| index | [19:11] | 9 | cacheline 在对应 pipe 中的行号(0~511) | + +**地址解码过程**: +1. 从请求地址中提取 `pipe = addr[10:8]`,确定目标 pipe(同时也是目标 node) +2. 提取 `index = addr[19:11]`,确定 pipe 内的行号 +3. `offset = addr[7:0]` 在当前实现中用于 256B 粒度内的字节定位 + +### 5.3 可配置性 + +| 参数 | 默认值 | 约束 | +|------|--------|------| +| `tile_bytes` | 1MB (2^20) | 必须是 8 x 256 = 2048 的整数倍 | +| `tag_bits` | 8 | 请求标签位宽 | +| `spb_depth` | 4 | SPB FIFO 深度 | +| `mgb_depth` | 4 | MGB FIFO 深度 | + +地址位宽根据 `tile_bytes` 自动计算: +``` +addr_bits = ceil(log2(tile_bytes)) # 20 for 1MB +offset_bits = ceil(log2(256)) = 8 +pipe_bits = ceil(log2(8)) = 3 +index_bits = addr_bits - offset_bits - pipe_bits # 9 for 1MB +``` + +--- + +## 6. 节点微架构 + +每个 node 包含以下组件: + +``` + ┌──────────────────────────────────┐ + │ Node i │ + │ │ + 外部请求 ──req_valid──> │ ┌─────────┐ ┌─────────┐ │ + (valid/ready) │ │ SPB_CW │ │ SPB_CC │ │ + req_write ────────────> │ │ depth=4 │ │ depth=4 │ │ + req_addr ─────────────> │ │ 1W2R │ │ 1W2R │ │ + req_tag ──────────────> │ └────┬────┘ └────┬────┘ │ + req_data[0:31] ───────> │ │ │ │ + <──── req_ready ─────── │ v v │ + │ ┌──────────────────────┐ │ + │ │ Request Ring │ │ + │ │ CW/CC 注入/转发 │ │ + │ └──────────────────────┘ │ + │ │ + │ ┌──────────────────────┐ │ + │ │ Pipe SRAM │ │ + │ │ (32 x byte_mem) │ │ + │ └──────────────────────┘ │ + │ │ + │ ┌──────────────────────┐ │ + │ │ Response Ring │ │ + │ │ CW/CC 注入/转发 │ │ + │ └──────────────────────┘ │ + │ │ │ │ + │ ┌────┴────┐ ┌────┴────┐ │ + │ │ MGB_CW │ │ MGB_CC │ │ + │ │ depth=4 │ │ depth=4 │ │ + │ │ 2W1R │ │ 2W1R │ │ + │ └────┬────┘ └────┬────┘ │ + │ │ RR 仲裁 │ │ + │ └──────┬───────┘ │ + <──── resp_valid ────── │ │ │ + <──── resp_tag ──────── │ v │ + <──── resp_data[0:31] ─ │ resp output │ + <──── resp_is_write ─── │ │ + ──── resp_ready ──────> │ │ + └──────────────────────────────────┘ +``` + +### 6.1 节点外部接口 + +每个 node 对外暴露以下信号: + +**请求通道(外部 -> TMU)**: + +| 信号 | 位宽 | 方向 | 含义 | +|------|------|------|------| +| `n{i}_req_valid` | 1 | input | 请求有效 | +| `n{i}_req_write` | 1 | input | 1=写请求,0=读请求 | +| `n{i}_req_addr` | 20 | input | 字节地址 | +| `n{i}_req_tag` | 8 | input | 请求标签(用于匹配响应) | +| `n{i}_req_data_w{0..31}` | 64 each | input | 写数据(32 个 64-bit word) | +| `n{i}_req_ready` | 1 | output | 请求就绪(反压信号) | + +**响应通道(TMU -> 外部)**: + +| 信号 | 位宽 | 方向 | 含义 | +|------|------|------|------| +| `n{i}_resp_valid` | 1 | output | 响应有效 | +| `n{i}_resp_tag` | 8 | output | 响应标签(与请求 tag 匹配) | +| `n{i}_resp_data_w{0..31}` | 64 each | output | 响应数据 | +| `n{i}_resp_is_write` | 1 | output | 标识原始请求是否为写操作 | +| `n{i}_resp_ready` | 1 | input | 外部准备好接收响应 | + +**握手协议**: 标准 valid/ready 握手。当 `valid & ready` 同时为高时,传输发生。 + +--- + +## 7. SPB (Send/Post Buffer) + +### 7.1 功能概述 + +SPB 是请求上 Ring 的缓冲区,位于每个 node 的请求注入端。每个 node 有两个 SPB: +- **SPB_CW**: 缓存将要向 CW 方向发送的请求 +- **SPB_CC**: 缓存将要向 CC 方向发送的请求 + +### 7.2 SPB 规格 + +| 参数 | 值 | +|------|-----| +| 深度 | 4 entries | +| 端口 | 1 写 2 读(一拍可同时 pick CW 和 CC 各一个请求上 Ring) | +| Bypass | **不支持** bypass SPB 上 Ring(请求必须先入 SPB 再注入 Ring) | +| 反压 | SPB 满时,`req_ready` 拉低,反压外部请求 | + +### 7.3 SPB 工作流程 + +1. 外部请求到达 node,根据 `CW_PREF[src][dst]` 确定方向 +2. 请求被写入对应方向的 SPB(CW 或 CC) +3. 当 Ring 对应方向的 slot 空闲时,SPB 头部的请求被注入 Ring +4. Ring 上已有 flit 优先前递(forward),SPB 注入优先级低于 Ring 转发 + +### 7.4 SPB 注入仲裁 + +``` +if ring_slot_has_flit: + forward flit (优先) + SPB 不注入 +else: + if SPB 非空 and 目的不是本地: + 注入 SPB 头部请求到 Ring +``` + +**本地请求优化**: 如果 SPB 头部请求的目的 node 就是本 node(即 src == dst),则该请求直接被弹出送往本地 pipe,不经过 Ring 传输。 + +--- + +## 8. MGB (Merge Buffer) + +### 8.1 功能概述 + +MGB 是响应下 Ring 的缓冲区,位于每个 node 的响应接收端。每个 node 有两个 MGB: +- **MGB_CW**: 缓存从 CW 方向到达的响应 +- **MGB_CC**: 缓存从 CC 方向到达的响应 + +### 8.2 MGB 规格 + +| 参数 | 值 | +|------|-----| +| 深度 | 4 entries | +| 端口 | 2 写 1 读(一拍可同时接收 CW 和 CC 各一个 flit,单路出队) | +| Bypass | **支持** bypass 下 Ring(队列为空且仅一个方向到达时可 bypass) | +| 反压 | MGB 满时,反压 Ring 上的响应注入 | + +### 8.3 MGB Bypass 机制 + +当满足以下条件时,响应可以 bypass MGB 直接输出: +- MGB 队列为空 +- 仅有一个方向(CW 或 CC)有到达的响应 +- 外部 `resp_ready` 为高 + +### 8.4 MGB 出队仲裁 + +当 CW 和 CC 两个 MGB 都有数据时,采用 **Round-Robin (RR)** 仲裁: + +``` +rr_reg: 1-bit 寄存器,每次出队后翻转 +if only CW has data: pick CW +if only CC has data: pick CC +if both have data: rr_reg==0 ? pick CW : pick CC +``` + +RR 仲裁确保两个方向的响应不会饿死。 + +--- + +## 9. 请求 Ring 数据通路 + +### 9.1 请求处理流水线 + +``` +外部请求 → SPB入队(1 cycle) → Ring传输(N hops) → Pipe SRAM访问(1 cycle) → 响应注入 +``` + +### 9.2 请求 Ring 每站逻辑 + +对于 Ring 上的每个 station(按 RING_ORDER 遍历),每拍执行以下逻辑: + +**Step 1: 检查到达的 Ring flit** +``` +cw_in = 从 CW 方向前一站到达的 flit +cc_in = 从 CC 方向后一站到达的 flit +``` + +**Step 2: 判断是否为本地请求(需要弹出到 pipe)** +``` +ring_cw_local = cw_in.valid AND (cw_in.dst == 本站 node_id) +ring_cc_local = cc_in.valid AND (cc_in.dst == 本站 node_id) +spb_cw_local = spb_cw.valid AND (spb_cw.dst == 本站 node_id) +spb_cc_local = spb_cc.valid AND (spb_cc.dst == 本站 node_id) +``` + +**Step 3: 优先级仲裁(弹出到 pipe)** +``` +优先级从高到低: +1. Ring CW 方向到达的本地请求 +2. Ring CC 方向到达的本地请求 +3. SPB CW 中目的为本地的请求 +4. SPB CC 中目的为本地的请求 +``` + +**Step 4: Ring 转发与 SPB 注入** +``` +CW 方向: + if cw_in 非本地: 转发 cw_in(优先) + else if SPB_CW 非空且非本地: 注入 SPB_CW 头部 + +CC 方向: + if cc_in 非本地: 转发 cc_in(优先) + else if SPB_CC 非空且非本地: 注入 SPB_CC 头部 +``` + +--- + +## 10. Pipe SRAM 访问 + +### 10.1 Pipe Stage 寄存器 + +从请求 Ring 弹出的请求先经过一级 **pipe stage 寄存器**(1 cycle 延迟),然后访问 SRAM: + +``` +pipe_req_valid → [pipe_stage_valid reg] → SRAM 读/写 +pipe_req_meta → [pipe_stage_meta reg] → 地址解码 +pipe_req_data → [pipe_stage_data reg] → 写数据 +``` + +### 10.2 SRAM 读写操作 + +**写操作**: +- 条件: `pipe_stage_valid & write` +- 将 32 个 64-bit word 写入对应 pipe 的 SRAM +- 写掩码: 全字节写入 (wstrb = 0xFF) +- 响应数据: 返回写入的数据本身 + +**读操作**: +- 条件: `pipe_stage_valid & ~write` +- 从对应 pipe 的 SRAM 读出 32 个 64-bit word +- 响应数据: 返回读出的数据 + +### 10.3 响应生成 + +SRAM 访问完成后,生成响应 flit: +``` +rsp_meta = pack(write, src=pipe_id, dst=原始请求的src, tag=原始请求的tag) +rsp_data = write ? 写入数据 : 读出数据 +rsp_dir = CW_PREF[pipe_id][原始请求的src] # 响应方向 +``` + +响应被送入对应方向的响应注入 FIFO(深度=4),等待注入响应 Ring。 + +--- + +## 11. 响应 Ring 数据通路 + +### 11.1 响应 Ring 每站逻辑 + +与请求 Ring 类似,但弹出目标是 MGB 而非 pipe: + +**Step 1: 检查到达的 Ring flit** +``` +cw_in = 从 CW 方向前一站到达的响应 flit +cc_in = 从 CC 方向后一站到达的响应 flit +``` + +**Step 2: 判断是否为本地响应** +``` +ring_cw_local = cw_in.valid AND (cw_in.dst == 本站 node_id) +ring_cc_local = cc_in.valid AND (cc_in.dst == 本站 node_id) +``` + +**Step 3: 本地响应送入 MGB** +``` +cw_local = ring_cw_local OR rsp_inject_cw_local +cc_local = ring_cc_local OR rsp_inject_cc_local +→ 分别送入 MGB_CW 和 MGB_CC +``` + +**Step 4: Ring 转发与响应注入** +``` +CW 方向: + if cw_in 非本地: 转发(优先) + else if rsp_inject_cw 非空且非本地: 注入 + +CC 方向: + if cc_in 非本地: 转发(优先) + else if rsp_inject_cc 非空且非本地: 注入 +``` + +### 11.2 MGB 出队到外部 + +``` +MGB_CW 和 MGB_CC 通过 RR 仲裁选择一个输出 +→ resp_valid, resp_tag, resp_data, resp_is_write +← resp_ready (外部反压) +``` + +--- + +## 12. 时序分析 + +### 12.1 延迟模型 + +一次完整的读/写操作延迟由以下阶段组成: + +| 阶段 | 延迟 | 说明 | +|------|------|------| +| SPB 入队 | 1 cycle | 请求写入 SPB | +| 请求 Ring 传输 | H hops | H = src 到 dst 的最短跳数 | +| Pipe Stage | 1 cycle | pipe stage 寄存器 | +| SRAM 访问 | 0 cycle | 与 pipe stage 同拍完成 | +| 响应 Ring 传输 | H hops | H = dst 到 src 的最短跳数(与请求相同) | +| MGB bypass/出队 | 1 cycle | 响应输出(bypass 时为 0) | + +**总延迟公式**: `Latency = 4 + 2 * H` cycles(最优情况,无竞争) + +其中 H 为 Ring 上的跳数。 + +### 12.2 典型延迟示例 + +**最短路径示例(Vector 访问 pipe2,H=1)**: + +``` +Cycle 1: Vector 请求到达 node2 → SPB 入队 +Cycle 2: SPB 注入请求 Ring → 请求到达 node2(本地,H=0 实际上是自访问) +Cycle 3: Pipe stage 寄存器 + SRAM 访问 +Cycle 4: 响应 bypass MGB 输出 → 数据可用 +总延迟: 4 cycles +``` + +**跨节点示例(node0 访问 pipe2,H=1)**: + +``` +Cycle 1: node0 请求 → SPB 入队 +Cycle 2: SPB 注入请求 Ring(CC 方向,node0→node2 跳 1 hop) +Cycle 3: 请求到达 node2 → 弹出到 pipe2 → pipe stage +Cycle 4: SRAM 访问完成 → 响应注入响应 Ring +Cycle 5: 响应传输 1 hop(node2→node0) +Cycle 6: 响应到达 node0 → MGB bypass 输出 +总延迟: 6 cycles = 4 + 2*1 +``` + +**远距离示例(node0 访问 pipe7,H=4)**: + +``` +总延迟: 4 + 2*4 = 12 cycles +``` + +### 12.3 各 node 自访问延迟 + +| 操作 | 延迟 | +|------|------| +| node_i 访问 pipe_i(自身 pipe) | 4 cycles | +| node_i 访问相邻 pipe(H=1) | 6 cycles | +| node_i 访问 H=2 的 pipe | 8 cycles | +| node_i 访问 H=3 的 pipe | 10 cycles | +| node_i 访问 H=4 的 pipe(最远) | 12 cycles | + +--- + +## 13. 反压与流控 + +### 13.1 请求侧反压 + +``` +req_ready = dir_cw ? SPB_CW.in_ready : SPB_CC.in_ready +``` + +当对应方向的 SPB 满(4 entries)时,`req_ready` 拉低,外部请求被阻塞。 + +### 13.2 Ring 反压 + +Ring 上的 flit 转发优先于 SPB 注入。当 Ring slot 被占用时,SPB 无法注入,但不会丢失数据(SPB 保持 flit 直到 slot 空闲)。 + +### 13.3 响应侧反压 + +MGB 满时,Ring 上到达本站的响应无法弹出,会继续在 Ring 上流转(实际上会阻塞 Ring 转发)。 + +外部 `resp_ready` 为低时,MGB 不出队,可能导致 MGB 满。 + +--- + +## 14. 防活锁/饿死机制 + +### 14.1 Tag 机制 + +- 每个请求携带 8-bit tag,响应原样返回 +- Tag 用于请求-响应匹配,确保外部可以区分不同请求的响应 +- Tag 不参与 Ring 路由决策 + +### 14.2 FIFO 顺序保证 + +- SPB 和 MGB 均为 FIFO 结构,保证同方向的请求/响应按序处理 +- 避免了乱序导致的活锁问题 + +### 14.3 Round-Robin 仲裁 + +- MGB 出队采用 RR 仲裁,确保 CW 和 CC 两个方向的响应公平出队 +- Pipe 访问时,Ring CW/CC 和 SPB CW/CC 四路请求按固定优先级仲裁 +- Ring 转发优先于 SPB 注入,保证 Ring 上的 flit 不会被无限阻塞 + +### 14.4 静态路由 + +- 最短路径静态路由消除了动态路由可能引入的活锁 +- 请求和响应走独立的 Ring,避免请求-响应死锁 + +--- + +## 15. 调试接口 + +TMU 提供以下调试输出信号,用于波形观察和可视化: + +| 信号 | 位宽 | 含义 | +|------|------|------| +| `dbg_req_cw_v{i}` | 1 | 请求 Ring CW 方向 node_i 处 link 寄存器 valid | +| `dbg_req_cc_v{i}` | 1 | 请求 Ring CC 方向 node_i 处 link 寄存器 valid | +| `dbg_req_cw_meta{i}` | variable | 请求 Ring CW 方向 node_i 处 meta 信息 | +| `dbg_req_cc_meta{i}` | variable | 请求 Ring CC 方向 node_i 处 meta 信息 | +| `dbg_rsp_cw_v{i}` | 1 | 响应 Ring CW 方向 node_i 处 link 寄存器 valid | +| `dbg_rsp_cc_v{i}` | 1 | 响应 Ring CC 方向 node_i 处 link 寄存器 valid | +| `dbg_rsp_cw_meta{i}` | variable | 响应 Ring CW 方向 node_i 处 meta 信息 | +| `dbg_rsp_cc_meta{i}` | variable | 响应 Ring CC 方向 node_i 处 meta 信息 | + +配套工具: +- `janus/tools/plot_tmu_trace.py`: 将 trace CSV 渲染为 SVG 时序图 +- `janus/tools/animate_tmu_trace.py`: 生成 Ring 拓扑动画 SVG +- `janus/tools/animate_tmu_ring_vcd.py`: 从 VCD 波形生成 Ring 动画 + +--- + +## 16. 实现代码结构 + +### 16.1 源文件 + +| 文件 | 用途 | +|------|------| +| `janus/pyc/janus/tmu/janus_tmu_pyc.py` | TMU RTL 实现(pyCircuit DSL) | +| `janus/tb/tb_janus_tmu_pyc.cpp` | C++ cycle-accurate 测试平台 | +| `janus/tb/tb_janus_tmu_pyc.sv` | SystemVerilog 测试平台 | +| `janus/tools/run_janus_tmu_pyc_cpp.sh` | C++ 仿真运行脚本 | +| `janus/tools/run_janus_tmu_pyc_verilator.sh` | Verilator 仿真运行脚本 | +| `janus/tools/update_tmu_generated.sh` | 重新生成 RTL 脚本 | +| `janus/generated/janus_tmu_pyc/` | 生成的 Verilog 和 C++ header | + +### 16.2 代码关键函数/区域 + +| 代码区域 | 行号范围 | 功能 | +|----------|----------|------| +| `RING_ORDER`, `CW_PREF` | L12-L34 | Ring 拓扑定义与路由表 | +| `_dir_cw()` | L37-L40 | 运行时路由方向选择 | +| `_build_bundle_fifo()` | L82-L129 | FIFO bundle 构建(SPB/MGB 共用) | +| `NodeIo` | L132-L144 | 节点 IO 定义 | +| `build()` 参数处理 | L147-L177 | 可配置参数与地址位宽计算 | +| Node IO 实例化 | L203-L232 | 8 个节点的 IO 端口创建 | +| SPB 构建 | L234-L290 | 每节点 CW/CC 两个 SPB | +| Ring link 寄存器 | L292-L331 | 请求/响应 Ring 的 link 寄存器 | +| 请求 Ring 遍历 | L338-L408 | 请求 Ring 每站逻辑(弹出/转发/注入) | +| Pipe stage 寄存器 | L410-L426 | Pipe 访问前的寄存器级 | +| 响应注入 FIFO | L428-L503 | Pipe 访问后的响应注入缓冲 | +| 响应 Ring 遍历 | L505-L630 | 响应 Ring 每站逻辑 + MGB | +| 调试输出 | L632-L654 | 调试信号输出 | + +--- + +## 17. 测试验证 + +### 17.1 基础测试用例 + +测试平台(`tb_janus_tmu_pyc.cpp` / `tb_janus_tmu_pyc.sv`)包含以下测试: + +**Test 1: 本地读写(每个 node 访问自身 pipe)** +``` +for each node n in [0..7]: + 1. node_n 写 pipe_n: addr = makeAddr(n, n, 0), data = seed(n+1) + 2. 等待写响应,验证 tag 和 data 匹配 + 3. node_n 读 pipe_n: 同一地址 + 4. 等待读响应,验证读回数据 == 写入数据 +``` + +**Test 2: 跨节点读写(node0 访问 pipe2)** +``` +1. node0 写 pipe2: addr = makeAddr(5, 2, 0), data = seed(0xAA), tag = 0x55 +2. 等待写响应 +3. node0 读 pipe2: 同一地址, tag = 0x56 +4. 等待读响应,验证读回数据 == 写入数据 +``` + +### 17.2 验证要点 + +- Tag 匹配:响应的 tag 必须与请求的 tag 一致 +- 数据完整性:读回的 32 个 64-bit word 必须与写入完全一致 +- resp_is_write:正确反映原始请求类型 +- 超时检测:2000 cycle 内未收到响应则报错 + +--- + +## 附录 A: CW_PREF 路由偏好表 + +基于 RING_ORDER = [0, 1, 3, 5, 7, 6, 4, 2],预计算的路由偏好(1=CW, 0=CC): + +| src\dst | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | +|---------|---|---|---|---|---|---|---|---| +| **0** | 1 | 1 | 0 | 1 | 0 | 1 | 0 | 1 | +| **1** | 0 | 1 | 0 | 1 | 0 | 1 | 0 | 1 | +| **2** | 1 | 1 | 1 | 1 | 1 | 0 | 1 | 0 | +| **3** | 0 | 0 | 0 | 1 | 0 | 1 | 0 | 1 | +| **4** | 1 | 1 | 0 | 1 | 1 | 1 | 0 | 1 | +| **5** | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 1 | +| **6** | 1 | 1 | 0 | 1 | 1 | 1 | 1 | 1 | +| **7** | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | + +## 附录 B: 术语表 + +| 术语 | 全称 | 含义 | +|------|------|------| +| TMU | Tile Management Unit | Tile 管理单元 | +| TileReg | Tile Register File | Tile 寄存器文件(片上 SRAM 缓冲区) | +| Ring | Ring Interconnect | 环形互联网络 | +| CS | Circuit Station | 环上的站点 | +| CW | Clockwise | 顺时针方向 | +| CC | Counter-Clockwise | 逆时针方向 | +| SPB | Send/Post Buffer | 发送缓冲区(请求上 Ring) | +| MGB | Merge Buffer | 合并缓冲区(响应下 Ring) | +| Flit | Flow control unit | 流控单元(Ring 上传输的最小数据单位) | +| Pipe | Pipeline SRAM | TileReg 的一个分区(128KB) | +| BCC | Block Control Core | 块控制核 | +| TMA | Tile Memory Access | Tile 存储访问单元 | +| RR | Round-Robin | 轮询仲裁 | \ No newline at end of file diff --git a/janus/pyc/janus/tmu/janus_tmu_pyc.py b/janus/pyc/janus/tmu/janus_tmu_pyc.py new file mode 100644 index 0000000..a8be20d --- /dev/null +++ b/janus/pyc/janus/tmu/janus_tmu_pyc.py @@ -0,0 +1,657 @@ +from __future__ import annotations + +import os +from dataclasses import dataclass + +from pycircuit import Circuit, Reg, Wire +from pycircuit.hw import cat + +from janus.bcc.ooo.helpers import mux_by_uindex + + +RING_ORDER = [0, 1, 3, 5, 7, 6, 4, 2] +NODE_COUNT = 8 + + +def _build_cw_pref() -> list[list[int]]: + order = RING_ORDER + n = len(order) + pos = {node: i for i, node in enumerate(order)} + prefs: list[list[int]] = [[0 for _ in range(n)] for _ in range(n)] + for s in range(n): + for d in range(n): + if s == d: + prefs[s][d] = 1 + continue + s_pos = pos[s] + d_pos = pos[d] + cw = (d_pos - s_pos) % n + cc = (s_pos - d_pos) % n + prefs[s][d] = 1 if cw <= cc else 0 + return prefs + + +CW_PREF = _build_cw_pref() + + +def _dir_cw(m: Circuit, *, src: int, dst: Wire) -> Wire: + c = m.const + items = [c(1 if CW_PREF[src][i] else 0, width=1) for i in range(NODE_COUNT)] + return mux_by_uindex(m, idx=dst, items=items, default=c(1, width=1)) + + +def _field(w: Wire, *, lsb: int, width: int) -> Wire: + return w.slice(lsb=lsb, width=width) + + +def _and_all(m: Circuit, items: list[Wire]) -> Wire: + out = m.const(1, width=1) + for it in items: + out = out & it + return out + + +def _select_words(sel: Wire, a_words: list[Wire], b_words: list[Wire]) -> list[Wire]: + return [sel.select(a, b) for a, b in zip(a_words, b_words)] + + +def _select4_words( + sel_a: Wire, + sel_b: Wire, + sel_c: Wire, + sel_d: Wire, + wa: list[Wire], + wb: list[Wire], + wc: list[Wire], + wd: list[Wire], +) -> list[Wire]: + out: list[Wire] = [] + for a, b, c, d in zip(wa, wb, wc, wd): + out.append(sel_a.select(a, sel_b.select(b, sel_c.select(c, d)))) + return out + + +@dataclass(frozen=True) +class BundleFifo: + in_ready: Wire + out_valid: Wire + out_meta: Wire + out_data: list[Wire] + + +def _build_bundle_fifo( + m: Circuit, + *, + clk: Wire, + rst: Wire, + in_valid: Wire, + in_meta: Wire, + in_data: list[Wire], + out_ready: Wire, + depth: int, + name: str, +) -> BundleFifo: + push = m.named_wire(f"{name}__push", width=1) + pop = m.named_wire(f"{name}__pop", width=1) + + meta_in_ready, meta_out_valid, meta_out_data = m.fifo( + clk, + rst, + in_valid=push, + in_data=in_meta, + out_ready=pop, + depth=depth, + ) + + data_in_ready: list[Wire] = [] + data_out_valid: list[Wire] = [] + data_out_data: list[Wire] = [] + + for wi, word in enumerate(in_data): + in_ready_w, out_valid_w, out_data_w = m.fifo( + clk, + rst, + in_valid=push, + in_data=word, + out_ready=pop, + depth=depth, + ) + data_in_ready.append(in_ready_w) + data_out_valid.append(out_valid_w) + data_out_data.append(out_data_w) + + bundle_in_ready = _and_all(m, [meta_in_ready, *data_in_ready]) + bundle_out_valid = _and_all(m, [meta_out_valid, *data_out_valid]) + + m.assign(push, in_valid & bundle_in_ready) + m.assign(pop, out_ready & bundle_out_valid) + + return BundleFifo(in_ready=bundle_in_ready, out_valid=bundle_out_valid, out_meta=meta_out_data, out_data=data_out_data) + + +@dataclass(frozen=True) +class NodeIo: + req_valid: Wire + req_write: Wire + req_addr: Wire + req_tag: Wire + req_data_words: list[Wire] + req_ready: Wire + resp_ready: Wire + resp_valid: Wire + resp_tag: Wire + resp_data_words: list[Wire] + resp_is_write: Wire + + +def build( + m: Circuit, + *, + tile_bytes: int | None = None, + tag_bits: int = 8, + spb_depth: int = 4, + mgb_depth: int = 4, +) -> None: + if tile_bytes is None: + tile_bytes = int(os.getenv("JANUS_TMU_TILE_BYTES", 1 << 20)) + if tile_bytes <= 0: + raise ValueError("tile_bytes must be > 0") + + line_bytes = 256 + line_words = line_bytes // 8 + pipe_count = NODE_COUNT + + if tile_bytes % (pipe_count * line_bytes) != 0: + raise ValueError("tile_bytes must be divisible by 8 * 256") + + addr_bits = (tile_bytes - 1).bit_length() + offset_bits = (line_bytes - 1).bit_length() + pipe_bits = (pipe_count - 1).bit_length() + if addr_bits < offset_bits + pipe_bits: + raise ValueError("tile_bytes too small for pipe addressing") + + index_bits = addr_bits - offset_bits - pipe_bits + lines_per_pipe = tile_bytes // (pipe_count * line_bytes) + + c = m.const + node_bits = pipe_bits + + clk = m.clock("clk") + rst = m.reset("rst") + + # Meta layouts (packed into 64-bit). + REQ_WRITE_LSB = 0 + REQ_SRC_LSB = REQ_WRITE_LSB + 1 + REQ_DST_LSB = REQ_SRC_LSB + node_bits + REQ_TAG_LSB = REQ_DST_LSB + node_bits + REQ_ADDR_LSB = REQ_TAG_LSB + tag_bits + + RSP_WRITE_LSB = 0 + RSP_SRC_LSB = RSP_WRITE_LSB + 1 + RSP_DST_LSB = RSP_SRC_LSB + node_bits + RSP_TAG_LSB = RSP_DST_LSB + node_bits + + def pack_req_meta(write: Wire, src: Wire, dst: Wire, tag: Wire, addr: Wire) -> Wire: + meta = cat(addr, tag, dst, src, write) + return meta.zext(width=64) + + def pack_rsp_meta(write: Wire, src: Wire, dst: Wire, tag: Wire) -> Wire: + meta = cat(tag, dst, src, write) + return meta.zext(width=64) + + # --- Node IOs --- + nodes: list[NodeIo] = [] + for i in range(NODE_COUNT): + req_valid = m.input(f"n{i}_req_valid", width=1) + req_write = m.input(f"n{i}_req_write", width=1) + req_addr = m.input(f"n{i}_req_addr", width=addr_bits) + req_tag = m.input(f"n{i}_req_tag", width=tag_bits) + req_data_words = [m.input(f"n{i}_req_data_w{wi}", width=64) for wi in range(line_words)] + resp_ready = m.input(f"n{i}_resp_ready", width=1) + + req_ready = m.named_wire(f"n{i}_req_ready", width=1) + resp_valid = m.named_wire(f"n{i}_resp_valid", width=1) + resp_tag = m.named_wire(f"n{i}_resp_tag", width=tag_bits) + resp_data_words = [m.named_wire(f"n{i}_resp_data_w{wi}", width=64) for wi in range(line_words)] + resp_is_write = m.named_wire(f"n{i}_resp_is_write", width=1) + + nodes.append( + NodeIo( + req_valid=req_valid, + req_write=req_write, + req_addr=req_addr, + req_tag=req_tag, + req_data_words=req_data_words, + req_ready=req_ready, + resp_ready=resp_ready, + resp_valid=resp_valid, + resp_tag=resp_tag, + resp_data_words=resp_data_words, + resp_is_write=resp_is_write, + ) + ) + + # --- Build SPB bundles per node (cw/cc) --- + spb_cw: list[BundleFifo] = [] + spb_cc: list[BundleFifo] = [] + spb_cw_out_ready: list[Wire] = [] + spb_cc_out_ready: list[Wire] = [] + + req_meta: list[Wire] = [] + req_words: list[list[Wire]] = [] + req_dir_cw: list[Wire] = [] + + for i, node in enumerate(nodes): + dst = node.req_addr.slice(lsb=offset_bits, width=pipe_bits) + src = c(i, width=node_bits) + meta = pack_req_meta(node.req_write, src, dst, node.req_tag, node.req_addr) + req_meta.append(meta) + words = node.req_data_words + req_words.append(words) + + dir_cw = _dir_cw(m, src=i, dst=dst) + req_dir_cw.append(dir_cw) + + in_valid_cw = node.req_valid & dir_cw + in_valid_cc = node.req_valid & (~dir_cw) + + cw_ready = m.named_wire(f"spb{i}_cw_out_ready", width=1) + cc_ready = m.named_wire(f"spb{i}_cc_out_ready", width=1) + spb_cw_out_ready.append(cw_ready) + spb_cc_out_ready.append(cc_ready) + + spb_cw.append( + _build_bundle_fifo( + m, + clk=clk, + rst=rst, + in_valid=in_valid_cw, + in_meta=meta, + in_data=words, + out_ready=cw_ready, + depth=spb_depth, + name=f"spb{i}_cw", + ) + ) + spb_cc.append( + _build_bundle_fifo( + m, + clk=clk, + rst=rst, + in_valid=in_valid_cc, + in_meta=meta, + in_data=words, + out_ready=cc_ready, + depth=spb_depth, + name=f"spb{i}_cc", + ) + ) + + m.assign(node.req_ready, dir_cw.select(spb_cw[i].in_ready, spb_cc[i].in_ready)) + + # --- Ring link registers (request + response, cw/cc) --- + req_cw_link_valid: list[Reg] = [] + req_cw_link_meta: list[Reg] = [] + req_cw_link_data: list[list[Reg]] = [] + req_cc_link_valid: list[Reg] = [] + req_cc_link_meta: list[Reg] = [] + req_cc_link_data: list[list[Reg]] = [] + + rsp_cw_link_valid: list[Reg] = [] + rsp_cw_link_meta: list[Reg] = [] + rsp_cw_link_data: list[list[Reg]] = [] + rsp_cc_link_valid: list[Reg] = [] + rsp_cc_link_meta: list[Reg] = [] + rsp_cc_link_data: list[list[Reg]] = [] + + with m.scope("req_ring"): + for i in range(NODE_COUNT): + req_cw_link_valid.append(m.out(f"cw_v{i}", clk=clk, rst=rst, width=1, init=0, en=1)) + req_cw_link_meta.append(m.out(f"cw_m{i}", clk=clk, rst=rst, width=64, init=0, en=1)) + req_cw_link_data.append( + [m.out(f"cw_d{i}_w{wi}", clk=clk, rst=rst, width=64, init=0, en=1) for wi in range(line_words)] + ) + req_cc_link_valid.append(m.out(f"cc_v{i}", clk=clk, rst=rst, width=1, init=0, en=1)) + req_cc_link_meta.append(m.out(f"cc_m{i}", clk=clk, rst=rst, width=64, init=0, en=1)) + req_cc_link_data.append( + [m.out(f"cc_d{i}_w{wi}", clk=clk, rst=rst, width=64, init=0, en=1) for wi in range(line_words)] + ) + + with m.scope("rsp_ring"): + for i in range(NODE_COUNT): + rsp_cw_link_valid.append(m.out(f"cw_v{i}", clk=clk, rst=rst, width=1, init=0, en=1)) + rsp_cw_link_meta.append(m.out(f"cw_m{i}", clk=clk, rst=rst, width=64, init=0, en=1)) + rsp_cw_link_data.append( + [m.out(f"cw_d{i}_w{wi}", clk=clk, rst=rst, width=64, init=0, en=1) for wi in range(line_words)] + ) + rsp_cc_link_valid.append(m.out(f"cc_v{i}", clk=clk, rst=rst, width=1, init=0, en=1)) + rsp_cc_link_meta.append(m.out(f"cc_m{i}", clk=clk, rst=rst, width=64, init=0, en=1)) + rsp_cc_link_data.append( + [m.out(f"cc_d{i}_w{wi}", clk=clk, rst=rst, width=64, init=0, en=1) for wi in range(line_words)] + ) + + # --- Pipe request wires --- + pipe_req_valid: list[Wire] = [c(0, width=1) for _ in range(NODE_COUNT)] + pipe_req_meta: list[Wire] = [c(0, width=64) for _ in range(NODE_COUNT)] + pipe_req_data: list[list[Wire]] = [[c(0, width=64) for _ in range(line_words)] for _ in range(NODE_COUNT)] + + # --- Request ring traversal + ejection to pipes --- + for pos in range(NODE_COUNT): + nid = RING_ORDER[pos] + node_const = c(nid, width=node_bits) + + prev_pos = (pos - 1) % NODE_COUNT + next_pos = (pos + 1) % NODE_COUNT + + cw_in_valid = req_cw_link_valid[prev_pos].out() + cw_in_meta = req_cw_link_meta[prev_pos].out() + cw_in_data = [r.out() for r in req_cw_link_data[prev_pos]] + + cc_in_valid = req_cc_link_valid[next_pos].out() + cc_in_meta = req_cc_link_meta[next_pos].out() + cc_in_data = [r.out() for r in req_cc_link_data[next_pos]] + + cw_in_dst = _field(cw_in_meta, lsb=REQ_DST_LSB, width=node_bits) + cc_in_dst = _field(cc_in_meta, lsb=REQ_DST_LSB, width=node_bits) + + ring_cw_local = cw_in_valid & cw_in_dst.eq(node_const) + ring_cc_local = cc_in_valid & cc_in_dst.eq(node_const) + + spb_cw_head_meta = spb_cw[nid].out_meta + spb_cc_head_meta = spb_cc[nid].out_meta + spb_cw_head_data = spb_cw[nid].out_data + spb_cc_head_data = spb_cc[nid].out_data + + spb_cw_dst = _field(spb_cw_head_meta, lsb=REQ_DST_LSB, width=node_bits) + spb_cc_dst = _field(spb_cc_head_meta, lsb=REQ_DST_LSB, width=node_bits) + + spb_cw_local = spb_cw[nid].out_valid & spb_cw_dst.eq(node_const) + spb_cc_local = spb_cc[nid].out_valid & spb_cc_dst.eq(node_const) + + sel_ring_cw = ring_cw_local + sel_ring_cc = (~sel_ring_cw) & ring_cc_local + sel_spb_cw = (~sel_ring_cw) & (~sel_ring_cc) & spb_cw_local + sel_spb_cc = (~sel_ring_cw) & (~sel_ring_cc) & (~sel_spb_cw) & spb_cc_local + + pipe_req_valid[nid] = sel_ring_cw | sel_ring_cc | sel_spb_cw | sel_spb_cc + pipe_req_meta[nid] = sel_ring_cw.select( + cw_in_meta, + sel_ring_cc.select(cc_in_meta, sel_spb_cw.select(spb_cw_head_meta, spb_cc_head_meta)), + ) + pipe_req_data[nid] = _select4_words(sel_ring_cw, sel_ring_cc, sel_spb_cw, sel_spb_cc, cw_in_data, cc_in_data, spb_cw_head_data, spb_cc_head_data) + + cw_forward_valid = cw_in_valid & (~sel_ring_cw) + cw_can_inject = ~cw_forward_valid + cw_inject_valid = spb_cw[nid].out_valid & (~spb_cw_local) & cw_can_inject + cw_out_valid = cw_forward_valid | cw_inject_valid + cw_out_meta = cw_forward_valid.select(cw_in_meta, spb_cw_head_meta) + cw_out_data = _select_words(cw_forward_valid, cw_in_data, spb_cw_head_data) + + cc_forward_valid = cc_in_valid & (~sel_ring_cc) + cc_can_inject = ~cc_forward_valid + cc_inject_valid = spb_cc[nid].out_valid & (~spb_cc_local) & cc_can_inject + cc_out_valid = cc_forward_valid | cc_inject_valid + cc_out_meta = cc_forward_valid.select(cc_in_meta, spb_cc_head_meta) + cc_out_data = _select_words(cc_forward_valid, cc_in_data, spb_cc_head_data) + + req_cw_link_valid[pos].set(cw_out_valid) + req_cw_link_meta[pos].set(cw_out_meta) + for wi in range(line_words): + req_cw_link_data[pos][wi].set(cw_out_data[wi]) + + req_cc_link_valid[pos].set(cc_out_valid) + req_cc_link_meta[pos].set(cc_out_meta) + for wi in range(line_words): + req_cc_link_data[pos][wi].set(cc_out_data[wi]) + + m.assign(spb_cw_out_ready[nid], sel_spb_cw | cw_inject_valid) + m.assign(spb_cc_out_ready[nid], sel_spb_cc | cc_inject_valid) + + # --- Pipe stage regs --- + pipe_stage_valid: list[Reg] = [] + pipe_stage_meta: list[Reg] = [] + pipe_stage_data: list[list[Reg]] = [] + + for p in range(pipe_count): + with m.scope(f"pipe{p}_stage"): + pipe_stage_valid.append(m.out("v", clk=clk, rst=rst, width=1, init=0, en=1)) + pipe_stage_meta.append(m.out("m", clk=clk, rst=rst, width=64, init=0, en=1)) + pipe_stage_data.append( + [m.out(f"d_w{wi}", clk=clk, rst=rst, width=64, init=0, en=1) for wi in range(line_words)] + ) + + pipe_stage_valid[p].set(pipe_req_valid[p]) + pipe_stage_meta[p].set(pipe_req_meta[p]) + for wi in range(line_words): + pipe_stage_data[p][wi].set(pipe_req_data[p][wi]) + + # --- Response inject bundles (per pipe, cw/cc) --- + rsp_cw: list[BundleFifo] = [] + rsp_cc: list[BundleFifo] = [] + rsp_cw_out_ready: list[Wire] = [] + rsp_cc_out_ready: list[Wire] = [] + + for p in range(pipe_count): + st_valid = pipe_stage_valid[p].out() + st_meta = pipe_stage_meta[p].out() + st_data_words = [r.out() for r in pipe_stage_data[p]] + + st_write = _field(st_meta, lsb=REQ_WRITE_LSB, width=1) + st_src = _field(st_meta, lsb=REQ_SRC_LSB, width=node_bits) + st_tag = _field(st_meta, lsb=REQ_TAG_LSB, width=tag_bits) + st_addr = _field(st_meta, lsb=REQ_ADDR_LSB, width=addr_bits) + + line_idx = st_addr.slice(lsb=offset_bits + pipe_bits, width=index_bits) + byte_addr = cat(line_idx, c(0, width=3)) + depth_bytes = lines_per_pipe * 8 + + read_words: list[Wire] = [] + wvalid = st_valid & st_write + wstrb = c(0xFF, width=8) + + for wi in range(line_words): + rdata = m.byte_mem( + clk=clk, + rst=rst, + raddr=byte_addr, + wvalid=wvalid, + waddr=byte_addr, + wdata=st_data_words[wi], + wstrb=wstrb, + depth=depth_bytes, + name=f"tmu_p{p}_w{wi}", + ) + read_words.append(rdata) + + rsp_meta = pack_rsp_meta(st_write, c(p, width=node_bits), st_src, st_tag) + rsp_words = [st_write.select(st_data_words[wi], read_words[wi]) for wi in range(line_words)] + + rsp_dir = _dir_cw(m, src=p, dst=st_src) + in_valid_cw = st_valid & rsp_dir + in_valid_cc = st_valid & (~rsp_dir) + + cw_ready = m.named_wire(f"rsp{p}_cw_out_ready", width=1) + cc_ready = m.named_wire(f"rsp{p}_cc_out_ready", width=1) + rsp_cw_out_ready.append(cw_ready) + rsp_cc_out_ready.append(cc_ready) + + rsp_cw.append( + _build_bundle_fifo( + m, + clk=clk, + rst=rst, + in_valid=in_valid_cw, + in_meta=rsp_meta, + in_data=rsp_words, + out_ready=cw_ready, + depth=spb_depth, + name=f"rsp{p}_cw", + ) + ) + rsp_cc.append( + _build_bundle_fifo( + m, + clk=clk, + rst=rst, + in_valid=in_valid_cc, + in_meta=rsp_meta, + in_data=rsp_words, + out_ready=cc_ready, + depth=spb_depth, + name=f"rsp{p}_cc", + ) + ) + + # --- Response ring traversal + MGB buffers --- + for pos in range(NODE_COUNT): + nid = RING_ORDER[pos] + node_const = c(nid, width=node_bits) + + prev_pos = (pos - 1) % NODE_COUNT + next_pos = (pos + 1) % NODE_COUNT + + cw_in_valid = rsp_cw_link_valid[prev_pos].out() + cw_in_meta = rsp_cw_link_meta[prev_pos].out() + cw_in_data = [r.out() for r in rsp_cw_link_data[prev_pos]] + + cc_in_valid = rsp_cc_link_valid[next_pos].out() + cc_in_meta = rsp_cc_link_meta[next_pos].out() + cc_in_data = [r.out() for r in rsp_cc_link_data[next_pos]] + + cw_in_dst = _field(cw_in_meta, lsb=RSP_DST_LSB, width=node_bits) + cc_in_dst = _field(cc_in_meta, lsb=RSP_DST_LSB, width=node_bits) + + ring_cw_local = cw_in_valid & cw_in_dst.eq(node_const) + ring_cc_local = cc_in_valid & cc_in_dst.eq(node_const) + + rsp_cw_head_meta = rsp_cw[nid].out_meta + rsp_cc_head_meta = rsp_cc[nid].out_meta + rsp_cw_head_data = rsp_cw[nid].out_data + rsp_cc_head_data = rsp_cc[nid].out_data + + rsp_cw_dst = _field(rsp_cw_head_meta, lsb=RSP_DST_LSB, width=node_bits) + rsp_cc_dst = _field(rsp_cc_head_meta, lsb=RSP_DST_LSB, width=node_bits) + + rsp_cw_local = rsp_cw[nid].out_valid & rsp_cw_dst.eq(node_const) + rsp_cc_local = rsp_cc[nid].out_valid & rsp_cc_dst.eq(node_const) + + cw_local_valid = ring_cw_local | rsp_cw_local + cc_local_valid = ring_cc_local | rsp_cc_local + cw_local_meta = ring_cw_local.select(cw_in_meta, rsp_cw_head_meta) + cc_local_meta = ring_cc_local.select(cc_in_meta, rsp_cc_head_meta) + cw_local_data = _select_words(ring_cw_local, cw_in_data, rsp_cw_head_data) + cc_local_data = _select_words(ring_cc_local, cc_in_data, rsp_cc_head_data) + + # MGB buffers. + mgb_cw_ready = m.named_wire(f"mgb{nid}_cw_out_ready", width=1) + mgb_cc_ready = m.named_wire(f"mgb{nid}_cc_out_ready", width=1) + + mgb_cw = _build_bundle_fifo( + m, + clk=clk, + rst=rst, + in_valid=cw_local_valid, + in_meta=cw_local_meta, + in_data=cw_local_data, + out_ready=mgb_cw_ready, + depth=mgb_depth, + name=f"mgb{nid}_cw", + ) + mgb_cc = _build_bundle_fifo( + m, + clk=clk, + rst=rst, + in_valid=cc_local_valid, + in_meta=cc_local_meta, + in_data=cc_local_data, + out_ready=mgb_cc_ready, + depth=mgb_depth, + name=f"mgb{nid}_cc", + ) + + rr = m.out(f"mgb{nid}_rr", clk=clk, rst=rst, width=1, init=0, en=1) + + any_cw = mgb_cw.out_valid + any_cc = mgb_cc.out_valid + both = any_cw & any_cc + pick_cw = (any_cw & (~any_cc)) | (both & (~rr.out())) + pick_cc = (any_cc & (~any_cw)) | (both & rr.out()) + + resp_ready = nodes[nid].resp_ready + resp_fire = (pick_cw | pick_cc) & resp_ready + + m.assign(mgb_cw_ready, pick_cw & resp_ready) + m.assign(mgb_cc_ready, pick_cc & resp_ready) + + rr_next = rr.out() + rr_next = resp_fire.select(~rr_next, rr_next) + rr.set(rr_next) + + resp_meta = pick_cw.select(mgb_cw.out_meta, mgb_cc.out_meta) + resp_words = _select_words(pick_cw, mgb_cw.out_data, mgb_cc.out_data) + + m.assign(nodes[nid].resp_valid, resp_fire) + m.assign(nodes[nid].resp_tag, _field(resp_meta, lsb=RSP_TAG_LSB, width=tag_bits)) + m.assign(nodes[nid].resp_is_write, _field(resp_meta, lsb=RSP_WRITE_LSB, width=1)) + for wi in range(line_words): + m.assign(nodes[nid].resp_data_words[wi], resp_words[wi]) + + # Forward or inject on response cw lane. + cw_forward_valid = cw_in_valid & (~ring_cw_local) + cc_forward_valid = cc_in_valid & (~ring_cc_local) + + cw_can_inject = ~cw_forward_valid + cc_can_inject = ~cc_forward_valid + + cw_inject_valid = rsp_cw[nid].out_valid & (~rsp_cw_local) & cw_can_inject + cc_inject_valid = rsp_cc[nid].out_valid & (~rsp_cc_local) & cc_can_inject + + cw_out_valid = cw_forward_valid | cw_inject_valid + cc_out_valid = cc_forward_valid | cc_inject_valid + + cw_out_meta = cw_forward_valid.select(cw_in_meta, rsp_cw_head_meta) + cc_out_meta = cc_forward_valid.select(cc_in_meta, rsp_cc_head_meta) + cw_out_data = _select_words(cw_forward_valid, cw_in_data, rsp_cw_head_data) + cc_out_data = _select_words(cc_forward_valid, cc_in_data, rsp_cc_head_data) + + rsp_cw_link_valid[pos].set(cw_out_valid) + rsp_cw_link_meta[pos].set(cw_out_meta) + for wi in range(line_words): + rsp_cw_link_data[pos][wi].set(cw_out_data[wi]) + + rsp_cc_link_valid[pos].set(cc_out_valid) + rsp_cc_link_meta[pos].set(cc_out_meta) + for wi in range(line_words): + rsp_cc_link_data[pos][wi].set(cc_out_data[wi]) + + rsp_cw_local_pop = rsp_cw_local & (~ring_cw_local) & mgb_cw.in_ready + rsp_cc_local_pop = rsp_cc_local & (~ring_cc_local) & mgb_cc.in_ready + m.assign(rsp_cw_out_ready[nid], rsp_cw_local_pop | cw_inject_valid) + m.assign(rsp_cc_out_ready[nid], rsp_cc_local_pop | cc_inject_valid) + + # --- Debug ring metadata outputs (for visualization) --- + for pos in range(NODE_COUNT): + nid = RING_ORDER[pos] + req_meta = req_cw_link_meta[pos].out().slice(lsb=0, width=REQ_ADDR_LSB + addr_bits) + req_meta_cc = req_cc_link_meta[pos].out().slice(lsb=0, width=REQ_ADDR_LSB + addr_bits) + rsp_meta = rsp_cw_link_meta[pos].out().slice(lsb=0, width=RSP_TAG_LSB + tag_bits) + rsp_meta_cc = rsp_cc_link_meta[pos].out().slice(lsb=0, width=RSP_TAG_LSB + tag_bits) + m.output(f"dbg_req_cw_v{nid}", req_cw_link_valid[pos].out()) + m.output(f"dbg_req_cc_v{nid}", req_cc_link_valid[pos].out()) + m.output(f"dbg_req_cw_meta{nid}", req_meta) + m.output(f"dbg_req_cc_meta{nid}", req_meta_cc) + m.output(f"dbg_rsp_cw_v{nid}", rsp_cw_link_valid[pos].out()) + m.output(f"dbg_rsp_cc_v{nid}", rsp_cc_link_valid[pos].out()) + m.output(f"dbg_rsp_cw_meta{nid}", rsp_meta) + m.output(f"dbg_rsp_cc_meta{nid}", rsp_meta_cc) + + for i, node in enumerate(nodes): + m.output(f"n{i}_req_ready", node.req_ready) + m.output(f"n{i}_resp_valid", node.resp_valid) + m.output(f"n{i}_resp_tag", node.resp_tag) + for wi in range(line_words): + m.output(f"n{i}_resp_data_w{wi}", node.resp_data_words[wi]) + m.output(f"n{i}_resp_is_write", node.resp_is_write) + + +build.__pycircuit_name__ = "janus_tmu_pyc" diff --git a/janus/tb/tb_janus_tmu_pyc.cpp b/janus/tb/tb_janus_tmu_pyc.cpp new file mode 100644 index 0000000..eda498d --- /dev/null +++ b/janus/tb/tb_janus_tmu_pyc.cpp @@ -0,0 +1,286 @@ +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "janus_tmu_pyc_gen.hpp" + +using pyc::cpp::Testbench; +using pyc::cpp::Wire; + +namespace { + +constexpr int kNodes = 8; +constexpr int kAddrBits = 20; +constexpr int kTagBits = 8; +constexpr int kWords = 32; + +using DataWord = Wire<64>; +using DataLine = std::array; + +struct NodePorts { + Wire<1> *req_valid = nullptr; + Wire<1> *req_write = nullptr; + Wire *req_addr = nullptr; + Wire *req_tag = nullptr; + std::array req_data{}; + Wire<1> *req_ready = nullptr; + Wire<1> *resp_ready = nullptr; + Wire<1> *resp_valid = nullptr; + Wire *resp_tag = nullptr; + std::array resp_data{}; + Wire<1> *resp_is_write = nullptr; +}; + +static bool envFlag(const char *name) { + const char *v = std::getenv(name); + if (!v) + return false; + return !(v[0] == '0' && v[1] == '\0'); +} + +static std::uint32_t makeAddr(std::uint32_t index, std::uint32_t pipe, std::uint32_t offset = 0) { + return (index << 11) | (pipe << 8) | (offset & 0xFFu); +} + +static DataLine makeData(std::uint32_t seed) { + DataLine out{}; + for (unsigned i = 0; i < kWords; i++) { + std::uint64_t word = (static_cast(seed) << 32) | i; + out[i] = DataWord(word); + } + return out; +} + +static void zeroReq(NodePorts &n) { + *n.req_valid = Wire<1>(0); + *n.req_write = Wire<1>(0); + *n.req_addr = Wire(0); + *n.req_tag = Wire(0); + for (auto *w : n.req_data) + *w = DataWord(0); +} + +static void setRespReady(NodePorts &n, bool ready) { *n.resp_ready = Wire<1>(ready ? 1u : 0u); } + +static void sendReq(Testbench &tb, + NodePorts &n, + std::uint64_t &cycle, + int node_id, + bool write, + std::uint32_t addr, + std::uint8_t tag, + const DataLine &data, + std::ofstream &trace) { + *n.req_write = Wire<1>(write ? 1u : 0u); + *n.req_addr = Wire(addr); + *n.req_tag = Wire(tag); + for (unsigned i = 0; i < kWords; i++) + *n.req_data[i] = data[i]; + *n.req_valid = Wire<1>(1); + while (true) { + tb.runCycles(1); + cycle++; + if (n.req_ready->toBool()) { + trace << cycle << ",accept" + << "," << node_id << "," << unsigned(tag) << "," << (write ? 1 : 0) << ",0x" << std::hex << addr + << std::dec << ",0x" + << std::hex << data[0].value() << std::dec << "\n"; + break; + } + } + *n.req_valid = Wire<1>(0); +} + +static void waitResp(Testbench &tb, + NodePorts &n, + std::uint64_t &cycle, + int node_id, + std::uint8_t tag, + bool expect_write, + const DataLine &expect_data, + std::ofstream &trace) { + for (std::uint64_t i = 0; i < 2000; i++) { + tb.runCycles(1); + cycle++; + if (!n.resp_valid->toBool()) + continue; + if (n.resp_tag->value() != tag) { + std::cerr << "FAIL: tag mismatch. got=" << std::hex << n.resp_tag->value() << " exp=" << unsigned(tag) << std::dec + << "\n"; + std::exit(1); + } + if (n.resp_is_write->toBool() != expect_write) { + std::cerr << "FAIL: resp_is_write mismatch\n"; + std::exit(1); + } + for (unsigned i = 0; i < kWords; i++) { + if (n.resp_data[i]->value() != expect_data[i].value()) { + std::cerr << "FAIL: resp_data mismatch\n"; + std::exit(1); + } + } + trace << cycle << ",resp" + << "," << node_id << "," << unsigned(tag) << "," << (expect_write ? 1 : 0) << ",0x" << std::hex + << n.resp_data[0]->value() + << std::dec << "\n"; + return; + } + std::cerr << "FAIL: timeout waiting for response tag=0x" << std::hex << unsigned(tag) << std::dec << "\n"; + std::exit(1); +} + +} // namespace + +int main() { + pyc::gen::janus_tmu_pyc dut{}; + Testbench tb(dut); + + const bool trace_log = envFlag("PYC_TRACE"); + const bool trace_vcd = envFlag("PYC_VCD"); + + std::filesystem::path out_dir{}; + if (trace_log || trace_vcd) { + const char *trace_dir_env = std::getenv("PYC_TRACE_DIR"); + out_dir = trace_dir_env ? std::filesystem::path(trace_dir_env) : std::filesystem::path("janus/generated/janus_tmu_pyc"); + std::filesystem::create_directories(out_dir); + } + + if (trace_log) { + tb.enableLog((out_dir / "tb_janus_tmu_pyc_cpp.log").string()); + } + + if (trace_vcd) { + tb.enableVcd((out_dir / "tb_janus_tmu_pyc_cpp.vcd").string(), /*top=*/"tb_janus_tmu_pyc_cpp"); + tb.vcdTrace(dut.clk, "clk"); + tb.vcdTrace(dut.rst, "rst"); + tb.vcdTrace(dut.n0_req_valid, "n0_req_valid"); + tb.vcdTrace(dut.n0_req_ready, "n0_req_ready"); + tb.vcdTrace(dut.n0_resp_valid, "n0_resp_valid"); + tb.vcdTrace(dut.n0_resp_is_write, "n0_resp_is_write"); + tb.vcdTrace(dut.n0_resp_tag, "n0_resp_tag"); + tb.vcdTrace(dut.n0_req_data_w0, "n0_req_data_w0"); + tb.vcdTrace(dut.n0_resp_data_w0, "n0_resp_data_w0"); + tb.vcdTrace(dut.dbg_req_cw_v0, "dbg_req_cw_v0"); + tb.vcdTrace(dut.dbg_req_cc_v0, "dbg_req_cc_v0"); + tb.vcdTrace(dut.dbg_rsp_cw_v0, "dbg_rsp_cw_v0"); + tb.vcdTrace(dut.dbg_rsp_cc_v0, "dbg_rsp_cc_v0"); + tb.vcdTrace(dut.dbg_req_cw_v1, "dbg_req_cw_v1"); + tb.vcdTrace(dut.dbg_req_cc_v1, "dbg_req_cc_v1"); + tb.vcdTrace(dut.dbg_rsp_cw_v1, "dbg_rsp_cw_v1"); + tb.vcdTrace(dut.dbg_rsp_cc_v1, "dbg_rsp_cc_v1"); + tb.vcdTrace(dut.dbg_req_cw_v2, "dbg_req_cw_v2"); + tb.vcdTrace(dut.dbg_req_cc_v2, "dbg_req_cc_v2"); + tb.vcdTrace(dut.dbg_rsp_cw_v2, "dbg_rsp_cw_v2"); + tb.vcdTrace(dut.dbg_rsp_cc_v2, "dbg_rsp_cc_v2"); + tb.vcdTrace(dut.dbg_req_cw_v3, "dbg_req_cw_v3"); + tb.vcdTrace(dut.dbg_req_cc_v3, "dbg_req_cc_v3"); + tb.vcdTrace(dut.dbg_rsp_cw_v3, "dbg_rsp_cw_v3"); + tb.vcdTrace(dut.dbg_rsp_cc_v3, "dbg_rsp_cc_v3"); + tb.vcdTrace(dut.dbg_req_cw_v4, "dbg_req_cw_v4"); + tb.vcdTrace(dut.dbg_req_cc_v4, "dbg_req_cc_v4"); + tb.vcdTrace(dut.dbg_rsp_cw_v4, "dbg_rsp_cw_v4"); + tb.vcdTrace(dut.dbg_rsp_cc_v4, "dbg_rsp_cc_v4"); + tb.vcdTrace(dut.dbg_req_cw_v5, "dbg_req_cw_v5"); + tb.vcdTrace(dut.dbg_req_cc_v5, "dbg_req_cc_v5"); + tb.vcdTrace(dut.dbg_rsp_cw_v5, "dbg_rsp_cw_v5"); + tb.vcdTrace(dut.dbg_rsp_cc_v5, "dbg_rsp_cc_v5"); + tb.vcdTrace(dut.dbg_req_cw_v6, "dbg_req_cw_v6"); + tb.vcdTrace(dut.dbg_req_cc_v6, "dbg_req_cc_v6"); + tb.vcdTrace(dut.dbg_rsp_cw_v6, "dbg_rsp_cw_v6"); + tb.vcdTrace(dut.dbg_rsp_cc_v6, "dbg_rsp_cc_v6"); + tb.vcdTrace(dut.dbg_req_cw_v7, "dbg_req_cw_v7"); + tb.vcdTrace(dut.dbg_req_cc_v7, "dbg_req_cc_v7"); + tb.vcdTrace(dut.dbg_rsp_cw_v7, "dbg_rsp_cw_v7"); + tb.vcdTrace(dut.dbg_rsp_cc_v7, "dbg_rsp_cc_v7"); + } + + tb.addClock(dut.clk, /*halfPeriodSteps=*/1); + tb.reset(dut.rst, /*cyclesAsserted=*/2, /*cyclesDeasserted=*/1); + + std::ofstream trace; + if (trace_log) { + trace.open(out_dir / "tmu_trace.csv", std::ios::out | std::ios::trunc); + trace << "cycle,event,node,tag,write,addr_or_word0,data_word0\n"; + } + + std::array nodes = {{ + {&dut.n0_req_valid, &dut.n0_req_write, &dut.n0_req_addr, &dut.n0_req_tag, + {&dut.n0_req_data_w0, &dut.n0_req_data_w1, &dut.n0_req_data_w2, &dut.n0_req_data_w3, &dut.n0_req_data_w4, &dut.n0_req_data_w5, &dut.n0_req_data_w6, &dut.n0_req_data_w7, &dut.n0_req_data_w8, &dut.n0_req_data_w9, &dut.n0_req_data_w10, &dut.n0_req_data_w11, &dut.n0_req_data_w12, &dut.n0_req_data_w13, &dut.n0_req_data_w14, &dut.n0_req_data_w15, &dut.n0_req_data_w16, &dut.n0_req_data_w17, &dut.n0_req_data_w18, &dut.n0_req_data_w19, &dut.n0_req_data_w20, &dut.n0_req_data_w21, &dut.n0_req_data_w22, &dut.n0_req_data_w23, &dut.n0_req_data_w24, &dut.n0_req_data_w25, &dut.n0_req_data_w26, &dut.n0_req_data_w27, &dut.n0_req_data_w28, &dut.n0_req_data_w29, &dut.n0_req_data_w30, &dut.n0_req_data_w31}, &dut.n0_req_ready, &dut.n0_resp_ready, &dut.n0_resp_valid, &dut.n0_resp_tag, + {&dut.n0_resp_data_w0, &dut.n0_resp_data_w1, &dut.n0_resp_data_w2, &dut.n0_resp_data_w3, &dut.n0_resp_data_w4, &dut.n0_resp_data_w5, &dut.n0_resp_data_w6, &dut.n0_resp_data_w7, &dut.n0_resp_data_w8, &dut.n0_resp_data_w9, &dut.n0_resp_data_w10, &dut.n0_resp_data_w11, &dut.n0_resp_data_w12, &dut.n0_resp_data_w13, &dut.n0_resp_data_w14, &dut.n0_resp_data_w15, &dut.n0_resp_data_w16, &dut.n0_resp_data_w17, &dut.n0_resp_data_w18, &dut.n0_resp_data_w19, &dut.n0_resp_data_w20, &dut.n0_resp_data_w21, &dut.n0_resp_data_w22, &dut.n0_resp_data_w23, &dut.n0_resp_data_w24, &dut.n0_resp_data_w25, &dut.n0_resp_data_w26, &dut.n0_resp_data_w27, &dut.n0_resp_data_w28, &dut.n0_resp_data_w29, &dut.n0_resp_data_w30, &dut.n0_resp_data_w31}, &dut.n0_resp_is_write}, + {&dut.n1_req_valid, &dut.n1_req_write, &dut.n1_req_addr, &dut.n1_req_tag, + {&dut.n1_req_data_w0, &dut.n1_req_data_w1, &dut.n1_req_data_w2, &dut.n1_req_data_w3, &dut.n1_req_data_w4, &dut.n1_req_data_w5, &dut.n1_req_data_w6, &dut.n1_req_data_w7, &dut.n1_req_data_w8, &dut.n1_req_data_w9, &dut.n1_req_data_w10, &dut.n1_req_data_w11, &dut.n1_req_data_w12, &dut.n1_req_data_w13, &dut.n1_req_data_w14, &dut.n1_req_data_w15, &dut.n1_req_data_w16, &dut.n1_req_data_w17, &dut.n1_req_data_w18, &dut.n1_req_data_w19, &dut.n1_req_data_w20, &dut.n1_req_data_w21, &dut.n1_req_data_w22, &dut.n1_req_data_w23, &dut.n1_req_data_w24, &dut.n1_req_data_w25, &dut.n1_req_data_w26, &dut.n1_req_data_w27, &dut.n1_req_data_w28, &dut.n1_req_data_w29, &dut.n1_req_data_w30, &dut.n1_req_data_w31}, &dut.n1_req_ready, &dut.n1_resp_ready, &dut.n1_resp_valid, &dut.n1_resp_tag, + {&dut.n1_resp_data_w0, &dut.n1_resp_data_w1, &dut.n1_resp_data_w2, &dut.n1_resp_data_w3, &dut.n1_resp_data_w4, &dut.n1_resp_data_w5, &dut.n1_resp_data_w6, &dut.n1_resp_data_w7, &dut.n1_resp_data_w8, &dut.n1_resp_data_w9, &dut.n1_resp_data_w10, &dut.n1_resp_data_w11, &dut.n1_resp_data_w12, &dut.n1_resp_data_w13, &dut.n1_resp_data_w14, &dut.n1_resp_data_w15, &dut.n1_resp_data_w16, &dut.n1_resp_data_w17, &dut.n1_resp_data_w18, &dut.n1_resp_data_w19, &dut.n1_resp_data_w20, &dut.n1_resp_data_w21, &dut.n1_resp_data_w22, &dut.n1_resp_data_w23, &dut.n1_resp_data_w24, &dut.n1_resp_data_w25, &dut.n1_resp_data_w26, &dut.n1_resp_data_w27, &dut.n1_resp_data_w28, &dut.n1_resp_data_w29, &dut.n1_resp_data_w30, &dut.n1_resp_data_w31}, &dut.n1_resp_is_write}, + {&dut.n2_req_valid, &dut.n2_req_write, &dut.n2_req_addr, &dut.n2_req_tag, + {&dut.n2_req_data_w0, &dut.n2_req_data_w1, &dut.n2_req_data_w2, &dut.n2_req_data_w3, &dut.n2_req_data_w4, &dut.n2_req_data_w5, &dut.n2_req_data_w6, &dut.n2_req_data_w7, &dut.n2_req_data_w8, &dut.n2_req_data_w9, &dut.n2_req_data_w10, &dut.n2_req_data_w11, &dut.n2_req_data_w12, &dut.n2_req_data_w13, &dut.n2_req_data_w14, &dut.n2_req_data_w15, &dut.n2_req_data_w16, &dut.n2_req_data_w17, &dut.n2_req_data_w18, &dut.n2_req_data_w19, &dut.n2_req_data_w20, &dut.n2_req_data_w21, &dut.n2_req_data_w22, &dut.n2_req_data_w23, &dut.n2_req_data_w24, &dut.n2_req_data_w25, &dut.n2_req_data_w26, &dut.n2_req_data_w27, &dut.n2_req_data_w28, &dut.n2_req_data_w29, &dut.n2_req_data_w30, &dut.n2_req_data_w31}, &dut.n2_req_ready, &dut.n2_resp_ready, &dut.n2_resp_valid, &dut.n2_resp_tag, + {&dut.n2_resp_data_w0, &dut.n2_resp_data_w1, &dut.n2_resp_data_w2, &dut.n2_resp_data_w3, &dut.n2_resp_data_w4, &dut.n2_resp_data_w5, &dut.n2_resp_data_w6, &dut.n2_resp_data_w7, &dut.n2_resp_data_w8, &dut.n2_resp_data_w9, &dut.n2_resp_data_w10, &dut.n2_resp_data_w11, &dut.n2_resp_data_w12, &dut.n2_resp_data_w13, &dut.n2_resp_data_w14, &dut.n2_resp_data_w15, &dut.n2_resp_data_w16, &dut.n2_resp_data_w17, &dut.n2_resp_data_w18, &dut.n2_resp_data_w19, &dut.n2_resp_data_w20, &dut.n2_resp_data_w21, &dut.n2_resp_data_w22, &dut.n2_resp_data_w23, &dut.n2_resp_data_w24, &dut.n2_resp_data_w25, &dut.n2_resp_data_w26, &dut.n2_resp_data_w27, &dut.n2_resp_data_w28, &dut.n2_resp_data_w29, &dut.n2_resp_data_w30, &dut.n2_resp_data_w31}, &dut.n2_resp_is_write}, + {&dut.n3_req_valid, &dut.n3_req_write, &dut.n3_req_addr, &dut.n3_req_tag, + {&dut.n3_req_data_w0, &dut.n3_req_data_w1, &dut.n3_req_data_w2, &dut.n3_req_data_w3, &dut.n3_req_data_w4, &dut.n3_req_data_w5, &dut.n3_req_data_w6, &dut.n3_req_data_w7, &dut.n3_req_data_w8, &dut.n3_req_data_w9, &dut.n3_req_data_w10, &dut.n3_req_data_w11, &dut.n3_req_data_w12, &dut.n3_req_data_w13, &dut.n3_req_data_w14, &dut.n3_req_data_w15, &dut.n3_req_data_w16, &dut.n3_req_data_w17, &dut.n3_req_data_w18, &dut.n3_req_data_w19, &dut.n3_req_data_w20, &dut.n3_req_data_w21, &dut.n3_req_data_w22, &dut.n3_req_data_w23, &dut.n3_req_data_w24, &dut.n3_req_data_w25, &dut.n3_req_data_w26, &dut.n3_req_data_w27, &dut.n3_req_data_w28, &dut.n3_req_data_w29, &dut.n3_req_data_w30, &dut.n3_req_data_w31}, &dut.n3_req_ready, &dut.n3_resp_ready, &dut.n3_resp_valid, &dut.n3_resp_tag, + {&dut.n3_resp_data_w0, &dut.n3_resp_data_w1, &dut.n3_resp_data_w2, &dut.n3_resp_data_w3, &dut.n3_resp_data_w4, &dut.n3_resp_data_w5, &dut.n3_resp_data_w6, &dut.n3_resp_data_w7, &dut.n3_resp_data_w8, &dut.n3_resp_data_w9, &dut.n3_resp_data_w10, &dut.n3_resp_data_w11, &dut.n3_resp_data_w12, &dut.n3_resp_data_w13, &dut.n3_resp_data_w14, &dut.n3_resp_data_w15, &dut.n3_resp_data_w16, &dut.n3_resp_data_w17, &dut.n3_resp_data_w18, &dut.n3_resp_data_w19, &dut.n3_resp_data_w20, &dut.n3_resp_data_w21, &dut.n3_resp_data_w22, &dut.n3_resp_data_w23, &dut.n3_resp_data_w24, &dut.n3_resp_data_w25, &dut.n3_resp_data_w26, &dut.n3_resp_data_w27, &dut.n3_resp_data_w28, &dut.n3_resp_data_w29, &dut.n3_resp_data_w30, &dut.n3_resp_data_w31}, &dut.n3_resp_is_write}, + {&dut.n4_req_valid, &dut.n4_req_write, &dut.n4_req_addr, &dut.n4_req_tag, + {&dut.n4_req_data_w0, &dut.n4_req_data_w1, &dut.n4_req_data_w2, &dut.n4_req_data_w3, &dut.n4_req_data_w4, &dut.n4_req_data_w5, &dut.n4_req_data_w6, &dut.n4_req_data_w7, &dut.n4_req_data_w8, &dut.n4_req_data_w9, &dut.n4_req_data_w10, &dut.n4_req_data_w11, &dut.n4_req_data_w12, &dut.n4_req_data_w13, &dut.n4_req_data_w14, &dut.n4_req_data_w15, &dut.n4_req_data_w16, &dut.n4_req_data_w17, &dut.n4_req_data_w18, &dut.n4_req_data_w19, &dut.n4_req_data_w20, &dut.n4_req_data_w21, &dut.n4_req_data_w22, &dut.n4_req_data_w23, &dut.n4_req_data_w24, &dut.n4_req_data_w25, &dut.n4_req_data_w26, &dut.n4_req_data_w27, &dut.n4_req_data_w28, &dut.n4_req_data_w29, &dut.n4_req_data_w30, &dut.n4_req_data_w31}, &dut.n4_req_ready, &dut.n4_resp_ready, &dut.n4_resp_valid, &dut.n4_resp_tag, + {&dut.n4_resp_data_w0, &dut.n4_resp_data_w1, &dut.n4_resp_data_w2, &dut.n4_resp_data_w3, &dut.n4_resp_data_w4, &dut.n4_resp_data_w5, &dut.n4_resp_data_w6, &dut.n4_resp_data_w7, &dut.n4_resp_data_w8, &dut.n4_resp_data_w9, &dut.n4_resp_data_w10, &dut.n4_resp_data_w11, &dut.n4_resp_data_w12, &dut.n4_resp_data_w13, &dut.n4_resp_data_w14, &dut.n4_resp_data_w15, &dut.n4_resp_data_w16, &dut.n4_resp_data_w17, &dut.n4_resp_data_w18, &dut.n4_resp_data_w19, &dut.n4_resp_data_w20, &dut.n4_resp_data_w21, &dut.n4_resp_data_w22, &dut.n4_resp_data_w23, &dut.n4_resp_data_w24, &dut.n4_resp_data_w25, &dut.n4_resp_data_w26, &dut.n4_resp_data_w27, &dut.n4_resp_data_w28, &dut.n4_resp_data_w29, &dut.n4_resp_data_w30, &dut.n4_resp_data_w31}, &dut.n4_resp_is_write}, + {&dut.n5_req_valid, &dut.n5_req_write, &dut.n5_req_addr, &dut.n5_req_tag, + {&dut.n5_req_data_w0, &dut.n5_req_data_w1, &dut.n5_req_data_w2, &dut.n5_req_data_w3, &dut.n5_req_data_w4, &dut.n5_req_data_w5, &dut.n5_req_data_w6, &dut.n5_req_data_w7, &dut.n5_req_data_w8, &dut.n5_req_data_w9, &dut.n5_req_data_w10, &dut.n5_req_data_w11, &dut.n5_req_data_w12, &dut.n5_req_data_w13, &dut.n5_req_data_w14, &dut.n5_req_data_w15, &dut.n5_req_data_w16, &dut.n5_req_data_w17, &dut.n5_req_data_w18, &dut.n5_req_data_w19, &dut.n5_req_data_w20, &dut.n5_req_data_w21, &dut.n5_req_data_w22, &dut.n5_req_data_w23, &dut.n5_req_data_w24, &dut.n5_req_data_w25, &dut.n5_req_data_w26, &dut.n5_req_data_w27, &dut.n5_req_data_w28, &dut.n5_req_data_w29, &dut.n5_req_data_w30, &dut.n5_req_data_w31}, &dut.n5_req_ready, &dut.n5_resp_ready, &dut.n5_resp_valid, &dut.n5_resp_tag, + {&dut.n5_resp_data_w0, &dut.n5_resp_data_w1, &dut.n5_resp_data_w2, &dut.n5_resp_data_w3, &dut.n5_resp_data_w4, &dut.n5_resp_data_w5, &dut.n5_resp_data_w6, &dut.n5_resp_data_w7, &dut.n5_resp_data_w8, &dut.n5_resp_data_w9, &dut.n5_resp_data_w10, &dut.n5_resp_data_w11, &dut.n5_resp_data_w12, &dut.n5_resp_data_w13, &dut.n5_resp_data_w14, &dut.n5_resp_data_w15, &dut.n5_resp_data_w16, &dut.n5_resp_data_w17, &dut.n5_resp_data_w18, &dut.n5_resp_data_w19, &dut.n5_resp_data_w20, &dut.n5_resp_data_w21, &dut.n5_resp_data_w22, &dut.n5_resp_data_w23, &dut.n5_resp_data_w24, &dut.n5_resp_data_w25, &dut.n5_resp_data_w26, &dut.n5_resp_data_w27, &dut.n5_resp_data_w28, &dut.n5_resp_data_w29, &dut.n5_resp_data_w30, &dut.n5_resp_data_w31}, &dut.n5_resp_is_write}, + {&dut.n6_req_valid, &dut.n6_req_write, &dut.n6_req_addr, &dut.n6_req_tag, + {&dut.n6_req_data_w0, &dut.n6_req_data_w1, &dut.n6_req_data_w2, &dut.n6_req_data_w3, &dut.n6_req_data_w4, &dut.n6_req_data_w5, &dut.n6_req_data_w6, &dut.n6_req_data_w7, &dut.n6_req_data_w8, &dut.n6_req_data_w9, &dut.n6_req_data_w10, &dut.n6_req_data_w11, &dut.n6_req_data_w12, &dut.n6_req_data_w13, &dut.n6_req_data_w14, &dut.n6_req_data_w15, &dut.n6_req_data_w16, &dut.n6_req_data_w17, &dut.n6_req_data_w18, &dut.n6_req_data_w19, &dut.n6_req_data_w20, &dut.n6_req_data_w21, &dut.n6_req_data_w22, &dut.n6_req_data_w23, &dut.n6_req_data_w24, &dut.n6_req_data_w25, &dut.n6_req_data_w26, &dut.n6_req_data_w27, &dut.n6_req_data_w28, &dut.n6_req_data_w29, &dut.n6_req_data_w30, &dut.n6_req_data_w31}, &dut.n6_req_ready, &dut.n6_resp_ready, &dut.n6_resp_valid, &dut.n6_resp_tag, + {&dut.n6_resp_data_w0, &dut.n6_resp_data_w1, &dut.n6_resp_data_w2, &dut.n6_resp_data_w3, &dut.n6_resp_data_w4, &dut.n6_resp_data_w5, &dut.n6_resp_data_w6, &dut.n6_resp_data_w7, &dut.n6_resp_data_w8, &dut.n6_resp_data_w9, &dut.n6_resp_data_w10, &dut.n6_resp_data_w11, &dut.n6_resp_data_w12, &dut.n6_resp_data_w13, &dut.n6_resp_data_w14, &dut.n6_resp_data_w15, &dut.n6_resp_data_w16, &dut.n6_resp_data_w17, &dut.n6_resp_data_w18, &dut.n6_resp_data_w19, &dut.n6_resp_data_w20, &dut.n6_resp_data_w21, &dut.n6_resp_data_w22, &dut.n6_resp_data_w23, &dut.n6_resp_data_w24, &dut.n6_resp_data_w25, &dut.n6_resp_data_w26, &dut.n6_resp_data_w27, &dut.n6_resp_data_w28, &dut.n6_resp_data_w29, &dut.n6_resp_data_w30, &dut.n6_resp_data_w31}, &dut.n6_resp_is_write}, + {&dut.n7_req_valid, &dut.n7_req_write, &dut.n7_req_addr, &dut.n7_req_tag, + {&dut.n7_req_data_w0, &dut.n7_req_data_w1, &dut.n7_req_data_w2, &dut.n7_req_data_w3, &dut.n7_req_data_w4, &dut.n7_req_data_w5, &dut.n7_req_data_w6, &dut.n7_req_data_w7, &dut.n7_req_data_w8, &dut.n7_req_data_w9, &dut.n7_req_data_w10, &dut.n7_req_data_w11, &dut.n7_req_data_w12, &dut.n7_req_data_w13, &dut.n7_req_data_w14, &dut.n7_req_data_w15, &dut.n7_req_data_w16, &dut.n7_req_data_w17, &dut.n7_req_data_w18, &dut.n7_req_data_w19, &dut.n7_req_data_w20, &dut.n7_req_data_w21, &dut.n7_req_data_w22, &dut.n7_req_data_w23, &dut.n7_req_data_w24, &dut.n7_req_data_w25, &dut.n7_req_data_w26, &dut.n7_req_data_w27, &dut.n7_req_data_w28, &dut.n7_req_data_w29, &dut.n7_req_data_w30, &dut.n7_req_data_w31}, &dut.n7_req_ready, &dut.n7_resp_ready, &dut.n7_resp_valid, &dut.n7_resp_tag, + {&dut.n7_resp_data_w0, &dut.n7_resp_data_w1, &dut.n7_resp_data_w2, &dut.n7_resp_data_w3, &dut.n7_resp_data_w4, &dut.n7_resp_data_w5, &dut.n7_resp_data_w6, &dut.n7_resp_data_w7, &dut.n7_resp_data_w8, &dut.n7_resp_data_w9, &dut.n7_resp_data_w10, &dut.n7_resp_data_w11, &dut.n7_resp_data_w12, &dut.n7_resp_data_w13, &dut.n7_resp_data_w14, &dut.n7_resp_data_w15, &dut.n7_resp_data_w16, &dut.n7_resp_data_w17, &dut.n7_resp_data_w18, &dut.n7_resp_data_w19, &dut.n7_resp_data_w20, &dut.n7_resp_data_w21, &dut.n7_resp_data_w22, &dut.n7_resp_data_w23, &dut.n7_resp_data_w24, &dut.n7_resp_data_w25, &dut.n7_resp_data_w26, &dut.n7_resp_data_w27, &dut.n7_resp_data_w28, &dut.n7_resp_data_w29, &dut.n7_resp_data_w30, &dut.n7_resp_data_w31}, &dut.n7_resp_is_write}, + }}; + + for (auto &n : nodes) { + zeroReq(n); + setRespReady(n, true); + } + + std::uint64_t cycle = 0; + + for (int n = 0; n < kNodes; n++) { + const auto addr = makeAddr(static_cast(n), static_cast(n)); + const auto data = makeData(static_cast(n + 1)); + const std::uint8_t tag_w = static_cast(n); + const std::uint8_t tag_r = static_cast(0x80 | n); + + sendReq(tb, nodes[n], cycle, n, true, addr, tag_w, data, trace); + waitResp(tb, nodes[n], cycle, n, tag_w, true, data, trace); + + sendReq(tb, nodes[n], cycle, n, false, addr, tag_r, DataLine{}, trace); + waitResp(tb, nodes[n], cycle, n, tag_r, false, data, trace); + } + + // Cross-node: node0 writes to pipe2, then reads it back. + { + const auto addr = makeAddr(5, 2); + const auto data = makeData(0xAA); + sendReq(tb, nodes[0], cycle, 0, true, addr, 0x55, data, trace); + waitResp(tb, nodes[0], cycle, 0, 0x55, true, data, trace); + sendReq(tb, nodes[0], cycle, 0, false, addr, 0x56, DataLine{}, trace); + waitResp(tb, nodes[0], cycle, 0, 0x56, false, data, trace); + } + + // Ring traffic: each node accesses a non-local pipe to exercise ring flow. + for (int n = 0; n < kNodes; n++) { + const int dst_pipe = (n + 2) % kNodes; + const auto addr = makeAddr(16 + n, static_cast(dst_pipe)); + const auto data = makeData(0x100 + n); + const std::uint8_t tag_w = static_cast(0x20 + n); + const std::uint8_t tag_r = static_cast(0xA0 + n); + + sendReq(tb, nodes[n], cycle, n, true, addr, tag_w, data, trace); + waitResp(tb, nodes[n], cycle, n, tag_w, true, data, trace); + sendReq(tb, nodes[n], cycle, n, false, addr, tag_r, DataLine{}, trace); + waitResp(tb, nodes[n], cycle, n, tag_r, false, data, trace); + } + + std::cout << "PASS: TMU tests\n"; + return 0; +} diff --git a/janus/tb/tb_janus_tmu_pyc.sv b/janus/tb/tb_janus_tmu_pyc.sv new file mode 100644 index 0000000..3df2527 --- /dev/null +++ b/janus/tb/tb_janus_tmu_pyc.sv @@ -0,0 +1,744 @@ +module tb_janus_tmu_pyc; + logic clk; + logic rst; + + logic req_valid [0:7]; + logic req_write [0:7]; + logic [19:0] req_addr [0:7]; + logic [7:0] req_tag [0:7]; + logic [63:0] req_data [0:7][0:31]; + logic req_ready [0:7]; + + logic resp_ready [0:7]; + logic resp_valid [0:7]; + logic [7:0] resp_tag [0:7]; + logic [63:0] resp_data [0:7][0:31]; + logic resp_is_write [0:7]; + + logic [63:0] line_data [0:31]; + logic [63:0] line_zero [0:31]; + + janus_tmu_pyc dut ( + .clk(clk), + .rst(rst), + .n0_req_valid(req_valid[0]), + .n0_req_write(req_write[0]), + .n0_req_addr(req_addr[0]), + .n0_req_tag(req_tag[0]), + .n0_req_data_w0(req_data[0][0]), + .n0_req_data_w1(req_data[0][1]), + .n0_req_data_w2(req_data[0][2]), + .n0_req_data_w3(req_data[0][3]), + .n0_req_data_w4(req_data[0][4]), + .n0_req_data_w5(req_data[0][5]), + .n0_req_data_w6(req_data[0][6]), + .n0_req_data_w7(req_data[0][7]), + .n0_req_data_w8(req_data[0][8]), + .n0_req_data_w9(req_data[0][9]), + .n0_req_data_w10(req_data[0][10]), + .n0_req_data_w11(req_data[0][11]), + .n0_req_data_w12(req_data[0][12]), + .n0_req_data_w13(req_data[0][13]), + .n0_req_data_w14(req_data[0][14]), + .n0_req_data_w15(req_data[0][15]), + .n0_req_data_w16(req_data[0][16]), + .n0_req_data_w17(req_data[0][17]), + .n0_req_data_w18(req_data[0][18]), + .n0_req_data_w19(req_data[0][19]), + .n0_req_data_w20(req_data[0][20]), + .n0_req_data_w21(req_data[0][21]), + .n0_req_data_w22(req_data[0][22]), + .n0_req_data_w23(req_data[0][23]), + .n0_req_data_w24(req_data[0][24]), + .n0_req_data_w25(req_data[0][25]), + .n0_req_data_w26(req_data[0][26]), + .n0_req_data_w27(req_data[0][27]), + .n0_req_data_w28(req_data[0][28]), + .n0_req_data_w29(req_data[0][29]), + .n0_req_data_w30(req_data[0][30]), + .n0_req_data_w31(req_data[0][31]), + .n0_req_ready(req_ready[0]), + .n0_resp_ready(resp_ready[0]), + .n0_resp_valid(resp_valid[0]), + .n0_resp_tag(resp_tag[0]), + .n0_resp_data_w0(resp_data[0][0]), + .n0_resp_data_w1(resp_data[0][1]), + .n0_resp_data_w2(resp_data[0][2]), + .n0_resp_data_w3(resp_data[0][3]), + .n0_resp_data_w4(resp_data[0][4]), + .n0_resp_data_w5(resp_data[0][5]), + .n0_resp_data_w6(resp_data[0][6]), + .n0_resp_data_w7(resp_data[0][7]), + .n0_resp_data_w8(resp_data[0][8]), + .n0_resp_data_w9(resp_data[0][9]), + .n0_resp_data_w10(resp_data[0][10]), + .n0_resp_data_w11(resp_data[0][11]), + .n0_resp_data_w12(resp_data[0][12]), + .n0_resp_data_w13(resp_data[0][13]), + .n0_resp_data_w14(resp_data[0][14]), + .n0_resp_data_w15(resp_data[0][15]), + .n0_resp_data_w16(resp_data[0][16]), + .n0_resp_data_w17(resp_data[0][17]), + .n0_resp_data_w18(resp_data[0][18]), + .n0_resp_data_w19(resp_data[0][19]), + .n0_resp_data_w20(resp_data[0][20]), + .n0_resp_data_w21(resp_data[0][21]), + .n0_resp_data_w22(resp_data[0][22]), + .n0_resp_data_w23(resp_data[0][23]), + .n0_resp_data_w24(resp_data[0][24]), + .n0_resp_data_w25(resp_data[0][25]), + .n0_resp_data_w26(resp_data[0][26]), + .n0_resp_data_w27(resp_data[0][27]), + .n0_resp_data_w28(resp_data[0][28]), + .n0_resp_data_w29(resp_data[0][29]), + .n0_resp_data_w30(resp_data[0][30]), + .n0_resp_data_w31(resp_data[0][31]), + .n0_resp_is_write(resp_is_write[0]), + + .n1_req_valid(req_valid[1]), + .n1_req_write(req_write[1]), + .n1_req_addr(req_addr[1]), + .n1_req_tag(req_tag[1]), + .n1_req_data_w0(req_data[1][0]), + .n1_req_data_w1(req_data[1][1]), + .n1_req_data_w2(req_data[1][2]), + .n1_req_data_w3(req_data[1][3]), + .n1_req_data_w4(req_data[1][4]), + .n1_req_data_w5(req_data[1][5]), + .n1_req_data_w6(req_data[1][6]), + .n1_req_data_w7(req_data[1][7]), + .n1_req_data_w8(req_data[1][8]), + .n1_req_data_w9(req_data[1][9]), + .n1_req_data_w10(req_data[1][10]), + .n1_req_data_w11(req_data[1][11]), + .n1_req_data_w12(req_data[1][12]), + .n1_req_data_w13(req_data[1][13]), + .n1_req_data_w14(req_data[1][14]), + .n1_req_data_w15(req_data[1][15]), + .n1_req_data_w16(req_data[1][16]), + .n1_req_data_w17(req_data[1][17]), + .n1_req_data_w18(req_data[1][18]), + .n1_req_data_w19(req_data[1][19]), + .n1_req_data_w20(req_data[1][20]), + .n1_req_data_w21(req_data[1][21]), + .n1_req_data_w22(req_data[1][22]), + .n1_req_data_w23(req_data[1][23]), + .n1_req_data_w24(req_data[1][24]), + .n1_req_data_w25(req_data[1][25]), + .n1_req_data_w26(req_data[1][26]), + .n1_req_data_w27(req_data[1][27]), + .n1_req_data_w28(req_data[1][28]), + .n1_req_data_w29(req_data[1][29]), + .n1_req_data_w30(req_data[1][30]), + .n1_req_data_w31(req_data[1][31]), + .n1_req_ready(req_ready[1]), + .n1_resp_ready(resp_ready[1]), + .n1_resp_valid(resp_valid[1]), + .n1_resp_tag(resp_tag[1]), + .n1_resp_data_w0(resp_data[1][0]), + .n1_resp_data_w1(resp_data[1][1]), + .n1_resp_data_w2(resp_data[1][2]), + .n1_resp_data_w3(resp_data[1][3]), + .n1_resp_data_w4(resp_data[1][4]), + .n1_resp_data_w5(resp_data[1][5]), + .n1_resp_data_w6(resp_data[1][6]), + .n1_resp_data_w7(resp_data[1][7]), + .n1_resp_data_w8(resp_data[1][8]), + .n1_resp_data_w9(resp_data[1][9]), + .n1_resp_data_w10(resp_data[1][10]), + .n1_resp_data_w11(resp_data[1][11]), + .n1_resp_data_w12(resp_data[1][12]), + .n1_resp_data_w13(resp_data[1][13]), + .n1_resp_data_w14(resp_data[1][14]), + .n1_resp_data_w15(resp_data[1][15]), + .n1_resp_data_w16(resp_data[1][16]), + .n1_resp_data_w17(resp_data[1][17]), + .n1_resp_data_w18(resp_data[1][18]), + .n1_resp_data_w19(resp_data[1][19]), + .n1_resp_data_w20(resp_data[1][20]), + .n1_resp_data_w21(resp_data[1][21]), + .n1_resp_data_w22(resp_data[1][22]), + .n1_resp_data_w23(resp_data[1][23]), + .n1_resp_data_w24(resp_data[1][24]), + .n1_resp_data_w25(resp_data[1][25]), + .n1_resp_data_w26(resp_data[1][26]), + .n1_resp_data_w27(resp_data[1][27]), + .n1_resp_data_w28(resp_data[1][28]), + .n1_resp_data_w29(resp_data[1][29]), + .n1_resp_data_w30(resp_data[1][30]), + .n1_resp_data_w31(resp_data[1][31]), + .n1_resp_is_write(resp_is_write[1]), + + .n2_req_valid(req_valid[2]), + .n2_req_write(req_write[2]), + .n2_req_addr(req_addr[2]), + .n2_req_tag(req_tag[2]), + .n2_req_data_w0(req_data[2][0]), + .n2_req_data_w1(req_data[2][1]), + .n2_req_data_w2(req_data[2][2]), + .n2_req_data_w3(req_data[2][3]), + .n2_req_data_w4(req_data[2][4]), + .n2_req_data_w5(req_data[2][5]), + .n2_req_data_w6(req_data[2][6]), + .n2_req_data_w7(req_data[2][7]), + .n2_req_data_w8(req_data[2][8]), + .n2_req_data_w9(req_data[2][9]), + .n2_req_data_w10(req_data[2][10]), + .n2_req_data_w11(req_data[2][11]), + .n2_req_data_w12(req_data[2][12]), + .n2_req_data_w13(req_data[2][13]), + .n2_req_data_w14(req_data[2][14]), + .n2_req_data_w15(req_data[2][15]), + .n2_req_data_w16(req_data[2][16]), + .n2_req_data_w17(req_data[2][17]), + .n2_req_data_w18(req_data[2][18]), + .n2_req_data_w19(req_data[2][19]), + .n2_req_data_w20(req_data[2][20]), + .n2_req_data_w21(req_data[2][21]), + .n2_req_data_w22(req_data[2][22]), + .n2_req_data_w23(req_data[2][23]), + .n2_req_data_w24(req_data[2][24]), + .n2_req_data_w25(req_data[2][25]), + .n2_req_data_w26(req_data[2][26]), + .n2_req_data_w27(req_data[2][27]), + .n2_req_data_w28(req_data[2][28]), + .n2_req_data_w29(req_data[2][29]), + .n2_req_data_w30(req_data[2][30]), + .n2_req_data_w31(req_data[2][31]), + .n2_req_ready(req_ready[2]), + .n2_resp_ready(resp_ready[2]), + .n2_resp_valid(resp_valid[2]), + .n2_resp_tag(resp_tag[2]), + .n2_resp_data_w0(resp_data[2][0]), + .n2_resp_data_w1(resp_data[2][1]), + .n2_resp_data_w2(resp_data[2][2]), + .n2_resp_data_w3(resp_data[2][3]), + .n2_resp_data_w4(resp_data[2][4]), + .n2_resp_data_w5(resp_data[2][5]), + .n2_resp_data_w6(resp_data[2][6]), + .n2_resp_data_w7(resp_data[2][7]), + .n2_resp_data_w8(resp_data[2][8]), + .n2_resp_data_w9(resp_data[2][9]), + .n2_resp_data_w10(resp_data[2][10]), + .n2_resp_data_w11(resp_data[2][11]), + .n2_resp_data_w12(resp_data[2][12]), + .n2_resp_data_w13(resp_data[2][13]), + .n2_resp_data_w14(resp_data[2][14]), + .n2_resp_data_w15(resp_data[2][15]), + .n2_resp_data_w16(resp_data[2][16]), + .n2_resp_data_w17(resp_data[2][17]), + .n2_resp_data_w18(resp_data[2][18]), + .n2_resp_data_w19(resp_data[2][19]), + .n2_resp_data_w20(resp_data[2][20]), + .n2_resp_data_w21(resp_data[2][21]), + .n2_resp_data_w22(resp_data[2][22]), + .n2_resp_data_w23(resp_data[2][23]), + .n2_resp_data_w24(resp_data[2][24]), + .n2_resp_data_w25(resp_data[2][25]), + .n2_resp_data_w26(resp_data[2][26]), + .n2_resp_data_w27(resp_data[2][27]), + .n2_resp_data_w28(resp_data[2][28]), + .n2_resp_data_w29(resp_data[2][29]), + .n2_resp_data_w30(resp_data[2][30]), + .n2_resp_data_w31(resp_data[2][31]), + .n2_resp_is_write(resp_is_write[2]), + + .n3_req_valid(req_valid[3]), + .n3_req_write(req_write[3]), + .n3_req_addr(req_addr[3]), + .n3_req_tag(req_tag[3]), + .n3_req_data_w0(req_data[3][0]), + .n3_req_data_w1(req_data[3][1]), + .n3_req_data_w2(req_data[3][2]), + .n3_req_data_w3(req_data[3][3]), + .n3_req_data_w4(req_data[3][4]), + .n3_req_data_w5(req_data[3][5]), + .n3_req_data_w6(req_data[3][6]), + .n3_req_data_w7(req_data[3][7]), + .n3_req_data_w8(req_data[3][8]), + .n3_req_data_w9(req_data[3][9]), + .n3_req_data_w10(req_data[3][10]), + .n3_req_data_w11(req_data[3][11]), + .n3_req_data_w12(req_data[3][12]), + .n3_req_data_w13(req_data[3][13]), + .n3_req_data_w14(req_data[3][14]), + .n3_req_data_w15(req_data[3][15]), + .n3_req_data_w16(req_data[3][16]), + .n3_req_data_w17(req_data[3][17]), + .n3_req_data_w18(req_data[3][18]), + .n3_req_data_w19(req_data[3][19]), + .n3_req_data_w20(req_data[3][20]), + .n3_req_data_w21(req_data[3][21]), + .n3_req_data_w22(req_data[3][22]), + .n3_req_data_w23(req_data[3][23]), + .n3_req_data_w24(req_data[3][24]), + .n3_req_data_w25(req_data[3][25]), + .n3_req_data_w26(req_data[3][26]), + .n3_req_data_w27(req_data[3][27]), + .n3_req_data_w28(req_data[3][28]), + .n3_req_data_w29(req_data[3][29]), + .n3_req_data_w30(req_data[3][30]), + .n3_req_data_w31(req_data[3][31]), + .n3_req_ready(req_ready[3]), + .n3_resp_ready(resp_ready[3]), + .n3_resp_valid(resp_valid[3]), + .n3_resp_tag(resp_tag[3]), + .n3_resp_data_w0(resp_data[3][0]), + .n3_resp_data_w1(resp_data[3][1]), + .n3_resp_data_w2(resp_data[3][2]), + .n3_resp_data_w3(resp_data[3][3]), + .n3_resp_data_w4(resp_data[3][4]), + .n3_resp_data_w5(resp_data[3][5]), + .n3_resp_data_w6(resp_data[3][6]), + .n3_resp_data_w7(resp_data[3][7]), + .n3_resp_data_w8(resp_data[3][8]), + .n3_resp_data_w9(resp_data[3][9]), + .n3_resp_data_w10(resp_data[3][10]), + .n3_resp_data_w11(resp_data[3][11]), + .n3_resp_data_w12(resp_data[3][12]), + .n3_resp_data_w13(resp_data[3][13]), + .n3_resp_data_w14(resp_data[3][14]), + .n3_resp_data_w15(resp_data[3][15]), + .n3_resp_data_w16(resp_data[3][16]), + .n3_resp_data_w17(resp_data[3][17]), + .n3_resp_data_w18(resp_data[3][18]), + .n3_resp_data_w19(resp_data[3][19]), + .n3_resp_data_w20(resp_data[3][20]), + .n3_resp_data_w21(resp_data[3][21]), + .n3_resp_data_w22(resp_data[3][22]), + .n3_resp_data_w23(resp_data[3][23]), + .n3_resp_data_w24(resp_data[3][24]), + .n3_resp_data_w25(resp_data[3][25]), + .n3_resp_data_w26(resp_data[3][26]), + .n3_resp_data_w27(resp_data[3][27]), + .n3_resp_data_w28(resp_data[3][28]), + .n3_resp_data_w29(resp_data[3][29]), + .n3_resp_data_w30(resp_data[3][30]), + .n3_resp_data_w31(resp_data[3][31]), + .n3_resp_is_write(resp_is_write[3]), + + .n4_req_valid(req_valid[4]), + .n4_req_write(req_write[4]), + .n4_req_addr(req_addr[4]), + .n4_req_tag(req_tag[4]), + .n4_req_data_w0(req_data[4][0]), + .n4_req_data_w1(req_data[4][1]), + .n4_req_data_w2(req_data[4][2]), + .n4_req_data_w3(req_data[4][3]), + .n4_req_data_w4(req_data[4][4]), + .n4_req_data_w5(req_data[4][5]), + .n4_req_data_w6(req_data[4][6]), + .n4_req_data_w7(req_data[4][7]), + .n4_req_data_w8(req_data[4][8]), + .n4_req_data_w9(req_data[4][9]), + .n4_req_data_w10(req_data[4][10]), + .n4_req_data_w11(req_data[4][11]), + .n4_req_data_w12(req_data[4][12]), + .n4_req_data_w13(req_data[4][13]), + .n4_req_data_w14(req_data[4][14]), + .n4_req_data_w15(req_data[4][15]), + .n4_req_data_w16(req_data[4][16]), + .n4_req_data_w17(req_data[4][17]), + .n4_req_data_w18(req_data[4][18]), + .n4_req_data_w19(req_data[4][19]), + .n4_req_data_w20(req_data[4][20]), + .n4_req_data_w21(req_data[4][21]), + .n4_req_data_w22(req_data[4][22]), + .n4_req_data_w23(req_data[4][23]), + .n4_req_data_w24(req_data[4][24]), + .n4_req_data_w25(req_data[4][25]), + .n4_req_data_w26(req_data[4][26]), + .n4_req_data_w27(req_data[4][27]), + .n4_req_data_w28(req_data[4][28]), + .n4_req_data_w29(req_data[4][29]), + .n4_req_data_w30(req_data[4][30]), + .n4_req_data_w31(req_data[4][31]), + .n4_req_ready(req_ready[4]), + .n4_resp_ready(resp_ready[4]), + .n4_resp_valid(resp_valid[4]), + .n4_resp_tag(resp_tag[4]), + .n4_resp_data_w0(resp_data[4][0]), + .n4_resp_data_w1(resp_data[4][1]), + .n4_resp_data_w2(resp_data[4][2]), + .n4_resp_data_w3(resp_data[4][3]), + .n4_resp_data_w4(resp_data[4][4]), + .n4_resp_data_w5(resp_data[4][5]), + .n4_resp_data_w6(resp_data[4][6]), + .n4_resp_data_w7(resp_data[4][7]), + .n4_resp_data_w8(resp_data[4][8]), + .n4_resp_data_w9(resp_data[4][9]), + .n4_resp_data_w10(resp_data[4][10]), + .n4_resp_data_w11(resp_data[4][11]), + .n4_resp_data_w12(resp_data[4][12]), + .n4_resp_data_w13(resp_data[4][13]), + .n4_resp_data_w14(resp_data[4][14]), + .n4_resp_data_w15(resp_data[4][15]), + .n4_resp_data_w16(resp_data[4][16]), + .n4_resp_data_w17(resp_data[4][17]), + .n4_resp_data_w18(resp_data[4][18]), + .n4_resp_data_w19(resp_data[4][19]), + .n4_resp_data_w20(resp_data[4][20]), + .n4_resp_data_w21(resp_data[4][21]), + .n4_resp_data_w22(resp_data[4][22]), + .n4_resp_data_w23(resp_data[4][23]), + .n4_resp_data_w24(resp_data[4][24]), + .n4_resp_data_w25(resp_data[4][25]), + .n4_resp_data_w26(resp_data[4][26]), + .n4_resp_data_w27(resp_data[4][27]), + .n4_resp_data_w28(resp_data[4][28]), + .n4_resp_data_w29(resp_data[4][29]), + .n4_resp_data_w30(resp_data[4][30]), + .n4_resp_data_w31(resp_data[4][31]), + .n4_resp_is_write(resp_is_write[4]), + + .n5_req_valid(req_valid[5]), + .n5_req_write(req_write[5]), + .n5_req_addr(req_addr[5]), + .n5_req_tag(req_tag[5]), + .n5_req_data_w0(req_data[5][0]), + .n5_req_data_w1(req_data[5][1]), + .n5_req_data_w2(req_data[5][2]), + .n5_req_data_w3(req_data[5][3]), + .n5_req_data_w4(req_data[5][4]), + .n5_req_data_w5(req_data[5][5]), + .n5_req_data_w6(req_data[5][6]), + .n5_req_data_w7(req_data[5][7]), + .n5_req_data_w8(req_data[5][8]), + .n5_req_data_w9(req_data[5][9]), + .n5_req_data_w10(req_data[5][10]), + .n5_req_data_w11(req_data[5][11]), + .n5_req_data_w12(req_data[5][12]), + .n5_req_data_w13(req_data[5][13]), + .n5_req_data_w14(req_data[5][14]), + .n5_req_data_w15(req_data[5][15]), + .n5_req_data_w16(req_data[5][16]), + .n5_req_data_w17(req_data[5][17]), + .n5_req_data_w18(req_data[5][18]), + .n5_req_data_w19(req_data[5][19]), + .n5_req_data_w20(req_data[5][20]), + .n5_req_data_w21(req_data[5][21]), + .n5_req_data_w22(req_data[5][22]), + .n5_req_data_w23(req_data[5][23]), + .n5_req_data_w24(req_data[5][24]), + .n5_req_data_w25(req_data[5][25]), + .n5_req_data_w26(req_data[5][26]), + .n5_req_data_w27(req_data[5][27]), + .n5_req_data_w28(req_data[5][28]), + .n5_req_data_w29(req_data[5][29]), + .n5_req_data_w30(req_data[5][30]), + .n5_req_data_w31(req_data[5][31]), + .n5_req_ready(req_ready[5]), + .n5_resp_ready(resp_ready[5]), + .n5_resp_valid(resp_valid[5]), + .n5_resp_tag(resp_tag[5]), + .n5_resp_data_w0(resp_data[5][0]), + .n5_resp_data_w1(resp_data[5][1]), + .n5_resp_data_w2(resp_data[5][2]), + .n5_resp_data_w3(resp_data[5][3]), + .n5_resp_data_w4(resp_data[5][4]), + .n5_resp_data_w5(resp_data[5][5]), + .n5_resp_data_w6(resp_data[5][6]), + .n5_resp_data_w7(resp_data[5][7]), + .n5_resp_data_w8(resp_data[5][8]), + .n5_resp_data_w9(resp_data[5][9]), + .n5_resp_data_w10(resp_data[5][10]), + .n5_resp_data_w11(resp_data[5][11]), + .n5_resp_data_w12(resp_data[5][12]), + .n5_resp_data_w13(resp_data[5][13]), + .n5_resp_data_w14(resp_data[5][14]), + .n5_resp_data_w15(resp_data[5][15]), + .n5_resp_data_w16(resp_data[5][16]), + .n5_resp_data_w17(resp_data[5][17]), + .n5_resp_data_w18(resp_data[5][18]), + .n5_resp_data_w19(resp_data[5][19]), + .n5_resp_data_w20(resp_data[5][20]), + .n5_resp_data_w21(resp_data[5][21]), + .n5_resp_data_w22(resp_data[5][22]), + .n5_resp_data_w23(resp_data[5][23]), + .n5_resp_data_w24(resp_data[5][24]), + .n5_resp_data_w25(resp_data[5][25]), + .n5_resp_data_w26(resp_data[5][26]), + .n5_resp_data_w27(resp_data[5][27]), + .n5_resp_data_w28(resp_data[5][28]), + .n5_resp_data_w29(resp_data[5][29]), + .n5_resp_data_w30(resp_data[5][30]), + .n5_resp_data_w31(resp_data[5][31]), + .n5_resp_is_write(resp_is_write[5]), + + .n6_req_valid(req_valid[6]), + .n6_req_write(req_write[6]), + .n6_req_addr(req_addr[6]), + .n6_req_tag(req_tag[6]), + .n6_req_data_w0(req_data[6][0]), + .n6_req_data_w1(req_data[6][1]), + .n6_req_data_w2(req_data[6][2]), + .n6_req_data_w3(req_data[6][3]), + .n6_req_data_w4(req_data[6][4]), + .n6_req_data_w5(req_data[6][5]), + .n6_req_data_w6(req_data[6][6]), + .n6_req_data_w7(req_data[6][7]), + .n6_req_data_w8(req_data[6][8]), + .n6_req_data_w9(req_data[6][9]), + .n6_req_data_w10(req_data[6][10]), + .n6_req_data_w11(req_data[6][11]), + .n6_req_data_w12(req_data[6][12]), + .n6_req_data_w13(req_data[6][13]), + .n6_req_data_w14(req_data[6][14]), + .n6_req_data_w15(req_data[6][15]), + .n6_req_data_w16(req_data[6][16]), + .n6_req_data_w17(req_data[6][17]), + .n6_req_data_w18(req_data[6][18]), + .n6_req_data_w19(req_data[6][19]), + .n6_req_data_w20(req_data[6][20]), + .n6_req_data_w21(req_data[6][21]), + .n6_req_data_w22(req_data[6][22]), + .n6_req_data_w23(req_data[6][23]), + .n6_req_data_w24(req_data[6][24]), + .n6_req_data_w25(req_data[6][25]), + .n6_req_data_w26(req_data[6][26]), + .n6_req_data_w27(req_data[6][27]), + .n6_req_data_w28(req_data[6][28]), + .n6_req_data_w29(req_data[6][29]), + .n6_req_data_w30(req_data[6][30]), + .n6_req_data_w31(req_data[6][31]), + .n6_req_ready(req_ready[6]), + .n6_resp_ready(resp_ready[6]), + .n6_resp_valid(resp_valid[6]), + .n6_resp_tag(resp_tag[6]), + .n6_resp_data_w0(resp_data[6][0]), + .n6_resp_data_w1(resp_data[6][1]), + .n6_resp_data_w2(resp_data[6][2]), + .n6_resp_data_w3(resp_data[6][3]), + .n6_resp_data_w4(resp_data[6][4]), + .n6_resp_data_w5(resp_data[6][5]), + .n6_resp_data_w6(resp_data[6][6]), + .n6_resp_data_w7(resp_data[6][7]), + .n6_resp_data_w8(resp_data[6][8]), + .n6_resp_data_w9(resp_data[6][9]), + .n6_resp_data_w10(resp_data[6][10]), + .n6_resp_data_w11(resp_data[6][11]), + .n6_resp_data_w12(resp_data[6][12]), + .n6_resp_data_w13(resp_data[6][13]), + .n6_resp_data_w14(resp_data[6][14]), + .n6_resp_data_w15(resp_data[6][15]), + .n6_resp_data_w16(resp_data[6][16]), + .n6_resp_data_w17(resp_data[6][17]), + .n6_resp_data_w18(resp_data[6][18]), + .n6_resp_data_w19(resp_data[6][19]), + .n6_resp_data_w20(resp_data[6][20]), + .n6_resp_data_w21(resp_data[6][21]), + .n6_resp_data_w22(resp_data[6][22]), + .n6_resp_data_w23(resp_data[6][23]), + .n6_resp_data_w24(resp_data[6][24]), + .n6_resp_data_w25(resp_data[6][25]), + .n6_resp_data_w26(resp_data[6][26]), + .n6_resp_data_w27(resp_data[6][27]), + .n6_resp_data_w28(resp_data[6][28]), + .n6_resp_data_w29(resp_data[6][29]), + .n6_resp_data_w30(resp_data[6][30]), + .n6_resp_data_w31(resp_data[6][31]), + .n6_resp_is_write(resp_is_write[6]), + + .n7_req_valid(req_valid[7]), + .n7_req_write(req_write[7]), + .n7_req_addr(req_addr[7]), + .n7_req_tag(req_tag[7]), + .n7_req_data_w0(req_data[7][0]), + .n7_req_data_w1(req_data[7][1]), + .n7_req_data_w2(req_data[7][2]), + .n7_req_data_w3(req_data[7][3]), + .n7_req_data_w4(req_data[7][4]), + .n7_req_data_w5(req_data[7][5]), + .n7_req_data_w6(req_data[7][6]), + .n7_req_data_w7(req_data[7][7]), + .n7_req_data_w8(req_data[7][8]), + .n7_req_data_w9(req_data[7][9]), + .n7_req_data_w10(req_data[7][10]), + .n7_req_data_w11(req_data[7][11]), + .n7_req_data_w12(req_data[7][12]), + .n7_req_data_w13(req_data[7][13]), + .n7_req_data_w14(req_data[7][14]), + .n7_req_data_w15(req_data[7][15]), + .n7_req_data_w16(req_data[7][16]), + .n7_req_data_w17(req_data[7][17]), + .n7_req_data_w18(req_data[7][18]), + .n7_req_data_w19(req_data[7][19]), + .n7_req_data_w20(req_data[7][20]), + .n7_req_data_w21(req_data[7][21]), + .n7_req_data_w22(req_data[7][22]), + .n7_req_data_w23(req_data[7][23]), + .n7_req_data_w24(req_data[7][24]), + .n7_req_data_w25(req_data[7][25]), + .n7_req_data_w26(req_data[7][26]), + .n7_req_data_w27(req_data[7][27]), + .n7_req_data_w28(req_data[7][28]), + .n7_req_data_w29(req_data[7][29]), + .n7_req_data_w30(req_data[7][30]), + .n7_req_data_w31(req_data[7][31]), + .n7_req_ready(req_ready[7]), + .n7_resp_ready(resp_ready[7]), + .n7_resp_valid(resp_valid[7]), + .n7_resp_tag(resp_tag[7]), + .n7_resp_data_w0(resp_data[7][0]), + .n7_resp_data_w1(resp_data[7][1]), + .n7_resp_data_w2(resp_data[7][2]), + .n7_resp_data_w3(resp_data[7][3]), + .n7_resp_data_w4(resp_data[7][4]), + .n7_resp_data_w5(resp_data[7][5]), + .n7_resp_data_w6(resp_data[7][6]), + .n7_resp_data_w7(resp_data[7][7]), + .n7_resp_data_w8(resp_data[7][8]), + .n7_resp_data_w9(resp_data[7][9]), + .n7_resp_data_w10(resp_data[7][10]), + .n7_resp_data_w11(resp_data[7][11]), + .n7_resp_data_w12(resp_data[7][12]), + .n7_resp_data_w13(resp_data[7][13]), + .n7_resp_data_w14(resp_data[7][14]), + .n7_resp_data_w15(resp_data[7][15]), + .n7_resp_data_w16(resp_data[7][16]), + .n7_resp_data_w17(resp_data[7][17]), + .n7_resp_data_w18(resp_data[7][18]), + .n7_resp_data_w19(resp_data[7][19]), + .n7_resp_data_w20(resp_data[7][20]), + .n7_resp_data_w21(resp_data[7][21]), + .n7_resp_data_w22(resp_data[7][22]), + .n7_resp_data_w23(resp_data[7][23]), + .n7_resp_data_w24(resp_data[7][24]), + .n7_resp_data_w25(resp_data[7][25]), + .n7_resp_data_w26(resp_data[7][26]), + .n7_resp_data_w27(resp_data[7][27]), + .n7_resp_data_w28(resp_data[7][28]), + .n7_resp_data_w29(resp_data[7][29]), + .n7_resp_data_w30(resp_data[7][30]), + .n7_resp_data_w31(resp_data[7][31]), + .n7_resp_is_write(resp_is_write[7]) + ); + + function automatic [19:0] make_addr(input int index, input int pipe, input int offset); + make_addr = {index[8:0], pipe[2:0], offset[7:0]}; + endfunction + + task automatic fill_data(output logic [63:0] data[0:31], input int seed); + integer i; + begin + for (i = 0; i < 32; i = i + 1) begin + data[i] = {seed[31:0], i[31:0]}; + end + end + endtask + + task automatic clear_line(output logic [63:0] data[0:31]); + integer i; + begin + for (i = 0; i < 32; i = i + 1) begin + data[i] = 64'd0; + end + end + endtask + + task automatic clear_reqs(); + integer i; + integer j; + begin + for (i = 0; i < 8; i = i + 1) begin + req_valid[i] = 1'b0; + req_write[i] = 1'b0; + req_addr[i] = 20'd0; + req_tag[i] = 8'd0; + resp_ready[i] = 1'b1; + for (j = 0; j < 32; j = j + 1) begin + req_data[i][j] = 64'd0; + end + end + end + endtask + + task automatic send_req( + input int node, + input bit write, + input logic [19:0] addr, + input logic [7:0] tag, + input logic [63:0] data[0:31] + ); + integer i; + begin + req_write[node] = write; + req_addr[node] = addr; + req_tag[node] = tag; + for (i = 0; i < 32; i = i + 1) begin + req_data[node][i] = data[i]; + end + req_valid[node] = 1'b1; + while (req_ready[node] !== 1'b1) begin + @(posedge clk); + end + @(posedge clk); + req_valid[node] = 1'b0; + end + endtask + + task automatic wait_resp( + input int node, + input logic [7:0] tag, + input bit expect_write, + input logic [63:0] expect_data[0:31] + ); + integer timeout; + integer i; + begin + timeout = 2000; + while (timeout > 0) begin + @(posedge clk); + if (resp_valid[node]) begin + if (resp_tag[node] !== tag) $fatal(1, "tag mismatch"); + if (resp_is_write[node] !== expect_write) $fatal(1, "resp_is_write mismatch"); + for (i = 0; i < 32; i = i + 1) begin + if (resp_data[node][i] !== expect_data[i]) $fatal(1, "resp_data mismatch"); + end + return; + end + timeout = timeout - 1; + end + $fatal(1, "timeout waiting resp"); + end + endtask + + initial begin + clk = 1'b0; + rst = 1'b1; + clear_reqs(); + repeat (2) @(posedge clk); + rst = 1'b0; + repeat (1) @(posedge clk); + + for (int n = 0; n < 8; n = n + 1) begin + fill_data(line_data, n + 1); + clear_line(line_zero); + send_req(n, 1'b1, make_addr(n, n, 0), n[7:0], line_data); + wait_resp(n, n[7:0], 1'b1, line_data); + send_req(n, 1'b0, make_addr(n, n, 0), (8'h80 | n[7:0]), line_zero); + wait_resp(n, (8'h80 | n[7:0]), 1'b0, line_data); + end + + begin + fill_data(line_data, 8'hAA); + clear_line(line_zero); + send_req(0, 1'b1, make_addr(5, 2, 0), 8'h55, line_data); + wait_resp(0, 8'h55, 1'b1, line_data); + send_req(0, 1'b0, make_addr(5, 2, 0), 8'h56, line_zero); + wait_resp(0, 8'h56, 1'b0, line_data); + end + + $display("PASS: TMU tests"); + $finish; + end + + always #1 clk = ~clk; + + initial begin + if (!$test$plusargs("NOVCD")) begin + $dumpfile("janus/generated/janus_tmu_pyc/tb_janus_tmu_pyc.vcd"); + $dumpvars(0, tb_janus_tmu_pyc); + end + end +endmodule diff --git a/janus/tools/animate_tmu_ring_vcd.py b/janus/tools/animate_tmu_ring_vcd.py new file mode 100755 index 0000000..8792fc0 --- /dev/null +++ b/janus/tools/animate_tmu_ring_vcd.py @@ -0,0 +1,261 @@ +#!/usr/bin/env python3 +import argparse +import math +from pathlib import Path + +RING_ORDER = [0, 1, 3, 5, 7, 6, 4, 2] + + +def ring_positions(center_x, center_y, radius): + positions = {} + n = len(RING_ORDER) + for i, node in enumerate(RING_ORDER): + angle = (2.0 * math.pi * i / n) - (math.pi / 2.0) + x = center_x + radius * math.cos(angle) + y = center_y + radius * math.sin(angle) + positions[node] = (x, y) + return positions + + +def parse_vcd(path: Path, watch_names, max_cycles=None, skip_cycles=0): + watch_names = set(watch_names) + id_to_name = {} + values = {name: "0" for name in watch_names} + snapshots = [] + + with path.open() as f: + in_header = True + for line in f: + line = line.strip() + if not line: + continue + if in_header: + if line.startswith("$var"): + parts = line.split() + if len(parts) >= 5: + code = parts[3] + name = parts[4] + if name in watch_names: + id_to_name[code] = name + elif line.startswith("$enddefinitions"): + in_header = False + continue + + # body parsing + if line[0] == "#": + time = int(line[1:]) + continue + val = line[0] + if val not in "01xXzZ": + continue + code = line[1:] + name = id_to_name.get(code) + if name is None: + continue + values[name] = "0" if val in "xXzZ" else val + + # detect posedge from clk updates + if name == "clk" and val == "1": + if skip_cycles > 0: + skip_cycles -= 1 + continue + snap = {k: values.get(k, "0") for k in watch_names} + snapshots.append(snap) + if max_cycles is not None and len(snapshots) >= max_cycles: + break + + return snapshots + + +def emit_token(lines, token_id, start_xy, end_xy, begin_s, dur_s, color, shape, label, glow_id): + x0, y0 = start_xy + x1, y1 = end_xy + if shape == "circle": + lines.append( + f"" + ) + else: + size = 8 + points = [ + f"{x0:.2f},{y0 - size:.2f}", + f"{x0 + size:.2f},{y0:.2f}", + f"{x0:.2f},{y0 + size:.2f}", + f"{x0 - size:.2f},{y0:.2f}", + ] + lines.append( + f"" + ) + lines.append(f"{label}") + lines.append( + f"" + ) + lines.append( + f"" + ) + lines.append( + f"" + ) + lines.append("" if shape == "circle" else "") + + +def render_svg(snapshots, out_path: Path, cycle_time): + width = 980 + height = 720 + cx = width / 2 + cy = height / 2 + 10 + + req_radius = 230 + rsp_radius = 280 + + req_pos = ring_positions(cx, cy, req_radius) + rsp_pos = ring_positions(cx, cy, rsp_radius) + + next_map = {RING_ORDER[i]: RING_ORDER[(i + 1) % len(RING_ORDER)] for i in range(len(RING_ORDER))} + prev_map = {RING_ORDER[i]: RING_ORDER[(i - 1) % len(RING_ORDER)] for i in range(len(RING_ORDER))} + + lines = [] + lines.append( + f"" + ) + lines.append("") + lines.append( + "" + ) + + lines.append( + "" + "" + "" + "" + "" + "" + "" + "" + "" + "" + ) + + lines.append(f"TMU ring flow (from VCD)") + lines.append( + f"req cw/cc = blue/cyan • rsp cw/cc = green/lime • {cycle_time:.2f}s per cycle" + ) + + lines.append(f"") + lines.append(f"") + + for i in range(len(RING_ORDER)): + a = RING_ORDER[i] + b = RING_ORDER[(i + 1) % len(RING_ORDER)] + x1, y1 = req_pos[a] + x2, y2 = req_pos[b] + lines.append(f"") + + for node, (x, y) in req_pos.items(): + lines.append(f"") + lines.append(f"n{node}") + + for cyc, snap in enumerate(snapshots): + begin = cyc * cycle_time + dur = cycle_time + for nid in range(8): + # requests on inner ring + if snap.get(f"dbg_req_cw_v{nid}") == "1": + start = req_pos[nid] + end = req_pos[next_map[nid]] + emit_token( + lines, + f"req_cw_{cyc}_{nid}", + start, + end, + begin, + dur, + "#38bdf8", + "circle", + f"req cw node={nid} cycle={cyc}", + "glow_req", + ) + if snap.get(f"dbg_req_cc_v{nid}") == "1": + start = req_pos[nid] + end = req_pos[prev_map[nid]] + emit_token( + lines, + f"req_cc_{cyc}_{nid}", + start, + end, + begin, + dur, + "#22d3ee", + "circle", + f"req cc node={nid} cycle={cyc}", + "glow_req", + ) + + # responses on outer ring + if snap.get(f"dbg_rsp_cw_v{nid}") == "1": + start = rsp_pos[nid] + end = rsp_pos[next_map[nid]] + emit_token( + lines, + f"rsp_cw_{cyc}_{nid}", + start, + end, + begin, + dur, + "#22c55e", + "diamond", + f"rsp cw node={nid} cycle={cyc}", + "glow_rsp", + ) + if snap.get(f"dbg_rsp_cc_v{nid}") == "1": + start = rsp_pos[nid] + end = rsp_pos[prev_map[nid]] + emit_token( + lines, + f"rsp_cc_{cyc}_{nid}", + start, + end, + begin, + dur, + "#a3e635", + "diamond", + f"rsp cc node={nid} cycle={cyc}", + "glow_rsp", + ) + + lines.append("") + out_path.write_text("\n".join(lines)) + + +def main(): + parser = argparse.ArgumentParser(description="Animate TMU ring flows from VCD debug signals.") + parser.add_argument("vcd", type=Path, help="Path to VCD (tb_janus_tmu_pyc_cpp.vcd)") + parser.add_argument("-o", "--out", type=Path, default=Path("tmu_flow_real.svg"), help="Output SVG") + parser.add_argument("--cycle", type=float, default=0.20, help="Seconds per cycle") + parser.add_argument("--max-cycles", type=int, default=None, help="Limit cycles") + parser.add_argument("--skip-cycles", type=int, default=0, help="Skip initial cycles") + args = parser.parse_args() + + watch = ["clk"] + for n in range(8): + watch.append(f"dbg_req_cw_v{n}") + watch.append(f"dbg_req_cc_v{n}") + watch.append(f"dbg_rsp_cw_v{n}") + watch.append(f"dbg_rsp_cc_v{n}") + + snapshots = parse_vcd(args.vcd, watch, max_cycles=args.max_cycles, skip_cycles=args.skip_cycles) + if not snapshots: + raise SystemExit("no snapshots found (check VCD path or signals)") + + render_svg(snapshots, args.out, args.cycle) + + +if __name__ == "__main__": + main() diff --git a/janus/tools/animate_tmu_trace.py b/janus/tools/animate_tmu_trace.py new file mode 100755 index 0000000..5fa53cb --- /dev/null +++ b/janus/tools/animate_tmu_trace.py @@ -0,0 +1,239 @@ +#!/usr/bin/env python3 +import argparse +import csv +import math +from collections import defaultdict, deque +from pathlib import Path + +RING_ORDER = [0, 1, 3, 5, 7, 6, 4, 2] + + +def parse_int(text: str) -> int: + text = text.strip() + if text.startswith("0x") or text.startswith("0X"): + return int(text, 16) + return int(text, 10) + + +def load_transactions(path: Path): + accepts = defaultdict(deque) + transactions = [] + max_cycle = 0 + + with path.open() as f: + reader = csv.DictReader(f) + for row in reader: + if not row: + continue + try: + cycle = int(row.get("cycle", "0")) + node = int(row.get("node", "0")) + tag = int(row.get("tag", "0")) + write = int(row.get("write", "0")) + except ValueError: + continue + event = row.get("event", "") + if cycle > max_cycle: + max_cycle = cycle + if event == "accept": + addr_text = row.get("addr_or_word0", "0") + try: + addr = parse_int(addr_text) + except ValueError: + addr = 0 + accepts[(node, tag)].append({ + "cycle": cycle, + "node": node, + "tag": tag, + "write": write, + "addr": addr, + }) + elif event == "resp": + key = (node, tag) + if not accepts[key]: + continue + acc = accepts[key].popleft() + transactions.append({ + "src": acc["node"], + "dst": (acc["addr"] >> 8) & 0x7, + "cycle_accept": acc["cycle"], + "cycle_resp": cycle, + "tag": tag, + "write": acc["write"], + }) + + return transactions, max_cycle + + +def ring_positions(center_x, center_y, radius): + positions = {} + n = len(RING_ORDER) + for i, node in enumerate(RING_ORDER): + angle = (2.0 * math.pi * i / n) - (math.pi / 2.0) + x = center_x + radius * math.cos(angle) + y = center_y + radius * math.sin(angle) + positions[node] = (x, y) + return positions + + +def path_nodes(src, dst): + if src == dst: + return [src] + n = len(RING_ORDER) + pos = {node: i for i, node in enumerate(RING_ORDER)} + s = pos[src] + d = pos[dst] + cw = (d - s) % n + cc = (s - d) % n + if cw <= cc: + step = 1 + dist = cw + else: + step = -1 + dist = cc + nodes = [] + idx = s + for _ in range(dist + 1): + nodes.append(RING_ORDER[idx]) + idx = (idx + step) % n + return nodes + + +def ensure_anim_coords(coords): + if len(coords) == 1: + return [coords[0], coords[0]] + return coords + + +def emit_token(lines, token_id, coords, begin_s, dur_s, color, shape, label): + coords = ensure_anim_coords(coords) + xs = ";".join(f"{x:.2f}" for x, _ in coords) + ys = ";".join(f"{y:.2f}" for _, y in coords) + key_times = ";".join(f"{i / (len(coords) - 1):.3f}" for i in range(len(coords))) + if shape == "circle": + lines.append(f"") + else: + size = 7 + x0, y0 = coords[0] + points = [ + f"{x0:.2f},{y0 - size:.2f}", + f"{x0 + size:.2f},{y0:.2f}", + f"{x0:.2f},{y0 + size:.2f}", + f"{x0 - size:.2f},{y0:.2f}", + ] + lines.append(f"") + lines.append(f"{label}") + lines.append( + f"" + ) + lines.append( + f"" + ) + lines.append( + f"" + ) + lines.append("" if shape == "circle" else "") + + +def render_svg(transactions, max_cycle, out_path: Path, cycle_time): + width = 900 + height = 650 + cx = width / 2 + cy = height / 2 + radius = 230 + + positions = ring_positions(cx, cy, radius) + + lines = [] + lines.append( + f"" + ) + lines.append("") + lines.append( + "" + ) + lines.append("".format(cx, cy, radius)) + lines.append("TMU ring flow animation") + lines.append("blue=accept(req), green=resp") + + for i in range(len(RING_ORDER)): + a = RING_ORDER[i] + b = RING_ORDER[(i + 1) % len(RING_ORDER)] + x1, y1 = positions[a] + x2, y2 = positions[b] + lines.append(f"") + + for node, (x, y) in positions.items(): + lines.append(f"") + lines.append(f"n{node}") + + tpc = cycle_time + for idx, tr in enumerate(transactions): + src = tr["src"] + dst = tr["dst"] + c0 = tr["cycle_accept"] + c1 = tr["cycle_resp"] + tag = tr["tag"] + write = tr["write"] + + req_nodes = path_nodes(src, dst) + req_coords = [positions[n] for n in req_nodes] + req_hops = max(len(req_coords) - 1, 1) + req_dur = req_hops * tpc + req_begin = c0 * tpc + req_label = f"req tag={tag} src={src} dst={dst} w={write}" + emit_token( + lines, + f"req_{idx}", + req_coords, + req_begin, + req_dur, + "#38bdf8", + "circle", + req_label, + ) + + rsp_nodes = path_nodes(dst, src) + rsp_coords = [positions[n] for n in rsp_nodes] + rsp_hops = max(len(rsp_coords) - 1, 1) + rsp_dur = rsp_hops * tpc + rsp_end = c1 * tpc + rsp_begin = max(req_begin + req_dur, rsp_end - rsp_dur) + rsp_label = f"resp tag={tag} src={dst} dst={src} w={write}" + emit_token( + lines, + f"rsp_{idx}", + rsp_coords, + rsp_begin, + rsp_dur, + "#22c55e", + "diamond", + rsp_label, + ) + + lines.append("") + out_path.write_text("\n".join(lines)) + + +def main(): + parser = argparse.ArgumentParser(description="Render animated SVG for TMU ring flows.") + parser.add_argument("csv", type=Path, help="Path to tmu_trace.csv") + parser.add_argument("-o", "--out", type=Path, default=Path("tmu_flow.svg"), help="Output SVG") + parser.add_argument("--cycle", type=float, default=0.06, help="Seconds per cycle") + args = parser.parse_args() + + transactions, max_cycle = load_transactions(args.csv) + if not transactions: + raise SystemExit("no transactions found in CSV") + render_svg(transactions, max_cycle, args.out, args.cycle) + + +if __name__ == "__main__": + main() diff --git a/janus/tools/plot_tmu_trace.py b/janus/tools/plot_tmu_trace.py new file mode 100755 index 0000000..1d57e30 --- /dev/null +++ b/janus/tools/plot_tmu_trace.py @@ -0,0 +1,136 @@ +#!/usr/bin/env python3 +import argparse +import csv +from pathlib import Path + + +def load_events(path: Path): + events = [] + max_cycle = 0 + max_node = 0 + with path.open() as f: + reader = csv.DictReader(f) + for row in reader: + try: + cycle = int(row.get("cycle", "0")) + node = int(row.get("node", "0")) + except ValueError: + continue + event = row.get("event", "") + tag = row.get("tag", "") + write = row.get("write", "") + events.append((cycle, node, event, tag, write)) + if cycle > max_cycle: + max_cycle = cycle + if node > max_node: + max_node = node + return events, max_cycle, max_node + + +def render_svg(events, max_cycle, max_node, scale, lane_h, out_path: Path): + margin_x = 70 + margin_top = 60 + margin_bottom = 30 + width = margin_x * 2 + max_cycle * scale + 1 + height = margin_top + margin_bottom + (max_node + 1) * lane_h + + def y_for(node, event): + base = margin_top + node * lane_h + if event == "resp": + return base + int(lane_h * 0.68) + return base + int(lane_h * 0.28) + + lines = [] + lines.append( + f"" + ) + lines.append("") + lines.append( + "" + ) + + lines.append( + f"TMU trace timeline" + ) + lines.append( + f"accept = blue circle, resp = green diamond" + ) + + if max_cycle <= 50: + tick_step = 5 + elif max_cycle <= 200: + tick_step = 10 + elif max_cycle <= 500: + tick_step = 20 + else: + tick_step = 50 + + for n in range(max_node + 1): + y = margin_top + n * lane_h + lane_cls = "lane" if (n % 2 == 0) else "lane-alt" + lines.append( + f"" + ) + mid_y = y + int(lane_h * 0.5) + lines.append(f"") + lines.append(f"node{n}") + + for cyc in range(0, max_cycle + 1, tick_step): + x = margin_x + cyc * scale + lines.append(f"") + lines.append(f"{cyc}") + + for cycle, node, event, tag, write in events: + x = margin_x + cycle * scale + y = y_for(node, event) + is_accept = event == "accept" + color = "#2563eb" if is_accept else "#16a34a" + label = f"{event} node={node} tag={tag} w={write} cycle={cycle}" + if is_accept: + lines.append(f"") + lines.append(f"{label}") + lines.append("") + else: + size = 4 + points = [ + f"{x},{y - size}", + f"{x + size},{y}", + f"{x},{y + size}", + f"{x - size},{y}", + ] + lines.append( + f"" + ) + lines.append(f"{label}") + lines.append("") + + lines.append("") + out_path.write_text("\n".join(lines)) + + +def main(): + parser = argparse.ArgumentParser(description="Render TMU trace CSV into SVG timeline.") + parser.add_argument("csv", type=Path, help="Path to tmu_trace.csv") + parser.add_argument("-o", "--out", type=Path, default=Path("tmu_trace.svg"), help="Output SVG path") + parser.add_argument("--scale", type=int, default=4, help="Pixels per cycle") + parser.add_argument("--lane", type=int, default=30, help="Pixels per node lane") + args = parser.parse_args() + + events, max_cycle, max_node = load_events(args.csv) + if not events: + raise SystemExit("no events found in CSV") + events.sort(key=lambda e: (e[0], e[1], 0 if e[2] == "accept" else 1)) + render_svg(events, max_cycle, max_node, args.scale, args.lane, args.out) + + +if __name__ == "__main__": + main() diff --git a/janus/tools/run_janus_tmu_pyc_cpp.sh b/janus/tools/run_janus_tmu_pyc_cpp.sh new file mode 100755 index 0000000..c6bc44f --- /dev/null +++ b/janus/tools/run_janus_tmu_pyc_cpp.sh @@ -0,0 +1,37 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" && pwd)" +ROOT_DIR="$(cd -- "${SCRIPT_DIR}/../.." && pwd)" +# shellcheck source=../../scripts/lib.sh +source "${ROOT_DIR}/scripts/lib.sh" +pyc_find_pyc_compile + +GEN_DIR="${ROOT_DIR}/janus/generated/janus_tmu_pyc" +HDR="${GEN_DIR}/janus_tmu_pyc_gen.hpp" + +need_regen=0 +if [[ ! -f "${HDR}" ]]; then + need_regen=1 +elif find "${ROOT_DIR}/janus/pyc/janus/tmu" -name '*.py' -newer "${HDR}" | grep -q .; then + need_regen=1 +fi + +if [[ "${need_regen}" -ne 0 ]]; then + bash "${ROOT_DIR}/janus/tools/update_tmu_generated.sh" +fi + +WORK_DIR="$(mktemp -d -t janus_tmu_pyc_tb.XXXXXX)" +trap 'rm -rf "${WORK_DIR}"' EXIT + +"${CXX:-clang++}" -std=c++17 -O2 \ + -I "${ROOT_DIR}/include" \ + -I "${GEN_DIR}" \ + -o "${WORK_DIR}/tb_janus_tmu_pyc" \ + "${ROOT_DIR}/janus/tb/tb_janus_tmu_pyc.cpp" + +if [[ $# -gt 0 ]]; then + "${WORK_DIR}/tb_janus_tmu_pyc" "$@" +else + "${WORK_DIR}/tb_janus_tmu_pyc" +fi diff --git a/janus/tools/run_janus_tmu_pyc_verilator.sh b/janus/tools/run_janus_tmu_pyc_verilator.sh new file mode 100755 index 0000000..5061cc7 --- /dev/null +++ b/janus/tools/run_janus_tmu_pyc_verilator.sh @@ -0,0 +1,48 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" && pwd)" +ROOT_DIR="$(cd -- "${SCRIPT_DIR}/../.." && pwd)" +# shellcheck source=../../scripts/lib.sh +source "${ROOT_DIR}/scripts/lib.sh" +pyc_find_pyc_compile + +VERILATOR="${VERILATOR:-$(command -v verilator || true)}" +if [[ -z "${VERILATOR}" ]]; then + echo "error: missing verilator (install with: brew install verilator)" >&2 + exit 1 +fi + +GEN_DIR="${ROOT_DIR}/janus/generated/janus_tmu_pyc" +VLOG="${GEN_DIR}/janus_tmu_pyc.v" +if [[ ! -f "${VLOG}" ]]; then + bash "${ROOT_DIR}/janus/tools/update_tmu_generated.sh" +fi + +TB_SV="${ROOT_DIR}/janus/tb/tb_janus_tmu_pyc.sv" +OBJ_DIR="${GEN_DIR}/verilator_obj" +EXE="${OBJ_DIR}/Vtb_janus_tmu_pyc" + +need_build=0 +if [[ ! -x "${EXE}" ]]; then + need_build=1 +elif [[ "${TB_SV}" -nt "${EXE}" || "${VLOG}" -nt "${EXE}" ]]; then + need_build=1 +fi + +if [[ "${need_build}" -ne 0 ]]; then + mkdir -p "${OBJ_DIR}" + "${VERILATOR}" \ + --binary \ + --timing \ + --trace \ + -Wno-fatal \ + -I"${ROOT_DIR}/include/pyc/verilog" \ + --top-module tb_janus_tmu_pyc \ + "${TB_SV}" \ + "${VLOG}" \ + --Mdir "${OBJ_DIR}" +fi + +echo "[janus-vlt] tmu" +"${EXE}" "$@" diff --git a/janus/tools/update_tmu_generated.sh b/janus/tools/update_tmu_generated.sh new file mode 100755 index 0000000..b466bce --- /dev/null +++ b/janus/tools/update_tmu_generated.sh @@ -0,0 +1,23 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" && pwd)" +ROOT_DIR="$(cd -- "${SCRIPT_DIR}/../.." && pwd)" +# shellcheck source=../../scripts/lib.sh +source "${ROOT_DIR}/scripts/lib.sh" +pyc_find_pyc_compile + +OUT_ROOT="${ROOT_DIR}/janus/generated/janus_tmu_pyc" +mkdir -p "${OUT_ROOT}" + +tmp_pyc="$(mktemp -t "pycircuit.janus.tmu.XXXXXX.pyc")" + +PYTHONDONTWRITEBYTECODE=1 PYTHONPATH="$(pyc_pythonpath):${ROOT_DIR}/janus/pyc" \ + python3 -m pycircuit.cli emit "${ROOT_DIR}/janus/pyc/janus/tmu/janus_tmu_pyc.py" -o "${tmp_pyc}" + +"${PYC_COMPILE}" "${tmp_pyc}" --emit=verilog -o "${OUT_ROOT}/janus_tmu_pyc.v" +"${PYC_COMPILE}" "${tmp_pyc}" --emit=cpp -o "${OUT_ROOT}/janus_tmu_pyc.hpp" + +mv -f "${OUT_ROOT}/janus_tmu_pyc.hpp" "${OUT_ROOT}/janus_tmu_pyc_gen.hpp" + +pyc_log "ok: wrote TMU outputs under ${OUT_ROOT}" diff --git a/runtime/cpp/pyc_bits.hpp b/runtime/cpp/pyc_bits.hpp index 38b109e..bcf5f27 100644 --- a/runtime/cpp/pyc_bits.hpp +++ b/runtime/cpp/pyc_bits.hpp @@ -5,8 +5,104 @@ #include #include +#if defined(__aarch64__) || defined(_M_ARM64) +#include +#define PYC_SIMD_NEON 1 +#endif + namespace pyc::cpp { +// --------------------------------------------------------------------------- +// NEON helpers (compile to nothing on non-ARM) +// --------------------------------------------------------------------------- +namespace simd { + +#if PYC_SIMD_NEON +inline void bitwise_and(std::uint64_t *dst, const std::uint64_t *a, + const std::uint64_t *b, unsigned nWords) { + unsigned i = 0; + for (; i + 2 <= nWords; i += 2) { + uint64x2_t va = vld1q_u64(a + i); + uint64x2_t vb = vld1q_u64(b + i); + vst1q_u64(dst + i, vandq_u64(va, vb)); + } + for (; i < nWords; i++) + dst[i] = a[i] & b[i]; +} + +inline void bitwise_or(std::uint64_t *dst, const std::uint64_t *a, + const std::uint64_t *b, unsigned nWords) { + unsigned i = 0; + for (; i + 2 <= nWords; i += 2) { + uint64x2_t va = vld1q_u64(a + i); + uint64x2_t vb = vld1q_u64(b + i); + vst1q_u64(dst + i, vorrq_u64(va, vb)); + } + for (; i < nWords; i++) + dst[i] = a[i] | b[i]; +} + +inline void bitwise_xor(std::uint64_t *dst, const std::uint64_t *a, + const std::uint64_t *b, unsigned nWords) { + unsigned i = 0; + for (; i + 2 <= nWords; i += 2) { + uint64x2_t va = vld1q_u64(a + i); + uint64x2_t vb = vld1q_u64(b + i); + vst1q_u64(dst + i, veorq_u64(va, vb)); + } + for (; i < nWords; i++) + dst[i] = a[i] ^ b[i]; +} + +inline void bitwise_not(std::uint64_t *dst, const std::uint64_t *a, + unsigned nWords) { + unsigned i = 0; + for (; i + 2 <= nWords; i += 2) { + uint64x2_t va = vld1q_u64(a + i); + vst1q_u64(dst + i, vmvnq_u8(vreinterpretq_u8_u64(va))); + } + for (; i < nWords; i++) + dst[i] = ~a[i]; +} + +inline bool bitwise_eq(const std::uint64_t *a, const std::uint64_t *b, + unsigned nWords) { + unsigned i = 0; + for (; i + 2 <= nWords; i += 2) { + uint64x2_t va = vld1q_u64(a + i); + uint64x2_t vb = vld1q_u64(b + i); + uint64x2_t cmp = vceqq_u64(va, vb); + if (vgetq_lane_u64(cmp, 0) != ~std::uint64_t{0} || + vgetq_lane_u64(cmp, 1) != ~std::uint64_t{0}) + return false; + } + for (; i < nWords; i++) + if (a[i] != b[i]) + return false; + return true; +} + +// Bitwise select: dst[i] = mask[i] ? a[i] : b[i] (per-bit) +inline void bitwise_sel(std::uint64_t *dst, const std::uint64_t *mask, + const std::uint64_t *a, const std::uint64_t *b, + unsigned nWords) { + unsigned i = 0; + for (; i + 2 <= nWords; i += 2) { + uint64x2_t vm = vld1q_u64(mask + i); + uint64x2_t va = vld1q_u64(a + i); + uint64x2_t vb = vld1q_u64(b + i); + vst1q_u64(dst + i, + vbslq_u64(vreinterpretq_u64_u8( + vreinterpretq_u8_u64(vm)), + va, vb)); + } + for (; i < nWords; i++) + dst[i] = (a[i] & mask[i]) | (b[i] & ~mask[i]); +} +#endif + +} // namespace simd + template class Bits { public: @@ -58,8 +154,10 @@ class Bits { return ((word(wi) >> bi) & 1u) != 0; } + word_type *data() { return words_.data(); } + const word_type *data() const { return words_.data(); } + static constexpr word_type mask() { - // Legacy helper: returns a low-word mask for widths <= 64. if constexpr (Width >= 64) return ~word_type{0}; return (word_type{1} << Width) - 1; @@ -118,39 +216,71 @@ class Bits { return out; } - friend constexpr Bits operator&(Bits a, Bits b) { + friend Bits operator&(Bits a, Bits b) { Bits out; +#if PYC_SIMD_NEON + if constexpr (kWords >= 2) { + simd::bitwise_and(out.words_.data(), a.words_.data(), b.words_.data(), kWords); + out.maskTop(); + return out; + } +#endif for (unsigned i = 0; i < kWords; i++) out.words_[i] = a.words_[i] & b.words_[i]; out.maskTop(); return out; } - friend constexpr Bits operator|(Bits a, Bits b) { + friend Bits operator|(Bits a, Bits b) { Bits out; +#if PYC_SIMD_NEON + if constexpr (kWords >= 2) { + simd::bitwise_or(out.words_.data(), a.words_.data(), b.words_.data(), kWords); + out.maskTop(); + return out; + } +#endif for (unsigned i = 0; i < kWords; i++) out.words_[i] = a.words_[i] | b.words_[i]; out.maskTop(); return out; } - friend constexpr Bits operator^(Bits a, Bits b) { + friend Bits operator^(Bits a, Bits b) { Bits out; +#if PYC_SIMD_NEON + if constexpr (kWords >= 2) { + simd::bitwise_xor(out.words_.data(), a.words_.data(), b.words_.data(), kWords); + out.maskTop(); + return out; + } +#endif for (unsigned i = 0; i < kWords; i++) out.words_[i] = a.words_[i] ^ b.words_[i]; out.maskTop(); return out; } - friend constexpr Bits operator~(Bits a) { + friend Bits operator~(Bits a) { Bits out; +#if PYC_SIMD_NEON + if constexpr (kWords >= 2) { + simd::bitwise_not(out.words_.data(), a.words_.data(), kWords); + out.maskTop(); + return out; + } +#endif for (unsigned i = 0; i < kWords; i++) out.words_[i] = ~a.words_[i]; out.maskTop(); return out; } - friend constexpr bool operator==(Bits a, Bits b) { + friend bool operator==(Bits a, Bits b) { +#if PYC_SIMD_NEON + if constexpr (kWords >= 2) + return simd::bitwise_eq(a.words_.data(), b.words_.data(), kWords); +#endif for (unsigned i = 0; i < kWords; i++) { if (a.words_[i] != b.words_[i]) return false; @@ -158,7 +288,7 @@ class Bits { return true; } - friend constexpr bool operator!=(Bits a, Bits b) { return !(a == b); } + friend bool operator!=(Bits a, Bits b) { return !(a == b); } friend constexpr bool operator<(Bits a, Bits b) { for (unsigned i = 0; i < kWords; i++) { @@ -190,6 +320,32 @@ class Bits { template using Wire = Bits; +// SIMD-accelerated MUX: returns sel ? a : b (branch-free for wide wires) +template +inline Wire mux(Wire<1> sel, Wire a, Wire b) { +#if PYC_SIMD_NEON + if constexpr (Wire::kWords >= 2) { + Wire out; + // Broadcast sel to all bits: 0 or all-ones mask + std::uint64_t smask = sel.toBool() ? ~std::uint64_t{0} : std::uint64_t{0}; + uint64x2_t vm = vdupq_n_u64(smask); + const auto *pa = a.data(); + const auto *pb = b.data(); + auto *po = out.data(); + unsigned i = 0; + for (; i + 2 <= Wire::kWords; i += 2) { + uint64x2_t va = vld1q_u64(pa + i); + uint64x2_t vb = vld1q_u64(pb + i); + vst1q_u64(po + i, vbslq_u64(vm, va, vb)); + } + for (; i < Wire::kWords; i++) + po[i] = sel.toBool() ? pa[i] : pb[i]; + return out; + } +#endif + return sel.toBool() ? a : b; +} + template inline void appendPackedWireWords(std::array &dst, std::size_t &offset, Wire v) { for (unsigned i = 0; i < Wire::kWords; ++i) diff --git a/runtime/cpp/pyc_change_detect.hpp b/runtime/cpp/pyc_change_detect.hpp new file mode 100644 index 0000000..454cf81 --- /dev/null +++ b/runtime/cpp/pyc_change_detect.hpp @@ -0,0 +1,166 @@ +#pragma once + +#include +#include +#include + +#include "pyc_bits.hpp" + +namespace pyc::cpp { + +// --------------------------------------------------------------------------- +// ChangeDetector — lightweight snapshot-based change detection for +// individual Wire signals. Compares current value against a cached +// snapshot taken at the previous observation point. +// --------------------------------------------------------------------------- + +template +class ChangeDetector { +public: + explicit ChangeDetector(const Wire &target) : target_(target) { + snapshot_ = target; + } + + bool changed() const { return !(target_ == snapshot_); } + + void capture() { snapshot_ = target_; } + + bool check_and_capture() { + bool c = changed(); + snapshot_ = target_; + return c; + } + +private: + const Wire &target_; + Wire snapshot_{}; +}; + +// --------------------------------------------------------------------------- +// InputFingerprint — tracks whether *any* of a set of primary inputs changed +// since the last capture. Uses a simple XOR-fold hash over raw words for +// O(1) fast-path rejection, with a full comparison fallback. +// +// Usage (in a CAPI wrapper or testbench): +// InputFingerprint<80, 5, 40, 320> fp(dut.raddr_bus, dut.wen_bus, ...); +// ... +// if (fp.check_and_capture()) { dut.eval(); } +// --------------------------------------------------------------------------- + +namespace detail { + +template +inline void xor_fold(const Wire &w, std::uint64_t &acc) { + for (unsigned i = 0; i < Wire::kWords; i++) + acc ^= w.word(i) * (0x9E3779B97F4A7C15ULL + i); +} + +template +inline std::size_t wire_bytes() { + return Wire::kWords * sizeof(std::uint64_t); +} + +} // namespace detail + +template +class InputFingerprint { +public: + static constexpr std::size_t kTotalWords = ((Wire::kWords + ... + 0)); + + explicit InputFingerprint(const Wire &...wires) + : ptrs_{wires.data()...}, sizes_{Wire::kWords...} { + do_capture(); + } + + bool changed() const { + std::uint64_t h = 0; + std::size_t idx = 0; + auto fold = [&](const std::uint64_t *p, unsigned nw) { + for (unsigned i = 0; i < nw; i++) + h ^= p[i] * (0x9E3779B97F4A7C15ULL + idx++); + }; + for (unsigned k = 0; k < sizeof...(Widths); k++) + fold(ptrs_[k], sizes_[k]); + + if (h != hash_) + return true; + + idx = 0; + for (unsigned k = 0; k < sizeof...(Widths); k++) { + if (std::memcmp(ptrs_[k], &snapshot_[idx], + sizes_[k] * sizeof(std::uint64_t)) != 0) + return true; + idx += sizes_[k]; + } + return false; + } + + void capture() { do_capture(); } + + bool check_and_capture() { + bool c = changed(); + do_capture(); + return c; + } + +private: + void do_capture() { + hash_ = 0; + std::size_t idx = 0; + std::size_t fold_idx = 0; + for (unsigned k = 0; k < sizeof...(Widths); k++) { + for (unsigned i = 0; i < sizes_[k]; i++) { + snapshot_[idx] = ptrs_[k][i]; + hash_ ^= ptrs_[k][i] * (0x9E3779B97F4A7C15ULL + fold_idx++); + idx++; + } + } + } + + const std::uint64_t *ptrs_[sizeof...(Widths)]; + unsigned sizes_[sizeof...(Widths)]; + std::uint64_t hash_ = 0; + std::uint64_t snapshot_[kTotalWords]{}; +}; + +// --------------------------------------------------------------------------- +// EvalGuard — wraps an eval_comb function call, only executing if at least +// one input Wire changed since the last invocation. +// +// Template parameters: +// Fn — callable (lambda / function pointer) for the eval_comb body +// InputWidths — widths of the input Wires tracked by this guard +// +// Usage: +// EvalGuard guard([&]{ dut.eval_comb_0(); }, dut.raddr_bus, dut.wen_bus); +// guard.eval(); // only calls eval_comb_0 if raddr_bus or wen_bus changed +// --------------------------------------------------------------------------- + +template +class EvalGuard { +public: + explicit EvalGuard(Fn fn, const Wire &...inputs) + : fn_(fn), fp_(inputs...) {} + + bool eval() { + if (fp_.check_and_capture()) { + fn_(); + return true; + } + return false; + } + + void force_eval() { + fp_.capture(); + fn_(); + } + +private: + Fn fn_; + InputFingerprint fp_; +}; + +template +EvalGuard(Fn, const Wire &...) -> EvalGuard; + +} // namespace pyc::cpp diff --git a/runtime/cpp/pyc_primitives.hpp b/runtime/cpp/pyc_primitives.hpp index ad648b2..7d7ff43 100644 --- a/runtime/cpp/pyc_primitives.hpp +++ b/runtime/cpp/pyc_primitives.hpp @@ -71,33 +71,53 @@ class pyc_reg { pyc_reg(Wire<1> &clk, Wire<1> &rst, Wire<1> &en, Wire &d, Wire &init, Wire &q) : clk(clk), rst(rst), en(en), d(d), init(init), q(q) {} - void tick_compute() { + // Branch-optimized two-phase update. + // tick_compute: sample inputs; tick_commit: apply. + inline void tick_compute() { bool clkNow = clk.toBool(); - bool posedge = (!clkPrev) && clkNow; + bool posedge = (!clkPrev) & clkNow; clkPrev = clkNow; - pending = false; - if (!posedge) - return; - - if (rst.toBool()) { - pending = true; - qNext = init; - return; - } - if (en.toBool()) { - pending = true; - qNext = d; + if (__builtin_expect(!posedge, 1)) { + pending = false; return; } + posedge_compute_inner(); } - void tick_commit() { - if (!pending) - return; - q = qNext; + // Direct posedge path — caller guarantees a 0→1 edge just occurred. + // Saves the clkPrev read + posedge check (~2 branches per register). + inline void posedge_tick_compute() { + clkPrev = true; + posedge_compute_inner(); + } + + // Negedge bookkeeping — just reset clkPrev so next posedge is detected. + // Avoids running the full tick_compute logic on the falling edge. + inline void negedge_update() { + clkPrev = false; pending = false; } + inline void tick_commit() { + if (__builtin_expect(pending, 0)) { + q = qNext; + pending = false; + } + } + +private: + inline void posedge_compute_inner() { + bool r = rst.toBool(); + bool e = en.toBool(); + pending = r | e; + if (r) + qNext = init; + else + qNext = d; + } + +public: + Wire<1> &clk; Wire<1> &rst; Wire<1> &en; diff --git a/runtime/cpp/pyc_sim.hpp b/runtime/cpp/pyc_sim.hpp index 8e882c7..ec0a9ee 100644 --- a/runtime/cpp/pyc_sim.hpp +++ b/runtime/cpp/pyc_sim.hpp @@ -1,6 +1,7 @@ #pragma once #include "pyc_bits.hpp" +#include "pyc_change_detect.hpp" #include "pyc_clock.hpp" #include "pyc_connector.hpp" #include "pyc_cdc_sync.hpp" diff --git a/runtime/cpp/pyc_tb.hpp b/runtime/cpp/pyc_tb.hpp index ab30a96..b09522c 100644 --- a/runtime/cpp/pyc_tb.hpp +++ b/runtime/cpp/pyc_tb.hpp @@ -60,6 +60,19 @@ struct has_transfer : std::false_type {}; template struct has_transfer().transfer())>> : std::true_type {}; +// DUT may provide split posedge/negedge tick for faster simulation. +template +struct has_tick_posedge : std::false_type {}; + +template +struct has_tick_posedge().tick_posedge())>> : std::true_type {}; + +template +struct has_tick_negedge : std::false_type {}; + +template +struct has_tick_negedge().tick_negedge())>> : std::true_type {}; + template inline void maybe_comb(T &dut) { if constexpr (has_comb::value) { @@ -76,6 +89,24 @@ inline void maybe_transfer(T &dut) { } } +template +inline void maybe_tick_posedge(T &dut) { + if constexpr (has_tick_posedge::value) { + dut.tick_posedge(); + } else { + dut.tick(); + } +} + +template +inline void maybe_tick_negedge(T &dut) { + if constexpr (has_tick_negedge::value) { + dut.tick_negedge(); + } else { + dut.tick(); + } +} + } // namespace detail template @@ -297,7 +328,7 @@ class Testbench { // Posedge phase. detail::maybe_comb(dut_); c.set(true); - dut_.tick(); + detail::maybe_tick_posedge(dut_); detail::maybe_transfer(dut_); detail::maybe_comb(dut_); if (shouldDumpVcd(time_)) @@ -306,7 +337,7 @@ class Testbench { // Negedge bookkeeping (no extra combinational settle needed here). c.set(false); - dut_.tick(); + detail::maybe_tick_negedge(dut_); detail::maybe_transfer(dut_); if (shouldDumpVcd(time_)) vcd_->dump(time_); @@ -318,12 +349,12 @@ class Testbench { for (std::uint64_t i = 0; i < cycles; i++) { detail::maybe_comb(dut_); c.set(true); - dut_.tick(); + detail::maybe_tick_posedge(dut_); detail::maybe_transfer(dut_); detail::maybe_comb(dut_); time_++; c.set(false); - dut_.tick(); + detail::maybe_tick_negedge(dut_); detail::maybe_transfer(dut_); time_++; }