Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -691,7 +691,14 @@ private static async Task ProxyRequestAsync(
var httpClientFactory = context.RequestServices.GetRequiredService<IHttpClientFactory>();
using var client = httpClientFactory.CreateClient("devui-proxy");

var targetUri = new Uri(new Uri(backendUrl), path);
var targetUri = ValidateProxyTarget(backendUrl, path);
if (targetUri is null)
{
context.Response.StatusCode = StatusCodes.Status400BadRequest;
await context.Response.WriteAsync("Invalid proxy target.", context.RequestAborted).ConfigureAwait(false);
return;
}

using var request = new HttpRequestMessage(new HttpMethod(context.Request.Method), targetUri);

foreach (var header in context.Request.Headers)
Expand Down Expand Up @@ -795,4 +802,27 @@ private static bool IsHopByHopHeader(string headerName)
|| headerName.Equals("Keep-Alive", StringComparison.OrdinalIgnoreCase)
|| headerName.Equals("Host", StringComparison.OrdinalIgnoreCase);
}

/// <summary>
/// Validates that constructing a proxy target URI from <paramref name="backendUrl"/> and
/// <paramref name="path"/> does not redirect the request to an unintended host.
/// Returns the validated <see cref="Uri"/> if safe, or <c>null</c> if the target is invalid.
/// </summary>
internal static Uri? ValidateProxyTarget(string backendUrl, string path)
{
if (!Uri.TryCreate(backendUrl, UriKind.Absolute, out var baseUri) ||
!Uri.TryCreate(baseUri, path, out var targetUri))
{
return null;
}

if (!string.Equals(targetUri.Host, baseUri.Host, StringComparison.OrdinalIgnoreCase) ||
!string.Equals(targetUri.Scheme, baseUri.Scheme, StringComparison.OrdinalIgnoreCase) ||
targetUri.Port != baseUri.Port)
{
return null;
}

return targetUri;
}
}
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Net;
using System.Net.Http;
using System.Net.Sockets;
using System.Threading;
using System.Threading.Tasks;
using Aspire.Hosting.ApplicationModel;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Hosting.Server;
using Microsoft.AspNetCore.Hosting.Server.Features;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;

namespace Aspire.Hosting.AgentFramework.DevUI.UnitTests;
Expand Down Expand Up @@ -166,11 +177,12 @@ public void DevUIResource_NoAnnotations_ResolveBackendsReturnsEmpty()
var builder = DistributedApplication.CreateBuilder();
var devui = builder.AddDevUI("devui");

// Assert - no AgentServiceAnnotation means no backends
// Act
var annotations = devui.Resource.Annotations
.OfType<AgentServiceAnnotation>()
.ToList();

// Assert - no AgentServiceAnnotation means no backends
Assert.Empty(annotations);
}

Expand Down Expand Up @@ -330,10 +342,12 @@ public void PrefixedEntityId_Format_ExtractsCorrectly(string prefixedId, string
// - prefix is typically the resource name or custom prefix
// - entityId is the original entity identifier from the backend

// Act
var slashIndex = prefixedId.IndexOf('/');
var prefix = prefixedId[..slashIndex];
var rest = prefixedId[(slashIndex + 1)..];

// Assert
Assert.Equal(expectedPrefix, prefix);
Assert.Equal(expectedRest, rest);
}
Expand Down Expand Up @@ -391,4 +405,209 @@ private static EndpointAnnotation AddEndpoint(
private sealed class TestEndpointResource(string name) : Resource(name), IResourceWithEndpoints;

#endregion

#region Proxy Target Validation Tests

[Theory]
[InlineData("http://localhost:5000", "/v1/conversations")]
[InlineData("http://localhost:5000", "/devui/index.html?v=1")]
public void ValidateProxyTarget_TargetStaysOnConfiguredBackend_ReturnsTargetUri(string backendUrl, string path)
{
// Arrange
var backendUri = new Uri(backendUrl);

// Act
var target = DevUIAggregatorHostedService.ValidateProxyTarget(backendUrl, path);

// Assert
Assert.NotNull(target);
Assert.Equal(backendUri.Host, target!.Host);
Assert.Equal(backendUri.Scheme, target.Scheme);
Assert.Equal(backendUri.Port, target.Port);
}

[Theory]
[InlineData("http://localhost:5000", "http://alternate.example/data")] // absolute path overrides the host
[InlineData("http://localhost:5000", "//alternate.example/data")] // protocol-relative path overrides the host
[InlineData("http://localhost:5000", "https://localhost:5000/data")] // scheme differs from the backend
[InlineData("http://localhost:5000", "http://localhost:6000/data")] // port differs from the backend
[InlineData("this is not a url", "/v1/conversations")] // malformed backend url
public void ValidateProxyTarget_TargetLeavesConfiguredBackend_ReturnsNull(string backendUrl, string path)
{
// Act
var target = DevUIAggregatorHostedService.ValidateProxyTarget(backendUrl, path);

// Assert
Assert.Null(target);
}

[Fact]
public async Task ProxyRequest_ConversationRoute_ForwardsToConfiguredBackendAsync()
{
// Arrange
await using var proxy = await ProxyTestContext.StartAsync();

// Act
var response = await proxy.SendAsync("/v1/conversations?limit=10");

// Assert
Assert.Equal(HttpStatusCode.OK, response.StatusCode);
var forwarded = Assert.Single(proxy.BackendRequests);
Assert.Equal("/v1/conversations", forwarded.Path);
Assert.Equal("?limit=10", forwarded.QueryString);
}

[Fact]
public async Task ProxyRequest_DevUIRoute_ForwardsToConfiguredBackendAsync()
{
// Arrange
await using var proxy = await ProxyTestContext.StartAsync();

// Act
var response = await proxy.SendAsync("/devui/index.html?v=1");

// Assert
Assert.Equal(HttpStatusCode.OK, response.StatusCode);
var forwarded = Assert.Single(proxy.BackendRequests);
Assert.Equal("/devui/index.html", forwarded.Path);
Assert.Equal("?v=1", forwarded.QueryString);
}

[Theory]
[InlineData("/v1/conversations/../conversations")]
[InlineData("/devui/../devui/index.html")]
public async Task ProxyRequest_NormalizedPath_ForwardsToConfiguredBackendAsync(string requestPath)
{
// Arrange
await using var proxy = await ProxyTestContext.StartAsync();

// Act
var response = await proxy.SendAsync(requestPath);

// Assert
Assert.Equal(HttpStatusCode.OK, response.StatusCode);
Assert.Single(proxy.BackendRequests);
}
Comment thread
SergeyMenshykh marked this conversation as resolved.

#region Proxy Test Helpers

/// <summary>
/// Hosts a stub backend together with a DevUI aggregator wired to it, and exposes an
/// <see cref="HttpClient"/> targeting the aggregator so proxied requests can be observed
/// on the backend.
/// </summary>
private sealed class ProxyTestContext : IAsyncDisposable
{
private readonly WebApplication _backend;
private readonly DevUIAggregatorHostedService _aggregator;
private readonly HttpClient _client;
private readonly List<(string Path, string QueryString)> _backendRequests;

private ProxyTestContext(
WebApplication backend,
DevUIAggregatorHostedService aggregator,
HttpClient client,
List<(string Path, string QueryString)> backendRequests)
{
this._backend = backend;
this._aggregator = aggregator;
this._client = client;
this._backendRequests = backendRequests;
}

/// <summary>Gets the requests received by the stub backend, in arrival order.</summary>
public IReadOnlyList<(string Path, string QueryString)> BackendRequests => this._backendRequests;

public static async Task<ProxyTestContext> StartAsync()
{
var backendRequests = new List<(string Path, string QueryString)>();
var backend = await StartStubBackendAsync(backendRequests).ConfigureAwait(false);

var aggregator = await StartAggregatorAsync(GetBaseAddress(backend)).ConfigureAwait(false);
var client = new HttpClient { BaseAddress = new Uri($"http://127.0.0.1:{aggregator.AllocatedPort}") };

return new ProxyTestContext(backend, aggregator, client, backendRequests);
}

/// <summary>Sends a GET request to the aggregator using the given relative path.</summary>
public Task<HttpResponseMessage> SendAsync(string relativePath)
=> this._client.GetAsync(new Uri(relativePath, UriKind.Relative));

public async ValueTask DisposeAsync()
{
this._client.Dispose();
await this._aggregator.DisposeAsync().ConfigureAwait(false);
await this._backend.StopAsync().ConfigureAwait(false);
await this._backend.DisposeAsync().ConfigureAwait(false);
}
}

/// <summary>
/// Starts a minimal backend that records the path and query string of every request it receives.
/// </summary>
private static async Task<WebApplication> StartStubBackendAsync(List<(string Path, string QueryString)> requests)
{
var builder = WebApplication.CreateSlimBuilder();
builder.Logging.ClearProviders();

var app = builder.Build();
app.Urls.Add("http://127.0.0.1:0");
app.Map("{**path}", (HttpContext context) =>
{
requests.Add((context.Request.Path.Value ?? string.Empty, context.Request.QueryString.Value ?? string.Empty));
return Results.Json(new { ok = true });
});

await app.StartAsync().ConfigureAwait(false);
return app;
}

/// <summary>
/// Starts a DevUI aggregator configured with a single backend pointing at <paramref name="backendUrl"/>.
/// </summary>
private static async Task<DevUIAggregatorHostedService> StartAggregatorAsync(string backendUrl)
{
var resource = new DevUIResource("test-devui");
resource.Annotations.Add(new AgentServiceAnnotation(CreateBackendResource(backendUrl)));

using var loggerFactory = LoggerFactory.Create(_ => { });
var aggregator = new DevUIAggregatorHostedService(
resource,
loggerFactory.CreateLogger<DevUIAggregatorHostedService>());

await aggregator.StartAsync(CancellationToken.None).ConfigureAwait(false);
return aggregator;
}

/// <summary>
/// Creates a backend resource whose "http" endpoint is allocated to <paramref name="backendUrl"/>.
/// </summary>
private static TestBackendResource CreateBackendResource(string backendUrl)
{
var backendUri = new Uri(backendUrl);
var resource = new TestBackendResource("test-backend");

var endpoint = new EndpointAnnotation(
ProtocolType.Tcp,
uriScheme: "http",
name: "http",
port: backendUri.Port,
isProxied: false)
{
TargetHost = backendUri.Host
};
endpoint.AllocatedEndpoint = new AllocatedEndpoint(endpoint, backendUri.Host, backendUri.Port);

resource.Annotations.Add(endpoint);
return resource;
}

private static string GetBaseAddress(WebApplication app)
=> app.Services.GetRequiredService<IServer>().Features.Get<IServerAddressesFeature>()!.Addresses.First();

private sealed class TestBackendResource(string name) : Resource(name), IResourceWithEndpoints;

#endregion

#endregion
}
Loading