diff --git a/rtl/moe_router/expert_gate.sv b/rtl/moe_router/expert_gate.sv new file mode 100644 index 000000000..d1b39bbed --- /dev/null +++ b/rtl/moe_router/expert_gate.sv @@ -0,0 +1,91 @@ +// SPDX-License-Identifier: Apache-2.0 +// W42 MoE Sparse Routing — top-2-of-8 expert selector +// THESIS: NO new L1 opcode — composes existing OP_SPARSE_SKIP (0xE8) + OP_SPARSE_MASK (0xED) +// Anchor: phi^2 + phi^-2 = 3 · NEVER STOP · DOI 10.5281/zenodo.19227877 +// R-SI-1: ZERO mul characters in synth path. + +module expert_gate ( + input wire [7:0] logit0, + input wire [7:0] logit1, + input wire [7:0] logit2, + input wire [7:0] logit3, + input wire [7:0] logit4, + input wire [7:0] logit5, + input wire [7:0] logit6, + input wire [7:0] logit7, + output wire [7:0] mask_out, // 8-bit one-hot-per-expert mask: top-2 set + output wire [2:0] top1_idx, + output wire [2:0] top2_idx +); + // Comparison stage 1: pairwise comparator network (no mul anywhere) + // To select top-2 we use a simple O(N^2) comparator network. + // For each expert i, count how many other experts j "beat" expert i: + // j beats i if logit_j > logit_i, OR (logit_j == logit_i AND j < i) + // This gives a strict total order with stable tie-breaking by index. + // Expert is in top-2 iff its rank (number of beaters) <= 1. + + // Compute rank_i using stable comparator: j beats i if + // (logit_j > logit_i) OR (logit_j == logit_i AND j_idx < self_idx) + wire [2:0] rank0, rank1, rank2, rank3, rank4, rank5, rank6, rank7; + + function automatic [2:0] count_gt_stable; + input [7:0] self_logit; + input [7:0] l0, l1, l2, l3, l4, l5, l6, l7; + input [2:0] self_idx; + reg [2:0] cnt; + reg b0, b1, b2, b3, b4, b5, b6, b7; + begin + // b_j = 1 if j beats self (j != self) + b0 = (self_idx != 3'd0) & ((l0 > self_logit) | ((l0 == self_logit) & (3'd0 < self_idx))); + b1 = (self_idx != 3'd1) & ((l1 > self_logit) | ((l1 == self_logit) & (3'd1 < self_idx))); + b2 = (self_idx != 3'd2) & ((l2 > self_logit) | ((l2 == self_logit) & (3'd2 < self_idx))); + b3 = (self_idx != 3'd3) & ((l3 > self_logit) | ((l3 == self_logit) & (3'd3 < self_idx))); + b4 = (self_idx != 3'd4) & ((l4 > self_logit) | ((l4 == self_logit) & (3'd4 < self_idx))); + b5 = (self_idx != 3'd5) & ((l5 > self_logit) | ((l5 == self_logit) & (3'd5 < self_idx))); + b6 = (self_idx != 3'd6) & ((l6 > self_logit) | ((l6 == self_logit) & (3'd6 < self_idx))); + b7 = (self_idx != 3'd7) & ((l7 > self_logit) | ((l7 == self_logit) & (3'd7 < self_idx))); + // explicit addition, no mul (verilog '+' is fine under R-SI-1) + cnt = {2'b0, b0} + {2'b0, b1} + {2'b0, b2} + {2'b0, b3} + + {2'b0, b4} + {2'b0, b5} + {2'b0, b6} + {2'b0, b7}; + count_gt_stable = cnt; + end + endfunction + + assign rank0 = count_gt_stable(logit0, logit0, logit1, logit2, logit3, logit4, logit5, logit6, logit7, 3'd0); + assign rank1 = count_gt_stable(logit1, logit0, logit1, logit2, logit3, logit4, logit5, logit6, logit7, 3'd1); + assign rank2 = count_gt_stable(logit2, logit0, logit1, logit2, logit3, logit4, logit5, logit6, logit7, 3'd2); + assign rank3 = count_gt_stable(logit3, logit0, logit1, logit2, logit3, logit4, logit5, logit6, logit7, 3'd3); + assign rank4 = count_gt_stable(logit4, logit0, logit1, logit2, logit3, logit4, logit5, logit6, logit7, 3'd4); + assign rank5 = count_gt_stable(logit5, logit0, logit1, logit2, logit3, logit4, logit5, logit6, logit7, 3'd5); + assign rank6 = count_gt_stable(logit6, logit0, logit1, logit2, logit3, logit4, logit5, logit6, logit7, 3'd6); + assign rank7 = count_gt_stable(logit7, logit0, logit1, logit2, logit3, logit4, logit5, logit6, logit7, 3'd7); + + // Mask bit set iff rank <= 1 + assign mask_out[0] = (rank0 <= 3'd1); + assign mask_out[1] = (rank1 <= 3'd1); + assign mask_out[2] = (rank2 <= 3'd1); + assign mask_out[3] = (rank3 <= 3'd1); + assign mask_out[4] = (rank4 <= 3'd1); + assign mask_out[5] = (rank5 <= 3'd1); + assign mask_out[6] = (rank6 <= 3'd1); + assign mask_out[7] = (rank7 <= 3'd1); + + // top1_idx = index of rank == 0 + assign top1_idx = (rank0 == 3'd0) ? 3'd0 : + (rank1 == 3'd0) ? 3'd1 : + (rank2 == 3'd0) ? 3'd2 : + (rank3 == 3'd0) ? 3'd3 : + (rank4 == 3'd0) ? 3'd4 : + (rank5 == 3'd0) ? 3'd5 : + (rank6 == 3'd0) ? 3'd6 : 3'd7; + + // top2_idx = index of rank == 1 + assign top2_idx = (rank0 == 3'd1) ? 3'd0 : + (rank1 == 3'd1) ? 3'd1 : + (rank2 == 3'd1) ? 3'd2 : + (rank3 == 3'd1) ? 3'd3 : + (rank4 == 3'd1) ? 3'd4 : + (rank5 == 3'd1) ? 3'd5 : + (rank6 == 3'd1) ? 3'd6 : 3'd7; + +endmodule diff --git a/rtl/moe_router/expert_gate_tb.sv b/rtl/moe_router/expert_gate_tb.sv new file mode 100644 index 000000000..42c526572 --- /dev/null +++ b/rtl/moe_router/expert_gate_tb.sv @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: Apache-2.0 +`timescale 1ns/1ps +module expert_gate_tb; + reg [7:0] l0, l1, l2, l3, l4, l5, l6, l7; + wire [7:0] mask; + wire [2:0] t1, t2; + + expert_gate dut (.logit0(l0), .logit1(l1), .logit2(l2), .logit3(l3), + .logit4(l4), .logit5(l5), .logit6(l6), .logit7(l7), + .mask_out(mask), .top1_idx(t1), .top2_idx(t2)); + + integer errors; + initial begin + errors = 0; + + // A1: clear ordering [1,5,3,8,2,7,4,6] → top1=idx3, top2=idx5 + l0=8;l1=5;l2=3;l3=128;l4=2;l5=127;l6=4;l7=6; #1; + // wait: i used wrong values. Re-test: logits 1,5,3,8,2,7,4,6 with idx 0..7. + l0=1;l1=5;l2=3;l3=8;l4=2;l5=7;l6=4;l7=6; #1; + if (t1 !== 3'd3) begin $display("FAIL A1 t1=%0d expected 3", t1); errors = errors + 1; end + if (t2 !== 3'd5) begin $display("FAIL A1 t2=%0d expected 5", t2); errors = errors + 1; end + + // A2: exactly 2 bits set in mask + if ($countones(mask) !== 2) begin $display("FAIL A2 popcount=%0d expected 2", $countones(mask)); errors = errors + 1; end + + // A3: mask[3] and mask[5] set + if (mask[3] !== 1'b1 || mask[5] !== 1'b1) begin $display("FAIL A3 mask=%b", mask); errors = errors + 1; end + + // A4: all-equal logits → top1=0, top2=1 (stable tie-break by index) + l0=8'd5;l1=8'd5;l2=8'd5;l3=8'd5;l4=8'd5;l5=8'd5;l6=8'd5;l7=8'd5; #1; + if (t1 !== 3'd0) begin $display("FAIL A4 t1=%0d expected 0", t1); errors = errors + 1; end + if (t2 !== 3'd1) begin $display("FAIL A4 t2=%0d expected 1", t2); errors = errors + 1; end + + // A5: monotone descending → top1=0, top2=1 + l0=8'd80;l1=8'd70;l2=8'd60;l3=8'd50;l4=8'd40;l5=8'd30;l6=8'd20;l7=8'd10; #1; + if (t1 !== 3'd0) begin $display("FAIL A5 t1=%0d expected 0", t1); errors = errors + 1; end + if (t2 !== 3'd1) begin $display("FAIL A5 t2=%0d expected 1", t2); errors = errors + 1; end + + // A6: monotone ascending → top1=7, top2=6 + l0=8'd10;l1=8'd20;l2=8'd30;l3=8'd40;l4=8'd50;l5=8'd60;l6=8'd70;l7=8'd80; #1; + if (t1 !== 3'd7) begin $display("FAIL A6 t1=%0d expected 7", t1); errors = errors + 1; end + if (t2 !== 3'd6) begin $display("FAIL A6 t2=%0d expected 6", t2); errors = errors + 1; end + + // A7: mask has exactly 2 bits in all cases + if ($countones(mask) !== 2) begin $display("FAIL A7 popcount=%0d", $countones(mask)); errors = errors + 1; end + + if (errors == 0) begin + $display("ALL 7 ASSERTIONS PASSED · MoE top-2-of-8 · NO NEW OPCODE · phi^2+phi^-2=3"); + $finish; + end else begin + $display("FAIL: %0d errors", errors); + $fatal(1); + end + end +endmodule