diff --git a/core-services/document-grounding/pom.xml b/core-services/document-grounding/pom.xml index 98db51a46..617eb6048 100644 --- a/core-services/document-grounding/pom.xml +++ b/core-services/document-grounding/pom.xml @@ -37,10 +37,10 @@ ${project.basedir}/../../ - 80% + 77% 71% 85% - 100% + 75% 80% 100% diff --git a/core-services/document-grounding/src/main/java/com/sap/ai/sdk/grounding/GroundingClient.java b/core-services/document-grounding/src/main/java/com/sap/ai/sdk/grounding/GroundingClient.java index 3360e126e..839bc9ba0 100644 --- a/core-services/document-grounding/src/main/java/com/sap/ai/sdk/grounding/GroundingClient.java +++ b/core-services/document-grounding/src/main/java/com/sap/ai/sdk/grounding/GroundingClient.java @@ -1,14 +1,20 @@ package com.sap.ai.sdk.grounding; +import com.google.common.annotations.Beta; import com.sap.ai.sdk.core.AiCoreService; import com.sap.ai.sdk.grounding.client.PipelinesApi; import com.sap.ai.sdk.grounding.client.RetrievalApi; import com.sap.ai.sdk.grounding.client.VectorApi; +import com.sap.cloud.sdk.cloudplatform.connectivity.Header; +import com.sap.cloud.sdk.services.openapi.apiclient.ApiClient; +import java.util.ArrayList; +import java.util.List; import javax.annotation.Nonnull; import lombok.AccessLevel; import lombok.Getter; import lombok.RequiredArgsConstructor; import lombok.experimental.Tolerate; +import lombok.val; /** * Service class for the Document Grounding APIs. @@ -20,6 +26,7 @@ public class GroundingClient { @Nonnull private final AiCoreService service; @Nonnull private final String basePath; + @Nonnull private final List
customHeaders = new ArrayList<>(); static final String DEFAULT_BASE_PATH = "lm/document-grounding/"; @@ -45,7 +52,7 @@ public GroundingClient(final @Nonnull AiCoreService service) { */ @Nonnull public PipelinesApi pipelines() { - return new PipelinesApi(getService().getApiClient().setBasePath(getBasePath())); + return new PipelinesApi(getClient()); } /** @@ -55,7 +62,7 @@ public PipelinesApi pipelines() { */ @Nonnull public VectorApi vector() { - return new VectorApi(getService().getApiClient().setBasePath(getBasePath())); + return new VectorApi(getClient()); } /** @@ -65,6 +72,34 @@ public VectorApi vector() { */ @Nonnull public RetrievalApi retrieval() { - return new RetrievalApi(getService().getApiClient().setBasePath(getBasePath())); + return new RetrievalApi(getClient()); + } + + /** + * Create a new OpenAI client with a custom header added to every call made with this client + * + * @param key the key of the custom header to add + * @param value the value of the custom header to add + * @return a new client. + * @since 1.17.0 + */ + @Beta + @Nonnull + public GroundingClient withHeader(@Nonnull final String key, @Nonnull final String value) { + final var newClient = new GroundingClient(this.service, this.basePath); + newClient.customHeaders.addAll(this.customHeaders); + newClient.customHeaders.add(new Header(key, value)); + return newClient; + } + + @Nonnull + private ApiClient getClient() { + val apiClient = getService().getApiClient().setBasePath(getBasePath()); + for (val header : customHeaders) { + if (header.getValue() != null) { + apiClient.addDefaultHeader(header.getName(), header.getValue()); + } + } + return apiClient; } } diff --git a/core-services/document-grounding/src/test/java/com/sap/ai/sdk/grounding/GroundingClientTest.java b/core-services/document-grounding/src/test/java/com/sap/ai/sdk/grounding/GroundingClientTest.java index acda11eaf..2125d5c0f 100644 --- a/core-services/document-grounding/src/test/java/com/sap/ai/sdk/grounding/GroundingClientTest.java +++ b/core-services/document-grounding/src/test/java/com/sap/ai/sdk/grounding/GroundingClientTest.java @@ -1,5 +1,10 @@ package com.sap.ai.sdk.grounding; +import static com.github.tomakehurst.wiremock.client.WireMock.anyUrl; +import static com.github.tomakehurst.wiremock.client.WireMock.equalTo; +import static com.github.tomakehurst.wiremock.client.WireMock.get; +import static com.github.tomakehurst.wiremock.client.WireMock.getRequestedFor; +import static com.github.tomakehurst.wiremock.client.WireMock.okJson; import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.wireMockConfig; import static org.assertj.core.api.Assertions.assertThat; @@ -36,7 +41,7 @@ class GroundingClientTest { void testPipelines() { final PipelinesApi api = new GroundingClient(SERVICE).pipelines(); - final GetPipelines allPipelines = api.getAllPipelines("reosurceGroup"); + final GetPipelines allPipelines = api.getAllPipelines("resourceGroup"); assertThat(allPipelines).isNotNull(); assertThat(allPipelines.getResources()).isEmpty(); } @@ -45,7 +50,7 @@ void testPipelines() { void testVector() { final VectorApi api = new GroundingClient(SERVICE).vector(); - final CollectionsListResponse collections = api.getAllCollections("reosurceGroup"); + final CollectionsListResponse collections = api.getAllCollections("resourceGroup"); assertThat(collections).isNotNull(); assertThat(collections.getResources()) .isNotNull() @@ -63,7 +68,7 @@ void testVector() { }); final UUID collectionId = collections.getResources().get(0).getId(); - final Documents documents = api.getAllDocuments("reosurceGroup", collectionId); + final Documents documents = api.getAllDocuments("resourceGroup", collectionId); assertThat(documents).isNotNull(); final var documentMeta = VectorDocumentKeyValueListPair.create() @@ -82,7 +87,7 @@ void testVector() { final UUID documentId = documents.getResources().get(0).getId(); final DocumentResponse document = - api.getDocumentById("reosurceGroup", collectionId, documentId); + api.getDocumentById("resourceGroup", collectionId, documentId); assertThat(document).isNotNull(); assertThat(document.getId()).isEqualTo(documentId); assertThat(document.getMetadata()).isNotNull().containsExactly(documentMeta); @@ -115,7 +120,7 @@ void testVector() { void testRetrieval() { final RetrievalApi api = new GroundingClient(SERVICE).retrieval(); - DataRepositories repositories = api.getDataRepositories("reosurceGroup"); + DataRepositories repositories = api.getDataRepositories("resourceGroup"); assertThat(repositories).isNotNull(); assertThat(repositories.getResources()) .isNotEmpty() @@ -137,4 +142,26 @@ void testRetrieval() { assertThat(r2.getMetadata()).isNotNull().isEmpty(); }); } + + @Test + void testCustomHeaders() { + WM.stubFor( + get(anyUrl()) + .withHeader("x-test-header", equalTo("test-value")) + .willReturn( + okJson( + """ + { + "count": 0, + "resources": [] + } + """))); + + new GroundingClient(SERVICE) + .withHeader("x-test-header", "test-value") + .pipelines() + .getAllPipelines("resourceGroup"); + + WM.verify(getRequestedFor(anyUrl()).withHeader("x-test-header", equalTo("test-value"))); + } } diff --git a/docs/release_notes.md b/docs/release_notes.md index 5745d6681..80e2325c3 100644 --- a/docs/release_notes.md +++ b/docs/release_notes.md @@ -12,7 +12,7 @@ ### ✨ New Functionality -- +- [Grounding] Added `GroundingClient.withHeader()`. ### 📈 Improvements