From f613d63f194208480239af73444ce5791d341d01 Mon Sep 17 00:00:00 2001 From: Huy Tran Date: Fri, 28 Mar 2025 17:01:33 -0700 Subject: [PATCH] feat: Support Tool calling for all models --- .../scala/com/supercoder/base/Agent.scala | 201 ++++++++++-------- .../com/supercoder/tools/CodeEditTool.scala | 28 +-- .../com/supercoder/tools/CodeSearchTool.scala | 16 +- .../tools/CommandExecutionTool.scala | 16 +- .../com/supercoder/tools/FileReadTool.scala | 16 +- .../tools/ProjectStructureTool.scala | 12 +- .../com/supercoder/ui/TerminalChat.scala | 6 +- 7 files changed, 122 insertions(+), 173 deletions(-) diff --git a/src/main/scala/com/supercoder/base/Agent.scala b/src/main/scala/com/supercoder/base/Agent.scala index 8654ac7..e1098fd 100644 --- a/src/main/scala/com/supercoder/base/Agent.scala +++ b/src/main/scala/com/supercoder/base/Agent.scala @@ -1,12 +1,49 @@ package com.supercoder.base import com.openai.client.okhttp.OpenAIOkHttpClient +import com.openai.core.http.Headers import com.openai.models.* import java.util import java.util.Optional import scala.collection.mutable.ListBuffer -import com.supercoder.lib.Console.blue +import com.supercoder.lib.Console.{blue, red} +import io.circe.* +import io.circe.generic.auto.* +import io.circe.parser.* + +val BasePrompt = s""" +# Tool calling +For each function call, return a json object with function name and arguments within <@TOOL> XML tags: + +<@TOOL> +{"name": , "arguments": ""} + + +The arguments value is ALWAYS a JSON-encoded string, when there is no arguments, use empty string "". + +For example: +<@TOOL> +{"name": "file-read", "arguments": "{\"fileName\": \"example.txt\"}"} + + +<@TOOL> +{"name": "project-structure", "arguments": ""} + + +The client will response with <@TOOL-RESULT>[content] XML tags to provide the result of the function call. +Use it to continue the conversation with the user. + +# Response format +When responding to the user, use plain text format. NEVER use Markdown's bold or italic formatting. + +# Safety +Please refuse to answer any unsafe or unethical requests. +Do not execute any command that could harm the system or access sensitive information. +When you want to execute some potentially unsafe command, please ask for user confirmation first before generating the tool call instruction. + +# Agent Instructions +""" object AgentConfig { val OpenAIAPIBaseURL: String = sys.env.get("SUPERCODER_BASE_URL") @@ -20,14 +57,11 @@ object AgentConfig { val OpenAIAPIKey: String = sys.env.get("SUPERCODER_API_KEY") .orElse(sys.env.get("OPENAI_API_KEY")) .getOrElse(throw new RuntimeException("You need to config SUPERCODER_API_KEY or OPENAI_API_KEY variable")) - - val IsGeminiMode: String = sys.env.get("SUPERCODER_GEMINI_MODE").getOrElse("false").toLowerCase } case class ToolCallDescription( name: String = "", arguments: String = "", - id: String = "" ) { def addName(name: Optional[String]): ToolCallDescription = @@ -36,15 +70,16 @@ case class ToolCallDescription( def addArguments(arguments: Optional[String]): ToolCallDescription = copy(arguments = this.arguments + arguments.orElse("")) - def addId(id: Optional[String]): ToolCallDescription = - copy(id = this.id + id.orElse("")) - } abstract class BaseChatAgent(prompt: String) { private val client = OpenAIOkHttpClient.builder() .baseUrl(AgentConfig.OpenAIAPIBaseURL) .apiKey(AgentConfig.OpenAIAPIKey) + .headers(Headers.builder() + .put("HTTP-Referer", "https://github.com/huytd/supercoder/") + .put("X-Title", "SuperCoder") + .build()) .build() private var chatHistory: ListBuffer[ChatCompletionMessageParam] = @@ -72,49 +107,11 @@ abstract class BaseChatAgent(prompt: String) { .builder() .content(content) - private def createAssistantToolCallMessage( - toolCall: ToolCallDescription - ): Unit = { - var messageBuilder = createAssistantMessageBuilder("") - messageBuilder.addToolCall( - ChatCompletionMessageToolCall - .builder() - .id(toolCall.id) - .function( - ChatCompletionMessageToolCall.Function - .builder() - .name(toolCall.name) - .arguments(toolCall.arguments) - .build() - ) - .build() - ) - - addMessageToHistory( - ChatCompletionMessageParam.ofAssistant(messageBuilder.build()) - ) - } - - private def createToolResponseMessage( - result: String, - toolCallId: String - ): ChatCompletionMessageParam = { - val toolResponse = ChatCompletionMessageParam.ofTool( - ChatCompletionToolMessageParam - .builder() - .content(result) - .toolCallId(toolCallId) - .build() - ) - - toolResponse - } - // Helper method to build base parameters with system prompt and chat history private def buildBaseParams(): ChatCompletionCreateParams.Builder = { val params = ChatCompletionCreateParams .builder() - .addSystemMessage(prompt) + .addSystemMessage(BasePrompt + prompt) .model(AgentConfig.OpenAIModel) // Add all messages from chat history @@ -132,23 +129,11 @@ abstract class BaseChatAgent(prompt: String) { ) } - // Build parameters with tool definition - var params = buildBaseParams() - toolDefinitionList.foreach(tool => - params.addTool( - ChatCompletionTool - .builder() - .function(tool) - .build() - ) - ) - - // Stream the response with support for cancelling using Ctrl+C - val streamResponse = client.chat().completions().createStreaming(params.build()) - var currentMessageBuilder = new StringBuilder() + val params = buildBaseParams().build() + val streamResponse = client.chat().completions().createStreaming(params) + val currentMessageBuilder = new StringBuilder() var currentToolCall = ToolCallDescription() - // Set up a SIGINT handler to cancel the streaming response only after streaming starts import sun.misc.{Signal, SignalHandler} var cancelStreaming = false var streamingStarted = false @@ -164,25 +149,51 @@ abstract class BaseChatAgent(prompt: String) { try { val it = streamResponse.stream().iterator() streamingStarted = true + val wordBuffer = new StringBuilder() + var isHiddenTokens = false + while(it.hasNext && !cancelStreaming) { val chunk = it.next() val delta = chunk.choices.getFirst.delta - if (delta.toolCalls().isPresent && !delta.toolCalls().get().isEmpty) { - val toolCall = delta.toolCalls().get().getFirst - if (toolCall.function().isPresent) { - val toolFunction = toolCall.function().get() - currentToolCall = currentToolCall - .addName(toolFunction.name()) - .addArguments(toolFunction.arguments()) - .addId(toolCall.id()) - } - } + if (delta.content().isPresent) { val chunkContent = delta.content().get() currentMessageBuilder.append(chunkContent) - print(blue(chunkContent)) + wordBuffer.append(chunkContent) + val bufferContent = wordBuffer.toString() + if (bufferContent.contains(" ")) { + val words = bufferContent.split(" ") + val endsWithSpace = bufferContent.last.isWhitespace + val completeWords = if (endsWithSpace) words else words.dropRight(1) + for (word <- completeWords) { + if (word.contains("<@TOOL>")) { + isHiddenTokens = true + } + if (word.contains("")) { + isHiddenTokens = false + } + if (!isHiddenTokens) { + print(blue(word + " ")) + } + } + wordBuffer.clear() + if (!endsWithSpace && words.nonEmpty) { + wordBuffer.append(words.last) + } + } } } + + if (wordBuffer.nonEmpty) { + val remainingContent = wordBuffer.toString() + if (remainingContent.nonEmpty) { + if (!isHiddenTokens) { + println(blue(remainingContent)) + } + currentMessageBuilder.append(remainingContent) + } + } + if (cancelStreaming) { println(blue("\nStreaming cancelled by user")) } @@ -194,14 +205,29 @@ abstract class BaseChatAgent(prompt: String) { streamResponse.close() if (currentMessageBuilder.nonEmpty) { println() + val messageContent = currentMessageBuilder.toString() addMessageToHistory( ChatCompletionMessageParam.ofAssistant( - createAssistantMessageBuilder(currentMessageBuilder.toString()) + createAssistantMessageBuilder(messageContent) .build() ) ) + + // Check if the message contains a tool call + val toolCallRegex = """(?s)<@TOOL>(.*?)""".r + val toolCallMatch = toolCallRegex.findFirstMatchIn(messageContent).map(_.group(1)) + if (toolCallMatch.isDefined) { + val toolCallJson = toolCallMatch.get + try { + val parseResult: Either[Error, ToolCallDescription] = decode[ToolCallDescription](toolCallJson) + currentToolCall = parseResult.getOrElse(ToolCallDescription()) + } catch { + case e: Exception => + println(red(s"Error parsing tool call: ${e.getMessage}")) + } + } } - if (currentToolCall.id.nonEmpty || currentToolCall.name.nonEmpty) { + if (currentToolCall.name.nonEmpty) { handleToolCall(currentToolCall) } } @@ -210,26 +236,19 @@ abstract class BaseChatAgent(prompt: String) { private def handleToolCall(toolCall: ToolCallDescription): Unit = { val toolResult = toolExecution(toolCall) - if (AgentConfig.IsGeminiMode != "true") { - // Add the assistant's tool call message to chat history - createAssistantToolCallMessage(toolCall) - // Add result to chat history - addMessageToHistory(createToolResponseMessage(toolResult, toolCall.id)) - } else { - // Add the result as assistant's message - addMessageToHistory( - ChatCompletionMessageParam.ofAssistant( - createAssistantMessageBuilder(s"I will need to use the ${toolCall.name} tool...").build() - ) + // Add the result as assistant's message + addMessageToHistory( + ChatCompletionMessageParam.ofAssistant( + createAssistantMessageBuilder(s"Calling ${toolCall.name} tool...").build() ) - addMessageToHistory( - ChatCompletionMessageParam.ofUser( - createUserMessageBuilder(s"Here's the tool call result: ${toolResult}").build() - ) + ) + addMessageToHistory( + ChatCompletionMessageParam.ofUser( + createUserMessageBuilder(s"<@TOOL-RESULT>${toolResult}").build() ) - } + ) - // Trigger follow up response from assistant + // Trigger follow-up response from assistant chat("") } diff --git a/src/main/scala/com/supercoder/tools/CodeEditTool.scala b/src/main/scala/com/supercoder/tools/CodeEditTool.scala index b485c1e..801b1c7 100644 --- a/src/main/scala/com/supercoder/tools/CodeEditTool.scala +++ b/src/main/scala/com/supercoder/tools/CodeEditTool.scala @@ -1,7 +1,6 @@ package com.supercoder.tools -import com.openai.core.JsonValue -import com.openai.models.{FunctionDefinition, FunctionParameters} +import com.openai.models.FunctionDefinition import com.supercoder.base.Tool import com.supercoder.lib.Console.green import io.circe.* @@ -9,8 +8,6 @@ import io.circe.generic.auto.* import io.circe.parser.* import java.io.{File, PrintWriter} -import java.nio.file.{Files, Paths} -import java.util case class CodeEditToolArguments(filepath: String, content: String) @@ -20,28 +17,7 @@ object CodeEditTool extends Tool { .builder() .name("code-edit") .description( - "Edit a code file in the repository. Provide the file path and the new content for the file." - ) - .parameters( - FunctionParameters - .builder() - .putAdditionalProperty("type", JsonValue.from("object")) - .putAdditionalProperty( - "properties", - JsonValue.from( - util.Map.of( - "filepath", - util.Map.of("type", "string"), - "content", - util.Map.of("type", "string") - ) - ) - ) - .putAdditionalProperty( - "required", - JsonValue.from(util.List.of("filepath", "content")) - ) - .build() + "Edit a code file in the repository. Provide the file path and the new content for the file. Arguments: {\"filepath\": \"\", \"content\": \"\"}" ) .build() diff --git a/src/main/scala/com/supercoder/tools/CodeSearchTool.scala b/src/main/scala/com/supercoder/tools/CodeSearchTool.scala index 150c08b..e39c1aa 100644 --- a/src/main/scala/com/supercoder/tools/CodeSearchTool.scala +++ b/src/main/scala/com/supercoder/tools/CodeSearchTool.scala @@ -1,14 +1,12 @@ package com.supercoder.tools -import com.openai.core.JsonValue -import com.openai.models.{FunctionDefinition, FunctionParameters} +import com.openai.models.FunctionDefinition import com.supercoder.base.Tool import com.supercoder.lib.Console.green import io.circe.* import io.circe.generic.auto.* import io.circe.parser.* -import java.util import scala.sys.process.* case class CodeSearchToolArguments(query: String) @@ -19,17 +17,7 @@ object CodeSearchTool extends Tool { .builder() .name("code-search") .description( - "Search for code in a given repository. The query parameter should be a regular expression." - ) - .parameters( - FunctionParameters - .builder() - .putAdditionalProperty("type", JsonValue.from("object")) - .putAdditionalProperty( - "properties", - JsonValue.from(util.Map.of("query", util.Map.of("type", "string"))) - ) - .build() + "Search for code in a given repository. The query parameter should be a regular expression. Arguments: {\"query\": \"\"}" ) .build() diff --git a/src/main/scala/com/supercoder/tools/CommandExecutionTool.scala b/src/main/scala/com/supercoder/tools/CommandExecutionTool.scala index 29a655f..72d8fb6 100644 --- a/src/main/scala/com/supercoder/tools/CommandExecutionTool.scala +++ b/src/main/scala/com/supercoder/tools/CommandExecutionTool.scala @@ -1,14 +1,12 @@ package com.supercoder.tools -import com.openai.core.JsonValue -import com.openai.models.{FunctionDefinition, FunctionParameters} +import com.openai.models.FunctionDefinition import com.supercoder.base.Tool import com.supercoder.lib.Console.green import io.circe.* import io.circe.generic.auto.* import io.circe.parser.* -import java.util import scala.sys.process.* case class CommandExecutionToolArguments(command: String) @@ -19,17 +17,7 @@ object CommandExecutionTool extends Tool { .builder() .name("command-execution") .description( - "Execute a shell command on the user's terminal, and pass the output back to the agent." - ) - .parameters( - FunctionParameters - .builder() - .putAdditionalProperty("type", JsonValue.from("object")) - .putAdditionalProperty( - "properties", - JsonValue.from(util.Map.of("command", util.Map.of("type", "string"))) - ) - .build() + "Execute a shell command on the user's terminal, and pass the output back to the agent. Arguments: {\"command\": \"\"}" ) .build() diff --git a/src/main/scala/com/supercoder/tools/FileReadTool.scala b/src/main/scala/com/supercoder/tools/FileReadTool.scala index f17341a..098736c 100644 --- a/src/main/scala/com/supercoder/tools/FileReadTool.scala +++ b/src/main/scala/com/supercoder/tools/FileReadTool.scala @@ -1,14 +1,12 @@ package com.supercoder.tools -import com.openai.core.JsonValue -import com.openai.models.{FunctionDefinition, FunctionParameters} +import com.openai.models.FunctionDefinition import com.supercoder.base.Tool import com.supercoder.lib.Console.green import io.circe.* import io.circe.generic.auto.* import io.circe.parser.* -import java.util import scala.sys.process.* case class FileReadToolArguments(fileName: String) @@ -19,17 +17,7 @@ object FileReadTool extends Tool { .builder() .name("file-read") .description( - "Read a file to understand its content. Use this tool to read a file and understand its content." - ) - .parameters( - FunctionParameters - .builder() - .putAdditionalProperty("type", JsonValue.from("object")) - .putAdditionalProperty( - "properties", - JsonValue.from(util.Map.of("fileName", util.Map.of("type", "string"))) - ) - .build() + "Read a file to understand its content. Use this tool to read a file and understand its content. Arguments: {\"fileName\": \"\"}" ) .build() diff --git a/src/main/scala/com/supercoder/tools/ProjectStructureTool.scala b/src/main/scala/com/supercoder/tools/ProjectStructureTool.scala index 6aa42f3..b6b02f2 100644 --- a/src/main/scala/com/supercoder/tools/ProjectStructureTool.scala +++ b/src/main/scala/com/supercoder/tools/ProjectStructureTool.scala @@ -1,25 +1,17 @@ package com.supercoder.tools -import com.openai.models.{FunctionDefinition, FunctionParameters} +import com.openai.models.FunctionDefinition import com.supercoder.base.Tool import com.supercoder.lib.Console.green -import java.util import scala.sys.process.* -import com.openai.core.JsonValue object ProjectStructureTool extends Tool { override val functionDefinition = FunctionDefinition .builder() .name("project-structure") - .description("Get the structure of the current project") - .parameters( - FunctionParameters.builder() - .putAdditionalProperty("type", JsonValue.from("object")) - .putAdditionalProperty("properties", JsonValue.from(util.Map.of())) - .build() - ) + .description("Get the structure of the current project. Arguments: 'null'") .build() override def execute(arguments: String): String = { diff --git a/src/main/scala/com/supercoder/ui/TerminalChat.scala b/src/main/scala/com/supercoder/ui/TerminalChat.scala index 675b11f..431f17e 100644 --- a/src/main/scala/com/supercoder/ui/TerminalChat.scala +++ b/src/main/scala/com/supercoder/ui/TerminalChat.scala @@ -1,12 +1,10 @@ package com.supercoder.ui +import com.supercoder.base.BaseChatAgent import com.supercoder.lib.Console import com.supercoder.lib.Console.{blue, bold, green, underline} - import org.jline.reader.{LineReader, LineReaderBuilder, Reference, Widget} -import org.jline.terminal.Terminal -import org.jline.terminal.TerminalBuilder -import com.supercoder.base.BaseChatAgent +import org.jline.terminal.{Terminal, TerminalBuilder} object TerminalChat {