diff --git a/dotnet/src/Aspire.Hosting.AgentFramework.DevUI/DevUIAggregatorHostedService.cs b/dotnet/src/Aspire.Hosting.AgentFramework.DevUI/DevUIAggregatorHostedService.cs index 583f23b6fe..1b89126438 100644 --- a/dotnet/src/Aspire.Hosting.AgentFramework.DevUI/DevUIAggregatorHostedService.cs +++ b/dotnet/src/Aspire.Hosting.AgentFramework.DevUI/DevUIAggregatorHostedService.cs @@ -691,7 +691,14 @@ private static async Task ProxyRequestAsync( var httpClientFactory = context.RequestServices.GetRequiredService(); using var client = httpClientFactory.CreateClient("devui-proxy"); - var targetUri = new Uri(new Uri(backendUrl), path); + var targetUri = ValidateProxyTarget(backendUrl, path); + if (targetUri is null) + { + context.Response.StatusCode = StatusCodes.Status400BadRequest; + await context.Response.WriteAsync("Invalid proxy target.", context.RequestAborted).ConfigureAwait(false); + return; + } + using var request = new HttpRequestMessage(new HttpMethod(context.Request.Method), targetUri); foreach (var header in context.Request.Headers) @@ -795,4 +802,27 @@ private static bool IsHopByHopHeader(string headerName) || headerName.Equals("Keep-Alive", StringComparison.OrdinalIgnoreCase) || headerName.Equals("Host", StringComparison.OrdinalIgnoreCase); } + + /// + /// Validates that constructing a proxy target URI from and + /// does not redirect the request to an unintended host. + /// Returns the validated if safe, or null if the target is invalid. + /// + internal static Uri? ValidateProxyTarget(string backendUrl, string path) + { + if (!Uri.TryCreate(backendUrl, UriKind.Absolute, out var baseUri) || + !Uri.TryCreate(baseUri, path, out var targetUri)) + { + return null; + } + + if (!string.Equals(targetUri.Host, baseUri.Host, StringComparison.OrdinalIgnoreCase) || + !string.Equals(targetUri.Scheme, baseUri.Scheme, StringComparison.OrdinalIgnoreCase) || + targetUri.Port != baseUri.Port) + { + return null; + } + + return targetUri; + } } diff --git a/dotnet/tests/Aspire.Hosting.AgentFramework.DevUI.UnitTests/DevUIAggregatorHostedServiceTests.cs b/dotnet/tests/Aspire.Hosting.AgentFramework.DevUI.UnitTests/DevUIAggregatorHostedServiceTests.cs index 14b1ed2e14..9af48926dc 100644 --- a/dotnet/tests/Aspire.Hosting.AgentFramework.DevUI.UnitTests/DevUIAggregatorHostedServiceTests.cs +++ b/dotnet/tests/Aspire.Hosting.AgentFramework.DevUI.UnitTests/DevUIAggregatorHostedServiceTests.cs @@ -1,9 +1,20 @@ // Copyright (c) Microsoft. All rights reserved. +using System; +using System.Collections.Generic; using System.Linq; +using System.Net; +using System.Net.Http; using System.Net.Sockets; +using System.Threading; +using System.Threading.Tasks; using Aspire.Hosting.ApplicationModel; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Hosting.Server; +using Microsoft.AspNetCore.Hosting.Server.Features; using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; namespace Aspire.Hosting.AgentFramework.DevUI.UnitTests; @@ -166,11 +177,12 @@ public void DevUIResource_NoAnnotations_ResolveBackendsReturnsEmpty() var builder = DistributedApplication.CreateBuilder(); var devui = builder.AddDevUI("devui"); - // Assert - no AgentServiceAnnotation means no backends + // Act var annotations = devui.Resource.Annotations .OfType() .ToList(); + // Assert - no AgentServiceAnnotation means no backends Assert.Empty(annotations); } @@ -330,10 +342,12 @@ public void PrefixedEntityId_Format_ExtractsCorrectly(string prefixedId, string // - prefix is typically the resource name or custom prefix // - entityId is the original entity identifier from the backend + // Act var slashIndex = prefixedId.IndexOf('/'); var prefix = prefixedId[..slashIndex]; var rest = prefixedId[(slashIndex + 1)..]; + // Assert Assert.Equal(expectedPrefix, prefix); Assert.Equal(expectedRest, rest); } @@ -391,4 +405,209 @@ private static EndpointAnnotation AddEndpoint( private sealed class TestEndpointResource(string name) : Resource(name), IResourceWithEndpoints; #endregion + + #region Proxy Target Validation Tests + + [Theory] + [InlineData("http://localhost:5000", "/v1/conversations")] + [InlineData("http://localhost:5000", "/devui/index.html?v=1")] + public void ValidateProxyTarget_TargetStaysOnConfiguredBackend_ReturnsTargetUri(string backendUrl, string path) + { + // Arrange + var backendUri = new Uri(backendUrl); + + // Act + var target = DevUIAggregatorHostedService.ValidateProxyTarget(backendUrl, path); + + // Assert + Assert.NotNull(target); + Assert.Equal(backendUri.Host, target!.Host); + Assert.Equal(backendUri.Scheme, target.Scheme); + Assert.Equal(backendUri.Port, target.Port); + } + + [Theory] + [InlineData("http://localhost:5000", "http://alternate.example/data")] // absolute path overrides the host + [InlineData("http://localhost:5000", "//alternate.example/data")] // protocol-relative path overrides the host + [InlineData("http://localhost:5000", "https://localhost:5000/data")] // scheme differs from the backend + [InlineData("http://localhost:5000", "http://localhost:6000/data")] // port differs from the backend + [InlineData("this is not a url", "/v1/conversations")] // malformed backend url + public void ValidateProxyTarget_TargetLeavesConfiguredBackend_ReturnsNull(string backendUrl, string path) + { + // Act + var target = DevUIAggregatorHostedService.ValidateProxyTarget(backendUrl, path); + + // Assert + Assert.Null(target); + } + + [Fact] + public async Task ProxyRequest_ConversationRoute_ForwardsToConfiguredBackendAsync() + { + // Arrange + await using var proxy = await ProxyTestContext.StartAsync(); + + // Act + var response = await proxy.SendAsync("/v1/conversations?limit=10"); + + // Assert + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + var forwarded = Assert.Single(proxy.BackendRequests); + Assert.Equal("/v1/conversations", forwarded.Path); + Assert.Equal("?limit=10", forwarded.QueryString); + } + + [Fact] + public async Task ProxyRequest_DevUIRoute_ForwardsToConfiguredBackendAsync() + { + // Arrange + await using var proxy = await ProxyTestContext.StartAsync(); + + // Act + var response = await proxy.SendAsync("/devui/index.html?v=1"); + + // Assert + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + var forwarded = Assert.Single(proxy.BackendRequests); + Assert.Equal("/devui/index.html", forwarded.Path); + Assert.Equal("?v=1", forwarded.QueryString); + } + + [Theory] + [InlineData("/v1/conversations/../conversations")] + [InlineData("/devui/../devui/index.html")] + public async Task ProxyRequest_NormalizedPath_ForwardsToConfiguredBackendAsync(string requestPath) + { + // Arrange + await using var proxy = await ProxyTestContext.StartAsync(); + + // Act + var response = await proxy.SendAsync(requestPath); + + // Assert + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + Assert.Single(proxy.BackendRequests); + } + + #region Proxy Test Helpers + + /// + /// Hosts a stub backend together with a DevUI aggregator wired to it, and exposes an + /// targeting the aggregator so proxied requests can be observed + /// on the backend. + /// + private sealed class ProxyTestContext : IAsyncDisposable + { + private readonly WebApplication _backend; + private readonly DevUIAggregatorHostedService _aggregator; + private readonly HttpClient _client; + private readonly List<(string Path, string QueryString)> _backendRequests; + + private ProxyTestContext( + WebApplication backend, + DevUIAggregatorHostedService aggregator, + HttpClient client, + List<(string Path, string QueryString)> backendRequests) + { + this._backend = backend; + this._aggregator = aggregator; + this._client = client; + this._backendRequests = backendRequests; + } + + /// Gets the requests received by the stub backend, in arrival order. + public IReadOnlyList<(string Path, string QueryString)> BackendRequests => this._backendRequests; + + public static async Task StartAsync() + { + var backendRequests = new List<(string Path, string QueryString)>(); + var backend = await StartStubBackendAsync(backendRequests).ConfigureAwait(false); + + var aggregator = await StartAggregatorAsync(GetBaseAddress(backend)).ConfigureAwait(false); + var client = new HttpClient { BaseAddress = new Uri($"http://127.0.0.1:{aggregator.AllocatedPort}") }; + + return new ProxyTestContext(backend, aggregator, client, backendRequests); + } + + /// Sends a GET request to the aggregator using the given relative path. + public Task SendAsync(string relativePath) + => this._client.GetAsync(new Uri(relativePath, UriKind.Relative)); + + public async ValueTask DisposeAsync() + { + this._client.Dispose(); + await this._aggregator.DisposeAsync().ConfigureAwait(false); + await this._backend.StopAsync().ConfigureAwait(false); + await this._backend.DisposeAsync().ConfigureAwait(false); + } + } + + /// + /// Starts a minimal backend that records the path and query string of every request it receives. + /// + private static async Task StartStubBackendAsync(List<(string Path, string QueryString)> requests) + { + var builder = WebApplication.CreateSlimBuilder(); + builder.Logging.ClearProviders(); + + var app = builder.Build(); + app.Urls.Add("http://127.0.0.1:0"); + app.Map("{**path}", (HttpContext context) => + { + requests.Add((context.Request.Path.Value ?? string.Empty, context.Request.QueryString.Value ?? string.Empty)); + return Results.Json(new { ok = true }); + }); + + await app.StartAsync().ConfigureAwait(false); + return app; + } + + /// + /// Starts a DevUI aggregator configured with a single backend pointing at . + /// + private static async Task StartAggregatorAsync(string backendUrl) + { + var resource = new DevUIResource("test-devui"); + resource.Annotations.Add(new AgentServiceAnnotation(CreateBackendResource(backendUrl))); + + using var loggerFactory = LoggerFactory.Create(_ => { }); + var aggregator = new DevUIAggregatorHostedService( + resource, + loggerFactory.CreateLogger()); + + await aggregator.StartAsync(CancellationToken.None).ConfigureAwait(false); + return aggregator; + } + + /// + /// Creates a backend resource whose "http" endpoint is allocated to . + /// + private static TestBackendResource CreateBackendResource(string backendUrl) + { + var backendUri = new Uri(backendUrl); + var resource = new TestBackendResource("test-backend"); + + var endpoint = new EndpointAnnotation( + ProtocolType.Tcp, + uriScheme: "http", + name: "http", + port: backendUri.Port, + isProxied: false) + { + TargetHost = backendUri.Host + }; + endpoint.AllocatedEndpoint = new AllocatedEndpoint(endpoint, backendUri.Host, backendUri.Port); + + resource.Annotations.Add(endpoint); + return resource; + } + + private static string GetBaseAddress(WebApplication app) + => app.Services.GetRequiredService().Features.Get()!.Addresses.First(); + + private sealed class TestBackendResource(string name) : Resource(name), IResourceWithEndpoints; + + #endregion + + #endregion }