Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 64 additions & 41 deletions dotnet/src/CopilotRequestHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
*--------------------------------------------------------------------------------------------*/

using GitHub.Copilot.Rpc;
using System.Buffers;
using System.Collections.Concurrent;
using System.Diagnostics.CodeAnalysis;
using System.Net.WebSockets;
Expand Down Expand Up @@ -76,13 +77,10 @@ public readonly struct CopilotWebSocketMessage(ReadOnlyMemory<byte> data, bool i
public bool IsBinary { get; } = isBinary;

/// <summary>Decodes the payload as UTF-8 text.</summary>
public string GetText() => Encoding.UTF8.GetString(Data.ToArray());
public string GetText() => Encoding.UTF8.GetString(Data.Span);

/// <summary>Creates a text message from a UTF-8 string.</summary>
public static CopilotWebSocketMessage Text(string text) => new(Encoding.UTF8.GetBytes(text), isBinary: false);

/// <summary>Creates a binary message from raw bytes.</summary>
public static CopilotWebSocketMessage Binary(ReadOnlyMemory<byte> data) => new(data, isBinary: true);
public static CopilotWebSocketMessage FromText(string text) => new(Encoding.UTF8.GetBytes(text), isBinary: false);
}

/// <summary>
Expand Down Expand Up @@ -253,7 +251,12 @@ internal override async Task OpenAsync()
await socket.ConnectAsync(ToWebSocketUri(_url), Context.CancellationToken).ConfigureAwait(false);
_upstream = socket;
_pumpCts = CancellationTokenSource.CreateLinkedTokenSource(Context.CancellationToken);
_responsePump = Task.Run(() => PumpResponsesAsync(_pumpCts.Token), _pumpCts.Token);

// Start the pump without a cancellation token on Task.Run itself: if the
// linked token is already cancelled, we still want PumpResponsesAsync to
// run so its cleanup (closing the upstream and finalising the response)
// executes rather than the task being cancelled before it ever starts.
_responsePump = Task.Run(() => PumpResponsesAsync(_pumpCts.Token));
}

/// <summary>
Expand All @@ -270,10 +273,10 @@ public override Task SendRequestMessageAsync(CopilotWebSocketMessage message)

var type = message.IsBinary ? WebSocketMessageType.Binary : WebSocketMessageType.Text;
return _upstream.SendAsync(
new ArraySegment<byte>(message.Data.ToArray()),
message.Data,
type,
endOfMessage: true,
Context.CancellationToken);
Context.CancellationToken).AsTask();
}

/// <inheritdoc />
Expand Down Expand Up @@ -346,34 +349,41 @@ await CloseAsync(new CopilotWebSocketCloseStatus

private static async Task<CopilotWebSocketMessage?> ReceiveMessageAsync(WebSocket socket, CancellationToken cancellationToken)
{
var buffer = new byte[16 * 1024];
using var assembled = new MemoryStream();
WebSocketReceiveResult result;
do
var buffer = ArrayPool<byte>.Shared.Rent(16 * 1024);
try
{
try
{
result = await socket.ReceiveAsync(new ArraySegment<byte>(buffer), cancellationToken).ConfigureAwait(false);
}
catch (OperationCanceledException)
{
return null;
}
catch (WebSocketException)
using var assembled = new MemoryStream();
ValueWebSocketReceiveResult result;
do
{
return null;
}
try
{
result = await socket.ReceiveAsync(buffer.AsMemory(), cancellationToken).ConfigureAwait(false);
}
catch (OperationCanceledException)
{
return null;
}
catch (WebSocketException)
{
return null;
}

if (result.MessageType == WebSocketMessageType.Close)
{
return null;
if (result.MessageType == WebSocketMessageType.Close)
{
return null;
}

assembled.Write(buffer, 0, result.Count);
}
while (!result.EndOfMessage);

assembled.Write(buffer, 0, result.Count);
return new CopilotWebSocketMessage(assembled.ToArray(), result.MessageType == WebSocketMessageType.Binary);
}
finally
{
ArrayPool<byte>.Shared.Return(buffer);
}
while (!result.EndOfMessage);

return new CopilotWebSocketMessage(assembled.ToArray(), result.MessageType == WebSocketMessageType.Binary);
}

private static async Task CloseWebSocketQuietlyAsync(WebSocket socket)
Expand Down Expand Up @@ -430,13 +440,34 @@ public class CopilotRequestHandler
{
private static readonly HttpClient s_sharedHttpClient = new();

private readonly HttpClient _httpClient;

/// <summary>
/// Initializes a new instance that issues upstream requests using a shared
/// process-wide <see cref="HttpClient"/>.
/// </summary>
public CopilotRequestHandler()
: this(null)
{
}

/// <summary>
/// Initializes a new instance that issues upstream requests using the supplied
/// <see cref="HttpClient"/>, or a shared process-wide instance when <paramref name="httpClient"/> is <see langword="null"/>.
/// </summary>
/// <param name="httpClient">The <see cref="HttpClient"/> to use, or <see langword="null"/> to use the shared instance.</param>
public CopilotRequestHandler(HttpClient? httpClient)
{
_httpClient = httpClient ?? s_sharedHttpClient;
}

/// <summary>
/// Issue the upstream HTTP request. Override to mutate the request before
/// calling <c>base</c>, mutate the returned response after, or replace the
/// call entirely.
/// </summary>
protected virtual Task<HttpResponseMessage> SendRequestAsync(HttpRequestMessage request, CopilotRequestContext ctx) =>
s_sharedHttpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, ctx.CancellationToken);
_httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, ctx.CancellationToken);

/// <summary>
/// Open the upstream WebSocket connection. Override to return a custom
Expand Down Expand Up @@ -464,7 +495,7 @@ private async Task HandleHttpAsync(LlmInferenceExchange exchange)

private static async Task<HttpRequestMessage> BuildHttpRequestAsync(LlmInferenceExchange exchange)
{
var method = new HttpMethod(exchange.Method.ToUpperInvariant());
var method = new HttpMethod(exchange.Method);
var message = new HttpRequestMessage(method, exchange.Context.Url);

var hasBody = method != HttpMethod.Get && method != HttpMethod.Head;
Expand Down Expand Up @@ -499,18 +530,10 @@ await exchange.StartResponseAsync(
HeadersToMultiMap(response)).ConfigureAwait(false);

var ct = exchange.Context.CancellationToken;
#if NETSTANDARD2_0
using var stream = await response.Content.ReadAsStreamAsync().ConfigureAwait(false);
#else
using var stream = await response.Content.ReadAsStreamAsync(ct).ConfigureAwait(false);
#endif
var buffer = new byte[16 * 1024];
int read;
#if NETSTANDARD2_0
while ((read = await stream.ReadAsync(buffer, 0, buffer.Length, ct).ConfigureAwait(false)) > 0)
#else
while ((read = await stream.ReadAsync(buffer.AsMemory(), ct).ConfigureAwait(false)) > 0)
#endif
{
await exchange.WriteResponseAsync(new ReadOnlyMemory<byte>(buffer, 0, read)).ConfigureAwait(false);
}
Expand Down Expand Up @@ -579,7 +602,7 @@ private static async Task<byte[]> DrainAsync(IAsyncEnumerable<ReadOnlyMemory<byt
{
if (chunk.Length > 0)
{
buffer.Write(chunk.ToArray(), 0, chunk.Length);
buffer.Write(chunk.Span);
}
}

Expand Down
141 changes: 141 additions & 0 deletions dotnet/src/Polyfills/DownlevelExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,25 @@ public async ValueTask ReadExactlyAsync(Memory<byte> buffer, Threading.Cancellat
totalRead += bytesRead;
}
}

public void Write(ReadOnlySpan<byte> buffer)
{
if (buffer.IsEmpty)
{
return;
}

var rented = ArrayPool<byte>.Shared.Rent(buffer.Length);
try
{
buffer.CopyTo(rented);
stream.Write(rented, 0, buffer.Length);
}
finally
{
ArrayPool<byte>.Shared.Return(rented);
}
}
}

private static async ValueTask<int> ReadAsyncSlow(Stream stream, Memory<byte> buffer, Threading.CancellationToken cancellationToken)
Expand Down Expand Up @@ -646,3 +665,125 @@ public async Task<T> WaitAsync(TimeSpan timeout, CancellationToken cancellationT
}
}
}

namespace System.Text
{
internal static class DownlevelEncodingExtensions
{
extension(Encoding encoding)
{
public string GetString(ReadOnlySpan<byte> bytes)
{
if (bytes.IsEmpty)
{
return string.Empty;
}

var rented = ArrayPool<byte>.Shared.Rent(bytes.Length);
try
{
bytes.CopyTo(rented);
return encoding.GetString(rented, 0, bytes.Length);
}
finally
{
ArrayPool<byte>.Shared.Return(rented);
}
}
}
}
}

namespace System.Net.Http
{
internal static class DownlevelHttpContentExtensions
{
extension(HttpContent content)
{
public Task<IO.Stream> ReadAsStreamAsync(Threading.CancellationToken cancellationToken)
{
// The underlying netstandard2.0 ReadAsStreamAsync() can't be cancelled,
// but honour an already-cancelled token to match the BCL overload.
cancellationToken.ThrowIfCancellationRequested();
return content.ReadAsStreamAsync();
}
}
}
}

namespace System.Net.WebSockets
{
/// <summary>
/// Polyfill for the <c>System.Net.WebSockets.ValueWebSocketReceiveResult</c>
/// struct, which is unavailable on .NET Standard 2.0.
/// </summary>
internal readonly struct ValueWebSocketReceiveResult
{
public ValueWebSocketReceiveResult(int count, WebSocketMessageType messageType, bool endOfMessage)
{
Count = count;
MessageType = messageType;
EndOfMessage = endOfMessage;
}

public int Count { get; }

public WebSocketMessageType MessageType { get; }

public bool EndOfMessage { get; }
}

internal static class DownlevelWebSocketExtensions
{
extension(WebSocket socket)
{
public ValueTask SendAsync(ReadOnlyMemory<byte> buffer, WebSocketMessageType messageType, bool endOfMessage, Threading.CancellationToken cancellationToken)
{
if (Runtime.InteropServices.MemoryMarshal.TryGetArray(buffer, out ArraySegment<byte> segment))
{
return new ValueTask(socket.SendAsync(segment, messageType, endOfMessage, cancellationToken));
}

return SendAsyncSlow(socket, buffer, messageType, endOfMessage, cancellationToken);
}

public ValueTask<ValueWebSocketReceiveResult> ReceiveAsync(Memory<byte> buffer, Threading.CancellationToken cancellationToken) =>
ReceiveAsyncCore(socket, buffer, cancellationToken);
}

private static async ValueTask SendAsyncSlow(WebSocket socket, ReadOnlyMemory<byte> buffer, WebSocketMessageType messageType, bool endOfMessage, Threading.CancellationToken cancellationToken)
{
var rented = ArrayPool<byte>.Shared.Rent(buffer.Length);
try
{
buffer.CopyTo(rented);
await socket.SendAsync(new ArraySegment<byte>(rented, 0, buffer.Length), messageType, endOfMessage, cancellationToken).ConfigureAwait(false);
}
finally
{
ArrayPool<byte>.Shared.Return(rented);
}
}

private static async ValueTask<ValueWebSocketReceiveResult> ReceiveAsyncCore(WebSocket socket, Memory<byte> buffer, Threading.CancellationToken cancellationToken)
{
if (Runtime.InteropServices.MemoryMarshal.TryGetArray(buffer, out ArraySegment<byte> segment))
{
var result = await socket.ReceiveAsync(segment, cancellationToken).ConfigureAwait(false);
return new ValueWebSocketReceiveResult(result.Count, result.MessageType, result.EndOfMessage);
}

var rented = ArrayPool<byte>.Shared.Rent(buffer.Length);
try
{
var result = await socket.ReceiveAsync(new ArraySegment<byte>(rented, 0, buffer.Length), cancellationToken).ConfigureAwait(false);
new ReadOnlyMemory<byte>(rented, 0, result.Count).CopyTo(buffer);
return new ValueWebSocketReceiveResult(result.Count, result.MessageType, result.EndOfMessage);
}
finally
{
ArrayPool<byte>.Shared.Return(rented);
}
}
}
}
14 changes: 5 additions & 9 deletions go/copilot_request_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,16 +90,12 @@ type CopilotWebSocketMessage struct {
// Text decodes the frame payload as a UTF-8 string.
func (m CopilotWebSocketMessage) Text() string { return string(m.Data) }

// NewTextMessage creates a text-frame message from a UTF-8 string.
// NewTextMessage creates a text-frame message from a UTF-8 string. Binary
// frames are constructed directly with CopilotWebSocketMessage{Data: ..., Binary: true}.
func NewTextMessage(text string) CopilotWebSocketMessage {
return CopilotWebSocketMessage{Data: []byte(text), Binary: false}
}

// NewBinaryMessage creates a binary-frame message from raw bytes.
func NewBinaryMessage(data []byte) CopilotWebSocketMessage {
return CopilotWebSocketMessage{Data: data, Binary: true}
}

// CopilotRequestHandler is the idiomatic handler for intercepting or replacing
// LLM inference requests. HTTP requests are forwarded through Transport (an
// [http.RoundTripper]); supply a custom RoundTripper to mutate the request,
Expand Down Expand Up @@ -227,9 +223,9 @@ func streamResponseToSink(resp *http.Response, sink *responseSink) error {
for {
n, readErr := resp.Body.Read(buf)
if n > 0 {
frame := make([]byte, n)
copy(frame, buf[:n])
if err := sink.writeText(frame); err != nil {
// writeText copies eagerly via string(...), so the reused read
// buffer can be passed directly without an extra per-chunk alloc.
if err := sink.writeText(buf[:n]); err != nil {
return err
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,7 @@ private static void streamResponse(HttpResponse<InputStream> response, LlmInfere
int n;
while ((n = body.read(buffer)) != -1) {
if (n > 0) {
byte[] frame = new byte[n];
System.arraycopy(buffer, 0, frame, 0, n);
exchange.writeResponseBinary(frame);
exchange.writeResponseBinary(buffer, 0, n);
}
}
} catch (IOException e) {
Expand Down
Loading
Loading