Skip to content

RuntimeError: Triton Error [CUDA]: context is destroyed #28

@huxiaopang666

Description

@huxiaopang666

when I use your kat_group as a module, it runs normally on cuda:0; however, when I run it on other GPUs, I encounter the error mentioned in the title. I tested the code you provided and experienced the same issue. Could you please let me know how to fix it?
import torch
import torch.nn as nn
from kat_rational import KAT_Group
class KAN(nn.Module):
"""MLP as used in Vision Transformer, MLP-Mixer and related networks."""

def __init__(
        self,
        in_features,
        hidden_features=None,
        out_features=None,
        act_cfg=dict(type="KAT", act_init=["identity", "gelu"]),
        bias=True,
        drop=0.,
):
    super().__init__()
    out_features = out_features or in_features
    hidden_features = hidden_features or in_features

    self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
    self.act1 = KAT_Group(mode = act_cfg['act_init'][0])
    self.drop1 = nn.Dropout(drop)
    self.act2 = KAT_Group(mode = act_cfg['act_init'][1])
    self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
    self.drop2 = nn.Dropout(drop)

def forward(self, x):
    x = self.act1(x)
    x = self.drop1(x)
    x = self.fc1(x)
    x = self.act2(x)
    x = self.drop2(x)
    x = self.fc2(x)
    return x

N, C = 8, 64
input_tensor = torch.randn(N, C).to('cuda:1')
model = KAN(in_features=C, hidden_features=128, out_features=C).to('cuda:1')
output = model(input_tensor)
print(output.shape)

Image

Image

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions