Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions runner/.justfile
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ build-runner-binary:
#!/usr/bin/env bash
set -e
echo "Building runner for linux/amd64"
cd {{source_directory()}}/cmd/runner && GOOS=linux GOARCH=amd64 go build
cd {{source_directory()}}/cmd/runner && GOOS=linux GOARCH=amd64 go build -ldflags "-X 'main.Version=$version' -extldflags '-static'"
echo "Runner build complete!"

# Build shim
Expand All @@ -56,12 +56,12 @@ build-shim-binary:
cd {{source_directory()}}/cmd/shim
if [ -n "$shim_os" ] && [ -n "$shim_arch" ]; then
echo "Building shim for $shim_os/$shim_arch"
GOOS=$shim_os GOARCH=$shim_arch go build
GOOS=$shim_os GOARCH=$shim_arch go build -ldflags "-X 'main.Version=$version' -extldflags '-static'"
else
echo "Building shim for current platform"
go build
go build -ldflags "-X 'main.Version=$version' -extldflags '-static'"
fi
echo "Shim build complete!"
echo "Shim build (version: $version) complete!"

# Build both runner and shim
build-runner: build-runner-binary build-shim-binary
Expand Down
97 changes: 64 additions & 33 deletions runner/internal/shim/host/gpu.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,67 @@ type ttDeviceInfo struct {

type ttBoardInfo struct {
BoardType string `json:"board_type"`
BusID string `json:"bus_id"`
BoardID string `json:"board_id"`
}

func unmarshalTtSmiSnapshot(data []byte) (*ttSmiSnapshot, error) {
var snapshot ttSmiSnapshot
if err := json.Unmarshal(data, &snapshot); err != nil {
return nil, err
}
return &snapshot, nil
}

func getGpusFromTtSmiSnapshot(snapshot *ttSmiSnapshot) []GpuInfo {
// Group devices by board_id to aggregate memory for the same physical GPU
boardMap := make(map[string]*GpuInfo)
indexCounter := 0

for _, device := range snapshot.DeviceInfo {
boardID := device.BoardInfo.BoardID

// Extract board type without R/L suffix
boardType := strings.TrimSpace(device.BoardInfo.BoardType)
name := boardType

// Remove " R" or " L" suffix if present
if strings.HasSuffix(boardType, " R") {
name = boardType[:len(boardType)-2]
} else if strings.HasSuffix(boardType, " L") {
name = boardType[:len(boardType)-2]
}

// Determine base VRAM based on board type
baseVram := 0
if strings.HasPrefix(name, "n150") {
baseVram = 12 * 1024 // 12GB in MiB
} else if strings.HasPrefix(name, "n300") {
baseVram = 12 * 1024 // 12GB in MiB
}

if existingGpu, exists := boardMap[boardID]; exists {
// Aggregate VRAM for the same board_id
existingGpu.Vram += baseVram
} else {
// Create new GPU entry
boardMap[boardID] = &GpuInfo{
Vendor: common.GpuVendorTenstorrent,
Name: name,
Vram: baseVram,
ID: boardID,
Index: strconv.Itoa(indexCounter),
}
indexCounter++
}
}

// Convert map to slice
var gpus []GpuInfo
for _, gpu := range boardMap {
gpus = append(gpus, *gpu)
}

return gpus
}

func getTenstorrentGpuInfo(ctx context.Context) []GpuInfo {
Expand Down Expand Up @@ -218,43 +278,14 @@ func getTenstorrentGpuInfo(ctx context.Context) []GpuInfo {
return gpus
}

var ttSmiSnapshot ttSmiSnapshot
if err := json.Unmarshal([]byte(res.Stdout), &ttSmiSnapshot); err != nil {
ttSmiSnapshot, err := unmarshalTtSmiSnapshot([]byte(res.Stdout))
if err != nil {
log.Error(ctx, "cannot read tt-smi json", "err", err)
log.Debug(ctx, "tt-smi output", "stdout", res.Stdout)
return gpus
}

for i, device := range ttSmiSnapshot.DeviceInfo {
// Extract board type without R/L suffix
boardType := strings.TrimSpace(device.BoardInfo.BoardType)
name := boardType

// Remove " R" or " L" suffix if present
if strings.HasSuffix(boardType, " R") {
name = boardType[:len(boardType)-2]
} else if strings.HasSuffix(boardType, " L") {
name = boardType[:len(boardType)-2]
}

// Determine VRAM based on board type
vram := 0
if strings.HasPrefix(name, "n150") {
vram = 12 * 1024 // 12GB in MiB
} else if strings.HasPrefix(name, "n300") {
vram = 24 * 1024 // 24GB in MiB
}

gpus = append(gpus, GpuInfo{
Vendor: common.GpuVendorTenstorrent,
Name: name,
Vram: vram,
ID: device.BoardInfo.BusID,
Index: strconv.Itoa(i),
})
}

return gpus
return getGpusFromTtSmiSnapshot(ttSmiSnapshot)
}

func getAmdRenderNodePath(bdf string) (string, error) {
Expand Down
239 changes: 239 additions & 0 deletions runner/internal/shim/host/gpu_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
package host

import (
"os"
"path/filepath"
"reflect"
"testing"

"github.com/dstackai/dstack/runner/internal/common"
)

func loadTestData(filename string) ([]byte, error) {
path := filepath.Join("testdata", filename)
return os.ReadFile(path)
}

func TestUnmarshalTtSmiSnapshot(t *testing.T) {
tests := []struct {
name string
filename string
want *ttSmiSnapshot
wantErr bool
}{
{
name: "valid single device",
filename: "tenstorrent/valid_single_device.json",
want: &ttSmiSnapshot{
DeviceInfo: []ttDeviceInfo{
{
BoardInfo: ttBoardInfo{
BoardType: "n150 L",
BoardID: "100018611902010",
},
},
},
},
wantErr: false,
},
{
name: "valid multiple devices",
filename: "tenstorrent/valid_multiple_devices.json",
want: &ttSmiSnapshot{
DeviceInfo: []ttDeviceInfo{
{
BoardInfo: ttBoardInfo{
BoardType: "n300 L",
BoardID: "10001451172208f",
},
},
{
BoardInfo: ttBoardInfo{
BoardType: "n300 L",
BoardID: "100014511722053",
},
},
{
BoardInfo: ttBoardInfo{
BoardType: "n300 L",
BoardID: "10001451172209c",
},
},
{
BoardInfo: ttBoardInfo{
BoardType: "n300 L",
BoardID: "100014511722058",
},
},
{
BoardInfo: ttBoardInfo{
BoardType: "n300 R",
BoardID: "10001451172208f",
},
},
{
BoardInfo: ttBoardInfo{
BoardType: "n300 R",
BoardID: "100014511722053",
},
},
{
BoardInfo: ttBoardInfo{
BoardType: "n300 R",
BoardID: "10001451172209c",
},
},
{
BoardInfo: ttBoardInfo{
BoardType: "n300 R",
BoardID: "100014511722058",
},
},
},
},
wantErr: false,
},
{
name: "empty device info",
filename: "tenstorrent/empty_device_info.json",
want: &ttSmiSnapshot{
DeviceInfo: []ttDeviceInfo{},
},
wantErr: false,
},
{
name: "invalid JSON",
filename: "tenstorrent/invalid_json.json",
want: nil,
wantErr: true,
},
{
name: "missing device_info field",
filename: "tenstorrent/missing_device_info.json",
want: &ttSmiSnapshot{DeviceInfo: nil},
wantErr: false,
},
{
name: "empty JSON",
filename: "tenstorrent/empty_json.json",
want: &ttSmiSnapshot{DeviceInfo: nil},
wantErr: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
data, err := loadTestData(tt.filename)
if err != nil {
t.Fatalf("Failed to load test data from %s: %v", tt.filename, err)
}

got, err := unmarshalTtSmiSnapshot(data)
if (err != nil) != tt.wantErr {
t.Errorf("unmarshalTtSmiSnapshot() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr {
if got == nil {
t.Errorf("unmarshalTtSmiSnapshot() returned nil, expected non-nil result")
return
}
if len(got.DeviceInfo) != len(tt.want.DeviceInfo) {
t.Errorf("unmarshalTtSmiSnapshot() device count = %v, want %v", len(got.DeviceInfo), len(tt.want.DeviceInfo))
return
}
for i, device := range got.DeviceInfo {
if i >= len(tt.want.DeviceInfo) {
break
}
expected := tt.want.DeviceInfo[i]
if device.BoardInfo.BoardType != expected.BoardInfo.BoardType {
t.Errorf("unmarshalTtSmiSnapshot() device[%d].BoardInfo.BoardType = %v, want %v", i, device.BoardInfo.BoardType, expected.BoardInfo.BoardType)
}
if device.BoardInfo.BoardID != expected.BoardInfo.BoardID {
t.Errorf("unmarshalTtSmiSnapshot() device[%d].BoardInfo.BoardID = %v, want %v", i, device.BoardInfo.BoardID, expected.BoardInfo.BoardID)
}
}
}
})
}
}

func TestGetGpusFromTtSmiSnapshot(t *testing.T) {
data, err := loadTestData("tenstorrent/single_n150_gpu.json")
if err != nil {
t.Fatalf("Failed to load test data: %v", err)
}
snapshot, err := unmarshalTtSmiSnapshot(data)
if err != nil {
t.Fatalf("Failed to unmarshal snapshot: %v", err)
}

expectedGpus := []GpuInfo{
{
Vendor: common.GpuVendorTenstorrent,
Name: "n150",
Vram: 12 * 1024,
ID: "100018611902010",
Index: "0",
},
}

gpus := getGpusFromTtSmiSnapshot(snapshot)

if !reflect.DeepEqual(gpus, expectedGpus) {
t.Errorf("getGpusFromTtSmiSnapshot() = %v, want %v", gpus, expectedGpus)
}
}

func TestGetGpusFromTtSmiSnapshotMultipleDevices(t *testing.T) {
data, err := loadTestData("tenstorrent/valid_multiple_devices.json")
if err != nil {
t.Fatalf("Failed to load test data: %v", err)
}
snapshot, err := unmarshalTtSmiSnapshot(data)
if err != nil {
t.Fatalf("Failed to unmarshal snapshot: %v", err)
}

gpus := getGpusFromTtSmiSnapshot(snapshot)

// Verify we have 4 unique GPUs (grouped by board_id)
if len(gpus) != 4 {
t.Errorf("getGpusFromTtSmiSnapshot() returned %d GPUs, want 4", len(gpus))
}

// Create a map to check the results by board_id
gpusByID := make(map[string]GpuInfo)
for _, gpu := range gpus {
gpusByID[gpu.ID] = gpu
}

// Verify specific GPUs and their aggregated VRAM
expectedGpus := map[string]struct {
name string
vram int
}{
"10001451172208f": {"n300", 24 * 1024}, // 12GB (n300 L) + 12GB (n300 R) = 24GB
"100014511722053": {"n300", 24 * 1024}, // 12GB (n300 L) + 12GB (n300 R) = 24GB
"10001451172209c": {"n300", 24 * 1024}, // 12GB (n300 L) + 12GB (n300 R) = 24GB
"100014511722058": {"n300", 24 * 1024}, // 12GB (n300 L) + 12GB (n300 R) = 24GB
}

for boardID, expected := range expectedGpus {
gpu, exists := gpusByID[boardID]
if !exists {
t.Errorf("Expected GPU with board_id %s not found", boardID)
continue
}
if gpu.Name != expected.name {
t.Errorf("GPU %s: name = %s, want %s", boardID, gpu.Name, expected.name)
}
if gpu.Vram != expected.vram {
t.Errorf("GPU %s: VRAM = %d, want %d", boardID, gpu.Vram, expected.vram)
}
if gpu.Vendor != common.GpuVendorTenstorrent {
t.Errorf("GPU %s: vendor = %v, want %v", boardID, gpu.Vendor, common.GpuVendorTenstorrent)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
{
"time": "2025-06-20T12:10:28.926938",
"host_info": {
"OS": "Linux",
"Distro": "Ubuntu 20.04.6 LTS",
"Kernel": "5.15.0-138-generic",
"Hostname": "empty-system",
"Platform": "x86_64",
"Python": "3.8.10",
"Memory": "16.00 GB",
"Driver": "TT-KMD 1.33"
},
"host_sw_vers": {
"tt_smi": "3.0.15",
"pyluwen": "0.7.2"
},
"device_info": []
}
Loading
Loading