Skip to content

Commit 151ec19

Browse files
committed
Changes for getting GPU details from C funtion
1 parent 35901ea commit 151ec19

4 files changed

Lines changed: 79 additions & 28 deletions

File tree

java/cuvs-java/src/main/java/com/nvidia/cuvs/common/Util.java

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -35,23 +35,25 @@
3535

3636
public class Util {
3737

38-
/**
39-
* Returns the number of GPUs connected to the system using CuVSResources.
40-
*
41-
* @param resources The CuVSResources object managing native resources.
42-
* @return Number of GPUs connected, or -1 if an error occurred.
43-
*/
44-
public static int getNumberOfGPUs(CuVSResources resources) {
45-
try {
46-
MethodHandle getNumberOfGPUsHandle = resources.linker.downcallHandle(
47-
resources.libcuvsNativeLibrary.find("get_number_of_gpus")
48-
.orElseThrow(() -> new IllegalStateException("get_number_of_gpus not found in library")),
49-
FunctionDescriptor.of(ValueLayout.JAVA_INT));
50-
51-
return (int) getNumberOfGPUsHandle.invokeExact();
38+
public static String getGpuDetails(CuVSResources resources, int maxGpus, int maxDetailLength) {
39+
try (Arena arena = Arena.ofConfined()) {
40+
MemorySegment detailSegment = arena.allocate(maxGpus * maxDetailLength);
41+
MethodHandle getGpuDetailsHandle = resources.linker.downcallHandle(
42+
resources.libcuvsNativeLibrary.find("get_gpu_details")
43+
.orElseThrow(() -> new IllegalStateException("get_gpu_details not found in library")),
44+
FunctionDescriptor.of(ValueLayout.JAVA_INT, ValueLayout.ADDRESS, ValueLayout.JAVA_INT, ValueLayout.JAVA_INT));
45+
46+
int gpuCount = (int) getGpuDetailsHandle.invoke(detailSegment, maxGpus, maxDetailLength);
47+
if (gpuCount < 0) {
48+
throw new RuntimeException("Failed to retrieve GPU details");
49+
}
50+
51+
// Convert MemorySegment to String
52+
String details = new String(detailSegment.toArray(ValueLayout.JAVA_BYTE), 0, gpuCount * maxDetailLength);
53+
return details.trim();
5254
} catch (Throwable e) {
53-
System.err.println("Failed to invoke get_number_of_gpus: " + e.getMessage());
54-
return -1; // Return -1 to indicate an error
55+
System.err.println("Error invoking get_gpu_details: " + e.getMessage());
56+
throw new RuntimeException("Failed to invoke get_gpu_details", e);
5557
}
5658
}
5759

java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraRandomizedTest.java

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import org.slf4j.LoggerFactory;
1111

1212
import com.carrotsearch.randomizedtesting.RandomizedContext;
13-
import com.nvidia.cuvs.common.Util;
1413

1514
public class CagraRandomizedTest extends LuceneTestCase {
1615
private Random random;
@@ -56,21 +55,12 @@ public void testResultsTopKWithRandomValues() throws Throwable {
5655
}
5756

5857
log.info("Queries:");
59-
for (float[] query : queries) {
58+
for (float[] query : queries) {
6059
log.info(java.util.Arrays.toString(query));
6160
}
6261

6362
CuVSResources resources = new CuVSResources();
6463

65-
int gpuCount = Util.getNumberOfGPUs(resources);
66-
if (gpuCount == -1) {
67-
log.info("Failed to detect GPUs.");
68-
} else if (gpuCount == 0) {
69-
log.info("No GPUs detected.");
70-
} else {
71-
log.info("Number of GPUs detected: {}", gpuCount);
72-
}
73-
7464
CagraIndexParams indexParams = new CagraIndexParams.Builder(resources).build();
7565

7666
CagraIndex index = new CagraIndex.Builder(resources).withDataset(dataset).withIndexParams(indexParams).build();
@@ -87,4 +77,3 @@ public void testResultsTopKWithRandomValues() throws Throwable {
8777
}
8878

8979
}
90-
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
package com.nvidia.cuvs.common;
2+
3+
import static org.junit.Assert.assertTrue;
4+
5+
import org.junit.Test;
6+
7+
import com.nvidia.cuvs.CuVSResources;
8+
9+
public class TestUtil {
10+
11+
@Test
12+
public void testGpuDetails() throws Throwable {
13+
try {
14+
CuVSResources resources = new CuVSResources();
15+
String details = Util.getGpuDetails(resources, 10, 256);
16+
System.out.println("GPU Details: " + details);
17+
assertTrue("GPU details should not be empty", !details.isEmpty());
18+
} catch (RuntimeException e) {
19+
e.printStackTrace();
20+
throw new AssertionError("Test failed due to an exception: " + e.getMessage());
21+
}
22+
}
23+
24+
}

java/internal/src/cuvs_java.c

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
#include <cuda_runtime.h>
2121
#include <stdio.h>
2222
#include <stdlib.h>
23+
#include <cuda_runtime.h>
24+
#include <string.h>
2325

2426
cuvsResources_t create_resource(int *returnValue) {
2527
cuvsResources_t cuvsResources;
@@ -104,3 +106,37 @@ int get_number_of_gpus() {
104106
}
105107
return deviceCount;
106108
}
109+
110+
int get_gpu_details(char *details, int max_gpus, int max_detail_length) {
111+
int deviceCount = 0;
112+
cudaError_t err = cudaGetDeviceCount(&deviceCount);
113+
114+
if (err != cudaSuccess || deviceCount == 0) {
115+
fprintf(stderr, "cudaGetDeviceCount failed or no GPUs found: %s\n", cudaGetErrorString(err));
116+
return -1;
117+
}
118+
119+
for (int i = 0; i < deviceCount && i < max_gpus; i++) {
120+
struct cudaDeviceProp deviceProp;
121+
err = cudaGetDeviceProperties(&deviceProp, i);
122+
if (err != cudaSuccess) {
123+
snprintf(&details[i * max_detail_length], max_detail_length,
124+
"Error fetching properties for device %d", i);
125+
continue;
126+
}
127+
128+
size_t freeMem = 0, totalMem = 0;
129+
cudaSetDevice(i);
130+
err = cudaMemGetInfo(&freeMem, &totalMem);
131+
if (err != cudaSuccess) {
132+
snprintf(&details[i * max_detail_length], max_detail_length,
133+
"%s | Memory info unavailable", deviceProp.name);
134+
continue;
135+
}
136+
137+
snprintf(&details[i * max_detail_length], max_detail_length,
138+
"%s | Total: %zuMB | Free: %zuMB", deviceProp.name, totalMem / (1024 * 1024), freeMem / (1024 * 1024));
139+
}
140+
141+
return deviceCount;
142+
}

0 commit comments

Comments
 (0)