Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions rtl/moe_router/expert_gate.sv
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
// SPDX-License-Identifier: Apache-2.0
// W42 MoE Sparse Routing — top-2-of-8 expert selector
// THESIS: NO new L1 opcode — composes existing OP_SPARSE_SKIP (0xE8) + OP_SPARSE_MASK (0xED)
// Anchor: phi^2 + phi^-2 = 3 · NEVER STOP · DOI 10.5281/zenodo.19227877
// R-SI-1: ZERO mul characters in synth path.

module expert_gate (
input wire [7:0] logit0,
input wire [7:0] logit1,
input wire [7:0] logit2,
input wire [7:0] logit3,
input wire [7:0] logit4,
input wire [7:0] logit5,
input wire [7:0] logit6,
input wire [7:0] logit7,
output wire [7:0] mask_out, // 8-bit one-hot-per-expert mask: top-2 set
output wire [2:0] top1_idx,
output wire [2:0] top2_idx
);
// Comparison stage 1: pairwise comparator network (no mul anywhere)
// To select top-2 we use a simple O(N^2) comparator network.
// For each expert i, count how many other experts j "beat" expert i:
// j beats i if logit_j > logit_i, OR (logit_j == logit_i AND j < i)
// This gives a strict total order with stable tie-breaking by index.
// Expert is in top-2 iff its rank (number of beaters) <= 1.

// Compute rank_i using stable comparator: j beats i if
// (logit_j > logit_i) OR (logit_j == logit_i AND j_idx < self_idx)
wire [2:0] rank0, rank1, rank2, rank3, rank4, rank5, rank6, rank7;

function automatic [2:0] count_gt_stable;
input [7:0] self_logit;
input [7:0] l0, l1, l2, l3, l4, l5, l6, l7;
input [2:0] self_idx;
reg [2:0] cnt;
reg b0, b1, b2, b3, b4, b5, b6, b7;
begin
// b_j = 1 if j beats self (j != self)
b0 = (self_idx != 3'd0) & ((l0 > self_logit) | ((l0 == self_logit) & (3'd0 < self_idx)));
b1 = (self_idx != 3'd1) & ((l1 > self_logit) | ((l1 == self_logit) & (3'd1 < self_idx)));
b2 = (self_idx != 3'd2) & ((l2 > self_logit) | ((l2 == self_logit) & (3'd2 < self_idx)));
b3 = (self_idx != 3'd3) & ((l3 > self_logit) | ((l3 == self_logit) & (3'd3 < self_idx)));
b4 = (self_idx != 3'd4) & ((l4 > self_logit) | ((l4 == self_logit) & (3'd4 < self_idx)));
b5 = (self_idx != 3'd5) & ((l5 > self_logit) | ((l5 == self_logit) & (3'd5 < self_idx)));
b6 = (self_idx != 3'd6) & ((l6 > self_logit) | ((l6 == self_logit) & (3'd6 < self_idx)));
b7 = (self_idx != 3'd7) & ((l7 > self_logit) | ((l7 == self_logit) & (3'd7 < self_idx)));
// explicit addition, no mul (verilog '+' is fine under R-SI-1)
cnt = {2'b0, b0} + {2'b0, b1} + {2'b0, b2} + {2'b0, b3}
+ {2'b0, b4} + {2'b0, b5} + {2'b0, b6} + {2'b0, b7};
count_gt_stable = cnt;
end
endfunction

assign rank0 = count_gt_stable(logit0, logit0, logit1, logit2, logit3, logit4, logit5, logit6, logit7, 3'd0);
assign rank1 = count_gt_stable(logit1, logit0, logit1, logit2, logit3, logit4, logit5, logit6, logit7, 3'd1);
assign rank2 = count_gt_stable(logit2, logit0, logit1, logit2, logit3, logit4, logit5, logit6, logit7, 3'd2);
assign rank3 = count_gt_stable(logit3, logit0, logit1, logit2, logit3, logit4, logit5, logit6, logit7, 3'd3);
assign rank4 = count_gt_stable(logit4, logit0, logit1, logit2, logit3, logit4, logit5, logit6, logit7, 3'd4);
assign rank5 = count_gt_stable(logit5, logit0, logit1, logit2, logit3, logit4, logit5, logit6, logit7, 3'd5);
assign rank6 = count_gt_stable(logit6, logit0, logit1, logit2, logit3, logit4, logit5, logit6, logit7, 3'd6);
assign rank7 = count_gt_stable(logit7, logit0, logit1, logit2, logit3, logit4, logit5, logit6, logit7, 3'd7);

// Mask bit set iff rank <= 1
assign mask_out[0] = (rank0 <= 3'd1);
assign mask_out[1] = (rank1 <= 3'd1);
assign mask_out[2] = (rank2 <= 3'd1);
assign mask_out[3] = (rank3 <= 3'd1);
assign mask_out[4] = (rank4 <= 3'd1);
assign mask_out[5] = (rank5 <= 3'd1);
assign mask_out[6] = (rank6 <= 3'd1);
assign mask_out[7] = (rank7 <= 3'd1);

// top1_idx = index of rank == 0
assign top1_idx = (rank0 == 3'd0) ? 3'd0 :
(rank1 == 3'd0) ? 3'd1 :
(rank2 == 3'd0) ? 3'd2 :
(rank3 == 3'd0) ? 3'd3 :
(rank4 == 3'd0) ? 3'd4 :
(rank5 == 3'd0) ? 3'd5 :
(rank6 == 3'd0) ? 3'd6 : 3'd7;

// top2_idx = index of rank == 1
assign top2_idx = (rank0 == 3'd1) ? 3'd0 :
(rank1 == 3'd1) ? 3'd1 :
(rank2 == 3'd1) ? 3'd2 :
(rank3 == 3'd1) ? 3'd3 :
(rank4 == 3'd1) ? 3'd4 :
(rank5 == 3'd1) ? 3'd5 :
(rank6 == 3'd1) ? 3'd6 : 3'd7;

endmodule
55 changes: 55 additions & 0 deletions rtl/moe_router/expert_gate_tb.sv
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// SPDX-License-Identifier: Apache-2.0
`timescale 1ns/1ps
module expert_gate_tb;
reg [7:0] l0, l1, l2, l3, l4, l5, l6, l7;
wire [7:0] mask;
wire [2:0] t1, t2;

expert_gate dut (.logit0(l0), .logit1(l1), .logit2(l2), .logit3(l3),
.logit4(l4), .logit5(l5), .logit6(l6), .logit7(l7),
.mask_out(mask), .top1_idx(t1), .top2_idx(t2));

integer errors;
initial begin
errors = 0;

// A1: clear ordering [1,5,3,8,2,7,4,6] → top1=idx3, top2=idx5
l0=8;l1=5;l2=3;l3=128;l4=2;l5=127;l6=4;l7=6; #1;
// wait: i used wrong values. Re-test: logits 1,5,3,8,2,7,4,6 with idx 0..7.
l0=1;l1=5;l2=3;l3=8;l4=2;l5=7;l6=4;l7=6; #1;
if (t1 !== 3'd3) begin $display("FAIL A1 t1=%0d expected 3", t1); errors = errors + 1; end
if (t2 !== 3'd5) begin $display("FAIL A1 t2=%0d expected 5", t2); errors = errors + 1; end

// A2: exactly 2 bits set in mask
if ($countones(mask) !== 2) begin $display("FAIL A2 popcount=%0d expected 2", $countones(mask)); errors = errors + 1; end

// A3: mask[3] and mask[5] set
if (mask[3] !== 1'b1 || mask[5] !== 1'b1) begin $display("FAIL A3 mask=%b", mask); errors = errors + 1; end

// A4: all-equal logits → top1=0, top2=1 (stable tie-break by index)
l0=8'd5;l1=8'd5;l2=8'd5;l3=8'd5;l4=8'd5;l5=8'd5;l6=8'd5;l7=8'd5; #1;
if (t1 !== 3'd0) begin $display("FAIL A4 t1=%0d expected 0", t1); errors = errors + 1; end
if (t2 !== 3'd1) begin $display("FAIL A4 t2=%0d expected 1", t2); errors = errors + 1; end

// A5: monotone descending → top1=0, top2=1
l0=8'd80;l1=8'd70;l2=8'd60;l3=8'd50;l4=8'd40;l5=8'd30;l6=8'd20;l7=8'd10; #1;
if (t1 !== 3'd0) begin $display("FAIL A5 t1=%0d expected 0", t1); errors = errors + 1; end
if (t2 !== 3'd1) begin $display("FAIL A5 t2=%0d expected 1", t2); errors = errors + 1; end

// A6: monotone ascending → top1=7, top2=6
l0=8'd10;l1=8'd20;l2=8'd30;l3=8'd40;l4=8'd50;l5=8'd60;l6=8'd70;l7=8'd80; #1;
if (t1 !== 3'd7) begin $display("FAIL A6 t1=%0d expected 7", t1); errors = errors + 1; end
if (t2 !== 3'd6) begin $display("FAIL A6 t2=%0d expected 6", t2); errors = errors + 1; end

// A7: mask has exactly 2 bits in all cases
if ($countones(mask) !== 2) begin $display("FAIL A7 popcount=%0d", $countones(mask)); errors = errors + 1; end

if (errors == 0) begin
$display("ALL 7 ASSERTIONS PASSED · MoE top-2-of-8 · NO NEW OPCODE · phi^2+phi^-2=3");
$finish;
end else begin
$display("FAIL: %0d errors", errors);
$fatal(1);
end
end
endmodule
Loading