this is a work in progress
- Teach a base model to do n-dimensional shape inference on tensors.
- Start with pytorch transformation code, given tensor shape A and ops X, Y, Z, produce B=Z(Y(X(A)))
- Directly verifiable rewards with TorchFX tracing to generate intermediate primitives which can be generated into reasoning traces.
- The idea is not to design a compiler for torch shape ops, but understand whether we can encode implicit rules for doing tensor transformations into the model that apply outside of the dataset scope.
For example, on the base Qwen3 4B:
Parsed Input:
Code: 'q = s[None, :]\nf = q.expand(4, -1)'
Args: {'s': [16]}
Target: f
============================================================
Model Response:
============================================================
Okay, let's try to figure out the shape of tensor f step by step. So, the code given is q = s[None, :] and then f = q.expand(4, -1). The input shape for s is [16], and we need to find the shape of f.
First, let's look at step 0. The operation here is s[None, :]. The s tensor has a shape of [16]. When you do [None, :], that's adding a new dimension at the beginning. In PyTorch, [None] is like adding a dimension of size 1. So s[None, :] would take the 16-element tensor and add a dimension at index 0, making it a 2D tensor with shape [1, 16]. Let me confirm: if s is [16], then s[None, :] is equivalent to adding a batch dimension, so yes, it becomes (1, 16). So step 0: s[None, :] -> [1, 16].
Next, step 1 is f = q.expand(4, -1). Here, q is the tensor from step 0, which is [1, 16]. The expand operation in PyTorch is used to expand the tensor to a new shape without copying data. The parameters here are (4, -1). The first dimension is 4, which means we're expanding the first dimension from 1 to 4. The second dimension is -1, which means it will calculate the size based on the original size. Since the original second dimension is 16, expanding the second dimension to -1 would keep it as 16. Wait, but expand is a bit different from reshape. Let me think.
In PyTorch, when you do expand(4, -1), the -1 is a dynamic dimension. So the original shape of q is [1, 16]. The expand operation takes the new shape as (4, 16) because -1 is replaced by the original size of the corresponding dimension. Wait, no. Wait, the expand function in PyTorch takes the new shape as a tuple. The -1 means "infer the size based on the original tensor's dimensions". So for example, if you have a tensor of shape [a, b], and you do expand(4, -1), it would become [4, b], because the -1
============================================================
and with some finetuning:
============================================================
Model Response:
============================================================
step 0: add dimension at position 0 | [16] -> [1, 16]
step 1: expand(4) | [1, 16] -> [4, 16]
</think>
final: [4, 16]
============================================================
- Dataset (entirely synthetically generated)
- Cold-start style supervised finetuning process, enable "style transfer" for primitives into base model.
- Single step CoT reasoning based reinforcement learning.
- Simple local vLLM based inference with a CLI interface.
Implemented with torchtune and torchforge to stay within the Pytorch ecosystem and also experiment with torchforge asynchronous workers for vLLM policy generation, policy training and reward generation. This project uses Qwen3 (Thinking) as a base model.
- RL rewards saturate at final answer reward without good reasoning steps. Likely cause is that the dataset is a bit crap, and I need to spend more time designing a more diverse and representative dataset with a larger variety in complexity. This is probably the most annoying part of this project, because hand-stringing together a good dataset is pretty time-intensive!
- Needs more exploration into whether manually implemented GRPO loss is sufficient.
- Pretty resource limited on GPUs.
The current SFT checkpoint is on huggingface at arsalanu/rotbot-sft