diff --git a/dotnet/src/CopilotRequestHandler.cs b/dotnet/src/CopilotRequestHandler.cs index 868f2efe6..b8292b1b4 100644 --- a/dotnet/src/CopilotRequestHandler.cs +++ b/dotnet/src/CopilotRequestHandler.cs @@ -3,6 +3,7 @@ *--------------------------------------------------------------------------------------------*/ using GitHub.Copilot.Rpc; +using System.Buffers; using System.Collections.Concurrent; using System.Diagnostics.CodeAnalysis; using System.Net.WebSockets; @@ -76,13 +77,10 @@ public readonly struct CopilotWebSocketMessage(ReadOnlyMemory data, bool i public bool IsBinary { get; } = isBinary; /// Decodes the payload as UTF-8 text. - public string GetText() => Encoding.UTF8.GetString(Data.ToArray()); + public string GetText() => Encoding.UTF8.GetString(Data.Span); /// Creates a text message from a UTF-8 string. - public static CopilotWebSocketMessage Text(string text) => new(Encoding.UTF8.GetBytes(text), isBinary: false); - - /// Creates a binary message from raw bytes. - public static CopilotWebSocketMessage Binary(ReadOnlyMemory data) => new(data, isBinary: true); + public static CopilotWebSocketMessage FromText(string text) => new(Encoding.UTF8.GetBytes(text), isBinary: false); } /// @@ -253,7 +251,12 @@ internal override async Task OpenAsync() await socket.ConnectAsync(ToWebSocketUri(_url), Context.CancellationToken).ConfigureAwait(false); _upstream = socket; _pumpCts = CancellationTokenSource.CreateLinkedTokenSource(Context.CancellationToken); - _responsePump = Task.Run(() => PumpResponsesAsync(_pumpCts.Token), _pumpCts.Token); + + // Start the pump without a cancellation token on Task.Run itself: if the + // linked token is already cancelled, we still want PumpResponsesAsync to + // run so its cleanup (closing the upstream and finalising the response) + // executes rather than the task being cancelled before it ever starts. + _responsePump = Task.Run(() => PumpResponsesAsync(_pumpCts.Token)); } /// @@ -270,10 +273,10 @@ public override Task SendRequestMessageAsync(CopilotWebSocketMessage message) var type = message.IsBinary ? WebSocketMessageType.Binary : WebSocketMessageType.Text; return _upstream.SendAsync( - new ArraySegment(message.Data.ToArray()), + message.Data, type, endOfMessage: true, - Context.CancellationToken); + Context.CancellationToken).AsTask(); } /// @@ -346,34 +349,41 @@ await CloseAsync(new CopilotWebSocketCloseStatus private static async Task ReceiveMessageAsync(WebSocket socket, CancellationToken cancellationToken) { - var buffer = new byte[16 * 1024]; - using var assembled = new MemoryStream(); - WebSocketReceiveResult result; - do + var buffer = ArrayPool.Shared.Rent(16 * 1024); + try { - try - { - result = await socket.ReceiveAsync(new ArraySegment(buffer), cancellationToken).ConfigureAwait(false); - } - catch (OperationCanceledException) - { - return null; - } - catch (WebSocketException) + using var assembled = new MemoryStream(); + ValueWebSocketReceiveResult result; + do { - return null; - } + try + { + result = await socket.ReceiveAsync(buffer.AsMemory(), cancellationToken).ConfigureAwait(false); + } + catch (OperationCanceledException) + { + return null; + } + catch (WebSocketException) + { + return null; + } - if (result.MessageType == WebSocketMessageType.Close) - { - return null; + if (result.MessageType == WebSocketMessageType.Close) + { + return null; + } + + assembled.Write(buffer, 0, result.Count); } + while (!result.EndOfMessage); - assembled.Write(buffer, 0, result.Count); + return new CopilotWebSocketMessage(assembled.ToArray(), result.MessageType == WebSocketMessageType.Binary); + } + finally + { + ArrayPool.Shared.Return(buffer); } - while (!result.EndOfMessage); - - return new CopilotWebSocketMessage(assembled.ToArray(), result.MessageType == WebSocketMessageType.Binary); } private static async Task CloseWebSocketQuietlyAsync(WebSocket socket) @@ -430,13 +440,34 @@ public class CopilotRequestHandler { private static readonly HttpClient s_sharedHttpClient = new(); + private readonly HttpClient _httpClient; + + /// + /// Initializes a new instance that issues upstream requests using a shared + /// process-wide . + /// + public CopilotRequestHandler() + : this(null) + { + } + + /// + /// Initializes a new instance that issues upstream requests using the supplied + /// , or a shared process-wide instance when is . + /// + /// The to use, or to use the shared instance. + public CopilotRequestHandler(HttpClient? httpClient) + { + _httpClient = httpClient ?? s_sharedHttpClient; + } + /// /// Issue the upstream HTTP request. Override to mutate the request before /// calling base, mutate the returned response after, or replace the /// call entirely. /// protected virtual Task SendRequestAsync(HttpRequestMessage request, CopilotRequestContext ctx) => - s_sharedHttpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, ctx.CancellationToken); + _httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, ctx.CancellationToken); /// /// Open the upstream WebSocket connection. Override to return a custom @@ -464,7 +495,7 @@ private async Task HandleHttpAsync(LlmInferenceExchange exchange) private static async Task BuildHttpRequestAsync(LlmInferenceExchange exchange) { - var method = new HttpMethod(exchange.Method.ToUpperInvariant()); + var method = new HttpMethod(exchange.Method); var message = new HttpRequestMessage(method, exchange.Context.Url); var hasBody = method != HttpMethod.Get && method != HttpMethod.Head; @@ -499,18 +530,10 @@ await exchange.StartResponseAsync( HeadersToMultiMap(response)).ConfigureAwait(false); var ct = exchange.Context.CancellationToken; -#if NETSTANDARD2_0 - using var stream = await response.Content.ReadAsStreamAsync().ConfigureAwait(false); -#else using var stream = await response.Content.ReadAsStreamAsync(ct).ConfigureAwait(false); -#endif var buffer = new byte[16 * 1024]; int read; -#if NETSTANDARD2_0 - while ((read = await stream.ReadAsync(buffer, 0, buffer.Length, ct).ConfigureAwait(false)) > 0) -#else while ((read = await stream.ReadAsync(buffer.AsMemory(), ct).ConfigureAwait(false)) > 0) -#endif { await exchange.WriteResponseAsync(new ReadOnlyMemory(buffer, 0, read)).ConfigureAwait(false); } @@ -579,7 +602,7 @@ private static async Task DrainAsync(IAsyncEnumerable 0) { - buffer.Write(chunk.ToArray(), 0, chunk.Length); + buffer.Write(chunk.Span); } } diff --git a/dotnet/src/Polyfills/DownlevelExtensions.cs b/dotnet/src/Polyfills/DownlevelExtensions.cs index 17c98643e..fc611010c 100644 --- a/dotnet/src/Polyfills/DownlevelExtensions.cs +++ b/dotnet/src/Polyfills/DownlevelExtensions.cs @@ -376,6 +376,25 @@ public async ValueTask ReadExactlyAsync(Memory buffer, Threading.Cancellat totalRead += bytesRead; } } + + public void Write(ReadOnlySpan buffer) + { + if (buffer.IsEmpty) + { + return; + } + + var rented = ArrayPool.Shared.Rent(buffer.Length); + try + { + buffer.CopyTo(rented); + stream.Write(rented, 0, buffer.Length); + } + finally + { + ArrayPool.Shared.Return(rented); + } + } } private static async ValueTask ReadAsyncSlow(Stream stream, Memory buffer, Threading.CancellationToken cancellationToken) @@ -646,3 +665,125 @@ public async Task WaitAsync(TimeSpan timeout, CancellationToken cancellationT } } } + +namespace System.Text +{ + internal static class DownlevelEncodingExtensions + { + extension(Encoding encoding) + { + public string GetString(ReadOnlySpan bytes) + { + if (bytes.IsEmpty) + { + return string.Empty; + } + + var rented = ArrayPool.Shared.Rent(bytes.Length); + try + { + bytes.CopyTo(rented); + return encoding.GetString(rented, 0, bytes.Length); + } + finally + { + ArrayPool.Shared.Return(rented); + } + } + } + } +} + +namespace System.Net.Http +{ + internal static class DownlevelHttpContentExtensions + { + extension(HttpContent content) + { + public Task ReadAsStreamAsync(Threading.CancellationToken cancellationToken) + { + // The underlying netstandard2.0 ReadAsStreamAsync() can't be cancelled, + // but honour an already-cancelled token to match the BCL overload. + cancellationToken.ThrowIfCancellationRequested(); + return content.ReadAsStreamAsync(); + } + } + } +} + +namespace System.Net.WebSockets +{ + /// + /// Polyfill for the System.Net.WebSockets.ValueWebSocketReceiveResult + /// struct, which is unavailable on .NET Standard 2.0. + /// + internal readonly struct ValueWebSocketReceiveResult + { + public ValueWebSocketReceiveResult(int count, WebSocketMessageType messageType, bool endOfMessage) + { + Count = count; + MessageType = messageType; + EndOfMessage = endOfMessage; + } + + public int Count { get; } + + public WebSocketMessageType MessageType { get; } + + public bool EndOfMessage { get; } + } + + internal static class DownlevelWebSocketExtensions + { + extension(WebSocket socket) + { + public ValueTask SendAsync(ReadOnlyMemory buffer, WebSocketMessageType messageType, bool endOfMessage, Threading.CancellationToken cancellationToken) + { + if (Runtime.InteropServices.MemoryMarshal.TryGetArray(buffer, out ArraySegment segment)) + { + return new ValueTask(socket.SendAsync(segment, messageType, endOfMessage, cancellationToken)); + } + + return SendAsyncSlow(socket, buffer, messageType, endOfMessage, cancellationToken); + } + + public ValueTask ReceiveAsync(Memory buffer, Threading.CancellationToken cancellationToken) => + ReceiveAsyncCore(socket, buffer, cancellationToken); + } + + private static async ValueTask SendAsyncSlow(WebSocket socket, ReadOnlyMemory buffer, WebSocketMessageType messageType, bool endOfMessage, Threading.CancellationToken cancellationToken) + { + var rented = ArrayPool.Shared.Rent(buffer.Length); + try + { + buffer.CopyTo(rented); + await socket.SendAsync(new ArraySegment(rented, 0, buffer.Length), messageType, endOfMessage, cancellationToken).ConfigureAwait(false); + } + finally + { + ArrayPool.Shared.Return(rented); + } + } + + private static async ValueTask ReceiveAsyncCore(WebSocket socket, Memory buffer, Threading.CancellationToken cancellationToken) + { + if (Runtime.InteropServices.MemoryMarshal.TryGetArray(buffer, out ArraySegment segment)) + { + var result = await socket.ReceiveAsync(segment, cancellationToken).ConfigureAwait(false); + return new ValueWebSocketReceiveResult(result.Count, result.MessageType, result.EndOfMessage); + } + + var rented = ArrayPool.Shared.Rent(buffer.Length); + try + { + var result = await socket.ReceiveAsync(new ArraySegment(rented, 0, buffer.Length), cancellationToken).ConfigureAwait(false); + new ReadOnlyMemory(rented, 0, result.Count).CopyTo(buffer); + return new ValueWebSocketReceiveResult(result.Count, result.MessageType, result.EndOfMessage); + } + finally + { + ArrayPool.Shared.Return(rented); + } + } + } +} diff --git a/go/copilot_request_handler.go b/go/copilot_request_handler.go index 632d6a638..c1ac52bc2 100644 --- a/go/copilot_request_handler.go +++ b/go/copilot_request_handler.go @@ -90,16 +90,12 @@ type CopilotWebSocketMessage struct { // Text decodes the frame payload as a UTF-8 string. func (m CopilotWebSocketMessage) Text() string { return string(m.Data) } -// NewTextMessage creates a text-frame message from a UTF-8 string. +// NewTextMessage creates a text-frame message from a UTF-8 string. Binary +// frames are constructed directly with CopilotWebSocketMessage{Data: ..., Binary: true}. func NewTextMessage(text string) CopilotWebSocketMessage { return CopilotWebSocketMessage{Data: []byte(text), Binary: false} } -// NewBinaryMessage creates a binary-frame message from raw bytes. -func NewBinaryMessage(data []byte) CopilotWebSocketMessage { - return CopilotWebSocketMessage{Data: data, Binary: true} -} - // CopilotRequestHandler is the idiomatic handler for intercepting or replacing // LLM inference requests. HTTP requests are forwarded through Transport (an // [http.RoundTripper]); supply a custom RoundTripper to mutate the request, @@ -227,9 +223,9 @@ func streamResponseToSink(resp *http.Response, sink *responseSink) error { for { n, readErr := resp.Body.Read(buf) if n > 0 { - frame := make([]byte, n) - copy(frame, buf[:n]) - if err := sink.writeText(frame); err != nil { + // writeText copies eagerly via string(...), so the reused read + // buffer can be passed directly without an extra per-chunk alloc. + if err := sink.writeText(buf[:n]); err != nil { return err } } diff --git a/java/src/main/java/com/github/copilot/CopilotRequestHandler.java b/java/src/main/java/com/github/copilot/CopilotRequestHandler.java index 64cb7f0a2..7b34b20e7 100644 --- a/java/src/main/java/com/github/copilot/CopilotRequestHandler.java +++ b/java/src/main/java/com/github/copilot/CopilotRequestHandler.java @@ -143,9 +143,7 @@ private static void streamResponse(HttpResponse response, LlmInfere int n; while ((n = body.read(buffer)) != -1) { if (n > 0) { - byte[] frame = new byte[n]; - System.arraycopy(buffer, 0, frame, 0, n); - exchange.writeResponseBinary(frame); + exchange.writeResponseBinary(buffer, 0, n); } } } catch (IOException e) { diff --git a/java/src/main/java/com/github/copilot/CopilotWebSocketMessage.java b/java/src/main/java/com/github/copilot/CopilotWebSocketMessage.java index 748709b5f..87ffaf7fe 100644 --- a/java/src/main/java/com/github/copilot/CopilotWebSocketMessage.java +++ b/java/src/main/java/com/github/copilot/CopilotWebSocketMessage.java @@ -35,18 +35,7 @@ public String text() { * the text payload * @return a text message */ - public static CopilotWebSocketMessage text(String text) { + public static CopilotWebSocketMessage fromText(String text) { return new CopilotWebSocketMessage(text.getBytes(StandardCharsets.UTF_8), false); } - - /** - * Creates a binary message from raw bytes. - * - * @param data - * the binary payload - * @return a binary message - */ - public static CopilotWebSocketMessage binary(byte[] data) { - return new CopilotWebSocketMessage(data, true); - } } diff --git a/java/src/main/java/com/github/copilot/LlmInferenceExchange.java b/java/src/main/java/com/github/copilot/LlmInferenceExchange.java index 9c2bbe40c..67933e40b 100644 --- a/java/src/main/java/com/github/copilot/LlmInferenceExchange.java +++ b/java/src/main/java/com/github/copilot/LlmInferenceExchange.java @@ -6,6 +6,8 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; import java.util.Base64; import java.util.List; import java.util.Map; @@ -196,6 +198,11 @@ void writeResponseBinary(byte[] data) throws IOException { writeChunk(Base64.getEncoder().encodeToString(data), true); } + void writeResponseBinary(byte[] data, int offset, int length) throws IOException { + ByteBuffer encoded = Base64.getEncoder().encode(ByteBuffer.wrap(data, offset, length)); + writeChunk(new String(encoded.array(), 0, encoded.limit(), StandardCharsets.ISO_8859_1), true); + } + void endResponse() throws IOException { synchronized (lock) { if (finished) { diff --git a/rust/src/copilot_request_handler.rs b/rust/src/copilot_request_handler.rs index 09f6e0745..b686b6ead 100644 --- a/rust/src/copilot_request_handler.rs +++ b/rust/src/copilot_request_handler.rs @@ -209,18 +209,14 @@ pub struct CopilotWebSocketMessage { } impl CopilotWebSocketMessage { - /// A UTF-8 text message. - pub fn text(data: impl Into) -> Self { + /// A UTF-8 text message. Binary messages are constructed directly via the + /// public `data` / `binary` fields. + pub fn from_text(data: impl Into) -> Self { Self { data: data.into().into_bytes(), binary: false, } } - - /// A binary message. - pub fn binary(data: Vec) -> Self { - Self { data, binary: true } - } } /// The runtime-facing side of a WebSocket: a [`CopilotWebSocketHandler`] writes @@ -477,13 +473,13 @@ impl CopilotWebSocketForwarderBuilder { _ = loop_cancel.cancelled() => break, msg = read.next() => match msg { Some(Ok(Message::Text(text))) => { - let message = CopilotWebSocketMessage::text(text); + let message = CopilotWebSocketMessage::from_text(text); if let Some(out) = apply_transform(&on_response, message) { let _ = response.send_message(out).await; } } Some(Ok(Message::Binary(data))) => { - let message = CopilotWebSocketMessage::binary(data); + let message = CopilotWebSocketMessage { data, binary: true }; if let Some(out) = apply_transform(&on_response, message) { let _ = response.send_message(out).await; } @@ -542,7 +538,11 @@ impl CopilotWebSocketHandler for CopilotWebSocketForwarder { let ws_message = if message.binary { Message::Binary(message.data) } else { - Message::Text(String::from_utf8_lossy(&message.data).into_owned()) + let text = match String::from_utf8(message.data) { + Ok(text) => text, + Err(err) => String::from_utf8_lossy(err.as_bytes()).into_owned(), + }; + Message::Text(text) }; let mut guard = self.write.lock().await; if let Some(write) = guard.as_mut() {