From 0fd16df95f5f1ed4e04130738c96cd1bb5769082 Mon Sep 17 00:00:00 2001 From: Huy Tran Date: Sat, 29 Mar 2025 17:01:50 -0700 Subject: [PATCH 1/2] rewrite text streaming --- src/main/scala/com/supercoder/Main.scala | 3 +- .../com/supercoder/agents/CoderAgent.scala | 4 +- .../scala/com/supercoder/base/Agent.scala | 178 ++++++++++++++---- .../com/supercoder/config/ArgsParser.scala | 5 +- .../com/supercoder/ui/TerminalChat.scala | 7 +- 5 files changed, 153 insertions(+), 44 deletions(-) diff --git a/src/main/scala/com/supercoder/Main.scala b/src/main/scala/com/supercoder/Main.scala index 46924ea..384d54a 100644 --- a/src/main/scala/com/supercoder/Main.scala +++ b/src/main/scala/com/supercoder/Main.scala @@ -10,7 +10,8 @@ object Main { ArgsParser.parse(args) match { case Some(config) => val additionalPrompt = if config.useCursorRules then CursorRulesLoader.loadRules() else "" - val agent = new CoderAgent(additionalPrompt) + val modelName = config.model + val agent = new CoderAgent(additionalPrompt, modelName) TerminalChat.run(agent) case None => // invalid options, usage error message is already printed by scopt diff --git a/src/main/scala/com/supercoder/agents/CoderAgent.scala b/src/main/scala/com/supercoder/agents/CoderAgent.scala index a11a333..a65d8ba 100644 --- a/src/main/scala/com/supercoder/agents/CoderAgent.scala +++ b/src/main/scala/com/supercoder/agents/CoderAgent.scala @@ -21,8 +21,8 @@ The discussion is about the code of the current project/folder. Always use the r project if you are unsure before giving answer. """ -class CoderAgent(additionalPrompt: String = "") - extends BaseChatAgent(coderAgentPrompt + additionalPrompt) { +class CoderAgent(additionalPrompt: String = "", model: String = "") + extends BaseChatAgent(coderAgentPrompt + additionalPrompt, model) { final val availableTools = List( CodeSearchTool, diff --git a/src/main/scala/com/supercoder/base/Agent.scala b/src/main/scala/com/supercoder/base/Agent.scala index 9e99d0f..1f32fdd 100644 --- a/src/main/scala/com/supercoder/base/Agent.scala +++ b/src/main/scala/com/supercoder/base/Agent.scala @@ -3,27 +3,33 @@ package com.supercoder.base import com.openai.client.okhttp.OpenAIOkHttpClient import com.openai.core.http.Headers import com.openai.models.* - -import java.util -import java.util.Optional -import scala.collection.mutable.ListBuffer import com.supercoder.lib.Console.{blue, red} import io.circe.* import io.circe.generic.auto.* import io.circe.parser.* +import java.util +import java.util.Optional +import scala.collection.mutable.ListBuffer + val BasePrompt = s""" # Tool calling For each function call, return a json object with function name and arguments within <@TOOL> XML tags: -<@TOOL>{"name": , "arguments": ""} +<@TOOL> +{"name": , "arguments": ""} + The arguments value is ALWAYS a JSON-encoded string, when there is no arguments, use empty string "". For example: -<@TOOL>{"name": "file-read", "arguments": "{\"fileName\": \"example.txt\"}"} +<@TOOL> +{"name": "file-read", "arguments": "{\"fileName\": \"example.txt\"}"} + -<@TOOL>{"name": "project-structure", "arguments": ""} +<@TOOL> +{"name": "project-structure", "arguments": ""} + The client will response with <@TOOL-RESULT>[content] XML tags to provide the result of the function call. Use it to continue the conversation with the user. @@ -66,7 +72,7 @@ case class ToolCallDescription( } -abstract class BaseChatAgent(prompt: String) { +abstract class BaseChatAgent(prompt: String, model: String = AgentConfig.OpenAIModel) { private val client = OpenAIOkHttpClient.builder() .baseUrl(AgentConfig.OpenAIAPIBaseURL) .apiKey(AgentConfig.OpenAIAPIKey) @@ -79,6 +85,8 @@ abstract class BaseChatAgent(prompt: String) { private var chatHistory: ListBuffer[ChatCompletionMessageParam] = ListBuffer.empty + def selectedModel: String = if (model.nonEmpty) model else AgentConfig.OpenAIModel + def toolExecution(toolCall: ToolCallDescription): String def toolDefinitionList: List[FunctionDefinition] @@ -106,7 +114,7 @@ abstract class BaseChatAgent(prompt: String) { val params = ChatCompletionCreateParams .builder() .addSystemMessage(BasePrompt + prompt) - .model(AgentConfig.OpenAIModel) + .model(selectedModel) // Add all messages from chat history chatHistory.foreach(params.addMessage) @@ -144,48 +152,117 @@ abstract class BaseChatAgent(prompt: String) { val it = streamResponse.stream().iterator() streamingStarted = true val wordBuffer = new StringBuilder() - var isHiddenTokens = false + var isInToolTag = false while(it.hasNext && !cancelStreaming) { val chunk = it.next() val delta = chunk.choices.getFirst.delta if (delta.content().isPresent) { - val chunkContent = delta.content().get() - currentMessageBuilder.append(chunkContent) - wordBuffer.append(chunkContent) - val bufferContent = wordBuffer.toString() - if (bufferContent.contains(" ")) { - val words = bufferContent.split(" ") - val endsWithSpace = bufferContent.last.isWhitespace - val completeWords = if (endsWithSpace) words else words.dropRight(1) - for (word <- completeWords) { - if (word.contains("<@TOOL>")) { - isHiddenTokens = true + val content = delta.content().get() + wordBuffer.append(content) + // Append raw content immediately to the main builder for history/parsing + currentMessageBuilder.append(content) + + // Process wordBuffer for printing + val toolStart = "<@TOOL>" + val toolEnd = "" + val toolResultStart = "<@TOOL-RESULT>" + val toolResultEnd = "" + + // Store which end marker we are looking for when isInToolTag is true + var currentToolTagEndMarker: Option[String] = None + + var processedSomething = true // Flag to loop if buffer was modified + while (processedSomething && wordBuffer.nonEmpty) { + processedSomething = false // Assume no processing needed unless tag/word found + + if (isInToolTag) { + val endMarker = currentToolTagEndMarker.getOrElse(toolEnd) // Use stored marker + val endTagIndex = wordBuffer.indexOf(endMarker) + if (endTagIndex != -1) { + // End tag found in buffer + val contentToConsume = wordBuffer.substring(0, endTagIndex + endMarker.length) + // print(red(contentToConsume)) // Omit printing + wordBuffer.delete(0, contentToConsume.length) + isInToolTag = false + currentToolTagEndMarker = None // Clear expected end marker + processedSomething = true // Check remaining buffer for normal text + } else { + // End tag not yet in buffer. Consume the whole buffer internally, but don't print. + // print(red(wordBuffer.toString())) // Omit printing + wordBuffer.clear() + // Wait for the next chunk } - if (word.contains("")) { - isHiddenTokens = false + } else { // Not in tool tag + // Find the *earliest* start tag + val toolStartIndex = wordBuffer.indexOf(toolStart) + val toolResultStartIndex = wordBuffer.indexOf(toolResultStart) + + var startTagIndex = -1 + var startMarker = "" + var expectedEndMarker = "" + + // Determine which tag starts first, if any + if (toolStartIndex != -1 && (toolResultStartIndex == -1 || toolStartIndex < toolResultStartIndex)) { + startTagIndex = toolStartIndex + startMarker = toolStart + expectedEndMarker = toolEnd + } else if (toolResultStartIndex != -1) { + startTagIndex = toolResultStartIndex + startMarker = toolResultStart + expectedEndMarker = toolResultEnd } - if (!isHiddenTokens) { - print(blue(word + " ")) + + if (startTagIndex != -1) { + // A start tag was found + // Process content *before* the tag + val beforeTag = wordBuffer.substring(0, startTagIndex) + if (beforeTag.nonEmpty) { + val (words, remaining) = processWords(beforeTag) + // Only print if words were actually extracted + if (words.nonEmpty) { + words.foreach { case (word, ws) => print(blue(word)); print(ws) } + // Delete only the blue part that was fully processed (words + whitespace) + wordBuffer.delete(0, beforeTag.length - remaining.length) + processedSomething = true // Buffer content changed + } + } + + // Print the start tag itself red, but only if it's now at the start of the buffer + if (wordBuffer.indexOf(startMarker) == 0) { + // print(red(startMarker)) // Omit printing + wordBuffer.delete(0, startMarker.length) + isInToolTag = true + currentToolTagEndMarker = Some(expectedEndMarker) // Set expected end marker + processedSomething = true // Buffer content changed, loop again + } + + } else { + // No start tag found in the buffer, process as regular content + val (words, remaining) = processWords(wordBuffer.toString()) + if (words.nonEmpty) { // Only process if complete words were found + words.foreach { case (word, ws) => print(blue(word)); print(ws) } + val processedLength = wordBuffer.length() - remaining.length() + wordBuffer.delete(0, processedLength) + processedSomething = true // Buffer content changed + } + // If only `remaining` part is left, wait for the next chunk } } - wordBuffer.clear() - if (!endsWithSpace && words.nonEmpty) { - wordBuffer.append(words.last) - } - } + } // End while(processedSomething && wordBuffer.nonEmpty) } - } + } // End while(it.hasNext) + // Print out the rest of the word buffer if it has any content if (wordBuffer.nonEmpty) { - val remainingContent = wordBuffer.toString() - if (remainingContent.nonEmpty) { - if (!isHiddenTokens) { - println(blue(remainingContent)) - } - currentMessageBuilder.append(remainingContent) + if (isInToolTag) { + // Should ideally not happen if tags are well-formed, but print red if it does + print(red(wordBuffer.toString())) + } else { + print(blue(wordBuffer.toString())) } + wordBuffer.clear() } if (cancelStreaming) { @@ -227,6 +304,33 @@ abstract class BaseChatAgent(prompt: String) { } } + // Helper function to process words and whitespace + private def processWords(text: String): (ListBuffer[(String, String)], String) = { + val words = ListBuffer[(String, String)]() + var remainingText = text + var continueProcessing = true + + while (continueProcessing) { + val whitespaceIndex = remainingText.indexWhere(_.isWhitespace) + if (whitespaceIndex != -1) { + val word = remainingText.substring(0, whitespaceIndex) + val whitespace = remainingText.substring(whitespaceIndex).takeWhile(_.isWhitespace) + if (word.nonEmpty) { + words += ((word, whitespace)) + } else { + // Handle leading whitespace? For now, just consume it with the next word or as trailing. + // If printing just whitespace: print(whitespace) + } + remainingText = remainingText.substring(whitespaceIndex + whitespace.length) + if (remainingText.isEmpty) continueProcessing = false + } else { + // No more whitespace, the rest is a partial word or empty + continueProcessing = false + } + } + (words, remainingText) // Return processed words and any remaining partial word + } + private def handleToolCall(toolCall: ToolCallDescription): Unit = { val toolResult = toolExecution(toolCall) diff --git a/src/main/scala/com/supercoder/config/ArgsParser.scala b/src/main/scala/com/supercoder/config/ArgsParser.scala index 6b12568..0283722 100644 --- a/src/main/scala/com/supercoder/config/ArgsParser.scala +++ b/src/main/scala/com/supercoder/config/ArgsParser.scala @@ -2,7 +2,7 @@ package com.supercoder.config import scopt.OParser -case class Config(useCursorRules: Boolean = false) +case class Config(useCursorRules: Boolean = false, model: String = "") object ArgsParser { def parse(args: Array[String]): Option[Config] = { @@ -14,6 +14,9 @@ object ArgsParser { opt[String]('c', "use-cursor-rules") .action((x, c) => c.copy(useCursorRules = (x == "true"))) .text("use Cursor rules for the agent"), + opt[String]('m', "model") + .action((x, c) => c.copy(model = x)) + .text("model to use for the agent"), help("help").text("prints this usage text") ) } diff --git a/src/main/scala/com/supercoder/ui/TerminalChat.scala b/src/main/scala/com/supercoder/ui/TerminalChat.scala index 1bf2023..df79bf3 100644 --- a/src/main/scala/com/supercoder/ui/TerminalChat.scala +++ b/src/main/scala/com/supercoder/ui/TerminalChat.scala @@ -14,12 +14,13 @@ object TerminalChat { print("\u001b[H") } - def printHeader(): Unit = { + def printHeader(agent: BaseChatAgent): Unit = { clearScreen() println(blue("█▀ █░█ █▀█ █▀▀ █▀█ █▀▀ █▀█ █▀▄ █▀▀ █▀█")) println(blue("▄█ █▄█ █▀▀ ██▄ █▀▄ █▄▄ █▄█ █▄▀ ██▄ █▀▄")) println(blue(s"v${BuildInfo.version}")) println() + println(blue(s"Model: ${agent.selectedModel}")) println(blue("Type '/help' for available commands.\n")) } @@ -34,7 +35,7 @@ object TerminalChat { } def run(agent: BaseChatAgent): Unit = { - printHeader() + printHeader(agent) val terminal: Terminal = TerminalBuilder.builder().system(true).build() val reader: LineReader = LineReaderBuilder.builder().terminal(terminal).build() @@ -69,7 +70,7 @@ object TerminalChat { case "/help" => showHelp() case "/clear" => clearScreen() - printHeader() + printHeader(agent) case "exit" | "bye" => println(blue("\nChat session terminated. Goodbye!")) keepRunning = false From d4cc3e3919485548b51f750cb862b3c60bb84eda Mon Sep 17 00:00:00 2001 From: Huy Tran Date: Sat, 29 Mar 2025 18:09:54 -0700 Subject: [PATCH 2/2] add debug mode --- src/main/scala/com/supercoder/Main.scala | 9 ++- .../scala/com/supercoder/base/Agent.scala | 60 +++++++++---------- .../com/supercoder/config/ArgsParser.scala | 5 +- 3 files changed, 37 insertions(+), 37 deletions(-) diff --git a/src/main/scala/com/supercoder/Main.scala b/src/main/scala/com/supercoder/Main.scala index 384d54a..36eedff 100644 --- a/src/main/scala/com/supercoder/Main.scala +++ b/src/main/scala/com/supercoder/Main.scala @@ -2,15 +2,18 @@ package com.supercoder import com.supercoder.ui.TerminalChat import com.supercoder.agents.CoderAgent -import com.supercoder.config.ArgsParser +import com.supercoder.config.{ArgsParser, Config} import com.supercoder.lib.CursorRulesLoader object Main { + var AppConfig: Config = Config() + def main(args: Array[String]): Unit = { ArgsParser.parse(args) match { case Some(config) => - val additionalPrompt = if config.useCursorRules then CursorRulesLoader.loadRules() else "" - val modelName = config.model + AppConfig = config + val additionalPrompt = if AppConfig.useCursorRules then CursorRulesLoader.loadRules() else "" + val modelName = AppConfig.model val agent = new CoderAgent(additionalPrompt, modelName) TerminalChat.run(agent) case None => diff --git a/src/main/scala/com/supercoder/base/Agent.scala b/src/main/scala/com/supercoder/base/Agent.scala index 1f32fdd..b27c38b 100644 --- a/src/main/scala/com/supercoder/base/Agent.scala +++ b/src/main/scala/com/supercoder/base/Agent.scala @@ -3,6 +3,8 @@ package com.supercoder.base import com.openai.client.okhttp.OpenAIOkHttpClient import com.openai.core.http.Headers import com.openai.models.* +import com.supercoder.Main +import com.supercoder.Main.AppConfig import com.supercoder.lib.Console.{blue, red} import io.circe.* import io.circe.generic.auto.* @@ -161,41 +163,38 @@ abstract class BaseChatAgent(prompt: String, model: String = AgentConfig.OpenAIM if (delta.content().isPresent) { val content = delta.content().get() wordBuffer.append(content) - // Append raw content immediately to the main builder for history/parsing currentMessageBuilder.append(content) - // Process wordBuffer for printing val toolStart = "<@TOOL>" val toolEnd = "" val toolResultStart = "<@TOOL-RESULT>" val toolResultEnd = "" - // Store which end marker we are looking for when isInToolTag is true var currentToolTagEndMarker: Option[String] = None - var processedSomething = true // Flag to loop if buffer was modified + var processedSomething = true while (processedSomething && wordBuffer.nonEmpty) { - processedSomething = false // Assume no processing needed unless tag/word found + processedSomething = false if (isInToolTag) { - val endMarker = currentToolTagEndMarker.getOrElse(toolEnd) // Use stored marker + val endMarker = currentToolTagEndMarker.getOrElse(toolEnd) val endTagIndex = wordBuffer.indexOf(endMarker) if (endTagIndex != -1) { - // End tag found in buffer val contentToConsume = wordBuffer.substring(0, endTagIndex + endMarker.length) - // print(red(contentToConsume)) // Omit printing + if (AppConfig.isDebugMode) { + print(red(contentToConsume)) + } wordBuffer.delete(0, contentToConsume.length) isInToolTag = false - currentToolTagEndMarker = None // Clear expected end marker - processedSomething = true // Check remaining buffer for normal text + currentToolTagEndMarker = None + processedSomething = true } else { - // End tag not yet in buffer. Consume the whole buffer internally, but don't print. - // print(red(wordBuffer.toString())) // Omit printing + if (AppConfig.isDebugMode) { + print(red(wordBuffer.toString())) + } wordBuffer.clear() - // Wait for the next chunk } - } else { // Not in tool tag - // Find the *earliest* start tag + } else { val toolStartIndex = wordBuffer.indexOf(toolStart) val toolResultStartIndex = wordBuffer.indexOf(toolResultStart) @@ -203,7 +202,6 @@ abstract class BaseChatAgent(prompt: String, model: String = AgentConfig.OpenAIM var startMarker = "" var expectedEndMarker = "" - // Determine which tag starts first, if any if (toolStartIndex != -1 && (toolResultStartIndex == -1 || toolStartIndex < toolResultStartIndex)) { startTagIndex = toolStartIndex startMarker = toolStart @@ -215,31 +213,27 @@ abstract class BaseChatAgent(prompt: String, model: String = AgentConfig.OpenAIM } if (startTagIndex != -1) { - // A start tag was found - // Process content *before* the tag val beforeTag = wordBuffer.substring(0, startTagIndex) if (beforeTag.nonEmpty) { val (words, remaining) = processWords(beforeTag) - // Only print if words were actually extracted if (words.nonEmpty) { words.foreach { case (word, ws) => print(blue(word)); print(ws) } - // Delete only the blue part that was fully processed (words + whitespace) wordBuffer.delete(0, beforeTag.length - remaining.length) - processedSomething = true // Buffer content changed + processedSomething = true } } - // Print the start tag itself red, but only if it's now at the start of the buffer if (wordBuffer.indexOf(startMarker) == 0) { - // print(red(startMarker)) // Omit printing - wordBuffer.delete(0, startMarker.length) - isInToolTag = true - currentToolTagEndMarker = Some(expectedEndMarker) // Set expected end marker - processedSomething = true // Buffer content changed, loop again + if (AppConfig.isDebugMode) { + print(red(startMarker)) + } + wordBuffer.delete(0, startMarker.length) + isInToolTag = true + currentToolTagEndMarker = Some(expectedEndMarker) + processedSomething = true } } else { - // No start tag found in the buffer, process as regular content val (words, remaining) = processWords(wordBuffer.toString()) if (words.nonEmpty) { // Only process if complete words were found words.foreach { case (word, ws) => print(blue(word)); print(ws) } @@ -247,18 +241,18 @@ abstract class BaseChatAgent(prompt: String, model: String = AgentConfig.OpenAIM wordBuffer.delete(0, processedLength) processedSomething = true // Buffer content changed } - // If only `remaining` part is left, wait for the next chunk } } - } // End while(processedSomething && wordBuffer.nonEmpty) + } } - } // End while(it.hasNext) + } // Print out the rest of the word buffer if it has any content if (wordBuffer.nonEmpty) { if (isInToolTag) { - // Should ideally not happen if tags are well-formed, but print red if it does - print(red(wordBuffer.toString())) + if (AppConfig.isDebugMode) { + print(red(wordBuffer.toString())) + } } else { print(blue(wordBuffer.toString())) } diff --git a/src/main/scala/com/supercoder/config/ArgsParser.scala b/src/main/scala/com/supercoder/config/ArgsParser.scala index 0283722..d22377e 100644 --- a/src/main/scala/com/supercoder/config/ArgsParser.scala +++ b/src/main/scala/com/supercoder/config/ArgsParser.scala @@ -2,7 +2,7 @@ package com.supercoder.config import scopt.OParser -case class Config(useCursorRules: Boolean = false, model: String = "") +case class Config(useCursorRules: Boolean = false, model: String = "", isDebugMode: Boolean = false) object ArgsParser { def parse(args: Array[String]): Option[Config] = { @@ -17,6 +17,9 @@ object ArgsParser { opt[String]('m', "model") .action((x, c) => c.copy(model = x)) .text("model to use for the agent"), + opt[String]('d', "debug") + .action((x, c) => c.copy(isDebugMode = (x == "true"))) + .text("enable debug mode"), help("help").text("prints this usage text") ) }