Skip to content

arsalanu/rotbot

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

13 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

rotbot: Post-training LLMs for nD-shape inference

this is a work in progress

The idea:

  • Teach a base model to do n-dimensional shape inference on tensors.
  • Start with pytorch transformation code, given tensor shape A and ops X, Y, Z, produce B=Z(Y(X(A)))
  • Directly verifiable rewards with TorchFX tracing to generate intermediate primitives which can be generated into reasoning traces.
  • The idea is not to design a compiler for torch shape ops, but understand whether we can encode implicit rules for doing tensor transformations into the model that apply outside of the dataset scope.

For example, on the base Qwen3 4B:

Parsed Input:
  Code: 'q = s[None, :]\nf = q.expand(4, -1)'
  Args: {'s': [16]}
  Target: f

============================================================
Model Response:
============================================================
Okay, let's try to figure out the shape of tensor f step by step. So, the code given is q = s[None, :] and then f = q.expand(4, -1). The input shape for s is [16], and we need to find the shape of f.

First, let's look at step 0. The operation here is s[None, :]. The s tensor has a shape of [16]. When you do [None, :], that's adding a new dimension at the beginning. In PyTorch, [None] is like adding a dimension of size 1. So s[None, :] would take the 16-element tensor and add a dimension at index 0, making it a 2D tensor with shape [1, 16]. Let me confirm: if s is [16], then s[None, :] is equivalent to adding a batch dimension, so yes, it becomes (1, 16). So step 0: s[None, :] -> [1, 16].

Next, step 1 is f = q.expand(4, -1). Here, q is the tensor from step 0, which is [1, 16]. The expand operation in PyTorch is used to expand the tensor to a new shape without copying data. The parameters here are (4, -1). The first dimension is 4, which means we're expanding the first dimension from 1 to 4. The second dimension is -1, which means it will calculate the size based on the original size. Since the original second dimension is 16, expanding the second dimension to -1 would keep it as 16. Wait, but expand is a bit different from reshape. Let me think.

In PyTorch, when you do expand(4, -1), the -1 is a dynamic dimension. So the original shape of q is [1, 16]. The expand operation takes the new shape as (4, 16) because -1 is replaced by the original size of the corresponding dimension. Wait, no. Wait, the expand function in PyTorch takes the new shape as a tuple. The -1 means "infer the size based on the original tensor's dimensions". So for example, if you have a tensor of shape [a, b], and you do expand(4, -1), it would become [4, b], because the -1
============================================================

and with some finetuning:

============================================================
Model Response:
============================================================
step 0: add dimension at position 0 | [16] -> [1, 16]
step 1: expand(4) | [1, 16] -> [4, 16]
</think>
final: [4, 16]
============================================================

Components:

  • Dataset (entirely synthetically generated)
  • Cold-start style supervised finetuning process, enable "style transfer" for primitives into base model.
  • Single step CoT reasoning based reinforcement learning.
  • Simple local vLLM based inference with a CLI interface.

Implemented with torchtune and torchforge to stay within the Pytorch ecosystem and also experiment with torchforge asynchronous workers for vLLM policy generation, policy training and reward generation. This project uses Qwen3 (Thinking) as a base model.

Current issues:

  • RL rewards saturate at final answer reward without good reasoning steps. Likely cause is that the dataset is a bit crap, and I need to spend more time designing a more diverse and representative dataset with a larger variety in complexity. This is probably the most annoying part of this project, because hand-stringing together a good dataset is pretty time-intensive!
  • Needs more exploration into whether manually implemented GRPO loss is sufficient.
  • Pretty resource limited on GPUs.

The current SFT checkpoint is on huggingface at arsalanu/rotbot-sft

About

Experimental synthetic dataset construction, SFT and CoT-RL for multidimensional shape inference problems

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors