Skip to content
This repository was archived by the owner on Jun 21, 2024. It is now read-only.
This repository was archived by the owner on Jun 21, 2024. It is now read-only.

Training with Hidet compiler #7

Description

@RameshArvind

Hello!
I was wondering if there was anything extra that needed to be done to get training with Hidet compiler working.

Out of the box I seem to be running into errors

import torch
from palm_rlhf_pytorch import PaLM

palm = PaLM(
    num_tokens = 20000,
    dim = 512,
    depth = 12,
    flash_attn = True, # https://arxiv.org/abs/2205.14135
    cross_entropy_ignore_index = 0
).to(torch.bfloat16).cuda()

palm_opt = torch.compile(palm, backend='hidet')

seq = torch.randint(0, 20000, (1, 1024)).cuda()

loss = palm_opt(seq, return_loss = True)
loss.backward()

Some of the errors I faced were around the usage of rearrange here and here.

It also seems like einsum isn't supported. Even after replacing those OPs with equivalent alternatives. I'm still running into some reshape errors from hidet

AssertionError: , occurred when interpreting reshape with
  tensor_reshape(tensor(...), [1023])

I can post additional info as needed, but wondering if you ran into those same errors or if I'm doing something incorrectly.

Thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions