diff --git a/charts/hami/templates/_helpers.tpl b/charts/hami/templates/_helpers.tpl index 7a882789e..8833421c8 100644 --- a/charts/hami/templates/_helpers.tpl +++ b/charts/hami/templates/_helpers.tpl @@ -233,6 +233,12 @@ Returns a YAML list that can be used directly or converted to JSON via fromYaml {{- $resources = append $resources (dict "name" . "ignoredByScheduler" true) -}} {{- end -}} {{- end -}} +{{/* Vastai resources */}} +{{- if .Values.devices.vastai.enabled -}} +{{- range .Values.devices.vastai.customresources -}} +{{- $resources = append $resources (dict "name" . "ignoredByScheduler" true) -}} +{{- end -}} +{{- end -}} {{/* AMD resources */}} {{- range .Values.devices.amd.customresources -}} {{- $resources = append $resources (dict "name" . "ignoredByScheduler" true) -}} diff --git a/charts/hami/templates/scheduler/device-configmap.yaml b/charts/hami/templates/scheduler/device-configmap.yaml index 92a5a020d..faea2bfa7 100644 --- a/charts/hami/templates/scheduler/device-configmap.yaml +++ b/charts/hami/templates/scheduler/device-configmap.yaml @@ -283,6 +283,8 @@ data: resourceCoreName: "aws.amazon.com/neuroncore" amd: resourceCountName: "amd.com/gpu" + vastai: + resourceCountName: {{ .Values.vastaiResourceName }} vnpus: - chipName: 910A commonWord: Ascend910A diff --git a/charts/hami/values.yaml b/charts/hami/values.yaml index 55af5e626..8726c3c5f 100644 --- a/charts/hami/values.yaml +++ b/charts/hami/values.yaml @@ -53,6 +53,9 @@ kunlunResourceName: "kunlunxin.com/xpu" kunlunResourceVCountName: "kunlunxin.com/vxpu" kunlunResourceVMemoryName: "kunlunxin.com/vxpu-memory" +#Vastai Parameters +vastaiResourceName: "vastaitech.com/va" + schedulerName: "hami-scheduler" podSecurityPolicy: @@ -440,6 +443,10 @@ devices: enabled: true customresources: - mthreads.com/vgpu + vastai: + enabled: true + customresources: + - vastaitech.com/va nvidia: gpuCorePolicy: default libCudaLogLevel: 1 diff --git a/docs/vastai-support.md b/docs/vastai-support.md new file mode 100644 index 000000000..564fdfd64 --- /dev/null +++ b/docs/vastai-support.md @@ -0,0 +1,231 @@ +## Introduction + +We now support sharing `vastaitech.com/va` (Vastaitech) devices and provides the following capabilities: + +***Supports both Full-Card mode and Die mode***: Only Full-Card mode and Die mode are supported currently. + +***die-mode topology awareness***: When multiple resources are requested in die mode, the scheduler will try to allocate them on the same AIC whenever possible. + +***Device UUID selection***: You can specify or exclude particular devices through annotations. + +## Using Vastai Devices + +### Enabling Vastai Device Sharing + +#### Label the Node + +``` +kubectl label node {vastai-node} vastai=on +``` + +#### Deploy the `vastai-device-plugin` + +##### Full Card Mode + +``` +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRole +metadata: + name: hami-vastai +rules: + - apiGroups: [""] + resources: ["pods"] + verbs: ["get", "list", "update", "watch", "patch"] + - apiGroups: [""] + resources: ["nodes"] + verbs: ["get", "update", "patch"] +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRoleBinding +metadata: + name: hami-vastai +subjects: + - kind: ServiceAccount + name: hami-vastai + namespace: kube-system +roleRef: + kind: ClusterRole + name: hami-vastai + apiGroup: rbac.authorization.k8s.io +--- +apiVersion: v1 +kind: ServiceAccount +metadata: + name: hami-vastai + namespace: kube-system + labels: + app.kubernetes.io/component: "hami-vastai" +--- +apiVersion: apps/v1 +kind: DaemonSet +metadata: + name: vastai-device-plugin-daemonset + namespace: kube-system + labels: + app.kubernetes.io/component: hami-vastai-device-plugin +spec: + selector: + matchLabels: + app.kubernetes.io/component: hami-vastai-device-plugin + hami.io/webhook: ignore + updateStrategy: + type: RollingUpdate + template: + metadata: + labels: + app.kubernetes.io/component: hami-vastai-device-plugin + hami.io/webhook: ignore + spec: + priorityClassName: "system-node-critical" + serviceAccountName: hami-vastai + nodeSelector: + vastai-device: "vastai" + containers: + - image: projecthami/vastai-device-plugin:latest + imagePullPolicy: Always + name: vastai-device-plugin-dp + env: + - name: NODE_NAME + valueFrom: + fieldRef: + fieldPath: spec.nodeName + args: ["--fail-on-init-error=false", "--pass-device-specs=true"] + securityContext: + privileged: true + volumeMounts: + - name: device-plugin + mountPath: /var/lib/kubelet/device-plugins + - name: libvaml-lib + mountPath: /usr/lib/libvaml.so + - name: libvaml-lib64 + mountPath: /usr/lib64/libvaml.so + volumes: + - name: device-plugin + hostPath: + path: /var/lib/kubelet/device-plugins + - name: libvaml-lib + hostPath: + path: /usr/lib/libvaml.so + - name: libvaml-lib64 + hostPath: + path: /usr/lib64/libvaml.so + nodeSelector: + vastai: "on" +``` + +##### Die Mode + +``` +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRole +metadata: + name: hami-vastai +rules: + - apiGroups: [""] + resources: ["pods"] + verbs: ["get", "list", "update", "watch", "patch"] + - apiGroups: [""] + resources: ["nodes"] + verbs: ["get", "update", "patch"] +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRoleBinding +metadata: + name: hami-vastai +subjects: + - kind: ServiceAccount + name: hami-vastai + namespace: kube-system +roleRef: + kind: ClusterRole + name: hami-vastai + apiGroup: rbac.authorization.k8s.io +--- +apiVersion: v1 +kind: ServiceAccount +metadata: + name: hami-vastai + namespace: kube-system + labels: + app.kubernetes.io/component: "hami-vastai" +--- +apiVersion: apps/v1 +kind: DaemonSet +metadata: + name: vastai-device-plugin-daemonset + namespace: kube-system + labels: + app.kubernetes.io/component: hami-vastai-device-plugin +spec: + selector: + matchLabels: + app.kubernetes.io/component: hami-vastai-device-plugin + hami.io/webhook: ignore + updateStrategy: + type: RollingUpdate + template: + metadata: + labels: + app.kubernetes.io/component: hami-vastai-device-plugin + hami.io/webhook: ignore + spec: + priorityClassName: "system-node-critical" + serviceAccountName: hami-vastai + nodeSelector: + vastai-device: "vastai" + containers: + - image: projecthami/vastai-device-plugin:latest + imagePullPolicy: Always + name: vastai-device-plugin-dp + env: + - name: NODE_NAME + valueFrom: + fieldRef: + fieldPath: spec.nodeName + args: ["--fail-on-init-error=false", "--pass-device-specs=true", "--device-strategy=die", "--rename-on-die=false"] + securityContext: + privileged: true + volumeMounts: + - name: device-plugin + mountPath: /var/lib/kubelet/device-plugins + - name: libvaml-lib + mountPath: /usr/lib/libvaml.so + - name: libvaml-lib64 + mountPath: /usr/lib64/libvaml.so + volumes: + - name: device-plugin + hostPath: + path: /var/lib/kubelet/device-plugins + - name: libvaml-lib + hostPath: + path: /usr/lib/libvaml.so + - name: libvaml-lib64 + hostPath: + path: /usr/lib64/libvaml.so + nodeSelector: + vastai: "on" +``` + +### Run Vastai jobs + +``` +apiVersion: v1 +kind: Pod +metadata: + name: vastai-pod +spec: + restartPolicy: Never + containers: + - name: vastai-container + image: harbor.vastaitech.com/ai_deliver/vllm_vacc:VVI-25.12.SP2 + command: ["sleep", "infinity"] + resources: + limits: + vastaitech.com/va: "1" +``` + +## Notes +1. When requesting Vastai resources, you cannot specify the memory size. +2. The `vastai-device-plugin` does not mount the `vasmi` into the container.If you need to use the `vasmi` command inside the container, please mount it manually. \ No newline at end of file diff --git a/docs/vastai-support_cn.md b/docs/vastai-support_cn.md new file mode 100644 index 000000000..5cd00f341 --- /dev/null +++ b/docs/vastai-support_cn.md @@ -0,0 +1,231 @@ +## 简介 + +本组件支持复用瀚博设备,并为此提供以下几种复用功能,包括: + +***支持整卡模式和die模式***: 目前只支持整卡模式和die模式 + +***die模式拓扑感知***: die模式下,申请多个资源时尽可能的分配到同一个AIC上 + +***设备 UUID 选择***: 你可以通过注解指定使用或排除特定的设备 + +## 复用瀚博设备 + +### 开启复用瀚博设备 + +#### 给node打标签 + +``` +kubectl label node {vastai-node} vastai=on +``` + +#### 部署 vastai-device-plugin + +##### 整卡模式 + +``` +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRole +metadata: + name: hami-vastai +rules: + - apiGroups: [""] + resources: ["pods"] + verbs: ["get", "list", "update", "watch", "patch"] + - apiGroups: [""] + resources: ["nodes"] + verbs: ["get", "update", "patch"] +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRoleBinding +metadata: + name: hami-vastai +subjects: + - kind: ServiceAccount + name: hami-vastai + namespace: kube-system +roleRef: + kind: ClusterRole + name: hami-vastai + apiGroup: rbac.authorization.k8s.io +--- +apiVersion: v1 +kind: ServiceAccount +metadata: + name: hami-vastai + namespace: kube-system + labels: + app.kubernetes.io/component: "hami-vastai" +--- +apiVersion: apps/v1 +kind: DaemonSet +metadata: + name: vastai-device-plugin-daemonset + namespace: kube-system + labels: + app.kubernetes.io/component: hami-vastai-device-plugin +spec: + selector: + matchLabels: + app.kubernetes.io/component: hami-vastai-device-plugin + hami.io/webhook: ignore + updateStrategy: + type: RollingUpdate + template: + metadata: + labels: + app.kubernetes.io/component: hami-vastai-device-plugin + hami.io/webhook: ignore + spec: + priorityClassName: "system-node-critical" + serviceAccountName: hami-vastai + nodeSelector: + vastai-device: "vastai" + containers: + - image: projecthami/vastai-device-plugin:latest + imagePullPolicy: Always + name: vastai-device-plugin-dp + env: + - name: NODE_NAME + valueFrom: + fieldRef: + fieldPath: spec.nodeName + args: ["--fail-on-init-error=false", "--pass-device-specs=true"] + securityContext: + privileged: true + volumeMounts: + - name: device-plugin + mountPath: /var/lib/kubelet/device-plugins + - name: libvaml-lib + mountPath: /usr/lib/libvaml.so + - name: libvaml-lib64 + mountPath: /usr/lib64/libvaml.so + volumes: + - name: device-plugin + hostPath: + path: /var/lib/kubelet/device-plugins + - name: libvaml-lib + hostPath: + path: /usr/lib/libvaml.so + - name: libvaml-lib64 + hostPath: + path: /usr/lib64/libvaml.so + nodeSelector: + vastai: "on" +``` + +##### die 模式 + +``` +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRole +metadata: + name: hami-vastai +rules: + - apiGroups: [""] + resources: ["pods"] + verbs: ["get", "list", "update", "watch", "patch"] + - apiGroups: [""] + resources: ["nodes"] + verbs: ["get", "update", "patch"] +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRoleBinding +metadata: + name: hami-vastai +subjects: + - kind: ServiceAccount + name: hami-vastai + namespace: kube-system +roleRef: + kind: ClusterRole + name: hami-vastai + apiGroup: rbac.authorization.k8s.io +--- +apiVersion: v1 +kind: ServiceAccount +metadata: + name: hami-vastai + namespace: kube-system + labels: + app.kubernetes.io/component: "hami-vastai" +--- +apiVersion: apps/v1 +kind: DaemonSet +metadata: + name: vastai-device-plugin-daemonset + namespace: kube-system + labels: + app.kubernetes.io/component: hami-vastai-device-plugin +spec: + selector: + matchLabels: + app.kubernetes.io/component: hami-vastai-device-plugin + hami.io/webhook: ignore + updateStrategy: + type: RollingUpdate + template: + metadata: + labels: + app.kubernetes.io/component: hami-vastai-device-plugin + hami.io/webhook: ignore + spec: + priorityClassName: "system-node-critical" + serviceAccountName: hami-vastai + nodeSelector: + vastai-device: "vastai" + containers: + - image: projecthami/vastai-device-plugin:latest + imagePullPolicy: Always + name: vastai-device-plugin-dp + env: + - name: NODE_NAME + valueFrom: + fieldRef: + fieldPath: spec.nodeName + args: ["--fail-on-init-error=false", "--pass-device-specs=true", "--device-strategy=die", "--rename-on-die=false"] + securityContext: + privileged: true + volumeMounts: + - name: device-plugin + mountPath: /var/lib/kubelet/device-plugins + - name: libvaml-lib + mountPath: /usr/lib/libvaml.so + - name: libvaml-lib64 + mountPath: /usr/lib64/libvaml.so + volumes: + - name: device-plugin + hostPath: + path: /var/lib/kubelet/device-plugins + - name: libvaml-lib + hostPath: + path: /usr/lib/libvaml.so + - name: libvaml-lib64 + hostPath: + path: /usr/lib64/libvaml.so + nodeSelector: + vastai: "on" +``` + +### 运行瀚博任务 + +``` +apiVersion: v1 +kind: Pod +metadata: + name: vastai-pod +spec: + restartPolicy: Never + containers: + - name: vastai-container + image: harbor.vastaitech.com/ai_deliver/vllm_vacc:VVI-25.12.SP2 + command: ["sleep", "infinity"] + resources: + limits: + vastaitech.com/va: "1" +``` + +## 注意事项 +1. 申请瀚博资源时不可以指定显存大小 +2. `vastai-device-plugin` 没有把 `vasmi` 文件挂载到容器中。如果想在容器里使用 `vasmi` 命令,请自行挂载 diff --git a/examples/vastai/default_use.yaml b/examples/vastai/default_use.yaml new file mode 100644 index 000000000..f0f1f688d --- /dev/null +++ b/examples/vastai/default_use.yaml @@ -0,0 +1,12 @@ +apiVersion: v1 +kind: Pod +metadata: + name: vastai-pod +spec: + containers: + - name: test + image: harbor.vastaitech.com/ai_deliver/vllm_vacc:VVI-25.12.SP2 + command: ["sleep", "infinity"] + resources: + limits: + vastaitech.com/va: 1 \ No newline at end of file diff --git a/pkg/device/vastai/device.go b/pkg/device/vastai/device.go new file mode 100644 index 000000000..a39565498 --- /dev/null +++ b/pkg/device/vastai/device.go @@ -0,0 +1,357 @@ +/* +Copyright 2026 The HAMi Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package vastai + +import ( + "errors" + "fmt" + "sort" + "strings" + + "github.com/Project-HAMi/HAMi/pkg/device" + "github.com/Project-HAMi/HAMi/pkg/device/common" + "github.com/Project-HAMi/HAMi/pkg/util" + "github.com/Project-HAMi/HAMi/pkg/util/nodelock" + + corev1 "k8s.io/api/core/v1" + "k8s.io/klog/v2" +) + +type VastaiDevices struct { +} + +const ( + HandshakeAnnos = "hami.io/node-handshake-va" + RegisterAnnos = "hami.io/node-va-register" + VastaiDevice = "Vastai" + VastaiCommonWord = "Vastai" + VastaiInUse = "vastaitech.com/use-va" + VastaiNoUse = "vastaitech.com/nouse-va" + VastaiUseUUID = "vastaitech.com/use-gpuuuid" + VastaiNoUseUUID = "vastaitech.com/nouse-gpuuuid" +) + +var ( + VastaiResourceCount string +) + +type VastaiConfig struct { + ResourceCountName string `yaml:"resourceCountName"` +} + +func InitVastaiDevice(config VastaiConfig) *VastaiDevices { + VastaiResourceCount = config.ResourceCountName + commonWord := VastaiCommonWord + _, ok := device.InRequestDevices[commonWord] + if !ok { + device.InRequestDevices[commonWord] = fmt.Sprintf("hami.io/%s-devices-to-allocate", commonWord) + device.SupportDevices[commonWord] = fmt.Sprintf("hami.io/%s-devices-allocated", commonWord) + util.HandshakeAnnos[commonWord] = HandshakeAnnos + } + return &VastaiDevices{} +} + +func (dev *VastaiDevices) CommonWord() string { + return VastaiCommonWord +} + +func (dev *VastaiDevices) GetNodeDevices(n corev1.Node) ([]*device.DeviceInfo, error) { + devEncoded, ok := n.Annotations[RegisterAnnos] + if !ok { + return []*device.DeviceInfo{}, errors.New("annos not found " + RegisterAnnos) + } + nodedevices, err := device.UnMarshalNodeDevices(devEncoded) + if err != nil { + klog.ErrorS(err, "failed to decode node devices", "node", n.Name, "device annotation", devEncoded) + return []*device.DeviceInfo{}, err + } + klog.V(5).InfoS("nodes device information", "node", n.Name, "nodedevices", devEncoded) + for idx := range nodedevices { + nodedevices[idx].DeviceVendor = VastaiCommonWord + nodedevices[idx].Devcore = 100 // only for calscore use + } + if len(nodedevices) == 0 { + klog.InfoS("no vastai device found", "node", n.Name, "device annotation", devEncoded) + return []*device.DeviceInfo{}, errors.New("no gpu found on node") + } + return nodedevices, nil +} + +func (dev *VastaiDevices) MutateAdmission(ctr *corev1.Container, p *corev1.Pod) (bool, error) { + _, ok := ctr.Resources.Limits[corev1.ResourceName(VastaiResourceCount)] + return ok, nil +} + +func (dev *VastaiDevices) LockNode(n *corev1.Node, p *corev1.Pod) error { + found := false + for _, val := range p.Spec.Containers { + if (dev.GenerateResourceRequests(&val).Nums) > 0 { + found = true + break + } + } + if !found { + return nil + } + return nodelock.LockNode(n.Name, nodelock.NodeLockKey, p) +} + +func (dev *VastaiDevices) ReleaseNodeLock(n *corev1.Node, p *corev1.Pod) error { + found := false + for _, val := range p.Spec.Containers { + if (dev.GenerateResourceRequests(&val).Nums) > 0 { + found = true + break + } + } + if !found { + return nil + } + return nodelock.ReleaseNodeLock(n.Name, nodelock.NodeLockKey, p, false) +} + +func (dev *VastaiDevices) NodeCleanUp(nn string) error { + return util.MarkAnnotationsToDelete(HandshakeAnnos, nn) +} + +func (dev *VastaiDevices) checkType(annos map[string]string, d device.DeviceUsage, n device.ContainerDeviceRequest) (bool, bool, bool) { + if strings.Compare(n.Type, VastaiDevice) == 0 { + return true, true, false + } + return false, false, false +} + +func (dev *VastaiDevices) CheckHealth(devType string, n *corev1.Node) (bool, bool) { + return device.CheckHealth(devType, n) +} + +func (dev *VastaiDevices) GenerateResourceRequests(ctr *corev1.Container) device.ContainerDeviceRequest { + klog.V(5).Info("Start to count vastai devices for container ", ctr.Name) + vastaiResourceCount := corev1.ResourceName(VastaiResourceCount) + v, ok := ctr.Resources.Limits[vastaiResourceCount] + if !ok { + v, ok = ctr.Resources.Requests[vastaiResourceCount] + } + if ok { + if n, ok := v.AsInt64(); ok { + klog.Info("Found vastai devices") + memnum := 0 + corenum := int32(0) + mempnum := 100 + + return device.ContainerDeviceRequest{ + Nums: int32(n), + Type: VastaiDevice, + Memreq: int32(memnum), + MemPercentagereq: int32(mempnum), + Coresreq: corenum, + } + } + } + return device.ContainerDeviceRequest{} +} + +func (dev *VastaiDevices) PatchAnnotations(pod *corev1.Pod, annoinput *map[string]string, pd device.PodDevices) map[string]string { + devlist, ok := pd[VastaiDevice] + if ok && len(devlist) > 0 { + deviceStr := device.EncodePodSingleDevice(devlist) + (*annoinput)[device.InRequestDevices[VastaiDevice]] = deviceStr + (*annoinput)[device.SupportDevices[VastaiDevice]] = deviceStr + klog.V(5).Infof("pod add notation key [%s], values is [%s]", device.InRequestDevices[VastaiDevice], deviceStr) + klog.V(5).Infof("pod add notation key [%s], values is [%s]", device.SupportDevices[VastaiDevice], deviceStr) + } + return *annoinput +} + +func (dev *VastaiDevices) ScoreNode(node *corev1.Node, podDevices device.PodSingleDevice, previous []*device.DeviceUsage, policy string) float32 { + score := float32(0) + for _, containerDevices := range podDevices { + if len(containerDevices) == 0 { + continue + } + cntMap := make(map[string]int) + for _, device := range containerDevices { + if device.CustomInfo == nil { + return 0 + } + if strategy, ok := device.CustomInfo["DeviceStrategy"]; ok { + if val, ok := strategy.(string); ok && val != "die" { + return 0 + } + } + if AIC, ok := device.CustomInfo["AIC"]; ok { + if id, ok := AIC.(string); ok { + cntMap[id]++ + } + } else { + return 0 + } + } + maxCnt, totalCnt := 0, 0 + for _, cnt := range cntMap { + if cnt > maxCnt { + maxCnt = cnt + } + totalCnt += cnt + } + if totalCnt == 0 { + continue + } + score += float32(maxCnt) / float32(totalCnt) + } + klog.V(4).InfoS("ScoreNode", "node", node.Name, "deviceType", dev.CommonWord(), "score", score) + return score +} + +func (dev *VastaiDevices) AddResourceUsage(pod *corev1.Pod, n *device.DeviceUsage, ctr *device.ContainerDevice) error { + n.Used++ + n.Usedcores += ctr.Usedcores + n.Usedmem += ctr.Usedmem + return nil +} + +func (va *VastaiDevices) Fit(devices []*device.DeviceUsage, request device.ContainerDeviceRequest, pod *corev1.Pod, nodeInfo *device.NodeInfo, allocated *device.PodDevices) (bool, map[string]device.ContainerDevices, string) { + k := request + originReq := k.Nums + klog.InfoS("Allocating device for container request", "pod", klog.KObj(pod), "card request", k) + tmpDevs := make(map[string]device.ContainerDevices) + reason := make(map[string]int) + dieMode := isDieMode(devices) + for i := range len(devices) { + dev := devices[i] + klog.V(4).InfoS("scoring pod", "pod", klog.KObj(pod), "device", dev.ID, "Memreq", k.Memreq, "MemPercentagereq", k.MemPercentagereq, "Coresreq", k.Coresreq, "Nums", k.Nums, "device index", i) + + _, found, _ := va.checkType(pod.GetAnnotations(), *dev, k) + if !found { + reason[common.CardTypeMismatch]++ + klog.V(5).InfoS(common.CardTypeMismatch, "pod", klog.KObj(pod), "device", dev.ID, dev.Type, k.Type) + continue + } + if !device.CheckUUID(pod.GetAnnotations(), dev.ID, VastaiUseUUID, VastaiNoUseUUID, VastaiCommonWord) { + reason[common.CardUUIDMismatch]++ + klog.V(5).InfoS(common.CardUUIDMismatch, "pod", klog.KObj(pod), "device", dev.ID, "current device info is:", *dev) + continue + } + + if dev.Count <= dev.Used { + reason[common.CardTimeSlicingExhausted]++ + klog.V(5).InfoS(common.CardTimeSlicingExhausted, "pod", klog.KObj(pod), "device", dev.ID, "count", dev.Count, "used", dev.Used) + continue + } + if k.Nums > 0 { + klog.V(5).InfoS("find fit device", "pod", klog.KObj(pod), "device", dev.ID) + if !dieMode { + k.Nums-- + } + tmpDevs[k.Type] = append(tmpDevs[k.Type], device.ContainerDevice{ + Idx: int(dev.Index), + UUID: dev.ID, + Type: k.Type, + Usedcores: k.Coresreq, + CustomInfo: dev.CustomInfo, + }) + } + if k.Nums == 0 && !dieMode { + klog.V(4).InfoS("device allocate success", "pod", klog.KObj(pod), "allocate device", tmpDevs) + return true, tmpDevs, "" + } + + } + + if dieMode { + if len(tmpDevs[k.Type]) == int(originReq) { + klog.V(5).InfoS("device allocate success", "pod", klog.KObj(pod), "allocate device", tmpDevs) + return true, tmpDevs, "" + } else if len(tmpDevs[k.Type]) > int(originReq) { + if originReq == 1 { + tmpDevs[k.Type] = device.ContainerDevices{tmpDevs[k.Type][0]} + } else { + // If requesting multiple devices, select the best combination of cards. + tmpDevs[k.Type] = va.computeBestCombination(int(originReq), tmpDevs[k.Type]) + } + klog.V(5).InfoS("device allocate success", "pod", klog.KObj(pod), "best device combination", tmpDevs) + return true, tmpDevs, "" + } + } + if len(tmpDevs) > 0 { + reason[common.AllocatedCardsInsufficientRequest] = len(tmpDevs) + klog.V(5).InfoS(common.AllocatedCardsInsufficientRequest, "pod", klog.KObj(pod), "request", originReq, "allocated", len(tmpDevs)) + } + return false, tmpDevs, common.GenReason(reason, len(devices)) +} + +func (dev *VastaiDevices) GetResourceNames() device.ResourceNames { + return device.ResourceNames{ + ResourceCountName: VastaiResourceCount, + } +} + +func (dev *VastaiDevices) computeBestCombination(reqNum int, containerDevices device.ContainerDevices) device.ContainerDevices { + deviceMap := make(map[string]device.ContainerDevices) + for _, dev := range containerDevices { + if dev.CustomInfo != nil { + if AIC, ok := dev.CustomInfo["AIC"]; ok { + if id, ok := AIC.(string); ok { + deviceMap[id] = append(deviceMap[id], dev) + } + } + } + } + + type DeviceCount struct { + ID string + Count int + } + var sortedDevices []DeviceCount + for id, devices := range deviceMap { + sortedDevices = append(sortedDevices, DeviceCount{ + ID: id, + Count: len(devices), + }) + } + + sort.SliceStable(sortedDevices, func(i, j int) bool { + return sortedDevices[i].Count > sortedDevices[j].Count + }) + result := device.ContainerDevices{} + for _, item := range sortedDevices { + devices := deviceMap[item.ID] + for _, dev := range devices { + result = append(result, dev) + if len(result) == reqNum { + return result + } + } + } + return result +} + +func isDieMode(devices []*device.DeviceUsage) bool { + if len(devices) == 0 { + return false + } + dev := devices[0] + if dev.CustomInfo == nil { + return false + } + if strategy, ok := dev.CustomInfo["DeviceStrategy"]; ok { + if val, ok := strategy.(string); ok && val == "die" { + return true + } + } + return false +} diff --git a/pkg/device/vastai/device_test.go b/pkg/device/vastai/device_test.go new file mode 100644 index 000000000..2a5dbc385 --- /dev/null +++ b/pkg/device/vastai/device_test.go @@ -0,0 +1,956 @@ +/* +Copyright 2026 The HAMi Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package vastai + +import ( + "context" + "errors" + "fmt" + "testing" + + "gotest.tools/v3/assert" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/kubernetes/fake" + "k8s.io/klog/v2" + + "github.com/Project-HAMi/HAMi/pkg/device" + "github.com/Project-HAMi/HAMi/pkg/util" + "github.com/Project-HAMi/HAMi/pkg/util/client" + "github.com/Project-HAMi/HAMi/pkg/util/nodelock" +) + +func Test_MutateAdmission(t *testing.T) { + config := VastaiConfig{ + ResourceCountName: "vastaitech.com/va", + } + InitVastaiDevice(config) + tests := []struct { + name string + args struct { + ctr *corev1.Container + p *corev1.Pod + } + want bool + err error + }{ + { + name: "set to resource limits", + args: struct { + ctr *corev1.Container + p *corev1.Pod + }{ + ctr: &corev1.Container{ + Resources: corev1.ResourceRequirements{ + Limits: corev1.ResourceList{ + "vastaitech.com/va": resource.MustParse("1"), + }, + }, + }, + p: &corev1.Pod{}, + }, + want: true, + err: nil, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + dev := VastaiDevices{} + result, err := dev.MutateAdmission(test.args.ctr, test.args.p) + if err != test.err { + klog.InfoS("set to resource limits failed") + } + assert.Equal(t, result, test.want) + }) + } +} + +func Test_GetNodeDevices(t *testing.T) { + dev := VastaiDevices{} + tests := []struct { + name string + args corev1.Node + want []*device.DeviceInfo + err error + }{ + { + name: "no annotation", + args: corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: map[string]string{}, + }, + }, + want: []*device.DeviceInfo{}, + err: errors.New("annos not found " + RegisterAnnos), + }, + { + name: "exist vastai device", + args: corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: map[string]string{ + "hami.io/node-va-register": "[{\"id\":\"7-0\",\"count\":1,\"devmem\":32768,\"type\":\"Vastai\",\"health\":true,\"devicepairscore\":{}}]", + }, + }, + }, + want: []*device.DeviceInfo{ + { + ID: "7-0", + Count: int32(1), + Devmem: int32(32768), + Devcore: int32(100), + Type: dev.CommonWord(), + Numa: 0, + Health: true, + Index: uint(0), + Mode: "", + DeviceVendor: VastaiCommonWord, + }, + }, + err: nil, + }, + { + name: "no vasta device", + args: corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: map[string]string{ + "hami.io/node-va-register": ":", + }, + }, + }, + want: []*device.DeviceInfo{}, + err: errors.New("no gpu found on node"), + }, + { + name: "node annotations not decode successfully", + args: corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: map[string]string{ + "hami.io/node-va-register": "", + }, + }, + }, + want: []*device.DeviceInfo{}, + err: errors.New("node annotations not decode successfully"), + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + result, err := dev.GetNodeDevices(test.args) + if err != nil { + klog.Errorf("got %v, want %v", err, test.err) + } + assert.DeepEqual(t, result, test.want) + }) + } +} + +func Test_CheckHealth(t *testing.T) { + tests := []struct { + name string + args struct { + devType string + n *corev1.Node + } + want1 bool + want2 bool + }{ + { + name: "Requesting state expired", + args: struct { + devType string + n *corev1.Node + }{ + devType: "vastaitech.com/va", + n: &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: map[string]string{ + util.HandshakeAnnos["hami.io/node-handshake-va"]: "Requesting_2025-01-07 00:00:00", + }, + }, + }, + }, + want1: false, + want2: false, + }, + { + name: "Deleted state", + args: struct { + devType string + n *corev1.Node + }{ + devType: "vastaitech.com/va", + n: &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: map[string]string{ + util.HandshakeAnnos["hami.io/node-handshake-va"]: "Deleted", + }, + }, + }, + }, + want1: true, + want2: false, + }, + { + name: "Unknown state", + args: struct { + devType string + n *corev1.Node + }{ + devType: "vastaitech.com/va", + n: &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: map[string]string{ + util.HandshakeAnnos["hami.io/node-handshake-va"]: "Unknown", + }, + }, + }, + }, + want1: true, + want2: true, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + dev := VastaiDevices{} + result1, result2 := dev.CheckHealth(test.args.devType, test.args.n) + assert.Equal(t, result1, test.want1) + assert.Equal(t, result2, test.want2) + }) + } +} + +func Test_checkType(t *testing.T) { + tests := []struct { + name string + args struct { + annos map[string]string + d device.DeviceUsage + n device.ContainerDeviceRequest + } + want1 bool + want2 bool + want3 bool + }{ + { + name: "the same type", + args: struct { + annos map[string]string + d device.DeviceUsage + n device.ContainerDeviceRequest + }{ + annos: map[string]string{}, + d: device.DeviceUsage{ + Type: "Vastai", + }, + n: device.ContainerDeviceRequest{ + Type: "Vastai", + }, + }, + want1: true, + want2: true, + want3: false, + }, + { + name: "the different type", + args: struct { + annos map[string]string + d device.DeviceUsage + n device.ContainerDeviceRequest + }{ + annos: map[string]string{}, + d: device.DeviceUsage{ + Type: "Vastai", + }, + n: device.ContainerDeviceRequest{ + Type: "test", + }, + }, + want1: false, + want2: false, + want3: false, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + dev := VastaiDevices{} + result1, result2, result3 := dev.checkType(test.args.annos, test.args.d, test.args.n) + assert.Equal(t, result1, test.want1) + assert.Equal(t, result2, test.want2) + assert.Equal(t, result3, test.want3) + }) + } +} + +func Test_PatchAnnotations(t *testing.T) { + tests := []struct { + name string + args struct { + annoinput *map[string]string + pd device.PodDevices + } + want map[string]string + }{ + { + name: "exist device", + args: struct { + annoinput *map[string]string + pd device.PodDevices + }{ + annoinput: &map[string]string{}, + pd: device.PodDevices{ + VastaiDevice: device.PodSingleDevice{ + []device.ContainerDevice{ + { + Idx: 1, + UUID: "test1", + Type: VastaiDevice, + Usedmem: int32(2048), + Usedcores: int32(1), + }, + }, + }, + }, + }, + want: map[string]string{ + device.InRequestDevices[VastaiDevice]: "test1,Vastai,2048,1:;", + device.SupportDevices[VastaiDevice]: "test1,Vastai,2048,1:;", + }, + }, + { + name: "no device", + args: struct { + annoinput *map[string]string + pd device.PodDevices + }{ + annoinput: &map[string]string{}, + pd: device.PodDevices{}, + }, + want: map[string]string{}, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + dev := VastaiDevices{} + result := dev.PatchAnnotations(&corev1.Pod{}, test.args.annoinput, test.args.pd) + assert.DeepEqual(t, result, test.want) + }) + } +} + +func Test_GenerateResourceRequests(t *testing.T) { + tests := []struct { + name string + args *corev1.Container + want device.ContainerDeviceRequest + }{ + { + name: "don't set to limits and request", + args: &corev1.Container{ + Resources: corev1.ResourceRequirements{ + Limits: corev1.ResourceList{}, + Requests: corev1.ResourceList{}, + }, + }, + want: device.ContainerDeviceRequest{}, + }, + { + name: "set to limits and request", + args: &corev1.Container{ + Resources: corev1.ResourceRequirements{ + Limits: corev1.ResourceList{ + "vastaitech.com/va": resource.MustParse("1"), + }, + Requests: corev1.ResourceList{ + "vastaitech.com/va": resource.MustParse("1"), + }, + }, + }, + want: device.ContainerDeviceRequest{ + Nums: int32(1), + Type: VastaiDevice, + Memreq: int32(0), + MemPercentagereq: int32(100), + Coresreq: int32(0), + }, + }, + { + name: "only set to limits", + args: &corev1.Container{ + Resources: corev1.ResourceRequirements{ + Limits: corev1.ResourceList{ + "vastaitech.com/va": resource.MustParse("1"), + }, + }, + }, + want: device.ContainerDeviceRequest{ + Nums: int32(1), + Type: VastaiDevice, + Memreq: int32(0), + MemPercentagereq: int32(100), + Coresreq: int32(0), + }, + }, + { + name: "only set to request", + args: &corev1.Container{ + Resources: corev1.ResourceRequirements{ + Requests: corev1.ResourceList{ + "vastaitech.com/va": resource.MustParse("1"), + }, + }, + }, + want: device.ContainerDeviceRequest{ + Nums: int32(1), + Type: VastaiDevice, + Memreq: int32(0), + MemPercentagereq: int32(100), + Coresreq: int32(0), + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + dev := VastaiDevices{} + result := dev.GenerateResourceRequests(test.args) + assert.DeepEqual(t, result, test.want) + }) + } +} + +func TestDevices_LockNode(t *testing.T) { + tests := []struct { + name string + node *corev1.Node + pod *corev1.Pod + hasLock bool + expectError bool + }{ + { + name: "Test with no containers", + node: &corev1.Node{}, + pod: &corev1.Pod{Spec: corev1.PodSpec{}}, + hasLock: false, + expectError: false, + }, + { + name: "Test with non-zero resource requests", + node: &corev1.Node{}, + pod: &corev1.Pod{Spec: corev1.PodSpec{Containers: []corev1.Container{{Resources: corev1.ResourceRequirements{Requests: corev1.ResourceList{ + "vastaitech.com/va": resource.MustParse("1"), + }}}}}}, + hasLock: true, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Initialize fake clientset and pre-load test data + client.KubeClient = fake.NewSimpleClientset() + node := &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Name: "testNode", + Annotations: map[string]string{"test-annotation-key": "test-annotation-value", device.InRequestDevices["DCU"]: "some-value"}, + }, + } + + // Add the node to the fake clientset + _, err := client.KubeClient.CoreV1().Nodes().Create(context.Background(), node, metav1.CreateOptions{}) + if err != nil { + t.Fatalf("Failed to create test node: %v", err) + } + + dev := InitVastaiDevice(VastaiConfig{ + ResourceCountName: "vastaitech.com/va", + }) + err = dev.LockNode(node, tt.pod) + if tt.expectError { + assert.Equal(t, err != nil, true) + } else { + assert.NilError(t, err) + } + node, err = client.KubeClient.CoreV1().Nodes().Get(context.Background(), node.Name, metav1.GetOptions{}) + assert.NilError(t, err) + fmt.Println(node.Annotations) + _, ok := node.Annotations[nodelock.NodeLockKey] + assert.Equal(t, ok, tt.hasLock) + }) + } +} + +func TestDevices_ReleaseNodeLock(t *testing.T) { + tests := []struct { + name string + node *corev1.Node + pod *corev1.Pod + hasLock bool + expectError bool + }{ + { + name: "Test with no containers", + node: &corev1.Node{}, + pod: &corev1.Pod{Spec: corev1.PodSpec{}}, + hasLock: true, + expectError: false, + }, + { + name: "Test with non-zero resource requests", + node: &corev1.Node{}, + pod: &corev1.Pod{ObjectMeta: metav1.ObjectMeta{ + Name: "nozerorr", + Namespace: "default", + }, Spec: corev1.PodSpec{Containers: []corev1.Container{{Resources: corev1.ResourceRequirements{Requests: corev1.ResourceList{ + "vastaitech.com/va": resource.MustParse("1"), + }}}}}}, + hasLock: false, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Initialize fake clientset and pre-load test data + client.KubeClient = fake.NewSimpleClientset() + node := &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Name: "testNode", + Annotations: map[string]string{"test-annotation-key": "test-annotation-value", device.InRequestDevices[VastaiDevice]: "some-value", nodelock.NodeLockKey: "lock-values,default,nozerorr"}, + }, + } + + // Add the node to the fake clientset + _, err := client.KubeClient.CoreV1().Nodes().Create(context.Background(), node, metav1.CreateOptions{}) + if err != nil { + t.Fatalf("Failed to create test node: %v", err) + } + + dev := InitVastaiDevice(VastaiConfig{ + ResourceCountName: "vastaitech.com/va", + }) + err = dev.ReleaseNodeLock(node, tt.pod) + if tt.expectError { + assert.Equal(t, err != nil, true) + } else { + assert.NilError(t, err) + } + node, err = client.KubeClient.CoreV1().Nodes().Get(context.Background(), node.Name, metav1.GetOptions{}) + assert.NilError(t, err) + fmt.Println(node.Annotations) + _, ok := node.Annotations[nodelock.NodeLockKey] + assert.Equal(t, ok, tt.hasLock) + }) + } +} + +func TestDevices_Fit(t *testing.T) { + config := VastaiConfig{ + ResourceCountName: "vastaitech.com/va", + } + dev := InitVastaiDevice(config) + + tests := []struct { + name string + devices []*device.DeviceUsage + request device.ContainerDeviceRequest + annos map[string]string + wantFit bool + wantLen int + wantDevIDs []string + wantReason string + }{ + { + name: "fit success", + devices: []*device.DeviceUsage{ + { + ID: "dev-0", + Index: 0, + Used: 0, + Count: 1, + Usedmem: 0, + Totalmem: 0, + Totalcore: 0, + Usedcores: 0, + Numa: 0, + Type: VastaiDevice, + Health: true, + }, + { + ID: "dev-1", + Index: 1, + Used: 0, + Count: 1, + Usedmem: 0, + Totalmem: 0, + Totalcore: 0, + Usedcores: 0, + Numa: 0, + Type: VastaiDevice, + Health: true, + }, + }, + request: device.ContainerDeviceRequest{ + Nums: 1, + Memreq: 0, + MemPercentagereq: 0, + Coresreq: 0, + Type: VastaiDevice, + }, + annos: map[string]string{}, + wantFit: true, + wantLen: 1, + wantDevIDs: []string{"dev-0"}, + wantReason: "", + }, + { + name: "fit fail: type mismatch", + devices: []*device.DeviceUsage{{ + ID: "dev-0", + Index: 0, + Used: 0, + Count: 1, + Usedmem: 0, + Totalmem: 0, + Totalcore: 0, + Usedcores: 0, + Numa: 0, + Health: true, + Type: VastaiDevice, + }}, + request: device.ContainerDeviceRequest{ + Nums: 1, + Type: "OtherType", + Memreq: 512, + MemPercentagereq: 0, + Coresreq: 50, + }, + annos: map[string]string{}, + wantFit: false, + wantLen: 0, + wantDevIDs: []string{}, + wantReason: "1/1 CardTypeMismatch", + }, + { + name: "fit fail: user assign use uuid mismatch", + devices: []*device.DeviceUsage{{ + ID: "dev-1", + Index: 0, + Used: 0, + Count: 1, + Usedmem: 0, + Totalmem: 0, + Totalcore: 0, + Usedcores: 0, + Numa: 0, + Type: VastaiDevice, + Health: true, + }}, + request: device.ContainerDeviceRequest{ + Nums: 1, + Memreq: 0, + MemPercentagereq: 0, + Coresreq: 0, + Type: VastaiDevice, + }, + annos: map[string]string{VastaiUseUUID: "dev-0"}, + wantFit: false, + wantLen: 0, + wantDevIDs: []string{}, + wantReason: "1/1 CardUuidMismatch", + }, + { + name: "fit fail: user assign no use uuid match", + devices: []*device.DeviceUsage{{ + ID: "dev-0", + Index: 0, + Used: 0, + Count: 1, + Usedmem: 0, + Totalmem: 0, + Totalcore: 0, + Usedcores: 0, + Numa: 0, + Type: VastaiDevice, + Health: true, + }}, + request: device.ContainerDeviceRequest{ + Nums: 1, + Memreq: 0, + MemPercentagereq: 0, + Coresreq: 0, + Type: VastaiDevice, + }, + annos: map[string]string{VastaiNoUseUUID: "dev-0"}, + wantFit: false, + wantLen: 0, + wantDevIDs: []string{}, + wantReason: "1/1 CardUuidMismatch", + }, + { + name: "fit fail: card overused", + devices: []*device.DeviceUsage{{ + ID: "dev-0", + Index: 0, + Used: 1, + Count: 1, + Usedmem: 0, + Totalmem: 0, + Totalcore: 0, + Usedcores: 0, + Numa: 0, + Type: VastaiDevice, + Health: true, + }}, + request: device.ContainerDeviceRequest{ + Nums: 1, + Memreq: 0, + MemPercentagereq: 0, + Coresreq: 0, + Type: VastaiDevice, + }, + annos: map[string]string{}, + wantFit: false, + wantLen: 0, + wantDevIDs: []string{}, + wantReason: "1/1 CardTimeSlicingExhausted", + }, + { + name: "fit fail: AllocatedCardsInsufficientRequest", + devices: []*device.DeviceUsage{{ + ID: "dev-0", + Index: 0, + Used: 0, + Count: 1, + Usedmem: 0, + Totalmem: 0, + Totalcore: 0, + Usedcores: 0, + Numa: 0, + Type: VastaiDevice, + Health: true, + }}, + request: device.ContainerDeviceRequest{ + Nums: 2, + Memreq: 0, + MemPercentagereq: 0, + Coresreq: 0, + Type: VastaiDevice, + }, + annos: map[string]string{}, + wantFit: false, + wantLen: 0, + wantDevIDs: []string{}, + wantReason: "1/1 AllocatedCardsInsufficientRequest", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + allocated := &device.PodDevices{} + pod := &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: test.annos, + }, + } + fit, result, reason := dev.Fit(test.devices, test.request, pod, &device.NodeInfo{}, allocated) + if fit != test.wantFit { + t.Errorf("Fit: got %v, want %v", fit, test.wantFit) + } + if test.wantFit { + if len(result[VastaiDevice]) != test.wantLen { + t.Errorf("expected len: %d, got len %d", test.wantLen, len(result[VastaiDevice])) + } + for idx, id := range test.wantDevIDs { + if id != result[VastaiDevice][idx].UUID { + t.Errorf("expected device id: %s, got device id %s", id, result[VastaiDevice][idx].UUID) + } + } + } + + if reason != test.wantReason { + t.Errorf("expected reason: %s, got reason: %s", test.wantReason, reason) + } + }) + } +} + +func TestDevices_AddResourceUsage(t *testing.T) { + tests := []struct { + name string + deviceUsage *device.DeviceUsage + ctr *device.ContainerDevice + wantErr bool + wantUsage *device.DeviceUsage + }{ + { + name: "test add resource usage", + deviceUsage: &device.DeviceUsage{ + ID: "dev-0", + Used: 1, + Usedcores: 0, + Usedmem: 0, + }, + ctr: &device.ContainerDevice{ + UUID: "dev-0", + Usedcores: 0, + Usedmem: 0, + }, + wantUsage: &device.DeviceUsage{ + ID: "dev-0", + Used: 2, + Usedcores: 0, + Usedmem: 0, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dev := &VastaiDevices{} + if err := dev.AddResourceUsage(&corev1.Pod{}, tt.deviceUsage, tt.ctr); (err != nil) != tt.wantErr { + t.Errorf("AddResourceUsage() error=%v, wantErr %v", err, tt.wantErr) + } + if !tt.wantErr { + if tt.deviceUsage.Usedcores != tt.wantUsage.Usedcores { + t.Errorf("expected used cores: %d, got used cores %d", tt.wantUsage.Usedcores, tt.deviceUsage.Usedcores) + } + if tt.deviceUsage.Usedmem != tt.wantUsage.Usedmem { + t.Errorf("expected used mem: %d, got used mem %d", tt.wantUsage.Usedmem, tt.deviceUsage.Usedmem) + } + if tt.deviceUsage.Used != tt.wantUsage.Used { + t.Errorf("expected used: %d, got used %d", tt.wantUsage.Used, tt.deviceUsage.Used) + } + } + }) + } +} + +func TestComputeBestCombination(t *testing.T) { + tests := []struct { + name string + reqNum int + containerDevices device.ContainerDevices + expected device.ContainerDevices + expectedCount int + }{ + { + name: "single AIC group, sufficient to meet demand", + reqNum: 2, + containerDevices: device.ContainerDevices{ + {Idx: 1, UUID: "dev1", CustomInfo: map[string]any{"AIC": "1"}}, + {Idx: 2, UUID: "dev2", CustomInfo: map[string]any{"AIC": "1"}}, + {Idx: 3, UUID: "dev3", CustomInfo: map[string]any{"AIC": "1"}}, + }, + expected: device.ContainerDevices{ + {Idx: 1, UUID: "dev1", CustomInfo: map[string]any{"AIC": "1"}}, + {Idx: 2, UUID: "dev2", CustomInfo: map[string]any{"AIC": "1"}}, + }, + expectedCount: 2, + }, + { + name: "multiple AIC groups, select by count after sorting", + reqNum: 3, + containerDevices: device.ContainerDevices{ + // aic1 group has 3 devices + {Idx: 1, UUID: "dev1", CustomInfo: map[string]any{"AIC": "1"}}, + {Idx: 2, UUID: "dev2", CustomInfo: map[string]any{"AIC": "1"}}, + {Idx: 3, UUID: "dev3", CustomInfo: map[string]any{"AIC": "1"}}, + // aic2 group has 2 devices + {Idx: 4, UUID: "dev4", CustomInfo: map[string]any{"AIC": "2"}}, + {Idx: 5, UUID: "dev5", CustomInfo: map[string]any{"AIC": "2"}}, + // aic3 group has 1 device + {Idx: 6, UUID: "dev6", CustomInfo: map[string]any{"AIC": "3"}}, + }, + expected: device.ContainerDevices{ + // Should take from the group with the highest count (aic1) + {Idx: 1, UUID: "dev1", CustomInfo: map[string]any{"AIC": "1"}}, + {Idx: 2, UUID: "dev2", CustomInfo: map[string]any{"AIC": "1"}}, + {Idx: 3, UUID: "dev3", CustomInfo: map[string]any{"AIC": "1"}}, + }, + expectedCount: 3, + }, + { + name: "multiple AIC groups, need to take across groups", + reqNum: 4, + containerDevices: device.ContainerDevices{ + // aic1 group has 3 devices + {Idx: 1, UUID: "dev1", CustomInfo: map[string]any{"AIC": "1"}}, + {Idx: 2, UUID: "dev2", CustomInfo: map[string]any{"AIC": "1"}}, + {Idx: 3, UUID: "dev3", CustomInfo: map[string]any{"AIC": "1"}}, + // aic2 group has 2 devices + {Idx: 4, UUID: "dev4", CustomInfo: map[string]any{"AIC": "2"}}, + {Idx: 5, UUID: "dev5", CustomInfo: map[string]any{"AIC": "2"}}, + // aic3 group has 1 device + {Idx: 6, UUID: "dev6", CustomInfo: map[string]any{"AIC": "3"}}, + }, + expected: device.ContainerDevices{ + // Take all 3 from aic1 group first + {Idx: 1, UUID: "dev1", CustomInfo: map[string]any{"AIC": "1"}}, + {Idx: 2, UUID: "dev2", CustomInfo: map[string]any{"AIC": "1"}}, + {Idx: 3, UUID: "dev3", CustomInfo: map[string]any{"AIC": "1"}}, + // Then take 1 from aic2 group + {Idx: 4, UUID: "dev4", CustomInfo: map[string]any{"AIC": "2"}}, + }, + expectedCount: 4, + }, + { + name: "request number exceeds available devices", + reqNum: 5, + containerDevices: device.ContainerDevices{ + {Idx: 1, UUID: "dev1", CustomInfo: map[string]any{"AIC": "1"}}, + {Idx: 2, UUID: "dev2", CustomInfo: map[string]any{"AIC": "1"}}, + {Idx: 3, UUID: "dev3", CustomInfo: map[string]any{"AIC": "2"}}, + }, + expected: device.ContainerDevices{ + {Idx: 1, UUID: "dev1", CustomInfo: map[string]any{"AIC": "1"}}, + {Idx: 2, UUID: "dev2", CustomInfo: map[string]any{"AIC": "1"}}, + {Idx: 3, UUID: "dev3", CustomInfo: map[string]any{"AIC": "2"}}, + }, + expectedCount: 3, + }, + } + compareFunc := func(a, b device.ContainerDevices) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i].UUID != b[i].UUID { + return false + } + } + return true + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dev := &VastaiDevices{} // Assume this is your struct + result := dev.computeBestCombination(tt.reqNum, tt.containerDevices) + + // Check result count + if len(result) != tt.expectedCount { + t.Errorf("Expected %d devices, got %d", tt.expectedCount, len(result)) + } + + // If expected result is not empty, check specific content + if tt.expectedCount > 0 { + if !compareFunc(result, tt.expected) { + t.Errorf("Device list does not match\nExpected: %+v\nGot: %+v", tt.expected, result) + } + } + }) + } +} diff --git a/pkg/scheduler/config/config.go b/pkg/scheduler/config/config.go index a012eabf0..f131c2325 100644 --- a/pkg/scheduler/config/config.go +++ b/pkg/scheduler/config/config.go @@ -38,6 +38,7 @@ import ( "github.com/Project-HAMi/HAMi/pkg/device/metax" "github.com/Project-HAMi/HAMi/pkg/device/mthreads" "github.com/Project-HAMi/HAMi/pkg/device/nvidia" + "github.com/Project-HAMi/HAMi/pkg/device/vastai" "github.com/Project-HAMi/HAMi/pkg/util" ) @@ -82,6 +83,7 @@ type Config struct { KunlunConfig kunlun.KunlunConfig `yaml:"kunlun"` AWSNeuronConfig awsneuron.AWSNeuronConfig `yaml:"awsneuron"` AMDGPUConfig amd.AMDConfig `yaml:"amd"` + VastaiConfig vastai.VastaiConfig `yaml:"vastai"` VNPUs []ascend.VNPUConfig `yaml:"vnpus"` } @@ -209,6 +211,13 @@ func InitDevicesWithConfig(config *Config) error { } return amd.InitAMDGPUDevice(amdGPUConfig), nil }, config.AMDGPUConfig}, + {vastai.VastaiDevice, vastai.VastaiCommonWord, func(cfg any) (device.Devices, error) { + vastaiConfig, ok := cfg.(vastai.VastaiConfig) + if !ok { + return nil, fmt.Errorf("invalid configuration for %s", vastai.VastaiCommonWord) + } + return vastai.InitVastaiDevice(vastaiConfig), nil + }, config.VastaiConfig}, } // Initialize all devices using the wrapped functions diff --git a/pkg/scheduler/config/config_test.go b/pkg/scheduler/config/config_test.go index 997b0f601..10857a873 100644 --- a/pkg/scheduler/config/config_test.go +++ b/pkg/scheduler/config/config_test.go @@ -39,6 +39,7 @@ import ( "github.com/Project-HAMi/HAMi/pkg/device/metax" "github.com/Project-HAMi/HAMi/pkg/device/mthreads" "github.com/Project-HAMi/HAMi/pkg/device/nvidia" + "github.com/Project-HAMi/HAMi/pkg/device/vastai" ) func loadTestConfig() string { @@ -426,6 +427,7 @@ func setupTest(t *testing.T) (map[string]string, map[string]device.Devices) { kunlun.XPUDevice: kunlun.XPUCommonWord, awsneuron.AWSNeuronDevice: awsneuron.AWSNeuronCommonWord, amd.AMDDevice: amd.AMDDevice, + vastai.VastaiDevice: vastai.VastaiCommonWord, } return expectedDevices, device.DevicesMap