import jax
from penzai import pz
my_layer = pz.nn.Linear.from_config(
"my_layer", jax.random.key(0), input_axes={"in_axis": 2}, output_axes={"out_axis": 3}
)
nonstr_batch_axis = pz.nx.TmpPosAxisMarker()
my_layer(pz.nx.zeros({"in_axis": 2, nonstr_batch_axis: 3}))
The following raises an assertion error here:
|
assert isinstance(subkey, str) |
I don't think any assumptions of shapecheck require this, and it limits the ability to do automatic axis naming without worry of axis name overlaps.
The following raises an assertion error here:
penzai/penzai/core/shapecheck.py
Line 621 in aac7808
I don't think any assumptions of shapecheck require this, and it limits the ability to do automatic axis naming without worry of axis name overlaps.