From 3a76bfaf4d6b68fe076e984225fe66ec36048977 Mon Sep 17 00:00:00 2001 From: Avichay Marciano Date: Mon, 9 Feb 2026 16:42:59 +0200 Subject: [PATCH 1/8] [FLINK-AGENTS] Add Amazon Bedrock chat model and embedding model integrations Add two new integration modules for Amazon Bedrock: - Chat model using the Converse API with native tool calling support, SigV4 auth via DefaultCredentialsProvider, and token metrics reporting. Supports all Bedrock models accessible via Converse API. - Embedding model using Titan Text Embeddings V2 via InvokeModel. Batch embed(List) parallelizes via configurable thread pool (embed_concurrency parameter, default 4). Includes unit tests for constructors, parameter handling, and inheritance. --- dist/pom.xml | 10 + integrations/chat-models/bedrock/pom.xml | 48 +++ .../bedrock/BedrockChatModelConnection.java | 350 ++++++++++++++++++ .../bedrock/BedrockChatModelSetup.java | 64 ++++ .../BedrockChatModelConnectionTest.java | 83 +++++ .../bedrock/BedrockChatModelSetupTest.java | 77 ++++ integrations/chat-models/pom.xml | 1 + integrations/embedding-models/bedrock/pom.xml | 48 +++ .../BedrockEmbeddingModelConnection.java | 132 +++++++ .../bedrock/BedrockEmbeddingModelSetup.java | 57 +++ .../bedrock/BedrockEmbeddingModelTest.java | 98 +++++ integrations/embedding-models/pom.xml | 1 + integrations/pom.xml | 1 + 13 files changed, 970 insertions(+) create mode 100644 integrations/chat-models/bedrock/pom.xml create mode 100644 integrations/chat-models/bedrock/src/main/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelConnection.java create mode 100644 integrations/chat-models/bedrock/src/main/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelSetup.java create mode 100644 integrations/chat-models/bedrock/src/test/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelConnectionTest.java create mode 100644 integrations/chat-models/bedrock/src/test/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelSetupTest.java create mode 100644 integrations/embedding-models/bedrock/pom.xml create mode 100644 integrations/embedding-models/bedrock/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/bedrock/BedrockEmbeddingModelConnection.java create mode 100644 integrations/embedding-models/bedrock/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/bedrock/BedrockEmbeddingModelSetup.java create mode 100644 integrations/embedding-models/bedrock/src/test/java/org/apache/flink/agents/integrations/embeddingmodels/bedrock/BedrockEmbeddingModelTest.java diff --git a/dist/pom.xml b/dist/pom.xml index 8199eaa3..b4e479a8 100644 --- a/dist/pom.xml +++ b/dist/pom.xml @@ -79,11 +79,21 @@ under the License. flink-agents-integrations-chat-models-azureai ${project.version} + + org.apache.flink + flink-agents-integrations-chat-models-bedrock + ${project.version} + org.apache.flink flink-agents-integrations-embedding-models-ollama ${project.version} + + org.apache.flink + flink-agents-integrations-embedding-models-bedrock + ${project.version} + org.apache.flink flink-agents-integrations-vector-stores-elasticsearch diff --git a/integrations/chat-models/bedrock/pom.xml b/integrations/chat-models/bedrock/pom.xml new file mode 100644 index 00000000..a10b7a45 --- /dev/null +++ b/integrations/chat-models/bedrock/pom.xml @@ -0,0 +1,48 @@ + + + + 4.0.0 + + + org.apache.flink + flink-agents-integrations-chat-models + 0.3-SNAPSHOT + ../pom.xml + + + flink-agents-integrations-chat-models-bedrock + Flink Agents : Integrations: Chat Models: Bedrock + jar + + + + org.apache.flink + flink-agents-api + ${project.version} + + + + software.amazon.awssdk + bedrockruntime + ${aws.sdk.version} + + + + diff --git a/integrations/chat-models/bedrock/src/main/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelConnection.java b/integrations/chat-models/bedrock/src/main/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelConnection.java new file mode 100644 index 00000000..56b3b3da --- /dev/null +++ b/integrations/chat-models/bedrock/src/main/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelConnection.java @@ -0,0 +1,350 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.integrations.chatmodels.bedrock; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.flink.agents.api.chat.messages.ChatMessage; +import org.apache.flink.agents.api.chat.messages.MessageRole; +import org.apache.flink.agents.api.chat.model.BaseChatModelConnection; +import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.apache.flink.agents.api.resource.ResourceType; +import org.apache.flink.agents.api.tools.Tool; +import org.apache.flink.agents.api.tools.ToolMetadata; +import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient; +import software.amazon.awssdk.services.bedrockruntime.model.*; + +import java.util.*; +import java.util.function.BiFunction; +import java.util.stream.Collectors; + +/** + * Bedrock Converse API chat model connection for flink-agents. + * + *

Supported connection parameters: + *

    + *
  • region (optional): AWS region (defaults to us-east-1)
  • + *
  • model (optional): Default model ID (e.g. us.anthropic.claude-sonnet-4-20250514-v1:0)
  • + *
+ */ +public class BedrockChatModelConnection extends BaseChatModelConnection { + + private static final ObjectMapper MAPPER = new ObjectMapper(); + private final BedrockRuntimeClient client; + private final String defaultModel; + + public BedrockChatModelConnection( + ResourceDescriptor descriptor, + BiFunction getResource) { + super(descriptor, getResource); + + String region = descriptor.getArgument("region"); + if (region == null || region.isBlank()) { + region = "us-east-1"; + } + + this.client = BedrockRuntimeClient.builder() + .region(Region.of(region)) + .credentialsProvider(DefaultCredentialsProvider.create()) + .build(); + + this.defaultModel = descriptor.getArgument("model"); + } + + @Override + public ChatMessage chat( + List messages, List tools, Map arguments) { + try { + String modelId = resolveModel(arguments); + + // Separate system messages from conversation messages + List systemMsgs = messages.stream() + .filter(m -> m.getRole() == MessageRole.SYSTEM) + .collect(Collectors.toList()); + List conversationMsgs = messages.stream() + .filter(m -> m.getRole() != MessageRole.SYSTEM) + .collect(Collectors.toList()); + + ConverseRequest.Builder requestBuilder = ConverseRequest.builder() + .modelId(modelId) + .messages(mergeMessages(conversationMsgs)); + + // System prompt + if (!systemMsgs.isEmpty()) { + requestBuilder.system(systemMsgs.stream() + .map(m -> SystemContentBlock.builder().text(m.getContent()).build()) + .collect(Collectors.toList())); + } + + // Tools + if (tools != null && !tools.isEmpty()) { + requestBuilder.toolConfig(ToolConfiguration.builder() + .tools(tools.stream().map(this::toBedrockTool).collect(Collectors.toList())) + .build()); + } + + // Temperature + if (arguments != null) { + Object temp = arguments.get("temperature"); + if (temp instanceof Number) { + requestBuilder.inferenceConfig(InferenceConfiguration.builder() + .temperature(((Number) temp).floatValue()) + .build()); + } + } + + ConverseResponse response = client.converse(requestBuilder.build()); + + // Record token metrics + if (response.usage() != null) { + recordTokenMetrics(modelId, + response.usage().inputTokens(), + response.usage().outputTokens()); + } + + return convertResponse(response); + } catch (Exception e) { + throw new RuntimeException("Failed to call Bedrock Converse API.", e); + } + } + + private String resolveModel(Map arguments) { + String model = arguments != null ? (String) arguments.get("model") : null; + if (model == null || model.isBlank()) { + model = this.defaultModel; + } + if (model == null || model.isBlank()) { + throw new IllegalArgumentException("No model specified for Bedrock."); + } + return model; + } + + /** + * Merge consecutive TOOL messages into a single USER message with multiple + * toolResult content blocks, as required by Bedrock Converse API. + */ + private List mergeMessages(List msgs) { + List result = new ArrayList<>(); + int i = 0; + while (i < msgs.size()) { + ChatMessage msg = msgs.get(i); + if (msg.getRole() == MessageRole.TOOL) { + // Collect all consecutive TOOL messages into one USER message + List toolResultBlocks = new ArrayList<>(); + while (i < msgs.size() && msgs.get(i).getRole() == MessageRole.TOOL) { + ChatMessage toolMsg = msgs.get(i); + String toolCallId = (String) toolMsg.getExtraArgs().get("externalId"); + toolResultBlocks.add(ContentBlock.fromToolResult(ToolResultBlock.builder() + .toolUseId(toolCallId) + .content(ToolResultContentBlock.builder() + .text(toolMsg.getContent()) + .build()) + .build())); + i++; + } + result.add(Message.builder() + .role(ConversationRole.USER) + .content(toolResultBlocks) + .build()); + } else { + result.add(toBedrockMessage(msg)); + i++; + } + } + return result; + } + + private Message toBedrockMessage(ChatMessage msg) { + switch (msg.getRole()) { + case USER: + return Message.builder() + .role(ConversationRole.USER) + .content(ContentBlock.fromText(msg.getContent())) + .build(); + case ASSISTANT: + List blocks = new ArrayList<>(); + if (msg.getContent() != null && !msg.getContent().isEmpty()) { + blocks.add(ContentBlock.fromText(msg.getContent())); + } + // Re-emit tool use blocks for multi-turn tool calling + if (msg.getToolCalls() != null && !msg.getToolCalls().isEmpty()) { + for (Map call : msg.getToolCalls()) { + @SuppressWarnings("unchecked") + Map fn = (Map) call.get("function"); + String toolUseId = (String) call.get("id"); + String name = (String) fn.get("name"); + Object args = fn.get("arguments"); + blocks.add(ContentBlock.fromToolUse(ToolUseBlock.builder() + .toolUseId(toolUseId) + .name(name) + .input(toDocument(args)) + .build())); + } + } + return Message.builder() + .role(ConversationRole.ASSISTANT) + .content(blocks) + .build(); + case TOOL: + String toolCallId = (String) msg.getExtraArgs().get("externalId"); + return Message.builder() + .role(ConversationRole.USER) + .content(ContentBlock.fromToolResult(ToolResultBlock.builder() + .toolUseId(toolCallId) + .content(ToolResultContentBlock.builder() + .text(msg.getContent()) + .build()) + .build())) + .build(); + default: + throw new IllegalArgumentException("Unsupported role for Bedrock: " + msg.getRole()); + } + } + + private software.amazon.awssdk.services.bedrockruntime.model.Tool toBedrockTool(Tool tool) { + ToolMetadata meta = tool.getMetadata(); + software.amazon.awssdk.services.bedrockruntime.model.ToolSpecification.Builder specBuilder = + software.amazon.awssdk.services.bedrockruntime.model.ToolSpecification.builder() + .name(meta.getName()) + .description(meta.getDescription()); + + String schema = meta.getInputSchema(); + if (schema != null && !schema.isBlank()) { + try { + Map schemaMap = MAPPER.readValue(schema, + new TypeReference>() {}); + specBuilder.inputSchema(ToolInputSchema.fromJson(toDocument(schemaMap))); + } catch (JsonProcessingException e) { + throw new RuntimeException("Failed to parse tool schema.", e); + } + } + + return software.amazon.awssdk.services.bedrockruntime.model.Tool.builder() + .toolSpec(specBuilder.build()) + .build(); + } + + private ChatMessage convertResponse(ConverseResponse response) { + List outputBlocks = response.output().message().content(); + StringBuilder textContent = new StringBuilder(); + List> toolCalls = new ArrayList<>(); + + for (ContentBlock block : outputBlocks) { + if (block.text() != null) { + textContent.append(block.text()); + } + if (block.toolUse() != null) { + ToolUseBlock toolUse = block.toolUse(); + Map callMap = new LinkedHashMap<>(); + callMap.put("id", toolUse.toolUseId()); + callMap.put("type", "function"); + Map fnMap = new LinkedHashMap<>(); + fnMap.put("name", toolUse.name()); + fnMap.put("arguments", documentToMap(toolUse.input())); + callMap.put("function", fnMap); + callMap.put("original_id", toolUse.toolUseId()); + toolCalls.add(callMap); + } + } + + ChatMessage result = ChatMessage.assistant(stripMarkdownFences(textContent.toString())); + if (!toolCalls.isEmpty()) { + result.setToolCalls(toolCalls); + } + return result; + } + + /** Strip markdown code fences and extract JSON from mixed text responses. */ + private static String stripMarkdownFences(String text) { + if (text == null) return null; + String trimmed = text.trim(); + // Strip ```json ... ``` fences + if (trimmed.startsWith("```")) { + int firstNewline = trimmed.indexOf('\n'); + if (firstNewline >= 0) { + trimmed = trimmed.substring(firstNewline + 1); + } + if (trimmed.endsWith("```")) { + trimmed = trimmed.substring(0, trimmed.length() - 3).trim(); + } + return trimmed; + } + // Extract first JSON object from mixed text + int start = trimmed.indexOf('{'); + int end = trimmed.lastIndexOf('}'); + if (start >= 0 && end > start) { + return trimmed.substring(start, end + 1); + } + return trimmed; + } + + @SuppressWarnings("unchecked") + private software.amazon.awssdk.core.document.Document toDocument(Object obj) { + if (obj == null) { + return software.amazon.awssdk.core.document.Document.fromNull(); + } + if (obj instanceof Map) { + Map docMap = new LinkedHashMap<>(); + ((Map) obj).forEach((k, v) -> docMap.put(k, toDocument(v))); + return software.amazon.awssdk.core.document.Document.fromMap(docMap); + } + if (obj instanceof List) { + return software.amazon.awssdk.core.document.Document.fromList( + ((List) obj).stream().map(this::toDocument).collect(Collectors.toList())); + } + if (obj instanceof String) { + return software.amazon.awssdk.core.document.Document.fromString((String) obj); + } + if (obj instanceof Number) { + return software.amazon.awssdk.core.document.Document.fromNumber( + software.amazon.awssdk.core.SdkNumber.fromBigDecimal( + new java.math.BigDecimal(obj.toString()))); + } + if (obj instanceof Boolean) { + return software.amazon.awssdk.core.document.Document.fromBoolean((Boolean) obj); + } + return software.amazon.awssdk.core.document.Document.fromString(obj.toString()); + } + + @SuppressWarnings("unchecked") + private Map documentToMap(software.amazon.awssdk.core.document.Document doc) { + if (doc == null || !doc.isMap()) { + return Collections.emptyMap(); + } + Map result = new LinkedHashMap<>(); + doc.asMap().forEach((k, v) -> result.put(k, documentToObject(v))); + return result; + } + + private Object documentToObject(software.amazon.awssdk.core.document.Document doc) { + if (doc == null || doc.isNull()) return null; + if (doc.isString()) return doc.asString(); + if (doc.isNumber()) return doc.asNumber().bigDecimalValue(); + if (doc.isBoolean()) return doc.asBoolean(); + if (doc.isList()) { + return doc.asList().stream().map(this::documentToObject).collect(Collectors.toList()); + } + if (doc.isMap()) return documentToMap(doc); + return doc.toString(); + } +} diff --git a/integrations/chat-models/bedrock/src/main/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelSetup.java b/integrations/chat-models/bedrock/src/main/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelSetup.java new file mode 100644 index 00000000..74c9e4e2 --- /dev/null +++ b/integrations/chat-models/bedrock/src/main/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelSetup.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.integrations.chatmodels.bedrock; + +import org.apache.flink.agents.api.chat.model.BaseChatModelSetup; +import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.apache.flink.agents.api.resource.ResourceType; + +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; +import java.util.function.BiFunction; + +/** + * Chat model setup for AWS Bedrock Converse API. + * + *

Supported parameters: + *

    + *
  • connection (required): name of the BedrockChatModelConnection resource
  • + *
  • model (required): Bedrock model ID (e.g. us.anthropic.claude-sonnet-4-20250514-v1:0)
  • + *
  • temperature (optional): sampling temperature (default 0.1)
  • + *
  • prompt (optional): prompt resource name
  • + *
  • tools (optional): list of tool resource names
  • + *
+ */ +public class BedrockChatModelSetup extends BaseChatModelSetup { + + private final Double temperature; + + public BedrockChatModelSetup( + ResourceDescriptor descriptor, + BiFunction getResource) { + super(descriptor, getResource); + this.temperature = Optional.ofNullable(descriptor.getArgument("temperature")) + .map(Number::doubleValue) + .orElse(0.1); + } + + @Override + public Map getParameters() { + Map params = new HashMap<>(); + if (model != null) { + params.put("model", model); + } + params.put("temperature", temperature); + return params; + } +} diff --git a/integrations/chat-models/bedrock/src/test/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelConnectionTest.java b/integrations/chat-models/bedrock/src/test/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelConnectionTest.java new file mode 100644 index 00000000..fb2234fd --- /dev/null +++ b/integrations/chat-models/bedrock/src/test/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelConnectionTest.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.integrations.chatmodels.bedrock; + +import org.apache.flink.agents.api.chat.messages.ChatMessage; +import org.apache.flink.agents.api.chat.messages.MessageRole; +import org.apache.flink.agents.api.chat.model.BaseChatModelConnection; +import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.apache.flink.agents.api.resource.ResourceType; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +import java.util.*; +import java.util.function.BiFunction; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.junit.jupiter.api.Assertions.assertNotNull; + +/** Tests for {@link BedrockChatModelConnection}. */ +class BedrockChatModelConnectionTest { + + private static final BiFunction NOOP = (a, b) -> null; + + private static ResourceDescriptor descriptor(String region, String model) { + ResourceDescriptor.Builder b = + ResourceDescriptor.Builder.newBuilder(BedrockChatModelConnection.class.getName()); + if (region != null) b.addInitialArgument("region", region); + if (model != null) b.addInitialArgument("model", model); + return b.build(); + } + + @Test + @DisplayName("Constructor creates client with default region") + void testConstructorDefaultRegion() { + BedrockChatModelConnection conn = new BedrockChatModelConnection( + descriptor(null, "us.anthropic.claude-sonnet-4-20250514-v1:0"), NOOP); + assertNotNull(conn); + } + + @Test + @DisplayName("Constructor creates client with explicit region") + void testConstructorExplicitRegion() { + BedrockChatModelConnection conn = new BedrockChatModelConnection( + descriptor("us-west-2", "us.anthropic.claude-sonnet-4-20250514-v1:0"), NOOP); + assertNotNull(conn); + } + + @Test + @DisplayName("Extends BaseChatModelConnection") + void testInheritance() { + BedrockChatModelConnection conn = new BedrockChatModelConnection( + descriptor("us-east-1", "test-model"), NOOP); + assertThat(conn).isInstanceOf(BaseChatModelConnection.class); + } + + @Test + @DisplayName("Chat throws when no model specified") + void testChatThrowsWithoutModel() { + BedrockChatModelConnection conn = new BedrockChatModelConnection( + descriptor("us-east-1", null), NOOP); + List msgs = List.of(new ChatMessage(MessageRole.USER, "hello")); + assertThatThrownBy(() -> conn.chat(msgs, null, Collections.emptyMap())) + .isInstanceOf(RuntimeException.class); + } +} diff --git a/integrations/chat-models/bedrock/src/test/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelSetupTest.java b/integrations/chat-models/bedrock/src/test/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelSetupTest.java new file mode 100644 index 00000000..88f2fa13 --- /dev/null +++ b/integrations/chat-models/bedrock/src/test/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelSetupTest.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.integrations.chatmodels.bedrock; + +import org.apache.flink.agents.api.chat.model.BaseChatModelSetup; +import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.apache.flink.agents.api.resource.ResourceType; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +import java.util.Map; +import java.util.function.BiFunction; + +import static org.assertj.core.api.Assertions.assertThat; + +/** Tests for {@link BedrockChatModelSetup}. */ +class BedrockChatModelSetupTest { + + private static final BiFunction NOOP = (a, b) -> null; + + @Test + @DisplayName("getParameters includes model and default temperature") + void testGetParametersDefaults() { + ResourceDescriptor desc = ResourceDescriptor.Builder + .newBuilder(BedrockChatModelSetup.class.getName()) + .addInitialArgument("connection", "conn") + .addInitialArgument("model", "us.anthropic.claude-sonnet-4-20250514-v1:0") + .build(); + BedrockChatModelSetup setup = new BedrockChatModelSetup(desc, NOOP); + + Map params = setup.getParameters(); + assertThat(params).containsEntry("model", "us.anthropic.claude-sonnet-4-20250514-v1:0"); + assertThat(params).containsEntry("temperature", 0.1); + } + + @Test + @DisplayName("getParameters uses custom temperature") + void testGetParametersCustomTemperature() { + ResourceDescriptor desc = ResourceDescriptor.Builder + .newBuilder(BedrockChatModelSetup.class.getName()) + .addInitialArgument("connection", "conn") + .addInitialArgument("model", "test-model") + .addInitialArgument("temperature", 0.7) + .build(); + BedrockChatModelSetup setup = new BedrockChatModelSetup(desc, NOOP); + + assertThat(setup.getParameters()).containsEntry("temperature", 0.7); + } + + @Test + @DisplayName("Extends BaseChatModelSetup") + void testInheritance() { + ResourceDescriptor desc = ResourceDescriptor.Builder + .newBuilder(BedrockChatModelSetup.class.getName()) + .addInitialArgument("connection", "conn") + .addInitialArgument("model", "m") + .build(); + assertThat(new BedrockChatModelSetup(desc, NOOP)).isInstanceOf(BaseChatModelSetup.class); + } +} diff --git a/integrations/chat-models/pom.xml b/integrations/chat-models/pom.xml index 20b1b425..e5f4b9d4 100644 --- a/integrations/chat-models/pom.xml +++ b/integrations/chat-models/pom.xml @@ -33,6 +33,7 @@ under the License. anthropic azureai + bedrock ollama openai diff --git a/integrations/embedding-models/bedrock/pom.xml b/integrations/embedding-models/bedrock/pom.xml new file mode 100644 index 00000000..353c32c8 --- /dev/null +++ b/integrations/embedding-models/bedrock/pom.xml @@ -0,0 +1,48 @@ + + + + 4.0.0 + + + org.apache.flink + flink-agents-integrations-embedding-models + 0.3-SNAPSHOT + ../pom.xml + + + flink-agents-integrations-embedding-models-bedrock + Flink Agents : Integrations: Embedding Models: Bedrock + jar + + + + org.apache.flink + flink-agents-api + ${project.version} + + + + software.amazon.awssdk + bedrockruntime + ${aws.sdk.version} + + + + diff --git a/integrations/embedding-models/bedrock/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/bedrock/BedrockEmbeddingModelConnection.java b/integrations/embedding-models/bedrock/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/bedrock/BedrockEmbeddingModelConnection.java new file mode 100644 index 00000000..5d655718 --- /dev/null +++ b/integrations/embedding-models/bedrock/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/bedrock/BedrockEmbeddingModelConnection.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.integrations.embeddingmodels.bedrock; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ObjectNode; +import org.apache.flink.agents.api.embedding.model.BaseEmbeddingModelConnection; +import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.apache.flink.agents.api.resource.ResourceType; +import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.function.BiFunction; + +/** + * Bedrock embedding model connection using Amazon Titan Text Embeddings V2. + * + *

Uses the InvokeModel API to generate embeddings. Supports configurable + * dimensions (256, 512, or 1024) and normalization. + * + *

Parameters: + *

    + *
  • region (optional): AWS region, defaults to us-east-1
  • + *
  • model (optional): default model ID, defaults to amazon.titan-embed-text-v2:0
  • + *
+ */ +public class BedrockEmbeddingModelConnection extends BaseEmbeddingModelConnection { + + private static final ObjectMapper MAPPER = new ObjectMapper(); + private static final String DEFAULT_MODEL = "amazon.titan-embed-text-v2:0"; + + private final BedrockRuntimeClient client; + private final String defaultModel; + private final ExecutorService embedPool; + + public BedrockEmbeddingModelConnection( + ResourceDescriptor descriptor, + BiFunction getResource) { + super(descriptor, getResource); + + String region = descriptor.getArgument("region"); + if (region == null || region.isBlank()) region = "us-east-1"; + + this.client = BedrockRuntimeClient.builder() + .region(Region.of(region)) + .credentialsProvider(DefaultCredentialsProvider.create()) + .build(); + + String model = descriptor.getArgument("model"); + this.defaultModel = (model != null && !model.isBlank()) ? model : DEFAULT_MODEL; + + Integer concurrency = descriptor.getArgument("embed_concurrency"); + int threads = concurrency != null ? concurrency : 4; + this.embedPool = Executors.newFixedThreadPool(threads); + } + + @Override + public float[] embed(String text, Map parameters) { + try { + String model = (String) parameters.getOrDefault("model", defaultModel); + Integer dimensions = (Integer) parameters.get("dimensions"); + + ObjectNode body = MAPPER.createObjectNode(); + body.put("inputText", text); + if (dimensions != null) { + body.put("dimensions", dimensions); + } + body.put("normalize", true); + + InvokeModelResponse response = client.invokeModel(InvokeModelRequest.builder() + .modelId(model) + .contentType("application/json") + .body(SdkBytes.fromUtf8String(body.toString())) + .build()); + + JsonNode result = MAPPER.readTree(response.body().asUtf8String()); + JsonNode embeddingNode = result.get("embedding"); + float[] embedding = new float[embeddingNode.size()]; + for (int i = 0; i < embeddingNode.size(); i++) { + embedding[i] = (float) embeddingNode.get(i).asDouble(); + } + return embedding; + } catch (Exception e) { + throw new RuntimeException("Failed to generate Bedrock embedding.", e); + } + } + + @Override + public List embed(List texts, Map parameters) { + if (texts.size() <= 1) { + List results = new ArrayList<>(texts.size()); + for (String text : texts) results.add(embed(text, parameters)); + return results; + } + // Parallelize — Titan V2 is single-text per call, so concurrent requests help throughput + @SuppressWarnings("unchecked") + CompletableFuture[] futures = texts.stream() + .map(text -> CompletableFuture.supplyAsync(() -> embed(text, parameters), embedPool)) + .toArray(CompletableFuture[]::new); + CompletableFuture.allOf(futures).join(); + List results = new ArrayList<>(texts.size()); + for (CompletableFuture f : futures) results.add(f.join()); + return results; + } +} diff --git a/integrations/embedding-models/bedrock/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/bedrock/BedrockEmbeddingModelSetup.java b/integrations/embedding-models/bedrock/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/bedrock/BedrockEmbeddingModelSetup.java new file mode 100644 index 00000000..f022f005 --- /dev/null +++ b/integrations/embedding-models/bedrock/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/bedrock/BedrockEmbeddingModelSetup.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.integrations.embeddingmodels.bedrock; + +import org.apache.flink.agents.api.embedding.model.BaseEmbeddingModelSetup; +import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.apache.flink.agents.api.resource.ResourceType; + +import java.util.HashMap; +import java.util.Map; +import java.util.function.BiFunction; + +/** + * Embedding model setup for Bedrock Titan Text Embeddings. + * + *

Parameters: + *

    + *
  • connection (required): name of the BedrockEmbeddingModelConnection resource
  • + *
  • model (optional): model ID (default: amazon.titan-embed-text-v2:0)
  • + *
  • dimensions (optional): embedding dimensions (256, 512, or 1024)
  • + *
+ */ +public class BedrockEmbeddingModelSetup extends BaseEmbeddingModelSetup { + + private final Integer dimensions; + + public BedrockEmbeddingModelSetup( + ResourceDescriptor descriptor, + BiFunction getResource) { + super(descriptor, getResource); + this.dimensions = descriptor.getArgument("dimensions"); + } + + @Override + public Map getParameters() { + Map params = new HashMap<>(); + if (model != null) params.put("model", model); + if (dimensions != null) params.put("dimensions", dimensions); + return params; + } +} diff --git a/integrations/embedding-models/bedrock/src/test/java/org/apache/flink/agents/integrations/embeddingmodels/bedrock/BedrockEmbeddingModelTest.java b/integrations/embedding-models/bedrock/src/test/java/org/apache/flink/agents/integrations/embeddingmodels/bedrock/BedrockEmbeddingModelTest.java new file mode 100644 index 00000000..21dda8ee --- /dev/null +++ b/integrations/embedding-models/bedrock/src/test/java/org/apache/flink/agents/integrations/embeddingmodels/bedrock/BedrockEmbeddingModelTest.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.integrations.embeddingmodels.bedrock; + +import org.apache.flink.agents.api.embedding.model.BaseEmbeddingModelConnection; +import org.apache.flink.agents.api.embedding.model.BaseEmbeddingModelSetup; +import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.apache.flink.agents.api.resource.ResourceType; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +import java.util.Map; +import java.util.function.BiFunction; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertNotNull; + +/** Tests for {@link BedrockEmbeddingModelConnection} and {@link BedrockEmbeddingModelSetup}. */ +class BedrockEmbeddingModelTest { + + private static final BiFunction NOOP = (a, b) -> null; + + private static ResourceDescriptor connDescriptor(String region) { + ResourceDescriptor.Builder b = ResourceDescriptor.Builder + .newBuilder(BedrockEmbeddingModelConnection.class.getName()); + if (region != null) b.addInitialArgument("region", region); + return b.build(); + } + + @Test + @DisplayName("Connection constructor creates client with defaults") + void testConnectionDefaults() { + BedrockEmbeddingModelConnection conn = + new BedrockEmbeddingModelConnection(connDescriptor(null), NOOP); + assertNotNull(conn); + assertThat(conn).isInstanceOf(BaseEmbeddingModelConnection.class); + } + + @Test + @DisplayName("Connection constructor with explicit region and concurrency") + void testConnectionExplicitParams() { + ResourceDescriptor desc = ResourceDescriptor.Builder + .newBuilder(BedrockEmbeddingModelConnection.class.getName()) + .addInitialArgument("region", "eu-west-1") + .addInitialArgument("embed_concurrency", 8) + .build(); + BedrockEmbeddingModelConnection conn = + new BedrockEmbeddingModelConnection(desc, NOOP); + assertNotNull(conn); + } + + @Test + @DisplayName("Setup getParameters includes model and dimensions") + void testSetupParameters() { + ResourceDescriptor desc = ResourceDescriptor.Builder + .newBuilder(BedrockEmbeddingModelSetup.class.getName()) + .addInitialArgument("connection", "conn") + .addInitialArgument("model", "amazon.titan-embed-text-v2:0") + .addInitialArgument("dimensions", 1024) + .build(); + BedrockEmbeddingModelSetup setup = new BedrockEmbeddingModelSetup(desc, NOOP); + + Map params = setup.getParameters(); + assertThat(params).containsEntry("model", "amazon.titan-embed-text-v2:0"); + assertThat(params).containsEntry("dimensions", 1024); + assertThat(setup).isInstanceOf(BaseEmbeddingModelSetup.class); + } + + @Test + @DisplayName("Setup getParameters omits null dimensions") + void testSetupParametersNoDimensions() { + ResourceDescriptor desc = ResourceDescriptor.Builder + .newBuilder(BedrockEmbeddingModelSetup.class.getName()) + .addInitialArgument("connection", "conn") + .addInitialArgument("model", "amazon.titan-embed-text-v2:0") + .build(); + BedrockEmbeddingModelSetup setup = new BedrockEmbeddingModelSetup(desc, NOOP); + + assertThat(setup.getParameters()).doesNotContainKey("dimensions"); + } +} diff --git a/integrations/embedding-models/pom.xml b/integrations/embedding-models/pom.xml index f1bc6c08..1845a480 100644 --- a/integrations/embedding-models/pom.xml +++ b/integrations/embedding-models/pom.xml @@ -31,6 +31,7 @@ under the License. pom + bedrock ollama diff --git a/integrations/pom.xml b/integrations/pom.xml index 0e5df222..9989a5f0 100644 --- a/integrations/pom.xml +++ b/integrations/pom.xml @@ -35,6 +35,7 @@ under the License. 8.19.0 4.8.0 2.11.1 + 2.32.16 From 481d7a97ea390751b2f8cc5cf69d49e4317493b6 Mon Sep 17 00:00:00 2001 From: Avichay Marciano Date: Tue, 10 Feb 2026 00:30:22 +0200 Subject: [PATCH 2/8] Add exponential backoff retry for throttling and model errors in BedrockEmbeddingModelConnection --- .../BedrockEmbeddingModelConnection.java | 64 +++++++++++-------- 1 file changed, 39 insertions(+), 25 deletions(-) diff --git a/integrations/embedding-models/bedrock/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/bedrock/BedrockEmbeddingModelConnection.java b/integrations/embedding-models/bedrock/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/bedrock/BedrockEmbeddingModelConnection.java index 5d655718..b569bee1 100644 --- a/integrations/embedding-models/bedrock/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/bedrock/BedrockEmbeddingModelConnection.java +++ b/integrations/embedding-models/bedrock/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/bedrock/BedrockEmbeddingModelConnection.java @@ -83,35 +83,49 @@ public BedrockEmbeddingModelConnection( @Override public float[] embed(String text, Map parameters) { - try { - String model = (String) parameters.getOrDefault("model", defaultModel); - Integer dimensions = (Integer) parameters.get("dimensions"); - - ObjectNode body = MAPPER.createObjectNode(); - body.put("inputText", text); - if (dimensions != null) { - body.put("dimensions", dimensions); - } - body.put("normalize", true); - - InvokeModelResponse response = client.invokeModel(InvokeModelRequest.builder() - .modelId(model) - .contentType("application/json") - .body(SdkBytes.fromUtf8String(body.toString())) - .build()); - - JsonNode result = MAPPER.readTree(response.body().asUtf8String()); - JsonNode embeddingNode = result.get("embedding"); - float[] embedding = new float[embeddingNode.size()]; - for (int i = 0; i < embeddingNode.size(); i++) { - embedding[i] = (float) embeddingNode.get(i).asDouble(); + String model = (String) parameters.getOrDefault("model", defaultModel); + Integer dimensions = (Integer) parameters.get("dimensions"); + + ObjectNode body = MAPPER.createObjectNode(); + body.put("inputText", text); + if (dimensions != null) { + body.put("dimensions", dimensions); + } + body.put("normalize", true); + + int maxRetries = 5; + for (int attempt = 0; ; attempt++) { + try { + InvokeModelResponse response = client.invokeModel(InvokeModelRequest.builder() + .modelId(model) + .contentType("application/json") + .body(SdkBytes.fromUtf8String(body.toString())) + .build()); + + JsonNode result = MAPPER.readTree(response.body().asUtf8String()); + JsonNode embeddingNode = result.get("embedding"); + float[] embedding = new float[embeddingNode.size()]; + for (int i = 0; i < embeddingNode.size(); i++) { + embedding[i] = (float) embeddingNode.get(i).asDouble(); + } + return embedding; + } catch (Exception e) { + if (attempt < maxRetries && isRetryable(e)) { + try { Thread.sleep((long) (Math.pow(2, attempt) * 200)); } + catch (InterruptedException ie) { Thread.currentThread().interrupt(); } + } else { + throw new RuntimeException("Failed to generate Bedrock embedding.", e); + } } - return embedding; - } catch (Exception e) { - throw new RuntimeException("Failed to generate Bedrock embedding.", e); } } + private static boolean isRetryable(Exception e) { + String msg = e.toString(); + return msg.contains("ThrottlingException") || msg.contains("ModelErrorException") + || msg.contains("429") || msg.contains("424") || msg.contains("503"); + } + @Override public List embed(List texts, Map parameters) { if (texts.size() <= 1) { From 34d38c2be98285aec38d1e5bd8fd7681163a7e6f Mon Sep 17 00:00:00 2001 From: Avichay Marciano Date: Sat, 14 Feb 2026 00:18:34 +0200 Subject: [PATCH 3/8] Improve Bedrock integration: add close(), max_tokens, retry jitter, Javadocs --- .../bedrock/BedrockChatModelConnection.java | 245 +++++++++++------- .../bedrock/BedrockChatModelSetup.java | 45 +++- .../BedrockChatModelConnectionTest.java | 19 +- .../bedrock/BedrockChatModelSetupTest.java | 32 +-- .../BedrockEmbeddingModelConnection.java | 100 ++++--- .../bedrock/BedrockEmbeddingModelSetup.java | 39 ++- .../bedrock/BedrockEmbeddingModelTest.java | 41 +-- 7 files changed, 339 insertions(+), 182 deletions(-) diff --git a/integrations/chat-models/bedrock/src/main/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelConnection.java b/integrations/chat-models/bedrock/src/main/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelConnection.java index 56b3b3da..85270cfe 100644 --- a/integrations/chat-models/bedrock/src/main/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelConnection.java +++ b/integrations/chat-models/bedrock/src/main/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelConnection.java @@ -15,11 +15,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.flink.agents.integrations.chatmodels.bedrock; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import org.apache.flink.agents.api.chat.messages.ChatMessage; import org.apache.flink.agents.api.chat.messages.MessageRole; @@ -30,22 +30,61 @@ import org.apache.flink.agents.api.tools.Tool; import org.apache.flink.agents.api.tools.ToolMetadata; import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; +import software.amazon.awssdk.core.SdkNumber; +import software.amazon.awssdk.core.document.Document; import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient; -import software.amazon.awssdk.services.bedrockruntime.model.*; +import software.amazon.awssdk.services.bedrockruntime.model.ContentBlock; +import software.amazon.awssdk.services.bedrockruntime.model.ConversationRole; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseRequest; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseResponse; +import software.amazon.awssdk.services.bedrockruntime.model.InferenceConfiguration; +import software.amazon.awssdk.services.bedrockruntime.model.Message; +import software.amazon.awssdk.services.bedrockruntime.model.SystemContentBlock; +import software.amazon.awssdk.services.bedrockruntime.model.ToolConfiguration; +import software.amazon.awssdk.services.bedrockruntime.model.ToolInputSchema; +import software.amazon.awssdk.services.bedrockruntime.model.ToolResultBlock; +import software.amazon.awssdk.services.bedrockruntime.model.ToolResultContentBlock; +import software.amazon.awssdk.services.bedrockruntime.model.ToolSpecification; +import software.amazon.awssdk.services.bedrockruntime.model.ToolUseBlock; -import java.util.*; +import java.math.BigDecimal; +import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; import java.util.function.BiFunction; import java.util.stream.Collectors; /** * Bedrock Converse API chat model connection for flink-agents. * + *

Uses the Converse API which provides a unified interface across all Bedrock models with native + * tool calling support. Authentication is handled via SigV4 using the default AWS credentials + * chain. + * + *

Future work: support reasoning content blocks (Claude extended thinking), citation blocks, and + * image/document content blocks. + * *

Supported connection parameters: + * *

    - *
  • region (optional): AWS region (defaults to us-east-1)
  • - *
  • model (optional): Default model ID (e.g. us.anthropic.claude-sonnet-4-20250514-v1:0)
  • + *
  • region (optional): AWS region (defaults to us-east-1) + *
  • model (optional): Default model ID (e.g. us.anthropic.claude-sonnet-4-20250514-v1:0) *
+ * + *

Example usage: + * + *

{@code
+ * @ChatModelConnection
+ * public static ResourceDescriptor bedrockConnection() {
+ *     return ResourceDescriptor.Builder.newBuilder(BedrockChatModelConnection.class.getName())
+ *             .addInitialArgument("region", "us-east-1")
+ *             .addInitialArgument("model", "us.anthropic.claude-sonnet-4-20250514-v1:0")
+ *             .build();
+ * }
+ * }
*/ public class BedrockChatModelConnection extends BaseChatModelConnection { @@ -54,8 +93,7 @@ public class BedrockChatModelConnection extends BaseChatModelConnection { private final String defaultModel; public BedrockChatModelConnection( - ResourceDescriptor descriptor, - BiFunction getResource) { + ResourceDescriptor descriptor, BiFunction getResource) { super(descriptor, getResource); String region = descriptor.getArgument("region"); @@ -63,10 +101,11 @@ public BedrockChatModelConnection( region = "us-east-1"; } - this.client = BedrockRuntimeClient.builder() - .region(Region.of(region)) - .credentialsProvider(DefaultCredentialsProvider.create()) - .build(); + this.client = + BedrockRuntimeClient.builder() + .region(Region.of(region)) + .credentialsProvider(DefaultCredentialsProvider.create()) + .build(); this.defaultModel = descriptor.getArgument("model"); } @@ -77,49 +116,62 @@ public ChatMessage chat( try { String modelId = resolveModel(arguments); - // Separate system messages from conversation messages - List systemMsgs = messages.stream() - .filter(m -> m.getRole() == MessageRole.SYSTEM) - .collect(Collectors.toList()); - List conversationMsgs = messages.stream() - .filter(m -> m.getRole() != MessageRole.SYSTEM) - .collect(Collectors.toList()); + List systemMsgs = + messages.stream() + .filter(m -> m.getRole() == MessageRole.SYSTEM) + .collect(Collectors.toList()); + List conversationMsgs = + messages.stream() + .filter(m -> m.getRole() != MessageRole.SYSTEM) + .collect(Collectors.toList()); - ConverseRequest.Builder requestBuilder = ConverseRequest.builder() - .modelId(modelId) - .messages(mergeMessages(conversationMsgs)); + ConverseRequest.Builder requestBuilder = + ConverseRequest.builder() + .modelId(modelId) + .messages(mergeMessages(conversationMsgs)); - // System prompt if (!systemMsgs.isEmpty()) { - requestBuilder.system(systemMsgs.stream() - .map(m -> SystemContentBlock.builder().text(m.getContent()).build()) - .collect(Collectors.toList())); + requestBuilder.system( + systemMsgs.stream() + .map(m -> SystemContentBlock.builder().text(m.getContent()).build()) + .collect(Collectors.toList())); } - // Tools if (tools != null && !tools.isEmpty()) { - requestBuilder.toolConfig(ToolConfiguration.builder() - .tools(tools.stream().map(this::toBedrockTool).collect(Collectors.toList())) - .build()); + requestBuilder.toolConfig( + ToolConfiguration.builder() + .tools( + tools.stream() + .map(this::toBedrockTool) + .collect(Collectors.toList())) + .build()); } - // Temperature + // Inference config: temperature and max_tokens if (arguments != null) { + InferenceConfiguration.Builder inferenceBuilder = null; Object temp = arguments.get("temperature"); if (temp instanceof Number) { - requestBuilder.inferenceConfig(InferenceConfiguration.builder() - .temperature(((Number) temp).floatValue()) - .build()); + inferenceBuilder = InferenceConfiguration.builder(); + inferenceBuilder.temperature(((Number) temp).floatValue()); + } + Object maxTokens = arguments.get("max_tokens"); + if (maxTokens instanceof Number) { + if (inferenceBuilder == null) { + inferenceBuilder = InferenceConfiguration.builder(); + } + inferenceBuilder.maxTokens(((Number) maxTokens).intValue()); + } + if (inferenceBuilder != null) { + requestBuilder.inferenceConfig(inferenceBuilder.build()); } } ConverseResponse response = client.converse(requestBuilder.build()); - // Record token metrics if (response.usage() != null) { - recordTokenMetrics(modelId, - response.usage().inputTokens(), - response.usage().outputTokens()); + recordTokenMetrics( + modelId, response.usage().inputTokens(), response.usage().outputTokens()); } return convertResponse(response); @@ -128,6 +180,11 @@ public ChatMessage chat( } } + @Override + public void close() throws Exception { + this.client.close(); + } + private String resolveModel(Map arguments) { String model = arguments != null ? (String) arguments.get("model") : null; if (model == null || model.isBlank()) { @@ -140,8 +197,8 @@ private String resolveModel(Map arguments) { } /** - * Merge consecutive TOOL messages into a single USER message with multiple - * toolResult content blocks, as required by Bedrock Converse API. + * Merge consecutive TOOL messages into a single USER message with multiple toolResult content + * blocks, as required by Bedrock Converse API. */ private List mergeMessages(List msgs) { List result = new ArrayList<>(); @@ -149,23 +206,26 @@ private List mergeMessages(List msgs) { while (i < msgs.size()) { ChatMessage msg = msgs.get(i); if (msg.getRole() == MessageRole.TOOL) { - // Collect all consecutive TOOL messages into one USER message List toolResultBlocks = new ArrayList<>(); while (i < msgs.size() && msgs.get(i).getRole() == MessageRole.TOOL) { ChatMessage toolMsg = msgs.get(i); String toolCallId = (String) toolMsg.getExtraArgs().get("externalId"); - toolResultBlocks.add(ContentBlock.fromToolResult(ToolResultBlock.builder() - .toolUseId(toolCallId) - .content(ToolResultContentBlock.builder() - .text(toolMsg.getContent()) - .build()) - .build())); + toolResultBlocks.add( + ContentBlock.fromToolResult( + ToolResultBlock.builder() + .toolUseId(toolCallId) + .content( + ToolResultContentBlock.builder() + .text(toolMsg.getContent()) + .build()) + .build())); i++; } - result.add(Message.builder() - .role(ConversationRole.USER) - .content(toolResultBlocks) - .build()); + result.add( + Message.builder() + .role(ConversationRole.USER) + .content(toolResultBlocks) + .build()); } else { result.add(toBedrockMessage(msg)); i++; @@ -186,7 +246,6 @@ private Message toBedrockMessage(ChatMessage msg) { if (msg.getContent() != null && !msg.getContent().isEmpty()) { blocks.add(ContentBlock.fromText(msg.getContent())); } - // Re-emit tool use blocks for multi-turn tool calling if (msg.getToolCalls() != null && !msg.getToolCalls().isEmpty()) { for (Map call : msg.getToolCalls()) { @SuppressWarnings("unchecked") @@ -194,45 +253,46 @@ private Message toBedrockMessage(ChatMessage msg) { String toolUseId = (String) call.get("id"); String name = (String) fn.get("name"); Object args = fn.get("arguments"); - blocks.add(ContentBlock.fromToolUse(ToolUseBlock.builder() - .toolUseId(toolUseId) - .name(name) - .input(toDocument(args)) - .build())); + blocks.add( + ContentBlock.fromToolUse( + ToolUseBlock.builder() + .toolUseId(toolUseId) + .name(name) + .input(toDocument(args)) + .build())); } } - return Message.builder() - .role(ConversationRole.ASSISTANT) - .content(blocks) - .build(); + return Message.builder().role(ConversationRole.ASSISTANT).content(blocks).build(); case TOOL: String toolCallId = (String) msg.getExtraArgs().get("externalId"); return Message.builder() .role(ConversationRole.USER) - .content(ContentBlock.fromToolResult(ToolResultBlock.builder() - .toolUseId(toolCallId) - .content(ToolResultContentBlock.builder() - .text(msg.getContent()) - .build()) - .build())) + .content( + ContentBlock.fromToolResult( + ToolResultBlock.builder() + .toolUseId(toolCallId) + .content( + ToolResultContentBlock.builder() + .text(msg.getContent()) + .build()) + .build())) .build(); default: - throw new IllegalArgumentException("Unsupported role for Bedrock: " + msg.getRole()); + throw new IllegalArgumentException( + "Unsupported role for Bedrock: " + msg.getRole()); } } private software.amazon.awssdk.services.bedrockruntime.model.Tool toBedrockTool(Tool tool) { ToolMetadata meta = tool.getMetadata(); - software.amazon.awssdk.services.bedrockruntime.model.ToolSpecification.Builder specBuilder = - software.amazon.awssdk.services.bedrockruntime.model.ToolSpecification.builder() - .name(meta.getName()) - .description(meta.getDescription()); + ToolSpecification.Builder specBuilder = + ToolSpecification.builder().name(meta.getName()).description(meta.getDescription()); String schema = meta.getInputSchema(); if (schema != null && !schema.isBlank()) { try { - Map schemaMap = MAPPER.readValue(schema, - new TypeReference>() {}); + Map schemaMap = + MAPPER.readValue(schema, new TypeReference>() {}); specBuilder.inputSchema(ToolInputSchema.fromJson(toDocument(schemaMap))); } catch (JsonProcessingException e) { throw new RuntimeException("Failed to parse tool schema.", e); @@ -274,11 +334,17 @@ private ChatMessage convertResponse(ConverseResponse response) { return result; } - /** Strip markdown code fences and extract JSON from mixed text responses. */ + /** + * Strip markdown code fences and extract JSON from mixed text responses. Some Bedrock models + * wrap JSON output in markdown fences or add prose around it. + * + *

Note: Unlike OpenAI's strict mode, Bedrock models do not guarantee pure JSON output. The + * flink-agents framework's {@code ChatModelAction.generateStructuredOutput} expects clean JSON, + * so this extraction is necessary at the connection layer. + */ private static String stripMarkdownFences(String text) { if (text == null) return null; String trimmed = text.trim(); - // Strip ```json ... ``` fences if (trimmed.startsWith("```")) { int firstNewline = trimmed.indexOf('\n'); if (firstNewline >= 0) { @@ -289,7 +355,6 @@ private static String stripMarkdownFences(String text) { } return trimmed; } - // Extract first JSON object from mixed text int start = trimmed.indexOf('{'); int end = trimmed.lastIndexOf('}'); if (start >= 0 && end > start) { @@ -299,35 +364,33 @@ private static String stripMarkdownFences(String text) { } @SuppressWarnings("unchecked") - private software.amazon.awssdk.core.document.Document toDocument(Object obj) { + private Document toDocument(Object obj) { if (obj == null) { - return software.amazon.awssdk.core.document.Document.fromNull(); + return Document.fromNull(); } if (obj instanceof Map) { - Map docMap = new LinkedHashMap<>(); + Map docMap = new LinkedHashMap<>(); ((Map) obj).forEach((k, v) -> docMap.put(k, toDocument(v))); - return software.amazon.awssdk.core.document.Document.fromMap(docMap); + return Document.fromMap(docMap); } if (obj instanceof List) { - return software.amazon.awssdk.core.document.Document.fromList( - ((List) obj).stream().map(this::toDocument).collect(Collectors.toList())); + return Document.fromList( + ((List) obj) + .stream().map(this::toDocument).collect(Collectors.toList())); } if (obj instanceof String) { - return software.amazon.awssdk.core.document.Document.fromString((String) obj); + return Document.fromString((String) obj); } if (obj instanceof Number) { - return software.amazon.awssdk.core.document.Document.fromNumber( - software.amazon.awssdk.core.SdkNumber.fromBigDecimal( - new java.math.BigDecimal(obj.toString()))); + return Document.fromNumber(SdkNumber.fromBigDecimal(new BigDecimal(obj.toString()))); } if (obj instanceof Boolean) { - return software.amazon.awssdk.core.document.Document.fromBoolean((Boolean) obj); + return Document.fromBoolean((Boolean) obj); } - return software.amazon.awssdk.core.document.Document.fromString(obj.toString()); + return Document.fromString(obj.toString()); } - @SuppressWarnings("unchecked") - private Map documentToMap(software.amazon.awssdk.core.document.Document doc) { + private Map documentToMap(Document doc) { if (doc == null || !doc.isMap()) { return Collections.emptyMap(); } @@ -336,7 +399,7 @@ private Map documentToMap(software.amazon.awssdk.core.document.D return result; } - private Object documentToObject(software.amazon.awssdk.core.document.Document doc) { + private Object documentToObject(Document doc) { if (doc == null || doc.isNull()) return null; if (doc.isString()) return doc.asString(); if (doc.isNumber()) return doc.asNumber().bigDecimalValue(); diff --git a/integrations/chat-models/bedrock/src/main/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelSetup.java b/integrations/chat-models/bedrock/src/main/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelSetup.java index 74c9e4e2..cbcd380b 100644 --- a/integrations/chat-models/bedrock/src/main/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelSetup.java +++ b/integrations/chat-models/bedrock/src/main/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelSetup.java @@ -15,6 +15,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.flink.agents.integrations.chatmodels.bedrock; import org.apache.flink.agents.api.chat.model.BaseChatModelSetup; @@ -31,25 +32,46 @@ * Chat model setup for AWS Bedrock Converse API. * *

Supported parameters: + * *

    - *
  • connection (required): name of the BedrockChatModelConnection resource
  • - *
  • model (required): Bedrock model ID (e.g. us.anthropic.claude-sonnet-4-20250514-v1:0)
  • - *
  • temperature (optional): sampling temperature (default 0.1)
  • - *
  • prompt (optional): prompt resource name
  • - *
  • tools (optional): list of tool resource names
  • + *
  • connection (required): name of the BedrockChatModelConnection resource + *
  • model (required): Bedrock model ID (e.g. us.anthropic.claude-sonnet-4-20250514-v1:0) + *
  • temperature (optional): sampling temperature (default 0.1) + *
  • max_tokens (optional): maximum tokens in the response + *
  • prompt (optional): prompt resource name + *
  • tools (optional): list of tool resource names *
+ * + *

Example usage: + * + *

{@code
+ * @ChatModelSetup
+ * public static ResourceDescriptor bedrockModel() {
+ *     return ResourceDescriptor.Builder.newBuilder(BedrockChatModelSetup.class.getName())
+ *             .addInitialArgument("connection", "bedrockConnection")
+ *             .addInitialArgument("model", "us.anthropic.claude-sonnet-4-20250514-v1:0")
+ *             .addInitialArgument("temperature", 0.1)
+ *             .addInitialArgument("max_tokens", 4096)
+ *             .build();
+ * }
+ * }
*/ public class BedrockChatModelSetup extends BaseChatModelSetup { private final Double temperature; + private final Integer maxTokens; public BedrockChatModelSetup( - ResourceDescriptor descriptor, - BiFunction getResource) { + ResourceDescriptor descriptor, BiFunction getResource) { super(descriptor, getResource); - this.temperature = Optional.ofNullable(descriptor.getArgument("temperature")) - .map(Number::doubleValue) - .orElse(0.1); + this.temperature = + Optional.ofNullable(descriptor.getArgument("temperature")) + .map(Number::doubleValue) + .orElse(0.1); + this.maxTokens = + Optional.ofNullable(descriptor.getArgument("max_tokens")) + .map(Number::intValue) + .orElse(null); } @Override @@ -59,6 +81,9 @@ public Map getParameters() { params.put("model", model); } params.put("temperature", temperature); + if (maxTokens != null) { + params.put("max_tokens", maxTokens); + } return params; } } diff --git a/integrations/chat-models/bedrock/src/test/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelConnectionTest.java b/integrations/chat-models/bedrock/src/test/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelConnectionTest.java index fb2234fd..1fed3f2e 100644 --- a/integrations/chat-models/bedrock/src/test/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelConnectionTest.java +++ b/integrations/chat-models/bedrock/src/test/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelConnectionTest.java @@ -50,32 +50,35 @@ private static ResourceDescriptor descriptor(String region, String model) { @Test @DisplayName("Constructor creates client with default region") void testConstructorDefaultRegion() { - BedrockChatModelConnection conn = new BedrockChatModelConnection( - descriptor(null, "us.anthropic.claude-sonnet-4-20250514-v1:0"), NOOP); + BedrockChatModelConnection conn = + new BedrockChatModelConnection( + descriptor(null, "us.anthropic.claude-sonnet-4-20250514-v1:0"), NOOP); assertNotNull(conn); } @Test @DisplayName("Constructor creates client with explicit region") void testConstructorExplicitRegion() { - BedrockChatModelConnection conn = new BedrockChatModelConnection( - descriptor("us-west-2", "us.anthropic.claude-sonnet-4-20250514-v1:0"), NOOP); + BedrockChatModelConnection conn = + new BedrockChatModelConnection( + descriptor("us-west-2", "us.anthropic.claude-sonnet-4-20250514-v1:0"), + NOOP); assertNotNull(conn); } @Test @DisplayName("Extends BaseChatModelConnection") void testInheritance() { - BedrockChatModelConnection conn = new BedrockChatModelConnection( - descriptor("us-east-1", "test-model"), NOOP); + BedrockChatModelConnection conn = + new BedrockChatModelConnection(descriptor("us-east-1", "test-model"), NOOP); assertThat(conn).isInstanceOf(BaseChatModelConnection.class); } @Test @DisplayName("Chat throws when no model specified") void testChatThrowsWithoutModel() { - BedrockChatModelConnection conn = new BedrockChatModelConnection( - descriptor("us-east-1", null), NOOP); + BedrockChatModelConnection conn = + new BedrockChatModelConnection(descriptor("us-east-1", null), NOOP); List msgs = List.of(new ChatMessage(MessageRole.USER, "hello")); assertThatThrownBy(() -> conn.chat(msgs, null, Collections.emptyMap())) .isInstanceOf(RuntimeException.class); diff --git a/integrations/chat-models/bedrock/src/test/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelSetupTest.java b/integrations/chat-models/bedrock/src/test/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelSetupTest.java index 88f2fa13..05094f02 100644 --- a/integrations/chat-models/bedrock/src/test/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelSetupTest.java +++ b/integrations/chat-models/bedrock/src/test/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelSetupTest.java @@ -38,11 +38,11 @@ class BedrockChatModelSetupTest { @Test @DisplayName("getParameters includes model and default temperature") void testGetParametersDefaults() { - ResourceDescriptor desc = ResourceDescriptor.Builder - .newBuilder(BedrockChatModelSetup.class.getName()) - .addInitialArgument("connection", "conn") - .addInitialArgument("model", "us.anthropic.claude-sonnet-4-20250514-v1:0") - .build(); + ResourceDescriptor desc = + ResourceDescriptor.Builder.newBuilder(BedrockChatModelSetup.class.getName()) + .addInitialArgument("connection", "conn") + .addInitialArgument("model", "us.anthropic.claude-sonnet-4-20250514-v1:0") + .build(); BedrockChatModelSetup setup = new BedrockChatModelSetup(desc, NOOP); Map params = setup.getParameters(); @@ -53,12 +53,12 @@ void testGetParametersDefaults() { @Test @DisplayName("getParameters uses custom temperature") void testGetParametersCustomTemperature() { - ResourceDescriptor desc = ResourceDescriptor.Builder - .newBuilder(BedrockChatModelSetup.class.getName()) - .addInitialArgument("connection", "conn") - .addInitialArgument("model", "test-model") - .addInitialArgument("temperature", 0.7) - .build(); + ResourceDescriptor desc = + ResourceDescriptor.Builder.newBuilder(BedrockChatModelSetup.class.getName()) + .addInitialArgument("connection", "conn") + .addInitialArgument("model", "test-model") + .addInitialArgument("temperature", 0.7) + .build(); BedrockChatModelSetup setup = new BedrockChatModelSetup(desc, NOOP); assertThat(setup.getParameters()).containsEntry("temperature", 0.7); @@ -67,11 +67,11 @@ void testGetParametersCustomTemperature() { @Test @DisplayName("Extends BaseChatModelSetup") void testInheritance() { - ResourceDescriptor desc = ResourceDescriptor.Builder - .newBuilder(BedrockChatModelSetup.class.getName()) - .addInitialArgument("connection", "conn") - .addInitialArgument("model", "m") - .build(); + ResourceDescriptor desc = + ResourceDescriptor.Builder.newBuilder(BedrockChatModelSetup.class.getName()) + .addInitialArgument("connection", "conn") + .addInitialArgument("model", "m") + .build(); assertThat(new BedrockChatModelSetup(desc, NOOP)).isInstanceOf(BaseChatModelSetup.class); } } diff --git a/integrations/embedding-models/bedrock/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/bedrock/BedrockEmbeddingModelConnection.java b/integrations/embedding-models/bedrock/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/bedrock/BedrockEmbeddingModelConnection.java index b569bee1..2366d123 100644 --- a/integrations/embedding-models/bedrock/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/bedrock/BedrockEmbeddingModelConnection.java +++ b/integrations/embedding-models/bedrock/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/bedrock/BedrockEmbeddingModelConnection.java @@ -15,6 +15,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.flink.agents.integrations.embeddingmodels.bedrock; import com.fasterxml.jackson.databind.JsonNode; @@ -42,36 +43,55 @@ /** * Bedrock embedding model connection using Amazon Titan Text Embeddings V2. * - *

Uses the InvokeModel API to generate embeddings. Supports configurable - * dimensions (256, 512, or 1024) and normalization. + *

Uses the InvokeModel API to generate embeddings. Supports configurable dimensions (256, 512, + * or 1024) and normalization. Since Titan V2 processes one text per API call, batch embedding is + * parallelized via a configurable thread pool. + * + *

Supported connection parameters: * - *

Parameters: *

    - *
  • region (optional): AWS region, defaults to us-east-1
  • - *
  • model (optional): default model ID, defaults to amazon.titan-embed-text-v2:0
  • + *
  • region (optional): AWS region, defaults to us-east-1 + *
  • model (optional): default model ID, defaults to amazon.titan-embed-text-v2:0 + *
  • embed_concurrency (optional): thread pool size for parallel embedding (default: 4) *
+ * + *

Example usage: + * + *

{@code
+ * @EmbeddingModelConnection
+ * public static ResourceDescriptor bedrockEmbedding() {
+ *     return ResourceDescriptor.Builder.newBuilder(BedrockEmbeddingModelConnection.class.getName())
+ *             .addInitialArgument("region", "us-east-1")
+ *             .addInitialArgument("model", "amazon.titan-embed-text-v2:0")
+ *             .addInitialArgument("embed_concurrency", 8)
+ *             .build();
+ * }
+ * }
*/ public class BedrockEmbeddingModelConnection extends BaseEmbeddingModelConnection { private static final ObjectMapper MAPPER = new ObjectMapper(); private static final String DEFAULT_MODEL = "amazon.titan-embed-text-v2:0"; + private static final int MAX_RETRIES = 5; private final BedrockRuntimeClient client; private final String defaultModel; private final ExecutorService embedPool; public BedrockEmbeddingModelConnection( - ResourceDescriptor descriptor, - BiFunction getResource) { + ResourceDescriptor descriptor, BiFunction getResource) { super(descriptor, getResource); String region = descriptor.getArgument("region"); - if (region == null || region.isBlank()) region = "us-east-1"; + if (region == null || region.isBlank()) { + region = "us-east-1"; + } - this.client = BedrockRuntimeClient.builder() - .region(Region.of(region)) - .credentialsProvider(DefaultCredentialsProvider.create()) - .build(); + this.client = + BedrockRuntimeClient.builder() + .region(Region.of(region)) + .credentialsProvider(DefaultCredentialsProvider.create()) + .build(); String model = descriptor.getArgument("model"); this.defaultModel = (model != null && !model.isBlank()) ? model : DEFAULT_MODEL; @@ -93,14 +113,15 @@ public float[] embed(String text, Map parameters) { } body.put("normalize", true); - int maxRetries = 5; for (int attempt = 0; ; attempt++) { try { - InvokeModelResponse response = client.invokeModel(InvokeModelRequest.builder() - .modelId(model) - .contentType("application/json") - .body(SdkBytes.fromUtf8String(body.toString())) - .build()); + InvokeModelResponse response = + client.invokeModel( + InvokeModelRequest.builder() + .modelId(model) + .contentType("application/json") + .body(SdkBytes.fromUtf8String(body.toString())) + .build()); JsonNode result = MAPPER.readTree(response.body().asUtf8String()); JsonNode embeddingNode = result.get("embedding"); @@ -110,9 +131,14 @@ public float[] embed(String text, Map parameters) { } return embedding; } catch (Exception e) { - if (attempt < maxRetries && isRetryable(e)) { - try { Thread.sleep((long) (Math.pow(2, attempt) * 200)); } - catch (InterruptedException ie) { Thread.currentThread().interrupt(); } + if (attempt < MAX_RETRIES && isRetryable(e)) { + try { + long delay = + (long) (Math.pow(2, attempt) * 200 * (0.5 + Math.random() * 0.5)); + Thread.sleep(delay); + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + } } else { throw new RuntimeException("Failed to generate Bedrock embedding.", e); } @@ -122,25 +148,41 @@ public float[] embed(String text, Map parameters) { private static boolean isRetryable(Exception e) { String msg = e.toString(); - return msg.contains("ThrottlingException") || msg.contains("ModelErrorException") - || msg.contains("429") || msg.contains("424") || msg.contains("503"); + return msg.contains("ThrottlingException") + || msg.contains("ModelErrorException") + || msg.contains("429") + || msg.contains("424") + || msg.contains("503"); } @Override public List embed(List texts, Map parameters) { if (texts.size() <= 1) { List results = new ArrayList<>(texts.size()); - for (String text : texts) results.add(embed(text, parameters)); + for (String text : texts) { + results.add(embed(text, parameters)); + } return results; } - // Parallelize — Titan V2 is single-text per call, so concurrent requests help throughput @SuppressWarnings("unchecked") - CompletableFuture[] futures = texts.stream() - .map(text -> CompletableFuture.supplyAsync(() -> embed(text, parameters), embedPool)) - .toArray(CompletableFuture[]::new); + CompletableFuture[] futures = + texts.stream() + .map( + text -> + CompletableFuture.supplyAsync( + () -> embed(text, parameters), embedPool)) + .toArray(CompletableFuture[]::new); CompletableFuture.allOf(futures).join(); List results = new ArrayList<>(texts.size()); - for (CompletableFuture f : futures) results.add(f.join()); + for (CompletableFuture f : futures) { + results.add(f.join()); + } return results; } + + @Override + public void close() throws Exception { + this.embedPool.shutdown(); + this.client.close(); + } } diff --git a/integrations/embedding-models/bedrock/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/bedrock/BedrockEmbeddingModelSetup.java b/integrations/embedding-models/bedrock/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/bedrock/BedrockEmbeddingModelSetup.java index f022f005..90dc1934 100644 --- a/integrations/embedding-models/bedrock/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/bedrock/BedrockEmbeddingModelSetup.java +++ b/integrations/embedding-models/bedrock/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/bedrock/BedrockEmbeddingModelSetup.java @@ -15,6 +15,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.flink.agents.integrations.embeddingmodels.bedrock; import org.apache.flink.agents.api.embedding.model.BaseEmbeddingModelSetup; @@ -29,20 +30,33 @@ /** * Embedding model setup for Bedrock Titan Text Embeddings. * - *

Parameters: + *

Supported parameters: + * *

    - *
  • connection (required): name of the BedrockEmbeddingModelConnection resource
  • - *
  • model (optional): model ID (default: amazon.titan-embed-text-v2:0)
  • - *
  • dimensions (optional): embedding dimensions (256, 512, or 1024)
  • + *
  • connection (required): name of the BedrockEmbeddingModelConnection resource + *
  • model (optional): model ID (default: amazon.titan-embed-text-v2:0) + *
  • dimensions (optional): embedding dimensions (256, 512, or 1024) *
+ * + *

Example usage: + * + *

{@code
+ * @EmbeddingModelSetup
+ * public static ResourceDescriptor bedrockEmbeddingSetup() {
+ *     return ResourceDescriptor.Builder.newBuilder(BedrockEmbeddingModelSetup.class.getName())
+ *             .addInitialArgument("connection", "bedrockEmbedding")
+ *             .addInitialArgument("model", "amazon.titan-embed-text-v2:0")
+ *             .addInitialArgument("dimensions", 1024)
+ *             .build();
+ * }
+ * }
*/ public class BedrockEmbeddingModelSetup extends BaseEmbeddingModelSetup { private final Integer dimensions; public BedrockEmbeddingModelSetup( - ResourceDescriptor descriptor, - BiFunction getResource) { + ResourceDescriptor descriptor, BiFunction getResource) { super(descriptor, getResource); this.dimensions = descriptor.getArgument("dimensions"); } @@ -50,8 +64,17 @@ public BedrockEmbeddingModelSetup( @Override public Map getParameters() { Map params = new HashMap<>(); - if (model != null) params.put("model", model); - if (dimensions != null) params.put("dimensions", dimensions); + if (model != null) { + params.put("model", model); + } + if (dimensions != null) { + params.put("dimensions", dimensions); + } return params; } + + @Override + public BedrockEmbeddingModelConnection getConnection() { + return (BedrockEmbeddingModelConnection) super.getConnection(); + } } diff --git a/integrations/embedding-models/bedrock/src/test/java/org/apache/flink/agents/integrations/embeddingmodels/bedrock/BedrockEmbeddingModelTest.java b/integrations/embedding-models/bedrock/src/test/java/org/apache/flink/agents/integrations/embeddingmodels/bedrock/BedrockEmbeddingModelTest.java index 21dda8ee..3d2d3d07 100644 --- a/integrations/embedding-models/bedrock/src/test/java/org/apache/flink/agents/integrations/embeddingmodels/bedrock/BedrockEmbeddingModelTest.java +++ b/integrations/embedding-models/bedrock/src/test/java/org/apache/flink/agents/integrations/embeddingmodels/bedrock/BedrockEmbeddingModelTest.java @@ -38,8 +38,9 @@ class BedrockEmbeddingModelTest { private static final BiFunction NOOP = (a, b) -> null; private static ResourceDescriptor connDescriptor(String region) { - ResourceDescriptor.Builder b = ResourceDescriptor.Builder - .newBuilder(BedrockEmbeddingModelConnection.class.getName()); + ResourceDescriptor.Builder b = + ResourceDescriptor.Builder.newBuilder( + BedrockEmbeddingModelConnection.class.getName()); if (region != null) b.addInitialArgument("region", region); return b.build(); } @@ -56,25 +57,25 @@ void testConnectionDefaults() { @Test @DisplayName("Connection constructor with explicit region and concurrency") void testConnectionExplicitParams() { - ResourceDescriptor desc = ResourceDescriptor.Builder - .newBuilder(BedrockEmbeddingModelConnection.class.getName()) - .addInitialArgument("region", "eu-west-1") - .addInitialArgument("embed_concurrency", 8) - .build(); - BedrockEmbeddingModelConnection conn = - new BedrockEmbeddingModelConnection(desc, NOOP); + ResourceDescriptor desc = + ResourceDescriptor.Builder.newBuilder( + BedrockEmbeddingModelConnection.class.getName()) + .addInitialArgument("region", "eu-west-1") + .addInitialArgument("embed_concurrency", 8) + .build(); + BedrockEmbeddingModelConnection conn = new BedrockEmbeddingModelConnection(desc, NOOP); assertNotNull(conn); } @Test @DisplayName("Setup getParameters includes model and dimensions") void testSetupParameters() { - ResourceDescriptor desc = ResourceDescriptor.Builder - .newBuilder(BedrockEmbeddingModelSetup.class.getName()) - .addInitialArgument("connection", "conn") - .addInitialArgument("model", "amazon.titan-embed-text-v2:0") - .addInitialArgument("dimensions", 1024) - .build(); + ResourceDescriptor desc = + ResourceDescriptor.Builder.newBuilder(BedrockEmbeddingModelSetup.class.getName()) + .addInitialArgument("connection", "conn") + .addInitialArgument("model", "amazon.titan-embed-text-v2:0") + .addInitialArgument("dimensions", 1024) + .build(); BedrockEmbeddingModelSetup setup = new BedrockEmbeddingModelSetup(desc, NOOP); Map params = setup.getParameters(); @@ -86,11 +87,11 @@ void testSetupParameters() { @Test @DisplayName("Setup getParameters omits null dimensions") void testSetupParametersNoDimensions() { - ResourceDescriptor desc = ResourceDescriptor.Builder - .newBuilder(BedrockEmbeddingModelSetup.class.getName()) - .addInitialArgument("connection", "conn") - .addInitialArgument("model", "amazon.titan-embed-text-v2:0") - .build(); + ResourceDescriptor desc = + ResourceDescriptor.Builder.newBuilder(BedrockEmbeddingModelSetup.class.getName()) + .addInitialArgument("connection", "conn") + .addInitialArgument("model", "amazon.titan-embed-text-v2:0") + .build(); BedrockEmbeddingModelSetup setup = new BedrockEmbeddingModelSetup(desc, NOOP); assertThat(setup.getParameters()).doesNotContainKey("dimensions"); From 6d50c47dda14cd7b9e88b8305e89dbe7bc79ae51 Mon Sep 17 00:00:00 2001 From: Avichay Marciano Date: Thu, 19 Feb 2026 09:36:26 +0200 Subject: [PATCH 4/8] Address PR review: add retry for transient errors, fix stripMarkdownFences MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add exponential backoff retry (MAX_RETRIES=5) for ThrottlingException, ServiceUnavailableException, ModelErrorException, 429, 503 — consistent with BedrockEmbeddingModelConnection in this PR. - Remove {..} JSON extraction fallback from stripMarkdownFences that could corrupt normal text responses containing braces. - Only apply markdown fence stripping on non-tool-call responses. - Add 5 unit tests for stripMarkdownFences covering: text with braces, clean JSON, json fences, plain fences, and null input. --- .../bedrock/BedrockChatModelConnection.java | 161 ++++++++++-------- .../BedrockChatModelConnectionTest.java | 38 +++++ 2 files changed, 131 insertions(+), 68 deletions(-) diff --git a/integrations/chat-models/bedrock/src/main/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelConnection.java b/integrations/chat-models/bedrock/src/main/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelConnection.java index 85270cfe..ddd19400 100644 --- a/integrations/chat-models/bedrock/src/main/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelConnection.java +++ b/integrations/chat-models/bedrock/src/main/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelConnection.java @@ -110,76 +110,104 @@ public BedrockChatModelConnection( this.defaultModel = descriptor.getArgument("model"); } + private static final int MAX_RETRIES = 5; + @Override public ChatMessage chat( List messages, List tools, Map arguments) { - try { - String modelId = resolveModel(arguments); - - List systemMsgs = - messages.stream() - .filter(m -> m.getRole() == MessageRole.SYSTEM) - .collect(Collectors.toList()); - List conversationMsgs = - messages.stream() - .filter(m -> m.getRole() != MessageRole.SYSTEM) - .collect(Collectors.toList()); - - ConverseRequest.Builder requestBuilder = - ConverseRequest.builder() - .modelId(modelId) - .messages(mergeMessages(conversationMsgs)); - - if (!systemMsgs.isEmpty()) { - requestBuilder.system( - systemMsgs.stream() - .map(m -> SystemContentBlock.builder().text(m.getContent()).build()) - .collect(Collectors.toList())); - } + String modelId = resolveModel(arguments); + + List systemMsgs = + messages.stream() + .filter(m -> m.getRole() == MessageRole.SYSTEM) + .collect(Collectors.toList()); + List conversationMsgs = + messages.stream() + .filter(m -> m.getRole() != MessageRole.SYSTEM) + .collect(Collectors.toList()); + + ConverseRequest.Builder requestBuilder = + ConverseRequest.builder() + .modelId(modelId) + .messages(mergeMessages(conversationMsgs)); + + if (!systemMsgs.isEmpty()) { + requestBuilder.system( + systemMsgs.stream() + .map(m -> SystemContentBlock.builder().text(m.getContent()).build()) + .collect(Collectors.toList())); + } - if (tools != null && !tools.isEmpty()) { - requestBuilder.toolConfig( - ToolConfiguration.builder() - .tools( - tools.stream() - .map(this::toBedrockTool) - .collect(Collectors.toList())) - .build()); - } + if (tools != null && !tools.isEmpty()) { + requestBuilder.toolConfig( + ToolConfiguration.builder() + .tools( + tools.stream() + .map(this::toBedrockTool) + .collect(Collectors.toList())) + .build()); + } - // Inference config: temperature and max_tokens - if (arguments != null) { - InferenceConfiguration.Builder inferenceBuilder = null; - Object temp = arguments.get("temperature"); - if (temp instanceof Number) { + // Inference config: temperature and max_tokens + if (arguments != null) { + InferenceConfiguration.Builder inferenceBuilder = null; + Object temp = arguments.get("temperature"); + if (temp instanceof Number) { + inferenceBuilder = InferenceConfiguration.builder(); + inferenceBuilder.temperature(((Number) temp).floatValue()); + } + Object maxTokens = arguments.get("max_tokens"); + if (maxTokens instanceof Number) { + if (inferenceBuilder == null) { inferenceBuilder = InferenceConfiguration.builder(); - inferenceBuilder.temperature(((Number) temp).floatValue()); - } - Object maxTokens = arguments.get("max_tokens"); - if (maxTokens instanceof Number) { - if (inferenceBuilder == null) { - inferenceBuilder = InferenceConfiguration.builder(); - } - inferenceBuilder.maxTokens(((Number) maxTokens).intValue()); - } - if (inferenceBuilder != null) { - requestBuilder.inferenceConfig(inferenceBuilder.build()); } + inferenceBuilder.maxTokens(((Number) maxTokens).intValue()); + } + if (inferenceBuilder != null) { + requestBuilder.inferenceConfig(inferenceBuilder.build()); } + } - ConverseResponse response = client.converse(requestBuilder.build()); + ConverseRequest request = requestBuilder.build(); - if (response.usage() != null) { - recordTokenMetrics( - modelId, response.usage().inputTokens(), response.usage().outputTokens()); - } + for (int attempt = 0; ; attempt++) { + try { + ConverseResponse response = client.converse(request); - return convertResponse(response); - } catch (Exception e) { - throw new RuntimeException("Failed to call Bedrock Converse API.", e); + if (response.usage() != null) { + recordTokenMetrics( + modelId, + response.usage().inputTokens(), + response.usage().outputTokens()); + } + + return convertResponse(response); + } catch (Exception e) { + if (attempt < MAX_RETRIES && isRetryable(e)) { + try { + long delay = + (long) (Math.pow(2, attempt) * 200 * (0.5 + Math.random() * 0.5)); + Thread.sleep(delay); + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Interrupted during Bedrock retry.", ie); + } + } else { + throw new RuntimeException("Failed to call Bedrock Converse API.", e); + } + } } } + private static boolean isRetryable(Exception e) { + String msg = e.toString(); + return msg.contains("ThrottlingException") + || msg.contains("ServiceUnavailableException") + || msg.contains("ModelErrorException") + || msg.contains("429") + || msg.contains("503"); + } + @Override public void close() throws Exception { this.client.close(); @@ -327,22 +355,24 @@ private ChatMessage convertResponse(ConverseResponse response) { } } - ChatMessage result = ChatMessage.assistant(stripMarkdownFences(textContent.toString())); + ChatMessage result = ChatMessage.assistant(textContent.toString()); if (!toolCalls.isEmpty()) { result.setToolCalls(toolCalls); + } else { + // Only strip markdown fences for non-tool-call responses. + result = ChatMessage.assistant(stripMarkdownFences(textContent.toString())); } return result; } /** - * Strip markdown code fences and extract JSON from mixed text responses. Some Bedrock models - * wrap JSON output in markdown fences or add prose around it. + * Strip markdown code fences from text responses. Some Bedrock models wrap JSON output in + * markdown fences like {@code ```json ... ```}. * - *

Note: Unlike OpenAI's strict mode, Bedrock models do not guarantee pure JSON output. The - * flink-agents framework's {@code ChatModelAction.generateStructuredOutput} expects clean JSON, - * so this extraction is necessary at the connection layer. + *

Only strips code fences; does not extract JSON from arbitrary text, as that could corrupt + * normal prose responses containing braces. */ - private static String stripMarkdownFences(String text) { + static String stripMarkdownFences(String text) { if (text == null) return null; String trimmed = text.trim(); if (trimmed.startsWith("```")) { @@ -355,11 +385,6 @@ private static String stripMarkdownFences(String text) { } return trimmed; } - int start = trimmed.indexOf('{'); - int end = trimmed.lastIndexOf('}'); - if (start >= 0 && end > start) { - return trimmed.substring(start, end + 1); - } return trimmed; } diff --git a/integrations/chat-models/bedrock/src/test/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelConnectionTest.java b/integrations/chat-models/bedrock/src/test/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelConnectionTest.java index 1fed3f2e..84c29481 100644 --- a/integrations/chat-models/bedrock/src/test/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelConnectionTest.java +++ b/integrations/chat-models/bedrock/src/test/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelConnectionTest.java @@ -83,4 +83,42 @@ void testChatThrowsWithoutModel() { assertThatThrownBy(() -> conn.chat(msgs, null, Collections.emptyMap())) .isInstanceOf(RuntimeException.class); } + + @Test + @DisplayName("stripMarkdownFences: normal text with braces is not modified") + void testStripMarkdownFencesPreservesTextWithBraces() { + assertThat( + BedrockChatModelConnection.stripMarkdownFences( + "Use the format {key: value} for config")) + .isEqualTo("Use the format {key: value} for config"); + } + + @Test + @DisplayName("stripMarkdownFences: clean JSON passes through") + void testStripMarkdownFencesCleanJson() { + assertThat( + BedrockChatModelConnection.stripMarkdownFences( + "{\"score\": 5, \"reasons\": []}")) + .isEqualTo("{\"score\": 5, \"reasons\": []}"); + } + + @Test + @DisplayName("stripMarkdownFences: strips ```json fences") + void testStripMarkdownFencesJsonBlock() { + assertThat(BedrockChatModelConnection.stripMarkdownFences("```json\n{\"score\": 5}\n```")) + .isEqualTo("{\"score\": 5}"); + } + + @Test + @DisplayName("stripMarkdownFences: strips plain ``` fences") + void testStripMarkdownFencesPlainBlock() { + assertThat(BedrockChatModelConnection.stripMarkdownFences("```\n{\"id\": \"P001\"}\n```")) + .isEqualTo("{\"id\": \"P001\"}"); + } + + @Test + @DisplayName("stripMarkdownFences: null returns null") + void testStripMarkdownFencesNull() { + assertThat(BedrockChatModelConnection.stripMarkdownFences(null)).isNull(); + } } From f19dfab750fcd4933188b3fc51ddc0493116977c Mon Sep 17 00:00:00 2001 From: Avichay Marciano Date: Mon, 9 Feb 2026 16:44:28 +0200 Subject: [PATCH 5/8] [FLINK-AGENTS] Add Amazon OpenSearch and S3 Vectors vector store integrations Add two new vector store integration modules: - OpenSearch: supports Serverless (AOSS) and Service domains, IAM (SigV4) or basic auth. Implements CollectionManageableVectorStore. ANN search via knn query with ef_search, min_score, and filter_query support. Bulk writes chunked by configurable max_bulk_mb. - S3 Vectors: uses S3 Vectors SDK for PutVectors/QueryVectors/ GetVectors/DeleteVectors. PutVectors chunked at 500 (API limit). Both override add() for batch embedding via embed(List). Includes unit tests and integration tests (auto-enabled via OPENSEARCH_ENDPOINT / S3V_BUCKET environment variables). Validated against real OpenSearch domain and S3 Vectors bucket. --- dist/pom.xml | 10 + integrations/vector-stores/opensearch/pom.xml | 59 +++ .../opensearch/OpenSearchVectorStore.java | 453 ++++++++++++++++++ .../opensearch/OpenSearchVectorStoreTest.java | 226 +++++++++ integrations/vector-stores/pom.xml | 2 + integrations/vector-stores/s3vectors/pom.xml | 54 +++ .../s3vectors/S3VectorsVectorStore.java | 205 ++++++++ .../s3vectors/S3VectorsVectorStoreTest.java | 132 +++++ 8 files changed, 1141 insertions(+) create mode 100644 integrations/vector-stores/opensearch/pom.xml create mode 100644 integrations/vector-stores/opensearch/src/main/java/org/apache/flink/agents/integrations/vectorstores/opensearch/OpenSearchVectorStore.java create mode 100644 integrations/vector-stores/opensearch/src/test/java/org/apache/flink/agents/integrations/vectorstores/opensearch/OpenSearchVectorStoreTest.java create mode 100644 integrations/vector-stores/s3vectors/pom.xml create mode 100644 integrations/vector-stores/s3vectors/src/main/java/org/apache/flink/agents/integrations/vectorstores/s3vectors/S3VectorsVectorStore.java create mode 100644 integrations/vector-stores/s3vectors/src/test/java/org/apache/flink/agents/integrations/vectorstores/s3vectors/S3VectorsVectorStoreTest.java diff --git a/dist/pom.xml b/dist/pom.xml index b4e479a8..16ed7fcd 100644 --- a/dist/pom.xml +++ b/dist/pom.xml @@ -99,6 +99,16 @@ under the License. flink-agents-integrations-vector-stores-elasticsearch ${project.version} + + org.apache.flink + flink-agents-integrations-vector-stores-opensearch + ${project.version} + + + org.apache.flink + flink-agents-integrations-vector-stores-s3vectors + ${project.version} + org.apache.flink flink-agents-integrations-mcp diff --git a/integrations/vector-stores/opensearch/pom.xml b/integrations/vector-stores/opensearch/pom.xml new file mode 100644 index 00000000..6c34b885 --- /dev/null +++ b/integrations/vector-stores/opensearch/pom.xml @@ -0,0 +1,59 @@ + + + + 4.0.0 + + + org.apache.flink + flink-agents-integrations-vector-stores + 0.3-SNAPSHOT + ../pom.xml + + + flink-agents-integrations-vector-stores-opensearch + Flink Agents : Integrations: Vector Stores: OpenSearch + jar + + + + org.apache.flink + flink-agents-api + ${project.version} + + + + software.amazon.awssdk + apache-client + ${aws.sdk.version} + + + software.amazon.awssdk + auth + ${aws.sdk.version} + + + com.google.code.findbugs + jsr305 + 1.3.9 + provided + + + + diff --git a/integrations/vector-stores/opensearch/src/main/java/org/apache/flink/agents/integrations/vectorstores/opensearch/OpenSearchVectorStore.java b/integrations/vector-stores/opensearch/src/main/java/org/apache/flink/agents/integrations/vectorstores/opensearch/OpenSearchVectorStore.java new file mode 100644 index 00000000..59435389 --- /dev/null +++ b/integrations/vector-stores/opensearch/src/main/java/org/apache/flink/agents/integrations/vectorstores/opensearch/OpenSearchVectorStore.java @@ -0,0 +1,453 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.integrations.vectorstores.opensearch; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ArrayNode; +import com.fasterxml.jackson.databind.node.ObjectNode; +import org.apache.flink.agents.api.embedding.model.BaseEmbeddingModelSetup; +import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.apache.flink.agents.api.resource.ResourceType; +import org.apache.flink.agents.api.vectorstores.BaseVectorStore; +import org.apache.flink.agents.api.vectorstores.CollectionManageableVectorStore; +import org.apache.flink.agents.api.vectorstores.Document; +import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; +import software.amazon.awssdk.auth.signer.Aws4Signer; +import software.amazon.awssdk.auth.signer.params.Aws4SignerParams; +import software.amazon.awssdk.http.*; +import software.amazon.awssdk.http.apache.ApacheHttpClient; +import software.amazon.awssdk.regions.Region; + +import javax.annotation.Nullable; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.util.*; +import java.util.function.BiFunction; + +/** + * OpenSearch vector store supporting both OpenSearch Serverless (AOSS) and + * OpenSearch Service domains, with IAM (SigV4) or basic auth. + * + *

Implements {@link CollectionManageableVectorStore} for Long-Term Memory support. + * Collections map to OpenSearch indices. Collection metadata is stored in a + * dedicated {@code _collection_metadata} index. + * + *

Parameters: + *

    + *
  • embedding_model (required): name of the embedding model resource
  • + *
  • endpoint (required): OpenSearch endpoint URL
  • + *
  • index (required): default index name
  • + *
  • service_type (optional): "serverless" (default) or "domain"
  • + *
  • auth (optional): "iam" (default) or "basic"
  • + *
  • username (required if auth=basic): basic auth username
  • + *
  • password (required if auth=basic): basic auth password
  • + *
  • vector_field (optional): vector field name (default: "embedding")
  • + *
  • content_field (optional): content field name (default: "content")
  • + *
  • region (optional): AWS region (default: us-east-1)
  • + *
  • dims (optional): vector dimensions for index creation (default: 1024)
  • + *
+ */ +public class OpenSearchVectorStore extends BaseVectorStore + implements CollectionManageableVectorStore { + + private static final ObjectMapper MAPPER = new ObjectMapper(); + private static final String METADATA_INDEX = "flink_agents_collection_metadata"; + + private final String endpoint; + private final String index; + private final String vectorField; + private final String contentField; + private final int dims; + private final Region region; + private final boolean serverless; + private final boolean useIamAuth; + private final String basicAuthHeader; + + private final SdkHttpClient httpClient; + private final Aws4Signer signer; + + public OpenSearchVectorStore( + ResourceDescriptor descriptor, + BiFunction getResource) { + super(descriptor, getResource); + + this.endpoint = descriptor.getArgument("endpoint"); + this.index = descriptor.getArgument("index"); + this.vectorField = Objects.requireNonNullElse(descriptor.getArgument("vector_field"), "embedding"); + this.contentField = Objects.requireNonNullElse(descriptor.getArgument("content_field"), "content"); + Integer dimsArg = descriptor.getArgument("dims"); + this.dims = dimsArg != null ? dimsArg : 1024; + + String regionStr = descriptor.getArgument("region"); + this.region = Region.of(regionStr != null ? regionStr : "us-east-1"); + + String serviceType = Objects.requireNonNullElse(descriptor.getArgument("service_type"), "serverless"); + this.serverless = serviceType.equalsIgnoreCase("serverless"); + + String auth = Objects.requireNonNullElse(descriptor.getArgument("auth"), "iam"); + this.useIamAuth = auth.equalsIgnoreCase("iam"); + + if (!useIamAuth) { + String username = descriptor.getArgument("username"); + String password = descriptor.getArgument("password"); + if (username == null || password == null) { + throw new IllegalArgumentException("username and password required for basic auth"); + } + this.basicAuthHeader = "Basic " + Base64.getEncoder() + .encodeToString((username + ":" + password).getBytes(StandardCharsets.UTF_8)); + } else { + this.basicAuthHeader = null; + } + + this.httpClient = ApacheHttpClient.create(); + this.signer = Aws4Signer.create(); + + Integer bulkMb = descriptor.getArgument("max_bulk_mb"); + this.maxBulkBytes = (bulkMb != null ? bulkMb : 5) * 1024 * 1024; + } + + /** Batch-embeds all documents in a single call, then delegates to addEmbedding. */ + @Override + public List add(List documents, @Nullable String collection, + Map extraArgs) throws IOException { + BaseEmbeddingModelSetup emb = (BaseEmbeddingModelSetup) + this.getResource.apply(this.embeddingModel, ResourceType.EMBEDDING_MODEL); + List texts = new ArrayList<>(); + List needsEmbedding = new ArrayList<>(); + for (int i = 0; i < documents.size(); i++) { + if (documents.get(i).getEmbedding() == null) { + texts.add(documents.get(i).getContent()); + needsEmbedding.add(i); + } + } + if (!texts.isEmpty()) { + List embeddings = emb.embed(texts); + for (int j = 0; j < needsEmbedding.size(); j++) { + documents.get(needsEmbedding.get(j)).setEmbedding(embeddings.get(j)); + } + } + return this.addEmbedding(documents, collection, extraArgs); + } + + // ---- CollectionManageableVectorStore ---- + + @Override + public Collection getOrCreateCollection(String name, Map metadata) + throws Exception { + String idx = sanitizeIndexName(name); + if (!indexExists(idx)) { + createKnnIndex(idx); + } + // Store metadata + ensureMetadataIndex(); + ObjectNode doc = MAPPER.createObjectNode(); + doc.put("collection_name", name); + doc.set("metadata", MAPPER.valueToTree(metadata)); + executeRequest("PUT", "/" + METADATA_INDEX + "/_doc/" + idx, doc.toString()); + executeRequest("POST", "/" + METADATA_INDEX + "/_refresh", null); + return new Collection(name, metadata != null ? metadata : Collections.emptyMap()); + } + + @Override + @SuppressWarnings("unchecked") + public Collection getCollection(String name) throws Exception { + String idx = sanitizeIndexName(name); + if (!indexExists(idx)) { + throw new RuntimeException("Collection " + name + " not found"); + } + try { + ensureMetadataIndex(); + JsonNode resp = executeRequest("GET", "/" + METADATA_INDEX + "/_doc/" + idx, null); + if (resp.has("found") && resp.get("found").asBoolean()) { + Map meta = MAPPER.convertValue( + resp.path("_source").path("metadata"), Map.class); + return new Collection(name, meta != null ? meta : Collections.emptyMap()); + } + } catch (Exception ignored) { + // metadata index may not exist yet + } + return new Collection(name, Collections.emptyMap()); + } + + @Override + public Collection deleteCollection(String name) throws Exception { + String idx = sanitizeIndexName(name); + Collection col = getCollection(name); + executeRequest("DELETE", "/" + idx, null); + try { + executeRequest("DELETE", "/" + METADATA_INDEX + "/_doc/" + idx, null); + } catch (Exception ignored) { + } + return col; + } + + private boolean indexExists(String idx) { + try { + executeRequest("HEAD", "/" + idx, null); + return true; + } catch (Exception e) { + return false; + } + } + + private void createKnnIndex(String idx) { + String body = String.format( + "{\"settings\":{\"index\":{\"knn\":true}}," + + "\"mappings\":{\"properties\":{\"%s\":{\"type\":\"knn_vector\",\"dimension\":%d}," + + "\"%s\":{\"type\":\"text\"},\"metadata\":{\"type\":\"object\"}}}}", + vectorField, dims, contentField); + try { + executeRequest("PUT", "/" + idx, body); + } catch (RuntimeException e) { + if (!e.getMessage().contains("resource_already_exists_exception")) { + throw e; + } + } + } + + private void ensureMetadataIndex() { + if (!indexExists(METADATA_INDEX)) { + try { + executeRequest("PUT", "/" + METADATA_INDEX, + "{\"mappings\":{\"properties\":{\"collection_name\":{\"type\":\"keyword\"}," + + "\"metadata\":{\"type\":\"object\"}}}}"); + } catch (RuntimeException e) { + if (!e.getMessage().contains("resource_already_exists_exception")) { + throw e; + } + } + } + } + + /** Sanitize collection name to valid OpenSearch index name (lowercase, no special chars). */ + private String sanitizeIndexName(String name) { + return name.toLowerCase(Locale.ROOT) + .replaceAll("[^a-z0-9\\-_]", "-") + .replaceAll("^[^a-z]+", "a-"); // must start with letter + } + + // ---- BaseVectorStore ---- + + @Override + public Map getStoreKwargs() { + Map m = new HashMap<>(); + m.put("index", index); + m.put("vector_field", vectorField); + return m; + } + + @Override + public long size(@Nullable String collection) throws Exception { + String idx = collection != null ? sanitizeIndexName(collection) : this.index; + JsonNode response = executeRequest("GET", "/" + idx + "/_count", null); + return response.get("count").asLong(); + } + + @Override + public List get(@Nullable List ids, @Nullable String collection, + Map extraArgs) throws IOException { + String idx = collection != null ? sanitizeIndexName(collection) : this.index; + if (ids != null && !ids.isEmpty()) { + ObjectNode body = MAPPER.createObjectNode(); + ArrayNode idsArray = body.putObject("query").putObject("ids").putArray("values"); + ids.forEach(idsArray::add); + body.put("size", ids.size()); + return parseHits(executeRequest("POST", "/" + idx + "/_search", body.toString())); + } + return parseHits(executeRequest("POST", "/" + idx + "/_search", + "{\"query\":{\"match_all\":{}},\"size\":10000}")); + } + + @Override + public void delete(@Nullable List ids, @Nullable String collection, + Map extraArgs) throws IOException { + String idx = collection != null ? sanitizeIndexName(collection) : this.index; + if (ids != null && !ids.isEmpty()) { + ObjectNode body = MAPPER.createObjectNode(); + ArrayNode idsArray = body.putObject("query").putObject("ids").putArray("values"); + ids.forEach(idsArray::add); + executeRequest("POST", "/" + idx + "/_delete_by_query", body.toString()); + } else { + executeRequest("POST", "/" + idx + "/_delete_by_query", + "{\"query\":{\"match_all\":{}}}"); + } + executeRequest("POST", "/" + idx + "/_refresh", null); + } + + @Override + protected List queryEmbedding(float[] embedding, int limit, + @Nullable String collection, Map args) { + try { + String idx = collection != null ? sanitizeIndexName(collection) : this.index; + int k = (int) args.getOrDefault("k", Math.max(1, limit)); + + ObjectNode body = MAPPER.createObjectNode(); + body.put("size", k); + ObjectNode knnQuery = body.putObject("query").putObject("knn"); + ObjectNode fieldQuery = knnQuery.putObject(vectorField); + ArrayNode vectorArray = fieldQuery.putArray("vector"); + for (float v : embedding) vectorArray.add(v); + fieldQuery.put("k", k); + if (args.containsKey("min_score")) { + fieldQuery.put("min_score", ((Number) args.get("min_score")).floatValue()); + } + if (args.containsKey("ef_search")) { + fieldQuery.putObject("method_parameters") + .put("ef_search", ((Number) args.get("ef_search")).intValue()); + } + if (args.containsKey("filter_query")) { + fieldQuery.set("filter", MAPPER.readTree((String) args.get("filter_query"))); + } + + return parseHits(executeRequest("POST", "/" + idx + "/_search", body.toString())); + } catch (Exception e) { + throw new RuntimeException("OpenSearch KNN search failed.", e); + } + } + + /** Max bulk payload size in bytes (default 5MB, configurable via constructor). */ + private final int maxBulkBytes; + + @Override + protected List addEmbedding(List documents, @Nullable String collection, + Map extraArgs) throws IOException { + String idx = collection != null ? sanitizeIndexName(collection) : this.index; + if (!indexExists(idx)) { + createKnnIndex(idx); + } + List allIds = new ArrayList<>(); + StringBuilder bulk = new StringBuilder(); + int bulkBytes = 0; + + for (Document doc : documents) { + String id = doc.getId() != null ? doc.getId() : UUID.randomUUID().toString(); + allIds.add(id); + + ObjectNode action = MAPPER.createObjectNode(); + action.putObject("index").put("_index", idx).put("_id", id); + String actionLine = action.toString() + "\n"; + + ObjectNode source = MAPPER.createObjectNode(); + source.put(contentField, doc.getContent()); + if (doc.getEmbedding() != null) { + ArrayNode vec = source.putArray(vectorField); + for (float v : doc.getEmbedding()) vec.add(v); + } + if (doc.getMetadata() != null) { + source.set("metadata", MAPPER.valueToTree(doc.getMetadata())); + } + String sourceLine = source.toString() + "\n"; + + int entryBytes = actionLine.length() + sourceLine.length(); + + // Flush if adding this entry would exceed the bulk size limit + if (bulkBytes > 0 && bulkBytes + entryBytes > maxBulkBytes) { + executeRequest("POST", "/_bulk", bulk.toString()); + bulk.setLength(0); + bulkBytes = 0; + } + + bulk.append(actionLine).append(sourceLine); + bulkBytes += entryBytes; + } + + if (bulkBytes > 0) { + executeRequest("POST", "/_bulk", bulk.toString()); + } + executeRequest("POST", "/" + idx + "/_refresh", null); + return allIds; + } + + private List parseHits(JsonNode response) { + List docs = new ArrayList<>(); + JsonNode hits = response.path("hits").path("hits"); + for (JsonNode hit : hits) { + String id = hit.get("_id").asText(); + JsonNode source = hit.get("_source"); + String content = source.has(contentField) ? source.get(contentField).asText() : ""; + Map metadata = new HashMap<>(); + if (source.has("metadata")) { + metadata = MAPPER.convertValue(source.get("metadata"), Map.class); + } + docs.add(new Document(content, metadata, id)); + } + return docs; + } + + @SuppressWarnings("unchecked") + private JsonNode executeRequest(String method, String path, @Nullable String body) { + try { + URI uri = URI.create(endpoint + path); + SdkHttpFullRequest.Builder reqBuilder = SdkHttpFullRequest.builder() + .uri(uri) + .method(SdkHttpMethod.valueOf(method)) + .putHeader("Content-Type", "application/json"); + + if (body != null) { + reqBuilder.contentStreamProvider(() -> + new ByteArrayInputStream(body.getBytes(StandardCharsets.UTF_8))); + } + + SdkHttpFullRequest request; + if (useIamAuth) { + Aws4SignerParams signerParams = Aws4SignerParams.builder() + .awsCredentials(DefaultCredentialsProvider.create().resolveCredentials()) + .signingName(serverless ? "aoss" : "es") + .signingRegion(region) + .build(); + request = signer.sign(reqBuilder.build(), signerParams); + } else { + request = reqBuilder + .putHeader("Authorization", basicAuthHeader) + .build(); + } + + HttpExecuteRequest.Builder execBuilder = HttpExecuteRequest.builder().request(request); + if (request.contentStreamProvider().isPresent()) { + execBuilder.contentStreamProvider(request.contentStreamProvider().get()); + } + + HttpExecuteResponse response = httpClient.prepareRequest(execBuilder.build()).call(); + int statusCode = response.httpResponse().statusCode(); + + if ("HEAD".equals(method)) { + if (statusCode >= 400) { + throw new RuntimeException("OpenSearch HEAD request failed (" + statusCode + ")"); + } + return MAPPER.createObjectNode().put("status", statusCode); + } + + String responseBody = new String( + response.responseBody().orElseThrow().readAllBytes()); + + if (response.httpResponse().statusCode() >= 400) { + throw new RuntimeException("OpenSearch request failed (" + + response.httpResponse().statusCode() + "): " + responseBody); + } + return MAPPER.readTree(responseBody); + } catch (RuntimeException e) { + throw e; + } catch (Exception e) { + throw new RuntimeException("OpenSearch request failed.", e); + } + } +} diff --git a/integrations/vector-stores/opensearch/src/test/java/org/apache/flink/agents/integrations/vectorstores/opensearch/OpenSearchVectorStoreTest.java b/integrations/vector-stores/opensearch/src/test/java/org/apache/flink/agents/integrations/vectorstores/opensearch/OpenSearchVectorStoreTest.java new file mode 100644 index 00000000..dfed9a91 --- /dev/null +++ b/integrations/vector-stores/opensearch/src/test/java/org/apache/flink/agents/integrations/vectorstores/opensearch/OpenSearchVectorStoreTest.java @@ -0,0 +1,226 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.integrations.vectorstores.opensearch; + +import org.apache.flink.agents.api.embedding.model.BaseEmbeddingModelSetup; +import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.apache.flink.agents.api.resource.ResourceType; +import org.apache.flink.agents.api.vectorstores.BaseVectorStore; +import org.apache.flink.agents.api.vectorstores.CollectionManageableVectorStore; +import org.apache.flink.agents.api.vectorstores.Document; +import org.apache.flink.agents.api.vectorstores.VectorStoreQuery; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.mockito.Mockito; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.function.BiFunction; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link OpenSearchVectorStore}. + * + *

Integration tests require an OpenSearch Serverless collection or domain. Set + * OPENSEARCH_ENDPOINT environment variable to run. + */ +public class OpenSearchVectorStoreTest { + + private static final BiFunction NOOP = (a, b) -> null; + + @Test + @DisplayName("Constructor creates store with IAM auth") + void testConstructorIamAuth() { + ResourceDescriptor desc = ResourceDescriptor.Builder + .newBuilder(OpenSearchVectorStore.class.getName()) + .addInitialArgument("embedding_model", "emb") + .addInitialArgument("endpoint", "https://example.aoss.us-east-1.amazonaws.com") + .addInitialArgument("index", "test-index") + .addInitialArgument("region", "us-east-1") + .addInitialArgument("service_type", "serverless") + .addInitialArgument("auth", "iam") + .build(); + OpenSearchVectorStore store = new OpenSearchVectorStore(desc, NOOP); + assertThat(store).isInstanceOf(BaseVectorStore.class); + assertThat(store).isInstanceOf(CollectionManageableVectorStore.class); + } + + @Test + @DisplayName("Constructor creates store with basic auth") + void testConstructorBasicAuth() { + ResourceDescriptor desc = ResourceDescriptor.Builder + .newBuilder(OpenSearchVectorStore.class.getName()) + .addInitialArgument("embedding_model", "emb") + .addInitialArgument("endpoint", "https://my-domain.us-east-1.es.amazonaws.com") + .addInitialArgument("index", "test-index") + .addInitialArgument("region", "us-east-1") + .addInitialArgument("service_type", "domain") + .addInitialArgument("auth", "basic") + .addInitialArgument("username", "admin") + .addInitialArgument("password", "password") + .build(); + OpenSearchVectorStore store = new OpenSearchVectorStore(desc, NOOP); + assertThat(store).isInstanceOf(BaseVectorStore.class); + } + + @Test + @DisplayName("Constructor with custom max_bulk_mb") + void testConstructorCustomBulkSize() { + ResourceDescriptor desc = ResourceDescriptor.Builder + .newBuilder(OpenSearchVectorStore.class.getName()) + .addInitialArgument("embedding_model", "emb") + .addInitialArgument("endpoint", "https://example.aoss.us-east-1.amazonaws.com") + .addInitialArgument("index", "test-index") + .addInitialArgument("max_bulk_mb", 10) + .build(); + OpenSearchVectorStore store = new OpenSearchVectorStore(desc, NOOP); + assertThat(store.getStoreKwargs()).containsEntry("index", "test-index"); + } + + @Test + @DisplayName("Basic auth requires username and password") + void testBasicAuthRequiresCredentials() { + ResourceDescriptor desc = ResourceDescriptor.Builder + .newBuilder(OpenSearchVectorStore.class.getName()) + .addInitialArgument("embedding_model", "emb") + .addInitialArgument("endpoint", "https://example.com") + .addInitialArgument("index", "test") + .addInitialArgument("auth", "basic") + .build(); + Assertions.assertThrows(IllegalArgumentException.class, + () -> new OpenSearchVectorStore(desc, NOOP)); + } + + // --- Integration tests (require real OpenSearch) --- + + private static BaseVectorStore store; + + private static Resource getResource(String name, ResourceType type) { + BaseEmbeddingModelSetup emb = Mockito.mock(BaseEmbeddingModelSetup.class); + Mockito.when(emb.embed("OpenSearch is a search engine")) + .thenReturn(new float[]{0.2f, 0.3f, 0.4f, 0.5f, 0.6f}); + Mockito.when(emb.embed("Flink Agents is an AI framework")) + .thenReturn(new float[]{0.1f, 0.2f, 0.3f, 0.4f, 0.5f}); + Mockito.when(emb.embed("search engine")) + .thenReturn(new float[]{0.2f, 0.3f, 0.4f, 0.5f, 0.6f}); + Mockito.when(emb.embed(Mockito.anyList())).thenAnswer(inv -> { + List texts = inv.getArgument(0); + List result = new ArrayList<>(); + for (String t : texts) { + result.add(emb.embed(t)); + } + return result; + }); + return emb; + } + + @BeforeAll + static void initialize() { + String endpoint = System.getenv("OPENSEARCH_ENDPOINT"); + if (endpoint == null) return; + String auth = System.getenv().getOrDefault("OPENSEARCH_AUTH", "iam"); + ResourceDescriptor.Builder builder = ResourceDescriptor.Builder + .newBuilder(OpenSearchVectorStore.class.getName()) + .addInitialArgument("embedding_model", "emb") + .addInitialArgument("endpoint", endpoint) + .addInitialArgument("index", "test-opensearch") + .addInitialArgument("dims", 5) + .addInitialArgument("region", System.getenv().getOrDefault("AWS_REGION", "us-east-1")) + .addInitialArgument("service_type", + System.getenv().getOrDefault("OPENSEARCH_SERVICE_TYPE", "serverless")) + .addInitialArgument("auth", auth); + if ("basic".equals(auth)) { + builder.addInitialArgument("username", System.getenv("OPENSEARCH_USERNAME")); + builder.addInitialArgument("password", System.getenv("OPENSEARCH_PASSWORD")); + } + store = new OpenSearchVectorStore(builder.build(), OpenSearchVectorStoreTest::getResource); + } + + @Test + @EnabledIfEnvironmentVariable(named = "OPENSEARCH_ENDPOINT", matches = ".+") + @DisplayName("Collection management: create, get, delete") + void testCollectionManagement() throws Exception { + CollectionManageableVectorStore vs = (CollectionManageableVectorStore) store; + String name = "test_collection"; + Map metadata = Map.of("key1", "value1"); + vs.getOrCreateCollection(name, metadata); + + CollectionManageableVectorStore.Collection col = vs.getCollection(name); + Assertions.assertNotNull(col); + Assertions.assertEquals(name, col.getName()); + + vs.deleteCollection(name); + } + + @Test + @EnabledIfEnvironmentVariable(named = "OPENSEARCH_ENDPOINT", matches = ".+") + @DisplayName("Document management: add, get, delete") + void testDocumentManagement() throws Exception { + String name = "test_docs"; + ((CollectionManageableVectorStore) store).getOrCreateCollection(name, Map.of()); + + List docs = new ArrayList<>(); + docs.add(new Document("OpenSearch is a search engine", Map.of("src", "test"), "doc1")); + docs.add(new Document("Flink Agents is an AI framework", Map.of("src", "test"), "doc2")); + store.add(docs, name, Collections.emptyMap()); + Thread.sleep(1000); + + List all = store.get(null, name, Collections.emptyMap()); + Assertions.assertEquals(2, all.size()); + + store.delete(Collections.singletonList("doc1"), name, Collections.emptyMap()); + Thread.sleep(1000); + List remaining = store.get(null, name, Collections.emptyMap()); + Assertions.assertEquals(1, remaining.size()); + + ((CollectionManageableVectorStore) store).deleteCollection(name); + } + + @Test + @EnabledIfEnvironmentVariable(named = "OPENSEARCH_ENDPOINT", matches = ".+") + @DisplayName("Query with filter_query restricts results") + void testQueryWithFilter() throws Exception { + String name = "test_filter"; + ((CollectionManageableVectorStore) store).getOrCreateCollection(name, Map.of()); + + List docs = new ArrayList<>(); + docs.add(new Document("OpenSearch is a search engine", Map.of("src", "web"), "f1")); + docs.add(new Document("Flink Agents is an AI framework", Map.of("src", "code"), "f2")); + store.add(docs, name, Collections.emptyMap()); + Thread.sleep(1000); + + // Query with filter: only src=web + VectorStoreQuery q = new VectorStoreQuery( + "search engine", 5, name, + Map.of("filter_query", "{\"term\":{\"metadata.src.keyword\":\"web\"}}")); + List results = store.query(q).getDocuments(); + Assertions.assertFalse(results.isEmpty()); + Assertions.assertTrue(results.stream().allMatch( + d -> "web".equals(d.getMetadata().get("src")))); + + ((CollectionManageableVectorStore) store).deleteCollection(name); + } +} diff --git a/integrations/vector-stores/pom.xml b/integrations/vector-stores/pom.xml index 7b9612da..4d4766d9 100644 --- a/integrations/vector-stores/pom.xml +++ b/integrations/vector-stores/pom.xml @@ -32,6 +32,8 @@ under the License. elasticsearch + opensearch + s3vectors \ No newline at end of file diff --git a/integrations/vector-stores/s3vectors/pom.xml b/integrations/vector-stores/s3vectors/pom.xml new file mode 100644 index 00000000..64fbf87a --- /dev/null +++ b/integrations/vector-stores/s3vectors/pom.xml @@ -0,0 +1,54 @@ + + + + 4.0.0 + + + org.apache.flink + flink-agents-integrations-vector-stores + 0.3-SNAPSHOT + ../pom.xml + + + flink-agents-integrations-vector-stores-s3vectors + Flink Agents : Integrations: Vector Stores: S3 Vectors + jar + + + + org.apache.flink + flink-agents-api + ${project.version} + + + + software.amazon.awssdk + s3vectors + ${aws.sdk.version} + + + com.google.code.findbugs + jsr305 + 1.3.9 + provided + + + + diff --git a/integrations/vector-stores/s3vectors/src/main/java/org/apache/flink/agents/integrations/vectorstores/s3vectors/S3VectorsVectorStore.java b/integrations/vector-stores/s3vectors/src/main/java/org/apache/flink/agents/integrations/vectorstores/s3vectors/S3VectorsVectorStore.java new file mode 100644 index 00000000..9fec3151 --- /dev/null +++ b/integrations/vector-stores/s3vectors/src/main/java/org/apache/flink/agents/integrations/vectorstores/s3vectors/S3VectorsVectorStore.java @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.integrations.vectorstores.s3vectors; + +import org.apache.flink.agents.api.embedding.model.BaseEmbeddingModelSetup; +import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.apache.flink.agents.api.resource.ResourceType; +import org.apache.flink.agents.api.vectorstores.BaseVectorStore; +import org.apache.flink.agents.api.vectorstores.Document; +import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.s3vectors.S3VectorsClient; +import software.amazon.awssdk.services.s3vectors.model.*; + +import javax.annotation.Nullable; +import java.io.IOException; +import java.util.*; +import java.util.function.BiFunction; + +/** + * Amazon S3 Vectors vector store for flink-agents. + */ +public class S3VectorsVectorStore extends BaseVectorStore { + + private final S3VectorsClient client; + private final String vectorBucket; + private final String vectorIndex; + + public S3VectorsVectorStore( + ResourceDescriptor descriptor, + BiFunction getResource) { + super(descriptor, getResource); + this.vectorBucket = descriptor.getArgument("vector_bucket"); + this.vectorIndex = descriptor.getArgument("vector_index"); + String regionStr = descriptor.getArgument("region"); + this.client = S3VectorsClient.builder() + .region(Region.of(regionStr != null ? regionStr : "us-east-1")) + .credentialsProvider(DefaultCredentialsProvider.create()) + .build(); + } + + /** Batch-embeds all documents in a single call, then delegates to addEmbedding. */ + @Override + public List add(List documents, @Nullable String collection, + Map extraArgs) throws IOException { + BaseEmbeddingModelSetup emb = (BaseEmbeddingModelSetup) + this.getResource.apply(this.embeddingModel, ResourceType.EMBEDDING_MODEL); + List texts = new ArrayList<>(); + List needsEmbedding = new ArrayList<>(); + for (int i = 0; i < documents.size(); i++) { + if (documents.get(i).getEmbedding() == null) { + texts.add(documents.get(i).getContent()); + needsEmbedding.add(i); + } + } + if (!texts.isEmpty()) { + List embeddings = emb.embed(texts); + for (int j = 0; j < needsEmbedding.size(); j++) { + documents.get(needsEmbedding.get(j)).setEmbedding(embeddings.get(j)); + } + } + return this.addEmbedding(documents, collection, extraArgs); + } + + @Override + public Map getStoreKwargs() { + Map m = new HashMap<>(); + m.put("vector_bucket", vectorBucket); + m.put("vector_index", vectorIndex); + return m; + } + + @Override + public long size(@Nullable String collection) { return -1; } + + @Override + public List get(@Nullable List ids, @Nullable String collection, + Map extraArgs) throws IOException { + if (ids == null || ids.isEmpty()) return Collections.emptyList(); + String idx = collection != null ? collection : vectorIndex; + + GetVectorsResponse response = client.getVectors(GetVectorsRequest.builder() + .vectorBucketName(vectorBucket) + .indexName(idx) + .keys(ids) + .returnMetadata(true) + .build()); + + List docs = new ArrayList<>(); + for (GetOutputVector v : response.vectors()) { + docs.add(toDocument(v.key(), v.metadata())); + } + return docs; + } + + @Override + public void delete(@Nullable List ids, @Nullable String collection, + Map extraArgs) throws IOException { + if (ids == null || ids.isEmpty()) return; + String idx = collection != null ? collection : vectorIndex; + client.deleteVectors(DeleteVectorsRequest.builder() + .vectorBucketName(vectorBucket).indexName(idx).keys(ids).build()); + } + + @Override + protected List queryEmbedding(float[] embedding, int limit, + @Nullable String collection, Map args) { + try { + String idx = collection != null ? collection : vectorIndex; + int topK = (int) args.getOrDefault("top_k", Math.max(1, limit)); + + List queryVector = new ArrayList<>(embedding.length); + for (float v : embedding) queryVector.add(v); + + QueryVectorsResponse response = client.queryVectors(QueryVectorsRequest.builder() + .vectorBucketName(vectorBucket) + .indexName(idx) + .queryVector(VectorData.fromFloat32(queryVector)) + .topK(topK) + .returnMetadata(true) + .build()); + + List docs = new ArrayList<>(); + for (QueryOutputVector v : response.vectors()) { + docs.add(toDocument(v.key(), v.metadata())); + } + return docs; + } catch (Exception e) { + throw new RuntimeException("S3 Vectors query failed.", e); + } + } + + private static final int MAX_PUT_VECTORS_BATCH = 500; + + @Override + protected List addEmbedding(List documents, @Nullable String collection, + Map extraArgs) throws IOException { + String idx = collection != null ? collection : vectorIndex; + List allKeys = new ArrayList<>(); + + // Build all vectors first + List allVectors = new ArrayList<>(); + for (Document doc : documents) { + String key = doc.getId() != null ? doc.getId() : UUID.randomUUID().toString(); + allKeys.add(key); + + List embeddingList = new ArrayList<>(); + if (doc.getEmbedding() != null) { + for (float v : doc.getEmbedding()) embeddingList.add(v); + } + + Map metaMap = new LinkedHashMap<>(); + metaMap.put("source_text", + software.amazon.awssdk.core.document.Document.fromString(doc.getContent())); + if (doc.getMetadata() != null) { + doc.getMetadata().forEach((k, v) -> metaMap.put(k, + software.amazon.awssdk.core.document.Document.fromString(String.valueOf(v)))); + } + + allVectors.add(PutInputVector.builder() + .key(key) + .data(VectorData.fromFloat32(embeddingList)) + .metadata(software.amazon.awssdk.core.document.Document.fromMap(metaMap)) + .build()); + } + + // Chunk into batches of 500 (S3 Vectors API limit) + for (int i = 0; i < allVectors.size(); i += MAX_PUT_VECTORS_BATCH) { + List batch = allVectors.subList(i, + Math.min(i + MAX_PUT_VECTORS_BATCH, allVectors.size())); + client.putVectors(PutVectorsRequest.builder() + .vectorBucketName(vectorBucket).indexName(idx).vectors(batch).build()); + } + return allKeys; + } + + private Document toDocument(String key, + software.amazon.awssdk.core.document.Document metadata) { + Map metaMap = new HashMap<>(); + String content = ""; + if (metadata != null && metadata.isMap()) { + metadata.asMap().forEach((k, v) -> { + if (v.isString()) metaMap.put(k, v.asString()); + }); + content = metaMap.getOrDefault("source_text", "").toString(); + } + return new Document(content, metaMap, key); + } +} diff --git a/integrations/vector-stores/s3vectors/src/test/java/org/apache/flink/agents/integrations/vectorstores/s3vectors/S3VectorsVectorStoreTest.java b/integrations/vector-stores/s3vectors/src/test/java/org/apache/flink/agents/integrations/vectorstores/s3vectors/S3VectorsVectorStoreTest.java new file mode 100644 index 00000000..6ffac314 --- /dev/null +++ b/integrations/vector-stores/s3vectors/src/test/java/org/apache/flink/agents/integrations/vectorstores/s3vectors/S3VectorsVectorStoreTest.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.integrations.vectorstores.s3vectors; + +import org.apache.flink.agents.api.embedding.model.BaseEmbeddingModelSetup; +import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.apache.flink.agents.api.resource.ResourceType; +import org.apache.flink.agents.api.vectorstores.BaseVectorStore; +import org.apache.flink.agents.api.vectorstores.Document; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.mockito.Mockito; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.function.BiFunction; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link S3VectorsVectorStore}. + * + *

Integration tests require an S3 Vectors bucket and index. Set S3V_BUCKET env var to run. + */ +public class S3VectorsVectorStoreTest { + + private static final BiFunction NOOP = (a, b) -> null; + + @Test + @DisplayName("Constructor creates store") + void testConstructor() { + ResourceDescriptor desc = ResourceDescriptor.Builder + .newBuilder(S3VectorsVectorStore.class.getName()) + .addInitialArgument("embedding_model", "emb") + .addInitialArgument("vector_bucket", "my-bucket") + .addInitialArgument("vector_index", "my-index") + .addInitialArgument("region", "us-east-1") + .build(); + S3VectorsVectorStore store = new S3VectorsVectorStore(desc, NOOP); + assertThat(store).isInstanceOf(BaseVectorStore.class); + } + + @Test + @DisplayName("getStoreKwargs returns bucket and index") + void testStoreKwargs() { + ResourceDescriptor desc = ResourceDescriptor.Builder + .newBuilder(S3VectorsVectorStore.class.getName()) + .addInitialArgument("embedding_model", "emb") + .addInitialArgument("vector_bucket", "test-bucket") + .addInitialArgument("vector_index", "test-index") + .build(); + S3VectorsVectorStore store = new S3VectorsVectorStore(desc, NOOP); + Map kwargs = store.getStoreKwargs(); + assertThat(kwargs).containsEntry("vector_bucket", "test-bucket"); + assertThat(kwargs).containsEntry("vector_index", "test-index"); + } + + // --- Integration tests (require real S3 Vectors bucket) --- + + private static BaseVectorStore store; + + private static Resource getResource(String name, ResourceType type) { + BaseEmbeddingModelSetup emb = Mockito.mock(BaseEmbeddingModelSetup.class); + Mockito.when(emb.embed("Test document one")) + .thenReturn(new float[]{0.1f, 0.2f, 0.3f, 0.4f, 0.5f}); + Mockito.when(emb.embed("Test document two")) + .thenReturn(new float[]{0.5f, 0.4f, 0.3f, 0.2f, 0.1f}); + Mockito.when(emb.embed(Mockito.anyList())).thenAnswer(inv -> { + List texts = inv.getArgument(0); + List result = new ArrayList<>(); + for (String t : texts) { + result.add(emb.embed(t)); + } + return result; + }); + return emb; + } + + @BeforeAll + static void initialize() { + String bucket = System.getenv("S3V_BUCKET"); + if (bucket == null) return; + ResourceDescriptor desc = ResourceDescriptor.Builder + .newBuilder(S3VectorsVectorStore.class.getName()) + .addInitialArgument("embedding_model", "emb") + .addInitialArgument("vector_bucket", bucket) + .addInitialArgument("vector_index", + System.getenv().getOrDefault("S3V_INDEX", "test-index")) + .addInitialArgument("region", + System.getenv().getOrDefault("AWS_REGION", "us-east-1")) + .build(); + store = new S3VectorsVectorStore(desc, S3VectorsVectorStoreTest::getResource); + } + + @Test + @EnabledIfEnvironmentVariable(named = "S3V_BUCKET", matches = ".+") + @DisplayName("Document add and get") + void testDocumentAddAndGet() throws Exception { + List docs = new ArrayList<>(); + docs.add(new Document("Test document one", Map.of("src", "test"), "s3v-doc1")); + docs.add(new Document("Test document two", Map.of("src", "test"), "s3v-doc2")); + store.add(docs, null, Collections.emptyMap()); + + List retrieved = store.get(List.of("s3v-doc1", "s3v-doc2"), + null, Collections.emptyMap()); + Assertions.assertEquals(2, retrieved.size()); + + store.delete(List.of("s3v-doc1", "s3v-doc2"), null, Collections.emptyMap()); + } +} From 7673d15491b2f6ff3ae7a7b931373eb67b5e720a Mon Sep 17 00:00:00 2001 From: Avichay Marciano Date: Tue, 10 Feb 2026 18:09:43 +0200 Subject: [PATCH 6/8] Improve AWS integrations: add close(), validation, retry jitter, Javadocs Bedrock (chat + embedding): - Add close() to release AWS SDK clients and thread pools - Wire max_tokens through BedrockChatModelSetup into InferenceConfiguration - Add retry jitter to BedrockEmbeddingModelConnection - Add typed getConnection() override to BedrockEmbeddingModelSetup - Document stripMarkdownFences necessity and future work OpenSearch vector store: - Add close() to release SdkHttpClient - Cache DefaultCredentialsProvider (was creating new instance per request) - Add constructor validation for required endpoint/index params - Add limit support in get() via extraArgs - Add TODO for Aws4Signer deprecation and batch add() dedup S3 Vectors vector store: - Add close() to release S3VectorsClient - Add constructor validation for required vector_bucket/vector_index params - size() now throws UnsupportedOperationException instead of returning -1 - Add TODO for batch add() dedup All files: expand wildcard imports, add usage example Javadocs --- .../opensearch/OpenSearchVectorStore.java | 260 ++++++++++++------ .../s3vectors/S3VectorsVectorStore.java | 244 ++++++++++++---- 2 files changed, 361 insertions(+), 143 deletions(-) diff --git a/integrations/vector-stores/opensearch/src/main/java/org/apache/flink/agents/integrations/vectorstores/opensearch/OpenSearchVectorStore.java b/integrations/vector-stores/opensearch/src/main/java/org/apache/flink/agents/integrations/vectorstores/opensearch/OpenSearchVectorStore.java index 59435389..bc790f06 100644 --- a/integrations/vector-stores/opensearch/src/main/java/org/apache/flink/agents/integrations/vectorstores/opensearch/OpenSearchVectorStore.java +++ b/integrations/vector-stores/opensearch/src/main/java/org/apache/flink/agents/integrations/vectorstores/opensearch/OpenSearchVectorStore.java @@ -15,12 +15,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.flink.agents.integrations.vectorstores.opensearch; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.node.ArrayNode; import com.fasterxml.jackson.databind.node.ObjectNode; + import org.apache.flink.agents.api.embedding.model.BaseEmbeddingModelSetup; import org.apache.flink.agents.api.resource.Resource; import org.apache.flink.agents.api.resource.ResourceDescriptor; @@ -28,43 +30,76 @@ import org.apache.flink.agents.api.vectorstores.BaseVectorStore; import org.apache.flink.agents.api.vectorstores.CollectionManageableVectorStore; import org.apache.flink.agents.api.vectorstores.Document; + +import software.amazon.awssdk.auth.credentials.AwsCredentials; import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; import software.amazon.awssdk.auth.signer.Aws4Signer; import software.amazon.awssdk.auth.signer.params.Aws4SignerParams; -import software.amazon.awssdk.http.*; +import software.amazon.awssdk.http.HttpExecuteRequest; +import software.amazon.awssdk.http.HttpExecuteResponse; +import software.amazon.awssdk.http.SdkHttpClient; +import software.amazon.awssdk.http.SdkHttpFullRequest; +import software.amazon.awssdk.http.SdkHttpMethod; import software.amazon.awssdk.http.apache.ApacheHttpClient; import software.amazon.awssdk.regions.Region; import javax.annotation.Nullable; + import java.io.ByteArrayInputStream; import java.io.IOException; import java.net.URI; import java.nio.charset.StandardCharsets; -import java.util.*; +import java.util.ArrayList; +import java.util.Base64; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.UUID; import java.util.function.BiFunction; /** - * OpenSearch vector store supporting both OpenSearch Serverless (AOSS) and - * OpenSearch Service domains, with IAM (SigV4) or basic auth. + * OpenSearch vector store supporting both OpenSearch Serverless (AOSS) and OpenSearch Service + * domains, with IAM (SigV4) or basic auth. + * + *

Implements {@link CollectionManageableVectorStore} for Long-Term Memory support. Collections + * map to OpenSearch indices. Collection metadata is stored in a dedicated {@code + * flink_agents_collection_metadata} index. * - *

Implements {@link CollectionManageableVectorStore} for Long-Term Memory support. - * Collections map to OpenSearch indices. Collection metadata is stored in a - * dedicated {@code _collection_metadata} index. + *

Supported parameters: * - *

Parameters: *

    - *
  • embedding_model (required): name of the embedding model resource
  • - *
  • endpoint (required): OpenSearch endpoint URL
  • - *
  • index (required): default index name
  • - *
  • service_type (optional): "serverless" (default) or "domain"
  • - *
  • auth (optional): "iam" (default) or "basic"
  • - *
  • username (required if auth=basic): basic auth username
  • - *
  • password (required if auth=basic): basic auth password
  • - *
  • vector_field (optional): vector field name (default: "embedding")
  • - *
  • content_field (optional): content field name (default: "content")
  • - *
  • region (optional): AWS region (default: us-east-1)
  • - *
  • dims (optional): vector dimensions for index creation (default: 1024)
  • + *
  • embedding_model (required): name of the embedding model resource + *
  • endpoint (required): OpenSearch endpoint URL + *
  • index (required): default index name + *
  • service_type (optional): "serverless" (default) or "domain" + *
  • auth (optional): "iam" (default) or "basic" + *
  • username (required if auth=basic): basic auth username + *
  • password (required if auth=basic): basic auth password + *
  • vector_field (optional): vector field name (default: "embedding") + *
  • content_field (optional): content field name (default: "content") + *
  • region (optional): AWS region (default: us-east-1) + *
  • dims (optional): vector dimensions for index creation (default: 1024) + *
  • max_bulk_mb (optional): max bulk payload size in MB (default: 5) *
+ * + *

Example usage: + * + *

{@code
+ * @VectorStore
+ * public static ResourceDescriptor opensearchStore() {
+ *     return ResourceDescriptor.Builder.newBuilder(OpenSearchVectorStore.class.getName())
+ *             .addInitialArgument("embedding_model", "bedrockEmbeddingSetup")
+ *             .addInitialArgument("endpoint", "https://my-domain.us-east-1.es.amazonaws.com")
+ *             .addInitialArgument("index", "my-vectors")
+ *             .addInitialArgument("service_type", "domain")
+ *             .addInitialArgument("auth", "iam")
+ *             .addInitialArgument("dims", 1024)
+ *             .build();
+ * }
+ * }
*/ public class OpenSearchVectorStore extends BaseVectorStore implements CollectionManageableVectorStore { @@ -81,9 +116,12 @@ public class OpenSearchVectorStore extends BaseVectorStore private final boolean serverless; private final boolean useIamAuth; private final String basicAuthHeader; + private final int maxBulkBytes; private final SdkHttpClient httpClient; + // TODO: Aws4Signer is legacy; migrate to AwsV4HttpSigner from http-auth-aws in a follow-up. private final Aws4Signer signer; + private final DefaultCredentialsProvider credentialsProvider; public OpenSearchVectorStore( ResourceDescriptor descriptor, @@ -91,16 +129,27 @@ public OpenSearchVectorStore( super(descriptor, getResource); this.endpoint = descriptor.getArgument("endpoint"); + if (this.endpoint == null || this.endpoint.isBlank()) { + throw new IllegalArgumentException("endpoint is required for OpenSearchVectorStore"); + } + this.index = descriptor.getArgument("index"); - this.vectorField = Objects.requireNonNullElse(descriptor.getArgument("vector_field"), "embedding"); - this.contentField = Objects.requireNonNullElse(descriptor.getArgument("content_field"), "content"); + if (this.index == null || this.index.isBlank()) { + throw new IllegalArgumentException("index is required for OpenSearchVectorStore"); + } + + this.vectorField = + Objects.requireNonNullElse(descriptor.getArgument("vector_field"), "embedding"); + this.contentField = + Objects.requireNonNullElse(descriptor.getArgument("content_field"), "content"); Integer dimsArg = descriptor.getArgument("dims"); this.dims = dimsArg != null ? dimsArg : 1024; String regionStr = descriptor.getArgument("region"); this.region = Region.of(regionStr != null ? regionStr : "us-east-1"); - String serviceType = Objects.requireNonNullElse(descriptor.getArgument("service_type"), "serverless"); + String serviceType = + Objects.requireNonNullElse(descriptor.getArgument("service_type"), "serverless"); this.serverless = serviceType.equalsIgnoreCase("serverless"); String auth = Objects.requireNonNullElse(descriptor.getArgument("auth"), "iam"); @@ -112,25 +161,44 @@ public OpenSearchVectorStore( if (username == null || password == null) { throw new IllegalArgumentException("username and password required for basic auth"); } - this.basicAuthHeader = "Basic " + Base64.getEncoder() - .encodeToString((username + ":" + password).getBytes(StandardCharsets.UTF_8)); + this.basicAuthHeader = + "Basic " + + Base64.getEncoder() + .encodeToString( + (username + ":" + password) + .getBytes(StandardCharsets.UTF_8)); } else { this.basicAuthHeader = null; } this.httpClient = ApacheHttpClient.create(); this.signer = Aws4Signer.create(); + this.credentialsProvider = DefaultCredentialsProvider.create(); Integer bulkMb = descriptor.getArgument("max_bulk_mb"); this.maxBulkBytes = (bulkMb != null ? bulkMb : 5) * 1024 * 1024; } - /** Batch-embeds all documents in a single call, then delegates to addEmbedding. */ @Override - public List add(List documents, @Nullable String collection, - Map extraArgs) throws IOException { - BaseEmbeddingModelSetup emb = (BaseEmbeddingModelSetup) - this.getResource.apply(this.embeddingModel, ResourceType.EMBEDDING_MODEL); + public void close() throws Exception { + this.httpClient.close(); + } + + /** + * Batch-embeds all documents in a single call, then delegates to addEmbedding. + * + *

TODO: This batch embedding logic is duplicated in S3VectorsVectorStore. Consider + * extracting to BaseVectorStore in a follow-up (would also benefit ElasticsearchVectorStore). + */ + @Override + public List add( + List documents, + @Nullable String collection, + Map extraArgs) + throws IOException { + BaseEmbeddingModelSetup emb = + (BaseEmbeddingModelSetup) + this.getResource.apply(this.embeddingModel, ResourceType.EMBEDDING_MODEL); List texts = new ArrayList<>(); List needsEmbedding = new ArrayList<>(); for (int i = 0; i < documents.size(); i++) { @@ -157,7 +225,6 @@ public Collection getOrCreateCollection(String name, Map metadat if (!indexExists(idx)) { createKnnIndex(idx); } - // Store metadata ensureMetadataIndex(); ObjectNode doc = MAPPER.createObjectNode(); doc.put("collection_name", name); @@ -176,10 +243,11 @@ public Collection getCollection(String name) throws Exception { } try { ensureMetadataIndex(); - JsonNode resp = executeRequest("GET", "/" + METADATA_INDEX + "/_doc/" + idx, null); + JsonNode resp = + executeRequest("GET", "/" + METADATA_INDEX + "/_doc/" + idx, null); if (resp.has("found") && resp.get("found").asBoolean()) { - Map meta = MAPPER.convertValue( - resp.path("_source").path("metadata"), Map.class); + Map meta = + MAPPER.convertValue(resp.path("_source").path("metadata"), Map.class); return new Collection(name, meta != null ? meta : Collections.emptyMap()); } } catch (Exception ignored) { @@ -196,6 +264,7 @@ public Collection deleteCollection(String name) throws Exception { try { executeRequest("DELETE", "/" + METADATA_INDEX + "/_doc/" + idx, null); } catch (Exception ignored) { + // metadata doc may not exist } return col; } @@ -210,11 +279,13 @@ private boolean indexExists(String idx) { } private void createKnnIndex(String idx) { - String body = String.format( - "{\"settings\":{\"index\":{\"knn\":true}}," - + "\"mappings\":{\"properties\":{\"%s\":{\"type\":\"knn_vector\",\"dimension\":%d}," - + "\"%s\":{\"type\":\"text\"},\"metadata\":{\"type\":\"object\"}}}}", - vectorField, dims, contentField); + String body = + String.format( + "{\"settings\":{\"index\":{\"knn\":true}}," + + "\"mappings\":{\"properties\":{\"%s\":{\"type\":\"knn_vector\"," + + "\"dimension\":%d},\"%s\":{\"type\":\"text\"}," + + "\"metadata\":{\"type\":\"object\"}}}}", + vectorField, dims, contentField); try { executeRequest("PUT", "/" + idx, body); } catch (RuntimeException e) { @@ -227,9 +298,11 @@ private void createKnnIndex(String idx) { private void ensureMetadataIndex() { if (!indexExists(METADATA_INDEX)) { try { - executeRequest("PUT", "/" + METADATA_INDEX, + executeRequest( + "PUT", + "/" + METADATA_INDEX, "{\"mappings\":{\"properties\":{\"collection_name\":{\"type\":\"keyword\"}," - + "\"metadata\":{\"type\":\"object\"}}}}"); + + "\"metadata\":{\"type\":\"object\"}}}}"); } catch (RuntimeException e) { if (!e.getMessage().contains("resource_already_exists_exception")) { throw e; @@ -242,7 +315,7 @@ private void ensureMetadataIndex() { private String sanitizeIndexName(String name) { return name.toLowerCase(Locale.ROOT) .replaceAll("[^a-z0-9\\-_]", "-") - .replaceAll("^[^a-z]+", "a-"); // must start with letter + .replaceAll("^[^a-z]+", "a-"); } // ---- BaseVectorStore ---- @@ -263,8 +336,11 @@ public long size(@Nullable String collection) throws Exception { } @Override - public List get(@Nullable List ids, @Nullable String collection, - Map extraArgs) throws IOException { + public List get( + @Nullable List ids, + @Nullable String collection, + Map extraArgs) + throws IOException { String idx = collection != null ? sanitizeIndexName(collection) : this.index; if (ids != null && !ids.isEmpty()) { ObjectNode body = MAPPER.createObjectNode(); @@ -273,13 +349,23 @@ public List get(@Nullable List ids, @Nullable String collectio body.put("size", ids.size()); return parseHits(executeRequest("POST", "/" + idx + "/_search", body.toString())); } - return parseHits(executeRequest("POST", "/" + idx + "/_search", - "{\"query\":{\"match_all\":{}},\"size\":10000}")); + int limit = 10000; + if (extraArgs != null && extraArgs.containsKey("limit")) { + limit = ((Number) extraArgs.get("limit")).intValue(); + } + return parseHits( + executeRequest( + "POST", + "/" + idx + "/_search", + "{\"query\":{\"match_all\":{}},\"size\":" + limit + "}")); } @Override - public void delete(@Nullable List ids, @Nullable String collection, - Map extraArgs) throws IOException { + public void delete( + @Nullable List ids, + @Nullable String collection, + Map extraArgs) + throws IOException { String idx = collection != null ? sanitizeIndexName(collection) : this.index; if (ids != null && !ids.isEmpty()) { ObjectNode body = MAPPER.createObjectNode(); @@ -287,15 +373,18 @@ public void delete(@Nullable List ids, @Nullable String collection, ids.forEach(idsArray::add); executeRequest("POST", "/" + idx + "/_delete_by_query", body.toString()); } else { - executeRequest("POST", "/" + idx + "/_delete_by_query", - "{\"query\":{\"match_all\":{}}}"); + executeRequest( + "POST", "/" + idx + "/_delete_by_query", "{\"query\":{\"match_all\":{}}}"); } executeRequest("POST", "/" + idx + "/_refresh", null); } @Override - protected List queryEmbedding(float[] embedding, int limit, - @Nullable String collection, Map args) { + protected List queryEmbedding( + float[] embedding, + int limit, + @Nullable String collection, + Map args) { try { String idx = collection != null ? sanitizeIndexName(collection) : this.index; int k = (int) args.getOrDefault("k", Math.max(1, limit)); @@ -305,13 +394,16 @@ protected List queryEmbedding(float[] embedding, int limit, ObjectNode knnQuery = body.putObject("query").putObject("knn"); ObjectNode fieldQuery = knnQuery.putObject(vectorField); ArrayNode vectorArray = fieldQuery.putArray("vector"); - for (float v : embedding) vectorArray.add(v); + for (float v : embedding) { + vectorArray.add(v); + } fieldQuery.put("k", k); if (args.containsKey("min_score")) { fieldQuery.put("min_score", ((Number) args.get("min_score")).floatValue()); } if (args.containsKey("ef_search")) { - fieldQuery.putObject("method_parameters") + fieldQuery + .putObject("method_parameters") .put("ef_search", ((Number) args.get("ef_search")).intValue()); } if (args.containsKey("filter_query")) { @@ -324,12 +416,12 @@ protected List queryEmbedding(float[] embedding, int limit, } } - /** Max bulk payload size in bytes (default 5MB, configurable via constructor). */ - private final int maxBulkBytes; - @Override - protected List addEmbedding(List documents, @Nullable String collection, - Map extraArgs) throws IOException { + protected List addEmbedding( + List documents, + @Nullable String collection, + Map extraArgs) + throws IOException { String idx = collection != null ? sanitizeIndexName(collection) : this.index; if (!indexExists(idx)) { createKnnIndex(idx); @@ -350,7 +442,9 @@ protected List addEmbedding(List documents, @Nullable String c source.put(contentField, doc.getContent()); if (doc.getEmbedding() != null) { ArrayNode vec = source.putArray(vectorField); - for (float v : doc.getEmbedding()) vec.add(v); + for (float v : doc.getEmbedding()) { + vec.add(v); + } } if (doc.getMetadata() != null) { source.set("metadata", MAPPER.valueToTree(doc.getMetadata())); @@ -359,7 +453,6 @@ protected List addEmbedding(List documents, @Nullable String c int entryBytes = actionLine.length() + sourceLine.length(); - // Flush if adding this entry would exceed the bulk size limit if (bulkBytes > 0 && bulkBytes + entryBytes > maxBulkBytes) { executeRequest("POST", "/_bulk", bulk.toString()); bulk.setLength(0); @@ -377,6 +470,7 @@ protected List addEmbedding(List documents, @Nullable String c return allIds; } + @SuppressWarnings("unchecked") private List parseHits(JsonNode response) { List docs = new ArrayList<>(); JsonNode hits = response.path("hits").path("hits"); @@ -393,35 +487,36 @@ private List parseHits(JsonNode response) { return docs; } - @SuppressWarnings("unchecked") private JsonNode executeRequest(String method, String path, @Nullable String body) { try { URI uri = URI.create(endpoint + path); - SdkHttpFullRequest.Builder reqBuilder = SdkHttpFullRequest.builder() - .uri(uri) - .method(SdkHttpMethod.valueOf(method)) - .putHeader("Content-Type", "application/json"); + SdkHttpFullRequest.Builder reqBuilder = + SdkHttpFullRequest.builder() + .uri(uri) + .method(SdkHttpMethod.valueOf(method)) + .putHeader("Content-Type", "application/json"); if (body != null) { - reqBuilder.contentStreamProvider(() -> - new ByteArrayInputStream(body.getBytes(StandardCharsets.UTF_8))); + reqBuilder.contentStreamProvider( + () -> new ByteArrayInputStream(body.getBytes(StandardCharsets.UTF_8))); } SdkHttpFullRequest request; if (useIamAuth) { - Aws4SignerParams signerParams = Aws4SignerParams.builder() - .awsCredentials(DefaultCredentialsProvider.create().resolveCredentials()) - .signingName(serverless ? "aoss" : "es") - .signingRegion(region) - .build(); + AwsCredentials credentials = credentialsProvider.resolveCredentials(); + Aws4SignerParams signerParams = + Aws4SignerParams.builder() + .awsCredentials(credentials) + .signingName(serverless ? "aoss" : "es") + .signingRegion(region) + .build(); request = signer.sign(reqBuilder.build(), signerParams); } else { - request = reqBuilder - .putHeader("Authorization", basicAuthHeader) - .build(); + request = reqBuilder.putHeader("Authorization", basicAuthHeader).build(); } - HttpExecuteRequest.Builder execBuilder = HttpExecuteRequest.builder().request(request); + HttpExecuteRequest.Builder execBuilder = + HttpExecuteRequest.builder().request(request); if (request.contentStreamProvider().isPresent()) { execBuilder.contentStreamProvider(request.contentStreamProvider().get()); } @@ -431,17 +526,18 @@ private JsonNode executeRequest(String method, String path, @Nullable String bod if ("HEAD".equals(method)) { if (statusCode >= 400) { - throw new RuntimeException("OpenSearch HEAD request failed (" + statusCode + ")"); + throw new RuntimeException( + "OpenSearch HEAD request failed (" + statusCode + ")"); } return MAPPER.createObjectNode().put("status", statusCode); } - String responseBody = new String( - response.responseBody().orElseThrow().readAllBytes()); + String responseBody = + new String(response.responseBody().orElseThrow().readAllBytes()); - if (response.httpResponse().statusCode() >= 400) { - throw new RuntimeException("OpenSearch request failed (" + - response.httpResponse().statusCode() + "): " + responseBody); + if (statusCode >= 400) { + throw new RuntimeException( + "OpenSearch request failed (" + statusCode + "): " + responseBody); } return MAPPER.readTree(responseBody); } catch (RuntimeException e) { diff --git a/integrations/vector-stores/s3vectors/src/main/java/org/apache/flink/agents/integrations/vectorstores/s3vectors/S3VectorsVectorStore.java b/integrations/vector-stores/s3vectors/src/main/java/org/apache/flink/agents/integrations/vectorstores/s3vectors/S3VectorsVectorStore.java index 9fec3151..92f47d36 100644 --- a/integrations/vector-stores/s3vectors/src/main/java/org/apache/flink/agents/integrations/vectorstores/s3vectors/S3VectorsVectorStore.java +++ b/integrations/vector-stores/s3vectors/src/main/java/org/apache/flink/agents/integrations/vectorstores/s3vectors/S3VectorsVectorStore.java @@ -15,6 +15,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.flink.agents.integrations.vectorstores.s3vectors; import org.apache.flink.agents.api.embedding.model.BaseEmbeddingModelSetup; @@ -23,21 +24,66 @@ import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.agents.api.vectorstores.BaseVectorStore; import org.apache.flink.agents.api.vectorstores.Document; + import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.s3vectors.S3VectorsClient; -import software.amazon.awssdk.services.s3vectors.model.*; +import software.amazon.awssdk.services.s3vectors.model.DeleteVectorsRequest; +import software.amazon.awssdk.services.s3vectors.model.GetOutputVector; +import software.amazon.awssdk.services.s3vectors.model.GetVectorsRequest; +import software.amazon.awssdk.services.s3vectors.model.GetVectorsResponse; +import software.amazon.awssdk.services.s3vectors.model.PutInputVector; +import software.amazon.awssdk.services.s3vectors.model.PutVectorsRequest; +import software.amazon.awssdk.services.s3vectors.model.QueryOutputVector; +import software.amazon.awssdk.services.s3vectors.model.QueryVectorsRequest; +import software.amazon.awssdk.services.s3vectors.model.QueryVectorsResponse; +import software.amazon.awssdk.services.s3vectors.model.VectorData; import javax.annotation.Nullable; + import java.io.IOException; -import java.util.*; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; import java.util.function.BiFunction; /** * Amazon S3 Vectors vector store for flink-agents. + * + *

Uses the S3 Vectors SDK for PutVectors/QueryVectors/GetVectors/DeleteVectors. PutVectors + * calls are chunked at 500 vectors per request (API limit). + * + *

Supported parameters: + * + *

    + *
  • embedding_model (required): name of the embedding model resource + *
  • vector_bucket (required): S3 Vectors bucket name + *
  • vector_index (required): S3 Vectors index name + *
  • region (optional): AWS region (default: us-east-1) + *
+ * + *

Example usage: + * + *

{@code
+ * @VectorStore
+ * public static ResourceDescriptor s3VectorsStore() {
+ *     return ResourceDescriptor.Builder.newBuilder(S3VectorsVectorStore.class.getName())
+ *             .addInitialArgument("embedding_model", "bedrockEmbeddingSetup")
+ *             .addInitialArgument("vector_bucket", "my-vector-bucket")
+ *             .addInitialArgument("vector_index", "my-index")
+ *             .addInitialArgument("region", "us-east-1")
+ *             .build();
+ * }
+ * }
*/ public class S3VectorsVectorStore extends BaseVectorStore { + private static final int MAX_PUT_VECTORS_BATCH = 500; + private final S3VectorsClient client; private final String vectorBucket; private final String vectorIndex; @@ -46,21 +92,47 @@ public S3VectorsVectorStore( ResourceDescriptor descriptor, BiFunction getResource) { super(descriptor, getResource); + this.vectorBucket = descriptor.getArgument("vector_bucket"); + if (this.vectorBucket == null || this.vectorBucket.isBlank()) { + throw new IllegalArgumentException( + "vector_bucket is required for S3VectorsVectorStore"); + } + this.vectorIndex = descriptor.getArgument("vector_index"); + if (this.vectorIndex == null || this.vectorIndex.isBlank()) { + throw new IllegalArgumentException( + "vector_index is required for S3VectorsVectorStore"); + } + String regionStr = descriptor.getArgument("region"); - this.client = S3VectorsClient.builder() - .region(Region.of(regionStr != null ? regionStr : "us-east-1")) - .credentialsProvider(DefaultCredentialsProvider.create()) - .build(); + this.client = + S3VectorsClient.builder() + .region(Region.of(regionStr != null ? regionStr : "us-east-1")) + .credentialsProvider(DefaultCredentialsProvider.create()) + .build(); + } + + @Override + public void close() throws Exception { + this.client.close(); } - /** Batch-embeds all documents in a single call, then delegates to addEmbedding. */ + /** + * Batch-embeds all documents in a single call, then delegates to addEmbedding. + * + *

TODO: This batch embedding logic is duplicated in OpenSearchVectorStore. Consider + * extracting to BaseVectorStore in a follow-up. + */ @Override - public List add(List documents, @Nullable String collection, - Map extraArgs) throws IOException { - BaseEmbeddingModelSetup emb = (BaseEmbeddingModelSetup) - this.getResource.apply(this.embeddingModel, ResourceType.EMBEDDING_MODEL); + public List add( + List documents, + @Nullable String collection, + Map extraArgs) + throws IOException { + BaseEmbeddingModelSetup emb = + (BaseEmbeddingModelSetup) + this.getResource.apply(this.embeddingModel, ResourceType.EMBEDDING_MODEL); List texts = new ArrayList<>(); List needsEmbedding = new ArrayList<>(); for (int i = 0; i < documents.size(); i++) { @@ -86,21 +158,36 @@ public Map getStoreKwargs() { return m; } + /** + * S3 Vectors does not support a count operation. + * + * @throws UnsupportedOperationException always + */ @Override - public long size(@Nullable String collection) { return -1; } + public long size(@Nullable String collection) { + throw new UnsupportedOperationException( + "S3 Vectors does not support count operations"); + } @Override - public List get(@Nullable List ids, @Nullable String collection, - Map extraArgs) throws IOException { - if (ids == null || ids.isEmpty()) return Collections.emptyList(); + public List get( + @Nullable List ids, + @Nullable String collection, + Map extraArgs) + throws IOException { + if (ids == null || ids.isEmpty()) { + return Collections.emptyList(); + } String idx = collection != null ? collection : vectorIndex; - GetVectorsResponse response = client.getVectors(GetVectorsRequest.builder() - .vectorBucketName(vectorBucket) - .indexName(idx) - .keys(ids) - .returnMetadata(true) - .build()); + GetVectorsResponse response = + client.getVectors( + GetVectorsRequest.builder() + .vectorBucketName(vectorBucket) + .indexName(idx) + .keys(ids) + .returnMetadata(true) + .build()); List docs = new ArrayList<>(); for (GetOutputVector v : response.vectors()) { @@ -110,31 +197,47 @@ public List get(@Nullable List ids, @Nullable String collectio } @Override - public void delete(@Nullable List ids, @Nullable String collection, - Map extraArgs) throws IOException { - if (ids == null || ids.isEmpty()) return; + public void delete( + @Nullable List ids, + @Nullable String collection, + Map extraArgs) + throws IOException { + if (ids == null || ids.isEmpty()) { + return; + } String idx = collection != null ? collection : vectorIndex; - client.deleteVectors(DeleteVectorsRequest.builder() - .vectorBucketName(vectorBucket).indexName(idx).keys(ids).build()); + client.deleteVectors( + DeleteVectorsRequest.builder() + .vectorBucketName(vectorBucket) + .indexName(idx) + .keys(ids) + .build()); } @Override - protected List queryEmbedding(float[] embedding, int limit, - @Nullable String collection, Map args) { + protected List queryEmbedding( + float[] embedding, + int limit, + @Nullable String collection, + Map args) { try { String idx = collection != null ? collection : vectorIndex; int topK = (int) args.getOrDefault("top_k", Math.max(1, limit)); List queryVector = new ArrayList<>(embedding.length); - for (float v : embedding) queryVector.add(v); + for (float v : embedding) { + queryVector.add(v); + } - QueryVectorsResponse response = client.queryVectors(QueryVectorsRequest.builder() - .vectorBucketName(vectorBucket) - .indexName(idx) - .queryVector(VectorData.fromFloat32(queryVector)) - .topK(topK) - .returnMetadata(true) - .build()); + QueryVectorsResponse response = + client.queryVectors( + QueryVectorsRequest.builder() + .vectorBucketName(vectorBucket) + .indexName(idx) + .queryVector(VectorData.fromFloat32(queryVector)) + .topK(topK) + .returnMetadata(true) + .build()); List docs = new ArrayList<>(); for (QueryOutputVector v : response.vectors()) { @@ -146,15 +249,15 @@ protected List queryEmbedding(float[] embedding, int limit, } } - private static final int MAX_PUT_VECTORS_BATCH = 500; - @Override - protected List addEmbedding(List documents, @Nullable String collection, - Map extraArgs) throws IOException { + protected List addEmbedding( + List documents, + @Nullable String collection, + Map extraArgs) + throws IOException { String idx = collection != null ? collection : vectorIndex; List allKeys = new ArrayList<>(); - // Build all vectors first List allVectors = new ArrayList<>(); for (Document doc : documents) { String key = doc.getId() != null ? doc.getId() : UUID.randomUUID().toString(); @@ -162,42 +265,61 @@ protected List addEmbedding(List documents, @Nullable String c List embeddingList = new ArrayList<>(); if (doc.getEmbedding() != null) { - for (float v : doc.getEmbedding()) embeddingList.add(v); + for (float v : doc.getEmbedding()) { + embeddingList.add(v); + } } - Map metaMap = new LinkedHashMap<>(); - metaMap.put("source_text", + Map metaMap = + new LinkedHashMap<>(); + metaMap.put( + "source_text", software.amazon.awssdk.core.document.Document.fromString(doc.getContent())); if (doc.getMetadata() != null) { - doc.getMetadata().forEach((k, v) -> metaMap.put(k, - software.amazon.awssdk.core.document.Document.fromString(String.valueOf(v)))); + doc.getMetadata() + .forEach( + (k, v) -> + metaMap.put( + k, + software.amazon.awssdk.core.document.Document + .fromString(String.valueOf(v)))); } - allVectors.add(PutInputVector.builder() - .key(key) - .data(VectorData.fromFloat32(embeddingList)) - .metadata(software.amazon.awssdk.core.document.Document.fromMap(metaMap)) - .build()); + allVectors.add( + PutInputVector.builder() + .key(key) + .data(VectorData.fromFloat32(embeddingList)) + .metadata( + software.amazon.awssdk.core.document.Document.fromMap(metaMap)) + .build()); } - // Chunk into batches of 500 (S3 Vectors API limit) for (int i = 0; i < allVectors.size(); i += MAX_PUT_VECTORS_BATCH) { - List batch = allVectors.subList(i, - Math.min(i + MAX_PUT_VECTORS_BATCH, allVectors.size())); - client.putVectors(PutVectorsRequest.builder() - .vectorBucketName(vectorBucket).indexName(idx).vectors(batch).build()); + List batch = + allVectors.subList( + i, Math.min(i + MAX_PUT_VECTORS_BATCH, allVectors.size())); + client.putVectors( + PutVectorsRequest.builder() + .vectorBucketName(vectorBucket) + .indexName(idx) + .vectors(batch) + .build()); } return allKeys; } - private Document toDocument(String key, - software.amazon.awssdk.core.document.Document metadata) { + private Document toDocument( + String key, software.amazon.awssdk.core.document.Document metadata) { Map metaMap = new HashMap<>(); String content = ""; if (metadata != null && metadata.isMap()) { - metadata.asMap().forEach((k, v) -> { - if (v.isString()) metaMap.put(k, v.asString()); - }); + metadata.asMap() + .forEach( + (k, v) -> { + if (v.isString()) { + metaMap.put(k, v.asString()); + } + }); content = metaMap.getOrDefault("source_text", "").toString(); } return new Document(content, metaMap, key); From dd77d5072e336480ef78a7164a34508493919d07 Mon Sep 17 00:00:00 2001 From: Avichay Marciano Date: Sat, 14 Feb 2026 00:19:33 +0200 Subject: [PATCH 7/8] Apply spotless formatting to pass CI code-style check --- .../opensearch/OpenSearchVectorStore.java | 35 ++--- .../opensearch/OpenSearchVectorStoreTest.java | 135 ++++++++++-------- .../s3vectors/S3VectorsVectorStore.java | 38 ++--- .../s3vectors/S3VectorsVectorStoreTest.java | 71 ++++----- 4 files changed, 130 insertions(+), 149 deletions(-) diff --git a/integrations/vector-stores/opensearch/src/main/java/org/apache/flink/agents/integrations/vectorstores/opensearch/OpenSearchVectorStore.java b/integrations/vector-stores/opensearch/src/main/java/org/apache/flink/agents/integrations/vectorstores/opensearch/OpenSearchVectorStore.java index bc790f06..7315f8c1 100644 --- a/integrations/vector-stores/opensearch/src/main/java/org/apache/flink/agents/integrations/vectorstores/opensearch/OpenSearchVectorStore.java +++ b/integrations/vector-stores/opensearch/src/main/java/org/apache/flink/agents/integrations/vectorstores/opensearch/OpenSearchVectorStore.java @@ -22,7 +22,6 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.node.ArrayNode; import com.fasterxml.jackson.databind.node.ObjectNode; - import org.apache.flink.agents.api.embedding.model.BaseEmbeddingModelSetup; import org.apache.flink.agents.api.resource.Resource; import org.apache.flink.agents.api.resource.ResourceDescriptor; @@ -30,7 +29,6 @@ import org.apache.flink.agents.api.vectorstores.BaseVectorStore; import org.apache.flink.agents.api.vectorstores.CollectionManageableVectorStore; import org.apache.flink.agents.api.vectorstores.Document; - import software.amazon.awssdk.auth.credentials.AwsCredentials; import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; import software.amazon.awssdk.auth.signer.Aws4Signer; @@ -124,8 +122,7 @@ public class OpenSearchVectorStore extends BaseVectorStore private final DefaultCredentialsProvider credentialsProvider; public OpenSearchVectorStore( - ResourceDescriptor descriptor, - BiFunction getResource) { + ResourceDescriptor descriptor, BiFunction getResource) { super(descriptor, getResource); this.endpoint = descriptor.getArgument("endpoint"); @@ -192,9 +189,7 @@ public void close() throws Exception { */ @Override public List add( - List documents, - @Nullable String collection, - Map extraArgs) + List documents, @Nullable String collection, Map extraArgs) throws IOException { BaseEmbeddingModelSetup emb = (BaseEmbeddingModelSetup) @@ -243,8 +238,7 @@ public Collection getCollection(String name) throws Exception { } try { ensureMetadataIndex(); - JsonNode resp = - executeRequest("GET", "/" + METADATA_INDEX + "/_doc/" + idx, null); + JsonNode resp = executeRequest("GET", "/" + METADATA_INDEX + "/_doc/" + idx, null); if (resp.has("found") && resp.get("found").asBoolean()) { Map meta = MAPPER.convertValue(resp.path("_source").path("metadata"), Map.class); @@ -337,9 +331,7 @@ public long size(@Nullable String collection) throws Exception { @Override public List get( - @Nullable List ids, - @Nullable String collection, - Map extraArgs) + @Nullable List ids, @Nullable String collection, Map extraArgs) throws IOException { String idx = collection != null ? sanitizeIndexName(collection) : this.index; if (ids != null && !ids.isEmpty()) { @@ -362,9 +354,7 @@ public List get( @Override public void delete( - @Nullable List ids, - @Nullable String collection, - Map extraArgs) + @Nullable List ids, @Nullable String collection, Map extraArgs) throws IOException { String idx = collection != null ? sanitizeIndexName(collection) : this.index; if (ids != null && !ids.isEmpty()) { @@ -381,10 +371,7 @@ public void delete( @Override protected List queryEmbedding( - float[] embedding, - int limit, - @Nullable String collection, - Map args) { + float[] embedding, int limit, @Nullable String collection, Map args) { try { String idx = collection != null ? sanitizeIndexName(collection) : this.index; int k = (int) args.getOrDefault("k", Math.max(1, limit)); @@ -418,9 +405,7 @@ protected List queryEmbedding( @Override protected List addEmbedding( - List documents, - @Nullable String collection, - Map extraArgs) + List documents, @Nullable String collection, Map extraArgs) throws IOException { String idx = collection != null ? sanitizeIndexName(collection) : this.index; if (!indexExists(idx)) { @@ -515,8 +500,7 @@ private JsonNode executeRequest(String method, String path, @Nullable String bod request = reqBuilder.putHeader("Authorization", basicAuthHeader).build(); } - HttpExecuteRequest.Builder execBuilder = - HttpExecuteRequest.builder().request(request); + HttpExecuteRequest.Builder execBuilder = HttpExecuteRequest.builder().request(request); if (request.contentStreamProvider().isPresent()) { execBuilder.contentStreamProvider(request.contentStreamProvider().get()); } @@ -532,8 +516,7 @@ private JsonNode executeRequest(String method, String path, @Nullable String bod return MAPPER.createObjectNode().put("status", statusCode); } - String responseBody = - new String(response.responseBody().orElseThrow().readAllBytes()); + String responseBody = new String(response.responseBody().orElseThrow().readAllBytes()); if (statusCode >= 400) { throw new RuntimeException( diff --git a/integrations/vector-stores/opensearch/src/test/java/org/apache/flink/agents/integrations/vectorstores/opensearch/OpenSearchVectorStoreTest.java b/integrations/vector-stores/opensearch/src/test/java/org/apache/flink/agents/integrations/vectorstores/opensearch/OpenSearchVectorStoreTest.java index dfed9a91..6e99cc0d 100644 --- a/integrations/vector-stores/opensearch/src/test/java/org/apache/flink/agents/integrations/vectorstores/opensearch/OpenSearchVectorStoreTest.java +++ b/integrations/vector-stores/opensearch/src/test/java/org/apache/flink/agents/integrations/vectorstores/opensearch/OpenSearchVectorStoreTest.java @@ -54,15 +54,16 @@ public class OpenSearchVectorStoreTest { @Test @DisplayName("Constructor creates store with IAM auth") void testConstructorIamAuth() { - ResourceDescriptor desc = ResourceDescriptor.Builder - .newBuilder(OpenSearchVectorStore.class.getName()) - .addInitialArgument("embedding_model", "emb") - .addInitialArgument("endpoint", "https://example.aoss.us-east-1.amazonaws.com") - .addInitialArgument("index", "test-index") - .addInitialArgument("region", "us-east-1") - .addInitialArgument("service_type", "serverless") - .addInitialArgument("auth", "iam") - .build(); + ResourceDescriptor desc = + ResourceDescriptor.Builder.newBuilder(OpenSearchVectorStore.class.getName()) + .addInitialArgument("embedding_model", "emb") + .addInitialArgument( + "endpoint", "https://example.aoss.us-east-1.amazonaws.com") + .addInitialArgument("index", "test-index") + .addInitialArgument("region", "us-east-1") + .addInitialArgument("service_type", "serverless") + .addInitialArgument("auth", "iam") + .build(); OpenSearchVectorStore store = new OpenSearchVectorStore(desc, NOOP); assertThat(store).isInstanceOf(BaseVectorStore.class); assertThat(store).isInstanceOf(CollectionManageableVectorStore.class); @@ -71,17 +72,18 @@ void testConstructorIamAuth() { @Test @DisplayName("Constructor creates store with basic auth") void testConstructorBasicAuth() { - ResourceDescriptor desc = ResourceDescriptor.Builder - .newBuilder(OpenSearchVectorStore.class.getName()) - .addInitialArgument("embedding_model", "emb") - .addInitialArgument("endpoint", "https://my-domain.us-east-1.es.amazonaws.com") - .addInitialArgument("index", "test-index") - .addInitialArgument("region", "us-east-1") - .addInitialArgument("service_type", "domain") - .addInitialArgument("auth", "basic") - .addInitialArgument("username", "admin") - .addInitialArgument("password", "password") - .build(); + ResourceDescriptor desc = + ResourceDescriptor.Builder.newBuilder(OpenSearchVectorStore.class.getName()) + .addInitialArgument("embedding_model", "emb") + .addInitialArgument( + "endpoint", "https://my-domain.us-east-1.es.amazonaws.com") + .addInitialArgument("index", "test-index") + .addInitialArgument("region", "us-east-1") + .addInitialArgument("service_type", "domain") + .addInitialArgument("auth", "basic") + .addInitialArgument("username", "admin") + .addInitialArgument("password", "password") + .build(); OpenSearchVectorStore store = new OpenSearchVectorStore(desc, NOOP); assertThat(store).isInstanceOf(BaseVectorStore.class); } @@ -89,13 +91,14 @@ void testConstructorBasicAuth() { @Test @DisplayName("Constructor with custom max_bulk_mb") void testConstructorCustomBulkSize() { - ResourceDescriptor desc = ResourceDescriptor.Builder - .newBuilder(OpenSearchVectorStore.class.getName()) - .addInitialArgument("embedding_model", "emb") - .addInitialArgument("endpoint", "https://example.aoss.us-east-1.amazonaws.com") - .addInitialArgument("index", "test-index") - .addInitialArgument("max_bulk_mb", 10) - .build(); + ResourceDescriptor desc = + ResourceDescriptor.Builder.newBuilder(OpenSearchVectorStore.class.getName()) + .addInitialArgument("embedding_model", "emb") + .addInitialArgument( + "endpoint", "https://example.aoss.us-east-1.amazonaws.com") + .addInitialArgument("index", "test-index") + .addInitialArgument("max_bulk_mb", 10) + .build(); OpenSearchVectorStore store = new OpenSearchVectorStore(desc, NOOP); assertThat(store.getStoreKwargs()).containsEntry("index", "test-index"); } @@ -103,15 +106,15 @@ void testConstructorCustomBulkSize() { @Test @DisplayName("Basic auth requires username and password") void testBasicAuthRequiresCredentials() { - ResourceDescriptor desc = ResourceDescriptor.Builder - .newBuilder(OpenSearchVectorStore.class.getName()) - .addInitialArgument("embedding_model", "emb") - .addInitialArgument("endpoint", "https://example.com") - .addInitialArgument("index", "test") - .addInitialArgument("auth", "basic") - .build(); - Assertions.assertThrows(IllegalArgumentException.class, - () -> new OpenSearchVectorStore(desc, NOOP)); + ResourceDescriptor desc = + ResourceDescriptor.Builder.newBuilder(OpenSearchVectorStore.class.getName()) + .addInitialArgument("embedding_model", "emb") + .addInitialArgument("endpoint", "https://example.com") + .addInitialArgument("index", "test") + .addInitialArgument("auth", "basic") + .build(); + Assertions.assertThrows( + IllegalArgumentException.class, () -> new OpenSearchVectorStore(desc, NOOP)); } // --- Integration tests (require real OpenSearch) --- @@ -121,19 +124,21 @@ void testBasicAuthRequiresCredentials() { private static Resource getResource(String name, ResourceType type) { BaseEmbeddingModelSetup emb = Mockito.mock(BaseEmbeddingModelSetup.class); Mockito.when(emb.embed("OpenSearch is a search engine")) - .thenReturn(new float[]{0.2f, 0.3f, 0.4f, 0.5f, 0.6f}); + .thenReturn(new float[] {0.2f, 0.3f, 0.4f, 0.5f, 0.6f}); Mockito.when(emb.embed("Flink Agents is an AI framework")) - .thenReturn(new float[]{0.1f, 0.2f, 0.3f, 0.4f, 0.5f}); + .thenReturn(new float[] {0.1f, 0.2f, 0.3f, 0.4f, 0.5f}); Mockito.when(emb.embed("search engine")) - .thenReturn(new float[]{0.2f, 0.3f, 0.4f, 0.5f, 0.6f}); - Mockito.when(emb.embed(Mockito.anyList())).thenAnswer(inv -> { - List texts = inv.getArgument(0); - List result = new ArrayList<>(); - for (String t : texts) { - result.add(emb.embed(t)); - } - return result; - }); + .thenReturn(new float[] {0.2f, 0.3f, 0.4f, 0.5f, 0.6f}); + Mockito.when(emb.embed(Mockito.anyList())) + .thenAnswer( + inv -> { + List texts = inv.getArgument(0); + List result = new ArrayList<>(); + for (String t : texts) { + result.add(emb.embed(t)); + } + return result; + }); return emb; } @@ -142,16 +147,19 @@ static void initialize() { String endpoint = System.getenv("OPENSEARCH_ENDPOINT"); if (endpoint == null) return; String auth = System.getenv().getOrDefault("OPENSEARCH_AUTH", "iam"); - ResourceDescriptor.Builder builder = ResourceDescriptor.Builder - .newBuilder(OpenSearchVectorStore.class.getName()) - .addInitialArgument("embedding_model", "emb") - .addInitialArgument("endpoint", endpoint) - .addInitialArgument("index", "test-opensearch") - .addInitialArgument("dims", 5) - .addInitialArgument("region", System.getenv().getOrDefault("AWS_REGION", "us-east-1")) - .addInitialArgument("service_type", - System.getenv().getOrDefault("OPENSEARCH_SERVICE_TYPE", "serverless")) - .addInitialArgument("auth", auth); + ResourceDescriptor.Builder builder = + ResourceDescriptor.Builder.newBuilder(OpenSearchVectorStore.class.getName()) + .addInitialArgument("embedding_model", "emb") + .addInitialArgument("endpoint", endpoint) + .addInitialArgument("index", "test-opensearch") + .addInitialArgument("dims", 5) + .addInitialArgument( + "region", System.getenv().getOrDefault("AWS_REGION", "us-east-1")) + .addInitialArgument( + "service_type", + System.getenv() + .getOrDefault("OPENSEARCH_SERVICE_TYPE", "serverless")) + .addInitialArgument("auth", auth); if ("basic".equals(auth)) { builder.addInitialArgument("username", System.getenv("OPENSEARCH_USERNAME")); builder.addInitialArgument("password", System.getenv("OPENSEARCH_PASSWORD")); @@ -213,13 +221,16 @@ void testQueryWithFilter() throws Exception { Thread.sleep(1000); // Query with filter: only src=web - VectorStoreQuery q = new VectorStoreQuery( - "search engine", 5, name, - Map.of("filter_query", "{\"term\":{\"metadata.src.keyword\":\"web\"}}")); + VectorStoreQuery q = + new VectorStoreQuery( + "search engine", + 5, + name, + Map.of("filter_query", "{\"term\":{\"metadata.src.keyword\":\"web\"}}")); List results = store.query(q).getDocuments(); Assertions.assertFalse(results.isEmpty()); - Assertions.assertTrue(results.stream().allMatch( - d -> "web".equals(d.getMetadata().get("src")))); + Assertions.assertTrue( + results.stream().allMatch(d -> "web".equals(d.getMetadata().get("src")))); ((CollectionManageableVectorStore) store).deleteCollection(name); } diff --git a/integrations/vector-stores/s3vectors/src/main/java/org/apache/flink/agents/integrations/vectorstores/s3vectors/S3VectorsVectorStore.java b/integrations/vector-stores/s3vectors/src/main/java/org/apache/flink/agents/integrations/vectorstores/s3vectors/S3VectorsVectorStore.java index 92f47d36..d47576fd 100644 --- a/integrations/vector-stores/s3vectors/src/main/java/org/apache/flink/agents/integrations/vectorstores/s3vectors/S3VectorsVectorStore.java +++ b/integrations/vector-stores/s3vectors/src/main/java/org/apache/flink/agents/integrations/vectorstores/s3vectors/S3VectorsVectorStore.java @@ -24,7 +24,6 @@ import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.agents.api.vectorstores.BaseVectorStore; import org.apache.flink.agents.api.vectorstores.Document; - import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.s3vectors.S3VectorsClient; @@ -54,8 +53,8 @@ /** * Amazon S3 Vectors vector store for flink-agents. * - *

Uses the S3 Vectors SDK for PutVectors/QueryVectors/GetVectors/DeleteVectors. PutVectors - * calls are chunked at 500 vectors per request (API limit). + *

Uses the S3 Vectors SDK for PutVectors/QueryVectors/GetVectors/DeleteVectors. PutVectors calls + * are chunked at 500 vectors per request (API limit). * *

Supported parameters: * @@ -89,8 +88,7 @@ public class S3VectorsVectorStore extends BaseVectorStore { private final String vectorIndex; public S3VectorsVectorStore( - ResourceDescriptor descriptor, - BiFunction getResource) { + ResourceDescriptor descriptor, BiFunction getResource) { super(descriptor, getResource); this.vectorBucket = descriptor.getArgument("vector_bucket"); @@ -101,8 +99,7 @@ public S3VectorsVectorStore( this.vectorIndex = descriptor.getArgument("vector_index"); if (this.vectorIndex == null || this.vectorIndex.isBlank()) { - throw new IllegalArgumentException( - "vector_index is required for S3VectorsVectorStore"); + throw new IllegalArgumentException("vector_index is required for S3VectorsVectorStore"); } String regionStr = descriptor.getArgument("region"); @@ -126,9 +123,7 @@ public void close() throws Exception { */ @Override public List add( - List documents, - @Nullable String collection, - Map extraArgs) + List documents, @Nullable String collection, Map extraArgs) throws IOException { BaseEmbeddingModelSetup emb = (BaseEmbeddingModelSetup) @@ -165,15 +160,12 @@ public Map getStoreKwargs() { */ @Override public long size(@Nullable String collection) { - throw new UnsupportedOperationException( - "S3 Vectors does not support count operations"); + throw new UnsupportedOperationException("S3 Vectors does not support count operations"); } @Override public List get( - @Nullable List ids, - @Nullable String collection, - Map extraArgs) + @Nullable List ids, @Nullable String collection, Map extraArgs) throws IOException { if (ids == null || ids.isEmpty()) { return Collections.emptyList(); @@ -198,9 +190,7 @@ public List get( @Override public void delete( - @Nullable List ids, - @Nullable String collection, - Map extraArgs) + @Nullable List ids, @Nullable String collection, Map extraArgs) throws IOException { if (ids == null || ids.isEmpty()) { return; @@ -216,10 +206,7 @@ public void delete( @Override protected List queryEmbedding( - float[] embedding, - int limit, - @Nullable String collection, - Map args) { + float[] embedding, int limit, @Nullable String collection, Map args) { try { String idx = collection != null ? collection : vectorIndex; int topK = (int) args.getOrDefault("top_k", Math.max(1, limit)); @@ -251,9 +238,7 @@ protected List queryEmbedding( @Override protected List addEmbedding( - List documents, - @Nullable String collection, - Map extraArgs) + List documents, @Nullable String collection, Map extraArgs) throws IOException { String idx = collection != null ? collection : vectorIndex; List allKeys = new ArrayList<>(); @@ -296,8 +281,7 @@ protected List addEmbedding( for (int i = 0; i < allVectors.size(); i += MAX_PUT_VECTORS_BATCH) { List batch = - allVectors.subList( - i, Math.min(i + MAX_PUT_VECTORS_BATCH, allVectors.size())); + allVectors.subList(i, Math.min(i + MAX_PUT_VECTORS_BATCH, allVectors.size())); client.putVectors( PutVectorsRequest.builder() .vectorBucketName(vectorBucket) diff --git a/integrations/vector-stores/s3vectors/src/test/java/org/apache/flink/agents/integrations/vectorstores/s3vectors/S3VectorsVectorStoreTest.java b/integrations/vector-stores/s3vectors/src/test/java/org/apache/flink/agents/integrations/vectorstores/s3vectors/S3VectorsVectorStoreTest.java index 6ffac314..ae8cd0f6 100644 --- a/integrations/vector-stores/s3vectors/src/test/java/org/apache/flink/agents/integrations/vectorstores/s3vectors/S3VectorsVectorStoreTest.java +++ b/integrations/vector-stores/s3vectors/src/test/java/org/apache/flink/agents/integrations/vectorstores/s3vectors/S3VectorsVectorStoreTest.java @@ -51,13 +51,13 @@ public class S3VectorsVectorStoreTest { @Test @DisplayName("Constructor creates store") void testConstructor() { - ResourceDescriptor desc = ResourceDescriptor.Builder - .newBuilder(S3VectorsVectorStore.class.getName()) - .addInitialArgument("embedding_model", "emb") - .addInitialArgument("vector_bucket", "my-bucket") - .addInitialArgument("vector_index", "my-index") - .addInitialArgument("region", "us-east-1") - .build(); + ResourceDescriptor desc = + ResourceDescriptor.Builder.newBuilder(S3VectorsVectorStore.class.getName()) + .addInitialArgument("embedding_model", "emb") + .addInitialArgument("vector_bucket", "my-bucket") + .addInitialArgument("vector_index", "my-index") + .addInitialArgument("region", "us-east-1") + .build(); S3VectorsVectorStore store = new S3VectorsVectorStore(desc, NOOP); assertThat(store).isInstanceOf(BaseVectorStore.class); } @@ -65,12 +65,12 @@ void testConstructor() { @Test @DisplayName("getStoreKwargs returns bucket and index") void testStoreKwargs() { - ResourceDescriptor desc = ResourceDescriptor.Builder - .newBuilder(S3VectorsVectorStore.class.getName()) - .addInitialArgument("embedding_model", "emb") - .addInitialArgument("vector_bucket", "test-bucket") - .addInitialArgument("vector_index", "test-index") - .build(); + ResourceDescriptor desc = + ResourceDescriptor.Builder.newBuilder(S3VectorsVectorStore.class.getName()) + .addInitialArgument("embedding_model", "emb") + .addInitialArgument("vector_bucket", "test-bucket") + .addInitialArgument("vector_index", "test-index") + .build(); S3VectorsVectorStore store = new S3VectorsVectorStore(desc, NOOP); Map kwargs = store.getStoreKwargs(); assertThat(kwargs).containsEntry("vector_bucket", "test-bucket"); @@ -84,17 +84,19 @@ void testStoreKwargs() { private static Resource getResource(String name, ResourceType type) { BaseEmbeddingModelSetup emb = Mockito.mock(BaseEmbeddingModelSetup.class); Mockito.when(emb.embed("Test document one")) - .thenReturn(new float[]{0.1f, 0.2f, 0.3f, 0.4f, 0.5f}); + .thenReturn(new float[] {0.1f, 0.2f, 0.3f, 0.4f, 0.5f}); Mockito.when(emb.embed("Test document two")) - .thenReturn(new float[]{0.5f, 0.4f, 0.3f, 0.2f, 0.1f}); - Mockito.when(emb.embed(Mockito.anyList())).thenAnswer(inv -> { - List texts = inv.getArgument(0); - List result = new ArrayList<>(); - for (String t : texts) { - result.add(emb.embed(t)); - } - return result; - }); + .thenReturn(new float[] {0.5f, 0.4f, 0.3f, 0.2f, 0.1f}); + Mockito.when(emb.embed(Mockito.anyList())) + .thenAnswer( + inv -> { + List texts = inv.getArgument(0); + List result = new ArrayList<>(); + for (String t : texts) { + result.add(emb.embed(t)); + } + return result; + }); return emb; } @@ -102,15 +104,16 @@ private static Resource getResource(String name, ResourceType type) { static void initialize() { String bucket = System.getenv("S3V_BUCKET"); if (bucket == null) return; - ResourceDescriptor desc = ResourceDescriptor.Builder - .newBuilder(S3VectorsVectorStore.class.getName()) - .addInitialArgument("embedding_model", "emb") - .addInitialArgument("vector_bucket", bucket) - .addInitialArgument("vector_index", - System.getenv().getOrDefault("S3V_INDEX", "test-index")) - .addInitialArgument("region", - System.getenv().getOrDefault("AWS_REGION", "us-east-1")) - .build(); + ResourceDescriptor desc = + ResourceDescriptor.Builder.newBuilder(S3VectorsVectorStore.class.getName()) + .addInitialArgument("embedding_model", "emb") + .addInitialArgument("vector_bucket", bucket) + .addInitialArgument( + "vector_index", + System.getenv().getOrDefault("S3V_INDEX", "test-index")) + .addInitialArgument( + "region", System.getenv().getOrDefault("AWS_REGION", "us-east-1")) + .build(); store = new S3VectorsVectorStore(desc, S3VectorsVectorStoreTest::getResource); } @@ -123,8 +126,8 @@ void testDocumentAddAndGet() throws Exception { docs.add(new Document("Test document two", Map.of("src", "test"), "s3v-doc2")); store.add(docs, null, Collections.emptyMap()); - List retrieved = store.get(List.of("s3v-doc1", "s3v-doc2"), - null, Collections.emptyMap()); + List retrieved = + store.get(List.of("s3v-doc1", "s3v-doc2"), null, Collections.emptyMap()); Assertions.assertEquals(2, retrieved.size()); store.delete(List.of("s3v-doc1", "s3v-doc2"), null, Collections.emptyMap()); From 9c6d61e98977abf2e551f838c3ca7652328300be Mon Sep 17 00:00:00 2001 From: Avichay Marciano Date: Thu, 19 Feb 2026 11:23:32 +0200 Subject: [PATCH 8/8] Proactive fixes: retry, exception handling, resource cleanup OpenSearchVectorStore: - Add retry with exponential backoff for 429/502/503 in executeRequest() - Only ignore 404s in getCollection/deleteCollection (not all exceptions) - Close credentialsProvider in close() S3VectorsVectorStore: - Add retry with backoff for putVectors (ThrottlingException, 429, 503) Consistent with retry patterns in BedrockChatModelConnection and BedrockEmbeddingModelConnection. --- .../opensearch/OpenSearchVectorStore.java | 44 +++++++++++++++++-- .../s3vectors/S3VectorsVectorStore.java | 44 ++++++++++++++++--- 2 files changed, 78 insertions(+), 10 deletions(-) diff --git a/integrations/vector-stores/opensearch/src/main/java/org/apache/flink/agents/integrations/vectorstores/opensearch/OpenSearchVectorStore.java b/integrations/vector-stores/opensearch/src/main/java/org/apache/flink/agents/integrations/vectorstores/opensearch/OpenSearchVectorStore.java index 7315f8c1..0ffcb494 100644 --- a/integrations/vector-stores/opensearch/src/main/java/org/apache/flink/agents/integrations/vectorstores/opensearch/OpenSearchVectorStore.java +++ b/integrations/vector-stores/opensearch/src/main/java/org/apache/flink/agents/integrations/vectorstores/opensearch/OpenSearchVectorStore.java @@ -179,6 +179,7 @@ public OpenSearchVectorStore( @Override public void close() throws Exception { this.httpClient.close(); + this.credentialsProvider.close(); } /** @@ -244,8 +245,11 @@ public Collection getCollection(String name) throws Exception { MAPPER.convertValue(resp.path("_source").path("metadata"), Map.class); return new Collection(name, meta != null ? meta : Collections.emptyMap()); } - } catch (Exception ignored) { - // metadata index may not exist yet + } catch (RuntimeException e) { + // metadata index may not exist yet; only ignore 404s + if (!e.getMessage().contains("(404)")) { + throw e; + } } return new Collection(name, Collections.emptyMap()); } @@ -257,8 +261,11 @@ public Collection deleteCollection(String name) throws Exception { executeRequest("DELETE", "/" + idx, null); try { executeRequest("DELETE", "/" + METADATA_INDEX + "/_doc/" + idx, null); - } catch (Exception ignored) { - // metadata doc may not exist + } catch (RuntimeException e) { + // metadata doc may not exist; only ignore 404s + if (!e.getMessage().contains("(404)")) { + throw e; + } } return col; } @@ -472,7 +479,36 @@ private List parseHits(JsonNode response) { return docs; } + private static final int MAX_RETRIES = 5; + private JsonNode executeRequest(String method, String path, @Nullable String body) { + for (int attempt = 0; ; attempt++) { + try { + return doExecuteRequest(method, path, body); + } catch (RuntimeException e) { + if (attempt < MAX_RETRIES && isRetryableStatus(e)) { + try { + long delay = + (long) (Math.pow(2, attempt) * 200 * (0.5 + Math.random() * 0.5)); + Thread.sleep(delay); + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Interrupted during OpenSearch retry.", ie); + } + } else { + throw e; + } + } + } + } + + private static boolean isRetryableStatus(RuntimeException e) { + String msg = e.getMessage(); + return msg != null + && (msg.contains("(429)") || msg.contains("(503)") || msg.contains("(502)")); + } + + private JsonNode doExecuteRequest(String method, String path, @Nullable String body) { try { URI uri = URI.create(endpoint + path); SdkHttpFullRequest.Builder reqBuilder = diff --git a/integrations/vector-stores/s3vectors/src/main/java/org/apache/flink/agents/integrations/vectorstores/s3vectors/S3VectorsVectorStore.java b/integrations/vector-stores/s3vectors/src/main/java/org/apache/flink/agents/integrations/vectorstores/s3vectors/S3VectorsVectorStore.java index d47576fd..0c25adee 100644 --- a/integrations/vector-stores/s3vectors/src/main/java/org/apache/flink/agents/integrations/vectorstores/s3vectors/S3VectorsVectorStore.java +++ b/integrations/vector-stores/s3vectors/src/main/java/org/apache/flink/agents/integrations/vectorstores/s3vectors/S3VectorsVectorStore.java @@ -236,6 +236,8 @@ protected List queryEmbedding( } } + private static final int MAX_RETRIES = 5; + @Override protected List addEmbedding( List documents, @Nullable String collection, Map extraArgs) @@ -282,16 +284,46 @@ protected List addEmbedding( for (int i = 0; i < allVectors.size(); i += MAX_PUT_VECTORS_BATCH) { List batch = allVectors.subList(i, Math.min(i + MAX_PUT_VECTORS_BATCH, allVectors.size())); - client.putVectors( - PutVectorsRequest.builder() - .vectorBucketName(vectorBucket) - .indexName(idx) - .vectors(batch) - .build()); + putVectorsWithRetry(idx, batch); } return allKeys; } + private void putVectorsWithRetry(String idx, List batch) { + for (int attempt = 0; ; attempt++) { + try { + client.putVectors( + PutVectorsRequest.builder() + .vectorBucketName(vectorBucket) + .indexName(idx) + .vectors(batch) + .build()); + return; + } catch (Exception e) { + if (attempt < MAX_RETRIES && isRetryable(e)) { + try { + long delay = + (long) (Math.pow(2, attempt) * 200 * (0.5 + Math.random() * 0.5)); + Thread.sleep(delay); + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Interrupted during S3 Vectors retry.", ie); + } + } else { + throw new RuntimeException("S3 Vectors putVectors failed.", e); + } + } + } + } + + private static boolean isRetryable(Exception e) { + String msg = e.toString(); + return msg.contains("ThrottlingException") + || msg.contains("ServiceUnavailableException") + || msg.contains("429") + || msg.contains("503"); + } + private Document toDocument( String key, software.amazon.awssdk.core.document.Document metadata) { Map metaMap = new HashMap<>();