diff --git a/core-services/document-grounding/pom.xml b/core-services/document-grounding/pom.xml
index 98db51a46..617eb6048 100644
--- a/core-services/document-grounding/pom.xml
+++ b/core-services/document-grounding/pom.xml
@@ -37,10 +37,10 @@
${project.basedir}/../../
- 80%
+ 77%
71%
85%
- 100%
+ 75%
80%
100%
diff --git a/core-services/document-grounding/src/main/java/com/sap/ai/sdk/grounding/GroundingClient.java b/core-services/document-grounding/src/main/java/com/sap/ai/sdk/grounding/GroundingClient.java
index 3360e126e..839bc9ba0 100644
--- a/core-services/document-grounding/src/main/java/com/sap/ai/sdk/grounding/GroundingClient.java
+++ b/core-services/document-grounding/src/main/java/com/sap/ai/sdk/grounding/GroundingClient.java
@@ -1,14 +1,20 @@
package com.sap.ai.sdk.grounding;
+import com.google.common.annotations.Beta;
import com.sap.ai.sdk.core.AiCoreService;
import com.sap.ai.sdk.grounding.client.PipelinesApi;
import com.sap.ai.sdk.grounding.client.RetrievalApi;
import com.sap.ai.sdk.grounding.client.VectorApi;
+import com.sap.cloud.sdk.cloudplatform.connectivity.Header;
+import com.sap.cloud.sdk.services.openapi.apiclient.ApiClient;
+import java.util.ArrayList;
+import java.util.List;
import javax.annotation.Nonnull;
import lombok.AccessLevel;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.experimental.Tolerate;
+import lombok.val;
/**
* Service class for the Document Grounding APIs.
@@ -20,6 +26,7 @@
public class GroundingClient {
@Nonnull private final AiCoreService service;
@Nonnull private final String basePath;
+ @Nonnull private final List customHeaders = new ArrayList<>();
static final String DEFAULT_BASE_PATH = "lm/document-grounding/";
@@ -45,7 +52,7 @@ public GroundingClient(final @Nonnull AiCoreService service) {
*/
@Nonnull
public PipelinesApi pipelines() {
- return new PipelinesApi(getService().getApiClient().setBasePath(getBasePath()));
+ return new PipelinesApi(getClient());
}
/**
@@ -55,7 +62,7 @@ public PipelinesApi pipelines() {
*/
@Nonnull
public VectorApi vector() {
- return new VectorApi(getService().getApiClient().setBasePath(getBasePath()));
+ return new VectorApi(getClient());
}
/**
@@ -65,6 +72,34 @@ public VectorApi vector() {
*/
@Nonnull
public RetrievalApi retrieval() {
- return new RetrievalApi(getService().getApiClient().setBasePath(getBasePath()));
+ return new RetrievalApi(getClient());
+ }
+
+ /**
+ * Create a new OpenAI client with a custom header added to every call made with this client
+ *
+ * @param key the key of the custom header to add
+ * @param value the value of the custom header to add
+ * @return a new client.
+ * @since 1.17.0
+ */
+ @Beta
+ @Nonnull
+ public GroundingClient withHeader(@Nonnull final String key, @Nonnull final String value) {
+ final var newClient = new GroundingClient(this.service, this.basePath);
+ newClient.customHeaders.addAll(this.customHeaders);
+ newClient.customHeaders.add(new Header(key, value));
+ return newClient;
+ }
+
+ @Nonnull
+ private ApiClient getClient() {
+ val apiClient = getService().getApiClient().setBasePath(getBasePath());
+ for (val header : customHeaders) {
+ if (header.getValue() != null) {
+ apiClient.addDefaultHeader(header.getName(), header.getValue());
+ }
+ }
+ return apiClient;
}
}
diff --git a/core-services/document-grounding/src/test/java/com/sap/ai/sdk/grounding/GroundingClientTest.java b/core-services/document-grounding/src/test/java/com/sap/ai/sdk/grounding/GroundingClientTest.java
index acda11eaf..2125d5c0f 100644
--- a/core-services/document-grounding/src/test/java/com/sap/ai/sdk/grounding/GroundingClientTest.java
+++ b/core-services/document-grounding/src/test/java/com/sap/ai/sdk/grounding/GroundingClientTest.java
@@ -1,5 +1,10 @@
package com.sap.ai.sdk.grounding;
+import static com.github.tomakehurst.wiremock.client.WireMock.anyUrl;
+import static com.github.tomakehurst.wiremock.client.WireMock.equalTo;
+import static com.github.tomakehurst.wiremock.client.WireMock.get;
+import static com.github.tomakehurst.wiremock.client.WireMock.getRequestedFor;
+import static com.github.tomakehurst.wiremock.client.WireMock.okJson;
import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.wireMockConfig;
import static org.assertj.core.api.Assertions.assertThat;
@@ -36,7 +41,7 @@ class GroundingClientTest {
void testPipelines() {
final PipelinesApi api = new GroundingClient(SERVICE).pipelines();
- final GetPipelines allPipelines = api.getAllPipelines("reosurceGroup");
+ final GetPipelines allPipelines = api.getAllPipelines("resourceGroup");
assertThat(allPipelines).isNotNull();
assertThat(allPipelines.getResources()).isEmpty();
}
@@ -45,7 +50,7 @@ void testPipelines() {
void testVector() {
final VectorApi api = new GroundingClient(SERVICE).vector();
- final CollectionsListResponse collections = api.getAllCollections("reosurceGroup");
+ final CollectionsListResponse collections = api.getAllCollections("resourceGroup");
assertThat(collections).isNotNull();
assertThat(collections.getResources())
.isNotNull()
@@ -63,7 +68,7 @@ void testVector() {
});
final UUID collectionId = collections.getResources().get(0).getId();
- final Documents documents = api.getAllDocuments("reosurceGroup", collectionId);
+ final Documents documents = api.getAllDocuments("resourceGroup", collectionId);
assertThat(documents).isNotNull();
final var documentMeta =
VectorDocumentKeyValueListPair.create()
@@ -82,7 +87,7 @@ void testVector() {
final UUID documentId = documents.getResources().get(0).getId();
final DocumentResponse document =
- api.getDocumentById("reosurceGroup", collectionId, documentId);
+ api.getDocumentById("resourceGroup", collectionId, documentId);
assertThat(document).isNotNull();
assertThat(document.getId()).isEqualTo(documentId);
assertThat(document.getMetadata()).isNotNull().containsExactly(documentMeta);
@@ -115,7 +120,7 @@ void testVector() {
void testRetrieval() {
final RetrievalApi api = new GroundingClient(SERVICE).retrieval();
- DataRepositories repositories = api.getDataRepositories("reosurceGroup");
+ DataRepositories repositories = api.getDataRepositories("resourceGroup");
assertThat(repositories).isNotNull();
assertThat(repositories.getResources())
.isNotEmpty()
@@ -137,4 +142,26 @@ void testRetrieval() {
assertThat(r2.getMetadata()).isNotNull().isEmpty();
});
}
+
+ @Test
+ void testCustomHeaders() {
+ WM.stubFor(
+ get(anyUrl())
+ .withHeader("x-test-header", equalTo("test-value"))
+ .willReturn(
+ okJson(
+ """
+ {
+ "count": 0,
+ "resources": []
+ }
+ """)));
+
+ new GroundingClient(SERVICE)
+ .withHeader("x-test-header", "test-value")
+ .pipelines()
+ .getAllPipelines("resourceGroup");
+
+ WM.verify(getRequestedFor(anyUrl()).withHeader("x-test-header", equalTo("test-value")));
+ }
}
diff --git a/docs/release_notes.md b/docs/release_notes.md
index 5745d6681..80e2325c3 100644
--- a/docs/release_notes.md
+++ b/docs/release_notes.md
@@ -12,7 +12,7 @@
### ✨ New Functionality
--
+- [Grounding] Added `GroundingClient.withHeader()`.
### 📈 Improvements