Skip to content

Commit 269adba

Browse files
[Feature]: Support for Tenstorrent Galaxy #2898 (WIP)
1 parent 5b8ce35 commit 269adba

File tree

3 files changed

+3152
-26
lines changed

3 files changed

+3152
-26
lines changed

runner/internal/shim/host/gpu.go

Lines changed: 104 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -200,38 +200,38 @@ func unmarshalTtSmiSnapshot(data []byte) (*ttSmiSnapshot, error) {
200200
}
201201

202202
func getGpusFromTtSmiSnapshot(snapshot *ttSmiSnapshot) []GpuInfo {
203-
// Group devices by board_id to aggregate memory for the same physical GPU
204-
boardMap := make(map[string]*GpuInfo)
203+
// Create a map to track "L" devices and their corresponding "R" devices
204+
// Each "L" device becomes a separate GPU
205+
lDeviceMap := make(map[string]*GpuInfo)
205206
indexCounter := 0
206207

207-
for _, device := range snapshot.DeviceInfo {
208+
// First pass: identify all "L" and "R" devices
209+
for i, device := range snapshot.DeviceInfo {
208210
boardID := device.BoardInfo.BoardID
209-
210-
// Extract board type without R/L suffix
211211
boardType := strings.TrimSpace(device.BoardInfo.BoardType)
212-
name := boardType
213212

214-
// Remove " R" or " L" suffix if present
215-
if strings.HasSuffix(boardType, " R") {
216-
name = boardType[:len(boardType)-2]
217-
} else if strings.HasSuffix(boardType, " L") {
218-
name = boardType[:len(boardType)-2]
219-
}
213+
// Determine if this is an "L" device
214+
isLDevice := strings.HasSuffix(boardType, " L")
220215

221-
// Determine base VRAM based on board type
222-
baseVram := 0
223-
if strings.HasPrefix(name, "n150") {
224-
baseVram = 12 * 1024 // 12GB in MiB
225-
} else if strings.HasPrefix(name, "n300") {
226-
baseVram = 12 * 1024 // 12GB in MiB
227-
}
216+
if isLDevice {
217+
// Create unique identifier for this "L" device
218+
uniqueID := fmt.Sprintf("%s_L_%d", boardID, i)
219+
220+
// Extract base name without L suffix
221+
name := boardType[:len(boardType)-2]
222+
223+
// Determine base VRAM based on board type
224+
baseVram := 0
225+
if strings.HasPrefix(name, "n150") {
226+
baseVram = 12 * 1024 // 12GB in MiB
227+
} else if strings.HasPrefix(name, "n300") {
228+
baseVram = 12 * 1024 // 12GB in MiB
229+
} else if strings.HasPrefix(name, "tt-galaxy-wh") {
230+
baseVram = 12 * 1024 // 12GB in MiB
231+
}
228232

229-
if existingGpu, exists := boardMap[boardID]; exists {
230-
// Aggregate VRAM for the same board_id
231-
existingGpu.Vram += baseVram
232-
} else {
233-
// Create new GPU entry
234-
boardMap[boardID] = &GpuInfo{
233+
// Create new GPU entry for "L" device
234+
lDeviceMap[uniqueID] = &GpuInfo{
235235
Vendor: common.GpuVendorTenstorrent,
236236
Name: name,
237237
Vram: baseVram,
@@ -242,12 +242,90 @@ func getGpusFromTtSmiSnapshot(snapshot *ttSmiSnapshot) []GpuInfo {
242242
}
243243
}
244244

245+
// Second pass: add memory from "R" devices to corresponding "L" devices
246+
for _, device := range snapshot.DeviceInfo {
247+
boardID := device.BoardInfo.BoardID
248+
boardType := strings.TrimSpace(device.BoardInfo.BoardType)
249+
250+
if strings.HasSuffix(boardType, " R") {
251+
// Find the corresponding "L" device with the same board_id
252+
// Since we need to match "R" to "L", we'll use the board_id as the key
253+
// and add memory to the first "L" device we find with that board_id
254+
for _, gpu := range lDeviceMap {
255+
if gpu.ID == boardID {
256+
// Extract base name without R suffix
257+
name := boardType[:len(boardType)-2]
258+
259+
// Determine base VRAM based on board type
260+
baseVram := 0
261+
if strings.HasPrefix(name, "n150") {
262+
baseVram = 12 * 1024 // 12GB in MiB
263+
} else if strings.HasPrefix(name, "n300") {
264+
baseVram = 12 * 1024 // 12GB in MiB
265+
}
266+
267+
// Add memory to the "L" device
268+
gpu.Vram += baseVram
269+
break // Only add to the first matching "L" device
270+
}
271+
}
272+
}
273+
}
274+
275+
// Handle devices without L/R suffix (backward compatibility)
276+
for i, device := range snapshot.DeviceInfo {
277+
boardID := device.BoardInfo.BoardID
278+
boardType := strings.TrimSpace(device.BoardInfo.BoardType)
279+
280+
if !strings.HasSuffix(boardType, " L") && !strings.HasSuffix(boardType, " R") {
281+
// For devices without L/R suffix, treat them as standalone GPUs
282+
// This maintains backward compatibility with existing data
283+
uniqueID := fmt.Sprintf("%s_standalone_%d", boardID, i)
284+
285+
// Determine base VRAM based on board type
286+
baseVram := 0
287+
if strings.HasPrefix(boardType, "n150") {
288+
baseVram = 12 * 1024 // 12GB in MiB
289+
} else if strings.HasPrefix(boardType, "n300") {
290+
baseVram = 12 * 1024 // 12GB in MiB
291+
}
292+
293+
// Check if we already have a GPU with this board_id (old behavior)
294+
existingGpu := false
295+
for _, gpu := range lDeviceMap {
296+
if gpu.ID == boardID {
297+
gpu.Vram += baseVram
298+
existingGpu = true
299+
break
300+
}
301+
}
302+
303+
if !existingGpu {
304+
// Create new GPU entry
305+
lDeviceMap[uniqueID] = &GpuInfo{
306+
Vendor: common.GpuVendorTenstorrent,
307+
Name: boardType,
308+
Vram: baseVram,
309+
ID: boardID,
310+
Index: strconv.Itoa(indexCounter),
311+
}
312+
indexCounter++
313+
}
314+
}
315+
}
316+
245317
// Convert map to slice
246318
var gpus []GpuInfo
247-
for _, gpu := range boardMap {
319+
for _, gpu := range lDeviceMap {
248320
gpus = append(gpus, *gpu)
249321
}
250322

323+
// Sort by the original index to ensure consistent ordering
324+
// We'll reassign indices sequentially based on the original order
325+
for i := range gpus {
326+
gpus[i].Index = strconv.Itoa(i)
327+
}
328+
251329
return gpus
252330
}
253331

runner/internal/shim/host/gpu_test.go

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"os"
55
"path/filepath"
66
"reflect"
7+
"strconv"
78
"testing"
89

910
"github.com/dstackai/dstack/runner/internal/common"
@@ -237,3 +238,54 @@ func TestGetGpusFromTtSmiSnapshotMultipleDevices(t *testing.T) {
237238
}
238239
}
239240
}
241+
242+
func TestGetGpusFromTtSmiSnapshotGalaxy(t *testing.T) {
243+
data, err := loadTestData("tenstorrent/galaxy.json")
244+
if err != nil {
245+
t.Fatalf("Failed to load test data: %v", err)
246+
}
247+
snapshot, err := unmarshalTtSmiSnapshot(data)
248+
if err != nil {
249+
t.Fatalf("Failed to unmarshal snapshot: %v", err)
250+
}
251+
252+
gpus := getGpusFromTtSmiSnapshot(snapshot)
253+
254+
// Galaxy.json contains 32 devices with board_type "tt-galaxy-wh L"
255+
// Each "L" device should be treated as a separate GPU
256+
// Each "tt-galaxy-wh" device has 12GB VRAM
257+
if len(gpus) != 32 {
258+
t.Errorf("getGpusFromTtSmiSnapshot() returned %d GPUs, want 32", len(gpus))
259+
}
260+
261+
// Calculate total VRAM: 32 devices × 12GB = 384GB
262+
totalVram := 32 * 12 * 1024 // 32 devices × 12GB × 1024 MiB/GB
263+
actualTotalVram := 0
264+
265+
// Verify all GPUs have the correct properties
266+
for i, gpu := range gpus {
267+
if gpu.Vendor != common.GpuVendorTenstorrent {
268+
t.Errorf("GPU[%d] vendor = %v, want %v", i, gpu.Vendor, common.GpuVendorTenstorrent)
269+
}
270+
if gpu.Name != "tt-galaxy-wh" {
271+
t.Errorf("GPU[%d] name = %s, want tt-galaxy-wh", i, gpu.Name)
272+
}
273+
if gpu.ID != "100035100000000" {
274+
t.Errorf("GPU[%d] ID = %s, want 100035100000000", i, gpu.ID)
275+
}
276+
if gpu.Vram != 12*1024 {
277+
t.Errorf("GPU[%d] VRAM = %d, want %d", i, gpu.Vram, 12*1024)
278+
}
279+
// Verify indices are sequential (0, 1, 2, ..., 31)
280+
expectedIndex := strconv.Itoa(i)
281+
if gpu.Index != expectedIndex {
282+
t.Errorf("GPU[%d] index = %s, want %s", i, gpu.Index, expectedIndex)
283+
}
284+
actualTotalVram += gpu.Vram
285+
}
286+
287+
// Verify total VRAM is 384GB
288+
if actualTotalVram != totalVram {
289+
t.Errorf("Total VRAM = %d MiB, want %d MiB (384GB)", actualTotalVram, totalVram)
290+
}
291+
}

0 commit comments

Comments
 (0)