Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions .github/workflows/docker-tt-smi.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
name: Build TT SMI Docker image

on:
workflow_dispatch:
inputs:
image_name:
description: "Docker image name"
required: true
default: "dstackai/tt-smi"
tt_smi_version:
description: "TT SMI version"
required: true
default: "3.0.25"

jobs:
build-tt-smi:
defaults:
run:
working-directory: docker/tt-smi
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to DockerHub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Build and upload to DockerHub
run: |
IMAGE_NAME=${{ inputs.image_name }}
docker buildx build . \
--load \
--provenance=false \
--platform linux/amd64 \
--build-arg IMAGE_NAME=${IMAGE_NAME} \
--build-arg TT_SMI_VERSION=${{ inputs.tt_smi_version }} \
--build-arg BUILD_DATE=$(date --utc --iso-8601=seconds)Z \
--tag ${IMAGE_NAME}:latest
VERSION=$(docker inspect --format '{{ index .Config.Labels "org.opencontainers.image.version" }}' ${IMAGE_NAME})
docker tag ${IMAGE_NAME}:latest ${IMAGE_NAME}:${VERSION}
docker push ${IMAGE_NAME}:${VERSION}
docker push ${IMAGE_NAME}:latest
22 changes: 22 additions & 0 deletions docker/tt-smi/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
FROM ubuntu:20.04

ARG IMAGE_NAME
ARG TT_SMI_VERSION
ARG BUILD_DATE

ENV PATH="/root/.cargo/bin:${PATH}"

RUN \
apt-get update && \
apt-get install -y curl git python3 python3-pip && \
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y && \
pip3 install --no-cache-dir git+https://github.com/tenstorrent/tt-smi@v${TT_SMI_VERSION} && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*

ENTRYPOINT ["/usr/local/bin/tt-smi"]
CMD ["--help"]

LABEL org.opencontainers.image.title="${IMAGE_NAME}"
LABEL org.opencontainers.image.version="${TT_SMI_VERSION}"
LABEL org.opencontainers.image.created="${BUILD_DATE}"
9 changes: 9 additions & 0 deletions docker/tt-smi/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# dstack TT SMI

An Ubuntu-based image with [TT SMI](https://github.com/tenstorrent/tt-smi/) preinstalled. Suitable for Tenstorrent GPU detection.

## Usage

```shell
docker run --device /dev/tenstorrent/ dstackai/tt-smi -s
```
130 changes: 104 additions & 26 deletions runner/internal/shim/host/gpu.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,38 +200,38 @@ func unmarshalTtSmiSnapshot(data []byte) (*ttSmiSnapshot, error) {
}

func getGpusFromTtSmiSnapshot(snapshot *ttSmiSnapshot) []GpuInfo {
// Group devices by board_id to aggregate memory for the same physical GPU
boardMap := make(map[string]*GpuInfo)
// Create a map to track "L" devices and their corresponding "R" devices
// Each "L" device becomes a separate GPU
lDeviceMap := make(map[string]*GpuInfo)
indexCounter := 0

for _, device := range snapshot.DeviceInfo {
// First pass: identify all "L" and "R" devices
for i, device := range snapshot.DeviceInfo {
boardID := device.BoardInfo.BoardID

// Extract board type without R/L suffix
boardType := strings.TrimSpace(device.BoardInfo.BoardType)
name := boardType

// Remove " R" or " L" suffix if present
if strings.HasSuffix(boardType, " R") {
name = boardType[:len(boardType)-2]
} else if strings.HasSuffix(boardType, " L") {
name = boardType[:len(boardType)-2]
}
// Determine if this is an "L" device
isLDevice := strings.HasSuffix(boardType, " L")

// Determine base VRAM based on board type
baseVram := 0
if strings.HasPrefix(name, "n150") {
baseVram = 12 * 1024 // 12GB in MiB
} else if strings.HasPrefix(name, "n300") {
baseVram = 12 * 1024 // 12GB in MiB
}
if isLDevice {
// Create unique identifier for this "L" device
uniqueID := fmt.Sprintf("%s_L_%d", boardID, i)

// Extract base name without L suffix
name := boardType[:len(boardType)-2]

// Determine base VRAM based on board type
baseVram := 0
if strings.HasPrefix(name, "n150") {
baseVram = 12 * 1024 // 12GB in MiB
} else if strings.HasPrefix(name, "n300") {
baseVram = 12 * 1024 // 12GB in MiB
} else if strings.HasPrefix(name, "tt-galaxy-wh") {
baseVram = 12 * 1024 // 12GB in MiB
}

if existingGpu, exists := boardMap[boardID]; exists {
// Aggregate VRAM for the same board_id
existingGpu.Vram += baseVram
} else {
// Create new GPU entry
boardMap[boardID] = &GpuInfo{
// Create new GPU entry for "L" device
lDeviceMap[uniqueID] = &GpuInfo{
Vendor: common.GpuVendorTenstorrent,
Name: name,
Vram: baseVram,
Expand All @@ -242,12 +242,90 @@ func getGpusFromTtSmiSnapshot(snapshot *ttSmiSnapshot) []GpuInfo {
}
}

// Second pass: add memory from "R" devices to corresponding "L" devices
for _, device := range snapshot.DeviceInfo {
boardID := device.BoardInfo.BoardID
boardType := strings.TrimSpace(device.BoardInfo.BoardType)

if strings.HasSuffix(boardType, " R") {
// Find the corresponding "L" device with the same board_id
// Since we need to match "R" to "L", we'll use the board_id as the key
// and add memory to the first "L" device we find with that board_id
for _, gpu := range lDeviceMap {
if gpu.ID == boardID {
// Extract base name without R suffix
name := boardType[:len(boardType)-2]

// Determine base VRAM based on board type
baseVram := 0
if strings.HasPrefix(name, "n150") {
baseVram = 12 * 1024 // 12GB in MiB
} else if strings.HasPrefix(name, "n300") {
baseVram = 12 * 1024 // 12GB in MiB
}

// Add memory to the "L" device
gpu.Vram += baseVram
break // Only add to the first matching "L" device
}
}
}
}

// Handle devices without L/R suffix (backward compatibility)
for i, device := range snapshot.DeviceInfo {
boardID := device.BoardInfo.BoardID
boardType := strings.TrimSpace(device.BoardInfo.BoardType)

if !strings.HasSuffix(boardType, " L") && !strings.HasSuffix(boardType, " R") {
// For devices without L/R suffix, treat them as standalone GPUs
// This maintains backward compatibility with existing data
uniqueID := fmt.Sprintf("%s_standalone_%d", boardID, i)

// Determine base VRAM based on board type
baseVram := 0
if strings.HasPrefix(boardType, "n150") {
baseVram = 12 * 1024 // 12GB in MiB
} else if strings.HasPrefix(boardType, "n300") {
baseVram = 12 * 1024 // 12GB in MiB
}

// Check if we already have a GPU with this board_id (old behavior)
existingGpu := false
for _, gpu := range lDeviceMap {
if gpu.ID == boardID {
gpu.Vram += baseVram
existingGpu = true
break
}
}

if !existingGpu {
// Create new GPU entry
lDeviceMap[uniqueID] = &GpuInfo{
Vendor: common.GpuVendorTenstorrent,
Name: boardType,
Vram: baseVram,
ID: boardID,
Index: strconv.Itoa(indexCounter),
}
indexCounter++
}
}
}

// Convert map to slice
var gpus []GpuInfo
for _, gpu := range boardMap {
for _, gpu := range lDeviceMap {
gpus = append(gpus, *gpu)
}

// Sort by the original index to ensure consistent ordering
// We'll reassign indices sequentially based on the original order
for i := range gpus {
gpus[i].Index = strconv.Itoa(i)
}

return gpus
}

Expand Down
52 changes: 52 additions & 0 deletions runner/internal/shim/host/gpu_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"os"
"path/filepath"
"reflect"
"strconv"
"testing"

"github.com/dstackai/dstack/runner/internal/common"
Expand Down Expand Up @@ -237,3 +238,54 @@ func TestGetGpusFromTtSmiSnapshotMultipleDevices(t *testing.T) {
}
}
}

func TestGetGpusFromTtSmiSnapshotGalaxy(t *testing.T) {
data, err := loadTestData("tenstorrent/galaxy.json")
if err != nil {
t.Fatalf("Failed to load test data: %v", err)
}
snapshot, err := unmarshalTtSmiSnapshot(data)
if err != nil {
t.Fatalf("Failed to unmarshal snapshot: %v", err)
}

gpus := getGpusFromTtSmiSnapshot(snapshot)

// Galaxy.json contains 32 devices with board_type "tt-galaxy-wh L"
// Each "L" device should be treated as a separate GPU
// Each "tt-galaxy-wh" device has 12GB VRAM
if len(gpus) != 32 {
t.Errorf("getGpusFromTtSmiSnapshot() returned %d GPUs, want 32", len(gpus))
}

// Calculate total VRAM: 32 devices × 12GB = 384GB
totalVram := 32 * 12 * 1024 // 32 devices × 12GB × 1024 MiB/GB
actualTotalVram := 0

// Verify all GPUs have the correct properties
for i, gpu := range gpus {
if gpu.Vendor != common.GpuVendorTenstorrent {
t.Errorf("GPU[%d] vendor = %v, want %v", i, gpu.Vendor, common.GpuVendorTenstorrent)
}
if gpu.Name != "tt-galaxy-wh" {
t.Errorf("GPU[%d] name = %s, want tt-galaxy-wh", i, gpu.Name)
}
if gpu.ID != "100035100000000" {
t.Errorf("GPU[%d] ID = %s, want 100035100000000", i, gpu.ID)
}
if gpu.Vram != 12*1024 {
t.Errorf("GPU[%d] VRAM = %d, want %d", i, gpu.Vram, 12*1024)
}
// Verify indices are sequential (0, 1, 2, ..., 31)
expectedIndex := strconv.Itoa(i)
if gpu.Index != expectedIndex {
t.Errorf("GPU[%d] index = %s, want %s", i, gpu.Index, expectedIndex)
}
actualTotalVram += gpu.Vram
}

// Verify total VRAM is 384GB
if actualTotalVram != totalVram {
t.Errorf("Total VRAM = %d MiB, want %d MiB (384GB)", actualTotalVram, totalVram)
}
}
Loading
Loading