Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions examples/batch_hash_lookup_dump/passes_dump/00_frontend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# pypto.program: BatchHashLookup
import pypto.language as pl

@pl.program
class BatchHashLookup:
@pl.function
def batch_hash_lookup(self, search_key: pl.Tensor[[1024, 64, 32], pl.INT32], hash_table_size: pl.Tensor[[64, 32], pl.INT32], hash_base_ptr: pl.Tensor[[64, 32], pl.INT32], hash_pool: pl.Tensor[[64, 128, 32], pl.INT32], value_ptr_out: pl.Tensor[[1024, 64, 32], pl.INT32]) -> pl.Tensor[[1024, 64, 32], pl.INT32]:
for b in pl.parallel(0, 1024, 32):
with pl.incore():
for ti in pl.parallel(0, 64, 32):
zero_src: pl.Tensor[[1, 32], pl.INT32] = pl.tensor.view(search_key, [1, 32], [b, ti, 0])
zero_tile: pl.Tensor[[1, 32], pl.INDEX] = pl.tensor.mul(zero_src, 0)
value_ptr_out: pl.Tensor[[1024, 64, 32], pl.INT32] = pl.tensor.assemble(value_ptr_out, zero_tile, [b, ti, 0])
for probe in pl.range(0, 8, 1):
round_has_active: pl.Scalar[pl.INDEX] = 0
with pl.incore():
for b in pl.parallel(0, 1024, 32):
for ti in pl.parallel(0, 64, 32):
keys_tile: pl.Tensor[[1, 32], pl.INT32] = pl.tensor.view(search_key, [1, 32], [b, ti, 0])
mixed: pl.Tensor[[1, 32], pl.INDEX] = pl.tensor.mul(keys_tile, 2654435761)
h_probe: pl.Tensor[[1, 32], pl.INDEX] = pl.tensor.ands(pl.tensor.add(mixed, probe * 2246822519), 64 - 1)
cand_key: pl.Tensor[[1, 32], pl.INDEX] = pl.tensor.mul(keys_tile, 0)
cand_val: pl.Tensor[[1, 32], pl.INDEX] = pl.tensor.mul(keys_tile, 0)
for bucket in pl.range(0, 64, 1):
bucket_mask: pl.Tensor[[1, 32], pl.INDEX] = pl.tensor.cmps(h_probe, bucket, cmp_type=0)
bucket_keys: pl.Tensor[[1, 32], pl.INT32] = pl.tensor.view(hash_pool, [1, 32], [ti, bucket, 0])
bucket_vals: pl.Tensor[[1, 32], pl.INT32] = pl.tensor.view(hash_pool, [1, 32], [ti, 64 + bucket, 0])
cand_key: pl.Tensor[[1, 32], pl.INDEX] = pl.tensor.sel(bucket_mask, bucket_keys, cand_key)
cand_val: pl.Tensor[[1, 32], pl.INDEX] = pl.tensor.sel(bucket_mask, bucket_vals, cand_val)
result_prev: pl.Tensor[[1, 32], pl.INT32] = pl.tensor.view(value_ptr_out, [1, 32], [b, ti, 0])
active_mask: pl.Tensor[[1, 32], pl.INDEX] = pl.tensor.cmps(result_prev, 0, cmp_type=0)
active_count: pl.Tensor[[1, 1], pl.INDEX] = pl.tensor.row_sum(active_mask)
active_count_s: pl.Scalar[pl.INDEX] = pl.tensor.read(active_count, [0, 0])
if active_count_s != 0:
round_has_active: pl.Scalar[pl.INDEX] = 1
key_match: pl.Tensor[[1, 32], pl.INDEX] = pl.tensor.cmp(cand_key, keys_tile, cmp_type=0)
hit_mask: pl.Tensor[[1, 32], pl.INDEX] = pl.tensor.and(active_mask, key_match)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
python - <<'PY'
import ast
from pathlib import Path

root = Path("examples/batch_hash_lookup_dump/passes_dump")
for path in sorted(root.glob("*.py")):
    try:
        ast.parse(path.read_text(), filename=str(path))
    except SyntaxError as exc:
        print(f"{path}:{exc.lineno}:{exc.offset}: {exc.msg}")
PY

Repository: hw-native-sys/pypto-lib

Length of output: 1373


🏁 Script executed:

#!/bin/bash
# Check the actual line 37 in the first problematic file
head -40 examples/batch_hash_lookup_dump/passes_dump/00_frontend.py | tail -5

Repository: hw-native-sys/pypto-lib

Length of output: 571


🏁 Script executed:

#!/bin/bash
# Get more context around the issue - let's see the full function/block
sed -n '30,45p' examples/batch_hash_lookup_dump/passes_dump/00_frontend.py

Repository: hw-native-sys/pypto-lib

Length of output: 1168


🏁 Script executed:

#!/bin/bash
# Check if these are generated files (look for generator markers/comments)
head -20 examples/batch_hash_lookup_dump/passes_dump/00_frontend.py

Repository: hw-native-sys/pypto-lib

Length of output: 1387


🏁 Script executed:

#!/bin/bash
# Look for the source generator/origin of these files
find examples/batch_hash_lookup_dump -type f -name "*.py" | head -20
# Check directory structure
ls -la examples/batch_hash_lookup_dump/

Repository: hw-native-sys/pypto-lib

Length of output: 1230


🏁 Script executed:

#!/bin/bash
# Look for the compiler/generator that creates these pass dumps
find examples/batch_hash_lookup_dump -type f ! -name "*.py" | head -20

Repository: hw-native-sys/pypto-lib

Length of output: 49


🏁 Script executed:

#!/bin/bash
# Search for references to "batch_hash_lookup" to find the original source
rg "batch_hash_lookup" examples/batch_hash_lookup_dump/ -l --type py

Repository: hw-native-sys/pypto-lib

Length of output: 1061


🏁 Script executed:

#!/bin/bash
# Search for code that might generate these dumps or emit pl.tensor.and
rg "pl\.tensor\.and|tensor\.and" --type py | head -20

Repository: hw-native-sys/pypto-lib

Length of output: 3786


🏁 Script executed:

#!/bin/bash
# Check if there's a generator script or configuration
ls -la examples/batch_hash_lookup_dump/

Repository: hw-native-sys/pypto-lib

Length of output: 235


🏁 Script executed:

#!/bin/bash
# Search for files that mention "passes_dump" to find the generator
rg "passes_dump" --type py | head -20

Repository: hw-native-sys/pypto-lib

Length of output: 49


🏁 Script executed:

#!/bin/bash
# Look for the original program source (likely in examples/)
find examples -type f -name "*.py" ! -path "*/passes_dump/*" | head -20

Repository: hw-native-sys/pypto-lib

Length of output: 928


🏁 Script executed:

#!/bin/bash
# Search for compiler or generator code
rg "batch_hash_lookup|BatchHashLookup" --type py -l | grep -v passes_dump

Repository: hw-native-sys/pypto-lib

Length of output: 49


🏁 Script executed:

#!/bin/bash
# Check if there's a test or runner script
find . -type f -name "*.py" -path "*/examples/*" ! -path "*/passes_dump/*" | xargs grep -l "passes_dump\|batch_hash" 2>/dev/null | head -10

Repository: hw-native-sys/pypto-lib

Length of output: 49


🏁 Script executed:

#!/bin/bash
# Search for batch_hash_lookup in all file types
rg "batch_hash_lookup|BatchHashLookup" -t all 2>/dev/null | grep -v "passes_dump" | head -20

Repository: hw-native-sys/pypto-lib

Length of output: 49


🏁 Script executed:

#!/bin/bash
# Check if there's a build script or Makefile that generates these
find . -name "Makefile" -o -name "*.mk" -o -name "CMakeLists.txt" -o -name "setup.py" -o -name "build*.py" | xargs grep -l "passes_dump\|batch_hash" 2>/dev/null

Repository: hw-native-sys/pypto-lib

Length of output: 51


🏁 Script executed:

#!/bin/bash
# Look for any Python code that generates tensor operations
rg "tensor\.and|tensor\.ands" --type py -B 5 -A 5 | head -50

Repository: hw-native-sys/pypto-lib

Length of output: 8625


🏁 Script executed:

#!/bin/bash
# Check if there's a compiler or DSL-to-Python generator
find . -type f \( -name "*.py" -o -name "*.cpp" -o -name "*.cc" -o -name "*.rs" \) -path "*/compiler/*" -o -path "*/codegen/*" | head -20

Repository: hw-native-sys/pypto-lib

Length of output: 49


Fix the reserved keyword usage in tensor operation calls and regenerate pass dumps.

The source compiler is emitting pl.tensor.and(...) and pl.tensor.ands(...) calls, which use Python's reserved keyword and as a method name. This makes all 13 pass-dump files unparsable: lines 37 (and similar lines in later passes) contain hit_mask: pl.Tensor[[1, 32], pl.INDEX] = pl.tensor.and(active_mask, key_match), which is invalid Python syntax.

Verify the issue with:

import ast
from pathlib import Path
root = Path("examples/batch_hash_lookup_dump/passes_dump")
for path in sorted(root.glob("*.py")):
    try:
        ast.parse(path.read_text(), filename=str(path))
    except SyntaxError as exc:
        print(f"{path}:{exc.lineno}:{exc.offset}: {exc.msg}")

The problem persists across all compilation stages. Fix the tensor operation names in the code generator (replace and/ands with safe alternatives like bit_and/bit_ands or similar), regenerate all pass-dump artifacts, and add a Python syntax validation check to CI to prevent similar regressions.

🧰 Tools
🪛 Ruff (0.15.6)

[warning] 37-37: Expected an identifier, but found a keyword and that cannot be used here

(invalid-syntax)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/batch_hash_lookup_dump/passes_dump/00_frontend.py` at line 37, The
generated pass-dump uses Python reserved names pl.tensor.and / pl.tensor.ands
(e.g., the expression assigning hit_mask from pl.tensor.and(active_mask,
key_match)), which produces invalid syntax; update the code generator that emits
tensor ops so it emits non-keyword names such as pl.tensor.bit_and and
pl.tensor.bit_ands (or another safe mapping) wherever pl.tensor.and /
pl.tensor.ands are produced, then regenerate all pass-dump artifacts (so lines
assigning hit_mask from active_mask and key_match are fixed) and add a CI step
that validates each generated .py with ast.parse (or a simple python -m
py_compile) to catch syntax errors in future dumps.

result_next: pl.Tensor[[1, 32], pl.INDEX] = pl.tensor.sel(hit_mask, cand_val, result_prev)
value_ptr_out: pl.Tensor[[1024, 64, 32], pl.INT32] = pl.tensor.assemble(value_ptr_out, result_next, [b, ti, 0])
if round_has_active == 0:
break
return value_ptr_out
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# pypto.program: BatchHashLookup
import pypto.language as pl

@pl.program
class BatchHashLookup:
@pl.function
def batch_hash_lookup(self, search_key: pl.Tensor[[1024, 64, 32], pl.INT32], hash_table_size: pl.Tensor[[64, 32], pl.INT32], hash_base_ptr: pl.Tensor[[64, 32], pl.INT32], hash_pool: pl.Tensor[[64, 128, 32], pl.INT32], value_ptr_out: pl.Tensor[[1024, 64, 32], pl.INT32]) -> pl.Tensor[[1024, 64, 32], pl.INT32]:
for b in pl.parallel(0, 1024, 32):
with pl.incore():
for ti in pl.parallel(0, 64, 32):
zero_src: pl.Tensor[[1, 32], pl.INT32] = pl.tensor.view(search_key, [1, 32], [b, ti, 0])
zero_tile: pl.Tensor[[1, 32], pl.INDEX] = pl.tensor.mul(zero_src, 0)
value_ptr_out: pl.Tensor[[1024, 64, 32], pl.INT32] = pl.tensor.assemble(value_ptr_out, zero_tile, [b, ti, 0])
for probe in pl.range(0, 8, 1):
round_has_active: pl.Scalar[pl.INDEX] = 0
with pl.incore():
for b in pl.parallel(0, 1024, 32):
for ti in pl.parallel(0, 64, 32):
keys_tile: pl.Tensor[[1, 32], pl.INT32] = pl.tensor.view(search_key, [1, 32], [b, ti, 0])
mixed: pl.Tensor[[1, 32], pl.INDEX] = pl.tensor.mul(keys_tile, 2654435761)
h_probe: pl.Tensor[[1, 32], pl.INDEX] = pl.tensor.ands(pl.tensor.add(mixed, probe * 2246822519), 64 - 1)
cand_key: pl.Tensor[[1, 32], pl.INDEX] = pl.tensor.mul(keys_tile, 0)
cand_val: pl.Tensor[[1, 32], pl.INDEX] = pl.tensor.mul(keys_tile, 0)
for bucket in pl.range(0, 64, 1):
bucket_mask: pl.Tensor[[1, 32], pl.INDEX] = pl.tensor.cmps(h_probe, bucket, cmp_type=0)
bucket_keys: pl.Tensor[[1, 32], pl.INT32] = pl.tensor.view(hash_pool, [1, 32], [ti, bucket, 0])
bucket_vals: pl.Tensor[[1, 32], pl.INT32] = pl.tensor.view(hash_pool, [1, 32], [ti, 64 + bucket, 0])
cand_key: pl.Tensor[[1, 32], pl.INDEX] = pl.tensor.sel(bucket_mask, bucket_keys, cand_key)
cand_val: pl.Tensor[[1, 32], pl.INDEX] = pl.tensor.sel(bucket_mask, bucket_vals, cand_val)
result_prev: pl.Tensor[[1, 32], pl.INT32] = pl.tensor.view(value_ptr_out, [1, 32], [b, ti, 0])
active_mask: pl.Tensor[[1, 32], pl.INDEX] = pl.tensor.cmps(result_prev, 0, cmp_type=0)
active_count: pl.Tensor[[1, 1], pl.INDEX] = pl.tensor.row_sum(active_mask)
active_count_s: pl.Scalar[pl.INDEX] = pl.tensor.read(active_count, [0, 0])
if active_count_s != 0:
round_has_active: pl.Scalar[pl.INDEX] = 1
key_match: pl.Tensor[[1, 32], pl.INDEX] = pl.tensor.cmp(cand_key, keys_tile, cmp_type=0)
hit_mask: pl.Tensor[[1, 32], pl.INDEX] = pl.tensor.and(active_mask, key_match)
result_next: pl.Tensor[[1, 32], pl.INDEX] = pl.tensor.sel(hit_mask, cand_val, result_prev)
value_ptr_out: pl.Tensor[[1024, 64, 32], pl.INT32] = pl.tensor.assemble(value_ptr_out, result_next, [b, ti, 0])
if round_has_active == 0:
break
return value_ptr_out
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# pypto.program: BatchHashLookup
import pypto.language as pl

@pl.program
class BatchHashLookup:
@pl.function
def batch_hash_lookup(self, search_key_0: pl.Tensor[[1024, 64, 32], pl.INT32], hash_table_size_0: pl.Tensor[[64, 32], pl.INT32], hash_base_ptr_0: pl.Tensor[[64, 32], pl.INT32], hash_pool_0: pl.Tensor[[64, 128, 32], pl.INT32], value_ptr_out_0: pl.Tensor[[1024, 64, 32], pl.INT32]) -> pl.Tensor[[1024, 64, 32], pl.INT32]:
for b_0, (value_ptr_out_iter_1,) in pl.parallel(0, 1024, 32, init_values=(value_ptr_out_0,)):
with pl.incore():
for ti_0, (value_ptr_out_iter_3,) in pl.parallel(0, 64, 32, init_values=(value_ptr_out_iter_1,)):
zero_src_0: pl.Tensor[[1, 32], pl.INT32] = pl.tensor.view(search_key_0, [1, 32], [b_0, ti_0, 0])
zero_tile_0: pl.Tensor[[1, 32], pl.INDEX] = pl.tensor.mul(zero_src_0, 0)
value_ptr_out_5: pl.Tensor[[1024, 64, 32], pl.INT32] = pl.tensor.assemble(value_ptr_out_iter_3, zero_tile_0, [b_0, ti_0, 0])
value_ptr_out_4: pl.Tensor[[1024, 64, 32], pl.INT32] = pl.yield_(value_ptr_out_5)
value_ptr_out_2: pl.Tensor[[1024, 64, 32], pl.INT32] = pl.yield_(value_ptr_out_4)
for probe_0, (b_iter_1, ti_iter_1, value_ptr_out_iter_6) in pl.range(0, 8, 1, init_values=(b_0, ti_0, value_ptr_out_2)):
round_has_active_0: pl.Scalar[pl.INDEX] = 0
with pl.incore():
for b_3, (round_has_active_iter_1, ti_iter_3, value_ptr_out_iter_8) in pl.parallel(0, 1024, 32, init_values=(round_has_active_0, ti_iter_1, value_ptr_out_iter_6)):
for ti_5, (round_has_active_iter_3, value_ptr_out_iter_10) in pl.parallel(0, 64, 32, init_values=(round_has_active_iter_1, value_ptr_out_iter_8)):
keys_tile_0: pl.Tensor[[1, 32], pl.INT32] = pl.tensor.view(search_key_0, [1, 32], [b_3, ti_5, 0])
mixed_0: pl.Tensor[[1, 32], pl.INDEX] = pl.tensor.mul(keys_tile_0, 2654435761)
h_probe_0: pl.Tensor[[1, 32], pl.INDEX] = pl.tensor.ands(pl.tensor.add(mixed_0, probe_0 * 2246822519), 64 - 1)
cand_key_0: pl.Tensor[[1, 32], pl.INDEX] = pl.tensor.mul(keys_tile_0, 0)
cand_val_0: pl.Tensor[[1, 32], pl.INDEX] = pl.tensor.mul(keys_tile_0, 0)
for bucket_0, (cand_key_iter_1, cand_val_iter_1) in pl.range(0, 64, 1, init_values=(cand_key_0, cand_val_0)):
bucket_mask_0: pl.Tensor[[1, 32], pl.INDEX] = pl.tensor.cmps(h_probe_0, bucket_0, cmp_type=0)
bucket_keys_0: pl.Tensor[[1, 32], pl.INT32] = pl.tensor.view(hash_pool_0, [1, 32], [ti_5, bucket_0, 0])
bucket_vals_0: pl.Tensor[[1, 32], pl.INT32] = pl.tensor.view(hash_pool_0, [1, 32], [ti_5, 64 + bucket_0, 0])
cand_key_3: pl.Tensor[[1, 32], pl.INDEX] = pl.tensor.sel(bucket_mask_0, bucket_keys_0, cand_key_iter_1)
cand_val_3: pl.Tensor[[1, 32], pl.INDEX] = pl.tensor.sel(bucket_mask_0, bucket_vals_0, cand_val_iter_1)
cand_key_2, cand_val_2 = pl.yield_(cand_key_3, cand_val_3)
result_prev_0: pl.Tensor[[1, 32], pl.INT32] = pl.tensor.view(value_ptr_out_iter_10, [1, 32], [b_3, ti_5, 0])
active_mask_0: pl.Tensor[[1, 32], pl.INDEX] = pl.tensor.cmps(result_prev_0, 0, cmp_type=0)
active_count_0: pl.Tensor[[1, 1], pl.INDEX] = pl.tensor.row_sum(active_mask_0)
active_count_s_0: pl.Scalar[pl.INDEX] = pl.tensor.read(active_count_0, [0, 0])
if active_count_s_0 != 0:
round_has_active_5: pl.Scalar[pl.INDEX] = 1
round_has_active_6: pl.Scalar[pl.INDEX] = pl.yield_(round_has_active_5)
else:
round_has_active_6: pl.Scalar[pl.INDEX] = pl.yield_(round_has_active_iter_3)
key_match_0: pl.Tensor[[1, 32], pl.INDEX] = pl.tensor.cmp(cand_key_2, keys_tile_0, cmp_type=0)
hit_mask_0: pl.Tensor[[1, 32], pl.INDEX] = pl.tensor.and(active_mask_0, key_match_0)
result_next_0: pl.Tensor[[1, 32], pl.INDEX] = pl.tensor.sel(hit_mask_0, cand_val_2, result_prev_0)
value_ptr_out_12: pl.Tensor[[1024, 64, 32], pl.INT32] = pl.tensor.assemble(value_ptr_out_iter_10, result_next_0, [b_3, ti_5, 0])
round_has_active_4, value_ptr_out_11 = pl.yield_(round_has_active_6, value_ptr_out_12)
round_has_active_2, ti_4, value_ptr_out_9 = pl.yield_(round_has_active_4, ti_5, value_ptr_out_11)
if round_has_active_2 == 0:
break
b_2, ti_2, value_ptr_out_7 = pl.yield_(b_3, ti_4, value_ptr_out_9)
return value_ptr_out_7
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# pypto.program: BatchHashLookup
import pypto.language as pl

@pl.program
class BatchHashLookup:
@pl.function
def batch_hash_lookup(self, search_key_0: pl.Tensor[[1024, 64, 32], pl.INT32], hash_table_size_0: pl.Tensor[[64, 32], pl.INT32], hash_base_ptr_0: pl.Tensor[[64, 32], pl.INT32], hash_pool_0: pl.Tensor[[64, 128, 32], pl.INT32], value_ptr_out_0: pl.Tensor[[1024, 64, 32], pl.INT32]) -> pl.Tensor[[1024, 64, 32], pl.INT32]:
for b_0, (value_ptr_out_iter_1,) in pl.parallel(0, 1024, 32, init_values=(value_ptr_out_0,)):
with pl.incore():
for ti_0, (value_ptr_out_iter_3,) in pl.parallel(0, 64, 32, init_values=(value_ptr_out_iter_1,)):
zero_src_0: pl.Tensor[[1, 32], pl.INT32] = pl.tensor.view(search_key_0, [1, 32], [b_0, ti_0, 0])
zero_tile_0: pl.Tensor[[1, 32], pl.INDEX] = pl.tensor.mul(zero_src_0, 0)
value_ptr_out_5: pl.Tensor[[1024, 64, 32], pl.INT32] = pl.tensor.assemble(value_ptr_out_iter_3, zero_tile_0, [b_0, ti_0, 0])
value_ptr_out_4: pl.Tensor[[1024, 64, 32], pl.INT32] = pl.yield_(value_ptr_out_5)
value_ptr_out_2: pl.Tensor[[1024, 64, 32], pl.INT32] = pl.yield_(value_ptr_out_4)
for probe_0, (b_iter_1, ti_iter_1, value_ptr_out_iter_6) in pl.range(0, 8, 1, init_values=(b_0, ti_0, value_ptr_out_2)):
round_has_active_0: pl.Scalar[pl.INDEX] = 0
with pl.incore():
for b_3, (round_has_active_iter_1, ti_iter_3, value_ptr_out_iter_8) in pl.parallel(0, 1024, 32, init_values=(round_has_active_0, ti_iter_1, value_ptr_out_iter_6)):
for ti_5, (round_has_active_iter_3, value_ptr_out_iter_10) in pl.parallel(0, 64, 32, init_values=(round_has_active_iter_1, value_ptr_out_iter_8)):
keys_tile_0: pl.Tensor[[1, 32], pl.INT32] = pl.tensor.view(search_key_0, [1, 32], [b_3, ti_5, 0])
mixed_0: pl.Tensor[[1, 32], pl.INDEX] = pl.tensor.mul(keys_tile_0, 2654435761)
_t0: pl.Tensor[[1, 32], pl.INDEX] = pl.tensor.add(mixed_0, probe_0 * 2246822519)
h_probe_0: pl.Tensor[[1, 32], pl.INDEX] = pl.tensor.ands(_t0, 64 - 1)
cand_key_0: pl.Tensor[[1, 32], pl.INDEX] = pl.tensor.mul(keys_tile_0, 0)
cand_val_0: pl.Tensor[[1, 32], pl.INDEX] = pl.tensor.mul(keys_tile_0, 0)
for bucket_0, (cand_key_iter_1, cand_val_iter_1) in pl.range(0, 64, 1, init_values=(cand_key_0, cand_val_0)):
bucket_mask_0: pl.Tensor[[1, 32], pl.INDEX] = pl.tensor.cmps(h_probe_0, bucket_0, cmp_type=0)
bucket_keys_0: pl.Tensor[[1, 32], pl.INT32] = pl.tensor.view(hash_pool_0, [1, 32], [ti_5, bucket_0, 0])
bucket_vals_0: pl.Tensor[[1, 32], pl.INT32] = pl.tensor.view(hash_pool_0, [1, 32], [ti_5, 64 + bucket_0, 0])
cand_key_3: pl.Tensor[[1, 32], pl.INDEX] = pl.tensor.sel(bucket_mask_0, bucket_keys_0, cand_key_iter_1)
cand_val_3: pl.Tensor[[1, 32], pl.INDEX] = pl.tensor.sel(bucket_mask_0, bucket_vals_0, cand_val_iter_1)
cand_key_2, cand_val_2 = pl.yield_(cand_key_3, cand_val_3)
result_prev_0: pl.Tensor[[1, 32], pl.INT32] = pl.tensor.view(value_ptr_out_iter_10, [1, 32], [b_3, ti_5, 0])
active_mask_0: pl.Tensor[[1, 32], pl.INDEX] = pl.tensor.cmps(result_prev_0, 0, cmp_type=0)
active_count_0: pl.Tensor[[1, 1], pl.INDEX] = pl.tensor.row_sum(active_mask_0)
active_count_s_0: pl.Scalar[pl.INDEX] = pl.tensor.read(active_count_0, [0, 0])
if active_count_s_0 != 0:
round_has_active_5: pl.Scalar[pl.INDEX] = 1
round_has_active_6: pl.Scalar[pl.INDEX] = pl.yield_(round_has_active_5)
else:
round_has_active_6: pl.Scalar[pl.INDEX] = pl.yield_(round_has_active_iter_3)
key_match_0: pl.Tensor[[1, 32], pl.INDEX] = pl.tensor.cmp(cand_key_2, keys_tile_0, cmp_type=0)
hit_mask_0: pl.Tensor[[1, 32], pl.INDEX] = pl.tensor.and(active_mask_0, key_match_0)
result_next_0: pl.Tensor[[1, 32], pl.INDEX] = pl.tensor.sel(hit_mask_0, cand_val_2, result_prev_0)
value_ptr_out_12: pl.Tensor[[1024, 64, 32], pl.INT32] = pl.tensor.assemble(value_ptr_out_iter_10, result_next_0, [b_3, ti_5, 0])
round_has_active_4, value_ptr_out_11 = pl.yield_(round_has_active_6, value_ptr_out_12)
round_has_active_2, ti_4, value_ptr_out_9 = pl.yield_(round_has_active_4, ti_5, value_ptr_out_11)
if round_has_active_2 == 0:
break
b_2, ti_2, value_ptr_out_7 = pl.yield_(b_3, ti_4, value_ptr_out_9)
return value_ptr_out_7
Loading
Loading