From c446d64480528d6c7c8a34144d492036b1cc010f Mon Sep 17 00:00:00 2001 From: JasonEran Date: Sun, 1 Feb 2026 22:48:06 +0800 Subject: [PATCH 01/24] feat: implement external signals ingestion pipeline with database integration and API support --- CHANGELOG.md | 1 + docs/Quickstart.md | 17 + .../Controllers/ExternalSignalsController.cs | 81 +++++ .../Data/ApplicationDbContext.cs | 21 ++ ...60201123714_AddExternalSignals.Designer.cs | 332 ++++++++++++++++++ .../20260201123714_AddExternalSignals.cs | 57 +++ .../ApplicationDbContextModelSnapshot.cs | 66 ++++ .../core-dotnet/AetherGuard.Core/Program.cs | 8 + .../ExternalSignalIngestionService.cs | 327 +++++++++++++++++ .../ExternalSignals/ExternalSignalsOptions.cs | 17 + .../AetherGuard.Core/appsettings.json | 18 + .../AetherGuard.Core/models/ExternalSignal.cs | 17 + 12 files changed, 962 insertions(+) create mode 100644 src/services/core-dotnet/AetherGuard.Core/Controllers/ExternalSignalsController.cs create mode 100644 src/services/core-dotnet/AetherGuard.Core/Data/Migrations/20260201123714_AddExternalSignals.Designer.cs create mode 100644 src/services/core-dotnet/AetherGuard.Core/Data/Migrations/20260201123714_AddExternalSignals.cs create mode 100644 src/services/core-dotnet/AetherGuard.Core/Services/ExternalSignals/ExternalSignalIngestionService.cs create mode 100644 src/services/core-dotnet/AetherGuard.Core/Services/ExternalSignals/ExternalSignalsOptions.cs create mode 100644 src/services/core-dotnet/AetherGuard.Core/models/ExternalSignal.cs diff --git a/CHANGELOG.md b/CHANGELOG.md index f6c8d29..fbed843 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ Semantic Versioning. - Snapshot retention sweeper with optional S3 lifecycle configuration. - Supply-chain workflow for SBOM generation, cosign signing, and SLSA container provenance. - API key protection for telemetry ingestion and snapshot artifact endpoints. +- External signals ingestion pipeline (RSS feeds) with persisted `external_signals` table. - v2.3 multimodal predictive architecture document in `docs/ARCHITECTURE-v2.3.md`. - v2.3 delivery roadmap in `docs/ROADMAP-v2.3.md`. - Expanded v2.3 roadmap with model choices, data sources, and validation guidance. diff --git a/docs/Quickstart.md b/docs/Quickstart.md index d8718ab..02aa07c 100644 --- a/docs/Quickstart.md +++ b/docs/Quickstart.md @@ -35,6 +35,23 @@ Open the dashboard at `http://localhost:3000` and log in with: Open Jaeger at `http://localhost:16686` to view traces. +## Optional: Enable external signals (v2.3 Milestone 0) + +External signals ingestion is disabled by default. To enable: + +```bash +# PowerShell +$env:ExternalSignals__Enabled="true" +# Bash +export ExternalSignals__Enabled=true +``` + +Then restart the core service. Signals are accessible via: + +``` +GET /api/v1/signals?limit=50 +``` + If you want to simulate migrations, start at least two agents: ```bash diff --git a/src/services/core-dotnet/AetherGuard.Core/Controllers/ExternalSignalsController.cs b/src/services/core-dotnet/AetherGuard.Core/Controllers/ExternalSignalsController.cs new file mode 100644 index 0000000..82e46e7 --- /dev/null +++ b/src/services/core-dotnet/AetherGuard.Core/Controllers/ExternalSignalsController.cs @@ -0,0 +1,81 @@ +using AetherGuard.Core.Data; +using Microsoft.AspNetCore.Mvc; +using Microsoft.EntityFrameworkCore; + +namespace AetherGuard.Core.Controllers; + +[ApiController] +[Route("api/v1/signals")] +public class ExternalSignalsController : ControllerBase +{ + private readonly ApplicationDbContext _db; + private readonly ILogger _logger; + + public ExternalSignalsController(ApplicationDbContext db, ILogger logger) + { + _db = db; + _logger = logger; + } + + [HttpGet] + public async Task GetSignals( + [FromQuery] DateTimeOffset? from, + [FromQuery] DateTimeOffset? to, + [FromQuery] string? source, + [FromQuery] string? region, + [FromQuery] string? severity, + [FromQuery] int limit = 200, + CancellationToken cancellationToken = default) + { + limit = Math.Clamp(limit, 1, 500); + var query = _db.ExternalSignals.AsNoTracking(); + + if (from.HasValue) + { + query = query.Where(signal => signal.PublishedAt >= from.Value); + } + + if (to.HasValue) + { + query = query.Where(signal => signal.PublishedAt <= to.Value); + } + + if (!string.IsNullOrWhiteSpace(source)) + { + query = query.Where(signal => signal.Source == source); + } + + if (!string.IsNullOrWhiteSpace(region)) + { + query = query.Where(signal => signal.Region == region); + } + + if (!string.IsNullOrWhiteSpace(severity)) + { + query = query.Where(signal => signal.Severity == severity); + } + + var results = await query + .OrderByDescending(signal => signal.PublishedAt) + .Take(limit) + .Select(signal => new + { + signal.Id, + signal.Source, + signal.ExternalId, + signal.Title, + signal.Summary, + signal.Region, + signal.Severity, + signal.Category, + signal.Url, + signal.Tags, + signal.PublishedAt, + signal.IngestedAt + }) + .ToListAsync(cancellationToken); + + _logger.LogInformation("Returned {Count} external signals.", results.Count); + return Ok(results); + } +} diff --git a/src/services/core-dotnet/AetherGuard.Core/Data/ApplicationDbContext.cs b/src/services/core-dotnet/AetherGuard.Core/Data/ApplicationDbContext.cs index d7e72c6..d545952 100644 --- a/src/services/core-dotnet/AetherGuard.Core/Data/ApplicationDbContext.cs +++ b/src/services/core-dotnet/AetherGuard.Core/Data/ApplicationDbContext.cs @@ -13,6 +13,7 @@ public ApplicationDbContext(DbContextOptions options) public DbSet Agents => Set(); public DbSet AgentCommands => Set(); public DbSet CommandAudits => Set(); + public DbSet ExternalSignals => Set(); public DbSet TelemetryRecords => Set(); public DbSet SchemaRegistryEntries => Set(); @@ -87,5 +88,25 @@ protected override void OnModelCreating(ModelBuilder modelBuilder) entity.Property(e => e.Schema).HasColumnName("schema"); entity.Property(e => e.CreatedAt).HasColumnName("created_at"); }); + + modelBuilder.Entity(entity => + { + entity.ToTable("external_signals"); + entity.HasKey(e => e.Id); + entity.HasIndex(e => new { e.Source, e.ExternalId }).IsUnique(); + entity.HasIndex(e => e.PublishedAt); + entity.Property(e => e.Id).HasColumnName("id").ValueGeneratedOnAdd(); + entity.Property(e => e.Source).HasColumnName("source"); + entity.Property(e => e.ExternalId).HasColumnName("external_id"); + entity.Property(e => e.Title).HasColumnName("title"); + entity.Property(e => e.Summary).HasColumnName("summary"); + entity.Property(e => e.Region).HasColumnName("region"); + entity.Property(e => e.Severity).HasColumnName("severity"); + entity.Property(e => e.Category).HasColumnName("category"); + entity.Property(e => e.Url).HasColumnName("url"); + entity.Property(e => e.Tags).HasColumnName("tags"); + entity.Property(e => e.PublishedAt).HasColumnName("published_at"); + entity.Property(e => e.IngestedAt).HasColumnName("ingested_at"); + }); } } diff --git a/src/services/core-dotnet/AetherGuard.Core/Data/Migrations/20260201123714_AddExternalSignals.Designer.cs b/src/services/core-dotnet/AetherGuard.Core/Data/Migrations/20260201123714_AddExternalSignals.Designer.cs new file mode 100644 index 0000000..7ecc710 --- /dev/null +++ b/src/services/core-dotnet/AetherGuard.Core/Data/Migrations/20260201123714_AddExternalSignals.Designer.cs @@ -0,0 +1,332 @@ +// +using System; +using AetherGuard.Core.Data; +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Infrastructure; +using Microsoft.EntityFrameworkCore.Migrations; +using Microsoft.EntityFrameworkCore.Storage.ValueConversion; +using Npgsql.EntityFrameworkCore.PostgreSQL.Metadata; + +#nullable disable + +namespace AetherGuard.Core.Data.Migrations +{ + [DbContext(typeof(ApplicationDbContext))] + [Migration("20260201123714_AddExternalSignals")] + partial class AddExternalSignals + { + /// + protected override void BuildTargetModel(ModelBuilder modelBuilder) + { +#pragma warning disable 612, 618 + modelBuilder + .HasAnnotation("ProductVersion", "8.0.0") + .HasAnnotation("Relational:MaxIdentifierLength", 63); + + NpgsqlModelBuilderExtensions.UseIdentityByDefaultColumns(modelBuilder); + + modelBuilder.Entity("AetherGuard.Core.Models.Agent", b => + { + b.Property("Id") + .ValueGeneratedOnAdd() + .HasColumnType("uuid") + .HasColumnName("id"); + + b.Property("AgentToken") + .IsRequired() + .HasColumnType("text") + .HasColumnName("agenttoken"); + + b.Property("Hostname") + .IsRequired() + .HasColumnType("text") + .HasColumnName("hostname"); + + b.Property("LastHeartbeat") + .HasColumnType("timestamp with time zone") + .HasColumnName("lastheartbeat"); + + b.Property("Status") + .IsRequired() + .HasColumnType("text") + .HasColumnName("status"); + + b.HasKey("Id"); + + b.HasIndex("AgentToken") + .IsUnique(); + + b.ToTable("agents", (string)null); + }); + + modelBuilder.Entity("AetherGuard.Core.Models.AgentCommand", b => + { + b.Property("Id") + .ValueGeneratedOnAdd() + .HasColumnType("integer") + .HasColumnName("id"); + + NpgsqlPropertyBuilderExtensions.UseIdentityByDefaultColumn(b.Property("Id")); + + b.Property("Action") + .IsRequired() + .HasColumnType("text") + .HasColumnName("action"); + + b.Property("AgentId") + .HasColumnType("uuid") + .HasColumnName("agent_id"); + + b.Property("CommandId") + .HasColumnType("uuid") + .HasColumnName("command_id"); + + b.Property("CreatedAt") + .HasColumnType("timestamp with time zone") + .HasColumnName("created_at"); + + b.Property("ExpiresAt") + .HasColumnType("timestamp with time zone") + .HasColumnName("expires_at"); + + b.Property("Nonce") + .IsRequired() + .HasColumnType("text") + .HasColumnName("nonce"); + + b.Property("Parameters") + .IsRequired() + .HasColumnType("text") + .HasColumnName("parameters"); + + b.Property("Signature") + .IsRequired() + .HasColumnType("text") + .HasColumnName("signature"); + + b.Property("Status") + .IsRequired() + .HasColumnType("text") + .HasColumnName("status"); + + b.Property("UpdatedAt") + .HasColumnType("timestamp with time zone") + .HasColumnName("updated_at"); + + b.Property("WorkloadId") + .IsRequired() + .HasColumnType("text") + .HasColumnName("workload_id"); + + b.HasKey("Id"); + + b.ToTable("agent_commands", (string)null); + }); + + modelBuilder.Entity("AetherGuard.Core.Models.CommandAudit", b => + { + b.Property("Id") + .ValueGeneratedOnAdd() + .HasColumnType("bigint") + .HasColumnName("id"); + + NpgsqlPropertyBuilderExtensions.UseIdentityByDefaultColumn(b.Property("Id")); + + b.Property("Action") + .IsRequired() + .HasColumnType("text") + .HasColumnName("action"); + + b.Property("Actor") + .IsRequired() + .HasColumnType("text") + .HasColumnName("actor"); + + b.Property("CommandId") + .HasColumnType("uuid") + .HasColumnName("command_id"); + + b.Property("CreatedAt") + .HasColumnType("timestamp with time zone") + .HasColumnName("created_at"); + + b.Property("Error") + .IsRequired() + .HasColumnType("text") + .HasColumnName("error"); + + b.Property("Result") + .IsRequired() + .HasColumnType("text") + .HasColumnName("result"); + + b.HasKey("Id"); + + b.ToTable("command_audits", (string)null); + }); + + modelBuilder.Entity("AetherGuard.Core.Models.ExternalSignal", b => + { + b.Property("Id") + .ValueGeneratedOnAdd() + .HasColumnType("bigint") + .HasColumnName("id"); + + NpgsqlPropertyBuilderExtensions.UseIdentityByDefaultColumn(b.Property("Id")); + + b.Property("Category") + .HasColumnType("text") + .HasColumnName("category"); + + b.Property("ExternalId") + .IsRequired() + .HasColumnType("text") + .HasColumnName("external_id"); + + b.Property("IngestedAt") + .HasColumnType("timestamp with time zone") + .HasColumnName("ingested_at"); + + b.Property("PublishedAt") + .HasColumnType("timestamp with time zone") + .HasColumnName("published_at"); + + b.Property("Region") + .HasColumnType("text") + .HasColumnName("region"); + + b.Property("Severity") + .HasColumnType("text") + .HasColumnName("severity"); + + b.Property("Source") + .IsRequired() + .HasColumnType("text") + .HasColumnName("source"); + + b.Property("Summary") + .HasColumnType("text") + .HasColumnName("summary"); + + b.Property("Tags") + .HasColumnType("text") + .HasColumnName("tags"); + + b.Property("Title") + .IsRequired() + .HasColumnType("text") + .HasColumnName("title"); + + b.Property("Url") + .HasColumnType("text") + .HasColumnName("url"); + + b.HasKey("Id"); + + b.HasIndex("PublishedAt"); + + b.HasIndex("Source", "ExternalId") + .IsUnique(); + + b.ToTable("external_signals", (string)null); + }); + + modelBuilder.Entity("AetherGuard.Core.Models.SchemaRegistryEntry", b => + { + b.Property("Id") + .ValueGeneratedOnAdd() + .HasColumnType("bigint") + .HasColumnName("id"); + + NpgsqlPropertyBuilderExtensions.UseIdentityByDefaultColumn(b.Property("Id")); + + b.Property("CreatedAt") + .HasColumnType("timestamp with time zone") + .HasColumnName("created_at"); + + b.Property("Schema") + .IsRequired() + .HasColumnType("text") + .HasColumnName("schema"); + + b.Property("Subject") + .IsRequired() + .HasColumnType("text") + .HasColumnName("subject"); + + b.Property("Version") + .HasColumnType("integer") + .HasColumnName("version"); + + b.HasKey("Id"); + + b.HasIndex("Subject", "Version") + .IsUnique(); + + b.ToTable("schema_registry", (string)null); + }); + + modelBuilder.Entity("AetherGuard.Core.Models.TelemetryRecord", b => + { + b.Property("Id") + .ValueGeneratedOnAdd() + .HasColumnType("bigint") + .HasColumnName("Id"); + + NpgsqlPropertyBuilderExtensions.UseIdentityByDefaultColumn(b.Property("Id")); + + b.Property("Timestamp") + .HasColumnType("timestamp with time zone") + .HasColumnName("Timestamp"); + + b.Property("AgentId") + .IsRequired() + .HasColumnType("text") + .HasColumnName("AgentId"); + + b.Property("AiConfidence") + .HasColumnType("double precision") + .HasColumnName("AiConfidence"); + + b.Property("AiStatus") + .IsRequired() + .HasColumnType("text") + .HasColumnName("AiStatus"); + + b.Property("CpuUsage") + .HasColumnType("double precision") + .HasColumnName("CpuUsage"); + + b.Property("DiskAvailable") + .HasColumnType("bigint") + .HasColumnName("DiskAvailable"); + + b.Property("MemoryUsage") + .HasColumnType("double precision") + .HasColumnName("MemoryUsage"); + + b.Property("PredictedCpu") + .HasColumnType("double precision") + .HasColumnName("PredictedCpu"); + + b.Property("RebalanceSignal") + .HasColumnType("boolean") + .HasColumnName("RebalanceSignal"); + + b.Property("RootCause") + .HasColumnType("text") + .HasColumnName("RootCause"); + + b.Property("WorkloadTier") + .IsRequired() + .HasColumnType("text") + .HasColumnName("WorkloadTier"); + + b.HasKey("Id", "Timestamp"); + + b.ToTable("TelemetryRecords", (string)null); + }); +#pragma warning restore 612, 618 + } + } +} diff --git a/src/services/core-dotnet/AetherGuard.Core/Data/Migrations/20260201123714_AddExternalSignals.cs b/src/services/core-dotnet/AetherGuard.Core/Data/Migrations/20260201123714_AddExternalSignals.cs new file mode 100644 index 0000000..9e92b84 --- /dev/null +++ b/src/services/core-dotnet/AetherGuard.Core/Data/Migrations/20260201123714_AddExternalSignals.cs @@ -0,0 +1,57 @@ +using System; +using Microsoft.EntityFrameworkCore.Migrations; +using Npgsql.EntityFrameworkCore.PostgreSQL.Metadata; + +#nullable disable + +namespace AetherGuard.Core.Data.Migrations +{ + /// + public partial class AddExternalSignals : Migration + { + /// + protected override void Up(MigrationBuilder migrationBuilder) + { + migrationBuilder.CreateTable( + name: "external_signals", + columns: table => new + { + id = table.Column(type: "bigint", nullable: false) + .Annotation("Npgsql:ValueGenerationStrategy", NpgsqlValueGenerationStrategy.IdentityByDefaultColumn), + source = table.Column(type: "text", nullable: false), + external_id = table.Column(type: "text", nullable: false), + title = table.Column(type: "text", nullable: false), + summary = table.Column(type: "text", nullable: true), + region = table.Column(type: "text", nullable: true), + severity = table.Column(type: "text", nullable: true), + category = table.Column(type: "text", nullable: true), + url = table.Column(type: "text", nullable: true), + tags = table.Column(type: "text", nullable: true), + published_at = table.Column(type: "timestamp with time zone", nullable: false), + ingested_at = table.Column(type: "timestamp with time zone", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_external_signals", x => x.id); + }); + + migrationBuilder.CreateIndex( + name: "IX_external_signals_published_at", + table: "external_signals", + column: "published_at"); + + migrationBuilder.CreateIndex( + name: "IX_external_signals_source_external_id", + table: "external_signals", + columns: new[] { "source", "external_id" }, + unique: true); + } + + /// + protected override void Down(MigrationBuilder migrationBuilder) + { + migrationBuilder.DropTable( + name: "external_signals"); + } + } +} diff --git a/src/services/core-dotnet/AetherGuard.Core/Data/Migrations/ApplicationDbContextModelSnapshot.cs b/src/services/core-dotnet/AetherGuard.Core/Data/Migrations/ApplicationDbContextModelSnapshot.cs index a5d65fb..aa841dc 100644 --- a/src/services/core-dotnet/AetherGuard.Core/Data/Migrations/ApplicationDbContextModelSnapshot.cs +++ b/src/services/core-dotnet/AetherGuard.Core/Data/Migrations/ApplicationDbContextModelSnapshot.cs @@ -162,6 +162,72 @@ protected override void BuildModel(ModelBuilder modelBuilder) b.ToTable("command_audits", (string)null); }); + modelBuilder.Entity("AetherGuard.Core.Models.ExternalSignal", b => + { + b.Property("Id") + .ValueGeneratedOnAdd() + .HasColumnType("bigint") + .HasColumnName("id"); + + NpgsqlPropertyBuilderExtensions.UseIdentityByDefaultColumn(b.Property("Id")); + + b.Property("Category") + .HasColumnType("text") + .HasColumnName("category"); + + b.Property("ExternalId") + .IsRequired() + .HasColumnType("text") + .HasColumnName("external_id"); + + b.Property("IngestedAt") + .HasColumnType("timestamp with time zone") + .HasColumnName("ingested_at"); + + b.Property("PublishedAt") + .HasColumnType("timestamp with time zone") + .HasColumnName("published_at"); + + b.Property("Region") + .HasColumnType("text") + .HasColumnName("region"); + + b.Property("Severity") + .HasColumnType("text") + .HasColumnName("severity"); + + b.Property("Source") + .IsRequired() + .HasColumnType("text") + .HasColumnName("source"); + + b.Property("Summary") + .HasColumnType("text") + .HasColumnName("summary"); + + b.Property("Tags") + .HasColumnType("text") + .HasColumnName("tags"); + + b.Property("Title") + .IsRequired() + .HasColumnType("text") + .HasColumnName("title"); + + b.Property("Url") + .HasColumnType("text") + .HasColumnName("url"); + + b.HasKey("Id"); + + b.HasIndex("PublishedAt"); + + b.HasIndex("Source", "ExternalId") + .IsUnique(); + + b.ToTable("external_signals", (string)null); + }); + modelBuilder.Entity("AetherGuard.Core.Models.SchemaRegistryEntry", b => { b.Property("Id") diff --git a/src/services/core-dotnet/AetherGuard.Core/Program.cs b/src/services/core-dotnet/AetherGuard.Core/Program.cs index feba0a9..209083c 100644 --- a/src/services/core-dotnet/AetherGuard.Core/Program.cs +++ b/src/services/core-dotnet/AetherGuard.Core/Program.cs @@ -51,6 +51,11 @@ builder.Services.AddSingleton(); builder.Services.AddHttpClient(); +builder.Services.AddHttpClient("external-signals", client => +{ + client.DefaultRequestHeaders.UserAgent.ParseAdd("Aether-Guard/ExternalSignals"); + client.Timeout = TimeSpan.FromSeconds(15); +}); builder.Services.AddScoped(); builder.Services.AddScoped(); builder.Services.AddScoped(); @@ -64,6 +69,9 @@ builder.Services.AddHostedService(); builder.Services.AddSingleton(); builder.Services.AddHostedService(); +builder.Services.Configure( + builder.Configuration.GetSection("ExternalSignals")); +builder.Services.AddHostedService(); var otelOptions = builder.Configuration.GetSection("OpenTelemetry").Get() ?? new OpenTelemetryOptions(); diff --git a/src/services/core-dotnet/AetherGuard.Core/Services/ExternalSignals/ExternalSignalIngestionService.cs b/src/services/core-dotnet/AetherGuard.Core/Services/ExternalSignals/ExternalSignalIngestionService.cs new file mode 100644 index 0000000..cc0b20e --- /dev/null +++ b/src/services/core-dotnet/AetherGuard.Core/Services/ExternalSignals/ExternalSignalIngestionService.cs @@ -0,0 +1,327 @@ +using System.Globalization; +using System.Net; +using System.Text.RegularExpressions; +using System.Xml.Linq; +using AetherGuard.Core.Data; +using AetherGuard.Core.Models; +using Microsoft.EntityFrameworkCore; +using Microsoft.Extensions.Options; + +namespace AetherGuard.Core.Services.ExternalSignals; + +public sealed class ExternalSignalIngestionService : BackgroundService +{ + private static readonly string[] SeverityKeywords = + [ + "outage", + "degraded", + "disruption", + "investigating", + "incident", + "partial", + "unavailable", + "restored", + "resolved" + ]; + + private readonly IHttpClientFactory _httpClientFactory; + private readonly IServiceScopeFactory _scopeFactory; + private readonly ILogger _logger; + private readonly ExternalSignalsOptions _options; + + public ExternalSignalIngestionService( + IHttpClientFactory httpClientFactory, + IServiceScopeFactory scopeFactory, + IOptions options, + ILogger logger) + { + _httpClientFactory = httpClientFactory; + _scopeFactory = scopeFactory; + _logger = logger; + _options = options.Value; + } + + protected override async Task ExecuteAsync(CancellationToken stoppingToken) + { + if (!_options.Enabled) + { + _logger.LogInformation("External signals ingestion is disabled."); + return; + } + + if (_options.Feeds.Count == 0) + { + _logger.LogWarning("External signals ingestion enabled but no feeds configured."); + return; + } + + var interval = TimeSpan.FromSeconds(Math.Max(60, _options.PollingIntervalSeconds)); + + while (!stoppingToken.IsCancellationRequested) + { + try + { + await IngestOnceAsync(stoppingToken); + } + catch (Exception ex) + { + _logger.LogError(ex, "External signals ingestion cycle failed."); + } + + try + { + await Task.Delay(interval, stoppingToken); + } + catch (TaskCanceledException) + { + break; + } + } + } + + private async Task IngestOnceAsync(CancellationToken cancellationToken) + { + using var scope = _scopeFactory.CreateScope(); + var db = scope.ServiceProvider.GetRequiredService(); + var httpClient = _httpClientFactory.CreateClient("external-signals"); + + var lookback = DateTimeOffset.UtcNow.AddHours(-Math.Max(1, _options.LookbackHours)); + var maxItems = Math.Clamp(_options.MaxItemsPerFeed, 10, 1000); + + foreach (var feed in _options.Feeds) + { + if (string.IsNullOrWhiteSpace(feed.Url) || string.IsNullOrWhiteSpace(feed.Name)) + { + _logger.LogWarning("Skipping external signal feed with missing name/url."); + continue; + } + + List parsed; + try + { + var response = await httpClient.GetAsync(feed.Url, cancellationToken); + response.EnsureSuccessStatusCode(); + var content = await response.Content.ReadAsStringAsync(cancellationToken); + parsed = ParseFeed(content, feed); + } + catch (Exception ex) + { + _logger.LogWarning(ex, "Failed to fetch external signal feed {Feed}.", feed.Name); + continue; + } + + if (parsed.Count == 0) + { + continue; + } + + parsed = parsed + .Where(signal => signal.PublishedAt >= lookback) + .OrderByDescending(signal => signal.PublishedAt) + .Take(maxItems) + .ToList(); + + if (parsed.Count == 0) + { + continue; + } + + var existingIds = await db.ExternalSignals + .Where(signal => signal.Source == feed.Name && signal.PublishedAt >= lookback) + .Select(signal => signal.ExternalId) + .ToListAsync(cancellationToken); + + var existing = existingIds.ToHashSet(StringComparer.OrdinalIgnoreCase); + var newSignals = parsed + .Where(signal => !existing.Contains(signal.ExternalId)) + .GroupBy(signal => signal.ExternalId, StringComparer.OrdinalIgnoreCase) + .Select(group => group.First()) + .ToList(); + + if (newSignals.Count == 0) + { + continue; + } + + db.ExternalSignals.AddRange(newSignals); + await db.SaveChangesAsync(cancellationToken); + + _logger.LogInformation("Ingested {Count} signals from {Feed}.", newSignals.Count, feed.Name); + } + } + + private static List ParseFeed(string content, ExternalSignalFeedOptions feed) + { + var document = XDocument.Parse(content); + var root = document.Root; + if (root is null) + { + return []; + } + + var items = root.Name.LocalName switch + { + "rss" => root.Descendants().Where(e => e.Name.LocalName == "item"), + "feed" => root.Descendants().Where(e => e.Name.LocalName == "entry"), + _ => root.Descendants().Where(e => e.Name.LocalName == "item" || e.Name.LocalName == "entry") + }; + + var signals = new List(); + foreach (var item in items) + { + var title = GetValue(item, "title"); + if (string.IsNullOrWhiteSpace(title)) + { + continue; + } + + var externalId = GetValue(item, "guid") ?? GetValue(item, "id") ?? title; + var url = GetValue(item, "link"); + if (string.IsNullOrWhiteSpace(url)) + { + url = GetLinkHref(item); + } + + var summaryRaw = GetValue(item, "description") ?? GetValue(item, "summary") ?? GetValue(item, "content"); + var summary = NormalizeText(summaryRaw); + var published = ParseDate(GetValue(item, "pubDate")) + ?? ParseDate(GetValue(item, "updated")) + ?? ParseDate(GetValue(item, "published")) + ?? DateTimeOffset.UtcNow; + + var severity = GuessSeverity(title, summary); + var category = GuessCategory(title, summary); + var region = ExtractRegion(title, summary) ?? feed.DefaultRegion; + + signals.Add(new ExternalSignal + { + Source = feed.Name, + ExternalId = externalId, + Title = title, + Summary = summary, + Region = region, + Severity = severity, + Category = category, + Url = url, + Tags = BuildTags(title, summary, severity, category), + PublishedAt = published, + IngestedAt = DateTimeOffset.UtcNow + }); + } + + return signals; + } + + private static string? GetValue(XElement parent, string name) + { + return parent.Elements().FirstOrDefault(e => e.Name.LocalName.Equals(name, StringComparison.OrdinalIgnoreCase))?.Value; + } + + private static string? GetLinkHref(XElement parent) + { + var link = parent.Elements().FirstOrDefault(e => e.Name.LocalName.Equals("link", StringComparison.OrdinalIgnoreCase)); + return link?.Attribute("href")?.Value; + } + + private static DateTimeOffset? ParseDate(string? value) + { + if (string.IsNullOrWhiteSpace(value)) + { + return null; + } + + if (DateTimeOffset.TryParse(value, CultureInfo.InvariantCulture, DateTimeStyles.AssumeUniversal, out var parsed)) + { + return parsed; + } + + return null; + } + + private static string? GuessSeverity(string title, string? summary) + { + var content = $"{title} {summary}".ToLowerInvariant(); + foreach (var keyword in SeverityKeywords) + { + if (content.Contains(keyword, StringComparison.OrdinalIgnoreCase)) + { + return keyword; + } + } + + return "info"; + } + + private static string? GuessCategory(string title, string? summary) + { + var content = $"{title} {summary}".ToLowerInvariant(); + if (content.Contains("maintenance")) + { + return "maintenance"; + } + if (content.Contains("outage") || content.Contains("disruption")) + { + return "outage"; + } + if (content.Contains("degraded")) + { + return "degraded"; + } + if (content.Contains("incident") || content.Contains("investigating")) + { + return "incident"; + } + + return "notice"; + } + + private static string? ExtractRegion(string title, string? summary) + { + var content = $"{title} {summary}"; + var match = System.Text.RegularExpressions.Regex.Match(content, @"\b([a-z]{2}-[a-z]+-\d)\b", System.Text.RegularExpressions.RegexOptions.IgnoreCase); + if (match.Success) + { + return match.Groups[1].Value; + } + + return null; + } + + private static string? BuildTags(string title, string? summary, string? severity, string? category) + { + var tags = new List(); + if (!string.IsNullOrWhiteSpace(severity)) + { + tags.Add(severity); + } + + if (!string.IsNullOrWhiteSpace(category) && !tags.Any(tag => string.Equals(tag, category, StringComparison.OrdinalIgnoreCase))) + { + tags.Add(category); + } + + if (title.Contains("region", StringComparison.OrdinalIgnoreCase)) + { + tags.Add("region"); + } + + if (!string.IsNullOrWhiteSpace(summary) && summary.Contains("latency", StringComparison.OrdinalIgnoreCase)) + { + tags.Add("latency"); + } + + return tags.Count > 0 ? string.Join(",", tags) : null; + } + + private static string? NormalizeText(string? raw) + { + if (string.IsNullOrWhiteSpace(raw)) + { + return raw; + } + + var decoded = WebUtility.HtmlDecode(raw); + var withoutTags = Regex.Replace(decoded, "<.*?>", string.Empty, RegexOptions.Singleline); + return withoutTags.Trim(); + } +} diff --git a/src/services/core-dotnet/AetherGuard.Core/Services/ExternalSignals/ExternalSignalsOptions.cs b/src/services/core-dotnet/AetherGuard.Core/Services/ExternalSignals/ExternalSignalsOptions.cs new file mode 100644 index 0000000..663fba8 --- /dev/null +++ b/src/services/core-dotnet/AetherGuard.Core/Services/ExternalSignals/ExternalSignalsOptions.cs @@ -0,0 +1,17 @@ +namespace AetherGuard.Core.Services.ExternalSignals; + +public sealed class ExternalSignalsOptions +{ + public bool Enabled { get; set; } = false; + public int PollingIntervalSeconds { get; set; } = 300; + public int LookbackHours { get; set; } = 48; + public int MaxItemsPerFeed { get; set; } = 200; + public List Feeds { get; set; } = new(); +} + +public sealed class ExternalSignalFeedOptions +{ + public string Name { get; set; } = string.Empty; + public string Url { get; set; } = string.Empty; + public string? DefaultRegion { get; set; } +} diff --git a/src/services/core-dotnet/AetherGuard.Core/appsettings.json b/src/services/core-dotnet/AetherGuard.Core/appsettings.json index dc35cd3..60a5dc2 100644 --- a/src/services/core-dotnet/AetherGuard.Core/appsettings.json +++ b/src/services/core-dotnet/AetherGuard.Core/appsettings.json @@ -70,6 +70,24 @@ "ApplyS3Lifecycle": true, "S3ExpirationDays": 30 }, + "ExternalSignals": { + "Enabled": false, + "PollingIntervalSeconds": 300, + "LookbackHours": 48, + "MaxItemsPerFeed": 200, + "Feeds": [ + { + "Name": "aws-status", + "Url": "https://status.aws.amazon.com/rss/all.rss", + "DefaultRegion": "global" + }, + { + "Name": "gcp-status", + "Url": "https://status.cloud.google.com/en/feed.atom", + "DefaultRegion": "global" + } + ] + }, "MigrationIntervalSeconds": 10, "HeartbeatTimeoutSeconds": 60, "AllowedHosts": "*" diff --git a/src/services/core-dotnet/AetherGuard.Core/models/ExternalSignal.cs b/src/services/core-dotnet/AetherGuard.Core/models/ExternalSignal.cs new file mode 100644 index 0000000..870fb36 --- /dev/null +++ b/src/services/core-dotnet/AetherGuard.Core/models/ExternalSignal.cs @@ -0,0 +1,17 @@ +namespace AetherGuard.Core.Models; + +public sealed class ExternalSignal +{ + public long Id { get; set; } + public string Source { get; set; } = string.Empty; + public string ExternalId { get; set; } = string.Empty; + public string Title { get; set; } = string.Empty; + public string? Summary { get; set; } + public string? Region { get; set; } + public string? Severity { get; set; } + public string? Category { get; set; } + public string? Url { get; set; } + public string? Tags { get; set; } + public DateTimeOffset PublishedAt { get; set; } + public DateTimeOffset IngestedAt { get; set; } +} From 66be642a38af503a06d92ff8ffff26f661b08498 Mon Sep 17 00:00:00 2001 From: JasonEran Date: Mon, 2 Feb 2026 19:07:25 +0800 Subject: [PATCH 02/24] feat: Add external signal feed health tracking and status API - Implemented `ExternalSignalFeedState` model to track the health of external signal feeds. - Added `GetFeedStates` endpoint in `ExternalSignalsController` to retrieve feed health status. - Updated `ApplicationDbContext` to include `ExternalSignalFeedState` DbSet and configured its mapping. - Created migration for adding `external_signal_feeds` table to the database. - Enhanced `ExternalSignalIngestionService` to update feed states based on fetch results. - Introduced `ExternalSignalParser` for parsing RSS and Atom feeds. - Added regression tests for feed parsing functionality. - Updated documentation to include new API endpoint for feed health status. --- CHANGELOG.md | 3 + docs/Quickstart.md | 6 + src/services/ai-engine/main.py | 63 ++- .../AetherGuard.Core.Tests.csproj | 19 + .../ExternalSignalParserTests.cs | 71 ++++ .../Controllers/ExternalSignalsController.cs | 21 + .../Data/ApplicationDbContext.cs | 16 + ...2103549_AddExternalSignalFeeds.Designer.cs | 379 ++++++++++++++++++ .../20260202103549_AddExternalSignalFeeds.cs | 48 +++ .../ApplicationDbContextModelSnapshot.cs | 47 +++ .../ExternalSignalIngestionService.cs | 208 ++-------- .../ExternalSignals/ExternalSignalParser.cs | 201 ++++++++++ .../AetherGuard.Core/appsettings.json | 5 + .../models/ExternalSignalFeedState.cs | 13 + 14 files changed, 912 insertions(+), 188 deletions(-) create mode 100644 src/services/core-dotnet/AetherGuard.Core.Tests/AetherGuard.Core.Tests.csproj create mode 100644 src/services/core-dotnet/AetherGuard.Core.Tests/ExternalSignalParserTests.cs create mode 100644 src/services/core-dotnet/AetherGuard.Core/Data/Migrations/20260202103549_AddExternalSignalFeeds.Designer.cs create mode 100644 src/services/core-dotnet/AetherGuard.Core/Data/Migrations/20260202103549_AddExternalSignalFeeds.cs create mode 100644 src/services/core-dotnet/AetherGuard.Core/Services/ExternalSignals/ExternalSignalParser.cs create mode 100644 src/services/core-dotnet/AetherGuard.Core/models/ExternalSignalFeedState.cs diff --git a/CHANGELOG.md b/CHANGELOG.md index fbed843..318dba0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,9 @@ Semantic Versioning. - Supply-chain workflow for SBOM generation, cosign signing, and SLSA container provenance. - API key protection for telemetry ingestion and snapshot artifact endpoints. - External signals ingestion pipeline (RSS feeds) with persisted `external_signals` table. +- External signal feed health tracking (`external_signal_feeds`) and feed status API. +- Parser regression tests for RSS/Atom feeds. +- AI Engine semantic enrichment stub (`/signals/enrich`) for v2.3 pipeline integration. - v2.3 multimodal predictive architecture document in `docs/ARCHITECTURE-v2.3.md`. - v2.3 delivery roadmap in `docs/ROADMAP-v2.3.md`. - Expanded v2.3 roadmap with model choices, data sources, and validation guidance. diff --git a/docs/Quickstart.md b/docs/Quickstart.md index 02aa07c..0cf0df7 100644 --- a/docs/Quickstart.md +++ b/docs/Quickstart.md @@ -52,6 +52,12 @@ Then restart the core service. Signals are accessible via: GET /api/v1/signals?limit=50 ``` +Feed health status is available via: + +``` +GET /api/v1/signals/feeds +``` + If you want to simulate migrations, start at least two agents: ```bash diff --git a/src/services/ai-engine/main.py b/src/services/ai-engine/main.py index ef47ce4..9776658 100644 --- a/src/services/ai-engine/main.py +++ b/src/services/ai-engine/main.py @@ -1,8 +1,9 @@ import logging import os +from contextlib import asynccontextmanager from fastapi import FastAPI -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, ConfigDict from opentelemetry import trace from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor @@ -14,24 +15,49 @@ from model import RiskScorer logger = logging.getLogger("uvicorn.error") -app = FastAPI() scorer = RiskScorer() +@asynccontextmanager +async def lifespan(app_instance: FastAPI): + configure_tracing() + app_instance.state.scorer = scorer + logger.info("AI Engine Online.") + yield + + +app = FastAPI(lifespan=lifespan) + + class RiskPayload(BaseModel): spot_price_history: list[float] = Field(default_factory=list, alias="spotPriceHistory") rebalance_signal: bool = Field(alias="rebalanceSignal") capacity_score: float = Field(alias="capacityScore") - class Config: - allow_population_by_field_name = True + model_config = ConfigDict(populate_by_name=True) -@app.on_event("startup") -def load_model() -> None: - configure_tracing() - app.state.scorer = scorer - logger.info("AI Engine Online.") +class SignalDocument(BaseModel): + source: str + title: str + summary: str | None = None + url: str | None = None + region: str | None = None + published_at: str | None = Field(default=None, alias="publishedAt") + + model_config = ConfigDict(populate_by_name=True) + + +class EnrichRequest(BaseModel): + documents: list[SignalDocument] + + +class EnrichResponse(BaseModel): + s_v: list[float] = Field(alias="S_v") + p_v: float = Field(alias="P_v") + b_s: float = Field(alias="B_s") + + model_config = ConfigDict(populate_by_name=True) @app.get("/") @@ -61,6 +87,25 @@ def analyze(payload: RiskPayload) -> dict: } +@app.post("/signals/enrich", response_model=EnrichResponse) +def enrich_signals(payload: EnrichRequest) -> EnrichResponse: + # Placeholder semantic enrichment: use simple heuristics until NLP pipeline is online. + combined = " ".join( + [doc.title + " " + (doc.summary or "") for doc in payload.documents] + ).lower() + negative_terms = ("outage", "disruption", "degraded", "incident", "latency", "unavailable") + has_negative = any(term in combined for term in negative_terms) + + s_v = [0.1, 0.1, 0.1] + p_v = 0.15 + b_s = 0.0 + if has_negative: + s_v = [0.9, 0.2, 0.1] + p_v = 0.85 + b_s = 0.2 + + return EnrichResponse(S_v=s_v, P_v=p_v, B_s=b_s) + def configure_tracing() -> None: endpoint = os.getenv("OTEL_EXPORTER_OTLP_ENDPOINT") if not endpoint: diff --git a/src/services/core-dotnet/AetherGuard.Core.Tests/AetherGuard.Core.Tests.csproj b/src/services/core-dotnet/AetherGuard.Core.Tests/AetherGuard.Core.Tests.csproj new file mode 100644 index 0000000..b162747 --- /dev/null +++ b/src/services/core-dotnet/AetherGuard.Core.Tests/AetherGuard.Core.Tests.csproj @@ -0,0 +1,19 @@ + + + net8.0 + false + enable + + + + + + + + + + + all + + + diff --git a/src/services/core-dotnet/AetherGuard.Core.Tests/ExternalSignalParserTests.cs b/src/services/core-dotnet/AetherGuard.Core.Tests/ExternalSignalParserTests.cs new file mode 100644 index 0000000..c6c4470 --- /dev/null +++ b/src/services/core-dotnet/AetherGuard.Core.Tests/ExternalSignalParserTests.cs @@ -0,0 +1,71 @@ +using AetherGuard.Core.Services.ExternalSignals; +using Xunit; + +namespace AetherGuard.Core.Tests; + +public class ExternalSignalParserTests +{ + [Fact] + public void ParsesRssItems() + { + var rss = """ + + + + + aws-1 + Service disruption in us-east-1 + https://status.aws.amazon.com/ + Investigating elevated errors. + Fri, 01 Aug 2025 12:00:00 GMT + + + + """; + + var feed = new ExternalSignalFeedOptions + { + Name = "aws-status", + Url = "https://status.aws.amazon.com/rss/all.rss", + DefaultRegion = "global" + }; + + var results = ExternalSignalParser.ParseFeed(rss, feed); + + Assert.Single(results); + Assert.Equal("aws-1", results[0].ExternalId); + Assert.Equal("aws-status", results[0].Source); + Assert.Equal("us-east-1", results[0].Region); + } + + [Fact] + public void ParsesAtomEntries() + { + var atom = """ + + + + tag:status.cloud.google.com,2025:feed:example + RESOLVED: incident in us-central1 + + 2025-07-23T09:26:58+00:00 + Service recovered. + + + """; + + var feed = new ExternalSignalFeedOptions + { + Name = "gcp-status", + Url = "https://status.cloud.google.com/en/feed.atom", + DefaultRegion = "global" + }; + + var results = ExternalSignalParser.ParseFeed(atom, feed); + + Assert.Single(results); + Assert.Equal("gcp-status", results[0].Source); + Assert.Equal("us-central1", results[0].Region); + Assert.Equal("incident", results[0].Category); + } +} diff --git a/src/services/core-dotnet/AetherGuard.Core/Controllers/ExternalSignalsController.cs b/src/services/core-dotnet/AetherGuard.Core/Controllers/ExternalSignalsController.cs index 82e46e7..536405e 100644 --- a/src/services/core-dotnet/AetherGuard.Core/Controllers/ExternalSignalsController.cs +++ b/src/services/core-dotnet/AetherGuard.Core/Controllers/ExternalSignalsController.cs @@ -78,4 +78,25 @@ public async Task GetSignals( _logger.LogInformation("Returned {Count} external signals.", results.Count); return Ok(results); } + + [HttpGet("feeds")] + public async Task GetFeedStates(CancellationToken cancellationToken = default) + { + var feeds = await _db.ExternalSignalFeedStates + .AsNoTracking() + .OrderBy(state => state.Name) + .Select(state => new + { + state.Name, + state.Url, + state.LastFetchAt, + state.LastSuccessAt, + state.FailureCount, + state.LastError, + state.LastStatusCode + }) + .ToListAsync(cancellationToken); + + return Ok(feeds); + } } diff --git a/src/services/core-dotnet/AetherGuard.Core/Data/ApplicationDbContext.cs b/src/services/core-dotnet/AetherGuard.Core/Data/ApplicationDbContext.cs index d545952..04c7261 100644 --- a/src/services/core-dotnet/AetherGuard.Core/Data/ApplicationDbContext.cs +++ b/src/services/core-dotnet/AetherGuard.Core/Data/ApplicationDbContext.cs @@ -13,6 +13,7 @@ public ApplicationDbContext(DbContextOptions options) public DbSet Agents => Set(); public DbSet AgentCommands => Set(); public DbSet CommandAudits => Set(); + public DbSet ExternalSignalFeedStates => Set(); public DbSet ExternalSignals => Set(); public DbSet TelemetryRecords => Set(); public DbSet SchemaRegistryEntries => Set(); @@ -108,5 +109,20 @@ protected override void OnModelCreating(ModelBuilder modelBuilder) entity.Property(e => e.PublishedAt).HasColumnName("published_at"); entity.Property(e => e.IngestedAt).HasColumnName("ingested_at"); }); + + modelBuilder.Entity(entity => + { + entity.ToTable("external_signal_feeds"); + entity.HasKey(e => e.Id); + entity.HasIndex(e => e.Name).IsUnique(); + entity.Property(e => e.Id).HasColumnName("id").ValueGeneratedOnAdd(); + entity.Property(e => e.Name).HasColumnName("name"); + entity.Property(e => e.Url).HasColumnName("url"); + entity.Property(e => e.LastFetchAt).HasColumnName("last_fetch_at"); + entity.Property(e => e.LastSuccessAt).HasColumnName("last_success_at"); + entity.Property(e => e.FailureCount).HasColumnName("failure_count"); + entity.Property(e => e.LastError).HasColumnName("last_error"); + entity.Property(e => e.LastStatusCode).HasColumnName("last_status_code"); + }); } } diff --git a/src/services/core-dotnet/AetherGuard.Core/Data/Migrations/20260202103549_AddExternalSignalFeeds.Designer.cs b/src/services/core-dotnet/AetherGuard.Core/Data/Migrations/20260202103549_AddExternalSignalFeeds.Designer.cs new file mode 100644 index 0000000..97eff7e --- /dev/null +++ b/src/services/core-dotnet/AetherGuard.Core/Data/Migrations/20260202103549_AddExternalSignalFeeds.Designer.cs @@ -0,0 +1,379 @@ +// +using System; +using AetherGuard.Core.Data; +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Infrastructure; +using Microsoft.EntityFrameworkCore.Migrations; +using Microsoft.EntityFrameworkCore.Storage.ValueConversion; +using Npgsql.EntityFrameworkCore.PostgreSQL.Metadata; + +#nullable disable + +namespace AetherGuard.Core.Data.Migrations +{ + [DbContext(typeof(ApplicationDbContext))] + [Migration("20260202103549_AddExternalSignalFeeds")] + partial class AddExternalSignalFeeds + { + /// + protected override void BuildTargetModel(ModelBuilder modelBuilder) + { +#pragma warning disable 612, 618 + modelBuilder + .HasAnnotation("ProductVersion", "8.0.0") + .HasAnnotation("Relational:MaxIdentifierLength", 63); + + NpgsqlModelBuilderExtensions.UseIdentityByDefaultColumns(modelBuilder); + + modelBuilder.Entity("AetherGuard.Core.Models.Agent", b => + { + b.Property("Id") + .ValueGeneratedOnAdd() + .HasColumnType("uuid") + .HasColumnName("id"); + + b.Property("AgentToken") + .IsRequired() + .HasColumnType("text") + .HasColumnName("agenttoken"); + + b.Property("Hostname") + .IsRequired() + .HasColumnType("text") + .HasColumnName("hostname"); + + b.Property("LastHeartbeat") + .HasColumnType("timestamp with time zone") + .HasColumnName("lastheartbeat"); + + b.Property("Status") + .IsRequired() + .HasColumnType("text") + .HasColumnName("status"); + + b.HasKey("Id"); + + b.HasIndex("AgentToken") + .IsUnique(); + + b.ToTable("agents", (string)null); + }); + + modelBuilder.Entity("AetherGuard.Core.Models.AgentCommand", b => + { + b.Property("Id") + .ValueGeneratedOnAdd() + .HasColumnType("integer") + .HasColumnName("id"); + + NpgsqlPropertyBuilderExtensions.UseIdentityByDefaultColumn(b.Property("Id")); + + b.Property("Action") + .IsRequired() + .HasColumnType("text") + .HasColumnName("action"); + + b.Property("AgentId") + .HasColumnType("uuid") + .HasColumnName("agent_id"); + + b.Property("CommandId") + .HasColumnType("uuid") + .HasColumnName("command_id"); + + b.Property("CreatedAt") + .HasColumnType("timestamp with time zone") + .HasColumnName("created_at"); + + b.Property("ExpiresAt") + .HasColumnType("timestamp with time zone") + .HasColumnName("expires_at"); + + b.Property("Nonce") + .IsRequired() + .HasColumnType("text") + .HasColumnName("nonce"); + + b.Property("Parameters") + .IsRequired() + .HasColumnType("text") + .HasColumnName("parameters"); + + b.Property("Signature") + .IsRequired() + .HasColumnType("text") + .HasColumnName("signature"); + + b.Property("Status") + .IsRequired() + .HasColumnType("text") + .HasColumnName("status"); + + b.Property("UpdatedAt") + .HasColumnType("timestamp with time zone") + .HasColumnName("updated_at"); + + b.Property("WorkloadId") + .IsRequired() + .HasColumnType("text") + .HasColumnName("workload_id"); + + b.HasKey("Id"); + + b.ToTable("agent_commands", (string)null); + }); + + modelBuilder.Entity("AetherGuard.Core.Models.CommandAudit", b => + { + b.Property("Id") + .ValueGeneratedOnAdd() + .HasColumnType("bigint") + .HasColumnName("id"); + + NpgsqlPropertyBuilderExtensions.UseIdentityByDefaultColumn(b.Property("Id")); + + b.Property("Action") + .IsRequired() + .HasColumnType("text") + .HasColumnName("action"); + + b.Property("Actor") + .IsRequired() + .HasColumnType("text") + .HasColumnName("actor"); + + b.Property("CommandId") + .HasColumnType("uuid") + .HasColumnName("command_id"); + + b.Property("CreatedAt") + .HasColumnType("timestamp with time zone") + .HasColumnName("created_at"); + + b.Property("Error") + .IsRequired() + .HasColumnType("text") + .HasColumnName("error"); + + b.Property("Result") + .IsRequired() + .HasColumnType("text") + .HasColumnName("result"); + + b.HasKey("Id"); + + b.ToTable("command_audits", (string)null); + }); + + modelBuilder.Entity("AetherGuard.Core.Models.ExternalSignal", b => + { + b.Property("Id") + .ValueGeneratedOnAdd() + .HasColumnType("bigint") + .HasColumnName("id"); + + NpgsqlPropertyBuilderExtensions.UseIdentityByDefaultColumn(b.Property("Id")); + + b.Property("Category") + .HasColumnType("text") + .HasColumnName("category"); + + b.Property("ExternalId") + .IsRequired() + .HasColumnType("text") + .HasColumnName("external_id"); + + b.Property("IngestedAt") + .HasColumnType("timestamp with time zone") + .HasColumnName("ingested_at"); + + b.Property("PublishedAt") + .HasColumnType("timestamp with time zone") + .HasColumnName("published_at"); + + b.Property("Region") + .HasColumnType("text") + .HasColumnName("region"); + + b.Property("Severity") + .HasColumnType("text") + .HasColumnName("severity"); + + b.Property("Source") + .IsRequired() + .HasColumnType("text") + .HasColumnName("source"); + + b.Property("Summary") + .HasColumnType("text") + .HasColumnName("summary"); + + b.Property("Tags") + .HasColumnType("text") + .HasColumnName("tags"); + + b.Property("Title") + .IsRequired() + .HasColumnType("text") + .HasColumnName("title"); + + b.Property("Url") + .HasColumnType("text") + .HasColumnName("url"); + + b.HasKey("Id"); + + b.HasIndex("PublishedAt"); + + b.HasIndex("Source", "ExternalId") + .IsUnique(); + + b.ToTable("external_signals", (string)null); + }); + + modelBuilder.Entity("AetherGuard.Core.Models.ExternalSignalFeedState", b => + { + b.Property("Id") + .ValueGeneratedOnAdd() + .HasColumnType("bigint") + .HasColumnName("id"); + + NpgsqlPropertyBuilderExtensions.UseIdentityByDefaultColumn(b.Property("Id")); + + b.Property("FailureCount") + .HasColumnType("integer") + .HasColumnName("failure_count"); + + b.Property("LastError") + .HasColumnType("text") + .HasColumnName("last_error"); + + b.Property("LastFetchAt") + .HasColumnType("timestamp with time zone") + .HasColumnName("last_fetch_at"); + + b.Property("LastStatusCode") + .HasColumnType("integer") + .HasColumnName("last_status_code"); + + b.Property("LastSuccessAt") + .HasColumnType("timestamp with time zone") + .HasColumnName("last_success_at"); + + b.Property("Name") + .IsRequired() + .HasColumnType("text") + .HasColumnName("name"); + + b.Property("Url") + .IsRequired() + .HasColumnType("text") + .HasColumnName("url"); + + b.HasKey("Id"); + + b.HasIndex("Name") + .IsUnique(); + + b.ToTable("external_signal_feeds", (string)null); + }); + + modelBuilder.Entity("AetherGuard.Core.Models.SchemaRegistryEntry", b => + { + b.Property("Id") + .ValueGeneratedOnAdd() + .HasColumnType("bigint") + .HasColumnName("id"); + + NpgsqlPropertyBuilderExtensions.UseIdentityByDefaultColumn(b.Property("Id")); + + b.Property("CreatedAt") + .HasColumnType("timestamp with time zone") + .HasColumnName("created_at"); + + b.Property("Schema") + .IsRequired() + .HasColumnType("text") + .HasColumnName("schema"); + + b.Property("Subject") + .IsRequired() + .HasColumnType("text") + .HasColumnName("subject"); + + b.Property("Version") + .HasColumnType("integer") + .HasColumnName("version"); + + b.HasKey("Id"); + + b.HasIndex("Subject", "Version") + .IsUnique(); + + b.ToTable("schema_registry", (string)null); + }); + + modelBuilder.Entity("AetherGuard.Core.Models.TelemetryRecord", b => + { + b.Property("Id") + .ValueGeneratedOnAdd() + .HasColumnType("bigint") + .HasColumnName("Id"); + + NpgsqlPropertyBuilderExtensions.UseIdentityByDefaultColumn(b.Property("Id")); + + b.Property("Timestamp") + .HasColumnType("timestamp with time zone") + .HasColumnName("Timestamp"); + + b.Property("AgentId") + .IsRequired() + .HasColumnType("text") + .HasColumnName("AgentId"); + + b.Property("AiConfidence") + .HasColumnType("double precision") + .HasColumnName("AiConfidence"); + + b.Property("AiStatus") + .IsRequired() + .HasColumnType("text") + .HasColumnName("AiStatus"); + + b.Property("CpuUsage") + .HasColumnType("double precision") + .HasColumnName("CpuUsage"); + + b.Property("DiskAvailable") + .HasColumnType("bigint") + .HasColumnName("DiskAvailable"); + + b.Property("MemoryUsage") + .HasColumnType("double precision") + .HasColumnName("MemoryUsage"); + + b.Property("PredictedCpu") + .HasColumnType("double precision") + .HasColumnName("PredictedCpu"); + + b.Property("RebalanceSignal") + .HasColumnType("boolean") + .HasColumnName("RebalanceSignal"); + + b.Property("RootCause") + .HasColumnType("text") + .HasColumnName("RootCause"); + + b.Property("WorkloadTier") + .IsRequired() + .HasColumnType("text") + .HasColumnName("WorkloadTier"); + + b.HasKey("Id", "Timestamp"); + + b.ToTable("TelemetryRecords", (string)null); + }); +#pragma warning restore 612, 618 + } + } +} diff --git a/src/services/core-dotnet/AetherGuard.Core/Data/Migrations/20260202103549_AddExternalSignalFeeds.cs b/src/services/core-dotnet/AetherGuard.Core/Data/Migrations/20260202103549_AddExternalSignalFeeds.cs new file mode 100644 index 0000000..b166802 --- /dev/null +++ b/src/services/core-dotnet/AetherGuard.Core/Data/Migrations/20260202103549_AddExternalSignalFeeds.cs @@ -0,0 +1,48 @@ +using System; +using Microsoft.EntityFrameworkCore.Migrations; +using Npgsql.EntityFrameworkCore.PostgreSQL.Metadata; + +#nullable disable + +namespace AetherGuard.Core.Data.Migrations +{ + /// + public partial class AddExternalSignalFeeds : Migration + { + /// + protected override void Up(MigrationBuilder migrationBuilder) + { + migrationBuilder.CreateTable( + name: "external_signal_feeds", + columns: table => new + { + id = table.Column(type: "bigint", nullable: false) + .Annotation("Npgsql:ValueGenerationStrategy", NpgsqlValueGenerationStrategy.IdentityByDefaultColumn), + name = table.Column(type: "text", nullable: false), + url = table.Column(type: "text", nullable: false), + last_fetch_at = table.Column(type: "timestamp with time zone", nullable: false), + last_success_at = table.Column(type: "timestamp with time zone", nullable: true), + failure_count = table.Column(type: "integer", nullable: false), + last_error = table.Column(type: "text", nullable: true), + last_status_code = table.Column(type: "integer", nullable: true) + }, + constraints: table => + { + table.PrimaryKey("PK_external_signal_feeds", x => x.id); + }); + + migrationBuilder.CreateIndex( + name: "IX_external_signal_feeds_name", + table: "external_signal_feeds", + column: "name", + unique: true); + } + + /// + protected override void Down(MigrationBuilder migrationBuilder) + { + migrationBuilder.DropTable( + name: "external_signal_feeds"); + } + } +} diff --git a/src/services/core-dotnet/AetherGuard.Core/Data/Migrations/ApplicationDbContextModelSnapshot.cs b/src/services/core-dotnet/AetherGuard.Core/Data/Migrations/ApplicationDbContextModelSnapshot.cs index aa841dc..9c007b4 100644 --- a/src/services/core-dotnet/AetherGuard.Core/Data/Migrations/ApplicationDbContextModelSnapshot.cs +++ b/src/services/core-dotnet/AetherGuard.Core/Data/Migrations/ApplicationDbContextModelSnapshot.cs @@ -228,6 +228,53 @@ protected override void BuildModel(ModelBuilder modelBuilder) b.ToTable("external_signals", (string)null); }); + modelBuilder.Entity("AetherGuard.Core.Models.ExternalSignalFeedState", b => + { + b.Property("Id") + .ValueGeneratedOnAdd() + .HasColumnType("bigint") + .HasColumnName("id"); + + NpgsqlPropertyBuilderExtensions.UseIdentityByDefaultColumn(b.Property("Id")); + + b.Property("FailureCount") + .HasColumnType("integer") + .HasColumnName("failure_count"); + + b.Property("LastError") + .HasColumnType("text") + .HasColumnName("last_error"); + + b.Property("LastFetchAt") + .HasColumnType("timestamp with time zone") + .HasColumnName("last_fetch_at"); + + b.Property("LastStatusCode") + .HasColumnType("integer") + .HasColumnName("last_status_code"); + + b.Property("LastSuccessAt") + .HasColumnType("timestamp with time zone") + .HasColumnName("last_success_at"); + + b.Property("Name") + .IsRequired() + .HasColumnType("text") + .HasColumnName("name"); + + b.Property("Url") + .IsRequired() + .HasColumnType("text") + .HasColumnName("url"); + + b.HasKey("Id"); + + b.HasIndex("Name") + .IsUnique(); + + b.ToTable("external_signal_feeds", (string)null); + }); + modelBuilder.Entity("AetherGuard.Core.Models.SchemaRegistryEntry", b => { b.Property("Id") diff --git a/src/services/core-dotnet/AetherGuard.Core/Services/ExternalSignals/ExternalSignalIngestionService.cs b/src/services/core-dotnet/AetherGuard.Core/Services/ExternalSignals/ExternalSignalIngestionService.cs index cc0b20e..d6fbb4f 100644 --- a/src/services/core-dotnet/AetherGuard.Core/Services/ExternalSignals/ExternalSignalIngestionService.cs +++ b/src/services/core-dotnet/AetherGuard.Core/Services/ExternalSignals/ExternalSignalIngestionService.cs @@ -1,7 +1,3 @@ -using System.Globalization; -using System.Net; -using System.Text.RegularExpressions; -using System.Xml.Linq; using AetherGuard.Core.Data; using AetherGuard.Core.Models; using Microsoft.EntityFrameworkCore; @@ -11,19 +7,6 @@ namespace AetherGuard.Core.Services.ExternalSignals; public sealed class ExternalSignalIngestionService : BackgroundService { - private static readonly string[] SeverityKeywords = - [ - "outage", - "degraded", - "disruption", - "investigating", - "incident", - "partial", - "unavailable", - "restored", - "resolved" - ]; - private readonly IHttpClientFactory _httpClientFactory; private readonly IServiceScopeFactory _scopeFactory; private readonly ILogger _logger; @@ -100,12 +83,15 @@ private async Task IngestOnceAsync(CancellationToken cancellationToken) try { var response = await httpClient.GetAsync(feed.Url, cancellationToken); + var statusCode = (int)response.StatusCode; response.EnsureSuccessStatusCode(); var content = await response.Content.ReadAsStringAsync(cancellationToken); - parsed = ParseFeed(content, feed); + parsed = ExternalSignalParser.ParseFeed(content, feed); + await UpdateFeedStateAsync(db, feed, statusCode, null, cancellationToken); } catch (Exception ex) { + await UpdateFeedStateAsync(db, feed, null, ex.Message, cancellationToken); _logger.LogWarning(ex, "Failed to fetch external signal feed {Feed}.", feed.Name); continue; } @@ -150,178 +136,42 @@ private async Task IngestOnceAsync(CancellationToken cancellationToken) } } - private static List ParseFeed(string content, ExternalSignalFeedOptions feed) + private static async Task UpdateFeedStateAsync( + ApplicationDbContext db, + ExternalSignalFeedOptions feed, + int? statusCode, + string? error, + CancellationToken cancellationToken) { - var document = XDocument.Parse(content); - var root = document.Root; - if (root is null) - { - return []; - } - - var items = root.Name.LocalName switch - { - "rss" => root.Descendants().Where(e => e.Name.LocalName == "item"), - "feed" => root.Descendants().Where(e => e.Name.LocalName == "entry"), - _ => root.Descendants().Where(e => e.Name.LocalName == "item" || e.Name.LocalName == "entry") - }; - - var signals = new List(); - foreach (var item in items) - { - var title = GetValue(item, "title"); - if (string.IsNullOrWhiteSpace(title)) - { - continue; - } - - var externalId = GetValue(item, "guid") ?? GetValue(item, "id") ?? title; - var url = GetValue(item, "link"); - if (string.IsNullOrWhiteSpace(url)) - { - url = GetLinkHref(item); - } - - var summaryRaw = GetValue(item, "description") ?? GetValue(item, "summary") ?? GetValue(item, "content"); - var summary = NormalizeText(summaryRaw); - var published = ParseDate(GetValue(item, "pubDate")) - ?? ParseDate(GetValue(item, "updated")) - ?? ParseDate(GetValue(item, "published")) - ?? DateTimeOffset.UtcNow; - - var severity = GuessSeverity(title, summary); - var category = GuessCategory(title, summary); - var region = ExtractRegion(title, summary) ?? feed.DefaultRegion; - - signals.Add(new ExternalSignal - { - Source = feed.Name, - ExternalId = externalId, - Title = title, - Summary = summary, - Region = region, - Severity = severity, - Category = category, - Url = url, - Tags = BuildTags(title, summary, severity, category), - PublishedAt = published, - IngestedAt = DateTimeOffset.UtcNow - }); - } - - return signals; - } - - private static string? GetValue(XElement parent, string name) - { - return parent.Elements().FirstOrDefault(e => e.Name.LocalName.Equals(name, StringComparison.OrdinalIgnoreCase))?.Value; - } - - private static string? GetLinkHref(XElement parent) - { - var link = parent.Elements().FirstOrDefault(e => e.Name.LocalName.Equals("link", StringComparison.OrdinalIgnoreCase)); - return link?.Attribute("href")?.Value; - } - - private static DateTimeOffset? ParseDate(string? value) - { - if (string.IsNullOrWhiteSpace(value)) - { - return null; - } - - if (DateTimeOffset.TryParse(value, CultureInfo.InvariantCulture, DateTimeStyles.AssumeUniversal, out var parsed)) - { - return parsed; - } - - return null; - } + var state = await db.ExternalSignalFeedStates + .FirstOrDefaultAsync(existing => existing.Name == feed.Name, cancellationToken); - private static string? GuessSeverity(string title, string? summary) - { - var content = $"{title} {summary}".ToLowerInvariant(); - foreach (var keyword in SeverityKeywords) + if (state is null) { - if (content.Contains(keyword, StringComparison.OrdinalIgnoreCase)) + state = new ExternalSignalFeedState { - return keyword; - } - } - - return "info"; - } - - private static string? GuessCategory(string title, string? summary) - { - var content = $"{title} {summary}".ToLowerInvariant(); - if (content.Contains("maintenance")) - { - return "maintenance"; - } - if (content.Contains("outage") || content.Contains("disruption")) - { - return "outage"; - } - if (content.Contains("degraded")) - { - return "degraded"; - } - if (content.Contains("incident") || content.Contains("investigating")) - { - return "incident"; + Name = feed.Name, + Url = feed.Url + }; + db.ExternalSignalFeedStates.Add(state); } - return "notice"; - } + state.Url = feed.Url; + state.LastFetchAt = DateTimeOffset.UtcNow; + state.LastStatusCode = statusCode; - private static string? ExtractRegion(string title, string? summary) - { - var content = $"{title} {summary}"; - var match = System.Text.RegularExpressions.Regex.Match(content, @"\b([a-z]{2}-[a-z]+-\d)\b", System.Text.RegularExpressions.RegexOptions.IgnoreCase); - if (match.Success) + if (string.IsNullOrWhiteSpace(error)) { - return match.Groups[1].Value; + state.LastSuccessAt = state.LastFetchAt; + state.LastError = null; + state.FailureCount = 0; } - - return null; - } - - private static string? BuildTags(string title, string? summary, string? severity, string? category) - { - var tags = new List(); - if (!string.IsNullOrWhiteSpace(severity)) - { - tags.Add(severity); - } - - if (!string.IsNullOrWhiteSpace(category) && !tags.Any(tag => string.Equals(tag, category, StringComparison.OrdinalIgnoreCase))) - { - tags.Add(category); - } - - if (title.Contains("region", StringComparison.OrdinalIgnoreCase)) - { - tags.Add("region"); - } - - if (!string.IsNullOrWhiteSpace(summary) && summary.Contains("latency", StringComparison.OrdinalIgnoreCase)) - { - tags.Add("latency"); - } - - return tags.Count > 0 ? string.Join(",", tags) : null; - } - - private static string? NormalizeText(string? raw) - { - if (string.IsNullOrWhiteSpace(raw)) + else { - return raw; + state.LastError = error.Length > 1000 ? error[..1000] : error; + state.FailureCount += 1; } - var decoded = WebUtility.HtmlDecode(raw); - var withoutTags = Regex.Replace(decoded, "<.*?>", string.Empty, RegexOptions.Singleline); - return withoutTags.Trim(); + await db.SaveChangesAsync(cancellationToken); } } diff --git a/src/services/core-dotnet/AetherGuard.Core/Services/ExternalSignals/ExternalSignalParser.cs b/src/services/core-dotnet/AetherGuard.Core/Services/ExternalSignals/ExternalSignalParser.cs new file mode 100644 index 0000000..7ab257a --- /dev/null +++ b/src/services/core-dotnet/AetherGuard.Core/Services/ExternalSignals/ExternalSignalParser.cs @@ -0,0 +1,201 @@ +using System.Globalization; +using System.Net; +using System.Text.RegularExpressions; +using System.Xml.Linq; +using AetherGuard.Core.Models; + +namespace AetherGuard.Core.Services.ExternalSignals; + +public static class ExternalSignalParser +{ + private static readonly string[] SeverityKeywords = + [ + "outage", + "degraded", + "disruption", + "investigating", + "incident", + "partial", + "unavailable", + "restored", + "resolved" + ]; + + public static List ParseFeed(string content, ExternalSignalFeedOptions feed) + { + var document = XDocument.Parse(content); + var root = document.Root; + if (root is null) + { + return []; + } + + var items = root.Name.LocalName switch + { + "rss" => root.Descendants().Where(e => e.Name.LocalName == "item"), + "feed" => root.Descendants().Where(e => e.Name.LocalName == "entry"), + _ => root.Descendants().Where(e => e.Name.LocalName == "item" || e.Name.LocalName == "entry") + }; + + var signals = new List(); + foreach (var item in items) + { + var title = GetValue(item, "title"); + if (string.IsNullOrWhiteSpace(title)) + { + continue; + } + + var externalId = GetValue(item, "guid") ?? GetValue(item, "id") ?? title; + var url = GetValue(item, "link"); + if (string.IsNullOrWhiteSpace(url)) + { + url = GetLinkHref(item); + } + + var summaryRaw = GetValue(item, "description") ?? GetValue(item, "summary") ?? GetValue(item, "content"); + var summary = NormalizeText(summaryRaw); + var published = ParseDate(GetValue(item, "pubDate")) + ?? ParseDate(GetValue(item, "updated")) + ?? ParseDate(GetValue(item, "published")) + ?? DateTimeOffset.UtcNow; + + var severity = GuessSeverity(title, summary); + var category = GuessCategory(title, summary); + var region = ExtractRegion(title, summary) ?? feed.DefaultRegion; + + signals.Add(new ExternalSignal + { + Source = feed.Name, + ExternalId = externalId, + Title = title, + Summary = summary, + Region = region, + Severity = severity, + Category = category, + Url = url, + Tags = BuildTags(title, summary, severity, category), + PublishedAt = published, + IngestedAt = DateTimeOffset.UtcNow + }); + } + + return signals; + } + + private static string? GetValue(XElement parent, string name) + { + return parent.Elements().FirstOrDefault(e => e.Name.LocalName.Equals(name, StringComparison.OrdinalIgnoreCase))?.Value; + } + + private static string? GetLinkHref(XElement parent) + { + var link = parent.Elements().FirstOrDefault(e => e.Name.LocalName.Equals("link", StringComparison.OrdinalIgnoreCase)); + return link?.Attribute("href")?.Value; + } + + private static DateTimeOffset? ParseDate(string? value) + { + if (string.IsNullOrWhiteSpace(value)) + { + return null; + } + + if (DateTimeOffset.TryParse(value, CultureInfo.InvariantCulture, DateTimeStyles.AssumeUniversal, out var parsed)) + { + return parsed; + } + + return null; + } + + private static string? GuessSeverity(string title, string? summary) + { + var content = $"{title} {summary}".ToLowerInvariant(); + foreach (var keyword in SeverityKeywords) + { + if (content.Contains(keyword, StringComparison.OrdinalIgnoreCase)) + { + return keyword; + } + } + + return "info"; + } + + private static string? GuessCategory(string title, string? summary) + { + var content = $"{title} {summary}".ToLowerInvariant(); + if (content.Contains("maintenance")) + { + return "maintenance"; + } + if (content.Contains("outage") || content.Contains("disruption")) + { + return "outage"; + } + if (content.Contains("degraded")) + { + return "degraded"; + } + if (content.Contains("incident") || content.Contains("investigating")) + { + return "incident"; + } + + return "notice"; + } + + private static string? ExtractRegion(string title, string? summary) + { + var content = $"{title} {summary}"; + var match = Regex.Match( + content, + @"\b([a-z]{2}-[a-z]+-\d|[a-z]{2}-[a-z]+[0-9])\b", + RegexOptions.IgnoreCase); + if (match.Success) + { + return match.Groups[1].Value; + } + + return null; + } + + private static string? BuildTags(string title, string? summary, string? severity, string? category) + { + var tags = new List(); + if (!string.IsNullOrWhiteSpace(severity)) + { + tags.Add(severity); + } + + if (!string.IsNullOrWhiteSpace(category) && !tags.Any(tag => string.Equals(tag, category, StringComparison.OrdinalIgnoreCase))) + { + tags.Add(category); + } + + if (title.Contains("region", StringComparison.OrdinalIgnoreCase)) + { + tags.Add("region"); + } + + if (!string.IsNullOrWhiteSpace(summary) && summary.Contains("latency", StringComparison.OrdinalIgnoreCase)) + { + tags.Add("latency"); + } + + return tags.Count > 0 ? string.Join(",", tags) : null; + } + + private static string? NormalizeText(string? raw) + { + if (string.IsNullOrWhiteSpace(raw)) + { + return raw; + } + + var decoded = WebUtility.HtmlDecode(raw); + var withoutTags = Regex.Replace(decoded, "<.*?>", string.Empty, RegexOptions.Singleline); + return withoutTags.Trim(); + } +} diff --git a/src/services/core-dotnet/AetherGuard.Core/appsettings.json b/src/services/core-dotnet/AetherGuard.Core/appsettings.json index 60a5dc2..04115d4 100644 --- a/src/services/core-dotnet/AetherGuard.Core/appsettings.json +++ b/src/services/core-dotnet/AetherGuard.Core/appsettings.json @@ -85,6 +85,11 @@ "Name": "gcp-status", "Url": "https://status.cloud.google.com/en/feed.atom", "DefaultRegion": "global" + }, + { + "Name": "azure-status", + "Url": "https://status.azure.com/en-us/status/feed/", + "DefaultRegion": "global" } ] }, diff --git a/src/services/core-dotnet/AetherGuard.Core/models/ExternalSignalFeedState.cs b/src/services/core-dotnet/AetherGuard.Core/models/ExternalSignalFeedState.cs new file mode 100644 index 0000000..337a479 --- /dev/null +++ b/src/services/core-dotnet/AetherGuard.Core/models/ExternalSignalFeedState.cs @@ -0,0 +1,13 @@ +namespace AetherGuard.Core.Models; + +public sealed class ExternalSignalFeedState +{ + public long Id { get; set; } + public string Name { get; set; } = string.Empty; + public string Url { get; set; } = string.Empty; + public DateTimeOffset LastFetchAt { get; set; } + public DateTimeOffset? LastSuccessAt { get; set; } + public int FailureCount { get; set; } + public string? LastError { get; set; } + public int? LastStatusCode { get; set; } +} From 8b48bee6122e56c6311d890c7f76e4747f6ae7fb Mon Sep 17 00:00:00 2001 From: JasonEran Date: Mon, 2 Feb 2026 19:48:55 +0800 Subject: [PATCH 03/24] feat: add external signals and feeds integration with dashboard components --- src/web/dashboard/app/DashboardClient.tsx | 105 +++++++++++++- .../dashboard/app/api/signals/feeds/route.ts | 27 ++++ src/web/dashboard/app/api/signals/route.ts | 31 ++++ .../components/ExternalSignalsPanel.tsx | 135 ++++++++++++++++++ src/web/dashboard/lib/api.ts | 33 ++++- src/web/dashboard/types/index.ts | 25 ++++ 6 files changed, 351 insertions(+), 5 deletions(-) create mode 100644 src/web/dashboard/app/api/signals/feeds/route.ts create mode 100644 src/web/dashboard/app/api/signals/route.ts create mode 100644 src/web/dashboard/components/ExternalSignalsPanel.tsx diff --git a/src/web/dashboard/app/DashboardClient.tsx b/src/web/dashboard/app/DashboardClient.tsx index 201e063..84912a9 100644 --- a/src/web/dashboard/app/DashboardClient.tsx +++ b/src/web/dashboard/app/DashboardClient.tsx @@ -5,11 +5,20 @@ import { signOut } from 'next-auth/react'; import AuditLogStream from '../components/AuditLogStream'; import ExplainabilityPanel from '../components/ExplainabilityPanel'; +import ExternalSignalsPanel from '../components/ExternalSignalsPanel'; import FirstRunGuide from '../components/FirstRunGuide'; import ControlPanel from '../components/ControlPanel'; import HistoryChart from '../components/HistoryChart'; -import { fetchAuditLogs, fetchFleetStatus, fetchRiskHistory, sendChaosSignal, RiskPoint } from '../lib/api'; -import type { Agent, AuditLog } from '../types'; +import { + fetchAuditLogs, + fetchExternalSignalFeeds, + fetchExternalSignals, + fetchFleetStatus, + fetchRiskHistory, + sendChaosSignal, + RiskPoint, +} from '../lib/api'; +import type { Agent, AuditLog, ExternalSignal, ExternalSignalFeedState } from '../types'; interface DashboardClientProps { userName: string; @@ -103,6 +112,81 @@ const buildMockPayload = (now: number, chaosActive: boolean) => { }, ]; + const mockSignals: ExternalSignal[] = [ + { + id: 1, + source: 'aws-status', + externalId: 'aws-2026-01', + title: 'Elevated latency in us-east-1 availability zones', + summary: 'Investigating increased API latency affecting spot capacity decisions.', + region: 'us-east-1', + severity: chaosActive ? 'critical' : 'warning', + category: 'latency', + url: 'https://status.aws.amazon.com', + tags: 'latency,spot', + publishedAt: new Date(now - 18 * 60_000).toISOString(), + ingestedAt: new Date(now - 15 * 60_000).toISOString(), + }, + { + id: 2, + source: 'gcp-status', + externalId: 'gcp-2026-02', + title: 'Compute Engine spot capacity advisory', + summary: 'Spot instance preemption notices increased in europe-west4.', + region: 'europe-west4', + severity: 'warning', + category: 'capacity', + url: 'https://status.cloud.google.com', + tags: 'capacity,preemption', + publishedAt: new Date(now - 55 * 60_000).toISOString(), + ingestedAt: new Date(now - 50 * 60_000).toISOString(), + }, + { + id: 3, + source: 'azure-status', + externalId: 'azure-2026-03', + title: 'Azure spot price volatility notice', + summary: 'Monitoring pricing volatility in West US 2.', + region: 'westus2', + severity: 'info', + category: 'pricing', + url: 'https://status.azure.com', + tags: 'pricing', + publishedAt: new Date(now - 2 * 60_000 * 60).toISOString(), + ingestedAt: new Date(now - 110 * 60_000).toISOString(), + }, + ]; + + const mockFeeds: ExternalSignalFeedState[] = [ + { + name: 'aws-status', + url: 'https://status.aws.amazon.com/rss/all.rss', + lastFetchAt: new Date(now - 2 * 60_000).toISOString(), + lastSuccessAt: new Date(now - 2 * 60_000).toISOString(), + failureCount: 0, + lastError: null, + lastStatusCode: 200, + }, + { + name: 'gcp-status', + url: 'https://status.cloud.google.com/en/feed.atom', + lastFetchAt: new Date(now - 5 * 60_000).toISOString(), + lastSuccessAt: new Date(now - 5 * 60_000).toISOString(), + failureCount: 0, + lastError: null, + lastStatusCode: 200, + }, + { + name: 'azure-status', + url: 'https://status.azure.com/en-us/status/feed/', + lastFetchAt: new Date(now - 4 * 60_000).toISOString(), + lastSuccessAt: new Date(now - 4 * 60_000).toISOString(), + failureCount: chaosActive ? 1 : 0, + lastError: chaosActive ? 'Timeout fetching feed' : null, + lastStatusCode: chaosActive ? 504 : 200, + }, + ]; + if (chaosActive) { mockAudits.unshift({ id: `audit-${now}-chaos`, @@ -114,7 +198,7 @@ const buildMockPayload = (now: number, chaosActive: boolean) => { }); } - return { mockHistory, mockAgents, mockAudits }; + return { mockHistory, mockAgents, mockAudits, mockSignals, mockFeeds }; }; const formatLocalTime = (value?: string) => { @@ -134,7 +218,10 @@ export default function DashboardClient({ userName, userRole }: DashboardClientP const [agents, setAgents] = useState([]); const [history, setHistory] = useState([]); const [auditLogs, setAuditLogs] = useState([]); + const [signals, setSignals] = useState([]); + const [feeds, setFeeds] = useState([]); const [usingMock, setUsingMock] = useState(false); + const [signalsUsingMock, setSignalsUsingMock] = useState(false); const [lastUpdated, setLastUpdated] = useState(''); const [showFirstRunGuide, setShowFirstRunGuide] = useState(true); const mockChaosAtRef = useRef(null); @@ -167,15 +254,19 @@ export default function DashboardClient({ userName, userRole }: DashboardClientP let isMounted = true; const load = async () => { - const [fleetResult, historyResult, auditResult] = await Promise.allSettled([ + const [fleetResult, historyResult, auditResult, signalsResult, feedsResult] = await Promise.allSettled([ fetchFleetStatus(), fetchRiskHistory(), fetchAuditLogs(), + fetchExternalSignals(), + fetchExternalSignalFeeds(), ]); const fleetData = fleetResult.status === 'fulfilled' ? fleetResult.value : []; const historyData = historyResult.status === 'fulfilled' ? historyResult.value : []; const auditData = auditResult.status === 'fulfilled' ? auditResult.value : []; + const signalsData = signalsResult.status === 'fulfilled' ? signalsResult.value : []; + const feedsData = feedsResult.status === 'fulfilled' ? feedsResult.value : []; if (!isMounted) { return; @@ -188,7 +279,12 @@ export default function DashboardClient({ userName, userRole }: DashboardClientP const useMockHistory = historyData.length === 0; const useMockAudits = auditData.length === 0; const useMock = useMockFleet || useMockHistory || useMockAudits; + const useMockSignals = signalsData.length === 0; + const useMockFeeds = feedsData.length === 0; setUsingMock(useMock); + setSignalsUsingMock(useMockSignals || useMockFeeds); + setSignals(useMockSignals ? mockPayload.mockSignals : signalsData); + setFeeds(useMockFeeds ? mockPayload.mockFeeds : feedsData); if (useMock) { setAgents(mockPayload.mockAgents); setHistory(mockPayload.mockHistory); @@ -371,6 +467,7 @@ export default function DashboardClient({ userName, userRole }: DashboardClientP
+
Risk Trend
diff --git a/src/web/dashboard/app/api/signals/feeds/route.ts b/src/web/dashboard/app/api/signals/feeds/route.ts new file mode 100644 index 0000000..d10e931 --- /dev/null +++ b/src/web/dashboard/app/api/signals/feeds/route.ts @@ -0,0 +1,27 @@ +import { NextResponse } from 'next/server'; + +export const dynamic = 'force-dynamic'; +export const revalidate = 0; + +const coreBaseUrl = process.env.CORE_API_URL ?? 'http://core-service:8080'; + +export async function GET() { + const response = await fetch(`${coreBaseUrl}/api/v1/signals/feeds`, { cache: 'no-store' }); + const text = await response.text(); + if (!response.ok) { + return NextResponse.json({ error: text || 'Failed to fetch feed status.' }, { status: response.status }); + } + + if (!text) { + return NextResponse.json([], { status: response.status }); + } + + try { + return NextResponse.json(JSON.parse(text), { status: response.status }); + } catch { + return new NextResponse(text, { + status: response.status, + headers: { 'Content-Type': response.headers.get('content-type') ?? 'text/plain' }, + }); + } +} diff --git a/src/web/dashboard/app/api/signals/route.ts b/src/web/dashboard/app/api/signals/route.ts new file mode 100644 index 0000000..a52c206 --- /dev/null +++ b/src/web/dashboard/app/api/signals/route.ts @@ -0,0 +1,31 @@ +import { NextResponse } from 'next/server'; + +export const dynamic = 'force-dynamic'; +export const revalidate = 0; + +const coreBaseUrl = process.env.CORE_API_URL ?? 'http://core-service:8080'; + +export async function GET(request: Request) { + const { searchParams } = new URL(request.url); + const query = searchParams.toString(); + const url = query ? `${coreBaseUrl}/api/v1/signals?${query}` : `${coreBaseUrl}/api/v1/signals`; + + const response = await fetch(url, { cache: 'no-store' }); + const text = await response.text(); + if (!response.ok) { + return NextResponse.json({ error: text || 'Failed to fetch external signals.' }, { status: response.status }); + } + + if (!text) { + return NextResponse.json([], { status: response.status }); + } + + try { + return NextResponse.json(JSON.parse(text), { status: response.status }); + } catch { + return new NextResponse(text, { + status: response.status, + headers: { 'Content-Type': response.headers.get('content-type') ?? 'text/plain' }, + }); + } +} diff --git a/src/web/dashboard/components/ExternalSignalsPanel.tsx b/src/web/dashboard/components/ExternalSignalsPanel.tsx new file mode 100644 index 0000000..0da4c70 --- /dev/null +++ b/src/web/dashboard/components/ExternalSignalsPanel.tsx @@ -0,0 +1,135 @@ +'use client'; + +import type { ExternalSignal, ExternalSignalFeedState } from '../types'; + +interface ExternalSignalsPanelProps { + signals: ExternalSignal[]; + feeds: ExternalSignalFeedState[]; + usingMock?: boolean; +} + +const formatTimestamp = (value?: string) => { + if (!value) { + return '--'; + } + + const parsed = Date.parse(value); + if (Number.isNaN(parsed)) { + return value; + } + return new Date(parsed).toLocaleString(); +}; + +const renderSeverity = (value?: string | null) => { + if (!value) { + return { label: 'Unknown', style: 'border-slate-700 bg-slate-900/60 text-slate-300' }; + } + + const normalized = value.toLowerCase(); + if (normalized.includes('critical') || normalized.includes('outage')) { + return { label: value, style: 'border-red-500/50 bg-red-500/10 text-red-200' }; + } + + if (normalized.includes('warning') || normalized.includes('degraded')) { + return { label: value, style: 'border-amber-500/50 bg-amber-500/10 text-amber-200' }; + } + + return { label: value, style: 'border-emerald-500/40 bg-emerald-500/10 text-emerald-200' }; +}; + +const renderFeedStatus = (feed: ExternalSignalFeedState) => { + if (feed.failureCount > 0) { + return { label: 'Degraded', style: 'border-amber-500/50 bg-amber-500/10 text-amber-200' }; + } + + if (!feed.lastSuccessAt) { + return { label: 'Pending', style: 'border-slate-700 bg-slate-900/60 text-slate-300' }; + } + + return { label: 'Healthy', style: 'border-emerald-500/40 bg-emerald-500/10 text-emerald-200' }; +}; + +export default function ExternalSignalsPanel({ signals, feeds, usingMock }: ExternalSignalsPanelProps) { + return ( +
+
+
External Signals
+ + {usingMock ? 'Simulation' : 'Live Feeds'} + +
+ +
+
+ {signals.slice(0, 6).map((signal) => { + const severity = renderSeverity(signal.severity); + return ( +
+
+ {signal.source} + + {severity.label} + +
+
{signal.title}
+ {signal.summary && ( +
{signal.summary}
+ )} +
+ {signal.region ? `Region: ${signal.region}` : 'Global'} + {formatTimestamp(signal.publishedAt)} +
+
+ ); + })} + {signals.length === 0 && ( +
+ No external signals ingested yet. +
+ )} +
+ +
+
Feed Health
+ {feeds.map((feed) => { + const status = renderFeedStatus(feed); + return ( +
+
+ {feed.name} + + {status.label} + +
+
Last fetch: {formatTimestamp(feed.lastFetchAt)}
+
+ Last success: {formatTimestamp(feed.lastSuccessAt ?? undefined)} +
+ {feed.failureCount > 0 && ( +
Failures: {feed.failureCount}
+ )} +
+ ); + })} + {feeds.length === 0 && ( +
+ Feed status will appear after the first sync. +
+ )} +
+
+
+ ); +} diff --git a/src/web/dashboard/lib/api.ts b/src/web/dashboard/lib/api.ts index ae6a9ff..7602f1a 100644 --- a/src/web/dashboard/lib/api.ts +++ b/src/web/dashboard/lib/api.ts @@ -1,4 +1,4 @@ -import type { Agent, AuditLog } from '../types'; +import type { Agent, AuditLog, ExternalSignal, ExternalSignalFeedState } from '../types'; export interface RiskPoint { timestamp: string; @@ -167,3 +167,34 @@ export async function sendChaosSignal(): Promise { throw new Error(message || 'Chaos signal failed.'); } } + +export async function fetchExternalSignals(limit = 6): Promise { + try { + const params = new URLSearchParams({ limit: limit.toString() }); + const response = await fetch(`/api/signals?${params.toString()}`, { cache: 'no-store' }); + if (!response.ok) { + return []; + } + + const data = (await response.json()) as ExternalSignal[]; + return Array.isArray(data) ? data : []; + } catch (error) { + console.error('[Dashboard] Failed to fetch external signals', error); + return []; + } +} + +export async function fetchExternalSignalFeeds(): Promise { + try { + const response = await fetch('/api/signals/feeds', { cache: 'no-store' }); + if (!response.ok) { + return []; + } + + const data = (await response.json()) as ExternalSignalFeedState[]; + return Array.isArray(data) ? data : []; + } catch (error) { + console.error('[Dashboard] Failed to fetch external signal feeds', error); + return []; + } +} diff --git a/src/web/dashboard/types/index.ts b/src/web/dashboard/types/index.ts index 5861893..eec6a0f 100644 --- a/src/web/dashboard/types/index.ts +++ b/src/web/dashboard/types/index.ts @@ -22,3 +22,28 @@ export interface AuditLog { error?: string; timestamp: string; } + +export interface ExternalSignal { + id: number; + source: string; + externalId: string; + title: string; + summary?: string | null; + region?: string | null; + severity?: string | null; + category?: string | null; + url?: string | null; + tags?: string | null; + publishedAt: string; + ingestedAt: string; +} + +export interface ExternalSignalFeedState { + name: string; + url: string; + lastFetchAt: string; + lastSuccessAt?: string | null; + failureCount: number; + lastError?: string | null; + lastStatusCode?: number | null; +} From 330de646c9e56ad0b1971a146032005acdfafb90 Mon Sep 17 00:00:00 2001 From: JasonEran Date: Wed, 4 Feb 2026 14:02:34 +0800 Subject: [PATCH 04/24] feat: add support for optional HTTP listener when mTLS is enabled for dashboard/AI traffic --- CHANGELOG.md | 1 + README.md | 3 ++- docker-compose.yml | 2 ++ src/services/core-dotnet/AetherGuard.Core/Program.cs | 5 +++++ .../core-dotnet/AetherGuard.Core/Security/MtlsOptions.cs | 2 ++ src/services/core-dotnet/AetherGuard.Core/appsettings.json | 2 ++ 6 files changed, 14 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 318dba0..c0e3b4f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ Semantic Versioning. - v2.3 delivery roadmap in `docs/ROADMAP-v2.3.md`. - Expanded v2.3 roadmap with model choices, data sources, and validation guidance. - Verification scripts now support API key headers and optional agent build flags. +- Optional HTTP listener when mTLS is enabled to keep dashboard/AI traffic on port 8080. ### Changed - Agent now injects W3C trace headers for HTTP requests. diff --git a/README.md b/README.md index dc124a7..7ad8553 100644 --- a/README.md +++ b/README.md @@ -358,7 +358,8 @@ sidecars to issue and rotate X.509 SVIDs: - Core serves mTLS on `https://core-service:8443` (host-mapped to 5001). - Agent uses SPIFFE-issued certs from `/run/spiffe/certs` and calls the mTLS endpoint. -- HTTP on `http://core-service:8080` remains for dashboard/AI traffic. +- When `Security__Mtls__AllowHttp=true`, Core also listens on `http://core-service:8080` for + dashboard/AI traffic (host-mapped to 5000). Disable mTLS locally by setting: diff --git a/docker-compose.yml b/docker-compose.yml index a0a4f09..64757f9 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -21,6 +21,8 @@ services: ArtifactBaseUrl: "https://core-service:8443" Security__Mtls__Enabled: "true" Security__Mtls__Port: "8443" + Security__Mtls__AllowHttp: "true" + Security__Mtls__HttpPort: "8080" Security__Mtls__CertificatePath: "/run/spiffe/certs/svid.pem" Security__Mtls__KeyPath: "/run/spiffe/certs/svid_key.pem" Security__Mtls__BundlePath: "/run/spiffe/certs/bundle.pem" diff --git a/src/services/core-dotnet/AetherGuard.Core/Program.cs b/src/services/core-dotnet/AetherGuard.Core/Program.cs index 209083c..6f739c1 100644 --- a/src/services/core-dotnet/AetherGuard.Core/Program.cs +++ b/src/services/core-dotnet/AetherGuard.Core/Program.cs @@ -28,6 +28,11 @@ builder.WebHost.ConfigureKestrel(options => { + if (mtlsOptions.AllowHttp) + { + options.ListenAnyIP(mtlsOptions.HttpPort); + } + options.ListenAnyIP(mtlsOptions.Port, listenOptions => { listenOptions.UseHttps(httpsOptions => diff --git a/src/services/core-dotnet/AetherGuard.Core/Security/MtlsOptions.cs b/src/services/core-dotnet/AetherGuard.Core/Security/MtlsOptions.cs index d7e48db..32151d8 100644 --- a/src/services/core-dotnet/AetherGuard.Core/Security/MtlsOptions.cs +++ b/src/services/core-dotnet/AetherGuard.Core/Security/MtlsOptions.cs @@ -4,6 +4,8 @@ public sealed record MtlsOptions { public bool Enabled { get; init; } public int Port { get; init; } = 8443; + public bool AllowHttp { get; init; } + public int HttpPort { get; init; } = 8080; public string? CertificatePath { get; init; } public string? KeyPath { get; init; } public string? BundlePath { get; init; } diff --git a/src/services/core-dotnet/AetherGuard.Core/appsettings.json b/src/services/core-dotnet/AetherGuard.Core/appsettings.json index 04115d4..85c330e 100644 --- a/src/services/core-dotnet/AetherGuard.Core/appsettings.json +++ b/src/services/core-dotnet/AetherGuard.Core/appsettings.json @@ -52,6 +52,8 @@ "Mtls": { "Enabled": false, "Port": 8443, + "AllowHttp": false, + "HttpPort": 8080, "CertificatePath": "", "KeyPath": "", "BundlePath": "", From da6c1ce26845f7378465c8fa074612d433859de7 Mon Sep 17 00:00:00 2001 From: JasonEran Date: Wed, 11 Feb 2026 14:15:02 +0800 Subject: [PATCH 05/24] feat(v2.3): finish milestone 0 signals retention and normalization - add retention cleanup for external signals - normalize severity/region/tags in parser - add SPIRE join-token runbook + v2.3 smoke test - update Quickstart/SPIRE docs and parser tests --- docs/QA-SmokeTest-v2.3.md | 52 ++++++++++++++ docs/Quickstart.md | 8 +++ docs/Runbook-SPIRE-JoinToken.md | 50 +++++++++++++ docs/SPIRE-mTLS.md | 5 ++ .../ExternalSignalParserTests.cs | 2 + .../ExternalSignalIngestionService.cs | 45 ++++++++++++ .../ExternalSignals/ExternalSignalParser.cs | 71 +++++++++++-------- .../ExternalSignals/ExternalSignalsOptions.cs | 2 + .../AetherGuard.Core/appsettings.json | 2 + 9 files changed, 206 insertions(+), 31 deletions(-) create mode 100644 docs/QA-SmokeTest-v2.3.md create mode 100644 docs/Runbook-SPIRE-JoinToken.md diff --git a/docs/QA-SmokeTest-v2.3.md b/docs/QA-SmokeTest-v2.3.md new file mode 100644 index 0000000..598ef02 --- /dev/null +++ b/docs/QA-SmokeTest-v2.3.md @@ -0,0 +1,52 @@ +# v2.3 Smoke Test Checklist (Signals End-to-End) + +This checklist validates the v2.3 Milestone 0 flow: external signals ingestion, +API exposure, and dashboard proxying. + +## Preconditions + +- Docker Desktop running +- `COMMAND_API_KEY` set (for other endpoints) + +## Start stack + +```bash +# PowerShell +$env:COMMAND_API_KEY="changeme" +$env:ExternalSignals__Enabled="true" + +docker compose up --build -d +``` + +## Verify Core APIs + +```bash +curl http://localhost:5000/api/v1/signals?limit=1 +curl http://localhost:5000/api/v1/signals/feeds +``` + +Expected: +- `signals` returns at least one item (or empty array on first run) +- `feeds` returns AWS/GCP/Azure feed status with timestamps + +## Verify Dashboard proxy + +```bash +curl http://localhost:3000/api/signals?limit=1 +curl http://localhost:3000/api/signals/feeds +``` + +Expected: +- Dashboard endpoints return the same data as Core +- HTTP 200 responses + +## Optional: UI check + +- Open `http://localhost:3000` +- Confirm the External Signals panel renders items and feed health + +## Stop stack + +```bash +docker compose down +``` diff --git a/docs/Quickstart.md b/docs/Quickstart.md index 0cf0df7..8b3c548 100644 --- a/docs/Quickstart.md +++ b/docs/Quickstart.md @@ -26,6 +26,7 @@ docker compose up --build -d ``` For SPIRE-based mTLS details, see `docs/SPIRE-mTLS.md`. +For SPIRE join-token recovery, see `docs/Runbook-SPIRE-JoinToken.md`. For observability setup, see `docs/Observability.md`. Open the dashboard at `http://localhost:3000` and log in with: @@ -42,8 +43,13 @@ External signals ingestion is disabled by default. To enable: ```bash # PowerShell $env:ExternalSignals__Enabled="true" +# Optional retention tuning +$env:ExternalSignals__RetentionDays="30" +$env:ExternalSignals__CleanupBatchSize="500" # Bash export ExternalSignals__Enabled=true +export ExternalSignals__RetentionDays=30 +export ExternalSignals__CleanupBatchSize=500 ``` Then restart the core service. Signals are accessible via: @@ -58,6 +64,8 @@ Feed health status is available via: GET /api/v1/signals/feeds ``` +Smoke test checklist: `docs/QA-SmokeTest-v2.3.md`. + If you want to simulate migrations, start at least two agents: ```bash diff --git a/docs/Runbook-SPIRE-JoinToken.md b/docs/Runbook-SPIRE-JoinToken.md new file mode 100644 index 0000000..b0161e5 --- /dev/null +++ b/docs/Runbook-SPIRE-JoinToken.md @@ -0,0 +1,50 @@ +# SPIRE Join-Token Rotation Runbook (Docker Compose) + +This runbook addresses the "join token already used/expired" error when the SPIRE agent +cannot re-attest and SVIDs stop rotating. + +## Symptoms + +- spire-agent logs show: `failed to attest: join token does not exist or has already been used` +- core/agent mTLS requests fail due to expired SVIDs + +## Recovery Steps (Docker Compose) + +1) Stop the stack (optional but recommended): + +```bash +docker compose down +``` + +2) Remove the cached join token from the bootstrap volume: + +```bash +docker run --rm -v aether-guard_spire_bootstrap:/data alpine \ + sh -lc "rm -f /data/join-token /data/trust-bundle.pem" +``` + +3) Re-run the bootstrap container to generate a fresh join token: + +```bash +docker compose run --rm spire-bootstrap +``` + +4) Restart SPIRE agent + helpers so they re-attest and fetch new SVIDs: + +```bash +docker compose restart spire-agent spiffe-helper-core spiffe-helper-agent +``` + +5) Verify in logs: + +```bash +docker compose logs --tail=50 spire-agent +``` + +You should see `Node attestation was successful` and `Renewing X509-SVID`. + +## Notes + +- Join tokens are one-time use by default. +- If the token cache persists across restarts, you must clear the bootstrap volume. +- For production, prefer automated rotation and monitoring around SVID expiration. diff --git a/docs/SPIRE-mTLS.md b/docs/SPIRE-mTLS.md index ad88d3f..43d9e2f 100644 --- a/docs/SPIRE-mTLS.md +++ b/docs/SPIRE-mTLS.md @@ -45,6 +45,11 @@ docker compose exec core-service ls -la /run/spiffe/certs docker compose exec agent-service ls -la /run/spiffe/certs ``` +## Join-token rotation + +If the SPIRE agent reports `join token already used/expired`, follow the +runbook at `docs/Runbook-SPIRE-JoinToken.md`. + ## Notes - The SPIRE agent mounts the Docker Engine socket to attest workloads by label. diff --git a/src/services/core-dotnet/AetherGuard.Core.Tests/ExternalSignalParserTests.cs b/src/services/core-dotnet/AetherGuard.Core.Tests/ExternalSignalParserTests.cs index c6c4470..e9b787d 100644 --- a/src/services/core-dotnet/AetherGuard.Core.Tests/ExternalSignalParserTests.cs +++ b/src/services/core-dotnet/AetherGuard.Core.Tests/ExternalSignalParserTests.cs @@ -36,6 +36,7 @@ public void ParsesRssItems() Assert.Equal("aws-1", results[0].ExternalId); Assert.Equal("aws-status", results[0].Source); Assert.Equal("us-east-1", results[0].Region); + Assert.Equal("critical", results[0].Severity); } [Fact] @@ -67,5 +68,6 @@ public void ParsesAtomEntries() Assert.Equal("gcp-status", results[0].Source); Assert.Equal("us-central1", results[0].Region); Assert.Equal("incident", results[0].Category); + Assert.Equal("info", results[0].Severity); } } diff --git a/src/services/core-dotnet/AetherGuard.Core/Services/ExternalSignals/ExternalSignalIngestionService.cs b/src/services/core-dotnet/AetherGuard.Core/Services/ExternalSignals/ExternalSignalIngestionService.cs index d6fbb4f..2118199 100644 --- a/src/services/core-dotnet/AetherGuard.Core/Services/ExternalSignals/ExternalSignalIngestionService.cs +++ b/src/services/core-dotnet/AetherGuard.Core/Services/ExternalSignals/ExternalSignalIngestionService.cs @@ -134,6 +134,8 @@ private async Task IngestOnceAsync(CancellationToken cancellationToken) _logger.LogInformation("Ingested {Count} signals from {Feed}.", newSignals.Count, feed.Name); } + + await CleanupOldSignalsAsync(db, cancellationToken); } private static async Task UpdateFeedStateAsync( @@ -174,4 +176,47 @@ private static async Task UpdateFeedStateAsync( await db.SaveChangesAsync(cancellationToken); } + + private async Task CleanupOldSignalsAsync(ApplicationDbContext db, CancellationToken cancellationToken) + { + if (_options.RetentionDays <= 0) + { + return; + } + + var cutoff = DateTimeOffset.UtcNow.AddDays(-_options.RetentionDays); + var batchSize = Math.Clamp(_options.CleanupBatchSize, 50, 2000); + var totalRemoved = 0; + + while (true) + { + var ids = await db.ExternalSignals + .AsNoTracking() + .Where(signal => signal.PublishedAt < cutoff) + .OrderBy(signal => signal.PublishedAt) + .Select(signal => signal.Id) + .Take(batchSize) + .ToListAsync(cancellationToken); + + if (ids.Count == 0) + { + break; + } + + var toRemove = ids.Select(id => new ExternalSignal { Id = id }).ToList(); + db.ExternalSignals.RemoveRange(toRemove); + var removed = await db.SaveChangesAsync(cancellationToken); + totalRemoved += removed; + + if (ids.Count < batchSize) + { + break; + } + } + + if (totalRemoved > 0) + { + _logger.LogInformation("Removed {Count} external signals older than {Cutoff}.", totalRemoved, cutoff); + } + } } diff --git a/src/services/core-dotnet/AetherGuard.Core/Services/ExternalSignals/ExternalSignalParser.cs b/src/services/core-dotnet/AetherGuard.Core/Services/ExternalSignals/ExternalSignalParser.cs index 7ab257a..782654a 100644 --- a/src/services/core-dotnet/AetherGuard.Core/Services/ExternalSignals/ExternalSignalParser.cs +++ b/src/services/core-dotnet/AetherGuard.Core/Services/ExternalSignals/ExternalSignalParser.cs @@ -8,18 +8,9 @@ namespace AetherGuard.Core.Services.ExternalSignals; public static class ExternalSignalParser { - private static readonly string[] SeverityKeywords = - [ - "outage", - "degraded", - "disruption", - "investigating", - "incident", - "partial", - "unavailable", - "restored", - "resolved" - ]; + private static readonly string[] CriticalKeywords = ["outage", "disruption", "unavailable"]; + private static readonly string[] WarningKeywords = ["degraded", "incident", "investigating", "partial"]; + private static readonly string[] InfoKeywords = ["maintenance", "resolved", "restored", "notice"]; public static List ParseFeed(string content, ExternalSignalFeedOptions feed) { @@ -60,9 +51,9 @@ public static List ParseFeed(string content, ExternalSignalFeedO ?? ParseDate(GetValue(item, "published")) ?? DateTimeOffset.UtcNow; - var severity = GuessSeverity(title, summary); - var category = GuessCategory(title, summary); - var region = ExtractRegion(title, summary) ?? feed.DefaultRegion; + var severity = NormalizeSeverity(title, summary); + var category = NormalizeCategory(title, summary); + var region = NormalizeRegion(ExtractRegion(title, summary) ?? feed.DefaultRegion); signals.Add(new ExternalSignal { @@ -109,21 +100,34 @@ public static List ParseFeed(string content, ExternalSignalFeedO return null; } - private static string? GuessSeverity(string title, string? summary) + private static string? NormalizeSeverity(string title, string? summary) { var content = $"{title} {summary}".ToLowerInvariant(); - foreach (var keyword in SeverityKeywords) + if (CriticalKeywords.Any(keyword => content.Contains(keyword, StringComparison.OrdinalIgnoreCase))) { - if (content.Contains(keyword, StringComparison.OrdinalIgnoreCase)) - { - return keyword; - } + return "critical"; + } + + if (content.Contains("resolved", StringComparison.OrdinalIgnoreCase) + || content.Contains("restored", StringComparison.OrdinalIgnoreCase)) + { + return "info"; + } + + if (WarningKeywords.Any(keyword => content.Contains(keyword, StringComparison.OrdinalIgnoreCase))) + { + return "warning"; + } + + if (InfoKeywords.Any(keyword => content.Contains(keyword, StringComparison.OrdinalIgnoreCase))) + { + return "info"; } return "info"; } - private static string? GuessCategory(string title, string? summary) + private static string? NormalizeCategory(string title, string? summary) { var content = $"{title} {summary}".ToLowerInvariant(); if (content.Contains("maintenance")) @@ -153,12 +157,7 @@ public static List ParseFeed(string content, ExternalSignalFeedO content, @"\b([a-z]{2}-[a-z]+-\d|[a-z]{2}-[a-z]+[0-9])\b", RegexOptions.IgnoreCase); - if (match.Success) - { - return match.Groups[1].Value; - } - - return null; + return match.Success ? match.Groups[1].Value : null; } private static string? BuildTags(string title, string? summary, string? severity, string? category) @@ -166,12 +165,12 @@ public static List ParseFeed(string content, ExternalSignalFeedO var tags = new List(); if (!string.IsNullOrWhiteSpace(severity)) { - tags.Add(severity); + tags.Add(severity.ToLowerInvariant()); } if (!string.IsNullOrWhiteSpace(category) && !tags.Any(tag => string.Equals(tag, category, StringComparison.OrdinalIgnoreCase))) { - tags.Add(category); + tags.Add(category.ToLowerInvariant()); } if (title.Contains("region", StringComparison.OrdinalIgnoreCase)) @@ -184,7 +183,17 @@ public static List ParseFeed(string content, ExternalSignalFeedO tags.Add("latency"); } - return tags.Count > 0 ? string.Join(",", tags) : null; + return tags.Count > 0 ? string.Join(",", tags.Distinct(StringComparer.OrdinalIgnoreCase)) : null; + } + + private static string? NormalizeRegion(string? region) + { + if (string.IsNullOrWhiteSpace(region)) + { + return null; + } + + return region.Trim().ToLowerInvariant(); } private static string? NormalizeText(string? raw) diff --git a/src/services/core-dotnet/AetherGuard.Core/Services/ExternalSignals/ExternalSignalsOptions.cs b/src/services/core-dotnet/AetherGuard.Core/Services/ExternalSignals/ExternalSignalsOptions.cs index 663fba8..ab91b68 100644 --- a/src/services/core-dotnet/AetherGuard.Core/Services/ExternalSignals/ExternalSignalsOptions.cs +++ b/src/services/core-dotnet/AetherGuard.Core/Services/ExternalSignals/ExternalSignalsOptions.cs @@ -6,6 +6,8 @@ public sealed class ExternalSignalsOptions public int PollingIntervalSeconds { get; set; } = 300; public int LookbackHours { get; set; } = 48; public int MaxItemsPerFeed { get; set; } = 200; + public int RetentionDays { get; set; } = 30; + public int CleanupBatchSize { get; set; } = 500; public List Feeds { get; set; } = new(); } diff --git a/src/services/core-dotnet/AetherGuard.Core/appsettings.json b/src/services/core-dotnet/AetherGuard.Core/appsettings.json index 85c330e..0ff04ea 100644 --- a/src/services/core-dotnet/AetherGuard.Core/appsettings.json +++ b/src/services/core-dotnet/AetherGuard.Core/appsettings.json @@ -77,6 +77,8 @@ "PollingIntervalSeconds": 300, "LookbackHours": 48, "MaxItemsPerFeed": 200, + "RetentionDays": 30, + "CleanupBatchSize": 500, "Feeds": [ { "Name": "aws-status", From cf49688ebca58622d21c8fa20ac329fe12259192 Mon Sep 17 00:00:00 2001 From: JasonEran Date: Wed, 11 Feb 2026 14:28:44 +0800 Subject: [PATCH 06/24] Define and version the S_v/P_v/B_s schema --- docs/ARCHITECTURE-v2.3.md | 1 + src/services/ai-engine/main.py | 31 +++++++++++++++++++++++++++---- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/docs/ARCHITECTURE-v2.3.md b/docs/ARCHITECTURE-v2.3.md index bdf46a0..259fb87 100644 --- a/docs/ARCHITECTURE-v2.3.md +++ b/docs/ARCHITECTURE-v2.3.md @@ -65,6 +65,7 @@ Telemetry alone misses off-chart events. We introduce a semantic pipeline for cl - **Model**: a domain-adapted transformer (BERT-class). If economic signals are used, FinBERT is a reasonable baseline for finance-domain text; an LLM summarizer handles longer advisories and provider policy updates. - **Outputs (standardized)**: + - `schemaVersion`: semantic vector schema version. - `S_v`: sentiment vector (normalized polarity + severity). - `P_v`: volatility probability (0-1). - `B_s`: supply or capacity bias (long-horizon). diff --git a/src/services/ai-engine/main.py b/src/services/ai-engine/main.py index 9776658..08d0a09 100644 --- a/src/services/ai-engine/main.py +++ b/src/services/ai-engine/main.py @@ -16,6 +16,7 @@ logger = logging.getLogger("uvicorn.error") scorer = RiskScorer() +ENRICHMENT_SCHEMA_VERSION = "1.0" @asynccontextmanager @@ -53,9 +54,19 @@ class EnrichRequest(BaseModel): class EnrichResponse(BaseModel): - s_v: list[float] = Field(alias="S_v") - p_v: float = Field(alias="P_v") - b_s: float = Field(alias="B_s") + schema_version: str = Field(alias="schemaVersion", description="Semantic vector schema version.") + s_v: list[float] = Field( + alias="S_v", + description="Sentiment vector (normalized polarity + severity).", + ) + p_v: float = Field( + alias="P_v", + description="Volatility probability in the range [0, 1].", + ) + b_s: float = Field( + alias="B_s", + description="Supply or capacity bias (long-horizon signal).", + ) model_config = ConfigDict(populate_by_name=True) @@ -104,7 +115,19 @@ def enrich_signals(payload: EnrichRequest) -> EnrichResponse: p_v = 0.85 b_s = 0.2 - return EnrichResponse(S_v=s_v, P_v=p_v, B_s=b_s) + return EnrichResponse(schemaVersion=ENRICHMENT_SCHEMA_VERSION, S_v=s_v, P_v=p_v, B_s=b_s) + + +@app.get("/signals/enrich/schema") +def enrich_schema() -> dict: + return { + "schemaVersion": ENRICHMENT_SCHEMA_VERSION, + "fields": { + "S_v": "Sentiment vector (normalized polarity + severity).", + "P_v": "Volatility probability in [0, 1].", + "B_s": "Supply or capacity bias (long-horizon signal).", + }, + } def configure_tracing() -> None: endpoint = os.getenv("OTEL_EXPORTER_OTLP_ENDPOINT") From 7cc462702046e9ceef86a52f8bf0a3aca6f53d76 Mon Sep 17 00:00:00 2001 From: JasonEran Date: Wed, 11 Feb 2026 14:38:07 +0800 Subject: [PATCH 07/24] Added FinBERT integration + rollback mechanism --- docs/Quickstart.md | 24 ++++ docs/ROADMAP-v2.3.md | 1 + src/services/ai-engine/main.py | 165 +++++++++++++++++++++--- src/services/ai-engine/requirements.txt | 1 + 4 files changed, 173 insertions(+), 18 deletions(-) diff --git a/docs/Quickstart.md b/docs/Quickstart.md index 8b3c548..e5d8505 100644 --- a/docs/Quickstart.md +++ b/docs/Quickstart.md @@ -66,6 +66,30 @@ GET /api/v1/signals/feeds Smoke test checklist: `docs/QA-SmokeTest-v2.3.md`. +## Optional: Enable semantic enrichment (v2.3 Milestone 1) + +The AI engine defaults to a FinBERT-based enricher when dependencies are available. +You can force the provider via environment variables: + +```bash +# PowerShell +$env:AI_ENRICH_PROVIDER="finbert" # or "heuristic" +$env:AI_FINBERT_MODEL="ProsusAI/finbert" + +# Bash +export AI_ENRICH_PROVIDER=finbert +export AI_FINBERT_MODEL=ProsusAI/finbert +``` + +Note: the first FinBERT run downloads model weights and can take a few minutes. +Set `AI_ENRICH_PROVIDER=heuristic` if you need a fast, offline fallback. + +Schema details: + +``` +GET /signals/enrich/schema +``` + If you want to simulate migrations, start at least two agents: ```bash diff --git a/docs/ROADMAP-v2.3.md b/docs/ROADMAP-v2.3.md index d8948f6..a3e30f7 100644 --- a/docs/ROADMAP-v2.3.md +++ b/docs/ROADMAP-v2.3.md @@ -37,6 +37,7 @@ minimize integration risk while preserving backward compatibility with the v2.2 - Enrichment throughput supports expected signal volume. - Outputs are versioned and validated. +- FinBERT/semantic enrichment is operational (with heuristic fallback). ### Milestone 2: Fusion and Forecasting (Offline) diff --git a/src/services/ai-engine/main.py b/src/services/ai-engine/main.py index 08d0a09..6f805b7 100644 --- a/src/services/ai-engine/main.py +++ b/src/services/ai-engine/main.py @@ -1,6 +1,9 @@ import logging import os from contextlib import asynccontextmanager +from dataclasses import dataclass +from functools import lru_cache +from typing import Iterable from fastapi import FastAPI from pydantic import BaseModel, Field, ConfigDict @@ -18,11 +21,17 @@ scorer = RiskScorer() ENRICHMENT_SCHEMA_VERSION = "1.0" +DEFAULT_FINBERT_MODEL = os.getenv("AI_FINBERT_MODEL", "ProsusAI/finbert") +ENRICHMENT_PROVIDER = os.getenv("AI_ENRICH_PROVIDER", "finbert").lower() +ENRICHMENT_MAX_CHARS = int(os.getenv("AI_ENRICH_MAX_CHARS", "2000")) +ENRICHMENT_CACHE_SIZE = int(os.getenv("AI_ENRICH_CACHE_SIZE", "1024")) + @asynccontextmanager async def lifespan(app_instance: FastAPI): configure_tracing() app_instance.state.scorer = scorer + app_instance.state.enricher = build_enricher() logger.info("AI Engine Online.") yield @@ -57,7 +66,7 @@ class EnrichResponse(BaseModel): schema_version: str = Field(alias="schemaVersion", description="Semantic vector schema version.") s_v: list[float] = Field( alias="S_v", - description="Sentiment vector (normalized polarity + severity).", + description="Sentiment vector [negative, neutral, positive], normalized to sum to 1.", ) p_v: float = Field( alias="P_v", @@ -100,22 +109,14 @@ def analyze(payload: RiskPayload) -> dict: @app.post("/signals/enrich", response_model=EnrichResponse) def enrich_signals(payload: EnrichRequest) -> EnrichResponse: - # Placeholder semantic enrichment: use simple heuristics until NLP pipeline is online. - combined = " ".join( - [doc.title + " " + (doc.summary or "") for doc in payload.documents] - ).lower() - negative_terms = ("outage", "disruption", "degraded", "incident", "latency", "unavailable") - has_negative = any(term in combined for term in negative_terms) - - s_v = [0.1, 0.1, 0.1] - p_v = 0.15 - b_s = 0.0 - if has_negative: - s_v = [0.9, 0.2, 0.1] - p_v = 0.85 - b_s = 0.2 - - return EnrichResponse(schemaVersion=ENRICHMENT_SCHEMA_VERSION, S_v=s_v, P_v=p_v, B_s=b_s) + enricher: SemanticEnricher = app.state.enricher + result = enricher.enrich(payload.documents) + return EnrichResponse( + schemaVersion=ENRICHMENT_SCHEMA_VERSION, + S_v=result.s_v, + P_v=result.p_v, + B_s=result.b_s, + ) @app.get("/signals/enrich/schema") @@ -123,12 +124,140 @@ def enrich_schema() -> dict: return { "schemaVersion": ENRICHMENT_SCHEMA_VERSION, "fields": { - "S_v": "Sentiment vector (normalized polarity + severity).", + "S_v": "Sentiment vector [negative, neutral, positive], normalized to sum to 1.", "P_v": "Volatility probability in [0, 1].", "B_s": "Supply or capacity bias (long-horizon signal).", }, } + +@dataclass +class EnrichResult: + s_v: list[float] + p_v: float + b_s: float + + +class SemanticEnricher: + def enrich(self, documents: Iterable[SignalDocument]) -> EnrichResult: # pragma: no cover - interface + raise NotImplementedError + + +class HeuristicEnricher(SemanticEnricher): + negative_terms = ("outage", "disruption", "degraded", "incident", "latency", "unavailable") + supply_terms = ("capacity", "shortage", "procurement", "inventory", "supply", "quota") + + def enrich(self, documents: Iterable[SignalDocument]) -> EnrichResult: + combined = " ".join([self._doc_text(doc) for doc in documents]).lower() + negative_hits = sum(term in combined for term in self.negative_terms) + supply_hits = sum(term in combined for term in self.supply_terms) + + neg_score = 0.15 + 0.2 * negative_hits + neutral_score = 0.6 - 0.1 * negative_hits + pos_score = 1.0 - (neg_score + neutral_score) + + s_v = normalize_vector([neg_score, neutral_score, max(0.0, pos_score)]) + p_v = clamp(0.15 + 0.25 * negative_hits, 0.0, 1.0) + b_s = clamp(0.05 * supply_hits, 0.0, 1.0) + return EnrichResult(s_v=s_v, p_v=p_v, b_s=b_s) + + @staticmethod + def _doc_text(doc: SignalDocument) -> str: + summary = doc.summary or "" + return f"{doc.title} {summary}".strip() + + +class FinbertEnricher(SemanticEnricher): + def __init__(self, model_id: str, max_chars: int, cache_size: int) -> None: + from transformers import pipeline + import torch + + device = 0 if torch.cuda.is_available() else -1 + self._pipeline = pipeline( + "sentiment-analysis", + model=model_id, + tokenizer=model_id, + return_all_scores=True, + device=device, + ) + self._max_chars = max_chars + self._cache = lru_cache(maxsize=cache_size)(self._score_text) + + def enrich(self, documents: Iterable[SignalDocument]) -> EnrichResult: + scores = [self._cache(self._doc_text(doc)) for doc in documents] + if not scores: + return EnrichResult(s_v=[0.15, 0.7, 0.15], p_v=0.1, b_s=0.0) + + neg = sum(score[0] for score in scores) / len(scores) + neutral = sum(score[1] for score in scores) / len(scores) + pos = sum(score[2] for score in scores) / len(scores) + s_v = normalize_vector([neg, neutral, pos]) + + volatility_boost = max(0.0, neg - pos) + p_v = clamp(0.1 + volatility_boost * 1.2, 0.0, 1.0) + b_s = clamp(self._supply_bias(documents), 0.0, 1.0) + return EnrichResult(s_v=s_v, p_v=p_v, b_s=b_s) + + def _score_text(self, text: str) -> list[float]: + scores_raw = self._pipeline(text[: self._max_chars]) + scores = self._parse_scores(scores_raw) + return normalize_vector(scores) + + @staticmethod + def _parse_scores(raw) -> list[float]: + if isinstance(raw, list) and raw and isinstance(raw[0], list): + entries = raw[0] + else: + entries = raw + + label_map = {"negative": 0, "neutral": 1, "positive": 2} + scores = [0.0, 0.0, 0.0] + for entry in entries: + label = entry.get("label", "").lower() + score = float(entry.get("score", 0.0)) + if label in label_map: + scores[label_map[label]] = score + if sum(scores) == 0: + return [0.15, 0.7, 0.15] + return scores + + @staticmethod + def _doc_text(doc: SignalDocument) -> str: + summary = doc.summary or "" + return f"{doc.title} {summary}".strip() + + @staticmethod + def _supply_bias(documents: Iterable[SignalDocument]) -> float: + supply_terms = ("capacity", "shortage", "procurement", "inventory", "supply", "quota") + combined = " ".join([doc.title + " " + (doc.summary or "") for doc in documents]).lower() + hits = sum(term in combined for term in supply_terms) + return 0.05 * hits + + +def build_enricher() -> SemanticEnricher: + if ENRICHMENT_PROVIDER == "finbert": + try: + logger.info("Loading FinBERT model %s for enrichment.", DEFAULT_FINBERT_MODEL) + return FinbertEnricher( + model_id=DEFAULT_FINBERT_MODEL, + max_chars=ENRICHMENT_MAX_CHARS, + cache_size=ENRICHMENT_CACHE_SIZE, + ) + except Exception as exc: + logger.warning("Failed to load FinBERT model, falling back to heuristics: %s", exc) + return HeuristicEnricher() + + +def normalize_vector(values: list[float]) -> list[float]: + total = sum(max(0.0, value) for value in values) + if total <= 0: + return [0.15, 0.7, 0.15] + return [max(0.0, value) / total for value in values] + + +def clamp(value: float, min_value: float, max_value: float) -> float: + return max(min_value, min(max_value, value)) + def configure_tracing() -> None: endpoint = os.getenv("OTEL_EXPORTER_OTLP_ENDPOINT") if not endpoint: diff --git a/src/services/ai-engine/requirements.txt b/src/services/ai-engine/requirements.txt index 680c2e1..7b736c8 100644 --- a/src/services/ai-engine/requirements.txt +++ b/src/services/ai-engine/requirements.txt @@ -8,6 +8,7 @@ opentelemetry-exporter-otlp opentelemetry-instrumentation-fastapi opentelemetry-instrumentation-requests torch +transformers>=4.39.0 numpy pandas scikit-learn From d1623269f6c565e048ede386b5aa59b8b6952d74 Mon Sep 17 00:00:00 2001 From: JasonEran Date: Wed, 11 Feb 2026 14:49:27 +0800 Subject: [PATCH 08/24] Added separate SUMMARY_SCHEMA_VERSION and /signals/summarize output versioning. --- docs/Quickstart.md | 34 +++++++ src/services/ai-engine/main.py | 160 +++++++++++++++++++++++++++++++++ 2 files changed, 194 insertions(+) diff --git a/docs/Quickstart.md b/docs/Quickstart.md index e5d8505..4b4b66a 100644 --- a/docs/Quickstart.md +++ b/docs/Quickstart.md @@ -90,6 +90,40 @@ Schema details: GET /signals/enrich/schema ``` +## Optional: Enable summarization (v2.3 Milestone 1) + +The AI engine can summarize long advisories with a built-in heuristic summarizer or a remote HTTP summarizer. + +```bash +# PowerShell +$env:AI_SUMMARIZER_PROVIDER="heuristic" # or "http" +$env:AI_SUMMARIZER_ENDPOINT="https://summarizer.example/api" # required for http +$env:AI_SUMMARIZER_MAX_CHARS="600" +$env:AI_SUMMARIZER_CACHE_SIZE="1024" +$env:AI_SUMMARIZER_TIMEOUT="8" + +# Bash +export AI_SUMMARIZER_PROVIDER=heuristic +export AI_SUMMARIZER_ENDPOINT=https://summarizer.example/api +export AI_SUMMARIZER_MAX_CHARS=600 +export AI_SUMMARIZER_CACHE_SIZE=1024 +export AI_SUMMARIZER_TIMEOUT=8 +``` + +Summarize signals: + +``` +POST /signals/summarize +``` + +Example payload: + +```bash +curl -X POST http://localhost:8000/signals/summarize \ + -H "Content-Type: application/json" \ + -d '{"documents":[{"source":"aws","title":"AWS incident update","summary":"Service degradation observed in us-east-1..."}],"maxChars":280}' +``` + If you want to simulate migrations, start at least two agents: ```bash diff --git a/src/services/ai-engine/main.py b/src/services/ai-engine/main.py index 6f805b7..c17054e 100644 --- a/src/services/ai-engine/main.py +++ b/src/services/ai-engine/main.py @@ -1,5 +1,6 @@ import logging import os +import re from contextlib import asynccontextmanager from dataclasses import dataclass from functools import lru_cache @@ -20,11 +21,17 @@ logger = logging.getLogger("uvicorn.error") scorer = RiskScorer() ENRICHMENT_SCHEMA_VERSION = "1.0" +SUMMARY_SCHEMA_VERSION = "1.0" DEFAULT_FINBERT_MODEL = os.getenv("AI_FINBERT_MODEL", "ProsusAI/finbert") ENRICHMENT_PROVIDER = os.getenv("AI_ENRICH_PROVIDER", "finbert").lower() ENRICHMENT_MAX_CHARS = int(os.getenv("AI_ENRICH_MAX_CHARS", "2000")) ENRICHMENT_CACHE_SIZE = int(os.getenv("AI_ENRICH_CACHE_SIZE", "1024")) +SUMMARY_PROVIDER = os.getenv("AI_SUMMARIZER_PROVIDER", "heuristic").lower() +SUMMARY_ENDPOINT = os.getenv("AI_SUMMARIZER_ENDPOINT", "") +SUMMARY_MAX_CHARS = int(os.getenv("AI_SUMMARIZER_MAX_CHARS", "600")) +SUMMARY_CACHE_SIZE = int(os.getenv("AI_SUMMARIZER_CACHE_SIZE", "1024")) +SUMMARY_TIMEOUT_SECONDS = float(os.getenv("AI_SUMMARIZER_TIMEOUT", "8")) @asynccontextmanager @@ -32,6 +39,7 @@ async def lifespan(app_instance: FastAPI): configure_tracing() app_instance.state.scorer = scorer app_instance.state.enricher = build_enricher() + app_instance.state.summarizer = build_summarizer() logger.info("AI Engine Online.") yield @@ -80,6 +88,28 @@ class EnrichResponse(BaseModel): model_config = ConfigDict(populate_by_name=True) +class SummarizeRequest(BaseModel): + documents: list[SignalDocument] + max_chars: int | None = Field(default=None, alias="maxChars") + + model_config = ConfigDict(populate_by_name=True) + + +class SummaryItem(BaseModel): + index: int + source: str + title: str + summary: str + truncated: bool + + +class SummarizeResponse(BaseModel): + schema_version: str = Field(alias="schemaVersion") + summaries: list[SummaryItem] + + model_config = ConfigDict(populate_by_name=True) + + @app.get("/") def root() -> dict: return {"status": "AI Engine Online"} @@ -143,6 +173,17 @@ def enrich(self, documents: Iterable[SignalDocument]) -> EnrichResult: # pragma raise NotImplementedError +@dataclass +class SummarizeResult: + summary: str + truncated: bool + + +class SignalSummarizer: + def summarize(self, text: str, max_chars: int) -> SummarizeResult: # pragma: no cover - interface + raise NotImplementedError + + class HeuristicEnricher(SemanticEnricher): negative_terms = ("outage", "disruption", "degraded", "incident", "latency", "unavailable") supply_terms = ("capacity", "shortage", "procurement", "inventory", "supply", "quota") @@ -248,6 +289,97 @@ def build_enricher() -> SemanticEnricher: return HeuristicEnricher() +class HeuristicSummarizer(SignalSummarizer): + def __init__(self, max_chars: int, cache_size: int) -> None: + self._max_chars = max_chars + self._cache = lru_cache(maxsize=cache_size)(self._summarize_text) + + def summarize(self, text: str, max_chars: int | None = None) -> SummarizeResult: + limit = max_chars or self._max_chars + clean = normalize_text(text) + if not clean: + return SummarizeResult(summary="", truncated=False) + if limit <= 0: + return SummarizeResult(summary="", truncated=len(clean) > 0) + return self._cache(clean, limit) + + def _summarize_text(self, text: str, limit: int) -> SummarizeResult: + if len(text) <= limit: + return SummarizeResult(summary=text, truncated=False) + + sentences = re.split(r"(?<=[.!?])\s+", text) + summary_parts: list[str] = [] + total_len = 0 + for sentence in sentences: + if not sentence: + continue + next_len = total_len + len(sentence) + (1 if summary_parts else 0) + if next_len > limit: + break + summary_parts.append(sentence) + total_len = next_len + + if not summary_parts: + return SummarizeResult(summary=text[:limit].rstrip(), truncated=True) + + summary = " ".join(summary_parts).rstrip() + return SummarizeResult(summary=summary, truncated=True) + + +class HttpSummarizer(SignalSummarizer): + def __init__(self, endpoint: str, fallback: SignalSummarizer, max_chars: int, cache_size: int, timeout: float) -> None: + self._endpoint = endpoint + self._fallback = fallback + self._max_chars = max_chars + self._timeout = timeout + self._cache = lru_cache(maxsize=cache_size)(self._summarize_remote) + + def summarize(self, text: str, max_chars: int | None = None) -> SummarizeResult: + limit = max_chars or self._max_chars + clean = normalize_text(text) + if not clean: + return SummarizeResult(summary="", truncated=False) + if limit <= 0: + return SummarizeResult(summary="", truncated=len(clean) > 0) + return self._cache(clean, limit) + + def _summarize_remote(self, text: str, limit: int) -> SummarizeResult: + try: + import requests + + response = requests.post( + self._endpoint, + json={"text": text, "maxChars": limit}, + timeout=self._timeout, + ) + response.raise_for_status() + payload = response.json() + summary = str(payload.get("summary", "")).strip() + if not summary: + return self._fallback.summarize(text, limit) + return SummarizeResult(summary=summary, truncated=len(text) > len(summary)) + except Exception as exc: + logger.warning("Summarizer remote call failed, falling back: %s", exc) + return self._fallback.summarize(text, limit) + + +def build_summarizer() -> SignalSummarizer: + heuristic = HeuristicSummarizer(max_chars=SUMMARY_MAX_CHARS, cache_size=SUMMARY_CACHE_SIZE) + + if SUMMARY_PROVIDER == "http": + if SUMMARY_ENDPOINT: + return HttpSummarizer( + endpoint=SUMMARY_ENDPOINT, + fallback=heuristic, + max_chars=SUMMARY_MAX_CHARS, + cache_size=SUMMARY_CACHE_SIZE, + timeout=SUMMARY_TIMEOUT_SECONDS, + ) + logger.warning("AI_SUMMARIZER_PROVIDER=http set but AI_SUMMARIZER_ENDPOINT is empty; using heuristic.") + + return heuristic + + def normalize_vector(values: list[float]) -> list[float]: total = sum(max(0.0, value) for value in values) if total <= 0: @@ -258,6 +390,34 @@ def normalize_vector(values: list[float]) -> list[float]: def clamp(value: float, min_value: float, max_value: float) -> float: return max(min_value, min(max_value, value)) + +def normalize_text(text: str) -> str: + return " ".join(text.replace("\n", " ").split()).strip() + + +@app.post("/signals/summarize", response_model=SummarizeResponse) +def summarize_signals(payload: SummarizeRequest) -> SummarizeResponse: + summarizer: SignalSummarizer = app.state.summarizer + max_chars = payload.max_chars if payload.max_chars and payload.max_chars > 0 else SUMMARY_MAX_CHARS + summaries: list[SummaryItem] = [] + + for index, doc in enumerate(payload.documents): + source = doc.source + title = doc.title + text = f"{doc.title}. {doc.summary}" if doc.summary else doc.title + result = summarizer.summarize(text, max_chars) + summaries.append( + SummaryItem( + index=index, + source=source, + title=title, + summary=result.summary, + truncated=result.truncated, + ) + ) + + return SummarizeResponse(schemaVersion=SUMMARY_SCHEMA_VERSION, summaries=summaries) + def configure_tracing() -> None: endpoint = os.getenv("OTEL_EXPORTER_OTLP_ENDPOINT") if not endpoint: From 7cb95b315e76fffe605257c8d458a851e7909c70 Mon Sep 17 00:00:00 2001 From: JasonEran Date: Wed, 11 Feb 2026 15:17:06 +0800 Subject: [PATCH 09/24] Core integration with AI Engine --- docs/Quickstart.md | 11 + .../Controllers/ExternalSignalsController.cs | 11 + .../Data/ApplicationDbContext.cs | 11 + ...00_AddExternalSignalEnrichment.Designer.cs | 422 ++++++++++++++++++ ...60211152000_AddExternalSignalEnrichment.cs | 129 ++++++ .../ApplicationDbContextModelSnapshot.cs | 44 ++ .../core-dotnet/AetherGuard.Core/Program.cs | 1 + .../ExternalSignalEnrichmentClient.cs | 215 +++++++++ .../ExternalSignalIngestionService.cs | 104 +++++ .../ExternalSignals/ExternalSignalsOptions.cs | 11 + .../AetherGuard.Core/appsettings.json | 8 + .../AetherGuard.Core/models/ExternalSignal.cs | 11 + 12 files changed, 978 insertions(+) create mode 100644 src/services/core-dotnet/AetherGuard.Core/Data/Migrations/20260211152000_AddExternalSignalEnrichment.Designer.cs create mode 100644 src/services/core-dotnet/AetherGuard.Core/Data/Migrations/20260211152000_AddExternalSignalEnrichment.cs create mode 100644 src/services/core-dotnet/AetherGuard.Core/Services/ExternalSignals/ExternalSignalEnrichmentClient.cs diff --git a/docs/Quickstart.md b/docs/Quickstart.md index 4b4b66a..e1c518a 100644 --- a/docs/Quickstart.md +++ b/docs/Quickstart.md @@ -81,6 +81,17 @@ export AI_ENRICH_PROVIDER=finbert export AI_FINBERT_MODEL=ProsusAI/finbert ``` +Core will call the AI engine for enrichment when external signals are enabled. +If the AI engine is not running at the default Docker host (`http://ai-service:8000`), override: + +```bash +# PowerShell +$env:ExternalSignals__Enrichment__BaseUrl="http://localhost:8000" + +# Bash +export ExternalSignals__Enrichment__BaseUrl=http://localhost:8000 +``` + Note: the first FinBERT run downloads model weights and can take a few minutes. Set `AI_ENRICH_PROVIDER=heuristic` if you need a fast, offline fallback. diff --git a/src/services/core-dotnet/AetherGuard.Core/Controllers/ExternalSignalsController.cs b/src/services/core-dotnet/AetherGuard.Core/Controllers/ExternalSignalsController.cs index 536405e..d8540b6 100644 --- a/src/services/core-dotnet/AetherGuard.Core/Controllers/ExternalSignalsController.cs +++ b/src/services/core-dotnet/AetherGuard.Core/Controllers/ExternalSignalsController.cs @@ -65,6 +65,17 @@ public async Task GetSignals( signal.ExternalId, signal.Title, signal.Summary, + signal.SummaryDigest, + signal.SummaryDigestTruncated, + signal.SummarySchemaVersion, + signal.EnrichmentSchemaVersion, + signal.SentimentNegative, + signal.SentimentNeutral, + signal.SentimentPositive, + signal.VolatilityProbability, + signal.SupplyBias, + signal.SummarizedAt, + signal.EnrichedAt, signal.Region, signal.Severity, signal.Category, diff --git a/src/services/core-dotnet/AetherGuard.Core/Data/ApplicationDbContext.cs b/src/services/core-dotnet/AetherGuard.Core/Data/ApplicationDbContext.cs index 04c7261..189de62 100644 --- a/src/services/core-dotnet/AetherGuard.Core/Data/ApplicationDbContext.cs +++ b/src/services/core-dotnet/AetherGuard.Core/Data/ApplicationDbContext.cs @@ -101,6 +101,17 @@ protected override void OnModelCreating(ModelBuilder modelBuilder) entity.Property(e => e.ExternalId).HasColumnName("external_id"); entity.Property(e => e.Title).HasColumnName("title"); entity.Property(e => e.Summary).HasColumnName("summary"); + entity.Property(e => e.SummaryDigest).HasColumnName("summary_digest"); + entity.Property(e => e.SummaryDigestTruncated).HasColumnName("summary_digest_truncated"); + entity.Property(e => e.SummarySchemaVersion).HasColumnName("summary_schema_version"); + entity.Property(e => e.EnrichmentSchemaVersion).HasColumnName("enrichment_schema_version"); + entity.Property(e => e.SentimentNegative).HasColumnName("sentiment_negative"); + entity.Property(e => e.SentimentNeutral).HasColumnName("sentiment_neutral"); + entity.Property(e => e.SentimentPositive).HasColumnName("sentiment_positive"); + entity.Property(e => e.VolatilityProbability).HasColumnName("volatility_probability"); + entity.Property(e => e.SupplyBias).HasColumnName("supply_bias"); + entity.Property(e => e.SummarizedAt).HasColumnName("summarized_at"); + entity.Property(e => e.EnrichedAt).HasColumnName("enriched_at"); entity.Property(e => e.Region).HasColumnName("region"); entity.Property(e => e.Severity).HasColumnName("severity"); entity.Property(e => e.Category).HasColumnName("category"); diff --git a/src/services/core-dotnet/AetherGuard.Core/Data/Migrations/20260211152000_AddExternalSignalEnrichment.Designer.cs b/src/services/core-dotnet/AetherGuard.Core/Data/Migrations/20260211152000_AddExternalSignalEnrichment.Designer.cs new file mode 100644 index 0000000..dae2ed0 --- /dev/null +++ b/src/services/core-dotnet/AetherGuard.Core/Data/Migrations/20260211152000_AddExternalSignalEnrichment.Designer.cs @@ -0,0 +1,422 @@ +// +using System; +using AetherGuard.Core.Data; +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Infrastructure; +using Microsoft.EntityFrameworkCore.Migrations; +using Microsoft.EntityFrameworkCore.Storage.ValueConversion; +using Npgsql.EntityFrameworkCore.PostgreSQL.Metadata; + +#nullable disable + +namespace AetherGuard.Core.Data.Migrations +{ + [DbContext(typeof(ApplicationDbContext))] + [Migration("20260211152000_AddExternalSignalEnrichment")] + partial class AddExternalSignalEnrichment + { + protected override void BuildTargetModel(ModelBuilder modelBuilder) + { +#pragma warning disable 612, 618 + modelBuilder + .HasAnnotation("ProductVersion", "8.0.0") + .HasAnnotation("Relational:MaxIdentifierLength", 63); + + NpgsqlModelBuilderExtensions.UseIdentityByDefaultColumns(modelBuilder); + + modelBuilder.Entity("AetherGuard.Core.Models.Agent", b => + { + b.Property("Id") + .ValueGeneratedOnAdd() + .HasColumnType("uuid") + .HasColumnName("id"); + + b.Property("AgentToken") + .IsRequired() + .HasColumnType("text") + .HasColumnName("agenttoken"); + + b.Property("Hostname") + .IsRequired() + .HasColumnType("text") + .HasColumnName("hostname"); + + b.Property("LastHeartbeat") + .HasColumnType("timestamp with time zone") + .HasColumnName("lastheartbeat"); + + b.Property("Status") + .IsRequired() + .HasColumnType("text") + .HasColumnName("status"); + + b.HasKey("Id"); + + b.HasIndex("AgentToken") + .IsUnique(); + + b.ToTable("agents", (string)null); + }); + + modelBuilder.Entity("AetherGuard.Core.Models.AgentCommand", b => + { + b.Property("Id") + .ValueGeneratedOnAdd() + .HasColumnType("integer") + .HasColumnName("id"); + + NpgsqlPropertyBuilderExtensions.UseIdentityByDefaultColumn(b.Property("Id")); + + b.Property("Action") + .IsRequired() + .HasColumnType("text") + .HasColumnName("action"); + + b.Property("AgentId") + .HasColumnType("uuid") + .HasColumnName("agent_id"); + + b.Property("CommandId") + .HasColumnType("uuid") + .HasColumnName("command_id"); + + b.Property("CreatedAt") + .HasColumnType("timestamp with time zone") + .HasColumnName("created_at"); + + b.Property("ExpiresAt") + .HasColumnType("timestamp with time zone") + .HasColumnName("expires_at"); + + b.Property("Nonce") + .IsRequired() + .HasColumnType("text") + .HasColumnName("nonce"); + + b.Property("Parameters") + .IsRequired() + .HasColumnType("text") + .HasColumnName("parameters"); + + b.Property("Signature") + .IsRequired() + .HasColumnType("text") + .HasColumnName("signature"); + + b.Property("Status") + .IsRequired() + .HasColumnType("text") + .HasColumnName("status"); + + b.Property("UpdatedAt") + .HasColumnType("timestamp with time zone") + .HasColumnName("updated_at"); + + b.Property("WorkloadId") + .IsRequired() + .HasColumnType("text") + .HasColumnName("workload_id"); + + b.HasKey("Id"); + + b.ToTable("agent_commands", (string)null); + }); + + modelBuilder.Entity("AetherGuard.Core.Models.CommandAudit", b => + { + b.Property("Id") + .ValueGeneratedOnAdd() + .HasColumnType("bigint") + .HasColumnName("id"); + + NpgsqlPropertyBuilderExtensions.UseIdentityByDefaultColumn(b.Property("Id")); + + b.Property("Action") + .IsRequired() + .HasColumnType("text") + .HasColumnName("action"); + + b.Property("Actor") + .IsRequired() + .HasColumnType("text") + .HasColumnName("actor"); + + b.Property("CommandId") + .HasColumnType("uuid") + .HasColumnName("command_id"); + + b.Property("CreatedAt") + .HasColumnType("timestamp with time zone") + .HasColumnName("created_at"); + + b.Property("Error") + .IsRequired() + .HasColumnType("text") + .HasColumnName("error"); + + b.Property("Result") + .IsRequired() + .HasColumnType("text") + .HasColumnName("result"); + + b.HasKey("Id"); + + b.ToTable("command_audits", (string)null); + }); + + modelBuilder.Entity("AetherGuard.Core.Models.ExternalSignal", b => + { + b.Property("Id") + .ValueGeneratedOnAdd() + .HasColumnType("bigint") + .HasColumnName("id"); + + NpgsqlPropertyBuilderExtensions.UseIdentityByDefaultColumn(b.Property("Id")); + + b.Property("Category") + .HasColumnType("text") + .HasColumnName("category"); + + b.Property("ExternalId") + .IsRequired() + .HasColumnType("text") + .HasColumnName("external_id"); + + b.Property("IngestedAt") + .HasColumnType("timestamp with time zone") + .HasColumnName("ingested_at"); + + b.Property("PublishedAt") + .HasColumnType("timestamp with time zone") + .HasColumnName("published_at"); + + b.Property("Region") + .HasColumnType("text") + .HasColumnName("region"); + + b.Property("Severity") + .HasColumnType("text") + .HasColumnName("severity"); + + b.Property("Source") + .IsRequired() + .HasColumnType("text") + .HasColumnName("source"); + + b.Property("Summary") + .HasColumnType("text") + .HasColumnName("summary"); + + b.Property("SummaryDigest") + .HasColumnType("text") + .HasColumnName("summary_digest"); + + b.Property("SummaryDigestTruncated") + .HasColumnType("boolean") + .HasColumnName("summary_digest_truncated"); + + b.Property("SummarySchemaVersion") + .HasColumnType("text") + .HasColumnName("summary_schema_version"); + + b.Property("EnrichmentSchemaVersion") + .HasColumnType("text") + .HasColumnName("enrichment_schema_version"); + + b.Property("SentimentNegative") + .HasColumnType("double precision") + .HasColumnName("sentiment_negative"); + + b.Property("SentimentNeutral") + .HasColumnType("double precision") + .HasColumnName("sentiment_neutral"); + + b.Property("SentimentPositive") + .HasColumnType("double precision") + .HasColumnName("sentiment_positive"); + + b.Property("VolatilityProbability") + .HasColumnType("double precision") + .HasColumnName("volatility_probability"); + + b.Property("SupplyBias") + .HasColumnType("double precision") + .HasColumnName("supply_bias"); + + b.Property("SummarizedAt") + .HasColumnType("timestamp with time zone") + .HasColumnName("summarized_at"); + + b.Property("EnrichedAt") + .HasColumnType("timestamp with time zone") + .HasColumnName("enriched_at"); + + b.Property("Tags") + .HasColumnType("text") + .HasColumnName("tags"); + + b.Property("Title") + .IsRequired() + .HasColumnType("text") + .HasColumnName("title"); + + b.Property("Url") + .HasColumnType("text") + .HasColumnName("url"); + + b.HasKey("Id"); + + b.HasIndex("PublishedAt"); + + b.HasIndex("Source", "ExternalId") + .IsUnique(); + + b.ToTable("external_signals", (string)null); + }); + + modelBuilder.Entity("AetherGuard.Core.Models.ExternalSignalFeedState", b => + { + b.Property("Id") + .ValueGeneratedOnAdd() + .HasColumnType("bigint") + .HasColumnName("id"); + + NpgsqlPropertyBuilderExtensions.UseIdentityByDefaultColumn(b.Property("Id")); + + b.Property("FailureCount") + .HasColumnType("integer") + .HasColumnName("failure_count"); + + b.Property("LastError") + .HasColumnType("text") + .HasColumnName("last_error"); + + b.Property("LastFetchAt") + .HasColumnType("timestamp with time zone") + .HasColumnName("last_fetch_at"); + + b.Property("LastStatusCode") + .HasColumnType("integer") + .HasColumnName("last_status_code"); + + b.Property("LastSuccessAt") + .HasColumnType("timestamp with time zone") + .HasColumnName("last_success_at"); + + b.Property("Name") + .IsRequired() + .HasColumnType("text") + .HasColumnName("name"); + + b.Property("Url") + .IsRequired() + .HasColumnType("text") + .HasColumnName("url"); + + b.HasKey("Id"); + + b.HasIndex("Name") + .IsUnique(); + + b.ToTable("external_signal_feeds", (string)null); + }); + + modelBuilder.Entity("AetherGuard.Core.Models.SchemaRegistryEntry", b => + { + b.Property("Id") + .ValueGeneratedOnAdd() + .HasColumnType("bigint") + .HasColumnName("id"); + + NpgsqlPropertyBuilderExtensions.UseIdentityByDefaultColumn(b.Property("Id")); + + b.Property("CreatedAt") + .HasColumnType("timestamp with time zone") + .HasColumnName("created_at"); + + b.Property("Schema") + .IsRequired() + .HasColumnType("text") + .HasColumnName("schema"); + + b.Property("Subject") + .IsRequired() + .HasColumnType("text") + .HasColumnName("subject"); + + b.Property("Version") + .HasColumnType("integer") + .HasColumnName("version"); + + b.HasKey("Id"); + + b.HasIndex("Subject", "Version") + .IsUnique(); + + b.ToTable("schema_registry", (string)null); + }); + + modelBuilder.Entity("AetherGuard.Core.Models.TelemetryRecord", b => + { + b.Property("Id") + .ValueGeneratedOnAdd() + .HasColumnType("bigint") + .HasColumnName("Id"); + + NpgsqlPropertyBuilderExtensions.UseIdentityByDefaultColumn(b.Property("Id")); + + b.Property("Timestamp") + .HasColumnType("timestamp with time zone") + .HasColumnName("Timestamp"); + + b.Property("AgentId") + .IsRequired() + .HasColumnType("text") + .HasColumnName("AgentId"); + + b.Property("AiConfidence") + .HasColumnType("double precision") + .HasColumnName("AiConfidence"); + + b.Property("AiStatus") + .IsRequired() + .HasColumnType("text") + .HasColumnName("AiStatus"); + + b.Property("CpuUsage") + .HasColumnType("double precision") + .HasColumnName("CpuUsage"); + + b.Property("DiskAvailable") + .HasColumnType("bigint") + .HasColumnName("DiskAvailable"); + + b.Property("MemoryUsage") + .HasColumnType("double precision") + .HasColumnName("MemoryUsage"); + + b.Property("PredictedCpu") + .HasColumnType("double precision") + .HasColumnName("PredictedCpu"); + + b.Property("RebalanceSignal") + .HasColumnType("boolean") + .HasColumnName("RebalanceSignal"); + + b.Property("RootCause") + .HasColumnType("text") + .HasColumnName("RootCause"); + + b.Property("WorkloadTier") + .IsRequired() + .HasColumnType("text") + .HasColumnName("WorkloadTier"); + + b.HasKey("Id", "Timestamp"); + + b.ToTable("TelemetryRecords", (string)null); + }); +#pragma warning restore 612, 618 + } + } +} diff --git a/src/services/core-dotnet/AetherGuard.Core/Data/Migrations/20260211152000_AddExternalSignalEnrichment.cs b/src/services/core-dotnet/AetherGuard.Core/Data/Migrations/20260211152000_AddExternalSignalEnrichment.cs new file mode 100644 index 0000000..b487a80 --- /dev/null +++ b/src/services/core-dotnet/AetherGuard.Core/Data/Migrations/20260211152000_AddExternalSignalEnrichment.cs @@ -0,0 +1,129 @@ +using System; +using Microsoft.EntityFrameworkCore.Migrations; + +#nullable disable + +namespace AetherGuard.Core.Data.Migrations +{ + /// + public partial class AddExternalSignalEnrichment : Migration + { + /// + protected override void Up(MigrationBuilder migrationBuilder) + { + migrationBuilder.AddColumn( + name: "summary_digest", + table: "external_signals", + type: "text", + nullable: true); + + migrationBuilder.AddColumn( + name: "summary_digest_truncated", + table: "external_signals", + type: "boolean", + nullable: true); + + migrationBuilder.AddColumn( + name: "summary_schema_version", + table: "external_signals", + type: "text", + nullable: true); + + migrationBuilder.AddColumn( + name: "enrichment_schema_version", + table: "external_signals", + type: "text", + nullable: true); + + migrationBuilder.AddColumn( + name: "sentiment_negative", + table: "external_signals", + type: "double precision", + nullable: true); + + migrationBuilder.AddColumn( + name: "sentiment_neutral", + table: "external_signals", + type: "double precision", + nullable: true); + + migrationBuilder.AddColumn( + name: "sentiment_positive", + table: "external_signals", + type: "double precision", + nullable: true); + + migrationBuilder.AddColumn( + name: "volatility_probability", + table: "external_signals", + type: "double precision", + nullable: true); + + migrationBuilder.AddColumn( + name: "supply_bias", + table: "external_signals", + type: "double precision", + nullable: true); + + migrationBuilder.AddColumn( + name: "summarized_at", + table: "external_signals", + type: "timestamp with time zone", + nullable: true); + + migrationBuilder.AddColumn( + name: "enriched_at", + table: "external_signals", + type: "timestamp with time zone", + nullable: true); + } + + /// + protected override void Down(MigrationBuilder migrationBuilder) + { + migrationBuilder.DropColumn( + name: "summary_digest", + table: "external_signals"); + + migrationBuilder.DropColumn( + name: "summary_digest_truncated", + table: "external_signals"); + + migrationBuilder.DropColumn( + name: "summary_schema_version", + table: "external_signals"); + + migrationBuilder.DropColumn( + name: "enrichment_schema_version", + table: "external_signals"); + + migrationBuilder.DropColumn( + name: "sentiment_negative", + table: "external_signals"); + + migrationBuilder.DropColumn( + name: "sentiment_neutral", + table: "external_signals"); + + migrationBuilder.DropColumn( + name: "sentiment_positive", + table: "external_signals"); + + migrationBuilder.DropColumn( + name: "volatility_probability", + table: "external_signals"); + + migrationBuilder.DropColumn( + name: "supply_bias", + table: "external_signals"); + + migrationBuilder.DropColumn( + name: "summarized_at", + table: "external_signals"); + + migrationBuilder.DropColumn( + name: "enriched_at", + table: "external_signals"); + } + } +} diff --git a/src/services/core-dotnet/AetherGuard.Core/Data/Migrations/ApplicationDbContextModelSnapshot.cs b/src/services/core-dotnet/AetherGuard.Core/Data/Migrations/ApplicationDbContextModelSnapshot.cs index 9c007b4..b8bf36d 100644 --- a/src/services/core-dotnet/AetherGuard.Core/Data/Migrations/ApplicationDbContextModelSnapshot.cs +++ b/src/services/core-dotnet/AetherGuard.Core/Data/Migrations/ApplicationDbContextModelSnapshot.cs @@ -205,6 +205,50 @@ protected override void BuildModel(ModelBuilder modelBuilder) .HasColumnType("text") .HasColumnName("summary"); + b.Property("SummaryDigest") + .HasColumnType("text") + .HasColumnName("summary_digest"); + + b.Property("SummaryDigestTruncated") + .HasColumnType("boolean") + .HasColumnName("summary_digest_truncated"); + + b.Property("SummarySchemaVersion") + .HasColumnType("text") + .HasColumnName("summary_schema_version"); + + b.Property("EnrichmentSchemaVersion") + .HasColumnType("text") + .HasColumnName("enrichment_schema_version"); + + b.Property("SentimentNegative") + .HasColumnType("double precision") + .HasColumnName("sentiment_negative"); + + b.Property("SentimentNeutral") + .HasColumnType("double precision") + .HasColumnName("sentiment_neutral"); + + b.Property("SentimentPositive") + .HasColumnType("double precision") + .HasColumnName("sentiment_positive"); + + b.Property("VolatilityProbability") + .HasColumnType("double precision") + .HasColumnName("volatility_probability"); + + b.Property("SupplyBias") + .HasColumnType("double precision") + .HasColumnName("supply_bias"); + + b.Property("SummarizedAt") + .HasColumnType("timestamp with time zone") + .HasColumnName("summarized_at"); + + b.Property("EnrichedAt") + .HasColumnType("timestamp with time zone") + .HasColumnName("enriched_at"); + b.Property("Tags") .HasColumnType("text") .HasColumnName("tags"); diff --git a/src/services/core-dotnet/AetherGuard.Core/Program.cs b/src/services/core-dotnet/AetherGuard.Core/Program.cs index 6f739c1..a748da7 100644 --- a/src/services/core-dotnet/AetherGuard.Core/Program.cs +++ b/src/services/core-dotnet/AetherGuard.Core/Program.cs @@ -61,6 +61,7 @@ client.DefaultRequestHeaders.UserAgent.ParseAdd("Aether-Guard/ExternalSignals"); client.Timeout = TimeSpan.FromSeconds(15); }); +builder.Services.AddHttpClient(); builder.Services.AddScoped(); builder.Services.AddScoped(); builder.Services.AddScoped(); diff --git a/src/services/core-dotnet/AetherGuard.Core/Services/ExternalSignals/ExternalSignalEnrichmentClient.cs b/src/services/core-dotnet/AetherGuard.Core/Services/ExternalSignals/ExternalSignalEnrichmentClient.cs new file mode 100644 index 0000000..ad8f569 --- /dev/null +++ b/src/services/core-dotnet/AetherGuard.Core/Services/ExternalSignals/ExternalSignalEnrichmentClient.cs @@ -0,0 +1,215 @@ +using System.Net.Http.Json; +using System.Text.Json; +using System.Text.Json.Serialization; +using AetherGuard.Core.Models; +using Microsoft.Extensions.Options; + +namespace AetherGuard.Core.Services.ExternalSignals; + +public sealed class ExternalSignalEnrichmentClient +{ + private static readonly JsonSerializerOptions RequestJsonOptions = new() + { + PropertyNamingPolicy = JsonNamingPolicy.CamelCase, + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull + }; + + private static readonly JsonSerializerOptions ResponseJsonOptions = new() + { + PropertyNameCaseInsensitive = true + }; + + private readonly HttpClient _httpClient; + private readonly ExternalSignalEnrichmentOptions _options; + private readonly ILogger _logger; + private readonly bool _configured; + + public ExternalSignalEnrichmentClient( + HttpClient httpClient, + IOptions options, + ILogger logger) + { + _httpClient = httpClient; + _logger = logger; + _options = options.Value.Enrichment ?? new ExternalSignalEnrichmentOptions(); + _configured = ConfigureClient(_options); + } + + public bool IsEnabled => _options.Enabled && _configured; + public int MaxBatchSize => Math.Clamp(_options.MaxBatchSize, 1, 1000); + public int MaxConcurrency => Math.Clamp(_options.MaxConcurrency, 1, 8); + public int SummaryMaxChars => _options.SummaryMaxChars; + + public async Task SummarizeAsync( + IReadOnlyList signals, + CancellationToken cancellationToken) + { + if (!IsEnabled || SummaryMaxChars <= 0 || signals.Count == 0) + { + return null; + } + + var documents = signals.Select(MapDocument).ToList(); + var request = new SummarizeRequestDto(documents, SummaryMaxChars); + + try + { + using var response = await _httpClient.PostAsJsonAsync( + "signals/summarize", + request, + RequestJsonOptions, + cancellationToken); + + if (!response.IsSuccessStatusCode) + { + _logger.LogWarning("Summarizer returned HTTP {StatusCode}.", response.StatusCode); + return null; + } + + var payload = await response.Content.ReadFromJsonAsync( + ResponseJsonOptions, + cancellationToken); + + if (payload?.Summaries is null || payload.Summaries.Count == 0) + { + _logger.LogWarning("Summarizer returned an empty payload."); + return null; + } + + return new SummarizeResponse(payload.SchemaVersion ?? "unknown", payload.Summaries); + } + catch (Exception ex) + { + _logger.LogWarning(ex, "Summarizer request failed."); + return null; + } + } + + public async Task EnrichAsync(ExternalSignal signal, CancellationToken cancellationToken) + { + if (!IsEnabled) + { + return null; + } + + var request = new EnrichRequestDto(new[] { MapDocument(signal) }); + + try + { + using var response = await _httpClient.PostAsJsonAsync( + "signals/enrich", + request, + RequestJsonOptions, + cancellationToken); + + if (!response.IsSuccessStatusCode) + { + _logger.LogWarning("Enrichment returned HTTP {StatusCode} for signal {ExternalId}.", + response.StatusCode, + signal.ExternalId); + return null; + } + + var payload = await response.Content.ReadFromJsonAsync( + ResponseJsonOptions, + cancellationToken); + + if (payload?.SentimentVector is null || payload.SentimentVector.Length < 3) + { + _logger.LogWarning("Enrichment payload missing S_v for signal {ExternalId}.", signal.ExternalId); + return null; + } + + return new EnrichResponse( + payload.SchemaVersion ?? "unknown", + payload.SentimentVector, + payload.VolatilityProbability, + payload.SupplyBias); + } + catch (Exception ex) + { + _logger.LogWarning(ex, "Enrichment request failed for signal {ExternalId}.", signal.ExternalId); + return null; + } + } + + private bool ConfigureClient(ExternalSignalEnrichmentOptions options) + { + if (string.IsNullOrWhiteSpace(options.BaseUrl)) + { + _logger.LogWarning("External signal enrichment base URL is empty."); + return false; + } + + try + { + var baseUrl = options.BaseUrl.TrimEnd('/') + "/"; + _httpClient.BaseAddress ??= new Uri(baseUrl, UriKind.Absolute); + } + catch (Exception ex) + { + _logger.LogWarning(ex, "Invalid external signal enrichment base URL: {BaseUrl}", options.BaseUrl); + return false; + } + + var timeoutSeconds = Math.Clamp(options.TimeoutSeconds, 2, 60); + _httpClient.Timeout = TimeSpan.FromSeconds(timeoutSeconds); + return true; + } + + private static SignalDocumentDto MapDocument(ExternalSignal signal) + => new( + Source: signal.Source, + Title: signal.Title, + Summary: signal.Summary, + Url: signal.Url, + Region: signal.Region, + PublishedAt: signal.PublishedAt); + + public sealed record SummarizeResponse(string SchemaVersion, IReadOnlyList Summaries); + public sealed record EnrichResponse(string SchemaVersion, double[] SentimentVector, double VolatilityProbability, double SupplyBias); + + public sealed record SummaryItemDto( + [property: JsonPropertyName("index")] int Index, + [property: JsonPropertyName("summary")] string Summary, + [property: JsonPropertyName("truncated")] bool Truncated); + + private sealed record SignalDocumentDto( + [property: JsonPropertyName("source")] string Source, + [property: JsonPropertyName("title")] string Title, + [property: JsonPropertyName("summary")] string? Summary, + [property: JsonPropertyName("url")] string? Url, + [property: JsonPropertyName("region")] string? Region, + [property: JsonPropertyName("publishedAt")] DateTimeOffset? PublishedAt); + + private sealed record SummarizeRequestDto( + [property: JsonPropertyName("documents")] IReadOnlyList Documents, + [property: JsonPropertyName("maxChars")] int? MaxChars); + + private sealed class SummarizeResponseDto + { + [JsonPropertyName("schemaVersion")] + public string? SchemaVersion { get; init; } + + [JsonPropertyName("summaries")] + public List Summaries { get; init; } = new(); + } + + private sealed record EnrichRequestDto( + [property: JsonPropertyName("documents")] IReadOnlyList Documents); + + private sealed class EnrichResponseDto + { + [JsonPropertyName("schemaVersion")] + public string? SchemaVersion { get; init; } + + [JsonPropertyName("S_v")] + public double[]? SentimentVector { get; init; } + + [JsonPropertyName("P_v")] + public double VolatilityProbability { get; init; } + + [JsonPropertyName("B_s")] + public double SupplyBias { get; init; } + } +} diff --git a/src/services/core-dotnet/AetherGuard.Core/Services/ExternalSignals/ExternalSignalIngestionService.cs b/src/services/core-dotnet/AetherGuard.Core/Services/ExternalSignals/ExternalSignalIngestionService.cs index 2118199..de13795 100644 --- a/src/services/core-dotnet/AetherGuard.Core/Services/ExternalSignals/ExternalSignalIngestionService.cs +++ b/src/services/core-dotnet/AetherGuard.Core/Services/ExternalSignals/ExternalSignalIngestionService.cs @@ -133,11 +133,115 @@ private async Task IngestOnceAsync(CancellationToken cancellationToken) await db.SaveChangesAsync(cancellationToken); _logger.LogInformation("Ingested {Count} signals from {Feed}.", newSignals.Count, feed.Name); + + if (_options.Enrichment.Enabled) + { + var enrichmentClient = scope.ServiceProvider.GetRequiredService(); + await EnrichSignalsAsync(db, enrichmentClient, newSignals, cancellationToken); + } } await CleanupOldSignalsAsync(db, cancellationToken); } + private async Task EnrichSignalsAsync( + ApplicationDbContext db, + ExternalSignalEnrichmentClient client, + List signals, + CancellationToken cancellationToken) + { + if (!client.IsEnabled || signals.Count == 0) + { + return; + } + + var batch = signals.Take(client.MaxBatchSize).ToList(); + if (batch.Count == 0) + { + return; + } + + var summaryUpdates = 0; + var enrichmentUpdates = 0; + var summarizedAt = DateTimeOffset.UtcNow; + + var summaryResponse = await client.SummarizeAsync(batch, cancellationToken); + if (summaryResponse is not null) + { + var summaryByIndex = summaryResponse.Summaries + .Where(item => item.Index >= 0 && item.Index < batch.Count) + .GroupBy(item => item.Index) + .ToDictionary(group => group.Key, group => group.First()); + + for (var index = 0; index < batch.Count; index++) + { + if (!summaryByIndex.TryGetValue(index, out var item)) + { + continue; + } + + if (string.IsNullOrWhiteSpace(item.Summary)) + { + continue; + } + + var signal = batch[index]; + signal.SummaryDigest = item.Summary; + signal.SummaryDigestTruncated = item.Truncated; + signal.SummarySchemaVersion = summaryResponse.SchemaVersion; + signal.SummarizedAt = summarizedAt; + summaryUpdates++; + } + } + + var semaphore = new SemaphoreSlim(client.MaxConcurrency); + var tasks = batch.Select(async signal => + { + await semaphore.WaitAsync(cancellationToken); + try + { + var result = await client.EnrichAsync(signal, cancellationToken); + return (signal, result); + } + finally + { + semaphore.Release(); + } + }).ToList(); + + var results = await Task.WhenAll(tasks); + var enrichedAt = DateTimeOffset.UtcNow; + + foreach (var (signal, result) in results) + { + if (result is null || result.SentimentVector.Length < 3) + { + continue; + } + + signal.SentimentNegative = Clamp01(result.SentimentVector[0]); + signal.SentimentNeutral = Clamp01(result.SentimentVector[1]); + signal.SentimentPositive = Clamp01(result.SentimentVector[2]); + signal.VolatilityProbability = Clamp01(result.VolatilityProbability); + signal.SupplyBias = Clamp01(result.SupplyBias); + signal.EnrichmentSchemaVersion = result.SchemaVersion; + signal.EnrichedAt = enrichedAt; + enrichmentUpdates++; + } + + if (summaryUpdates > 0 || enrichmentUpdates > 0) + { + await db.SaveChangesAsync(cancellationToken); + _logger.LogInformation( + "Enriched external signals: summaries={SummaryCount}, vectors={VectorCount}.", + summaryUpdates, + enrichmentUpdates); + } + } + + private static double Clamp01(double value) + => Math.Clamp(value, 0.0, 1.0); + private static async Task UpdateFeedStateAsync( ApplicationDbContext db, ExternalSignalFeedOptions feed, diff --git a/src/services/core-dotnet/AetherGuard.Core/Services/ExternalSignals/ExternalSignalsOptions.cs b/src/services/core-dotnet/AetherGuard.Core/Services/ExternalSignals/ExternalSignalsOptions.cs index ab91b68..e0f8819 100644 --- a/src/services/core-dotnet/AetherGuard.Core/Services/ExternalSignals/ExternalSignalsOptions.cs +++ b/src/services/core-dotnet/AetherGuard.Core/Services/ExternalSignals/ExternalSignalsOptions.cs @@ -9,6 +9,7 @@ public sealed class ExternalSignalsOptions public int RetentionDays { get; set; } = 30; public int CleanupBatchSize { get; set; } = 500; public List Feeds { get; set; } = new(); + public ExternalSignalEnrichmentOptions Enrichment { get; set; } = new(); } public sealed class ExternalSignalFeedOptions @@ -17,3 +18,13 @@ public sealed class ExternalSignalFeedOptions public string Url { get; set; } = string.Empty; public string? DefaultRegion { get; set; } } + +public sealed class ExternalSignalEnrichmentOptions +{ + public bool Enabled { get; set; } = true; + public string BaseUrl { get; set; } = "http://ai-service:8000"; + public int TimeoutSeconds { get; set; } = 8; + public int MaxBatchSize { get; set; } = 200; + public int MaxConcurrency { get; set; } = 4; + public int SummaryMaxChars { get; set; } = 280; +} diff --git a/src/services/core-dotnet/AetherGuard.Core/appsettings.json b/src/services/core-dotnet/AetherGuard.Core/appsettings.json index 0ff04ea..38e812e 100644 --- a/src/services/core-dotnet/AetherGuard.Core/appsettings.json +++ b/src/services/core-dotnet/AetherGuard.Core/appsettings.json @@ -79,6 +79,14 @@ "MaxItemsPerFeed": 200, "RetentionDays": 30, "CleanupBatchSize": 500, + "Enrichment": { + "Enabled": true, + "BaseUrl": "http://ai-service:8000", + "TimeoutSeconds": 8, + "MaxBatchSize": 200, + "MaxConcurrency": 4, + "SummaryMaxChars": 280 + }, "Feeds": [ { "Name": "aws-status", diff --git a/src/services/core-dotnet/AetherGuard.Core/models/ExternalSignal.cs b/src/services/core-dotnet/AetherGuard.Core/models/ExternalSignal.cs index 870fb36..da21166 100644 --- a/src/services/core-dotnet/AetherGuard.Core/models/ExternalSignal.cs +++ b/src/services/core-dotnet/AetherGuard.Core/models/ExternalSignal.cs @@ -7,6 +7,17 @@ public sealed class ExternalSignal public string ExternalId { get; set; } = string.Empty; public string Title { get; set; } = string.Empty; public string? Summary { get; set; } + public string? SummaryDigest { get; set; } + public bool? SummaryDigestTruncated { get; set; } + public string? SummarySchemaVersion { get; set; } + public string? EnrichmentSchemaVersion { get; set; } + public double? SentimentNegative { get; set; } + public double? SentimentNeutral { get; set; } + public double? SentimentPositive { get; set; } + public double? VolatilityProbability { get; set; } + public double? SupplyBias { get; set; } + public DateTimeOffset? SummarizedAt { get; set; } + public DateTimeOffset? EnrichedAt { get; set; } public string? Region { get; set; } public string? Severity { get; set; } public string? Category { get; set; } From 99a5864c6f92d0d8da8a5da863b2472f8e702f6b Mon Sep 17 00:00:00 2001 From: JasonEran Date: Thu, 19 Feb 2026 14:27:04 +0800 Subject: [PATCH 10/24] feat(v2.3): implement batch enrichment endpoint and update enrichment logic --- CHANGELOG.md | 5 +- README.md | 3 +- docs/QA-SmokeTest-v2.3-M1.md | 89 ++++++++++ docs/Quickstart.md | 17 ++ docs/ROADMAP-v2.3.md | 2 + src/services/ai-engine/main.py | 102 ++++++++++- .../ExternalSignalEnrichmentClientTests.cs | 167 ++++++++++++++++++ .../ExternalSignalEnrichmentClient.cs | 149 +++++++++++++++- .../ExternalSignalIngestionService.cs | 87 ++++++--- 9 files changed, 583 insertions(+), 38 deletions(-) create mode 100644 docs/QA-SmokeTest-v2.3-M1.md create mode 100644 src/services/core-dotnet/AetherGuard.Core.Tests/ExternalSignalEnrichmentClientTests.cs diff --git a/CHANGELOG.md b/CHANGELOG.md index c0e3b4f..ae3cc1e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,7 +15,9 @@ Semantic Versioning. - External signals ingestion pipeline (RSS feeds) with persisted `external_signals` table. - External signal feed health tracking (`external_signal_feeds`) and feed status API. - Parser regression tests for RSS/Atom feeds. -- AI Engine semantic enrichment stub (`/signals/enrich`) for v2.3 pipeline integration. +- AI Engine semantic enrichment service (`/signals/enrich`) with FinBERT/heuristic fallback. +- Batch semantic enrichment endpoint (`/signals/enrich/batch`) with schema-versioned vectors. +- v2.3 Milestone 1 smoke-test checklist in `docs/QA-SmokeTest-v2.3-M1.md`. - v2.3 multimodal predictive architecture document in `docs/ARCHITECTURE-v2.3.md`. - v2.3 delivery roadmap in `docs/ROADMAP-v2.3.md`. - Expanded v2.3 roadmap with model choices, data sources, and validation guidance. @@ -25,6 +27,7 @@ Semantic Versioning. ### Changed - Agent now injects W3C trace headers for HTTP requests. - Dashboard dependencies updated to Next.js 16.1.6. +- Core external-signal ingestion now prefers batch enrichment and falls back to per-item enrichment. ### Deprecated - diff --git a/README.md b/README.md index 7ad8553..905fdc0 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ v2.2 reference architecture with a concrete implementation guide. ## Project Status -- Stage: v2.2 baseline delivered (Phase 0-4). v2.3 transition roadmap in docs/ARCHITECTURE-v2.3.md. +- Stage: v2.2 baseline delivered (Phase 0-4). v2.3 Milestone 1 delivered; Milestone 2+ tracked in docs/ROADMAP-v2.3.md. - License: MIT - Authors: Qi Junyi, Xiao Erdong (2026) - Sponsor: https://github.com/sponsors/JasonEran @@ -194,6 +194,7 @@ Open the dashboard at http://localhost:3000. - Observability (OpenTelemetry): docs/Observability.md - v2.3 architecture roadmap: docs/ARCHITECTURE-v2.3.md - v2.3 delivery roadmap: docs/ROADMAP-v2.3.md +- v2.3 Milestone 1 smoke test: docs/QA-SmokeTest-v2.3-M1.md If you want to simulate migrations, start at least two agents: diff --git a/docs/QA-SmokeTest-v2.3-M1.md b/docs/QA-SmokeTest-v2.3-M1.md new file mode 100644 index 0000000..ccff2f7 --- /dev/null +++ b/docs/QA-SmokeTest-v2.3-M1.md @@ -0,0 +1,89 @@ +# v2.3 Milestone 1 Smoke Test Checklist (Semantic Enrichment) + +This checklist validates the v2.3 Milestone 1 flow: FinBERT/heuristic enrichment, +summary generation, schema versioning, and batch enrichment integration. + +## Preconditions + +- Docker Desktop running +- `COMMAND_API_KEY` set +- External signals enabled + +## Start stack + +```bash +# PowerShell +$env:COMMAND_API_KEY="changeme" +$env:ExternalSignals__Enabled="true" + +docker compose up --build -d +``` + +## Verify enrichment schema metadata + +```bash +curl http://localhost:8000/signals/enrich/schema +``` + +Expected: +- `schemaVersion` is present +- `batchEndpoint` is `/signals/enrich/batch` +- `fields` contains `S_v`, `P_v`, `B_s` + +## Verify single enrichment endpoint + +```bash +curl -X POST http://localhost:8000/signals/enrich \ + -H "Content-Type: application/json" \ + -d '{"documents":[{"source":"aws","title":"Service disruption in us-east-1","summary":"Investigating elevated errors."}]}' +``` + +Expected: +- HTTP 200 +- Response includes `schemaVersion`, `S_v`, `P_v`, `B_s` + +## Verify batch enrichment endpoint + +```bash +curl -X POST http://localhost:8000/signals/enrich/batch \ + -H "Content-Type: application/json" \ + -d '{"documents":[{"source":"aws","title":"Service disruption in us-east-1","summary":"Investigating elevated errors."},{"source":"gcp","title":"RESOLVED: incident in us-central1","summary":"Service recovered."}]}' +``` + +Expected: +- HTTP 200 +- `vectors` length matches input document count +- Each vector includes `index`, `S_v`, `P_v`, `B_s` + +## Verify summarization endpoint + +```bash +curl -X POST http://localhost:8000/signals/summarize \ + -H "Content-Type: application/json" \ + -d '{"documents":[{"source":"aws","title":"AWS incident update","summary":"Service degradation observed in us-east-1 with intermittent errors and elevated latency."}],"maxChars":160}' +``` + +Expected: +- HTTP 200 +- Response includes `schemaVersion` and non-empty `summaries` + +## Verify Core integration result + +```bash +curl http://localhost:5000/api/v1/signals?limit=10 +``` + +Expected (after at least one ingest cycle): +- At least one signal has `summarySchemaVersion` or `enrichmentSchemaVersion` +- Enriched vectors are persisted (`sentimentNegative`, `sentimentNeutral`, `sentimentPositive`, `volatilityProbability`, `supplyBias`) + +## Optional throughput sanity check + +Send 100+ documents to `/signals/enrich/batch` and confirm request latency remains within your +target budget for deployment (for example, under a few seconds in local Docker with heuristic mode). + +## Stop stack + +```bash +docker compose down +``` diff --git a/docs/Quickstart.md b/docs/Quickstart.md index e1c518a..adfc120 100644 --- a/docs/Quickstart.md +++ b/docs/Quickstart.md @@ -101,6 +101,23 @@ Schema details: GET /signals/enrich/schema ``` +Batch enrichment endpoint (recommended for throughput): + +``` +POST /signals/enrich/batch +``` + +Example payload: + +```bash +curl -X POST http://localhost:8000/signals/enrich/batch \ + -H "Content-Type: application/json" \ + -d '{"documents":[{"source":"aws","title":"Service disruption in us-east-1","summary":"Investigating elevated errors."},{"source":"gcp","title":"RESOLVED: incident in us-central1","summary":"Service recovered."}]}' +``` + +Core prefers `/signals/enrich/batch` and falls back to `/signals/enrich` automatically. +Milestone 1 smoke test checklist: `docs/QA-SmokeTest-v2.3-M1.md`. + ## Optional: Enable summarization (v2.3 Milestone 1) The AI engine can summarize long advisories with a built-in heuristic summarizer or a remote HTTP summarizer. diff --git a/docs/ROADMAP-v2.3.md b/docs/ROADMAP-v2.3.md index a3e30f7..55a6bdd 100644 --- a/docs/ROADMAP-v2.3.md +++ b/docs/ROADMAP-v2.3.md @@ -14,6 +14,7 @@ minimize integration risk while preserving backward compatibility with the v2.2 ### Milestone 0: Data Foundation (Signals Ingestion) **Goal**: Introduce external cloud signals and store them alongside telemetry windows. +**Status**: Completed (2026-02-11) - Add connectors for provider status feeds and incident streams. - Normalize signal schema (timestamp, region, severity, source, summary). @@ -27,6 +28,7 @@ minimize integration risk while preserving backward compatibility with the v2.2 ### Milestone 1: Semantic Enrichment Service **Goal**: Extract semantic vectors from signals without impacting core latency. +**Status**: Completed (2026-02-19) - Add NLP service for incident sentiment and volatility likelihood (FinBERT or domain BERT). - Add LLM summarizer for longer advisories and policy updates. diff --git a/src/services/ai-engine/main.py b/src/services/ai-engine/main.py index c17054e..28b6d18 100644 --- a/src/services/ai-engine/main.py +++ b/src/services/ai-engine/main.py @@ -74,20 +74,57 @@ class EnrichResponse(BaseModel): schema_version: str = Field(alias="schemaVersion", description="Semantic vector schema version.") s_v: list[float] = Field( alias="S_v", + min_length=3, + max_length=3, description="Sentiment vector [negative, neutral, positive], normalized to sum to 1.", ) p_v: float = Field( alias="P_v", + ge=0.0, + le=1.0, description="Volatility probability in the range [0, 1].", ) b_s: float = Field( alias="B_s", + ge=0.0, + le=1.0, description="Supply or capacity bias (long-horizon signal).", ) model_config = ConfigDict(populate_by_name=True) +class EnrichBatchItem(BaseModel): + index: int = Field(ge=0) + s_v: list[float] = Field( + alias="S_v", + min_length=3, + max_length=3, + description="Sentiment vector [negative, neutral, positive], normalized to sum to 1.", + ) + p_v: float = Field( + alias="P_v", + ge=0.0, + le=1.0, + description="Volatility probability in the range [0, 1].", + ) + b_s: float = Field( + alias="B_s", + ge=0.0, + le=1.0, + description="Supply or capacity bias (long-horizon signal).", + ) + + model_config = ConfigDict(populate_by_name=True) + + +class EnrichBatchResponse(BaseModel): + schema_version: str = Field(alias="schemaVersion", description="Semantic vector schema version.") + vectors: list[EnrichBatchItem] + + model_config = ConfigDict(populate_by_name=True) + + class SummarizeRequest(BaseModel): documents: list[SignalDocument] max_chars: int | None = Field(default=None, alias="maxChars") @@ -140,7 +177,7 @@ def analyze(payload: RiskPayload) -> dict: @app.post("/signals/enrich", response_model=EnrichResponse) def enrich_signals(payload: EnrichRequest) -> EnrichResponse: enricher: SemanticEnricher = app.state.enricher - result = enricher.enrich(payload.documents) + result = sanitize_enrich_result(enricher.enrich(payload.documents)) return EnrichResponse( schemaVersion=ENRICHMENT_SCHEMA_VERSION, S_v=result.s_v, @@ -149,10 +186,32 @@ def enrich_signals(payload: EnrichRequest) -> EnrichResponse: ) +@app.post("/signals/enrich/batch", response_model=EnrichBatchResponse) +def enrich_signals_batch(payload: EnrichRequest) -> EnrichBatchResponse: + enricher: SemanticEnricher = app.state.enricher + vectors = [ + EnrichBatchItem( + index=index, + S_v=result.s_v, + P_v=result.p_v, + B_s=result.b_s, + ) + for index, result in enumerate( + [sanitize_enrich_result(result) for result in enricher.enrich_batch(payload.documents)] + ) + ] + + return EnrichBatchResponse( + schemaVersion=ENRICHMENT_SCHEMA_VERSION, + vectors=vectors, + ) + + @app.get("/signals/enrich/schema") def enrich_schema() -> dict: return { "schemaVersion": ENRICHMENT_SCHEMA_VERSION, + "batchEndpoint": "/signals/enrich/batch", "fields": { "S_v": "Sentiment vector [negative, neutral, positive], normalized to sum to 1.", "P_v": "Volatility probability in [0, 1].", @@ -172,6 +231,9 @@ class SemanticEnricher: def enrich(self, documents: Iterable[SignalDocument]) -> EnrichResult: # pragma: no cover - interface raise NotImplementedError + def enrich_batch(self, documents: Iterable[SignalDocument]) -> list[EnrichResult]: + return [self.enrich([document]) for document in documents] + @dataclass class SummarizeResult: @@ -189,7 +251,18 @@ class HeuristicEnricher(SemanticEnricher): supply_terms = ("capacity", "shortage", "procurement", "inventory", "supply", "quota") def enrich(self, documents: Iterable[SignalDocument]) -> EnrichResult: - combined = " ".join([self._doc_text(doc) for doc in documents]).lower() + document_list = list(documents) + if not document_list: + return EnrichResult(s_v=[0.15, 0.7, 0.15], p_v=0.1, b_s=0.0) + + # Keep the legacy aggregate behavior for backward compatibility. + combined = " ".join([self._doc_text(doc) for doc in document_list]).lower() + return self._score_text(combined) + + def enrich_batch(self, documents: Iterable[SignalDocument]) -> list[EnrichResult]: + return [self._score_text(self._doc_text(document).lower()) for document in documents] + + def _score_text(self, combined: str) -> EnrichResult: negative_hits = sum(term in combined for term in self.negative_terms) supply_hits = sum(term in combined for term in self.supply_terms) @@ -225,7 +298,8 @@ def __init__(self, model_id: str, max_chars: int, cache_size: int) -> None: self._cache = lru_cache(maxsize=cache_size)(self._score_text) def enrich(self, documents: Iterable[SignalDocument]) -> EnrichResult: - scores = [self._cache(self._doc_text(doc)) for doc in documents] + document_list = list(documents) + scores = [self._cache(self._doc_text(doc)) for doc in document_list] if not scores: return EnrichResult(s_v=[0.15, 0.7, 0.15], p_v=0.1, b_s=0.0) @@ -236,9 +310,20 @@ def enrich(self, documents: Iterable[SignalDocument]) -> EnrichResult: volatility_boost = max(0.0, neg - pos) p_v = clamp(0.1 + volatility_boost * 1.2, 0.0, 1.0) - b_s = clamp(self._supply_bias(documents), 0.0, 1.0) + b_s = clamp(self._supply_bias(document_list), 0.0, 1.0) return EnrichResult(s_v=s_v, p_v=p_v, b_s=b_s) + def enrich_batch(self, documents: Iterable[SignalDocument]) -> list[EnrichResult]: + results: list[EnrichResult] = [] + for document in documents: + scores = self._cache(self._doc_text(document)) + neg = scores[0] + pos = scores[2] + p_v = clamp(0.1 + max(0.0, neg - pos) * 1.2, 0.0, 1.0) + b_s = clamp(self._supply_bias([document]), 0.0, 1.0) + results.append(EnrichResult(s_v=scores, p_v=p_v, b_s=b_s)) + return results + def _score_text(self, text: str) -> list[float]: scores_raw = self._pipeline(text[: self._max_chars]) scores = self._parse_scores(scores_raw) @@ -391,6 +476,15 @@ def clamp(value: float, min_value: float, max_value: float) -> float: return max(min_value, min(max_value, value)) +def sanitize_enrich_result(result: EnrichResult) -> EnrichResult: + values = normalize_vector((result.s_v + [0.0, 0.0, 0.0])[:3]) + return EnrichResult( + s_v=values, + p_v=clamp(result.p_v, 0.0, 1.0), + b_s=clamp(result.b_s, 0.0, 1.0), + ) + + def normalize_text(text: str) -> str: return " ".join(text.replace("\n", " ").split()).strip() diff --git a/src/services/core-dotnet/AetherGuard.Core.Tests/ExternalSignalEnrichmentClientTests.cs b/src/services/core-dotnet/AetherGuard.Core.Tests/ExternalSignalEnrichmentClientTests.cs new file mode 100644 index 0000000..bfdc8d9 --- /dev/null +++ b/src/services/core-dotnet/AetherGuard.Core.Tests/ExternalSignalEnrichmentClientTests.cs @@ -0,0 +1,167 @@ +using System; +using System.Net; +using System.Net.Http; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using AetherGuard.Core.Models; +using AetherGuard.Core.Services.ExternalSignals; +using Microsoft.Extensions.Logging.Abstractions; +using Microsoft.Extensions.Options; +using Xunit; + +namespace AetherGuard.Core.Tests; + +public class ExternalSignalEnrichmentClientTests +{ + [Fact] + public async Task EnrichBatchAsync_ReturnsVectors_WhenPayloadIsValid() + { + var json = """ + { + "schemaVersion": "1.0", + "vectors": [ + { "index": 0, "S_v": [0.6, 0.3, 0.1], "P_v": 0.72, "B_s": 0.12 }, + { "index": 1, "S_v": [0.2, 0.5, 0.3], "P_v": 0.35, "B_s": 0.40 } + ] + } + """; + + var client = CreateClient((request, _) => + { + if (request.RequestUri?.AbsolutePath.EndsWith("/signals/enrich/batch", StringComparison.OrdinalIgnoreCase) != true) + { + return new HttpResponseMessage(HttpStatusCode.NotFound); + } + + return new HttpResponseMessage(HttpStatusCode.OK) + { + Content = new StringContent(json, Encoding.UTF8, "application/json") + }; + }); + + var signals = new[] + { + CreateSignal("signal-1"), + CreateSignal("signal-2") + }; + + var result = await client.EnrichBatchAsync(signals, CancellationToken.None); + + Assert.NotNull(result); + Assert.Equal("1.0", result!.SchemaVersion); + Assert.Equal(2, result.Vectors.Count); + Assert.Equal(0, result.Vectors[0].Index); + Assert.Equal(3, result.Vectors[0].Vector.SentimentVector.Length); + Assert.Equal(0.72, result.Vectors[0].Vector.VolatilityProbability, 3); + } + + [Fact] + public async Task EnrichBatchAsync_FiltersInvalidVectors() + { + var json = """ + { + "schemaVersion": "1.0", + "vectors": [ + { "index": 0, "S_v": [0.6, 0.3], "P_v": 0.72, "B_s": 0.12 }, + { "index": 1, "S_v": [0.2, 0.5, 0.3], "P_v": 0.35, "B_s": 0.40 } + ] + } + """; + + var client = CreateClient((_, _) => new HttpResponseMessage(HttpStatusCode.OK) + { + Content = new StringContent(json, Encoding.UTF8, "application/json") + }); + + var signals = new[] + { + CreateSignal("signal-1"), + CreateSignal("signal-2") + }; + + var result = await client.EnrichBatchAsync(signals, CancellationToken.None); + + Assert.NotNull(result); + Assert.Single(result!.Vectors); + Assert.Equal(1, result.Vectors[0].Index); + } + + [Fact] + public async Task EnrichAsync_ClampsOutOfRangeValues() + { + var json = """ + { + "schemaVersion": "1.0", + "S_v": [0.6, 0.3, 0.1], + "P_v": 2.5, + "B_s": -1.0 + } + """; + + var client = CreateClient((request, _) => + { + if (request.RequestUri?.AbsolutePath.EndsWith("/signals/enrich", StringComparison.OrdinalIgnoreCase) != true) + { + return new HttpResponseMessage(HttpStatusCode.NotFound); + } + + return new HttpResponseMessage(HttpStatusCode.OK) + { + Content = new StringContent(json, Encoding.UTF8, "application/json") + }; + }); + + var result = await client.EnrichAsync(CreateSignal("signal-1"), CancellationToken.None); + + Assert.NotNull(result); + Assert.Equal(1.0, result!.VolatilityProbability, 3); + Assert.Equal(0.0, result.SupplyBias, 3); + } + + private static ExternalSignalEnrichmentClient CreateClient( + Func responder) + { + var options = Options.Create(new ExternalSignalsOptions + { + Enrichment = new ExternalSignalEnrichmentOptions + { + Enabled = true, + BaseUrl = "http://localhost:8000", + TimeoutSeconds = 5, + MaxBatchSize = 200, + MaxConcurrency = 4, + SummaryMaxChars = 280 + } + }); + + var httpClient = new HttpClient(new StubHandler(responder)); + return new ExternalSignalEnrichmentClient( + httpClient, + options, + NullLogger.Instance); + } + + private static ExternalSignal CreateSignal(string externalId) + => new() + { + Source = "aws-status", + ExternalId = externalId, + Title = $"Signal {externalId}", + Summary = "Investigating elevated error rates.", + PublishedAt = DateTimeOffset.UtcNow + }; + + private sealed class StubHandler : HttpMessageHandler + { + private readonly Func _responder; + + public StubHandler(Func responder) + { + _responder = responder; + } + + protected override Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) + => Task.FromResult(_responder(request, cancellationToken)); + } +} diff --git a/src/services/core-dotnet/AetherGuard.Core/Services/ExternalSignals/ExternalSignalEnrichmentClient.cs b/src/services/core-dotnet/AetherGuard.Core/Services/ExternalSignals/ExternalSignalEnrichmentClient.cs index ad8f569..1b958c1 100644 --- a/src/services/core-dotnet/AetherGuard.Core/Services/ExternalSignals/ExternalSignalEnrichmentClient.cs +++ b/src/services/core-dotnet/AetherGuard.Core/Services/ExternalSignals/ExternalSignalEnrichmentClient.cs @@ -114,17 +114,19 @@ public ExternalSignalEnrichmentClient( ResponseJsonOptions, cancellationToken); - if (payload?.SentimentVector is null || payload.SentimentVector.Length < 3) + if (payload is null + || !TryBuildEnrichResponse( + payload.SchemaVersion, + payload.SentimentVector, + payload.VolatilityProbability, + payload.SupplyBias, + out var normalized)) { _logger.LogWarning("Enrichment payload missing S_v for signal {ExternalId}.", signal.ExternalId); return null; } - return new EnrichResponse( - payload.SchemaVersion ?? "unknown", - payload.SentimentVector, - payload.VolatilityProbability, - payload.SupplyBias); + return normalized; } catch (Exception ex) { @@ -133,6 +135,78 @@ public ExternalSignalEnrichmentClient( } } + public async Task EnrichBatchAsync( + IReadOnlyList signals, + CancellationToken cancellationToken) + { + if (!IsEnabled || signals.Count == 0) + { + return null; + } + + var documents = signals.Take(MaxBatchSize).Select(MapDocument).ToList(); + if (documents.Count == 0) + { + return null; + } + var request = new EnrichRequestDto(documents); + + try + { + using var response = await _httpClient.PostAsJsonAsync( + "signals/enrich/batch", + request, + RequestJsonOptions, + cancellationToken); + + if (!response.IsSuccessStatusCode) + { + _logger.LogWarning("Batch enrichment returned HTTP {StatusCode}.", response.StatusCode); + return null; + } + + var payload = await response.Content.ReadFromJsonAsync( + ResponseJsonOptions, + cancellationToken); + + if (payload?.Vectors is null || payload.Vectors.Count == 0) + { + _logger.LogWarning("Batch enrichment returned an empty payload."); + return null; + } + + var vectors = payload.Vectors + .Where(item => item.Index >= 0 && item.Index < documents.Count) + .Select(item => + { + var isValid = TryBuildEnrichResponse( + payload.SchemaVersion, + item.SentimentVector, + item.VolatilityProbability, + item.SupplyBias, + out var normalized); + return (Index: item.Index, IsValid: isValid, Response: normalized); + }) + .Where(entry => entry.IsValid && entry.Response is not null) + .GroupBy(entry => entry.Index) + .Select(group => new BatchEnrichItem(group.Key, group.First().Response!)) + .ToList(); + + if (vectors.Count == 0) + { + _logger.LogWarning("Batch enrichment payload had no valid vectors."); + return null; + } + + return new BatchEnrichResponse(payload.SchemaVersion ?? "unknown", vectors); + } + catch (Exception ex) + { + _logger.LogWarning(ex, "Batch enrichment request failed."); + return null; + } + } + private bool ConfigureClient(ExternalSignalEnrichmentOptions options) { if (string.IsNullOrWhiteSpace(options.BaseUrl)) @@ -166,8 +240,47 @@ private static SignalDocumentDto MapDocument(ExternalSignal signal) Region: signal.Region, PublishedAt: signal.PublishedAt); + private static bool TryBuildEnrichResponse( + string? schemaVersion, + double[]? sentimentVector, + double volatilityProbability, + double supplyBias, + out EnrichResponse? response) + { + response = null; + if (sentimentVector is null || sentimentVector.Length < 3) + { + return false; + } + + var values = sentimentVector.Take(3).ToArray(); + if (values.Any(value => double.IsNaN(value) || double.IsInfinity(value))) + { + return false; + } + + if (double.IsNaN(volatilityProbability) || double.IsInfinity(volatilityProbability)) + { + return false; + } + + if (double.IsNaN(supplyBias) || double.IsInfinity(supplyBias)) + { + return false; + } + + response = new EnrichResponse( + schemaVersion ?? "unknown", + values, + Math.Clamp(volatilityProbability, 0.0, 1.0), + Math.Clamp(supplyBias, 0.0, 1.0)); + return true; + } + public sealed record SummarizeResponse(string SchemaVersion, IReadOnlyList Summaries); public sealed record EnrichResponse(string SchemaVersion, double[] SentimentVector, double VolatilityProbability, double SupplyBias); + public sealed record BatchEnrichResponse(string SchemaVersion, IReadOnlyList Vectors); + public sealed record BatchEnrichItem(int Index, EnrichResponse Vector); public sealed record SummaryItemDto( [property: JsonPropertyName("index")] int Index, @@ -212,4 +325,28 @@ private sealed class EnrichResponseDto [JsonPropertyName("B_s")] public double SupplyBias { get; init; } } + + private sealed class EnrichBatchResponseDto + { + [JsonPropertyName("schemaVersion")] + public string? SchemaVersion { get; init; } + + [JsonPropertyName("vectors")] + public List Vectors { get; init; } = new(); + } + + private sealed class EnrichBatchVectorDto + { + [JsonPropertyName("index")] + public int Index { get; init; } + + [JsonPropertyName("S_v")] + public double[]? SentimentVector { get; init; } + + [JsonPropertyName("P_v")] + public double VolatilityProbability { get; init; } + + [JsonPropertyName("B_s")] + public double SupplyBias { get; init; } + } } diff --git a/src/services/core-dotnet/AetherGuard.Core/Services/ExternalSignals/ExternalSignalIngestionService.cs b/src/services/core-dotnet/AetherGuard.Core/Services/ExternalSignals/ExternalSignalIngestionService.cs index de13795..e39c8c3 100644 --- a/src/services/core-dotnet/AetherGuard.Core/Services/ExternalSignals/ExternalSignalIngestionService.cs +++ b/src/services/core-dotnet/AetherGuard.Core/Services/ExternalSignals/ExternalSignalIngestionService.cs @@ -194,39 +194,60 @@ private async Task EnrichSignalsAsync( } } - var semaphore = new SemaphoreSlim(client.MaxConcurrency); - var tasks = batch.Select(async signal => + var enrichedAt = DateTimeOffset.UtcNow; + var batchResponse = await client.EnrichBatchAsync(batch, cancellationToken); + + if (batchResponse is not null) { - await semaphore.WaitAsync(cancellationToken); - try - { - var result = await client.EnrichAsync(signal, cancellationToken); - return (signal, result); - } - finally + var vectorsByIndex = batchResponse.Vectors + .Where(item => item.Index >= 0 && item.Index < batch.Count) + .GroupBy(item => item.Index) + .ToDictionary(group => group.Key, group => group.First().Vector); + + for (var index = 0; index < batch.Count; index++) { - semaphore.Release(); - } - }).ToList(); + if (!vectorsByIndex.TryGetValue(index, out var vector)) + { + continue; + } - var results = await Task.WhenAll(tasks); - var enrichedAt = DateTimeOffset.UtcNow; + if (vector.SentimentVector.Length < 3) + { + continue; + } - foreach (var (signal, result) in results) + ApplyEnrichment(batch[index], vector, enrichedAt); + enrichmentUpdates++; + } + } + else { - if (result is null || result.SentimentVector.Length < 3) + var semaphore = new SemaphoreSlim(client.MaxConcurrency); + var tasks = batch.Select(async signal => { - continue; - } + await semaphore.WaitAsync(cancellationToken); + try + { + var result = await client.EnrichAsync(signal, cancellationToken); + return (signal, result); + } + finally + { + semaphore.Release(); + } + }).ToList(); - signal.SentimentNegative = Clamp01(result.SentimentVector[0]); - signal.SentimentNeutral = Clamp01(result.SentimentVector[1]); - signal.SentimentPositive = Clamp01(result.SentimentVector[2]); - signal.VolatilityProbability = Clamp01(result.VolatilityProbability); - signal.SupplyBias = Clamp01(result.SupplyBias); - signal.EnrichmentSchemaVersion = result.SchemaVersion; - signal.EnrichedAt = enrichedAt; - enrichmentUpdates++; + var results = await Task.WhenAll(tasks); + foreach (var (signal, result) in results) + { + if (result is null || result.SentimentVector.Length < 3) + { + continue; + } + + ApplyEnrichment(signal, result, enrichedAt); + enrichmentUpdates++; + } } if (summaryUpdates > 0 || enrichmentUpdates > 0) @@ -242,6 +263,20 @@ private async Task EnrichSignalsAsync( private static double Clamp01(double value) => Math.Clamp(value, 0.0, 1.0); + private static void ApplyEnrichment( + ExternalSignal signal, + ExternalSignalEnrichmentClient.EnrichResponse result, + DateTimeOffset enrichedAt) + { + signal.SentimentNegative = Clamp01(result.SentimentVector[0]); + signal.SentimentNeutral = Clamp01(result.SentimentVector[1]); + signal.SentimentPositive = Clamp01(result.SentimentVector[2]); + signal.VolatilityProbability = Clamp01(result.VolatilityProbability); + signal.SupplyBias = Clamp01(result.SupplyBias); + signal.EnrichmentSchemaVersion = result.SchemaVersion; + signal.EnrichedAt = enrichedAt; + } + private static async Task UpdateFeedStateAsync( ApplicationDbContext db, ExternalSignalFeedOptions feed, From d77f3b2a4fcde3deec834ac1d1aaa1734fd23f02 Mon Sep 17 00:00:00 2001 From: JasonEran Date: Sun, 22 Feb 2026 13:20:50 +0800 Subject: [PATCH 11/24] feat(v2.3): enhance observability with custom metrics and tracing for external signals --- docs/Observability.md | 37 ++++ docs/QA-SmokeTest-v2.3-M1.md | 15 ++ src/services/ai-engine/main.py | 203 +++++++++++++----- .../core-dotnet/AetherGuard.Core/Program.cs | 3 + .../ExternalSignalEnrichmentClient.cs | 120 +++++++++++ .../ExternalSignalIngestionService.cs | 189 ++++++++++------ .../ExternalSignalsTelemetry.cs | 134 ++++++++++++ 7 files changed, 581 insertions(+), 120 deletions(-) create mode 100644 src/services/core-dotnet/AetherGuard.Core/Services/ExternalSignals/ExternalSignalsTelemetry.cs diff --git a/docs/Observability.md b/docs/Observability.md index efc05b9..ffb984c 100644 --- a/docs/Observability.md +++ b/docs/Observability.md @@ -12,6 +12,9 @@ The Docker Compose stack includes an OpenTelemetry Collector and Jaeger for dist - Core API (.NET): ASP.NET Core, HttpClient, EF Core, runtime/process metrics. - AI Engine (FastAPI): FastAPI + requests spans. - Dashboard (Next.js): Node.js auto-instrumentation for server routes. +- External Signals pipeline: + - Core custom traces/metrics for summarize + enrich (batch and fallback paths). + - AI custom traces/metrics for `/signals/enrich`, `/signals/enrich/batch`, and `/signals/summarize`. ## Configuration knobs @@ -28,6 +31,32 @@ AI + Dashboard (environment): - `OTEL_TRACES_EXPORTER` - `OTEL_METRICS_EXPORTER` +## Custom metric names (v2.3 Milestone 1) + +Core (`AetherGuard.Core.ExternalSignals` meter): + +- `aetherguard.external_signals.client.requests` +- `aetherguard.external_signals.client.failures` +- `aetherguard.external_signals.client.duration.ms` +- `aetherguard.external_signals.client.documents` +- `aetherguard.external_signals.pipeline.runs` +- `aetherguard.external_signals.pipeline.fallbacks` +- `aetherguard.external_signals.pipeline.duration.ms` +- `aetherguard.external_signals.pipeline.batch.size` +- `aetherguard.external_signals.pipeline.updates` + +AI (`aether_guard.ai.signals` meter): + +- `aetherguard.ai.signals.requests` +- `aetherguard.ai.signals.errors` +- `aetherguard.ai.signals.duration.ms` +- `aetherguard.ai.signals.documents` + +## Trace entry points (v2.3 Milestone 1) + +- Core spans: `external_signals.client.*`, `external_signals.pipeline.enrich` +- AI spans: `ai.signals.enrich`, `ai.signals.enrich.batch`, `ai.signals.summarize` + ## Troubleshooting If traces are missing: @@ -35,3 +64,11 @@ If traces are missing: - Verify `otel-collector` and `jaeger` are running: `docker compose ps`. - Check collector logs: `docker compose logs -f otel-collector`. - Ensure OTLP ports are free (4317/4318). + +Quick trace check: + +1. Trigger signal enrichment by enabling external signals and waiting for one ingest cycle. +2. Open Jaeger (`http://localhost:16686`) and search services: + - `aether-guard-core` + - `aether-guard-ai` +3. Validate a trace chain includes Core outbound enrichment call and AI endpoint span. diff --git a/docs/QA-SmokeTest-v2.3-M1.md b/docs/QA-SmokeTest-v2.3-M1.md index ccff2f7..554cd74 100644 --- a/docs/QA-SmokeTest-v2.3-M1.md +++ b/docs/QA-SmokeTest-v2.3-M1.md @@ -82,6 +82,21 @@ Expected (after at least one ingest cycle): Send 100+ documents to `/signals/enrich/batch` and confirm request latency remains within your target budget for deployment (for example, under a few seconds in local Docker with heuristic mode). +## Verify observability (traces + metrics naming) + +1. Open Jaeger: `http://localhost:16686`. +2. Search traces for service `aether-guard-core` and look for spans: + - `external_signals.client.summarize` + - `external_signals.client.enrich_batch` (or fallback `external_signals.client.enrich`) + - `external_signals.pipeline.enrich` +3. Search traces for service `aether-guard-ai` and look for spans: + - `ai.signals.enrich` + - `ai.signals.enrich.batch` + - `ai.signals.summarize` +4. Confirm metric names are documented and exported by configured OTLP pipeline: + - Core: `aetherguard.external_signals.*` + - AI: `aetherguard.ai.signals.*` + ## Stop stack ```bash diff --git a/src/services/ai-engine/main.py b/src/services/ai-engine/main.py index 28b6d18..0e7025a 100644 --- a/src/services/ai-engine/main.py +++ b/src/services/ai-engine/main.py @@ -1,18 +1,23 @@ import logging import os import re -from contextlib import asynccontextmanager +import time +from contextlib import asynccontextmanager, contextmanager from dataclasses import dataclass from functools import lru_cache -from typing import Iterable +from typing import Iterable, Iterator from fastapi import FastAPI from pydantic import BaseModel, Field, ConfigDict -from opentelemetry import trace +from opentelemetry import trace, metrics +from opentelemetry.trace import Status, StatusCode from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter +from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor from opentelemetry.instrumentation.requests import RequestsInstrumentor from opentelemetry.sdk.resources import Resource +from opentelemetry.sdk.metrics import MeterProvider +from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor @@ -32,6 +37,14 @@ SUMMARY_MAX_CHARS = int(os.getenv("AI_SUMMARIZER_MAX_CHARS", "600")) SUMMARY_CACHE_SIZE = int(os.getenv("AI_SUMMARIZER_CACHE_SIZE", "1024")) SUMMARY_TIMEOUT_SECONDS = float(os.getenv("AI_SUMMARIZER_TIMEOUT", "8")) +SIGNALS_TELEMETRY_METER = "aether_guard.ai.signals" +SIGNALS_TELEMETRY_TRACER = "aether_guard.ai.signals" + +signals_tracer = trace.get_tracer(SIGNALS_TELEMETRY_TRACER) +signals_request_counter = None +signals_error_counter = None +signals_latency_histogram = None +signals_document_histogram = None @asynccontextmanager @@ -47,6 +60,85 @@ async def lifespan(app_instance: FastAPI): app = FastAPI(lifespan=lifespan) +def initialize_signals_metrics() -> None: + global signals_request_counter + global signals_error_counter + global signals_latency_histogram + global signals_document_histogram + + meter = metrics.get_meter(SIGNALS_TELEMETRY_METER) + signals_request_counter = meter.create_counter( + "aetherguard.ai.signals.requests", + description="Count of AI signal API requests.", + ) + signals_error_counter = meter.create_counter( + "aetherguard.ai.signals.errors", + description="Count of AI signal API request failures.", + ) + signals_latency_histogram = meter.create_histogram( + "aetherguard.ai.signals.duration.ms", + unit="ms", + description="Latency of AI signal API requests.", + ) + signals_document_histogram = meter.create_histogram( + "aetherguard.ai.signals.documents", + unit="documents", + description="Document count per AI signal API request.", + ) + + +def record_signal_request( + endpoint: str, + provider: str, + documents: int, + duration_ms: float, + outcome: str, + error_type: str | None = None, +) -> None: + base_attributes = { + "endpoint": endpoint, + "provider": provider, + } + outcome_attributes = { + **base_attributes, + "outcome": outcome, + } + + if signals_request_counter is not None: + signals_request_counter.add(1, outcome_attributes) + if signals_latency_histogram is not None: + signals_latency_histogram.record(duration_ms, base_attributes) + if signals_document_histogram is not None: + signals_document_histogram.record(max(0, documents), base_attributes) + + if error_type and signals_error_counter is not None: + signals_error_counter.add(1, {**base_attributes, "error.type": error_type}) + + +@contextmanager +def observe_signal_endpoint(endpoint: str, provider: str, documents: int) -> Iterator[object]: + start = time.perf_counter() + outcome = "success" + error_type: str | None = None + + with signals_tracer.start_as_current_span(f"ai{endpoint.replace('/', '.')}") as span: + span.set_attribute("ai.signals.endpoint", endpoint) + span.set_attribute("ai.signals.provider", provider) + span.set_attribute("ai.signals.documents", documents) + try: + yield span + span.set_status(Status(StatusCode.OK)) + except Exception as exc: + outcome = "error" + error_type = type(exc).__name__ + span.record_exception(exc) + span.set_status(Status(StatusCode.ERROR)) + raise + finally: + duration_ms = (time.perf_counter() - start) * 1000 + record_signal_request(endpoint, provider, documents, duration_ms, outcome, error_type) + + class RiskPayload(BaseModel): spot_price_history: list[float] = Field(default_factory=list, alias="spotPriceHistory") rebalance_signal: bool = Field(alias="rebalanceSignal") @@ -177,34 +269,39 @@ def analyze(payload: RiskPayload) -> dict: @app.post("/signals/enrich", response_model=EnrichResponse) def enrich_signals(payload: EnrichRequest) -> EnrichResponse: enricher: SemanticEnricher = app.state.enricher - result = sanitize_enrich_result(enricher.enrich(payload.documents)) - return EnrichResponse( - schemaVersion=ENRICHMENT_SCHEMA_VERSION, - S_v=result.s_v, - P_v=result.p_v, - B_s=result.b_s, - ) + with observe_signal_endpoint("/signals/enrich", ENRICHMENT_PROVIDER, len(payload.documents)) as span: + result = sanitize_enrich_result(enricher.enrich(payload.documents)) + span.set_attribute("ai.signals.schema_version", ENRICHMENT_SCHEMA_VERSION) + return EnrichResponse( + schemaVersion=ENRICHMENT_SCHEMA_VERSION, + S_v=result.s_v, + P_v=result.p_v, + B_s=result.b_s, + ) @app.post("/signals/enrich/batch", response_model=EnrichBatchResponse) def enrich_signals_batch(payload: EnrichRequest) -> EnrichBatchResponse: enricher: SemanticEnricher = app.state.enricher - vectors = [ - EnrichBatchItem( - index=index, - S_v=result.s_v, - P_v=result.p_v, - B_s=result.b_s, - ) - for index, result in enumerate( - [sanitize_enrich_result(result) for result in enricher.enrich_batch(payload.documents)] - ) - ] + with observe_signal_endpoint("/signals/enrich/batch", ENRICHMENT_PROVIDER, len(payload.documents)) as span: + vectors = [ + EnrichBatchItem( + index=index, + S_v=result.s_v, + P_v=result.p_v, + B_s=result.b_s, + ) + for index, result in enumerate( + [sanitize_enrich_result(result) for result in enricher.enrich_batch(payload.documents)] + ) + ] - return EnrichBatchResponse( - schemaVersion=ENRICHMENT_SCHEMA_VERSION, - vectors=vectors, - ) + span.set_attribute("ai.signals.schema_version", ENRICHMENT_SCHEMA_VERSION) + span.set_attribute("ai.signals.vectors", len(vectors)) + return EnrichBatchResponse( + schemaVersion=ENRICHMENT_SCHEMA_VERSION, + vectors=vectors, + ) @app.get("/signals/enrich/schema") @@ -492,38 +589,46 @@ def normalize_text(text: str) -> str: @app.post("/signals/summarize", response_model=SummarizeResponse) def summarize_signals(payload: SummarizeRequest) -> SummarizeResponse: summarizer: SignalSummarizer = app.state.summarizer - max_chars = payload.max_chars if payload.max_chars and payload.max_chars > 0 else SUMMARY_MAX_CHARS - summaries: list[SummaryItem] = [] - - for index, doc in enumerate(payload.documents): - source = doc.source - title = doc.title - text = f"{doc.title}. {doc.summary}" if doc.summary else doc.title - result = summarizer.summarize(text, max_chars) - summaries.append( - SummaryItem( - index=index, - source=source, - title=title, - summary=result.summary, - truncated=result.truncated, + with observe_signal_endpoint("/signals/summarize", SUMMARY_PROVIDER, len(payload.documents)) as span: + max_chars = payload.max_chars if payload.max_chars and payload.max_chars > 0 else SUMMARY_MAX_CHARS + summaries: list[SummaryItem] = [] + + for index, doc in enumerate(payload.documents): + source = doc.source + title = doc.title + text = f"{doc.title}. {doc.summary}" if doc.summary else doc.title + result = summarizer.summarize(text, max_chars) + summaries.append( + SummaryItem( + index=index, + source=source, + title=title, + summary=result.summary, + truncated=result.truncated, + ) ) - ) - return SummarizeResponse(schemaVersion=SUMMARY_SCHEMA_VERSION, summaries=summaries) + span.set_attribute("ai.signals.schema_version", SUMMARY_SCHEMA_VERSION) + span.set_attribute("ai.signals.max_chars", max_chars) + return SummarizeResponse(schemaVersion=SUMMARY_SCHEMA_VERSION, summaries=summaries) def configure_tracing() -> None: endpoint = os.getenv("OTEL_EXPORTER_OTLP_ENDPOINT") - if not endpoint: - return - service_name = os.getenv("OTEL_SERVICE_NAME", "aether-guard-ai") resource = Resource.create({"service.name": service_name}) - provider = TracerProvider(resource=resource) - exporter = OTLPSpanExporter(endpoint=endpoint) - provider.add_span_processor(BatchSpanProcessor(exporter)) - trace.set_tracer_provider(provider) + if endpoint: + trace_provider = TracerProvider(resource=resource) + span_exporter = OTLPSpanExporter(endpoint=endpoint) + trace_provider.add_span_processor(BatchSpanProcessor(span_exporter)) + trace.set_tracer_provider(trace_provider) + + metric_exporter = OTLPMetricExporter(endpoint=endpoint) + metric_reader = PeriodicExportingMetricReader(metric_exporter) + metric_provider = MeterProvider(resource=resource, metric_readers=[metric_reader]) + metrics.set_meter_provider(metric_provider) + + initialize_signals_metrics() RequestsInstrumentor().instrument() FastAPIInstrumentor.instrument_app(app) diff --git a/src/services/core-dotnet/AetherGuard.Core/Program.cs b/src/services/core-dotnet/AetherGuard.Core/Program.cs index a748da7..c681211 100644 --- a/src/services/core-dotnet/AetherGuard.Core/Program.cs +++ b/src/services/core-dotnet/AetherGuard.Core/Program.cs @@ -2,6 +2,7 @@ using AetherGuard.Core.Observability; using AetherGuard.Core.Security; using AetherGuard.Core.Services; +using AetherGuard.Core.Services.ExternalSignals; using AetherGuard.Core.Services.Messaging; using Microsoft.EntityFrameworkCore; using Microsoft.AspNetCore.Server.Kestrel.Https; @@ -98,6 +99,7 @@ .AddHttpClientInstrumentation() .AddEntityFrameworkCoreInstrumentation() .AddSource("AetherGuard.Core.Messaging") + .AddSource(ExternalSignalsTelemetry.ActivitySourceName) .AddOtlpExporter(options => { options.Endpoint = new Uri(otelOptions.OtlpEndpoint); @@ -113,6 +115,7 @@ .AddHttpClientInstrumentation() .AddRuntimeInstrumentation() .AddProcessInstrumentation() + .AddMeter(ExternalSignalsTelemetry.MeterName) .AddOtlpExporter(options => { options.Endpoint = new Uri(otelOptions.OtlpEndpoint); diff --git a/src/services/core-dotnet/AetherGuard.Core/Services/ExternalSignals/ExternalSignalEnrichmentClient.cs b/src/services/core-dotnet/AetherGuard.Core/Services/ExternalSignals/ExternalSignalEnrichmentClient.cs index 1b958c1..0bfd250 100644 --- a/src/services/core-dotnet/AetherGuard.Core/Services/ExternalSignals/ExternalSignalEnrichmentClient.cs +++ b/src/services/core-dotnet/AetherGuard.Core/Services/ExternalSignals/ExternalSignalEnrichmentClient.cs @@ -1,3 +1,4 @@ +using System.Diagnostics; using System.Net.Http.Json; using System.Text.Json; using System.Text.Json.Serialization; @@ -51,6 +52,9 @@ public ExternalSignalEnrichmentClient( var documents = signals.Select(MapDocument).ToList(); var request = new SummarizeRequestDto(documents, SummaryMaxChars); + const string operation = "summarize"; + using var activity = StartClientActivity(operation, documents.Count); + var stopwatch = Stopwatch.StartNew(); try { @@ -63,6 +67,14 @@ public ExternalSignalEnrichmentClient( if (!response.IsSuccessStatusCode) { _logger.LogWarning("Summarizer returned HTTP {StatusCode}.", response.StatusCode); + ExternalSignalsTelemetry.RecordClientRequest( + operation, + documents.Count, + "http_error", + stopwatch.Elapsed.TotalMilliseconds, + (int)response.StatusCode); + activity?.SetStatus(ActivityStatusCode.Error, "HTTP error"); + activity?.SetTag("http.status_code", (int)response.StatusCode); return null; } @@ -73,14 +85,36 @@ public ExternalSignalEnrichmentClient( if (payload?.Summaries is null || payload.Summaries.Count == 0) { _logger.LogWarning("Summarizer returned an empty payload."); + ExternalSignalsTelemetry.RecordClientRequest( + operation, + documents.Count, + "invalid_payload", + stopwatch.Elapsed.TotalMilliseconds, + (int)response.StatusCode); + activity?.SetStatus(ActivityStatusCode.Error, "Invalid payload"); return null; } + ExternalSignalsTelemetry.RecordClientRequest( + operation, + documents.Count, + "success", + stopwatch.Elapsed.TotalMilliseconds, + (int)response.StatusCode); + activity?.SetStatus(ActivityStatusCode.Ok); return new SummarizeResponse(payload.SchemaVersion ?? "unknown", payload.Summaries); } catch (Exception ex) { _logger.LogWarning(ex, "Summarizer request failed."); + ExternalSignalsTelemetry.RecordClientRequest( + operation, + documents.Count, + "exception", + stopwatch.Elapsed.TotalMilliseconds); + activity?.SetTag("exception.type", ex.GetType().FullName); + activity?.SetTag("exception.message", ex.Message); + activity?.SetStatus(ActivityStatusCode.Error, ex.Message); return null; } } @@ -93,6 +127,11 @@ public ExternalSignalEnrichmentClient( } var request = new EnrichRequestDto(new[] { MapDocument(signal) }); + const string operation = "enrich"; + const int documentCount = 1; + using var activity = StartClientActivity(operation, documentCount); + activity?.SetTag("external_signals.external_id", signal.ExternalId); + var stopwatch = Stopwatch.StartNew(); try { @@ -107,6 +146,14 @@ public ExternalSignalEnrichmentClient( _logger.LogWarning("Enrichment returned HTTP {StatusCode} for signal {ExternalId}.", response.StatusCode, signal.ExternalId); + ExternalSignalsTelemetry.RecordClientRequest( + operation, + documentCount, + "http_error", + stopwatch.Elapsed.TotalMilliseconds, + (int)response.StatusCode); + activity?.SetStatus(ActivityStatusCode.Error, "HTTP error"); + activity?.SetTag("http.status_code", (int)response.StatusCode); return null; } @@ -123,14 +170,36 @@ public ExternalSignalEnrichmentClient( out var normalized)) { _logger.LogWarning("Enrichment payload missing S_v for signal {ExternalId}.", signal.ExternalId); + ExternalSignalsTelemetry.RecordClientRequest( + operation, + documentCount, + "invalid_payload", + stopwatch.Elapsed.TotalMilliseconds, + (int)response.StatusCode); + activity?.SetStatus(ActivityStatusCode.Error, "Invalid payload"); return null; } + ExternalSignalsTelemetry.RecordClientRequest( + operation, + documentCount, + "success", + stopwatch.Elapsed.TotalMilliseconds, + (int)response.StatusCode); + activity?.SetStatus(ActivityStatusCode.Ok); return normalized; } catch (Exception ex) { _logger.LogWarning(ex, "Enrichment request failed for signal {ExternalId}.", signal.ExternalId); + ExternalSignalsTelemetry.RecordClientRequest( + operation, + documentCount, + "exception", + stopwatch.Elapsed.TotalMilliseconds); + activity?.SetTag("exception.type", ex.GetType().FullName); + activity?.SetTag("exception.message", ex.Message); + activity?.SetStatus(ActivityStatusCode.Error, ex.Message); return null; } } @@ -150,6 +219,9 @@ public ExternalSignalEnrichmentClient( return null; } var request = new EnrichRequestDto(documents); + const string operation = "enrich_batch"; + using var activity = StartClientActivity(operation, documents.Count); + var stopwatch = Stopwatch.StartNew(); try { @@ -162,6 +234,14 @@ public ExternalSignalEnrichmentClient( if (!response.IsSuccessStatusCode) { _logger.LogWarning("Batch enrichment returned HTTP {StatusCode}.", response.StatusCode); + ExternalSignalsTelemetry.RecordClientRequest( + operation, + documents.Count, + "http_error", + stopwatch.Elapsed.TotalMilliseconds, + (int)response.StatusCode); + activity?.SetStatus(ActivityStatusCode.Error, "HTTP error"); + activity?.SetTag("http.status_code", (int)response.StatusCode); return null; } @@ -172,6 +252,13 @@ public ExternalSignalEnrichmentClient( if (payload?.Vectors is null || payload.Vectors.Count == 0) { _logger.LogWarning("Batch enrichment returned an empty payload."); + ExternalSignalsTelemetry.RecordClientRequest( + operation, + documents.Count, + "invalid_payload", + stopwatch.Elapsed.TotalMilliseconds, + (int)response.StatusCode); + activity?.SetStatus(ActivityStatusCode.Error, "Invalid payload"); return null; } @@ -195,14 +282,37 @@ public ExternalSignalEnrichmentClient( if (vectors.Count == 0) { _logger.LogWarning("Batch enrichment payload had no valid vectors."); + ExternalSignalsTelemetry.RecordClientRequest( + operation, + documents.Count, + "invalid_payload", + stopwatch.Elapsed.TotalMilliseconds, + (int)response.StatusCode); + activity?.SetStatus(ActivityStatusCode.Error, "Invalid payload"); return null; } + ExternalSignalsTelemetry.RecordClientRequest( + operation, + documents.Count, + "success", + stopwatch.Elapsed.TotalMilliseconds, + (int)response.StatusCode); + activity?.SetTag("external_signals.vectors.count", vectors.Count); + activity?.SetStatus(ActivityStatusCode.Ok); return new BatchEnrichResponse(payload.SchemaVersion ?? "unknown", vectors); } catch (Exception ex) { _logger.LogWarning(ex, "Batch enrichment request failed."); + ExternalSignalsTelemetry.RecordClientRequest( + operation, + documents.Count, + "exception", + stopwatch.Elapsed.TotalMilliseconds); + activity?.SetTag("exception.type", ex.GetType().FullName); + activity?.SetTag("exception.message", ex.Message); + activity?.SetStatus(ActivityStatusCode.Error, ex.Message); return null; } } @@ -240,6 +350,16 @@ private static SignalDocumentDto MapDocument(ExternalSignal signal) Region: signal.Region, PublishedAt: signal.PublishedAt); + private static Activity? StartClientActivity(string operation, int documentCount) + { + var activity = ExternalSignalsTelemetry.ActivitySource.StartActivity( + $"external_signals.client.{operation}", + ActivityKind.Client); + activity?.SetTag("external_signals.operation", operation); + activity?.SetTag("external_signals.documents", documentCount); + return activity; + } + private static bool TryBuildEnrichResponse( string? schemaVersion, double[]? sentimentVector, diff --git a/src/services/core-dotnet/AetherGuard.Core/Services/ExternalSignals/ExternalSignalIngestionService.cs b/src/services/core-dotnet/AetherGuard.Core/Services/ExternalSignals/ExternalSignalIngestionService.cs index e39c8c3..48b198b 100644 --- a/src/services/core-dotnet/AetherGuard.Core/Services/ExternalSignals/ExternalSignalIngestionService.cs +++ b/src/services/core-dotnet/AetherGuard.Core/Services/ExternalSignals/ExternalSignalIngestionService.cs @@ -1,3 +1,4 @@ +using System.Diagnostics; using AetherGuard.Core.Data; using AetherGuard.Core.Models; using Microsoft.EntityFrameworkCore; @@ -137,7 +138,7 @@ private async Task IngestOnceAsync(CancellationToken cancellationToken) if (_options.Enrichment.Enabled) { var enrichmentClient = scope.ServiceProvider.GetRequiredService(); - await EnrichSignalsAsync(db, enrichmentClient, newSignals, cancellationToken); + await EnrichSignalsAsync(db, enrichmentClient, newSignals, feed.Name, cancellationToken); } } @@ -148,6 +149,7 @@ private async Task EnrichSignalsAsync( ApplicationDbContext db, ExternalSignalEnrichmentClient client, List signals, + string source, CancellationToken cancellationToken) { if (!client.IsEnabled || signals.Count == 0) @@ -161,102 +163,147 @@ private async Task EnrichSignalsAsync( return; } + var stopwatch = Stopwatch.StartNew(); + using var activity = ExternalSignalsTelemetry.ActivitySource.StartActivity( + "external_signals.pipeline.enrich", + ActivityKind.Internal); + activity?.SetTag("external_signals.source", source); + activity?.SetTag("external_signals.batch.size", batch.Count); + var summaryUpdates = 0; var enrichmentUpdates = 0; var summarizedAt = DateTimeOffset.UtcNow; + var mode = "batch"; - var summaryResponse = await client.SummarizeAsync(batch, cancellationToken); - if (summaryResponse is not null) + try { - var summaryByIndex = summaryResponse.Summaries - .Where(item => item.Index >= 0 && item.Index < batch.Count) - .GroupBy(item => item.Index) - .ToDictionary(group => group.Key, group => group.First()); - - for (var index = 0; index < batch.Count; index++) + var summaryResponse = await client.SummarizeAsync(batch, cancellationToken); + if (summaryResponse is not null) { - if (!summaryByIndex.TryGetValue(index, out var item)) - { - continue; - } + var summaryByIndex = summaryResponse.Summaries + .Where(item => item.Index >= 0 && item.Index < batch.Count) + .GroupBy(item => item.Index) + .ToDictionary(group => group.Key, group => group.First()); - if (string.IsNullOrWhiteSpace(item.Summary)) + for (var index = 0; index < batch.Count; index++) { - continue; + if (!summaryByIndex.TryGetValue(index, out var item)) + { + continue; + } + + if (string.IsNullOrWhiteSpace(item.Summary)) + { + continue; + } + + var signal = batch[index]; + signal.SummaryDigest = item.Summary; + signal.SummaryDigestTruncated = item.Truncated; + signal.SummarySchemaVersion = summaryResponse.SchemaVersion; + signal.SummarizedAt = summarizedAt; + summaryUpdates++; } - - var signal = batch[index]; - signal.SummaryDigest = item.Summary; - signal.SummaryDigestTruncated = item.Truncated; - signal.SummarySchemaVersion = summaryResponse.SchemaVersion; - signal.SummarizedAt = summarizedAt; - summaryUpdates++; } - } - var enrichedAt = DateTimeOffset.UtcNow; - var batchResponse = await client.EnrichBatchAsync(batch, cancellationToken); + var enrichedAt = DateTimeOffset.UtcNow; + var batchResponse = await client.EnrichBatchAsync(batch, cancellationToken); - if (batchResponse is not null) - { - var vectorsByIndex = batchResponse.Vectors - .Where(item => item.Index >= 0 && item.Index < batch.Count) - .GroupBy(item => item.Index) - .ToDictionary(group => group.Key, group => group.First().Vector); - - for (var index = 0; index < batch.Count; index++) + if (batchResponse is not null) { - if (!vectorsByIndex.TryGetValue(index, out var vector)) - { - continue; - } + var vectorsByIndex = batchResponse.Vectors + .Where(item => item.Index >= 0 && item.Index < batch.Count) + .GroupBy(item => item.Index) + .ToDictionary(group => group.Key, group => group.First().Vector); - if (vector.SentimentVector.Length < 3) + for (var index = 0; index < batch.Count; index++) { - continue; + if (!vectorsByIndex.TryGetValue(index, out var vector)) + { + continue; + } + + if (vector.SentimentVector.Length < 3) + { + continue; + } + + ApplyEnrichment(batch[index], vector, enrichedAt); + enrichmentUpdates++; } - - ApplyEnrichment(batch[index], vector, enrichedAt); - enrichmentUpdates++; } - } - else - { - var semaphore = new SemaphoreSlim(client.MaxConcurrency); - var tasks = batch.Select(async signal => + else { - await semaphore.WaitAsync(cancellationToken); - try + mode = "single_fallback"; + ExternalSignalsTelemetry.RecordPipelineFallback(source, batch.Count, "batch_unavailable"); + activity?.SetTag("external_signals.fallback", true); + + var semaphore = new SemaphoreSlim(client.MaxConcurrency); + var tasks = batch.Select(async signal => { - var result = await client.EnrichAsync(signal, cancellationToken); - return (signal, result); - } - finally + await semaphore.WaitAsync(cancellationToken); + try + { + var result = await client.EnrichAsync(signal, cancellationToken); + return (signal, result); + } + finally + { + semaphore.Release(); + } + }).ToList(); + + var results = await Task.WhenAll(tasks); + foreach (var (signal, result) in results) { - semaphore.Release(); - } - }).ToList(); + if (result is null || result.SentimentVector.Length < 3) + { + continue; + } - var results = await Task.WhenAll(tasks); - foreach (var (signal, result) in results) - { - if (result is null || result.SentimentVector.Length < 3) - { - continue; + ApplyEnrichment(signal, result, enrichedAt); + enrichmentUpdates++; } + } - ApplyEnrichment(signal, result, enrichedAt); - enrichmentUpdates++; + if (summaryUpdates > 0 || enrichmentUpdates > 0) + { + await db.SaveChangesAsync(cancellationToken); + _logger.LogInformation( + "Enriched external signals: summaries={SummaryCount}, vectors={VectorCount}.", + summaryUpdates, + enrichmentUpdates); } - } - if (summaryUpdates > 0 || enrichmentUpdates > 0) + ExternalSignalsTelemetry.RecordPipelineUpdates(source, summaryUpdates, enrichmentUpdates); + var outcome = summaryUpdates > 0 || enrichmentUpdates > 0 ? "success" : "no_updates"; + ExternalSignalsTelemetry.RecordPipelineRun( + source, + batch.Count, + mode, + outcome, + stopwatch.Elapsed.TotalMilliseconds); + activity?.SetTag("external_signals.summary_updates", summaryUpdates); + activity?.SetTag("external_signals.vector_updates", enrichmentUpdates); + activity?.SetTag("external_signals.mode", mode); + activity?.SetStatus(ActivityStatusCode.Ok); + } + catch (Exception ex) { - await db.SaveChangesAsync(cancellationToken); - _logger.LogInformation( - "Enriched external signals: summaries={SummaryCount}, vectors={VectorCount}.", - summaryUpdates, - enrichmentUpdates); + ExternalSignalsTelemetry.RecordPipelineRun( + source, + batch.Count, + mode, + "error", + stopwatch.Elapsed.TotalMilliseconds); + activity?.SetTag("exception.type", ex.GetType().FullName); + activity?.SetTag("exception.message", ex.Message); + activity?.SetStatus(ActivityStatusCode.Error, ex.Message); + throw; + } + finally + { + activity?.SetTag("external_signals.pipeline_duration_ms", stopwatch.Elapsed.TotalMilliseconds); } } diff --git a/src/services/core-dotnet/AetherGuard.Core/Services/ExternalSignals/ExternalSignalsTelemetry.cs b/src/services/core-dotnet/AetherGuard.Core/Services/ExternalSignals/ExternalSignalsTelemetry.cs new file mode 100644 index 0000000..23b029e --- /dev/null +++ b/src/services/core-dotnet/AetherGuard.Core/Services/ExternalSignals/ExternalSignalsTelemetry.cs @@ -0,0 +1,134 @@ +using System.Diagnostics; +using System.Diagnostics.Metrics; + +namespace AetherGuard.Core.Services.ExternalSignals; + +public static class ExternalSignalsTelemetry +{ + public const string ActivitySourceName = "AetherGuard.Core.ExternalSignals"; + public const string MeterName = "AetherGuard.Core.ExternalSignals"; + + public static readonly ActivitySource ActivitySource = new(ActivitySourceName); + + private static readonly Meter Meter = new(MeterName, "1.0.0"); + private static readonly Counter ClientRequestCounter = Meter.CreateCounter( + "aetherguard.external_signals.client.requests", + description: "Number of external-signal enrichment client requests."); + private static readonly Counter ClientFailureCounter = Meter.CreateCounter( + "aetherguard.external_signals.client.failures", + description: "Number of external-signal enrichment client failures."); + private static readonly Histogram ClientLatencyHistogram = Meter.CreateHistogram( + "aetherguard.external_signals.client.duration.ms", + unit: "ms", + description: "Latency of external-signal enrichment client requests."); + private static readonly Histogram ClientDocumentHistogram = Meter.CreateHistogram( + "aetherguard.external_signals.client.documents", + unit: "documents", + description: "Number of documents sent per client request."); + private static readonly Counter PipelineRunCounter = Meter.CreateCounter( + "aetherguard.external_signals.pipeline.runs", + description: "Number of enrichment pipeline runs."); + private static readonly Counter PipelineFallbackCounter = Meter.CreateCounter( + "aetherguard.external_signals.pipeline.fallbacks", + description: "Number of fallback executions in enrichment pipeline."); + private static readonly Histogram PipelineLatencyHistogram = Meter.CreateHistogram( + "aetherguard.external_signals.pipeline.duration.ms", + unit: "ms", + description: "Latency of enrichment pipeline runs."); + private static readonly Histogram PipelineBatchHistogram = Meter.CreateHistogram( + "aetherguard.external_signals.pipeline.batch.size", + unit: "signals", + description: "Batch sizes processed by enrichment pipeline."); + private static readonly Counter PipelineUpdateCounter = Meter.CreateCounter( + "aetherguard.external_signals.pipeline.updates", + description: "Number of persisted enrichment updates."); + + public static void RecordClientRequest( + string operation, + int documentCount, + string outcome, + double durationMs, + int? statusCode = null) + { + var tags = CreateClientTags(operation, outcome, statusCode); + ClientRequestCounter.Add(1, tags); + ClientLatencyHistogram.Record(durationMs, tags); + ClientDocumentHistogram.Record(Math.Max(0, documentCount), tags); + + if (!string.Equals(outcome, "success", StringComparison.OrdinalIgnoreCase)) + { + ClientFailureCounter.Add(1, tags); + } + } + + public static void RecordPipelineRun( + string source, + int batchSize, + string mode, + string outcome, + double durationMs) + { + var tags = CreatePipelineTags(source, mode, outcome); + PipelineRunCounter.Add(1, tags); + PipelineLatencyHistogram.Record(durationMs, tags); + PipelineBatchHistogram.Record(Math.Max(0, batchSize), tags); + } + + public static void RecordPipelineFallback(string source, int batchSize, string reason) + { + var tags = CreatePipelineTags(source, "fallback", reason); + PipelineFallbackCounter.Add(1, tags); + PipelineBatchHistogram.Record(Math.Max(0, batchSize), tags); + } + + public static void RecordPipelineUpdates(string source, int summaryUpdates, int vectorUpdates) + { + var summaryTags = CreateUpdateTags(source, "summary"); + var vectorTags = CreateUpdateTags(source, "vector"); + + if (summaryUpdates > 0) + { + PipelineUpdateCounter.Add(summaryUpdates, summaryTags); + } + + if (vectorUpdates > 0) + { + PipelineUpdateCounter.Add(vectorUpdates, vectorTags); + } + } + + private static TagList CreateClientTags(string operation, string outcome, int? statusCode) + { + var tags = new TagList + { + { "operation", operation }, + { "outcome", outcome } + }; + + if (statusCode.HasValue) + { + tags.Add("http.status_code", statusCode.Value); + } + + return tags; + } + + private static TagList CreatePipelineTags(string source, string mode, string outcome) + { + return new TagList + { + { "source", source }, + { "mode", mode }, + { "outcome", outcome } + }; + } + + private static TagList CreateUpdateTags(string source, string updateType) + { + return new TagList + { + { "source", source }, + { "update_type", updateType } + }; + } +} From 670370339eafbec6ae27edc59309cd47c832ce25 Mon Sep 17 00:00:00 2001 From: JasonEran Date: Sun, 22 Feb 2026 13:41:06 +0800 Subject: [PATCH 12/24] feat(v2.3): add OTLP endpoint resolution for tracing and metrics configuration --- src/services/ai-engine/main.py | 37 ++++++++++++++++++++++++++++++---- 1 file changed, 33 insertions(+), 4 deletions(-) diff --git a/src/services/ai-engine/main.py b/src/services/ai-engine/main.py index 0e7025a..7b88fd0 100644 --- a/src/services/ai-engine/main.py +++ b/src/services/ai-engine/main.py @@ -612,18 +612,47 @@ def summarize_signals(payload: SummarizeRequest) -> SummarizeResponse: span.set_attribute("ai.signals.max_chars", max_chars) return SummarizeResponse(schemaVersion=SUMMARY_SCHEMA_VERSION, summaries=summaries) +def resolve_otlp_endpoint(base_endpoint: str | None, signal_endpoint: str | None, signal: str) -> str | None: + if signal_endpoint: + return signal_endpoint + + if not base_endpoint: + return None + + normalized = base_endpoint.rstrip("/") + signal_suffix = f"/v1/{signal}" + + if normalized.endswith(signal_suffix): + return normalized + if normalized.endswith("/v1"): + return f"{normalized}/{signal}" + + return f"{normalized}{signal_suffix}" + + def configure_tracing() -> None: - endpoint = os.getenv("OTEL_EXPORTER_OTLP_ENDPOINT") + base_endpoint = os.getenv("OTEL_EXPORTER_OTLP_ENDPOINT") + trace_endpoint = resolve_otlp_endpoint( + base_endpoint, + os.getenv("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT"), + "traces", + ) + metric_endpoint = resolve_otlp_endpoint( + base_endpoint, + os.getenv("OTEL_EXPORTER_OTLP_METRICS_ENDPOINT"), + "metrics", + ) service_name = os.getenv("OTEL_SERVICE_NAME", "aether-guard-ai") resource = Resource.create({"service.name": service_name}) - if endpoint: + if trace_endpoint: trace_provider = TracerProvider(resource=resource) - span_exporter = OTLPSpanExporter(endpoint=endpoint) + span_exporter = OTLPSpanExporter(endpoint=trace_endpoint) trace_provider.add_span_processor(BatchSpanProcessor(span_exporter)) trace.set_tracer_provider(trace_provider) - metric_exporter = OTLPMetricExporter(endpoint=endpoint) + if metric_endpoint: + metric_exporter = OTLPMetricExporter(endpoint=metric_endpoint) metric_reader = PeriodicExportingMetricReader(metric_exporter) metric_provider = MeterProvider(resource=resource, metric_readers=[metric_reader]) metrics.set_meter_provider(metric_provider) From 44ced1cec83e32b2ccef46f537b6560e23e6d93d Mon Sep 17 00:00:00 2001 From: JasonEran Date: Sun, 22 Feb 2026 14:02:02 +0800 Subject: [PATCH 13/24] feat(v2.3): add M2 data acquisition scripts and provenance docs --- README.md | 2 + docs/Data-Provenance-v2.3-M2.md | 44 +++ scripts/data_acquisition/README.md | 117 ++++++ .../data_acquisition/fetch_cluster_traces.py | 246 ++++++++++++ .../fetch_incident_archives.py | 320 ++++++++++++++++ .../data_acquisition/fetch_spot_history.py | 361 ++++++++++++++++++ 6 files changed, 1090 insertions(+) create mode 100644 docs/Data-Provenance-v2.3-M2.md create mode 100644 scripts/data_acquisition/README.md create mode 100644 scripts/data_acquisition/fetch_cluster_traces.py create mode 100644 scripts/data_acquisition/fetch_incident_archives.py create mode 100644 scripts/data_acquisition/fetch_spot_history.py diff --git a/README.md b/README.md index 905fdc0..eafcfdd 100644 --- a/README.md +++ b/README.md @@ -195,6 +195,8 @@ Open the dashboard at http://localhost:3000. - v2.3 architecture roadmap: docs/ARCHITECTURE-v2.3.md - v2.3 delivery roadmap: docs/ROADMAP-v2.3.md - v2.3 Milestone 1 smoke test: docs/QA-SmokeTest-v2.3-M1.md +- v2.3 M2 data provenance: docs/Data-Provenance-v2.3-M2.md +- v2.3 M2 data acquisition scripts: scripts/data_acquisition/README.md If you want to simulate migrations, start at least two agents: diff --git a/docs/Data-Provenance-v2.3-M2.md b/docs/Data-Provenance-v2.3-M2.md new file mode 100644 index 0000000..42b008c --- /dev/null +++ b/docs/Data-Provenance-v2.3-M2.md @@ -0,0 +1,44 @@ +# v2.3 M2 Data Provenance + +This document tracks data provenance for Milestone 2 (Fusion + Forecasting, offline replay/backtesting). + +## Scope + +The following sources are collected by scripts in `scripts/data_acquisition/`: + +- Spot pricing signals +- Cluster telemetry traces +- Cloud incident/status archives + +## Source Catalog + +| Dataset domain | Source | Access path | License / terms | License link | Notes | +|---|---|---|---|---|---| +| Spot pricing | AWS public spot snapshot (`spot.js`) | `fetch_spot_history.py --source spot-js` | AWS Site Terms | https://aws.amazon.com/terms/ | Snapshot feed; collect repeatedly to build local history. | +| Spot pricing (historical) | AWS EC2 Spot Price History API | `fetch_spot_history.py --source ec2-api` | AWS Service Terms | https://aws.amazon.com/service-terms/ | Requires AWS credentials and API permissions. | +| Cluster traces | Google ClusterData 2011-2 | `fetch_cluster_traces.py` | CC-BY 4.0 | https://creativecommons.org/licenses/by/4.0/ | Public trace dataset; provenance points to dataset description doc. | +| Incident archives | AWS Service Health RSS | `fetch_incident_archives.py` | AWS Site Terms | https://aws.amazon.com/terms/ | Feed content usage follows provider terms. | +| Incident archives | Google Cloud Status Atom | `fetch_incident_archives.py` | Google Terms of Service | https://policies.google.com/terms | Feed content usage follows provider terms. | +| Incident archives | Azure Status feed | `fetch_incident_archives.py` | Microsoft Terms of Use | https://www.microsoft.com/legal/terms-of-use | Feed content usage follows provider terms. | + +## Reproducibility + +Each acquisition script writes: + +1. Data artifacts under `Data/replay/*` +2. A machine-readable provenance manifest under `Data/replay/provenance/*_manifest.json` + +Manifests include: + +- Generation timestamp (UTC) +- Executed command +- Source URLs +- Output file paths +- Record/file counts +- Integrity metadata (for trace archives: size + SHA-256) + +## Governance Notes + +- `Data/` is excluded from git; datasets remain local unless explicitly exported. +- Downstream training/evaluation outputs should reference the exact provenance manifest used. +- Before publishing derived datasets, review each provider's terms and attribution requirements. diff --git a/scripts/data_acquisition/README.md b/scripts/data_acquisition/README.md new file mode 100644 index 0000000..5ee7e6a --- /dev/null +++ b/scripts/data_acquisition/README.md @@ -0,0 +1,117 @@ +# v2.3 M2 Data Acquisition Scripts + +This folder contains reproducible scripts for Milestone 2 dataset collection: + +- Spot pricing data (for market context features). +- Cluster trace archives (for telemetry replay). +- Incident/status feed archives (for semantic signals). + +All scripts write two outputs: + +1. Data files under an output root (default `Data/replay`). +2. A provenance manifest under `Data/replay/provenance`. + +`Data/` is gitignored in this repository, so downloaded artifacts stay local. + +## Prerequisites + +- Python 3.10+ +- Internet access +- Optional for EC2 history mode in spot script: + - `pip install boto3` + - AWS credentials with `ec2:DescribeSpotPriceHistory` + +## Quick Start + +From repository root: + +```bash +python scripts/data_acquisition/fetch_spot_history.py --source spot-js +python scripts/data_acquisition/fetch_cluster_traces.py --max-files 2 +python scripts/data_acquisition/fetch_incident_archives.py --max-items-per-feed 100 +``` + +## Script Details + +### 1) Spot data + +```bash +python scripts/data_acquisition/fetch_spot_history.py --source spot-js +``` + +- `spot-js` mode: no credentials; snapshots public AWS spot price feed. +- `ec2-api` mode: historical pull from EC2 API (needs AWS credentials). + +Example EC2 API pull: + +```bash +python scripts/data_acquisition/fetch_spot_history.py \ + --source ec2-api \ + --region us-east-1 \ + --instance-types c5.large,m5.large \ + --start-time 2026-01-01T00:00:00Z \ + --end-time 2026-01-31T23:59:59Z +``` + +### 2) Cluster traces (public sample) + +```bash +python scripts/data_acquisition/fetch_cluster_traces.py --max-files 4 +``` + +- Defaults to Google `clusterdata-2011-2` public sample files. +- Supports custom manifests with `--manifest-file`. + +Manifest format: + +```json +[ + { + "name": "task_events_part_00000", + "target": "task_events/part-00000-of-00500.csv.gz", + "url": "https://storage.googleapis.com/clusterdata-2011-2/task_events/part-00000-of-00500.csv.gz" + } +] +``` + +### 3) Incident archives + +```bash +python scripts/data_acquisition/fetch_incident_archives.py --max-items-per-feed 200 +``` + +- Defaults to AWS/GCP/Azure status feeds. +- Supports custom feed catalog with `--feeds-file`. + +Feed file format: + +```json +[ + { + "name": "aws-status", + "url": "https://status.aws.amazon.com/rss/all.rss", + "license": "AWS Site Terms", + "license_url": "https://aws.amazon.com/terms/" + } +] +``` + +## Output Layout + +Default output root (`Data/replay`): + +```text +Data/replay/ + spot_history/ + cluster_traces/ + google_clusterdata_2011_2/ + incident_archives/ + provenance/ + spot_history_manifest.json + cluster_traces_manifest.json + incident_archives_manifest.json +``` + +## Provenance and Licensing + +See `docs/Data-Provenance-v2.3-M2.md` for the source catalog, license/terms links, and governance notes. diff --git a/scripts/data_acquisition/fetch_cluster_traces.py b/scripts/data_acquisition/fetch_cluster_traces.py new file mode 100644 index 0000000..f61d445 --- /dev/null +++ b/scripts/data_acquisition/fetch_cluster_traces.py @@ -0,0 +1,246 @@ +#!/usr/bin/env python3 +"""Download public cluster trace files for offline replay/backtesting.""" + +from __future__ import annotations + +import argparse +import hashlib +import json +import sys +from datetime import datetime, timezone +from pathlib import Path +from typing import Any +from urllib.error import HTTPError, URLError +from urllib.request import Request, urlopen + +USER_AGENT = "Aether-Guard-Data-Acquisition/1.0" + +DEFAULT_FILES = [ + { + "name": "machine_events_part_00000", + "target": "machine_events/part-00000-of-00001.csv.gz", + "url": "https://storage.googleapis.com/clusterdata-2011-2/machine_events/part-00000-of-00001.csv.gz", + }, + { + "name": "machine_attributes_part_00000", + "target": "machine_attributes/part-00000-of-00001.csv.gz", + "url": "https://storage.googleapis.com/clusterdata-2011-2/machine_attributes/part-00000-of-00001.csv.gz", + }, + { + "name": "task_events_part_00000", + "target": "task_events/part-00000-of-00500.csv.gz", + "url": "https://storage.googleapis.com/clusterdata-2011-2/task_events/part-00000-of-00500.csv.gz", + }, + { + "name": "task_usage_part_00000", + "target": "task_usage/part-00000-of-00500.csv.gz", + "url": "https://storage.googleapis.com/clusterdata-2011-2/task_usage/part-00000-of-00500.csv.gz", + }, +] + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--output-dir", + default="Data/replay", + help="Root output directory for trace archives and provenance metadata.", + ) + parser.add_argument( + "--manifest-file", + default="", + help="Optional JSON manifest to replace defaults (list of {name,target,url}).", + ) + parser.add_argument( + "--max-files", + type=int, + default=4, + help="Maximum number of files to download from the manifest.", + ) + parser.add_argument( + "--timeout-seconds", + type=int, + default=120, + help="HTTP timeout per file.", + ) + parser.add_argument( + "--skip-existing", + action="store_true", + help="Skip files already present on disk.", + ) + return parser.parse_args() + + +def utc_now_iso() -> str: + return datetime.now(tz=timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") + + +def run_id() -> str: + return datetime.now(tz=timezone.utc).strftime("%Y%m%dT%H%M%SZ") + + +def write_json(path: Path, payload: dict[str, Any]) -> None: + with path.open("w", encoding="utf-8") as handle: + json.dump(payload, handle, indent=2, sort_keys=True) + handle.write("\n") + + +def load_manifest(path: str) -> list[dict[str, str]]: + if not path: + return list(DEFAULT_FILES) + + manifest_path = Path(path) + payload = json.loads(manifest_path.read_text(encoding="utf-8")) + + if isinstance(payload, dict): + entries = payload.get("files", []) + else: + entries = payload + + if not isinstance(entries, list): + raise ValueError("Manifest must be a list or an object with a 'files' list.") + + normalized: list[dict[str, str]] = [] + for entry in entries: + if not isinstance(entry, dict): + raise ValueError("Manifest entries must be objects.") + name = str(entry.get("name", "")).strip() + target = str(entry.get("target", "")).strip() + url = str(entry.get("url", "")).strip() + if not name or not target or not url: + raise ValueError("Each manifest entry requires name, target, and url.") + normalized.append({"name": name, "target": target, "url": url}) + return normalized + + +def hash_file(path: Path) -> str: + digest = hashlib.sha256() + with path.open("rb") as handle: + while True: + chunk = handle.read(1024 * 1024) + if not chunk: + break + digest.update(chunk) + return digest.hexdigest() + + +def download_file(url: str, target: Path, timeout_seconds: int) -> tuple[int, str, str | None]: + request = Request(url, headers={"User-Agent": USER_AGENT}) + digest = hashlib.sha256() + total_bytes = 0 + last_modified: str | None = None + + with urlopen(request, timeout=timeout_seconds) as response: + last_modified = response.headers.get("Last-Modified") + with target.open("wb") as handle: + while True: + chunk = response.read(1024 * 1024) + if not chunk: + break + handle.write(chunk) + digest.update(chunk) + total_bytes += len(chunk) + + return total_bytes, digest.hexdigest(), last_modified + + +def main() -> int: + args = parse_args() + + if args.max_files <= 0: + print("--max-files must be > 0", file=sys.stderr) + return 2 + + try: + manifest_entries = load_manifest(args.manifest_file) + except Exception as exc: + print(f"Manifest load failed: {exc}", file=sys.stderr) + return 2 + + selected_entries = manifest_entries[: args.max_files] + output_root = Path(args.output_dir) + traces_root = output_root / "cluster_traces" / "google_clusterdata_2011_2" + provenance_dir = output_root / "provenance" + traces_root.mkdir(parents=True, exist_ok=True) + provenance_dir.mkdir(parents=True, exist_ok=True) + + results: list[dict[str, Any]] = [] + success_count = 0 + failure_count = 0 + + for entry in selected_entries: + target = traces_root / entry["target"] + target.parent.mkdir(parents=True, exist_ok=True) + + if args.skip_existing and target.exists(): + record = { + "name": entry["name"], + "url": entry["url"], + "target": str(target), + "status": "skipped", + "bytes": target.stat().st_size, + "sha256": hash_file(target), + "last_modified": None, + } + results.append(record) + success_count += 1 + print(f"Skipped existing file: {target}") + continue + + try: + total_bytes, sha256, last_modified = download_file( + entry["url"], + target, + args.timeout_seconds, + ) + record = { + "name": entry["name"], + "url": entry["url"], + "target": str(target), + "status": "downloaded", + "bytes": total_bytes, + "sha256": sha256, + "last_modified": last_modified, + } + results.append(record) + success_count += 1 + print(f"Downloaded {entry['name']} -> {target} ({total_bytes} bytes)") + except (HTTPError, URLError, TimeoutError) as exc: + record = { + "name": entry["name"], + "url": entry["url"], + "target": str(target), + "status": "failed", + "error": str(exc), + } + results.append(record) + failure_count += 1 + print(f"Failed {entry['name']}: {exc}", file=sys.stderr) + + manifest = { + "dataset": "google_clusterdata_2011_2_sample", + "generated_at_utc": utc_now_iso(), + "command": " ".join(sys.argv), + "license": "CC-BY-4.0", + "license_url": "https://creativecommons.org/licenses/by/4.0/", + "source_terms_url": "https://raw.githubusercontent.com/google/cluster-data/master/ClusterData2011_2.md", + "files": results, + "summary": { + "requested": len(selected_entries), + "successful": success_count, + "failed": failure_count, + }, + } + manifest_path = provenance_dir / "cluster_traces_manifest.json" + write_json(manifest_path, manifest) + print(f"Wrote provenance manifest to {manifest_path}") + + if success_count == 0: + return 1 + if failure_count > 0: + return 2 + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/data_acquisition/fetch_incident_archives.py b/scripts/data_acquisition/fetch_incident_archives.py new file mode 100644 index 0000000..97e1004 --- /dev/null +++ b/scripts/data_acquisition/fetch_incident_archives.py @@ -0,0 +1,320 @@ +#!/usr/bin/env python3 +"""Collect incident/status feed archives for semantic replay datasets.""" + +from __future__ import annotations + +import argparse +import json +import re +import sys +import xml.etree.ElementTree as ET +from datetime import datetime, timezone +from email.utils import parsedate_to_datetime +from pathlib import Path +from typing import Any +from urllib.error import HTTPError, URLError +from urllib.request import Request, urlopen + +USER_AGENT = "Aether-Guard-Data-Acquisition/1.0" + +DEFAULT_FEEDS = [ + { + "name": "aws-status", + "url": "https://status.aws.amazon.com/rss/all.rss", + "license": "AWS Site Terms", + "license_url": "https://aws.amazon.com/terms/", + }, + { + "name": "gcp-status", + "url": "https://status.cloud.google.com/en/feed.atom", + "license": "Google Terms of Service", + "license_url": "https://policies.google.com/terms", + }, + { + "name": "azure-status", + "url": "https://status.azure.com/en-us/status/feed/", + "license": "Microsoft Terms of Use", + "license_url": "https://www.microsoft.com/legal/terms-of-use", + }, +] + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--output-dir", + default="Data/replay", + help="Root output directory for incident archives and provenance metadata.", + ) + parser.add_argument( + "--feeds-file", + default="", + help="Optional JSON file with feed entries ({name,url,license,license_url}).", + ) + parser.add_argument( + "--max-items-per-feed", + type=int, + default=200, + help="Maximum items to emit per feed.", + ) + parser.add_argument( + "--timeout-seconds", + type=int, + default=30, + help="HTTP timeout for feed requests.", + ) + return parser.parse_args() + + +def utc_now() -> datetime: + return datetime.now(tz=timezone.utc) + + +def utc_now_iso() -> str: + return utc_now().strftime("%Y-%m-%dT%H:%M:%SZ") + + +def run_id() -> str: + return utc_now().strftime("%Y%m%dT%H%M%SZ") + + +def write_json(path: Path, payload: dict[str, Any]) -> None: + with path.open("w", encoding="utf-8") as handle: + json.dump(payload, handle, indent=2, sort_keys=True) + handle.write("\n") + + +def local_name(tag: str) -> str: + if "}" in tag: + return tag.rsplit("}", 1)[-1] + if ":" in tag: + return tag.rsplit(":", 1)[-1] + return tag + + +def first_text(element: ET.Element, names: set[str]) -> str | None: + for child in list(element): + if local_name(child.tag) in names: + text = "".join(child.itertext()).strip() + if text: + return text + return None + + +def first_link(element: ET.Element) -> str | None: + for child in list(element): + if local_name(child.tag) != "link": + continue + href = child.attrib.get("href") + if href: + return href.strip() + text = "".join(child.itertext()).strip() + if text: + return text + return None + + +def strip_html(text: str) -> str: + no_tags = re.sub(r"<[^>]+>", " ", text, flags=re.DOTALL) + return " ".join(no_tags.split()).strip() + + +def normalize_timestamp(value: str | None) -> str | None: + if not value: + return None + + raw = value.strip() + if not raw: + return None + + try: + parsed = datetime.fromisoformat(raw.replace("Z", "+00:00")) + if parsed.tzinfo is None: + parsed = parsed.replace(tzinfo=timezone.utc) + return parsed.astimezone(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") + except ValueError: + pass + + try: + parsed = parsedate_to_datetime(raw) + if parsed.tzinfo is None: + parsed = parsed.replace(tzinfo=timezone.utc) + return parsed.astimezone(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") + except (TypeError, ValueError): + return None + + +def load_feeds(path: str) -> list[dict[str, str]]: + if not path: + return list(DEFAULT_FEEDS) + + payload = json.loads(Path(path).read_text(encoding="utf-8")) + if not isinstance(payload, list): + raise ValueError("Feeds file must be a JSON list.") + + normalized: list[dict[str, str]] = [] + for item in payload: + if not isinstance(item, dict): + raise ValueError("Feed entries must be objects.") + name = str(item.get("name", "")).strip() + url = str(item.get("url", "")).strip() + license_name = str(item.get("license", "")).strip() + license_url = str(item.get("license_url", "")).strip() + if not name or not url: + raise ValueError("Feed entries require name and url.") + normalized.append( + { + "name": name, + "url": url, + "license": license_name or "Unknown", + "license_url": license_url, + } + ) + return normalized + + +def parse_feed_items( + xml_payload: str, + feed_name: str, + feed_url: str, + max_items: int, + fetched_at_utc: str, +) -> list[dict[str, Any]]: + root = ET.fromstring(xml_payload) + entries = [node for node in root.iter() if local_name(node.tag) in {"item", "entry"}] + + records: list[dict[str, Any]] = [] + for entry in entries: + if len(records) >= max_items: + break + + title = first_text(entry, {"title"}) + if not title: + continue + + summary = first_text(entry, {"description", "summary", "content"}) or "" + link = first_link(entry) + guid = first_text(entry, {"guid", "id"}) + published_raw = first_text(entry, {"pubDate", "published", "updated", "date"}) + published_at = normalize_timestamp(published_raw) + external_id = guid or link or title + + records.append( + { + "source": feed_name, + "source_url": feed_url, + "external_id": external_id, + "title": title.strip(), + "summary": strip_html(summary), + "url": link, + "published_at": published_at, + "fetched_at_utc": fetched_at_utc, + } + ) + + return records + + +def fetch_url(url: str, timeout_seconds: int) -> str: + request = Request(url, headers={"User-Agent": USER_AGENT}) + with urlopen(request, timeout=timeout_seconds) as response: + return response.read().decode("utf-8") + + +def write_jsonl(path: Path, rows: list[dict[str, Any]]) -> None: + with path.open("w", encoding="utf-8") as handle: + for row in rows: + handle.write(json.dumps(row, ensure_ascii=False)) + handle.write("\n") + + +def main() -> int: + args = parse_args() + + if args.max_items_per_feed <= 0: + print("--max-items-per-feed must be > 0", file=sys.stderr) + return 2 + + try: + feeds = load_feeds(args.feeds_file) + except Exception as exc: + print(f"Failed to load feeds: {exc}", file=sys.stderr) + return 2 + + output_root = Path(args.output_dir) + incidents_dir = output_root / "incident_archives" + provenance_dir = output_root / "provenance" + incidents_dir.mkdir(parents=True, exist_ok=True) + provenance_dir.mkdir(parents=True, exist_ok=True) + + feed_results: list[dict[str, Any]] = [] + total_records = 0 + success_count = 0 + + stamp = run_id() + fetched_at = utc_now_iso() + + for feed in feeds: + name = feed["name"] + url = feed["url"] + output_path = incidents_dir / f"{name}_{stamp}.jsonl" + + try: + xml_payload = fetch_url(url, args.timeout_seconds) + records = parse_feed_items( + xml_payload=xml_payload, + feed_name=name, + feed_url=url, + max_items=args.max_items_per_feed, + fetched_at_utc=fetched_at, + ) + write_jsonl(output_path, records) + feed_results.append( + { + "name": name, + "url": url, + "license": feed.get("license", "Unknown"), + "license_url": feed.get("license_url", ""), + "status": "downloaded", + "record_count": len(records), + "output_jsonl": str(output_path), + } + ) + success_count += 1 + total_records += len(records) + print(f"Fetched {len(records)} records from {name} -> {output_path}") + except (ET.ParseError, HTTPError, URLError, TimeoutError, UnicodeDecodeError) as exc: + feed_results.append( + { + "name": name, + "url": url, + "license": feed.get("license", "Unknown"), + "license_url": feed.get("license_url", ""), + "status": "failed", + "error": str(exc), + } + ) + print(f"Failed {name}: {exc}", file=sys.stderr) + + manifest = { + "dataset": "incident_archives", + "generated_at_utc": utc_now_iso(), + "command": " ".join(sys.argv), + "feeds": feed_results, + "summary": { + "requested_feeds": len(feeds), + "successful_feeds": success_count, + "total_records": total_records, + }, + } + manifest_path = provenance_dir / "incident_archives_manifest.json" + write_json(manifest_path, manifest) + print(f"Wrote provenance manifest to {manifest_path}") + + if success_count == 0: + return 1 + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/data_acquisition/fetch_spot_history.py b/scripts/data_acquisition/fetch_spot_history.py new file mode 100644 index 0000000..ee19802 --- /dev/null +++ b/scripts/data_acquisition/fetch_spot_history.py @@ -0,0 +1,361 @@ +#!/usr/bin/env python3 +"""Collect spot price data for v2.3 replay/backtesting.""" + +from __future__ import annotations + +import argparse +import csv +import json +import re +import sys +from datetime import datetime, timedelta, timezone +from pathlib import Path +from typing import Any +from urllib.error import HTTPError, URLError +from urllib.request import Request, urlopen + +SPOT_JS_URL = "https://spot-price.s3.amazonaws.com/spot.js" +AWS_TERMS_URL = "https://aws.amazon.com/terms/" +USER_AGENT = "Aether-Guard-Data-Acquisition/1.0" +CSV_FIELDS = [ + "timestamp_utc", + "region", + "availability_zone", + "instance_type", + "instance_family", + "operating_system", + "currency", + "spot_price_usd", + "product_description", + "source", + "source_url", +] + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--source", + choices=("spot-js", "ec2-api"), + default="spot-js", + help="Data source mode. Use ec2-api for historical pulls with AWS credentials.", + ) + parser.add_argument( + "--output-dir", + default="Data/replay", + help="Root output directory for generated CSV and provenance metadata.", + ) + parser.add_argument( + "--regions", + default="", + help="Comma-separated region allowlist. Empty means all regions in source.", + ) + parser.add_argument( + "--instance-types", + default="", + help="Comma-separated instance type allowlist. Empty means all instance types.", + ) + parser.add_argument( + "--region", + default="us-east-1", + help="Single AWS region for ec2-api mode.", + ) + parser.add_argument( + "--product-descriptions", + default="Linux/UNIX", + help="Comma-separated EC2 product descriptions for ec2-api mode.", + ) + parser.add_argument( + "--start-time", + default="", + help="UTC start time for ec2-api mode in ISO-8601 (for example 2026-01-01T00:00:00Z).", + ) + parser.add_argument( + "--end-time", + default="", + help="UTC end time for ec2-api mode in ISO-8601.", + ) + parser.add_argument( + "--max-records", + type=int, + default=500_000, + help="Maximum rows to emit before truncating.", + ) + parser.add_argument( + "--timeout-seconds", + type=int, + default=30, + help="HTTP timeout in seconds.", + ) + return parser.parse_args() + + +def utc_now() -> datetime: + return datetime.now(tz=timezone.utc) + + +def format_utc(value: datetime) -> str: + return value.astimezone(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") + + +def run_id() -> str: + return utc_now().strftime("%Y%m%dT%H%M%SZ") + + +def parse_csv_list(value: str) -> set[str]: + return {item.strip() for item in value.split(",") if item.strip()} + + +def parse_iso8601(value: str) -> datetime: + normalized = value.strip().replace("Z", "+00:00") + parsed = datetime.fromisoformat(normalized) + if parsed.tzinfo is None: + parsed = parsed.replace(tzinfo=timezone.utc) + return parsed.astimezone(timezone.utc) + + +def parse_price(raw_value: Any) -> float | None: + try: + return float(raw_value) + except (TypeError, ValueError): + return None + + +def parse_spot_js_payload(payload: str) -> dict[str, Any]: + match = re.search(r"callback\((.*)\)\s*;?\s*$", payload, flags=re.DOTALL) + if not match: + raise ValueError("Unexpected spot.js format: callback(...) wrapper not found.") + return json.loads(match.group(1)) + + +def fetch_spot_js(timeout_seconds: int) -> tuple[dict[str, Any], str | None]: + request = Request(SPOT_JS_URL, headers={"User-Agent": USER_AGENT}) + with urlopen(request, timeout=timeout_seconds) as response: + payload = response.read().decode("utf-8") + last_modified = response.headers.get("Last-Modified") + return parse_spot_js_payload(payload), last_modified + + +def collect_from_spot_js( + timeout_seconds: int, + regions: set[str], + instance_types: set[str], + max_records: int, +) -> tuple[list[dict[str, Any]], dict[str, Any]]: + data, last_modified = fetch_spot_js(timeout_seconds) + rows: list[dict[str, Any]] = [] + snapshot_time = format_utc(utc_now()) + + for region_data in data.get("config", {}).get("regions", []): + region = str(region_data.get("region", "")).strip() + if regions and region not in regions: + continue + + for family_data in region_data.get("instanceTypes", []): + family = str(family_data.get("type", "")).strip() + for size_data in family_data.get("sizes", []): + instance_type = str(size_data.get("size", "")).strip() + if instance_types and instance_type not in instance_types: + continue + + for value_column in size_data.get("valueColumns", []): + operating_system = str(value_column.get("name", "unknown")).strip() + for currency, raw_price in value_column.get("prices", {}).items(): + price = parse_price(raw_price) + row = { + "timestamp_utc": snapshot_time, + "region": region, + "availability_zone": "", + "instance_type": instance_type, + "instance_family": family, + "operating_system": operating_system, + "currency": currency, + "spot_price_usd": price if currency == "USD" else None, + "product_description": operating_system, + "source": "aws_spot_js_snapshot", + "source_url": SPOT_JS_URL, + } + rows.append(row) + if len(rows) >= max_records: + metadata = { + "mode": "spot-js", + "source_url": SPOT_JS_URL, + "last_modified": last_modified, + "spot_js_version": data.get("vers"), + "truncated": True, + } + return rows, metadata + + metadata = { + "mode": "spot-js", + "source_url": SPOT_JS_URL, + "last_modified": last_modified, + "spot_js_version": data.get("vers"), + "truncated": False, + } + return rows, metadata + + +def collect_from_ec2_api( + region: str, + product_descriptions: list[str], + instance_types: set[str], + start_time: datetime, + end_time: datetime, + max_records: int, +) -> tuple[list[dict[str, Any]], dict[str, Any]]: + try: + import boto3 # type: ignore[import-untyped] + except Exception as exc: # pragma: no cover - optional dependency + raise RuntimeError( + "ec2-api mode requires boto3. Install with: pip install boto3" + ) from exc + + client = boto3.client("ec2", region_name=region) + request_kwargs: dict[str, Any] = { + "StartTime": start_time, + "EndTime": end_time, + "ProductDescriptions": product_descriptions, + "MaxResults": 1000, + } + if instance_types: + request_kwargs["InstanceTypes"] = sorted(instance_types) + + rows: list[dict[str, Any]] = [] + next_token = None + + while True: + if next_token: + request_kwargs["NextToken"] = next_token + elif "NextToken" in request_kwargs: + request_kwargs.pop("NextToken", None) + + response = client.describe_spot_price_history(**request_kwargs) + for item in response.get("SpotPriceHistory", []): + instance_type = str(item.get("InstanceType", "")).strip() + row = { + "timestamp_utc": format_utc(item["Timestamp"]), + "region": region, + "availability_zone": item.get("AvailabilityZone", ""), + "instance_type": instance_type, + "instance_family": instance_type.split(".")[0] if "." in instance_type else "", + "operating_system": item.get("ProductDescription", ""), + "currency": "USD", + "spot_price_usd": parse_price(item.get("SpotPrice")), + "product_description": item.get("ProductDescription", ""), + "source": "aws_ec2_spot_price_history_api", + "source_url": "https://docs.aws.amazon.com/AWSEC2/latest/APIReference/API_DescribeSpotPriceHistory.html", + } + rows.append(row) + if len(rows) >= max_records: + metadata = { + "mode": "ec2-api", + "truncated": True, + "region": region, + "start_time_utc": format_utc(start_time), + "end_time_utc": format_utc(end_time), + "request_product_descriptions": product_descriptions, + } + return rows, metadata + + next_token = response.get("NextToken") + if not next_token: + break + + metadata = { + "mode": "ec2-api", + "truncated": False, + "region": region, + "start_time_utc": format_utc(start_time), + "end_time_utc": format_utc(end_time), + "request_product_descriptions": product_descriptions, + } + return rows, metadata + + +def write_rows_csv(path: Path, rows: list[dict[str, Any]]) -> None: + with path.open("w", encoding="utf-8", newline="") as handle: + writer = csv.DictWriter(handle, fieldnames=CSV_FIELDS) + writer.writeheader() + for row in rows: + writer.writerow(row) + + +def write_json(path: Path, payload: dict[str, Any]) -> None: + with path.open("w", encoding="utf-8") as handle: + json.dump(payload, handle, indent=2, sort_keys=True) + handle.write("\n") + + +def main() -> int: + args = parse_args() + + if args.max_records <= 0: + print("--max-records must be > 0.", file=sys.stderr) + return 2 + + output_root = Path(args.output_dir) + spot_dir = output_root / "spot_history" + provenance_dir = output_root / "provenance" + spot_dir.mkdir(parents=True, exist_ok=True) + provenance_dir.mkdir(parents=True, exist_ok=True) + + regions = parse_csv_list(args.regions) + instance_types = parse_csv_list(args.instance_types) + command = " ".join(sys.argv) + + try: + if args.source == "spot-js": + rows, metadata = collect_from_spot_js( + timeout_seconds=args.timeout_seconds, + regions=regions, + instance_types=instance_types, + max_records=args.max_records, + ) + else: + end_time = parse_iso8601(args.end_time) if args.end_time else utc_now() + start_time = parse_iso8601(args.start_time) if args.start_time else end_time - timedelta(days=30) + product_descriptions = sorted(parse_csv_list(args.product_descriptions)) or ["Linux/UNIX"] + rows, metadata = collect_from_ec2_api( + region=args.region, + product_descriptions=product_descriptions, + instance_types=instance_types, + start_time=start_time, + end_time=end_time, + max_records=args.max_records, + ) + except (HTTPError, URLError, TimeoutError, RuntimeError, ValueError) as exc: + print(f"Spot data collection failed: {exc}", file=sys.stderr) + return 1 + except Exception as exc: # pragma: no cover - safety net + print(f"Unexpected spot data collection failure: {exc}", file=sys.stderr) + return 1 + + if not rows: + print("No spot data rows collected.", file=sys.stderr) + return 1 + + stamp = run_id() + csv_path = spot_dir / f"spot_history_{args.source}_{stamp}.csv" + write_rows_csv(csv_path, rows) + + manifest = { + "dataset": "spot_history", + "generated_at_utc": format_utc(utc_now()), + "command": command, + "record_count": len(rows), + "output_csv": str(csv_path), + "license": "AWS Site Terms", + "license_url": AWS_TERMS_URL, + "metadata": metadata, + } + manifest_path = provenance_dir / "spot_history_manifest.json" + write_json(manifest_path, manifest) + + print(f"Wrote {len(rows)} rows to {csv_path}") + print(f"Wrote provenance manifest to {manifest_path}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) From 43cc805ce57d4b5d6bf1c9208c92da5639eef20c Mon Sep 17 00:00:00 2001 From: JasonEran Date: Tue, 24 Feb 2026 10:39:17 +0800 Subject: [PATCH 14/24] feat(v2.3): add TSMixer baseline training and ONNX export workflow --- README.md | 2 + docs/AI-TSMixer-Baseline-v2.3-M2.md | 48 ++ scripts/model_training/README.md | 52 ++ scripts/model_training/requirements.txt | 5 + .../model_training/train_tsmixer_baseline.py | 720 ++++++++++++++++++ 5 files changed, 827 insertions(+) create mode 100644 docs/AI-TSMixer-Baseline-v2.3-M2.md create mode 100644 scripts/model_training/README.md create mode 100644 scripts/model_training/requirements.txt create mode 100644 scripts/model_training/train_tsmixer_baseline.py diff --git a/README.md b/README.md index eafcfdd..171b484 100644 --- a/README.md +++ b/README.md @@ -197,6 +197,8 @@ Open the dashboard at http://localhost:3000. - v2.3 Milestone 1 smoke test: docs/QA-SmokeTest-v2.3-M1.md - v2.3 M2 data provenance: docs/Data-Provenance-v2.3-M2.md - v2.3 M2 data acquisition scripts: scripts/data_acquisition/README.md +- v2.3 M2 TSMixer baseline guide: docs/AI-TSMixer-Baseline-v2.3-M2.md +- v2.3 M2 model training scripts: scripts/model_training/README.md If you want to simulate migrations, start at least two agents: diff --git a/docs/AI-TSMixer-Baseline-v2.3-M2.md b/docs/AI-TSMixer-Baseline-v2.3-M2.md new file mode 100644 index 0000000..0fda203 --- /dev/null +++ b/docs/AI-TSMixer-Baseline-v2.3-M2.md @@ -0,0 +1,48 @@ +# v2.3 M2 TSMixer Baseline + ONNX Export + +This guide documents the baseline workflow delivered for issue #35. + +## Goal + +Train a lightweight time-series model for `P(preempt)` baseline inference and export ONNX artifacts for agent-side runtime integration. + +## Entry Points + +- Training script: `scripts/model_training/train_tsmixer_baseline.py` +- Script usage: `scripts/model_training/README.md` +- Dependency file: `scripts/model_training/requirements.txt` + +## Reproducible Training + +The script is reproducible by design: + +- Fixed seed controls Python, NumPy, and PyTorch RNG. +- Deterministic PyTorch execution is enabled. +- Dataset split is deterministic for the same seed and input. +- Run configuration and metrics are written into `training_summary.json`. + +## Dataset Modes + +1. Real dataset mode: pass a spot-history CSV generated by data acquisition scripts. +2. Synthetic fallback mode: automatically used when the input dataset cannot produce enough windows. + +Fallback reason is persisted in the summary metadata. + +## ONNX Export and Validation + +The script exports `tsmixer_baseline.onnx` and validates by default: + +1. ONNX structure check (`onnx.checker`). +2. Inference parity check (PyTorch vs ONNX Runtime logits on held-out samples). + +Validation details are saved under `onnx_validation` in `training_summary.json`. + +## Artifacts + +Per run output directory contains: + +- `tsmixer_baseline.pt` +- `tsmixer_baseline.onnx` +- `training_summary.json` + +These artifacts should be versioned by downstream issue #38 once model governance flow is implemented. diff --git a/scripts/model_training/README.md b/scripts/model_training/README.md new file mode 100644 index 0000000..da0ee5d --- /dev/null +++ b/scripts/model_training/README.md @@ -0,0 +1,52 @@ +# TSMixer Baseline Training (v2.3 M2) + +This folder contains the baseline training workflow for issue #35: + +- Reproducible TSMixer training (seeded, deterministic flags enabled). +- ONNX export for agent-side inference. +- ONNX validation (checker + onnxruntime parity). + +## Prerequisites + +- Python 3.10+ +- Install training dependencies: + +```bash +python -m pip install -r scripts/model_training/requirements.txt +``` + +## Quick Smoke Run + +Run with synthetic fallback dataset (always available): + +```bash +python scripts/model_training/train_tsmixer_baseline.py \ + --epochs 6 \ + --batch-size 128 \ + --output-dir .tmp/tsmixer-baseline-smoke +``` + +## Train With Acquired Spot Dataset + +```bash +python scripts/model_training/train_tsmixer_baseline.py \ + --dataset-csv Data/replay/spot_history/spot_history_spot-js_YYYYMMDDTHHMMSSZ.csv \ + --epochs 20 \ + --output-dir .tmp/tsmixer-baseline-data +``` + +If real dataset windows are insufficient, the script switches to deterministic synthetic fallback and records the reason. + +## Outputs + +Each run writes: + +- `tsmixer_baseline.pt`: PyTorch checkpoint (`state_dict` + normalization metadata) +- `tsmixer_baseline.onnx`: exported ONNX model +- `training_summary.json`: config, dataset source, metrics, and ONNX validation report + +## Reproducibility Notes + +- `--seed` controls Python, NumPy, and PyTorch RNGs. +- Deterministic PyTorch mode is enabled (`torch.use_deterministic_algorithms`). +- Data split uses deterministic shuffling with the same seed. diff --git a/scripts/model_training/requirements.txt b/scripts/model_training/requirements.txt new file mode 100644 index 0000000..8594f17 --- /dev/null +++ b/scripts/model_training/requirements.txt @@ -0,0 +1,5 @@ +torch +numpy +pandas +onnx +onnxruntime diff --git a/scripts/model_training/train_tsmixer_baseline.py b/scripts/model_training/train_tsmixer_baseline.py new file mode 100644 index 0000000..9987960 --- /dev/null +++ b/scripts/model_training/train_tsmixer_baseline.py @@ -0,0 +1,720 @@ +#!/usr/bin/env python3 +"""Train a reproducible TSMixer baseline and export ONNX for agent inference.""" + +from __future__ import annotations + +import argparse +import json +import random +import sys +from dataclasses import dataclass +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + +import numpy as np +import pandas as pd +import torch +from torch import nn +from torch.utils.data import DataLoader, TensorDataset + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--dataset-csv", + default="", + help="Optional spot dataset CSV path (for example Data/replay/spot_history/*.csv).", + ) + parser.add_argument( + "--output-dir", + default=".tmp/tsmixer-baseline", + help="Output directory for model, ONNX, and training metadata.", + ) + parser.add_argument("--seed", type=int, default=42, help="Global random seed.") + parser.add_argument("--window-size", type=int, default=24, help="Input time steps.") + parser.add_argument("--horizon", type=int, default=1, help="Label horizon in steps.") + parser.add_argument( + "--label-threshold", + type=float, + default=0.03, + help="Future return threshold used to derive labels when label column is absent.", + ) + parser.add_argument("--epochs", type=int, default=20, help="Training epochs.") + parser.add_argument("--batch-size", type=int, default=128, help="Training batch size.") + parser.add_argument("--learning-rate", type=float, default=1e-3, help="Adam learning rate.") + parser.add_argument("--weight-decay", type=float, default=1e-4, help="AdamW weight decay.") + parser.add_argument("--dropout", type=float, default=0.1, help="Dropout in classifier head.") + parser.add_argument("--hidden-size", type=int, default=64, help="TSMixer hidden size.") + parser.add_argument("--num-blocks", type=int, default=3, help="TSMixer block count.") + parser.add_argument("--val-ratio", type=float, default=0.2, help="Validation split ratio.") + parser.add_argument("--test-ratio", type=float, default=0.1, help="Test split ratio.") + parser.add_argument( + "--target-column", + default="label_preempt", + help="Optional binary target column in source CSV.", + ) + parser.add_argument( + "--price-column", + default="spot_price_usd", + help="Price column used for feature extraction and derived labels.", + ) + parser.add_argument( + "--timestamp-column", + default="timestamp_utc", + help="Timestamp column used for deterministic ordering when available.", + ) + parser.add_argument( + "--max-rows", + type=int, + default=0, + help="Optional cap on loaded rows from CSV (0 means unlimited).", + ) + parser.add_argument( + "--synthetic-series", + type=int, + default=32, + help="Synthetic series count used when real dataset has insufficient windows.", + ) + parser.add_argument( + "--synthetic-length", + type=int, + default=240, + help="Length of each synthetic series when fallback is used.", + ) + parser.add_argument( + "--onnx-opset", + type=int, + default=17, + help="ONNX opset version for export.", + ) + parser.add_argument( + "--skip-onnx-validation", + action="store_true", + help="Export ONNX but skip checker/runtime validation.", + ) + return parser.parse_args() + + +def now_utc_iso() -> str: + return datetime.now(tz=timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") + + +def set_global_seed(seed: int) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.use_deterministic_algorithms(True, warn_only=True) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +def safe_float_series(series: pd.Series) -> np.ndarray: + values = pd.to_numeric(series, errors="coerce").astype(float).to_numpy() + return values[np.isfinite(values)] + + +def build_feature_matrix(prices: np.ndarray) -> np.ndarray: + if prices.ndim != 1: + raise ValueError("prices must be 1D") + if prices.size < 5: + raise ValueError("at least 5 points are required") + + returns = np.zeros_like(prices, dtype=np.float32) + returns[1:] = (prices[1:] - prices[:-1]) / np.maximum(np.abs(prices[:-1]), 1e-6) + + rolling_mean = pd.Series(prices).rolling(window=3, min_periods=1).mean().to_numpy(dtype=np.float32) + rolling_std = ( + pd.Series(prices) + .rolling(window=3, min_periods=1) + .std() + .fillna(0.0) + .to_numpy(dtype=np.float32) + ) + + features = np.column_stack( + [ + prices.astype(np.float32), + returns.astype(np.float32), + rolling_mean.astype(np.float32), + rolling_std.astype(np.float32), + ] + ) + return features + + +def build_derived_labels(prices: np.ndarray, horizon: int, threshold: float) -> np.ndarray: + labels = np.zeros(prices.size, dtype=np.float32) + if horizon <= 0: + return labels + for idx in range(prices.size - horizon): + current_price = prices[idx] + future_price = prices[idx + horizon] + future_return = (future_price - current_price) / max(abs(current_price), 1e-6) + labels[idx + horizon] = 1.0 if future_return >= threshold else 0.0 + return labels + + +def windows_from_series( + prices: np.ndarray, + labels: np.ndarray, + window_size: int, + horizon: int, +) -> tuple[list[np.ndarray], list[float]]: + features = build_feature_matrix(prices) + xs: list[np.ndarray] = [] + ys: list[float] = [] + max_index = prices.size - horizon + for end_idx in range(window_size - 1, max_index): + start_idx = end_idx - window_size + 1 + target_index = end_idx + horizon + window = features[start_idx : end_idx + 1] + xs.append(window.astype(np.float32)) + ys.append(float(labels[target_index])) + return xs, ys + + +def prepare_windows_from_dataframe( + df: pd.DataFrame, + *, + price_column: str, + target_column: str, + timestamp_column: str, + window_size: int, + horizon: int, + threshold: float, +) -> tuple[np.ndarray, np.ndarray, dict[str, Any]]: + if price_column not in df.columns: + raise ValueError(f"Missing required price column: {price_column}") + + working_df = df.copy() + if timestamp_column in working_df.columns: + working_df[timestamp_column] = pd.to_datetime(working_df[timestamp_column], errors="coerce", utc=True) + working_df = working_df.sort_values(timestamp_column) + + group_columns = [column for column in ("region", "instance_type") if column in working_df.columns] + if group_columns: + grouped = working_df.groupby(group_columns, dropna=False) + group_frames = [frame for _, frame in grouped] + else: + group_frames = [working_df] + + windows: list[np.ndarray] = [] + labels: list[float] = [] + series_count = 0 + + for frame in group_frames: + prices = safe_float_series(frame[price_column]) + if prices.size < (window_size + horizon + 2): + continue + + if target_column in frame.columns: + target_values = pd.to_numeric(frame[target_column], errors="coerce").fillna(0.0).astype(np.float32).to_numpy() + target_values = np.where(target_values > 0.5, 1.0, 0.0).astype(np.float32) + if target_values.size != prices.size: + # If coercion dropped values via safe_float_series alignment mismatch, fallback to derived labels. + target_values = build_derived_labels(prices, horizon, threshold) + else: + target_values = build_derived_labels(prices, horizon, threshold) + + xs, ys = windows_from_series(prices, target_values, window_size, horizon) + if not xs: + continue + + windows.extend(xs) + labels.extend(ys) + series_count += 1 + + if not windows: + raise ValueError("No training windows could be generated from the provided dataset.") + + x = np.stack(windows).astype(np.float32) + y = np.asarray(labels, dtype=np.float32) + metadata = { + "source": "dataset_csv", + "series_count": series_count, + "window_count": int(x.shape[0]), + } + return x, y, metadata + + +def generate_synthetic_windows( + *, + seed: int, + series_count: int, + series_length: int, + window_size: int, + horizon: int, + threshold: float, +) -> tuple[np.ndarray, np.ndarray, dict[str, Any]]: + rng = np.random.default_rng(seed) + windows: list[np.ndarray] = [] + labels: list[float] = [] + + for series_idx in range(series_count): + trend = rng.normal(0.0002, 0.0004) + noise = rng.normal(0.0, 0.01, size=series_length).astype(np.float32) + shocks = np.zeros(series_length, dtype=np.float32) + shock_positions = rng.choice(series_length, size=max(1, series_length // 60), replace=False) + shocks[shock_positions] = rng.normal(0.08, 0.03, size=shock_positions.shape[0]).astype(np.float32) + + price = np.empty(series_length, dtype=np.float32) + price[0] = 1.0 + float(rng.normal(0.0, 0.05)) + for idx in range(1, series_length): + drift = trend + noise[idx] + shocks[idx] + price[idx] = max(0.05, price[idx - 1] * (1.0 + drift)) + + derived_labels = build_derived_labels(price, horizon, threshold) + xs, ys = windows_from_series(price, derived_labels, window_size, horizon) + windows.extend(xs) + labels.extend(ys) + + x = np.stack(windows).astype(np.float32) + y = np.asarray(labels, dtype=np.float32) + metadata = { + "source": "synthetic_fallback", + "series_count": series_count, + "series_length": series_length, + "window_count": int(x.shape[0]), + } + return x, y, metadata + + +def split_dataset( + x: np.ndarray, + y: np.ndarray, + *, + val_ratio: float, + test_ratio: float, + seed: int, +) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + if x.shape[0] != y.shape[0]: + raise ValueError("Feature and label count mismatch.") + if not (0.0 < val_ratio < 0.5) or not (0.0 <= test_ratio < 0.5): + raise ValueError("val_ratio/test_ratio must be in a reasonable range.") + if val_ratio + test_ratio >= 0.8: + raise ValueError("val_ratio + test_ratio is too large.") + + rng = np.random.default_rng(seed) + indices = np.arange(x.shape[0]) + rng.shuffle(indices) + + x_shuffled = x[indices] + y_shuffled = y[indices] + + total = x.shape[0] + test_count = int(total * test_ratio) + val_count = int(total * val_ratio) + train_count = total - val_count - test_count + + if train_count <= 0 or val_count <= 0 or test_count <= 0: + raise ValueError("Dataset split is too small; increase sample size.") + + x_train = x_shuffled[:train_count] + y_train = y_shuffled[:train_count] + x_val = x_shuffled[train_count : train_count + val_count] + y_val = y_shuffled[train_count : train_count + val_count] + x_test = x_shuffled[train_count + val_count :] + y_test = y_shuffled[train_count + val_count :] + return x_train, y_train, x_val, y_val, x_test, y_test + + +def standardize_features( + x_train: np.ndarray, + x_val: np.ndarray, + x_test: np.ndarray, +) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + train_mean = x_train.mean(axis=(0, 1), keepdims=True).astype(np.float32) + train_std = x_train.std(axis=(0, 1), keepdims=True).astype(np.float32) + train_std = np.where(train_std < 1e-6, 1.0, train_std).astype(np.float32) + + x_train_norm = ((x_train - train_mean) / train_std).astype(np.float32) + x_val_norm = ((x_val - train_mean) / train_std).astype(np.float32) + x_test_norm = ((x_test - train_mean) / train_std).astype(np.float32) + return x_train_norm, x_val_norm, x_test_norm, train_mean, train_std + + +class MixerBlock(nn.Module): + def __init__(self, time_steps: int, channels: int, hidden_size: int) -> None: + super().__init__() + self.time_norm = nn.LayerNorm(channels) + self.time_mlp = nn.Sequential( + nn.Linear(time_steps, hidden_size), + nn.GELU(), + nn.Linear(hidden_size, time_steps), + ) + self.feature_norm = nn.LayerNorm(channels) + self.feature_mlp = nn.Sequential( + nn.Linear(channels, hidden_size), + nn.GELU(), + nn.Linear(hidden_size, channels), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + mixed_time = self.time_norm(x).transpose(1, 2) + mixed_time = self.time_mlp(mixed_time).transpose(1, 2) + x = residual + mixed_time + + residual = x + mixed_feature = self.feature_mlp(self.feature_norm(x)) + return residual + mixed_feature + + +class TSMixerBinaryClassifier(nn.Module): + def __init__( + self, + *, + time_steps: int, + channels: int, + hidden_size: int, + num_blocks: int, + dropout: float, + ) -> None: + super().__init__() + self.blocks = nn.ModuleList( + [MixerBlock(time_steps=time_steps, channels=channels, hidden_size=hidden_size) for _ in range(num_blocks)] + ) + self.head = nn.Sequential( + nn.LayerNorm(channels), + nn.Flatten(), + nn.Linear(time_steps * channels, hidden_size), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_size, 1), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for block in self.blocks: + x = block(x) + logits = self.head(x).squeeze(-1) + return logits + + +@dataclass +class Metrics: + loss: float + accuracy: float + positive_rate: float + + +def evaluate( + model: nn.Module, + x: np.ndarray, + y: np.ndarray, +) -> Metrics: + model.eval() + with torch.no_grad(): + inputs = torch.from_numpy(x) + labels = torch.from_numpy(y) + logits = model(inputs) + loss = nn.functional.binary_cross_entropy_with_logits(logits, labels).item() + probabilities = torch.sigmoid(logits) + predictions = (probabilities >= 0.5).float() + accuracy = float((predictions == labels).float().mean().item()) + positive_rate = float(predictions.mean().item()) + return Metrics(loss=loss, accuracy=accuracy, positive_rate=positive_rate) + + +def train_model( + model: nn.Module, + *, + x_train: np.ndarray, + y_train: np.ndarray, + x_val: np.ndarray, + y_val: np.ndarray, + epochs: int, + batch_size: int, + learning_rate: float, + weight_decay: float, +) -> tuple[nn.Module, list[dict[str, float]], float]: + train_dataset = TensorDataset(torch.from_numpy(x_train), torch.from_numpy(y_train)) + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=False) + + optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay) + positive_count = float(np.sum(y_train)) + negative_count = float(y_train.shape[0] - positive_count) + if positive_count > 0 and negative_count > 0: + pos_weight = torch.tensor([negative_count / positive_count], dtype=torch.float32) + criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight) + else: + criterion = nn.BCEWithLogitsLoss() + + best_val_loss = float("inf") + best_state: dict[str, Any] | None = None + history: list[dict[str, float]] = [] + + for epoch in range(1, epochs + 1): + model.train() + running_loss = 0.0 + seen = 0 + for batch_x, batch_y in train_loader: + optimizer.zero_grad(set_to_none=True) + logits = model(batch_x) + loss = criterion(logits, batch_y) + loss.backward() + optimizer.step() + running_loss += float(loss.item()) * batch_x.shape[0] + seen += batch_x.shape[0] + + train_loss = running_loss / max(seen, 1) + val_metrics = evaluate(model, x_val, y_val) + history.append( + { + "epoch": float(epoch), + "train_loss": float(train_loss), + "val_loss": float(val_metrics.loss), + "val_accuracy": float(val_metrics.accuracy), + } + ) + if val_metrics.loss < best_val_loss: + best_val_loss = val_metrics.loss + best_state = {name: value.detach().cpu().clone() for name, value in model.state_dict().items()} + + if best_state is None: + raise RuntimeError("Training finished without capturing best state.") + + model.load_state_dict(best_state) + return model, history, best_val_loss + + +def export_onnx( + model: nn.Module, + *, + input_shape: tuple[int, int, int], + output_path: Path, + opset: int, +) -> None: + model.eval() + dummy_input = torch.randn(*input_shape, dtype=torch.float32) + torch.onnx.export( + model, + dummy_input, + output_path.as_posix(), + export_params=True, + opset_version=opset, + do_constant_folding=True, + input_names=["telemetry_window"], + output_names=["preempt_logit"], + dynamic_axes={"telemetry_window": {0: "batch_size"}, "preempt_logit": {0: "batch_size"}}, + dynamo=False, + ) + + +def validate_onnx( + model: nn.Module, + onnx_path: Path, + sample_inputs: np.ndarray, +) -> dict[str, Any]: + import onnx + import onnxruntime as ort + + onnx_model = onnx.load(onnx_path.as_posix()) + onnx.checker.check_model(onnx_model) + + session = ort.InferenceSession(onnx_path.as_posix(), providers=["CPUExecutionProvider"]) + onnx_input_name = session.get_inputs()[0].name + onnx_output_name = session.get_outputs()[0].name + + with torch.no_grad(): + torch_logits = model(torch.from_numpy(sample_inputs)).detach().cpu().numpy().reshape(-1) + onnx_logits = session.run([onnx_output_name], {onnx_input_name: sample_inputs})[0].reshape(-1) + + max_abs_diff = float(np.max(np.abs(torch_logits - onnx_logits))) + mean_abs_diff = float(np.mean(np.abs(torch_logits - onnx_logits))) + return { + "onnx_checker_passed": True, + "onnxruntime_parity_passed": bool(max_abs_diff <= 1e-4), + "max_abs_diff": max_abs_diff, + "mean_abs_diff": mean_abs_diff, + "sample_size": int(sample_inputs.shape[0]), + } + + +def write_json(path: Path, payload: dict[str, Any]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w", encoding="utf-8") as handle: + json.dump(payload, handle, indent=2, sort_keys=True) + handle.write("\n") + + +def main() -> int: + args = parse_args() + set_global_seed(args.seed) + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + dataset_metadata: dict[str, Any] + x: np.ndarray + y: np.ndarray + + dataset_path = Path(args.dataset_csv) if args.dataset_csv else None + if dataset_path and dataset_path.exists(): + dataset_df = pd.read_csv(dataset_path) + if args.max_rows > 0: + dataset_df = dataset_df.head(args.max_rows) + try: + x, y, dataset_metadata = prepare_windows_from_dataframe( + dataset_df, + price_column=args.price_column, + target_column=args.target_column, + timestamp_column=args.timestamp_column, + window_size=args.window_size, + horizon=args.horizon, + threshold=args.label_threshold, + ) + except Exception as exc: + print(f"Dataset windows unavailable ({exc}); switching to synthetic fallback.") + x, y, dataset_metadata = generate_synthetic_windows( + seed=args.seed, + series_count=args.synthetic_series, + series_length=args.synthetic_length, + window_size=args.window_size, + horizon=args.horizon, + threshold=args.label_threshold, + ) + dataset_metadata["fallback_reason"] = str(exc) + dataset_metadata["requested_dataset"] = str(dataset_path) + else: + x, y, dataset_metadata = generate_synthetic_windows( + seed=args.seed, + series_count=args.synthetic_series, + series_length=args.synthetic_length, + window_size=args.window_size, + horizon=args.horizon, + threshold=args.label_threshold, + ) + if dataset_path: + dataset_metadata["requested_dataset"] = str(dataset_path) + dataset_metadata["fallback_reason"] = "dataset path does not exist" + + x_train, y_train, x_val, y_val, x_test, y_test = split_dataset( + x, + y, + val_ratio=args.val_ratio, + test_ratio=args.test_ratio, + seed=args.seed, + ) + x_train, x_val, x_test, train_mean, train_std = standardize_features(x_train, x_val, x_test) + + channels = int(x_train.shape[2]) + model = TSMixerBinaryClassifier( + time_steps=args.window_size, + channels=channels, + hidden_size=args.hidden_size, + num_blocks=args.num_blocks, + dropout=args.dropout, + ) + model, history, best_val_loss = train_model( + model, + x_train=x_train, + y_train=y_train, + x_val=x_val, + y_val=y_val, + epochs=args.epochs, + batch_size=args.batch_size, + learning_rate=args.learning_rate, + weight_decay=args.weight_decay, + ) + + train_metrics = evaluate(model, x_train, y_train) + val_metrics = evaluate(model, x_val, y_val) + test_metrics = evaluate(model, x_test, y_test) + + model_path = output_dir / "tsmixer_baseline.pt" + torch.save( + { + "state_dict": model.state_dict(), + "window_size": args.window_size, + "channels": channels, + "train_mean": train_mean.squeeze(0).tolist(), + "train_std": train_std.squeeze(0).tolist(), + "created_at_utc": now_utc_iso(), + }, + model_path, + ) + + onnx_path = output_dir / "tsmixer_baseline.onnx" + export_onnx( + model, + input_shape=(1, args.window_size, channels), + output_path=onnx_path, + opset=args.onnx_opset, + ) + + onnx_validation: dict[str, Any] = {"skipped": bool(args.skip_onnx_validation)} + if not args.skip_onnx_validation: + sample_size = int(min(16, x_test.shape[0])) + sample_inputs = x_test[:sample_size] + onnx_validation = validate_onnx(model, onnx_path, sample_inputs) + + summary = { + "run_at_utc": now_utc_iso(), + "command": " ".join(["python"] + sys.argv), + "config": { + "seed": args.seed, + "window_size": args.window_size, + "horizon": args.horizon, + "label_threshold": args.label_threshold, + "epochs": args.epochs, + "batch_size": args.batch_size, + "learning_rate": args.learning_rate, + "weight_decay": args.weight_decay, + "dropout": args.dropout, + "hidden_size": args.hidden_size, + "num_blocks": args.num_blocks, + "val_ratio": args.val_ratio, + "test_ratio": args.test_ratio, + "onnx_opset": args.onnx_opset, + }, + "dataset": dataset_metadata, + "shapes": { + "x_train": list(x_train.shape), + "x_val": list(x_val.shape), + "x_test": list(x_test.shape), + }, + "label_balance": { + "train_positive_rate": float(np.mean(y_train)), + "val_positive_rate": float(np.mean(y_val)), + "test_positive_rate": float(np.mean(y_test)), + }, + "metrics": { + "best_val_loss": float(best_val_loss), + "train": train_metrics.__dict__, + "val": val_metrics.__dict__, + "test": test_metrics.__dict__, + }, + "artifacts": { + "torch_model": str(model_path), + "onnx_model": str(onnx_path), + }, + "onnx_validation": onnx_validation, + "history": history, + } + summary_path = output_dir / "training_summary.json" + write_json(summary_path, summary) + + print(f"Model saved: {model_path}") + print(f"ONNX saved: {onnx_path}") + print(f"Summary saved: {summary_path}") + print( + "Metrics:" + f" train_acc={train_metrics.accuracy:.4f}" + f" val_acc={val_metrics.accuracy:.4f}" + f" test_acc={test_metrics.accuracy:.4f}" + ) + if not args.skip_onnx_validation: + print( + "ONNX validation:" + f" checker={onnx_validation.get('onnx_checker_passed')}" + f" parity={onnx_validation.get('onnxruntime_parity_passed')}" + f" max_abs_diff={onnx_validation.get('max_abs_diff')}" + ) + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) From 8f2b5ba67fcf007e690764d213fc7be4a77b2981 Mon Sep 17 00:00:00 2001 From: JasonEran Date: Tue, 24 Feb 2026 10:55:08 +0800 Subject: [PATCH 15/24] feat(v2.3): add fusion baseline training and input contract --- README.md | 1 + docs/AI-Fusion-Model-v2.3-M2.md | 76 +++++ scripts/model_training/README.md | 42 ++- scripts/model_training/requirements.txt | 1 + .../model_training/train_fusion_baseline.py | 318 ++++++++++++++++++ 5 files changed, 437 insertions(+), 1 deletion(-) create mode 100644 docs/AI-Fusion-Model-v2.3-M2.md create mode 100644 scripts/model_training/train_fusion_baseline.py diff --git a/README.md b/README.md index 171b484..bde1e36 100644 --- a/README.md +++ b/README.md @@ -198,6 +198,7 @@ Open the dashboard at http://localhost:3000. - v2.3 M2 data provenance: docs/Data-Provenance-v2.3-M2.md - v2.3 M2 data acquisition scripts: scripts/data_acquisition/README.md - v2.3 M2 TSMixer baseline guide: docs/AI-TSMixer-Baseline-v2.3-M2.md +- v2.3 M2 fusion baseline guide: docs/AI-Fusion-Model-v2.3-M2.md - v2.3 M2 model training scripts: scripts/model_training/README.md If you want to simulate migrations, start at least two agents: diff --git a/docs/AI-Fusion-Model-v2.3-M2.md b/docs/AI-Fusion-Model-v2.3-M2.md new file mode 100644 index 0000000..b6616e4 --- /dev/null +++ b/docs/AI-Fusion-Model-v2.3-M2.md @@ -0,0 +1,76 @@ +# v2.3 M2 Fusion Model Baseline (`P(preempt)`) + +This document defines the input contract and offline baseline evaluation for issue #36. + +## Goal + +Fuse telemetry windows with semantic exogenous vectors (`S_v`, `P_v`, `B_s`) and produce `P(preempt)`. + +## Training Entry Point + +- `scripts/model_training/train_fusion_baseline.py` + +## Input Contract (Offline CSV) + +### Required semantic columns + +- `s_v_negative` +- `s_v_neutral` +- `s_v_positive` +- `p_v` +- `b_s` + +### Telemetry columns + +Configured by `--telemetry-columns`. Default order: + +- `spot_price_usd` +- `cpu_utilization` +- `memory_utilization` +- `network_io` + +At least one configured telemetry column must exist in the dataset. + +### Optional label + +- `label_preempt` (binary, 0/1) + +If missing, labels are derived by: + +- future return >= `--label-threshold` OR +- current `p_v >= 0.75` + +### Windowing semantics + +- `window_size`: telemetry lookback length. +- `horizon`: prediction target offset. +- Per training sample: + - telemetry tensor: `[window_size, telemetry_dim]` + - semantic tensor: `[semantic_dim]` at the end of window + - label: binary target at `end + horizon` + +## Offline Baseline Evaluation + +The script trains and evaluates two models on the same split: + +1. Telemetry-only baseline +2. Fusion baseline (telemetry branch + semantic branch) + +Outputs: + +- `telemetry_only_baseline.pt` +- `fusion_baseline.pt` +- `fusion_evaluation_summary.json` + +Summary includes: + +- train/val/test metrics: loss, accuracy, precision, recall, F1, AUROC, average precision +- comparison deltas: + - `test_f1_delta_fusion_minus_telemetry` + - `test_auroc_delta_fusion_minus_telemetry` + +## Reproducibility + +- Fixed `--seed` for Python/NumPy/PyTorch RNG. +- Deterministic PyTorch algorithms enabled. +- Deterministic split and normalization based on train partition. diff --git a/scripts/model_training/README.md b/scripts/model_training/README.md index da0ee5d..0a6b019 100644 --- a/scripts/model_training/README.md +++ b/scripts/model_training/README.md @@ -1,10 +1,11 @@ # TSMixer Baseline Training (v2.3 M2) -This folder contains the baseline training workflow for issue #35: +This folder contains the baseline training workflows for Milestone 2: - Reproducible TSMixer training (seeded, deterministic flags enabled). - ONNX export for agent-side inference. - ONNX validation (checker + onnxruntime parity). +- Fusion baseline training with semantic vectors (`S_v`, `P_v`, `B_s`) and offline comparison. ## Prerequisites @@ -50,3 +51,42 @@ Each run writes: - `--seed` controls Python, NumPy, and PyTorch RNGs. - Deterministic PyTorch mode is enabled (`torch.use_deterministic_algorithms`). - Data split uses deterministic shuffling with the same seed. + +## Fusion Baseline (Issue #36) + +Train telemetry-only and fusion models on the same dataset, then export a comparison summary: + +```bash +python scripts/model_training/train_fusion_baseline.py \ + --epochs 12 \ + --output-dir .tmp/fusion-baseline-smoke +``` + +If the provided CSV does not contain the required semantic contract columns, the script falls back to deterministic synthetic data and records the reason. + +### Fusion Input Contract (CSV) + +Required semantic columns: + +- `s_v_negative` +- `s_v_neutral` +- `s_v_positive` +- `p_v` +- `b_s` + +Telemetry columns are configurable via `--telemetry-columns` and default to: + +- `spot_price_usd` +- `cpu_utilization` +- `memory_utilization` +- `network_io` + +Optional label: + +- `label_preempt` (binary). If missing, labels are derived from future return + `p_v`. + +Fusion outputs: + +- `telemetry_only_baseline.pt` +- `fusion_baseline.pt` +- `fusion_evaluation_summary.json` (contains offline baseline metrics and deltas) diff --git a/scripts/model_training/requirements.txt b/scripts/model_training/requirements.txt index 8594f17..7f85367 100644 --- a/scripts/model_training/requirements.txt +++ b/scripts/model_training/requirements.txt @@ -3,3 +3,4 @@ numpy pandas onnx onnxruntime +scikit-learn diff --git a/scripts/model_training/train_fusion_baseline.py b/scripts/model_training/train_fusion_baseline.py new file mode 100644 index 0000000..191e8ce --- /dev/null +++ b/scripts/model_training/train_fusion_baseline.py @@ -0,0 +1,318 @@ +#!/usr/bin/env python3 +"""Train fusion baseline (telemetry window + S_v/P_v/B_s) for P(preempt).""" + +from __future__ import annotations + +import argparse +import json +import random +import sys +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + +import numpy as np +import pandas as pd +import torch +from sklearn.metrics import average_precision_score, f1_score, precision_score, recall_score, roc_auc_score +from torch import nn +from torch.utils.data import DataLoader, TensorDataset + + +TELEMETRY_DEFAULT = ["spot_price_usd", "cpu_utilization", "memory_utilization", "network_io"] +SEMANTIC_DEFAULT = ["s_v_negative", "s_v_neutral", "s_v_positive", "p_v", "b_s"] + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--dataset-csv", default="", help="Optional CSV with telemetry + semantics + optional label.") + parser.add_argument("--output-dir", default=".tmp/fusion-baseline", help="Output directory.") + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--window-size", type=int, default=24) + parser.add_argument("--horizon", type=int, default=1) + parser.add_argument("--epochs", type=int, default=12) + parser.add_argument("--batch-size", type=int, default=128) + parser.add_argument("--learning-rate", type=float, default=1e-3) + parser.add_argument("--weight-decay", type=float, default=1e-4) + parser.add_argument("--hidden-size", type=int, default=64) + parser.add_argument("--dropout", type=float, default=0.1) + parser.add_argument("--val-ratio", type=float, default=0.2) + parser.add_argument("--test-ratio", type=float, default=0.1) + parser.add_argument("--label-column", default="label_preempt") + parser.add_argument("--label-threshold", type=float, default=0.03) + parser.add_argument("--telemetry-columns", default=",".join(TELEMETRY_DEFAULT)) + parser.add_argument("--semantic-columns", default=",".join(SEMANTIC_DEFAULT)) + parser.add_argument("--synthetic-series", type=int, default=48) + parser.add_argument("--synthetic-length", type=int, default=240) + return parser.parse_args() + + +def utc_now() -> str: + return datetime.now(tz=timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") + + +def set_seed(seed: int) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.use_deterministic_algorithms(True, warn_only=True) + + +def derive_labels(price: np.ndarray, p_v: np.ndarray, horizon: int, threshold: float) -> np.ndarray: + y = np.zeros_like(price, dtype=np.float32) + for idx in range(0, len(price) - horizon): + f = idx + horizon + ret = (price[f] - price[idx]) / max(abs(price[idx]), 1e-6) + y[f] = 1.0 if (ret >= threshold or p_v[idx] >= 0.75) else 0.0 + return y + + +def build_windows(tel: np.ndarray, sem: np.ndarray, y: np.ndarray, window: int, horizon: int) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + xt: list[np.ndarray] = [] + xs: list[np.ndarray] = [] + yy: list[float] = [] + for end in range(window - 1, tel.shape[0] - horizon): + start = end - window + 1 + target = end + horizon + tw = tel[start : end + 1] + sv = sem[end] + if not np.isfinite(tw).all() or not np.isfinite(sv).all(): + continue + xt.append(tw.astype(np.float32)) + xs.append(sv.astype(np.float32)) + yy.append(float(y[target])) + if not xt: + raise ValueError("No windows produced.") + return np.stack(xt), np.stack(xs), np.asarray(yy, dtype=np.float32) + + +def load_dataset(args: argparse.Namespace) -> tuple[np.ndarray, np.ndarray, np.ndarray, dict[str, Any]]: + telemetry_cols = [c.strip() for c in args.telemetry_columns.split(",") if c.strip()] + semantic_cols = [c.strip() for c in args.semantic_columns.split(",") if c.strip()] + path = Path(args.dataset_csv) if args.dataset_csv else None + fallback_reason: str | None = None + if path and path.exists(): + df = pd.read_csv(path) + tel_cols = [c for c in telemetry_cols if c in df.columns] + if tel_cols and all(c in df.columns for c in semantic_cols): + tel = np.column_stack([pd.to_numeric(df[c], errors="coerce").to_numpy(dtype=np.float32) for c in tel_cols]) + sem = np.column_stack([pd.to_numeric(df[c], errors="coerce").to_numpy(dtype=np.float32) for c in semantic_cols]) + if args.label_column in df.columns: + y = np.where(pd.to_numeric(df[args.label_column], errors="coerce").fillna(0.0).to_numpy() >= 0.5, 1.0, 0.0).astype(np.float32) + else: + y = derive_labels(tel[:, 0], sem[:, 3], args.horizon, args.label_threshold) + x_tel, x_sem, y = build_windows(tel, sem, y, args.window_size, args.horizon) + return x_tel, x_sem, y, { + "source": "dataset_csv", + "telemetry_columns": tel_cols, + "semantic_columns": semantic_cols, + "window_count": int(x_tel.shape[0]), + } + missing_sem = [c for c in semantic_cols if c not in df.columns] + if not tel_cols: + fallback_reason = f"missing telemetry columns: {telemetry_cols}" + elif missing_sem: + fallback_reason = f"missing semantic columns: {missing_sem}" + else: + fallback_reason = "dataset windows not generated from provided CSV" + elif path: + fallback_reason = "dataset path does not exist" + + rng = np.random.default_rng(args.seed) + tel_all: list[np.ndarray] = [] + sem_all: list[np.ndarray] = [] + y_all: list[np.ndarray] = [] + for _ in range(args.synthetic_series): + n = args.synthetic_length + price = np.empty(n, dtype=np.float32) + cpu = np.empty(n, dtype=np.float32) + mem = np.empty(n, dtype=np.float32) + net = np.empty(n, dtype=np.float32) + price[0] = 1.0 + rng.normal(0, 0.05) + cpu[0], mem[0], net[0] = 0.45, 0.5, 0.4 + spikes = set(rng.choice(n, size=max(2, n // 45), replace=False).tolist()) + shock = np.zeros(n, dtype=np.float32) + for i in range(1, n): + s = float(rng.normal(0.06, 0.02)) if i in spikes else 0.0 + shock[i] = 1.0 if s else 0.0 + step = 0.0002 + rng.normal(0, 0.01) + s + price[i] = max(0.05, price[i - 1] * (1.0 + step)) + cpu[i] = float(np.clip(0.65 * cpu[i - 1] + 0.35 * (0.45 + step * 2.2 + rng.normal(0, 0.03)), 0, 1)) + mem[i] = float(np.clip(0.75 * mem[i - 1] + 0.25 * (0.5 + abs(step) * 2.0 + rng.normal(0, 0.02)), 0, 1)) + net[i] = float(np.clip(0.60 * net[i - 1] + 0.40 * (0.38 + s * 1.8 + rng.normal(0, 0.03)), 0, 1)) + ret = np.zeros(n, dtype=np.float32) + ret[1:] = (price[1:] - price[:-1]) / np.maximum(np.abs(price[:-1]), 1e-6) + vol = pd.Series(ret).rolling(window=5, min_periods=1).std().fillna(0).to_numpy(dtype=np.float32) + trend = pd.Series(price).rolling(window=12, min_periods=1).mean().to_numpy(dtype=np.float32) + s_neg = 1.0 / (1.0 + np.exp(-( -ret * 12 + shock * 1.5))) + s_pos = 1.0 / (1.0 + np.exp(-( ret * 10 - shock * 0.2))) + s_neu = np.clip(1.0 - np.abs(s_pos - s_neg), 0, 1) + norm = np.maximum(s_neg + s_neu + s_pos, 1e-6) + s_neg, s_neu, s_pos = s_neg / norm, s_neu / norm, s_pos / norm + p_v = 1.0 / (1.0 + np.exp(-(vol * 35 + shock * 1.8))) + b_s = 1.0 / (1.0 + np.exp(-((trend - price) * 4.0))) + tel = np.column_stack([price, cpu, mem, net]).astype(np.float32) + sem = np.column_stack([s_neg, s_neu, s_pos, p_v, b_s]).astype(np.float32) + y = derive_labels(price, p_v.astype(np.float32), args.horizon, args.label_threshold) + xt, xs, yy = build_windows(tel, sem, y, args.window_size, args.horizon) + tel_all.append(xt) + sem_all.append(xs) + y_all.append(yy) + x_tel = np.concatenate(tel_all, axis=0) + x_sem = np.concatenate(sem_all, axis=0) + y = np.concatenate(y_all, axis=0) + metadata = { + "source": "synthetic_fallback", + "telemetry_columns": TELEMETRY_DEFAULT, + "semantic_columns": SEMANTIC_DEFAULT, + "window_count": int(x_tel.shape[0]), + } + if path: + metadata["requested_dataset"] = str(path) + if fallback_reason: + metadata["fallback_reason"] = fallback_reason + return x_tel, x_sem, y, metadata + + +def split_standardize(x_tel: np.ndarray, x_sem: np.ndarray, y: np.ndarray, args: argparse.Namespace) -> tuple[list[np.ndarray], dict[str, Any]]: + idx = np.arange(x_tel.shape[0]) + np.random.default_rng(args.seed).shuffle(idx) + x_tel, x_sem, y = x_tel[idx], x_sem[idx], y[idx] + total = x_tel.shape[0] + n_test = int(total * args.test_ratio) + n_val = int(total * args.val_ratio) + n_train = total - n_test - n_val + train_tel, val_tel, test_tel = x_tel[:n_train], x_tel[n_train : n_train + n_val], x_tel[n_train + n_val :] + train_sem, val_sem, test_sem = x_sem[:n_train], x_sem[n_train : n_train + n_val], x_sem[n_train + n_val :] + train_y, val_y, test_y = y[:n_train], y[n_train : n_train + n_val], y[n_train + n_val :] + tel_mean = train_tel.mean(axis=(0, 1), keepdims=True) + tel_std = np.where(train_tel.std(axis=(0, 1), keepdims=True) < 1e-6, 1.0, train_tel.std(axis=(0, 1), keepdims=True)) + sem_mean = train_sem.mean(axis=0, keepdims=True) + sem_std = np.where(train_sem.std(axis=0, keepdims=True) < 1e-6, 1.0, train_sem.std(axis=0, keepdims=True)) + train_tel, val_tel, test_tel = (train_tel - tel_mean) / tel_std, (val_tel - tel_mean) / tel_std, (test_tel - tel_mean) / tel_std + train_sem, val_sem, test_sem = (train_sem - sem_mean) / sem_std, (val_sem - sem_mean) / sem_std, (test_sem - sem_mean) / sem_std + stats = {"tel_mean": tel_mean.squeeze(0).tolist(), "tel_std": tel_std.squeeze(0).tolist(), "sem_mean": sem_mean.tolist(), "sem_std": sem_std.tolist()} + return [train_tel.astype(np.float32), train_sem.astype(np.float32), train_y.astype(np.float32), val_tel.astype(np.float32), val_sem.astype(np.float32), val_y.astype(np.float32), test_tel.astype(np.float32), test_sem.astype(np.float32), test_y.astype(np.float32)], stats + + +class TelemetryOnly(nn.Module): + def __init__(self, window: int, tel_dim: int, hidden: int, dropout: float): + super().__init__() + self.net = nn.Sequential(nn.Flatten(), nn.Linear(window * tel_dim, hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden, 1)) + + def forward(self, x_tel: torch.Tensor, x_sem: torch.Tensor) -> torch.Tensor: + del x_sem + return self.net(x_tel).squeeze(-1) + + +class Fusion(nn.Module): + def __init__(self, window: int, tel_dim: int, sem_dim: int, hidden: int, dropout: float): + super().__init__() + self.tel = nn.Sequential(nn.Flatten(), nn.Linear(window * tel_dim, hidden), nn.GELU()) + self.sem = nn.Sequential(nn.Linear(sem_dim, hidden), nn.GELU()) + self.cls = nn.Sequential(nn.Linear(hidden * 2, hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden, 1)) + + def forward(self, x_tel: torch.Tensor, x_sem: torch.Tensor) -> torch.Tensor: + return self.cls(torch.cat([self.tel(x_tel), self.sem(x_sem)], dim=1)).squeeze(-1) + + +def train(model: nn.Module, train_tel: np.ndarray, train_sem: np.ndarray, train_y: np.ndarray, val_tel: np.ndarray, val_sem: np.ndarray, val_y: np.ndarray, args: argparse.Namespace) -> tuple[nn.Module, list[dict[str, float]]]: + ds = TensorDataset(torch.from_numpy(train_tel), torch.from_numpy(train_sem), torch.from_numpy(train_y)) + dl = DataLoader(ds, batch_size=args.batch_size, shuffle=True) + pos = float(np.sum(train_y)); neg = float(len(train_y) - pos) + crit = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([neg / pos], dtype=torch.float32)) if pos > 0 and neg > 0 else nn.BCEWithLogitsLoss() + opt = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay) + hist: list[dict[str, float]] = [] + best = None; best_loss = float("inf") + for ep in range(1, args.epochs + 1): + model.train(); total = 0.0; count = 0 + for bt, bs, by in dl: + opt.zero_grad(set_to_none=True); lg = model(bt, bs); loss = crit(lg, by); loss.backward(); opt.step() + total += float(loss.item()) * bt.shape[0]; count += bt.shape[0] + val = evaluate(model, val_tel, val_sem, val_y) + hist.append({"epoch": float(ep), "train_loss": float(total / max(count, 1)), "val_loss": float(val["loss"]), "val_f1": float(val["f1"])}) + if val["loss"] < best_loss: best_loss = val["loss"]; best = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()} + if best is None: raise RuntimeError("No best checkpoint.") + model.load_state_dict(best) + return model, hist + + +def evaluate(model: nn.Module, x_tel: np.ndarray, x_sem: np.ndarray, y: np.ndarray) -> dict[str, Any]: + model.eval() + with torch.no_grad(): + lg = model(torch.from_numpy(x_tel), torch.from_numpy(x_sem)) + loss = float(nn.functional.binary_cross_entropy_with_logits(lg, torch.from_numpy(y)).item()) + prob = torch.sigmoid(lg).cpu().numpy().reshape(-1) + pred = (prob >= 0.5).astype(np.float32) + out: dict[str, Any] = { + "loss": loss, + "accuracy": float(np.mean(pred == y)), + "precision": float(precision_score(y, pred, zero_division=0)), + "recall": float(recall_score(y, pred, zero_division=0)), + "f1": float(f1_score(y, pred, zero_division=0)), + "positive_rate": float(np.mean(pred)), + "auroc": None, + "average_precision": None, + } + if len(np.unique(y)) > 1: + out["auroc"] = float(roc_auc_score(y, prob)) + out["average_precision"] = float(average_precision_score(y, prob)) + return out + + +def main() -> int: + args = parse_args(); set_seed(args.seed) + out_dir = Path(args.output_dir); out_dir.mkdir(parents=True, exist_ok=True) + x_tel, x_sem, y, ds_meta = load_dataset(args) + splits, norm = split_standardize(x_tel, x_sem, y, args) + tr_t, tr_s, tr_y, va_t, va_s, va_y, te_t, te_s, te_y = splits + tel_dim, sem_dim = tr_t.shape[2], tr_s.shape[1] + + set_seed(args.seed) + tel_model = TelemetryOnly(args.window_size, tel_dim, args.hidden_size, args.dropout) + tel_model, tel_hist = train(tel_model, tr_t, tr_s, tr_y, va_t, va_s, va_y, args) + + set_seed(args.seed + 1) + fus_model = Fusion(args.window_size, tel_dim, sem_dim, args.hidden_size, args.dropout) + fus_model, fus_hist = train(fus_model, tr_t, tr_s, tr_y, va_t, va_s, va_y, args) + + tel_m = {"train": evaluate(tel_model, tr_t, tr_s, tr_y), "val": evaluate(tel_model, va_t, va_s, va_y), "test": evaluate(tel_model, te_t, te_s, te_y)} + fus_m = {"train": evaluate(fus_model, tr_t, tr_s, tr_y), "val": evaluate(fus_model, va_t, va_s, va_y), "test": evaluate(fus_model, te_t, te_s, te_y)} + + tel_path = out_dir / "telemetry_only_baseline.pt"; fus_path = out_dir / "fusion_baseline.pt" + torch.save({"state_dict": tel_model.state_dict(), "normalization": norm, "created_at_utc": utc_now()}, tel_path) + torch.save({"state_dict": fus_model.state_dict(), "normalization": norm, "created_at_utc": utc_now()}, fus_path) + + summary = { + "run_at_utc": utc_now(), + "command": " ".join(["python"] + sys.argv), + "dataset": ds_meta, + "config": {k: getattr(args, k) for k in ["seed", "window_size", "horizon", "epochs", "batch_size", "learning_rate", "weight_decay", "hidden_size", "dropout", "label_column", "label_threshold"]}, + "label_balance": {"train_positive_rate": float(np.mean(tr_y)), "val_positive_rate": float(np.mean(va_y)), "test_positive_rate": float(np.mean(te_y))}, + "models": {"telemetry_only": {"metrics": tel_m, "history": tel_hist, "artifact": str(tel_path)}, "fusion": {"metrics": fus_m, "history": fus_hist, "artifact": str(fus_path)}}, + "comparison": { + "test_f1_delta_fusion_minus_telemetry": float(fus_m["test"]["f1"] - tel_m["test"]["f1"]), + "test_auroc_delta_fusion_minus_telemetry": None if (fus_m["test"]["auroc"] is None or tel_m["test"]["auroc"] is None) else float(fus_m["test"]["auroc"] - tel_m["test"]["auroc"]), + }, + } + summary_path = out_dir / "fusion_evaluation_summary.json" + with summary_path.open("w", encoding="utf-8") as f: + json.dump(summary, f, indent=2, sort_keys=True) + f.write("\n") + + print(f"Telemetry-only model saved: {tel_path}") + print(f"Fusion model saved: {fus_path}") + print(f"Summary saved: {summary_path}") + print( + f"Test F1: telemetry={tel_m['test']['f1']:.4f} " + f"fusion={fus_m['test']['f1']:.4f} " + f"delta={summary['comparison']['test_f1_delta_fusion_minus_telemetry']:.4f}" + ) + if summary["comparison"]["test_auroc_delta_fusion_minus_telemetry"] is not None: + print(f"Test AUROC delta: {summary['comparison']['test_auroc_delta_fusion_minus_telemetry']:.4f}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) From cfc80641713149145818e07be9bb86ebe98f84ac Mon Sep 17 00:00:00 2001 From: JasonEran Date: Tue, 24 Feb 2026 11:05:51 +0800 Subject: [PATCH 16/24] feat(v2.3): add backtesting harness for fusion vs v2.2 heuristic --- README.md | 1 + docs/AI-Backtesting-v2.3-M2.md | 56 +++ scripts/model_training/README.md | 17 + .../model_training/backtest_fusion_vs_v22.py | 441 ++++++++++++++++++ 4 files changed, 515 insertions(+) create mode 100644 docs/AI-Backtesting-v2.3-M2.md create mode 100644 scripts/model_training/backtest_fusion_vs_v22.py diff --git a/README.md b/README.md index bde1e36..f503235 100644 --- a/README.md +++ b/README.md @@ -199,6 +199,7 @@ Open the dashboard at http://localhost:3000. - v2.3 M2 data acquisition scripts: scripts/data_acquisition/README.md - v2.3 M2 TSMixer baseline guide: docs/AI-TSMixer-Baseline-v2.3-M2.md - v2.3 M2 fusion baseline guide: docs/AI-Fusion-Model-v2.3-M2.md +- v2.3 M2 backtesting guide: docs/AI-Backtesting-v2.3-M2.md - v2.3 M2 model training scripts: scripts/model_training/README.md If you want to simulate migrations, start at least two agents: diff --git a/docs/AI-Backtesting-v2.3-M2.md b/docs/AI-Backtesting-v2.3-M2.md new file mode 100644 index 0000000..4de31da --- /dev/null +++ b/docs/AI-Backtesting-v2.3-M2.md @@ -0,0 +1,56 @@ +# v2.3 M2 Backtesting Harness (v2.3 Fusion vs v2.2 Heuristic) + +This document describes the offline backtesting runner delivered for issue #37. + +## Goal + +Validate v2.3 fusion model improvements against v2.2 heuristic decisions on held-out windows. + +## Runner + +- `scripts/model_training/backtest_fusion_vs_v22.py` + +## Inputs + +- Fusion checkpoint (`fusion_baseline.pt`) from issue #36. +- Optional dataset CSV with telemetry + semantic columns. +- If dataset contract is not met, deterministic synthetic fallback is used and `fallback_reason` is recorded. + +## Compared Strategies + +1. **v2.2 heuristic** + - Uses legacy `RiskScorer` decision (`CRITICAL` => positive preemption signal). +2. **v2.3 fusion** + - Uses fusion model probability with configurable decision threshold. + +## Held-Out Backtest Protocol + +- Build chronological windows from replay dataset. +- Reserve the tail portion (`backtest_ratio`) as held-out period. +- Evaluate both strategies on the same held-out windows. + +## Metrics + +Reported per strategy: + +- Accuracy +- Precision +- Recall +- F1 +- AUROC (if both classes present) +- Average Precision (if both classes present) +- Positive prediction rate + +Reported deltas: + +- `f1_delta_fusion_minus_v22` +- `recall_delta_fusion_minus_v22` +- `precision_delta_fusion_minus_v22` +- `auroc_delta_fusion_minus_v22` + +## Outputs + +Per run output directory contains: + +- `backtest_summary.json` +- `backtest_report.md` diff --git a/scripts/model_training/README.md b/scripts/model_training/README.md index 0a6b019..c98f8bd 100644 --- a/scripts/model_training/README.md +++ b/scripts/model_training/README.md @@ -90,3 +90,20 @@ Fusion outputs: - `telemetry_only_baseline.pt` - `fusion_baseline.pt` - `fusion_evaluation_summary.json` (contains offline baseline metrics and deltas) + +## Backtest Harness (Issue #37) + +Compare v2.3 fusion model against v2.2 heuristic on held-out windows: + +```bash +python scripts/model_training/backtest_fusion_vs_v22.py \ + --fusion-checkpoint .tmp/fusion-baseline-smoke/fusion_baseline.pt \ + --output-dir .tmp/backtest-fusion-vs-v22 +``` + +If checkpoint is missing, the script can auto-train a fusion baseline and then run backtest. + +Backtest outputs: + +- `backtest_summary.json` +- `backtest_report.md` diff --git a/scripts/model_training/backtest_fusion_vs_v22.py b/scripts/model_training/backtest_fusion_vs_v22.py new file mode 100644 index 0000000..61be65f --- /dev/null +++ b/scripts/model_training/backtest_fusion_vs_v22.py @@ -0,0 +1,441 @@ +#!/usr/bin/env python3 +"""Backtest v2.3 fusion model against v2.2 heuristic on held-out windows.""" + +from __future__ import annotations + +import argparse +import json +import random +import subprocess +import sys +from dataclasses import dataclass +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + +import numpy as np +import pandas as pd +import torch +from sklearn.metrics import average_precision_score, f1_score, precision_score, recall_score, roc_auc_score +from torch import nn + + +TELEMETRY_DEFAULT = ["spot_price_usd", "cpu_utilization", "memory_utilization", "network_io"] +SEMANTIC_DEFAULT = ["s_v_negative", "s_v_neutral", "s_v_positive", "p_v", "b_s"] + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--dataset-csv", default="", help="Optional replay dataset CSV.") + parser.add_argument("--output-dir", default=".tmp/backtest-fusion-vs-v22", help="Output directory for report artifacts.") + parser.add_argument( + "--fusion-checkpoint", + default=".tmp/fusion-baseline-smoke/fusion_baseline.pt", + help="Path to fusion model checkpoint from #36 training.", + ) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--window-size", type=int, default=24) + parser.add_argument("--horizon", type=int, default=1) + parser.add_argument("--backtest-ratio", type=float, default=0.3, help="Fraction of timeline used as held-out backtest.") + parser.add_argument("--decision-threshold", type=float, default=0.5, help="Probability threshold for positive fusion decision.") + parser.add_argument("--label-column", default="label_preempt") + parser.add_argument("--label-threshold", type=float, default=0.03) + parser.add_argument("--telemetry-columns", default=",".join(TELEMETRY_DEFAULT)) + parser.add_argument("--semantic-columns", default=",".join(SEMANTIC_DEFAULT)) + parser.add_argument("--synthetic-series", type=int, default=48) + parser.add_argument("--synthetic-length", type=int, default=240) + parser.add_argument("--autotrain-if-missing", action="store_true", default=True) + parser.add_argument("--autotrain-epochs", type=int, default=8) + parser.add_argument("--autotrain-output-dir", default=".tmp/fusion-baseline-autotrain") + return parser.parse_args() + + +def utc_now() -> str: + return datetime.now(tz=timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") + + +def set_seed(seed: int) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def parse_csv_columns(value: str) -> list[str]: + return [item.strip() for item in value.split(",") if item.strip()] + + +def derive_labels(price: np.ndarray, p_v: np.ndarray, horizon: int, threshold: float) -> np.ndarray: + labels = np.zeros_like(price, dtype=np.float32) + for idx in range(len(price) - horizon): + f = idx + horizon + ret = (price[f] - price[idx]) / max(abs(price[idx]), 1e-6) + labels[f] = 1.0 if (ret >= threshold or p_v[idx] >= 0.75) else 0.0 + return labels + + +def build_windows( + telemetry: np.ndarray, + semantics: np.ndarray, + labels: np.ndarray, + *, + window_size: int, + horizon: int, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + x_tel: list[np.ndarray] = [] + x_sem: list[np.ndarray] = [] + y: list[float] = [] + for end in range(window_size - 1, telemetry.shape[0] - horizon): + start = end - window_size + 1 + target = end + horizon + tw = telemetry[start : end + 1] + sv = semantics[end] + if not np.isfinite(tw).all() or not np.isfinite(sv).all(): + continue + x_tel.append(tw.astype(np.float32)) + x_sem.append(sv.astype(np.float32)) + y.append(float(labels[target])) + if not x_tel: + raise ValueError("No valid windows produced.") + return np.stack(x_tel), np.stack(x_sem), np.asarray(y, dtype=np.float32) + + +def load_dataset(args: argparse.Namespace) -> tuple[np.ndarray, np.ndarray, np.ndarray, dict[str, Any]]: + telemetry_cols = parse_csv_columns(args.telemetry_columns) + semantic_cols = parse_csv_columns(args.semantic_columns) + path = Path(args.dataset_csv) if args.dataset_csv else None + fallback_reason: str | None = None + + if path and path.exists(): + df = pd.read_csv(path) + tel_cols = [col for col in telemetry_cols if col in df.columns] + missing_sem = [col for col in semantic_cols if col not in df.columns] + if tel_cols and not missing_sem: + telemetry = np.column_stack( + [pd.to_numeric(df[col], errors="coerce").to_numpy(dtype=np.float32) for col in tel_cols] + ) + semantics = np.column_stack( + [pd.to_numeric(df[col], errors="coerce").to_numpy(dtype=np.float32) for col in semantic_cols] + ) + if args.label_column in df.columns: + labels = np.where( + pd.to_numeric(df[args.label_column], errors="coerce").fillna(0.0).to_numpy() >= 0.5, + 1.0, + 0.0, + ).astype(np.float32) + else: + labels = derive_labels(telemetry[:, 0], semantics[:, 3], args.horizon, args.label_threshold) + x_tel, x_sem, y = build_windows( + telemetry, + semantics, + labels, + window_size=args.window_size, + horizon=args.horizon, + ) + return x_tel, x_sem, y, { + "source": "dataset_csv", + "telemetry_columns": tel_cols, + "semantic_columns": semantic_cols, + "window_count": int(x_tel.shape[0]), + } + if not tel_cols: + fallback_reason = f"missing telemetry columns: {telemetry_cols}" + else: + fallback_reason = f"missing semantic columns: {missing_sem}" + elif path: + fallback_reason = "dataset path does not exist" + + rng = np.random.default_rng(args.seed) + tel_all: list[np.ndarray] = [] + sem_all: list[np.ndarray] = [] + y_all: list[np.ndarray] = [] + for _ in range(args.synthetic_series): + n = args.synthetic_length + price = np.empty(n, dtype=np.float32) + cpu = np.empty(n, dtype=np.float32) + mem = np.empty(n, dtype=np.float32) + net = np.empty(n, dtype=np.float32) + price[0] = 1.0 + rng.normal(0, 0.05) + cpu[0], mem[0], net[0] = 0.45, 0.5, 0.4 + spike_positions = set(rng.choice(n, size=max(2, n // 45), replace=False).tolist()) + shock = np.zeros(n, dtype=np.float32) + for i in range(1, n): + s = float(rng.normal(0.06, 0.02)) if i in spike_positions else 0.0 + shock[i] = 1.0 if s else 0.0 + step = 0.0002 + rng.normal(0, 0.01) + s + price[i] = max(0.05, price[i - 1] * (1.0 + step)) + cpu[i] = float(np.clip(0.65 * cpu[i - 1] + 0.35 * (0.45 + step * 2.2 + rng.normal(0, 0.03)), 0, 1)) + mem[i] = float(np.clip(0.75 * mem[i - 1] + 0.25 * (0.50 + abs(step) * 2.0 + rng.normal(0, 0.02)), 0, 1)) + net[i] = float(np.clip(0.60 * net[i - 1] + 0.40 * (0.38 + s * 1.8 + rng.normal(0, 0.03)), 0, 1)) + + ret = np.zeros(n, dtype=np.float32) + ret[1:] = (price[1:] - price[:-1]) / np.maximum(np.abs(price[:-1]), 1e-6) + vol = pd.Series(ret).rolling(window=5, min_periods=1).std().fillna(0.0).to_numpy(dtype=np.float32) + trend = pd.Series(price).rolling(window=12, min_periods=1).mean().to_numpy(dtype=np.float32) + s_neg = 1.0 / (1.0 + np.exp(-(-ret * 12 + shock * 1.5))) + s_pos = 1.0 / (1.0 + np.exp(-(ret * 10 - shock * 0.2))) + s_neu = np.clip(1.0 - np.abs(s_pos - s_neg), 0, 1) + norm = np.maximum(s_neg + s_neu + s_pos, 1e-6) + s_neg, s_neu, s_pos = s_neg / norm, s_neu / norm, s_pos / norm + p_v = 1.0 / (1.0 + np.exp(-(vol * 35 + shock * 1.8))) + b_s = 1.0 / (1.0 + np.exp(-((trend - price) * 4.0))) + + tel = np.column_stack([price, cpu, mem, net]).astype(np.float32) + sem = np.column_stack([s_neg, s_neu, s_pos, p_v, b_s]).astype(np.float32) + labels = derive_labels(price, p_v.astype(np.float32), args.horizon, args.label_threshold) + x_tel, x_sem, y = build_windows(tel, sem, labels, window_size=args.window_size, horizon=args.horizon) + tel_all.append(x_tel) + sem_all.append(x_sem) + y_all.append(y) + + x_tel = np.concatenate(tel_all, axis=0) + x_sem = np.concatenate(sem_all, axis=0) + y = np.concatenate(y_all, axis=0) + metadata: dict[str, Any] = { + "source": "synthetic_fallback", + "telemetry_columns": TELEMETRY_DEFAULT, + "semantic_columns": SEMANTIC_DEFAULT, + "window_count": int(x_tel.shape[0]), + } + if path: + metadata["requested_dataset"] = str(path) + if fallback_reason: + metadata["fallback_reason"] = fallback_reason + return x_tel, x_sem, y, metadata + + +class FusionModel(nn.Module): + def __init__(self, window: int, tel_dim: int, sem_dim: int, hidden: int, dropout: float) -> None: + super().__init__() + self.tel = nn.Sequential(nn.Flatten(), nn.Linear(window * tel_dim, hidden), nn.GELU()) + self.sem = nn.Sequential(nn.Linear(sem_dim, hidden), nn.GELU()) + self.cls = nn.Sequential(nn.Linear(hidden * 2, hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden, 1)) + + def forward(self, x_tel: torch.Tensor, x_sem: torch.Tensor) -> torch.Tensor: + return self.cls(torch.cat([self.tel(x_tel), self.sem(x_sem)], dim=1)).squeeze(-1) + + +@dataclass +class BacktestMetrics: + accuracy: float + precision: float + recall: float + f1: float + auroc: float | None + average_precision: float | None + positive_rate: float + + +def calc_metrics(y_true: np.ndarray, y_pred: np.ndarray, y_prob: np.ndarray) -> BacktestMetrics: + accuracy = float(np.mean(y_true == y_pred)) + precision = float(precision_score(y_true, y_pred, zero_division=0)) + recall = float(recall_score(y_true, y_pred, zero_division=0)) + f1 = float(f1_score(y_true, y_pred, zero_division=0)) + positive_rate = float(np.mean(y_pred)) + if len(np.unique(y_true)) < 2: + auroc = None + avg_precision = None + else: + auroc = float(roc_auc_score(y_true, y_prob)) + avg_precision = float(average_precision_score(y_true, y_prob)) + return BacktestMetrics( + accuracy=accuracy, + precision=precision, + recall=recall, + f1=f1, + auroc=auroc, + average_precision=avg_precision, + positive_rate=positive_rate, + ) + + +def ensure_checkpoint(args: argparse.Namespace) -> Path: + checkpoint = Path(args.fusion_checkpoint) + if checkpoint.exists(): + return checkpoint + + if not args.autotrain_if_missing: + raise FileNotFoundError(f"Fusion checkpoint not found: {checkpoint}") + + train_script = Path(__file__).resolve().parent / "train_fusion_baseline.py" + command = [ + sys.executable, + str(train_script), + "--epochs", + str(args.autotrain_epochs), + "--output-dir", + args.autotrain_output_dir, + "--seed", + str(args.seed), + ] + if args.dataset_csv: + command.extend(["--dataset-csv", args.dataset_csv]) + result = subprocess.run(command, check=False, capture_output=True, text=True) + if result.returncode != 0: + raise RuntimeError(f"Autotrain failed.\nSTDOUT:\n{result.stdout}\nSTDERR:\n{result.stderr}") + return Path(args.autotrain_output_dir) / "fusion_baseline.pt" + + +def load_v22_scorer() -> Any: + ai_engine_dir = Path(__file__).resolve().parents[2] / "src" / "services" / "ai-engine" + if str(ai_engine_dir) not in sys.path: + sys.path.insert(0, str(ai_engine_dir)) + from model import RiskScorer # type: ignore + + return RiskScorer() + + +def normalize_with_checkpoint( + x_tel: np.ndarray, + x_sem: np.ndarray, + normalization: dict[str, Any] | None, +) -> tuple[np.ndarray, np.ndarray]: + if not normalization: + tel_mean = x_tel.mean(axis=(0, 1), keepdims=True) + tel_std = np.where(x_tel.std(axis=(0, 1), keepdims=True) < 1e-6, 1.0, x_tel.std(axis=(0, 1), keepdims=True)) + sem_mean = x_sem.mean(axis=0, keepdims=True) + sem_std = np.where(x_sem.std(axis=0, keepdims=True) < 1e-6, 1.0, x_sem.std(axis=0, keepdims=True)) + else: + tel_mean = np.asarray(normalization.get("tel_mean"), dtype=np.float32) + tel_std = np.asarray(normalization.get("tel_std"), dtype=np.float32) + sem_mean = np.asarray(normalization.get("sem_mean"), dtype=np.float32) + sem_std = np.asarray(normalization.get("sem_std"), dtype=np.float32) + tel_mean = tel_mean.reshape(1, 1, -1) + tel_std = np.where(tel_std.reshape(1, 1, -1) < 1e-6, 1.0, tel_std.reshape(1, 1, -1)) + sem_mean = sem_mean.reshape(1, -1) + sem_std = np.where(sem_std.reshape(1, -1) < 1e-6, 1.0, sem_std.reshape(1, -1)) + return ((x_tel - tel_mean) / tel_std).astype(np.float32), ((x_sem - sem_mean) / sem_std).astype(np.float32) + + +def main() -> int: + args = parse_args() + set_seed(args.seed) + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + x_tel, x_sem, y, dataset_meta = load_dataset(args) + total = x_tel.shape[0] + holdout_count = int(total * args.backtest_ratio) + if holdout_count < 100: + holdout_count = min(total, 100) + start = max(0, total - holdout_count) + + x_tel_holdout_raw = x_tel[start:] + x_sem_holdout_raw = x_sem[start:] + y_holdout = y[start:] + + checkpoint_path = ensure_checkpoint(args) + checkpoint = torch.load(checkpoint_path, map_location="cpu") + state_dict = checkpoint.get("state_dict", checkpoint) + normalization = checkpoint.get("normalization") + + hidden = int(state_dict["tel.1.weight"].shape[0]) + tel_dim = int(x_tel_holdout_raw.shape[2]) + sem_dim = int(x_sem_holdout_raw.shape[1]) + model = FusionModel(window=args.window_size, tel_dim=tel_dim, sem_dim=sem_dim, hidden=hidden, dropout=0.0) + model.load_state_dict(state_dict) + model.eval() + + x_tel_holdout, x_sem_holdout = normalize_with_checkpoint(x_tel_holdout_raw, x_sem_holdout_raw, normalization) + with torch.no_grad(): + logits = model(torch.from_numpy(x_tel_holdout), torch.from_numpy(x_sem_holdout)) + fusion_prob = torch.sigmoid(logits).cpu().numpy().reshape(-1) + fusion_pred = (fusion_prob >= args.decision_threshold).astype(np.float32) + + scorer = load_v22_scorer() + heuristic_prob = np.zeros_like(y_holdout, dtype=np.float32) + heuristic_pred = np.zeros_like(y_holdout, dtype=np.float32) + for idx, window in enumerate(x_tel_holdout_raw): + spot_history = [float(value) for value in window[:, 0]] + assessment = scorer.assess_risk( + spot_price_history=spot_history, + rebalance_signal=False, + capacity_score=0.5, + ) + is_critical = str(assessment.Priority).upper() == "CRITICAL" + heuristic_prob[idx] = 1.0 if is_critical else 0.0 + heuristic_pred[idx] = heuristic_prob[idx] + + fusion_metrics = calc_metrics(y_holdout, fusion_pred, fusion_prob) + heuristic_metrics = calc_metrics(y_holdout, heuristic_pred, heuristic_prob) + + auroc_delta = None + if fusion_metrics.auroc is not None and heuristic_metrics.auroc is not None: + auroc_delta = fusion_metrics.auroc - heuristic_metrics.auroc + + summary = { + "run_at_utc": utc_now(), + "command": " ".join(["python"] + sys.argv), + "dataset": dataset_meta, + "config": { + "seed": args.seed, + "window_size": args.window_size, + "horizon": args.horizon, + "backtest_ratio": args.backtest_ratio, + "decision_threshold": args.decision_threshold, + "fusion_checkpoint": str(checkpoint_path), + }, + "counts": { + "total_windows": int(total), + "holdout_windows": int(len(y_holdout)), + "holdout_positive_rate": float(np.mean(y_holdout)), + }, + "metrics": { + "v22_heuristic": heuristic_metrics.__dict__, + "v23_fusion": fusion_metrics.__dict__, + }, + "comparison": { + "f1_delta_fusion_minus_v22": float(fusion_metrics.f1 - heuristic_metrics.f1), + "auroc_delta_fusion_minus_v22": auroc_delta, + "recall_delta_fusion_minus_v22": float(fusion_metrics.recall - heuristic_metrics.recall), + "precision_delta_fusion_minus_v22": float(fusion_metrics.precision - heuristic_metrics.precision), + }, + } + + summary_path = output_dir / "backtest_summary.json" + with summary_path.open("w", encoding="utf-8") as handle: + json.dump(summary, handle, indent=2, sort_keys=True) + handle.write("\n") + + report = output_dir / "backtest_report.md" + with report.open("w", encoding="utf-8") as handle: + handle.write("# Backtest Report: v2.3 Fusion vs v2.2 Heuristic\n\n") + handle.write(f"- Generated at UTC: {summary['run_at_utc']}\n") + handle.write(f"- Hold-out windows: {summary['counts']['holdout_windows']}\n") + handle.write(f"- Hold-out positive rate: {summary['counts']['holdout_positive_rate']:.4f}\n\n") + handle.write("| Strategy | Accuracy | Precision | Recall | F1 | AUROC | AP |\n") + handle.write("|---|---:|---:|---:|---:|---:|---:|\n") + handle.write( + f"| v2.2 heuristic | {heuristic_metrics.accuracy:.4f} | {heuristic_metrics.precision:.4f} | " + f"{heuristic_metrics.recall:.4f} | {heuristic_metrics.f1:.4f} | " + f"{'n/a' if heuristic_metrics.auroc is None else f'{heuristic_metrics.auroc:.4f}'} | " + f"{'n/a' if heuristic_metrics.average_precision is None else f'{heuristic_metrics.average_precision:.4f}'} |\n" + ) + handle.write( + f"| v2.3 fusion | {fusion_metrics.accuracy:.4f} | {fusion_metrics.precision:.4f} | " + f"{fusion_metrics.recall:.4f} | {fusion_metrics.f1:.4f} | " + f"{'n/a' if fusion_metrics.auroc is None else f'{fusion_metrics.auroc:.4f}'} | " + f"{'n/a' if fusion_metrics.average_precision is None else f'{fusion_metrics.average_precision:.4f}'} |\n\n" + ) + handle.write("## Deltas (fusion - v2.2)\n\n") + handle.write(f"- F1 delta: {summary['comparison']['f1_delta_fusion_minus_v22']:.4f}\n") + handle.write(f"- Recall delta: {summary['comparison']['recall_delta_fusion_minus_v22']:.4f}\n") + handle.write(f"- Precision delta: {summary['comparison']['precision_delta_fusion_minus_v22']:.4f}\n") + if auroc_delta is None: + handle.write("- AUROC delta: n/a\n") + else: + handle.write(f"- AUROC delta: {auroc_delta:.4f}\n") + + print(f"Backtest summary saved: {summary_path}") + print(f"Backtest report saved: {report}") + print( + "F1 comparison:" + f" v2.2={heuristic_metrics.f1:.4f}" + f" fusion={fusion_metrics.f1:.4f}" + f" delta={summary['comparison']['f1_delta_fusion_minus_v22']:.4f}" + ) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) From 86bcc96abb1361a7141130dded8bf22ba7df489b Mon Sep 17 00:00:00 2001 From: JasonEran Date: Wed, 25 Feb 2026 10:15:01 +0800 Subject: [PATCH 17/24] feat(v2.3): add model artifact versioning and reproducibility checks --- README.md | 1 + docs/AI-Artifact-Versioning-v2.3-M2.md | 77 +++++++ docs/AI-TSMixer-Baseline-v2.3-M2.md | 1 + scripts/model_training/README.md | 35 ++++ scripts/model_training/artifact_registry.py | 163 +++++++++++++++ .../model_training/backtest_fusion_vs_v22.py | 111 ++++++++-- .../model_training/train_fusion_baseline.py | 132 ++++++++++-- .../model_training/train_tsmixer_baseline.py | 126 ++++++++--- .../model_training/verify_reproducible_run.py | 195 ++++++++++++++++++ 9 files changed, 776 insertions(+), 65 deletions(-) create mode 100644 docs/AI-Artifact-Versioning-v2.3-M2.md create mode 100644 scripts/model_training/artifact_registry.py create mode 100644 scripts/model_training/verify_reproducible_run.py diff --git a/README.md b/README.md index f503235..3aab2cd 100644 --- a/README.md +++ b/README.md @@ -200,6 +200,7 @@ Open the dashboard at http://localhost:3000. - v2.3 M2 TSMixer baseline guide: docs/AI-TSMixer-Baseline-v2.3-M2.md - v2.3 M2 fusion baseline guide: docs/AI-Fusion-Model-v2.3-M2.md - v2.3 M2 backtesting guide: docs/AI-Backtesting-v2.3-M2.md +- v2.3 M2 artifact versioning + reproducibility guide: docs/AI-Artifact-Versioning-v2.3-M2.md - v2.3 M2 model training scripts: scripts/model_training/README.md If you want to simulate migrations, start at least two agents: diff --git a/docs/AI-Artifact-Versioning-v2.3-M2.md b/docs/AI-Artifact-Versioning-v2.3-M2.md new file mode 100644 index 0000000..7fdcbe7 --- /dev/null +++ b/docs/AI-Artifact-Versioning-v2.3-M2.md @@ -0,0 +1,77 @@ +# v2.3 M2 Model Artifact Versioning + Reproducible Runs + +This document captures the delivery for issue #38 (`[Ops] Model artifact versioning + reproducible runs`). + +## Goal + +Make offline model artifacts release-safe and reproducible for Milestone 2. + +## Implemented Components + +- Shared utility: `scripts/model_training/artifact_registry.py` +- Repro check runner: `scripts/model_training/verify_reproducible_run.py` +- Integrated into: + - `scripts/model_training/train_tsmixer_baseline.py` + - `scripts/model_training/train_fusion_baseline.py` + - `scripts/model_training/backtest_fusion_vs_v22.py` + +## Artifact Naming / Versioning Scheme + +Each run computes: + +- `run_version` (CLI flag, default `v2.3-m2`) +- deterministic `run_id` = `-<12-char-fingerprint>` +- full `run_fingerprint_sha256` derived from config + dataset descriptor + git commit + +Each run outputs: + +- base artifacts (legacy names kept for compatibility) +- `run_manifest.json` with file hashes and provenance metadata +- `versioned/` copies named: + - `--.` + +## Run Manifest Schema (`run_manifest.json`) + +`schema_version: v1` payload includes: + +- pipeline metadata (`pipeline`, `run_version`, `run_id`, `run_fingerprint_sha256`) +- git metadata (`commit`, `dirty_worktree`) +- deterministic run config +- dataset/input descriptors (including file hash when a file path exists) +- key metrics used for promotion decisions +- artifact inventory (`path`, `sha256`, `bytes`) + +## Reproducibility Verification + +Use `verify_reproducible_run.py` to execute the same command twice and compare artifact hashes. + +TSMixer example: + +```bash +python scripts/model_training/verify_reproducible_run.py \ + --script scripts/model_training/train_tsmixer_baseline.py \ + --base-output-dir .tmp/repro-check/tsmixer \ + --artifacts tsmixer_baseline.pt,tsmixer_baseline.onnx,training_summary.json,run_manifest.json \ + -- --epochs 6 --batch-size 128 +``` + +Fusion example: + +```bash +python scripts/model_training/verify_reproducible_run.py \ + --script scripts/model_training/train_fusion_baseline.py \ + --base-output-dir .tmp/repro-check/fusion \ + --artifacts telemetry_only_baseline.pt,fusion_baseline.pt,fusion_evaluation_summary.json,run_manifest.json \ + -- --epochs 8 --batch-size 128 +``` + +Verification report location: + +- `/reproducibility_check.json` + +Acceptance is met when `all_artifacts_identical` is `true`. + +## Acceptance Criteria Mapping + +- [x] Artifact naming/versioning scheme +- [x] Re-run produces identical outputs (validated by hash comparison) diff --git a/docs/AI-TSMixer-Baseline-v2.3-M2.md b/docs/AI-TSMixer-Baseline-v2.3-M2.md index 0fda203..72d9ded 100644 --- a/docs/AI-TSMixer-Baseline-v2.3-M2.md +++ b/docs/AI-TSMixer-Baseline-v2.3-M2.md @@ -46,3 +46,4 @@ Per run output directory contains: - `training_summary.json` These artifacts should be versioned by downstream issue #38 once model governance flow is implemented. +Artifact versioning is now implemented in issue #38 (see `docs/AI-Artifact-Versioning-v2.3-M2.md`). diff --git a/scripts/model_training/README.md b/scripts/model_training/README.md index c98f8bd..624d839 100644 --- a/scripts/model_training/README.md +++ b/scripts/model_training/README.md @@ -6,6 +6,7 @@ This folder contains the baseline training workflows for Milestone 2: - ONNX export for agent-side inference. - ONNX validation (checker + onnxruntime parity). - Fusion baseline training with semantic vectors (`S_v`, `P_v`, `B_s`) and offline comparison. +- Artifact manifests + versioned file naming for release-safe model governance. ## Prerequisites @@ -45,6 +46,8 @@ Each run writes: - `tsmixer_baseline.pt`: PyTorch checkpoint (`state_dict` + normalization metadata) - `tsmixer_baseline.onnx`: exported ONNX model - `training_summary.json`: config, dataset source, metrics, and ONNX validation report +- `run_manifest.json`: versioned artifact inventory with SHA256 hashes and git metadata +- `versioned/`: deterministic names following `--.` ## Reproducibility Notes @@ -90,6 +93,8 @@ Fusion outputs: - `telemetry_only_baseline.pt` - `fusion_baseline.pt` - `fusion_evaluation_summary.json` (contains offline baseline metrics and deltas) +- `run_manifest.json` +- `versioned/` ## Backtest Harness (Issue #37) @@ -107,3 +112,33 @@ Backtest outputs: - `backtest_summary.json` - `backtest_report.md` +- `run_manifest.json` +- `versioned/` + +## Artifact Versioning + Reproducibility (Issue #38) + +All model/backtest scripts now: + +- Accept `--run-version` (default `v2.3-m2`). +- Produce deterministic `run_id` and `run_fingerprint_sha256`. +- Generate `run_manifest.json` with artifact hashes and git commit metadata. + +Quick reproducibility check example (TSMixer): + +```bash +python scripts/model_training/verify_reproducible_run.py \ + --script scripts/model_training/train_tsmixer_baseline.py \ + --base-output-dir .tmp/repro-check/tsmixer \ + --artifacts tsmixer_baseline.pt,tsmixer_baseline.onnx,training_summary.json,run_manifest.json \ + -- --epochs 6 --batch-size 128 +``` + +Quick reproducibility check example (Fusion): + +```bash +python scripts/model_training/verify_reproducible_run.py \ + --script scripts/model_training/train_fusion_baseline.py \ + --base-output-dir .tmp/repro-check/fusion \ + --artifacts telemetry_only_baseline.pt,fusion_baseline.pt,fusion_evaluation_summary.json,run_manifest.json \ + -- --epochs 8 --batch-size 128 +``` diff --git a/scripts/model_training/artifact_registry.py b/scripts/model_training/artifact_registry.py new file mode 100644 index 0000000..69615c4 --- /dev/null +++ b/scripts/model_training/artifact_registry.py @@ -0,0 +1,163 @@ +#!/usr/bin/env python3 +"""Utilities for deterministic model artifact versioning and run manifests.""" + +from __future__ import annotations + +import hashlib +import json +import shutil +import subprocess +from pathlib import Path +from typing import Any + + +def canonical_json_bytes(payload: Any) -> bytes: + return json.dumps(payload, ensure_ascii=True, sort_keys=True, separators=(",", ":")).encode("utf-8") + + +def sha256_bytes(payload: bytes) -> str: + return hashlib.sha256(payload).hexdigest() + + +def sha256_file(path: Path) -> str: + digest = hashlib.sha256() + with path.open("rb") as handle: + while True: + chunk = handle.read(1024 * 1024) + if not chunk: + break + digest.update(chunk) + return digest.hexdigest() + + +def describe_dataset_file(dataset_path: Path | None) -> dict[str, Any] | None: + if dataset_path is None or not dataset_path.exists(): + return None + return { + "path": dataset_path.as_posix(), + "sha256": sha256_file(dataset_path), + "bytes": int(dataset_path.stat().st_size), + } + + +def git_commit(repo_root: Path) -> str: + try: + output = subprocess.run( + ["git", "rev-parse", "HEAD"], + cwd=repo_root, + check=True, + capture_output=True, + text=True, + ) + return output.stdout.strip() or "unknown" + except Exception: + return "unknown" + + +def git_is_dirty(repo_root: Path) -> bool: + try: + output = subprocess.run( + ["git", "status", "--porcelain"], + cwd=repo_root, + check=True, + capture_output=True, + text=True, + ) + return bool(output.stdout.strip()) + except Exception: + return False + + +def build_run_identity( + *, + pipeline: str, + run_version: str, + config: dict[str, Any], + dataset: dict[str, Any], + git_sha: str, +) -> tuple[str, str]: + payload = { + "pipeline": pipeline, + "run_version": run_version, + "git_sha": git_sha, + "config": config, + "dataset": dataset, + } + fingerprint = sha256_bytes(canonical_json_bytes(payload)) + run_id = f"{run_version}-{fingerprint[:12]}" + return run_id, fingerprint + + +def _versioned_filename(pipeline: str, run_id: str, role: str, source_path: Path) -> str: + suffix = "".join(source_path.suffixes) or ".bin" + role_slug = role.replace("_", "-") + return f"{pipeline}-{run_id}-{role_slug}{suffix}" + + +def materialize_versioned_artifacts( + *, + output_dir: Path, + pipeline: str, + run_id: str, + artifacts: dict[str, Path], +) -> dict[str, Path]: + versioned_dir = output_dir / "versioned" + versioned_dir.mkdir(parents=True, exist_ok=True) + result: dict[str, Path] = {} + for role, source_path in sorted(artifacts.items()): + target_path = versioned_dir / _versioned_filename(pipeline, run_id, role, source_path) + shutil.copy2(source_path, target_path) + result[role] = target_path + return result + + +def artifact_inventory(artifacts: dict[str, Path]) -> dict[str, dict[str, Any]]: + inventory: dict[str, dict[str, Any]] = {} + for role, path in sorted(artifacts.items()): + inventory[role] = { + "path": path.as_posix(), + "sha256": sha256_file(path), + "bytes": int(path.stat().st_size), + } + return inventory + + +def write_json(path: Path, payload: dict[str, Any]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w", encoding="utf-8") as handle: + json.dump(payload, handle, indent=2, sort_keys=True) + handle.write("\n") + + +def write_run_manifest( + *, + output_dir: Path, + pipeline: str, + run_version: str, + run_id: str, + run_fingerprint: str, + git_sha: str, + git_dirty: bool, + config: dict[str, Any], + dataset: dict[str, Any], + metrics: dict[str, Any], + artifacts: dict[str, Path], +) -> Path: + manifest_path = output_dir / "run_manifest.json" + manifest = { + "schema_version": "v1", + "pipeline": pipeline, + "run_version": run_version, + "run_id": run_id, + "run_fingerprint_sha256": run_fingerprint, + "git": { + "commit": git_sha, + "dirty_worktree": git_dirty, + }, + "config": config, + "dataset": dataset, + "metrics": metrics, + "artifacts": artifact_inventory(artifacts), + } + write_json(manifest_path, manifest) + return manifest_path diff --git a/scripts/model_training/backtest_fusion_vs_v22.py b/scripts/model_training/backtest_fusion_vs_v22.py index 61be65f..7ae3150 100644 --- a/scripts/model_training/backtest_fusion_vs_v22.py +++ b/scripts/model_training/backtest_fusion_vs_v22.py @@ -4,12 +4,10 @@ from __future__ import annotations import argparse -import json import random import subprocess import sys from dataclasses import dataclass -from datetime import datetime, timezone from pathlib import Path from typing import Any @@ -19,6 +17,16 @@ from sklearn.metrics import average_precision_score, f1_score, precision_score, recall_score, roc_auc_score from torch import nn +from artifact_registry import ( + build_run_identity, + describe_dataset_file, + git_commit, + git_is_dirty, + materialize_versioned_artifacts, + write_json, + write_run_manifest, +) + TELEMETRY_DEFAULT = ["spot_price_usd", "cpu_utilization", "memory_utilization", "network_io"] SEMANTIC_DEFAULT = ["s_v_negative", "s_v_neutral", "s_v_positive", "p_v", "b_s"] @@ -47,13 +55,10 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--autotrain-if-missing", action="store_true", default=True) parser.add_argument("--autotrain-epochs", type=int, default=8) parser.add_argument("--autotrain-output-dir", default=".tmp/fusion-baseline-autotrain") + parser.add_argument("--run-version", default="v2.3-m2") return parser.parse_args() -def utc_now() -> str: - return datetime.now(tz=timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") - - def set_seed(seed: int) -> None: random.seed(seed) np.random.seed(seed) @@ -313,6 +318,8 @@ def main() -> int: output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) + dataset_path = Path(args.dataset_csv) if args.dataset_csv else None + dataset_file_info = describe_dataset_file(dataset_path) x_tel, x_sem, y, dataset_meta = load_dataset(args) total = x_tel.shape[0] holdout_count = int(total * args.backtest_ratio) @@ -325,10 +332,44 @@ def main() -> int: y_holdout = y[start:] checkpoint_path = ensure_checkpoint(args) + checkpoint_file_info = describe_dataset_file(checkpoint_path) checkpoint = torch.load(checkpoint_path, map_location="cpu") state_dict = checkpoint.get("state_dict", checkpoint) normalization = checkpoint.get("normalization") + run_config = { + "seed": args.seed, + "window_size": args.window_size, + "horizon": args.horizon, + "backtest_ratio": args.backtest_ratio, + "decision_threshold": args.decision_threshold, + "label_column": args.label_column, + "label_threshold": args.label_threshold, + "telemetry_columns": args.telemetry_columns, + "semantic_columns": args.semantic_columns, + "synthetic_series": args.synthetic_series, + "synthetic_length": args.synthetic_length, + "fusion_checkpoint": checkpoint_path.as_posix(), + "autotrain_if_missing": bool(args.autotrain_if_missing), + "autotrain_epochs": args.autotrain_epochs, + "autotrain_output_dir": args.autotrain_output_dir, + } + dataset_for_identity = dict(dataset_meta) + if dataset_file_info: + dataset_for_identity["dataset_file"] = dataset_file_info + if checkpoint_file_info: + dataset_for_identity["fusion_checkpoint_file"] = checkpoint_file_info + repo_root = Path(__file__).resolve().parents[2] + git_sha = git_commit(repo_root) + git_dirty = git_is_dirty(repo_root) + run_id, run_fingerprint = build_run_identity( + pipeline="fusion-backtest", + run_version=args.run_version, + config=run_config, + dataset=dataset_for_identity, + git_sha=git_sha, + ) + hidden = int(state_dict["tel.1.weight"].shape[0]) tel_dim = int(x_tel_holdout_raw.shape[2]) sem_dim = int(x_sem_holdout_raw.shape[1]) @@ -364,16 +405,18 @@ def main() -> int: auroc_delta = fusion_metrics.auroc - heuristic_metrics.auroc summary = { - "run_at_utc": utc_now(), + "pipeline": "fusion-backtest", + "run_version": args.run_version, + "run_id": run_id, "command": " ".join(["python"] + sys.argv), "dataset": dataset_meta, - "config": { - "seed": args.seed, - "window_size": args.window_size, - "horizon": args.horizon, - "backtest_ratio": args.backtest_ratio, - "decision_threshold": args.decision_threshold, - "fusion_checkpoint": str(checkpoint_path), + "dataset_file": dataset_file_info, + "fusion_checkpoint_file": checkpoint_file_info, + "config": run_config, + "reproducibility": { + "run_fingerprint_sha256": run_fingerprint, + "git_commit": git_sha, + "git_dirty_worktree": git_dirty, }, "counts": { "total_windows": int(total), @@ -393,14 +436,12 @@ def main() -> int: } summary_path = output_dir / "backtest_summary.json" - with summary_path.open("w", encoding="utf-8") as handle: - json.dump(summary, handle, indent=2, sort_keys=True) - handle.write("\n") + write_json(summary_path, summary) report = output_dir / "backtest_report.md" with report.open("w", encoding="utf-8") as handle: handle.write("# Backtest Report: v2.3 Fusion vs v2.2 Heuristic\n\n") - handle.write(f"- Generated at UTC: {summary['run_at_utc']}\n") + handle.write(f"- Run id: {run_id}\n") handle.write(f"- Hold-out windows: {summary['counts']['holdout_windows']}\n") handle.write(f"- Hold-out positive rate: {summary['counts']['holdout_positive_rate']:.4f}\n\n") handle.write("| Strategy | Accuracy | Precision | Recall | F1 | AUROC | AP |\n") @@ -426,8 +467,42 @@ def main() -> int: else: handle.write(f"- AUROC delta: {auroc_delta:.4f}\n") + base_artifacts = { + "backtest_summary": summary_path, + "backtest_report": report, + } + versioned_artifacts = materialize_versioned_artifacts( + output_dir=output_dir, + pipeline="fusion-backtest", + run_id=run_id, + artifacts=base_artifacts, + ) + manifest_artifacts = dict(base_artifacts) + for role, path in versioned_artifacts.items(): + manifest_artifacts[f"versioned_{role}"] = path + manifest_artifacts["input_fusion_checkpoint"] = checkpoint_path + manifest_path = write_run_manifest( + output_dir=output_dir, + pipeline="fusion-backtest", + run_version=args.run_version, + run_id=run_id, + run_fingerprint=run_fingerprint, + git_sha=git_sha, + git_dirty=git_dirty, + config=run_config, + dataset=dataset_for_identity, + metrics={ + "v22_heuristic": summary["metrics"]["v22_heuristic"], + "v23_fusion": summary["metrics"]["v23_fusion"], + "comparison": summary["comparison"], + }, + artifacts=manifest_artifacts, + ) + print(f"Backtest summary saved: {summary_path}") print(f"Backtest report saved: {report}") + print(f"Run manifest saved: {manifest_path}") + print(f"Run id: {run_id}") print( "F1 comparison:" f" v2.2={heuristic_metrics.f1:.4f}" diff --git a/scripts/model_training/train_fusion_baseline.py b/scripts/model_training/train_fusion_baseline.py index 191e8ce..754324d 100644 --- a/scripts/model_training/train_fusion_baseline.py +++ b/scripts/model_training/train_fusion_baseline.py @@ -4,10 +4,8 @@ from __future__ import annotations import argparse -import json import random import sys -from datetime import datetime, timezone from pathlib import Path from typing import Any @@ -18,6 +16,16 @@ from torch import nn from torch.utils.data import DataLoader, TensorDataset +from artifact_registry import ( + build_run_identity, + describe_dataset_file, + git_commit, + git_is_dirty, + materialize_versioned_artifacts, + write_json, + write_run_manifest, +) + TELEMETRY_DEFAULT = ["spot_price_usd", "cpu_utilization", "memory_utilization", "network_io"] SEMANTIC_DEFAULT = ["s_v_negative", "s_v_neutral", "s_v_positive", "p_v", "b_s"] @@ -44,13 +52,10 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--semantic-columns", default=",".join(SEMANTIC_DEFAULT)) parser.add_argument("--synthetic-series", type=int, default=48) parser.add_argument("--synthetic-length", type=int, default=240) + parser.add_argument("--run-version", default="v2.3-m2") return parser.parse_args() -def utc_now() -> str: - return datetime.now(tz=timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") - - def set_seed(seed: int) -> None: random.seed(seed) np.random.seed(seed) @@ -262,9 +267,47 @@ def evaluate(model: nn.Module, x_tel: np.ndarray, x_sem: np.ndarray, y: np.ndarr def main() -> int: - args = parse_args(); set_seed(args.seed) - out_dir = Path(args.output_dir); out_dir.mkdir(parents=True, exist_ok=True) + args = parse_args() + set_seed(args.seed) + out_dir = Path(args.output_dir) + out_dir.mkdir(parents=True, exist_ok=True) + + dataset_path = Path(args.dataset_csv) if args.dataset_csv else None + dataset_file_info = describe_dataset_file(dataset_path) x_tel, x_sem, y, ds_meta = load_dataset(args) + run_config = { + "seed": args.seed, + "window_size": args.window_size, + "horizon": args.horizon, + "epochs": args.epochs, + "batch_size": args.batch_size, + "learning_rate": args.learning_rate, + "weight_decay": args.weight_decay, + "hidden_size": args.hidden_size, + "dropout": args.dropout, + "val_ratio": args.val_ratio, + "test_ratio": args.test_ratio, + "label_column": args.label_column, + "label_threshold": args.label_threshold, + "telemetry_columns": args.telemetry_columns, + "semantic_columns": args.semantic_columns, + "synthetic_series": args.synthetic_series, + "synthetic_length": args.synthetic_length, + } + ds_for_identity = dict(ds_meta) + if dataset_file_info: + ds_for_identity["dataset_file"] = dataset_file_info + repo_root = Path(__file__).resolve().parents[2] + git_sha = git_commit(repo_root) + git_dirty = git_is_dirty(repo_root) + run_id, run_fingerprint = build_run_identity( + pipeline="fusion-baseline", + run_version=args.run_version, + config=run_config, + dataset=ds_for_identity, + git_sha=git_sha, + ) + splits, norm = split_standardize(x_tel, x_sem, y, args) tr_t, tr_s, tr_y, va_t, va_s, va_y, te_t, te_s, te_y = splits tel_dim, sem_dim = tr_t.shape[2], tr_s.shape[1] @@ -280,15 +323,40 @@ def main() -> int: tel_m = {"train": evaluate(tel_model, tr_t, tr_s, tr_y), "val": evaluate(tel_model, va_t, va_s, va_y), "test": evaluate(tel_model, te_t, te_s, te_y)} fus_m = {"train": evaluate(fus_model, tr_t, tr_s, tr_y), "val": evaluate(fus_model, va_t, va_s, va_y), "test": evaluate(fus_model, te_t, te_s, te_y)} - tel_path = out_dir / "telemetry_only_baseline.pt"; fus_path = out_dir / "fusion_baseline.pt" - torch.save({"state_dict": tel_model.state_dict(), "normalization": norm, "created_at_utc": utc_now()}, tel_path) - torch.save({"state_dict": fus_model.state_dict(), "normalization": norm, "created_at_utc": utc_now()}, fus_path) + tel_path = out_dir / "telemetry_only_baseline.pt" + fus_path = out_dir / "fusion_baseline.pt" + torch.save( + { + "state_dict": tel_model.state_dict(), + "normalization": norm, + "run_id": run_id, + "run_fingerprint_sha256": run_fingerprint, + }, + tel_path, + ) + torch.save( + { + "state_dict": fus_model.state_dict(), + "normalization": norm, + "run_id": run_id, + "run_fingerprint_sha256": run_fingerprint, + }, + fus_path, + ) summary = { - "run_at_utc": utc_now(), + "pipeline": "fusion-baseline", + "run_version": args.run_version, + "run_id": run_id, "command": " ".join(["python"] + sys.argv), "dataset": ds_meta, - "config": {k: getattr(args, k) for k in ["seed", "window_size", "horizon", "epochs", "batch_size", "learning_rate", "weight_decay", "hidden_size", "dropout", "label_column", "label_threshold"]}, + "dataset_file": dataset_file_info, + "config": run_config, + "reproducibility": { + "run_fingerprint_sha256": run_fingerprint, + "git_commit": git_sha, + "git_dirty_worktree": git_dirty, + }, "label_balance": {"train_positive_rate": float(np.mean(tr_y)), "val_positive_rate": float(np.mean(va_y)), "test_positive_rate": float(np.mean(te_y))}, "models": {"telemetry_only": {"metrics": tel_m, "history": tel_hist, "artifact": str(tel_path)}, "fusion": {"metrics": fus_m, "history": fus_hist, "artifact": str(fus_path)}}, "comparison": { @@ -297,13 +365,45 @@ def main() -> int: }, } summary_path = out_dir / "fusion_evaluation_summary.json" - with summary_path.open("w", encoding="utf-8") as f: - json.dump(summary, f, indent=2, sort_keys=True) - f.write("\n") + write_json(summary_path, summary) + + base_artifacts = { + "telemetry_model": tel_path, + "fusion_model": fus_path, + "evaluation_summary": summary_path, + } + versioned_artifacts = materialize_versioned_artifacts( + output_dir=out_dir, + pipeline="fusion-baseline", + run_id=run_id, + artifacts=base_artifacts, + ) + manifest_artifacts = dict(base_artifacts) + for role, path in versioned_artifacts.items(): + manifest_artifacts[f"versioned_{role}"] = path + manifest_path = write_run_manifest( + output_dir=out_dir, + pipeline="fusion-baseline", + run_version=args.run_version, + run_id=run_id, + run_fingerprint=run_fingerprint, + git_sha=git_sha, + git_dirty=git_dirty, + config=run_config, + dataset=ds_for_identity, + metrics={ + "telemetry_only_test": tel_m["test"], + "fusion_test": fus_m["test"], + "comparison": summary["comparison"], + }, + artifacts=manifest_artifacts, + ) print(f"Telemetry-only model saved: {tel_path}") print(f"Fusion model saved: {fus_path}") print(f"Summary saved: {summary_path}") + print(f"Run manifest saved: {manifest_path}") + print(f"Run id: {run_id}") print( f"Test F1: telemetry={tel_m['test']['f1']:.4f} " f"fusion={fus_m['test']['f1']:.4f} " diff --git a/scripts/model_training/train_tsmixer_baseline.py b/scripts/model_training/train_tsmixer_baseline.py index 9987960..829dccb 100644 --- a/scripts/model_training/train_tsmixer_baseline.py +++ b/scripts/model_training/train_tsmixer_baseline.py @@ -4,11 +4,9 @@ from __future__ import annotations import argparse -import json import random import sys from dataclasses import dataclass -from datetime import datetime, timezone from pathlib import Path from typing import Any @@ -18,6 +16,16 @@ from torch import nn from torch.utils.data import DataLoader, TensorDataset +from artifact_registry import ( + build_run_identity, + describe_dataset_file, + git_commit, + git_is_dirty, + materialize_versioned_artifacts, + write_json, + write_run_manifest, +) + def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description=__doc__) @@ -93,13 +101,14 @@ def parse_args() -> argparse.Namespace: action="store_true", help="Export ONNX but skip checker/runtime validation.", ) + parser.add_argument( + "--run-version", + default="v2.3-m2", + help="Version prefix used by artifact naming/manifest.", + ) return parser.parse_args() -def now_utc_iso() -> str: - return datetime.now(tz=timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") - - def set_global_seed(seed: int) -> None: random.seed(seed) np.random.seed(seed) @@ -532,13 +541,6 @@ def validate_onnx( } -def write_json(path: Path, payload: dict[str, Any]) -> None: - path.parent.mkdir(parents=True, exist_ok=True) - with path.open("w", encoding="utf-8") as handle: - json.dump(payload, handle, indent=2, sort_keys=True) - handle.write("\n") - - def main() -> int: args = parse_args() set_global_seed(args.seed) @@ -590,6 +592,44 @@ def main() -> int: dataset_metadata["requested_dataset"] = str(dataset_path) dataset_metadata["fallback_reason"] = "dataset path does not exist" + run_config = { + "seed": args.seed, + "window_size": args.window_size, + "horizon": args.horizon, + "label_threshold": args.label_threshold, + "epochs": args.epochs, + "batch_size": args.batch_size, + "learning_rate": args.learning_rate, + "weight_decay": args.weight_decay, + "dropout": args.dropout, + "hidden_size": args.hidden_size, + "num_blocks": args.num_blocks, + "val_ratio": args.val_ratio, + "test_ratio": args.test_ratio, + "target_column": args.target_column, + "price_column": args.price_column, + "timestamp_column": args.timestamp_column, + "max_rows": args.max_rows, + "synthetic_series": args.synthetic_series, + "synthetic_length": args.synthetic_length, + "onnx_opset": args.onnx_opset, + "skip_onnx_validation": bool(args.skip_onnx_validation), + } + dataset_for_identity = dict(dataset_metadata) + dataset_file_info = describe_dataset_file(dataset_path) + if dataset_file_info: + dataset_for_identity["dataset_file"] = dataset_file_info + repo_root = Path(__file__).resolve().parents[2] + git_sha = git_commit(repo_root) + git_dirty = git_is_dirty(repo_root) + run_id, run_fingerprint = build_run_identity( + pipeline="tsmixer-baseline", + run_version=args.run_version, + config=run_config, + dataset=dataset_for_identity, + git_sha=git_sha, + ) + x_train, y_train, x_val, y_val, x_test, y_test = split_dataset( x, y, @@ -631,7 +671,8 @@ def main() -> int: "channels": channels, "train_mean": train_mean.squeeze(0).tolist(), "train_std": train_std.squeeze(0).tolist(), - "created_at_utc": now_utc_iso(), + "run_id": run_id, + "run_fingerprint_sha256": run_fingerprint, }, model_path, ) @@ -651,25 +692,18 @@ def main() -> int: onnx_validation = validate_onnx(model, onnx_path, sample_inputs) summary = { - "run_at_utc": now_utc_iso(), + "pipeline": "tsmixer-baseline", + "run_version": args.run_version, + "run_id": run_id, "command": " ".join(["python"] + sys.argv), - "config": { - "seed": args.seed, - "window_size": args.window_size, - "horizon": args.horizon, - "label_threshold": args.label_threshold, - "epochs": args.epochs, - "batch_size": args.batch_size, - "learning_rate": args.learning_rate, - "weight_decay": args.weight_decay, - "dropout": args.dropout, - "hidden_size": args.hidden_size, - "num_blocks": args.num_blocks, - "val_ratio": args.val_ratio, - "test_ratio": args.test_ratio, - "onnx_opset": args.onnx_opset, - }, + "config": run_config, "dataset": dataset_metadata, + "dataset_file": dataset_file_info, + "reproducibility": { + "run_fingerprint_sha256": run_fingerprint, + "git_commit": git_sha, + "git_dirty_worktree": git_dirty, + }, "shapes": { "x_train": list(x_train.shape), "x_val": list(x_val.shape), @@ -696,9 +730,39 @@ def main() -> int: summary_path = output_dir / "training_summary.json" write_json(summary_path, summary) + base_artifacts = { + "torch_model": model_path, + "onnx_model": onnx_path, + "training_summary": summary_path, + } + versioned_artifacts = materialize_versioned_artifacts( + output_dir=output_dir, + pipeline="tsmixer-baseline", + run_id=run_id, + artifacts=base_artifacts, + ) + manifest_artifacts = dict(base_artifacts) + for role, path in versioned_artifacts.items(): + manifest_artifacts[f"versioned_{role}"] = path + manifest_path = write_run_manifest( + output_dir=output_dir, + pipeline="tsmixer-baseline", + run_version=args.run_version, + run_id=run_id, + run_fingerprint=run_fingerprint, + git_sha=git_sha, + git_dirty=git_dirty, + config=run_config, + dataset=dataset_for_identity, + metrics=summary["metrics"], + artifacts=manifest_artifacts, + ) + print(f"Model saved: {model_path}") print(f"ONNX saved: {onnx_path}") print(f"Summary saved: {summary_path}") + print(f"Run manifest saved: {manifest_path}") + print(f"Run id: {run_id}") print( "Metrics:" f" train_acc={train_metrics.accuracy:.4f}" diff --git a/scripts/model_training/verify_reproducible_run.py b/scripts/model_training/verify_reproducible_run.py new file mode 100644 index 0000000..de92bcb --- /dev/null +++ b/scripts/model_training/verify_reproducible_run.py @@ -0,0 +1,195 @@ +#!/usr/bin/env python3 +"""Run a training/eval script twice and verify artifact hashes are identical.""" + +from __future__ import annotations + +import argparse +import shutil +import subprocess +import sys +from pathlib import Path +from typing import Any + +from artifact_registry import sha256_file, write_json + + +def parse_args() -> tuple[argparse.Namespace, list[str]]: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--script", + required=True, + help="Script path to run twice (for example scripts/model_training/train_tsmixer_baseline.py).", + ) + parser.add_argument( + "--base-output-dir", + required=True, + help="Base directory where run_a and run_b outputs are generated.", + ) + parser.add_argument( + "--artifacts", + required=True, + help="Comma-separated artifact paths relative to each run output directory.", + ) + parser.add_argument( + "--python-executable", + default=sys.executable, + help="Python executable used to launch the target script.", + ) + parser.add_argument( + "--report-name", + default="reproducibility_check.json", + help="Output report file name written under base-output-dir.", + ) + parser.add_argument( + "--keep-runs", + action="store_true", + help="Keep run_a and run_b folders if verification fails.", + ) + args, passthrough = parser.parse_known_args() + if passthrough and passthrough[0] == "--": + passthrough = passthrough[1:] + return args, passthrough + + +def run_once( + *, + python_executable: str, + script_path: Path, + output_dir: Path, + passthrough: list[str], +) -> subprocess.CompletedProcess[str]: + command = [ + python_executable, + script_path.as_posix(), + "--output-dir", + output_dir.as_posix(), + *passthrough, + ] + return subprocess.run(command, check=False, capture_output=True, text=True) + + +def snapshot_artifacts(*, source_dir: Path, snapshot_dir: Path, artifacts: list[str]) -> None: + if snapshot_dir.exists(): + shutil.rmtree(snapshot_dir) + snapshot_dir.mkdir(parents=True, exist_ok=True) + for artifact in artifacts: + artifact_rel = artifact.strip() + if not artifact_rel: + continue + source_path = source_dir / artifact_rel + target_path = snapshot_dir / artifact_rel + if not source_path.exists(): + continue + target_path.parent.mkdir(parents=True, exist_ok=True) + shutil.copy2(source_path, target_path) + + +def compare_artifacts( + *, + run_a_dir: Path, + run_b_dir: Path, + artifacts: list[str], +) -> tuple[bool, list[dict[str, Any]]]: + all_equal = True + comparisons: list[dict[str, Any]] = [] + for artifact in artifacts: + artifact_rel = artifact.strip() + if not artifact_rel: + continue + path_a = run_a_dir / artifact_rel + path_b = run_b_dir / artifact_rel + exists_a = path_a.exists() + exists_b = path_b.exists() + hash_a = sha256_file(path_a) if exists_a else None + hash_b = sha256_file(path_b) if exists_b else None + identical = bool(exists_a and exists_b and hash_a == hash_b) + if not identical: + all_equal = False + comparisons.append( + { + "artifact": artifact_rel, + "run_a_exists": exists_a, + "run_b_exists": exists_b, + "run_a_sha256": hash_a, + "run_b_sha256": hash_b, + "identical": identical, + } + ) + return all_equal, comparisons + + +def main() -> int: + args, passthrough = parse_args() + script_path = Path(args.script) + if not script_path.exists(): + print(f"Script not found: {script_path}", file=sys.stderr) + return 2 + + base_output_dir = Path(args.base_output_dir) + execution_dir = base_output_dir / "exec" + run_a_dir = base_output_dir / "run_a" + run_b_dir = base_output_dir / "run_b" + for run_dir in (execution_dir, run_a_dir, run_b_dir): + if run_dir.exists(): + shutil.rmtree(run_dir) + run_dir.mkdir(parents=True, exist_ok=True) + + artifacts = [item for item in args.artifacts.split(",") if item.strip()] + + first = run_once( + python_executable=args.python_executable, + script_path=script_path, + output_dir=execution_dir, + passthrough=passthrough, + ) + snapshot_artifacts(source_dir=execution_dir, snapshot_dir=run_a_dir, artifacts=artifacts) + + if execution_dir.exists(): + shutil.rmtree(execution_dir) + execution_dir.mkdir(parents=True, exist_ok=True) + + second = run_once( + python_executable=args.python_executable, + script_path=script_path, + output_dir=execution_dir, + passthrough=passthrough, + ) + snapshot_artifacts(source_dir=execution_dir, snapshot_dir=run_b_dir, artifacts=artifacts) + + all_equal, comparisons = compare_artifacts(run_a_dir=run_a_dir, run_b_dir=run_b_dir, artifacts=artifacts) + + report = { + "script": script_path.as_posix(), + "python_executable": args.python_executable, + "passthrough_args": passthrough, + "first_run_return_code": int(first.returncode), + "second_run_return_code": int(second.returncode), + "all_artifacts_identical": bool(all_equal and first.returncode == 0 and second.returncode == 0), + "artifacts": comparisons, + "first_run_stdout": first.stdout, + "first_run_stderr": first.stderr, + "second_run_stdout": second.stdout, + "second_run_stderr": second.stderr, + } + report_path = base_output_dir / args.report_name + write_json(report_path, report) + + print(f"Report saved: {report_path}") + print(f"First run rc={first.returncode}, second run rc={second.returncode}") + print(f"Artifacts identical: {report['all_artifacts_identical']}") + for item in comparisons: + state = "OK" if item["identical"] else "DIFF" + print(f"- {state}: {item['artifact']}") + + if report["all_artifacts_identical"]: + return 0 + + if not args.keep_runs: + for run_dir in (run_a_dir, run_b_dir): + if run_dir.exists(): + shutil.rmtree(run_dir) + return 1 + + +if __name__ == "__main__": + raise SystemExit(main()) From 95b4f4f6166f5e6e91c508ed4dcbe18425171a0f Mon Sep 17 00:00:00 2001 From: JasonEran Date: Wed, 25 Feb 2026 10:31:53 +0800 Subject: [PATCH 18/24] feat(v2.3): extend heartbeat payload with semantic features --- README.md | 1 + docs/PROTO-Heartbeat-Semantic-v2.3-M3.md | 60 ++++++++++++++++ .../Controllers/AgentController.cs | 18 ++++- .../Services/AgentWorkflowService.cs | 71 ++++++++++++++++++- src/shared/protos/agent_service.proto | 13 ++++ 5 files changed, 161 insertions(+), 2 deletions(-) create mode 100644 docs/PROTO-Heartbeat-Semantic-v2.3-M3.md diff --git a/README.md b/README.md index 3aab2cd..fa02094 100644 --- a/README.md +++ b/README.md @@ -202,6 +202,7 @@ Open the dashboard at http://localhost:3000. - v2.3 M2 backtesting guide: docs/AI-Backtesting-v2.3-M2.md - v2.3 M2 artifact versioning + reproducibility guide: docs/AI-Artifact-Versioning-v2.3-M2.md - v2.3 M2 model training scripts: scripts/model_training/README.md +- v2.3 M3 heartbeat semantic payload contract: docs/PROTO-Heartbeat-Semantic-v2.3-M3.md If you want to simulate migrations, start at least two agents: diff --git a/docs/PROTO-Heartbeat-Semantic-v2.3-M3.md b/docs/PROTO-Heartbeat-Semantic-v2.3-M3.md new file mode 100644 index 0000000..8051de0 --- /dev/null +++ b/docs/PROTO-Heartbeat-Semantic-v2.3-M3.md @@ -0,0 +1,60 @@ +# v2.3 M3 Heartbeat Semantic Payload Contract + +This document captures issue #39: extending heartbeat payloads with semantic features while keeping backward compatibility. + +## Scope + +- Protobuf contract update: `src/shared/protos/agent_service.proto` +- Core heartbeat response wiring: + - gRPC: `AgentWorkflowService` + - legacy REST bridge: `AgentController` + +## Protobuf Additions + +New message: + +- `SemanticHeartbeatFeatures` + - `schema_version` + - `s_v_negative` + - `s_v_neutral` + - `s_v_positive` + - `p_v` + - `b_s` + - `source` + - `generated_at_unix` + - `fallback_used` + +`HeartbeatResponse` additive field: + +- `semantic_features = 3` + +No existing field numbers were changed, so wire compatibility is preserved. + +## JSON Transcoding Shape + +`POST /api/v2/agent/heartbeat` response now includes: + +```json +{ + "status": "active", + "commands": [], + "semanticFeatures": { + "schemaVersion": "1.0", + "sVNegative": 0.33, + "sVNeutral": 0.34, + "sVPositive": 0.33, + "pV": 0.5, + "bS": 0.5, + "source": "fallback:neutral", + "generatedAtUnix": 1767000000, + "fallbackUsed": true + } +} +``` + +Older agents that ignore unknown JSON/protobuf fields remain compatible. + +## Runtime Source Selection + +Core heartbeat selects the latest enriched external signal vector when available. +If no enriched signal is available, it emits a neutral fallback vector and marks `fallback_used=true`. diff --git a/src/services/core-dotnet/AetherGuard.Core/Controllers/AgentController.cs b/src/services/core-dotnet/AetherGuard.Core/Controllers/AgentController.cs index c3ca040..40f7b0f 100644 --- a/src/services/core-dotnet/AetherGuard.Core/Controllers/AgentController.cs +++ b/src/services/core-dotnet/AetherGuard.Core/Controllers/AgentController.cs @@ -90,7 +90,23 @@ public async Task Heartbeat([FromBody] HeartbeatRequest request, }) .ToArray(); - return Ok(new { status = result.Payload.Status, commands = commandPayload }); + var semantic = result.Payload.SemanticFeatures; + var semanticPayload = semantic is null + ? null + : new + { + schemaVersion = semantic.SchemaVersion, + sVNegative = semantic.SVNegative, + sVNeutral = semantic.SVNeutral, + sVPositive = semantic.SVPositive, + pV = semantic.PV, + bS = semantic.BS, + source = semantic.Source, + generatedAtUnix = semantic.GeneratedAtUnix, + fallbackUsed = semantic.FallbackUsed + }; + + return Ok(new { status = result.Payload.Status, commands = commandPayload, semanticFeatures = semanticPayload }); } [HttpGet("poll")] diff --git a/src/services/core-dotnet/AetherGuard.Core/Services/AgentWorkflowService.cs b/src/services/core-dotnet/AetherGuard.Core/Services/AgentWorkflowService.cs index dce0888..e6a1576 100644 --- a/src/services/core-dotnet/AetherGuard.Core/Services/AgentWorkflowService.cs +++ b/src/services/core-dotnet/AetherGuard.Core/Services/AgentWorkflowService.cs @@ -107,7 +107,11 @@ public async Task> HeartbeatAsync(HeartbeatRequest await _context.SaveChangesAsync(cancellationToken); - var response = new HeartbeatResponse { Status = "active" }; + var response = new HeartbeatResponse + { + Status = "active", + SemanticFeatures = await BuildSemanticFeaturesAsync(cancellationToken) + }; response.Commands.AddRange(pendingCommands.Select(MapCommand)); return ApiResult.Ok(response); @@ -216,6 +220,71 @@ public async Task> FeedbackAsync(FeedbackRequest req return ApiResult.Ok(new FeedbackResponse { Status = command.Status }); } + private async Task BuildSemanticFeaturesAsync(CancellationToken cancellationToken) + { + var latest = await _context.ExternalSignals + .AsNoTracking() + .Where(signal => + signal.SentimentNegative.HasValue && + signal.SentimentNeutral.HasValue && + signal.SentimentPositive.HasValue && + signal.VolatilityProbability.HasValue && + signal.SupplyBias.HasValue) + .OrderByDescending(signal => signal.EnrichedAt ?? signal.PublishedAt) + .FirstOrDefaultAsync(cancellationToken); + + if (latest is null) + { + return BuildFallbackSemanticFeatures(); + } + + var generatedAt = latest.EnrichedAt ?? latest.PublishedAt; + return new SemanticHeartbeatFeatures + { + SchemaVersion = string.IsNullOrWhiteSpace(latest.EnrichmentSchemaVersion) + ? "1.0" + : latest.EnrichmentSchemaVersion, + SVNegative = Clamp01(latest.SentimentNegative), + SVNeutral = Clamp01(latest.SentimentNeutral), + SVPositive = Clamp01(latest.SentimentPositive), + PV = Clamp01(latest.VolatilityProbability), + BS = Clamp01(latest.SupplyBias), + Source = latest.Source, + GeneratedAtUnix = generatedAt.ToUnixTimeSeconds(), + FallbackUsed = false + }; + } + + private static SemanticHeartbeatFeatures BuildFallbackSemanticFeatures() + => new() + { + SchemaVersion = "1.0", + SVNegative = 0.33, + SVNeutral = 0.34, + SVPositive = 0.33, + PV = 0.5, + BS = 0.5, + Source = "fallback:neutral", + GeneratedAtUnix = DateTimeOffset.UtcNow.ToUnixTimeSeconds(), + FallbackUsed = true + }; + + private static double Clamp01(double? value) + { + var resolved = value ?? 0.0; + if (resolved < 0.0) + { + return 0.0; + } + + if (resolved > 1.0) + { + return 1.0; + } + + return resolved; + } + private static AetherGuard.Grpc.V1.AgentCommand MapCommand(CoreAgentCommand command) { var parameters = GrpcParameterConverter.ParseJsonStruct(command.Parameters); diff --git a/src/shared/protos/agent_service.proto b/src/shared/protos/agent_service.proto index a92fc1c..2e69af8 100644 --- a/src/shared/protos/agent_service.proto +++ b/src/shared/protos/agent_service.proto @@ -77,9 +77,22 @@ message AgentConfig { string node_mode = 5; } +message SemanticHeartbeatFeatures { + string schema_version = 1; + double s_v_negative = 2; + double s_v_neutral = 3; + double s_v_positive = 4; + double p_v = 5; + double b_s = 6; + string source = 7; + int64 generated_at_unix = 8; + bool fallback_used = 9; +} + message HeartbeatResponse { string status = 1; repeated aetherguard.common.v1.AgentCommand commands = 2; + SemanticHeartbeatFeatures semantic_features = 3; } message PollCommandsRequest { From 0783022fbac40a6a03713033e4598923321e4dd6 Mon Sep 17 00:00:00 2001 From: JasonEran Date: Wed, 25 Feb 2026 10:50:55 +0800 Subject: [PATCH 19/24] feat(v2.3): add agent local inference runtime with rollout gates --- README.md | 1 + docs/Agent-ONNX-Inference-v2.3-M3.md | 51 ++++ src/services/agent-cpp/CMakeLists.txt | 48 ++++ src/services/agent-cpp/InferenceEngine.cpp | 231 ++++++++++++++++++ src/services/agent-cpp/InferenceEngine.hpp | 46 ++++ src/services/agent-cpp/NetworkClient.cpp | 36 ++- src/services/agent-cpp/NetworkClient.hpp | 5 +- src/services/agent-cpp/SemanticFeatures.hpp | 16 ++ src/services/agent-cpp/main.cpp | 40 ++- .../agent-cpp/tests/InferenceEngineTests.cpp | 76 ++++++ 10 files changed, 540 insertions(+), 10 deletions(-) create mode 100644 docs/Agent-ONNX-Inference-v2.3-M3.md create mode 100644 src/services/agent-cpp/InferenceEngine.cpp create mode 100644 src/services/agent-cpp/InferenceEngine.hpp create mode 100644 src/services/agent-cpp/SemanticFeatures.hpp create mode 100644 src/services/agent-cpp/tests/InferenceEngineTests.cpp diff --git a/README.md b/README.md index fa02094..6a4cb16 100644 --- a/README.md +++ b/README.md @@ -203,6 +203,7 @@ Open the dashboard at http://localhost:3000. - v2.3 M2 artifact versioning + reproducibility guide: docs/AI-Artifact-Versioning-v2.3-M2.md - v2.3 M2 model training scripts: scripts/model_training/README.md - v2.3 M3 heartbeat semantic payload contract: docs/PROTO-Heartbeat-Semantic-v2.3-M3.md +- v2.3 M3 agent ONNX inference + gating: docs/Agent-ONNX-Inference-v2.3-M3.md If you want to simulate migrations, start at least two agents: diff --git a/docs/Agent-ONNX-Inference-v2.3-M3.md b/docs/Agent-ONNX-Inference-v2.3-M3.md new file mode 100644 index 0000000..1e13b6e --- /dev/null +++ b/docs/Agent-ONNX-Inference-v2.3-M3.md @@ -0,0 +1,51 @@ +# v2.3 M3 Agent Local ONNX Inference + Feature Gating + +This document captures issue #40. + +## Scope + +- Agent local inference runtime integration points. +- Feature gate and rollback switches for safe rollout. + +## Implementation + +### New agent components + +- `src/services/agent-cpp/InferenceEngine.hpp` +- `src/services/agent-cpp/InferenceEngine.cpp` +- `src/services/agent-cpp/SemanticFeatures.hpp` + +### Heartbeat semantic payload consumption + +- `NetworkClient::SendHeartbeat(...)` now parses `semanticFeatures` from heartbeat response. +- Parsed fields include: + - `schemaVersion` + - `sVNegative`, `sVNeutral`, `sVPositive` + - `pV`, `bS` + - `source`, `generatedAtUnix`, `fallbackUsed` + +### ONNX runtime integration mode + +- Build switch: `AETHER_ENABLE_ONNX_RUNTIME` +- CMake input when enabled: + - `ONNXRUNTIME_ROOT` (must contain `include/` and `lib/`) +- Runtime model path: + - `AG_ONNX_MODEL_PATH` + +If ONNX runtime is unavailable (or model path invalid), the engine can fail-open to fallback scoring. + +## Feature Gate + Rollback + +- `AG_M3_ONLINE_INFERENCE_ENABLED` + Enables local inference path. +- `AG_M3_FORCE_V22_FALLBACK` + Forces rollback behavior and bypasses ONNX usage. +- `AG_ONNX_FAIL_OPEN` + If `true`, agent continues with fallback scoring when ONNX init/runtime fails. +- `AG_ONNX_DECISION_THRESHOLD` + Decision threshold for preemption recommendation. + +## Validation + +- Core heartbeat contract remains compatible (semantic fields are additive). +- Agent fallback path is covered by `AetherAgentInferenceTests`. diff --git a/src/services/agent-cpp/CMakeLists.txt b/src/services/agent-cpp/CMakeLists.txt index fde1cfe..e960d2f 100644 --- a/src/services/agent-cpp/CMakeLists.txt +++ b/src/services/agent-cpp/CMakeLists.txt @@ -11,6 +11,7 @@ option(AETHER_ENABLE_GRPC "Build gRPC stubs for agent protos." ON) option(AETHER_USE_LOCAL_PROTOBUF "Use third_party/protobuf when available." ON) option(AETHER_GRPC_USE_ARCHIVE "Use gRPC release archive (no submodules)." OFF) option(AETHER_ENABLE_OTEL "Enable OpenTelemetry spans for the agent." OFF) +option(AETHER_ENABLE_ONNX_RUNTIME "Enable ONNX Runtime local inference on the agent." OFF) # CMake < 3.24 doesn't understand DOWNLOAD_EXTRACT_TIMESTAMP, which breaks ExternalProject. set(_AETHER_FC_EXTRACT_TIMESTAMP) @@ -253,6 +254,7 @@ add_executable(AetherAgent CriuManager.cpp CommandDispatcher.cpp CommandPoller.cpp + InferenceEngine.cpp main.cpp LifecycleManager.cpp NetworkClient.cpp @@ -268,8 +270,32 @@ target_link_libraries(AetherAgent target_compile_definitions(AetherAgent PRIVATE $<$:AETHER_ENABLE_OTEL=1> + $<$:AETHER_ENABLE_ONNX_RUNTIME=1> ) +if(AETHER_ENABLE_ONNX_RUNTIME) + set(ONNXRUNTIME_ROOT "" CACHE PATH "Path to ONNX Runtime distribution root (contains include/ and lib/).") + if(NOT ONNXRUNTIME_ROOT) + message(FATAL_ERROR "AETHER_ENABLE_ONNX_RUNTIME=ON requires ONNXRUNTIME_ROOT to be set.") + endif() + + set(ONNXRUNTIME_INCLUDE_DIR "${ONNXRUNTIME_ROOT}/include") + if(NOT EXISTS "${ONNXRUNTIME_INCLUDE_DIR}") + message(FATAL_ERROR "ONNX Runtime include directory not found: ${ONNXRUNTIME_INCLUDE_DIR}") + endif() + + find_library(ONNXRUNTIME_LIBRARY + NAMES onnxruntime onnxruntime.lib + PATHS "${ONNXRUNTIME_ROOT}/lib" + NO_DEFAULT_PATH) + if(NOT ONNXRUNTIME_LIBRARY) + message(FATAL_ERROR "Unable to find ONNX Runtime library under ${ONNXRUNTIME_ROOT}/lib") + endif() + + target_include_directories(AetherAgent PRIVATE "${ONNXRUNTIME_INCLUDE_DIR}") + target_link_libraries(AetherAgent PRIVATE "${ONNXRUNTIME_LIBRARY}") +endif() + if(AETHER_ENABLE_OTEL) set(_AETHER_OTEL_TARGETS opentelemetry::trace @@ -308,3 +334,25 @@ target_include_directories(AetherAgentTests ) add_test(NAME AetherAgentTests COMMAND AetherAgentTests) + +add_executable(AetherAgentInferenceTests + InferenceEngine.cpp + tests/InferenceEngineTests.cpp +) + +target_include_directories(AetherAgentInferenceTests + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR} +) + +target_compile_definitions(AetherAgentInferenceTests + PRIVATE + $<$:AETHER_ENABLE_ONNX_RUNTIME=1> +) + +if(AETHER_ENABLE_ONNX_RUNTIME) + target_include_directories(AetherAgentInferenceTests PRIVATE "${ONNXRUNTIME_INCLUDE_DIR}") + target_link_libraries(AetherAgentInferenceTests PRIVATE "${ONNXRUNTIME_LIBRARY}") +endif() + +add_test(NAME AetherAgentInferenceTests COMMAND AetherAgentInferenceTests) diff --git a/src/services/agent-cpp/InferenceEngine.cpp b/src/services/agent-cpp/InferenceEngine.cpp new file mode 100644 index 0000000..267770e --- /dev/null +++ b/src/services/agent-cpp/InferenceEngine.cpp @@ -0,0 +1,231 @@ +#include "InferenceEngine.hpp" + +#include +#include +#include +#include +#include +#include + +#if defined(AETHER_ENABLE_ONNX_RUNTIME) +#include +#endif + +namespace { +constexpr double kDefaultFallbackNeutral = 0.5; +constexpr size_t kFallbackWindow = 24; +constexpr size_t kFallbackChannels = 4; +} // namespace + +struct InferenceEngine::OnnxSessionHandle { +#if defined(AETHER_ENABLE_ONNX_RUNTIME) + Ort::Env env; + Ort::SessionOptions options; + std::unique_ptr session; + std::string inputName; + std::string outputName; + size_t windowSize = kFallbackWindow; + size_t channels = kFallbackChannels; + + OnnxSessionHandle() + : env(ORT_LOGGING_LEVEL_WARNING, "aether-agent-onnx") {} +#endif +}; + +InferenceEngine::InferenceEngine(InferenceRuntimeConfig config) + : config_(std::move(config)), + initStatus_("not_initialized"), + onnxSession_(std::make_unique()) {} + +InferenceEngine::~InferenceEngine() = default; + +bool InferenceEngine::Initialize() +{ + if (!config_.enabled) { + initStatus_ = "feature_gate_disabled"; + return true; + } + + if (config_.forceV22Fallback) { + initStatus_ = "rollback_forced_v22_fallback"; + return true; + } + + if (config_.modelPath.empty()) { + initStatus_ = "onnx_model_path_missing"; + return config_.failOpen; + } + + if (!std::filesystem::exists(config_.modelPath)) { + initStatus_ = "onnx_model_not_found"; + return config_.failOpen; + } + +#if defined(AETHER_ENABLE_ONNX_RUNTIME) + try { + onnxSession_->options.SetIntraOpNumThreads(1); + onnxSession_->options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL); + onnxSession_->session = std::make_unique( + onnxSession_->env, + config_.modelPath.c_str(), + onnxSession_->options); + + Ort::AllocatorWithDefaultOptions allocator; + auto inputName = onnxSession_->session->GetInputNameAllocated(0, allocator); + auto outputName = onnxSession_->session->GetOutputNameAllocated(0, allocator); + onnxSession_->inputName = inputName.get(); + onnxSession_->outputName = outputName.get(); + + auto inputType = onnxSession_->session->GetInputTypeInfo(0); + auto tensorInfo = inputType.GetTensorTypeAndShapeInfo(); + auto shape = tensorInfo.GetShape(); + if (shape.size() == 3) { + if (shape[1] > 0) { + onnxSession_->windowSize = static_cast(shape[1]); + } + if (shape[2] > 0) { + onnxSession_->channels = static_cast(shape[2]); + } + } + + initStatus_ = "onnx_runtime_ready"; + return true; + } catch (const std::exception& ex) { + initStatus_ = std::string("onnx_runtime_init_failed:") + ex.what(); + return config_.failOpen; + } +#else + initStatus_ = "compiled_without_onnx_runtime"; + return config_.failOpen; +#endif +} + +InferenceDecision InferenceEngine::Evaluate(const SemanticHeartbeatFeatures& semanticFeatures) const +{ + if (!config_.enabled) { + return EvaluateFallback(semanticFeatures, "feature_gate_disabled"); + } + + if (config_.forceV22Fallback) { + return EvaluateFallback(semanticFeatures, "rollback_forced_v22_fallback"); + } + +#if defined(AETHER_ENABLE_ONNX_RUNTIME) + if (onnxSession_ && onnxSession_->session) { + try { + const auto window = onnxSession_->windowSize; + const auto channels = onnxSession_->channels; + std::vector input(window * channels, 0.0f); + + for (size_t t = 0; t < window; ++t) { + for (size_t c = 0; c < channels; ++c) { + double value = kDefaultFallbackNeutral; + switch (c % 4) { + case 0: + value = Clamp01(semanticFeatures.pV); + break; + case 1: + value = Clamp01(semanticFeatures.sVNegative); + break; + case 2: + value = Clamp01(semanticFeatures.sVPositive); + break; + case 3: + value = Clamp01(semanticFeatures.bS); + break; + default: + break; + } + input[t * channels + c] = static_cast(value); + } + } + + const std::vector inputShape{ + 1, + static_cast(window), + static_cast(channels)}; + auto memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); + Ort::Value inputTensor = Ort::Value::CreateTensor( + memoryInfo, + input.data(), + input.size(), + inputShape.data(), + inputShape.size()); + + const char* inputNames[] = {onnxSession_->inputName.c_str()}; + const char* outputNames[] = {onnxSession_->outputName.c_str()}; + auto outputs = onnxSession_->session->Run( + Ort::RunOptions{nullptr}, + inputNames, + &inputTensor, + 1, + outputNames, + 1); + + if (outputs.empty() || !outputs[0].IsTensor()) { + return EvaluateFallback(semanticFeatures, "onnx_output_invalid"); + } + + auto* logits = outputs[0].GetTensorMutableData(); + const double probability = Sigmoid(static_cast(logits[0])); + InferenceDecision decision; + decision.enabled = true; + decision.usedOnnxRuntime = true; + decision.fallbackApplied = false; + decision.probability = Clamp01(probability); + decision.shouldPreempt = decision.probability >= config_.decisionThreshold; + decision.reason = "onnx_runtime"; + return decision; + } catch (const std::exception& ex) { + return EvaluateFallback( + semanticFeatures, + std::string("onnx_runtime_error:") + ex.what()); + } + } +#endif + + return EvaluateFallback(semanticFeatures, initStatus_); +} + +const std::string& InferenceEngine::InitializationStatus() const +{ + return initStatus_; +} + +double InferenceEngine::Clamp01(double value) +{ + if (value < 0.0) { + return 0.0; + } + if (value > 1.0) { + return 1.0; + } + return value; +} + +double InferenceEngine::Sigmoid(double value) +{ + return 1.0 / (1.0 + std::exp(-value)); +} + +InferenceDecision InferenceEngine::EvaluateFallback( + const SemanticHeartbeatFeatures& semanticFeatures, + const std::string& reason) const +{ + const double volatility = Clamp01(semanticFeatures.pV); + const double negative = Clamp01(semanticFeatures.sVNegative); + const double supplyStress = Clamp01(1.0 - semanticFeatures.bS); + double probability = 0.60 * volatility + 0.25 * negative + 0.15 * supplyStress; + if (semanticFeatures.fallbackUsed) { + probability = std::max(probability, 0.5); + } + + InferenceDecision decision; + decision.enabled = config_.enabled && !config_.forceV22Fallback; + decision.usedOnnxRuntime = false; + decision.fallbackApplied = true; + decision.probability = Clamp01(probability); + decision.shouldPreempt = decision.probability >= config_.decisionThreshold; + decision.reason = "fallback:" + reason; + return decision; +} diff --git a/src/services/agent-cpp/InferenceEngine.hpp b/src/services/agent-cpp/InferenceEngine.hpp new file mode 100644 index 0000000..1c09277 --- /dev/null +++ b/src/services/agent-cpp/InferenceEngine.hpp @@ -0,0 +1,46 @@ +#pragma once + +#include "SemanticFeatures.hpp" + +#include +#include + +struct InferenceRuntimeConfig { + bool enabled = false; + bool forceV22Fallback = false; + bool failOpen = true; + double decisionThreshold = 0.65; + std::string modelPath; +}; + +struct InferenceDecision { + bool enabled = false; + bool usedOnnxRuntime = false; + bool fallbackApplied = true; + bool shouldPreempt = false; + double probability = 0.0; + std::string reason; +}; + +class InferenceEngine { +public: + explicit InferenceEngine(InferenceRuntimeConfig config); + ~InferenceEngine(); + + bool Initialize(); + InferenceDecision Evaluate(const SemanticHeartbeatFeatures& semanticFeatures) const; + const std::string& InitializationStatus() const; + +private: + InferenceRuntimeConfig config_; + mutable std::string initStatus_; + + struct OnnxSessionHandle; + std::unique_ptr onnxSession_; + + static double Clamp01(double value); + static double Sigmoid(double value); + InferenceDecision EvaluateFallback( + const SemanticHeartbeatFeatures& semanticFeatures, + const std::string& reason) const; +}; diff --git a/src/services/agent-cpp/NetworkClient.cpp b/src/services/agent-cpp/NetworkClient.cpp index 3dfb926..04fb75d 100644 --- a/src/services/agent-cpp/NetworkClient.cpp +++ b/src/services/agent-cpp/NetworkClient.cpp @@ -172,8 +172,12 @@ bool NetworkClient::SendHeartbeat( const std::string& agentId, const std::string& state, const std::string& tier, - std::vector& outCommands) { + std::vector& outCommands, + SemanticHeartbeatFeatures* outSemanticFeatures) { outCommands.clear(); + if (outSemanticFeatures != nullptr) { + *outSemanticFeatures = SemanticHeartbeatFeatures{}; + } nlohmann::json payload = { {"agentId", agentId}, @@ -218,15 +222,31 @@ bool NetworkClient::SendHeartbeat( if (requestOk && statusOk) { auto json = nlohmann::json::parse(response.text, nullptr, false); - if (!json.is_discarded() && json.contains("commands") && json["commands"].is_array()) { - for (const auto& item : json["commands"]) { - AgentCommand command; - command.id = item.value("id", 0); - command.type = item.value("type", ""); - if (command.id > 0 && !command.type.empty()) { - outCommands.push_back(std::move(command)); + if (!json.is_discarded()) { + if (json.contains("commands") && json["commands"].is_array()) { + for (const auto& item : json["commands"]) { + AgentCommand command; + command.id = item.value("id", 0); + command.type = item.value("type", item.value("action", "")); + if (command.id > 0 && !command.type.empty()) { + outCommands.push_back(std::move(command)); + } } } + + if (outSemanticFeatures != nullptr && json.contains("semanticFeatures") && json["semanticFeatures"].is_object()) { + const auto& semantic = json["semanticFeatures"]; + outSemanticFeatures->present = true; + outSemanticFeatures->schemaVersion = semantic.value("schemaVersion", ""); + outSemanticFeatures->sVNegative = semantic.value("sVNegative", 0.33); + outSemanticFeatures->sVNeutral = semantic.value("sVNeutral", 0.34); + outSemanticFeatures->sVPositive = semantic.value("sVPositive", 0.33); + outSemanticFeatures->pV = semantic.value("pV", 0.5); + outSemanticFeatures->bS = semantic.value("bS", 0.5); + outSemanticFeatures->source = semantic.value("source", ""); + outSemanticFeatures->generatedAtUnix = semantic.value("generatedAtUnix", 0LL); + outSemanticFeatures->fallbackUsed = semantic.value("fallbackUsed", true); + } } return true; } diff --git a/src/services/agent-cpp/NetworkClient.hpp b/src/services/agent-cpp/NetworkClient.hpp index 15d17b8..5e12e6b 100644 --- a/src/services/agent-cpp/NetworkClient.hpp +++ b/src/services/agent-cpp/NetworkClient.hpp @@ -1,5 +1,7 @@ #pragma once +#include "SemanticFeatures.hpp" + #include #include @@ -68,7 +70,8 @@ class NetworkClient { const std::string& agentId, const std::string& state, const std::string& tier, - std::vector& outCommands); + std::vector& outCommands, + SemanticHeartbeatFeatures* outSemanticFeatures = nullptr); bool PollCommands(const std::string& agentId, std::vector& outCommands); bool SendFeedback(const std::string& agentId, const CommandFeedback& feedback); bool UploadSnapshot(const std::string& url, const std::string& filePath); diff --git a/src/services/agent-cpp/SemanticFeatures.hpp b/src/services/agent-cpp/SemanticFeatures.hpp new file mode 100644 index 0000000..d85d1c7 --- /dev/null +++ b/src/services/agent-cpp/SemanticFeatures.hpp @@ -0,0 +1,16 @@ +#pragma once + +#include + +struct SemanticHeartbeatFeatures { + bool present = false; + std::string schemaVersion; + double sVNegative = 0.33; + double sVNeutral = 0.34; + double sVPositive = 0.33; + double pV = 0.50; + double bS = 0.50; + std::string source; + long long generatedAtUnix = 0; + bool fallbackUsed = true; +}; diff --git a/src/services/agent-cpp/main.cpp b/src/services/agent-cpp/main.cpp index 049293c..966703b 100644 --- a/src/services/agent-cpp/main.cpp +++ b/src/services/agent-cpp/main.cpp @@ -1,4 +1,5 @@ #include "CommandPoller.hpp" +#include "InferenceEngine.hpp" #include "LifecycleManager.hpp" #include "NetworkClient.hpp" #include "Tracing.hpp" @@ -108,6 +109,19 @@ bool GetEnvBool(const char* name, bool defaultValue) { return defaultValue; } +double GetEnvDouble(const char* name, double defaultValue) { + EnvValue value = GetEnvValue(name); + if (!value.found || value.value.empty()) { + return defaultValue; + } + + try { + return std::stod(value.value); + } catch (...) { + return defaultValue; + } +} + bool WaitForTlsFiles(const TlsSettings& settings, int timeoutSeconds) { if (settings.certPath.empty() || settings.keyPath.empty() || settings.caPath.empty()) { return false; @@ -207,6 +221,21 @@ int main() { std::cerr << "[Agent] Pre-flight check failed. Proceeding with heartbeat in IDLE state." << std::endl; } + InferenceRuntimeConfig inferenceConfig; + inferenceConfig.enabled = GetEnvBool("AG_M3_ONLINE_INFERENCE_ENABLED", false); + inferenceConfig.forceV22Fallback = GetEnvBool("AG_M3_FORCE_V22_FALLBACK", false); + inferenceConfig.failOpen = GetEnvBool("AG_ONNX_FAIL_OPEN", true); + inferenceConfig.modelPath = GetEnvOrDefault("AG_ONNX_MODEL_PATH", ""); + inferenceConfig.decisionThreshold = GetEnvDouble("AG_ONNX_DECISION_THRESHOLD", 0.65); + + InferenceEngine inferenceEngine(inferenceConfig); + const bool inferenceInitialized = inferenceEngine.Initialize(); + std::cout << "[Agent] Inference init status: " << inferenceEngine.InitializationStatus() << std::endl; + if (!inferenceInitialized && !inferenceConfig.failOpen) { + std::cerr << "[Agent] Inference initialization failed and fail-open is disabled. Exiting." << std::endl; + return 1; + } + const std::string tier = "T2"; const std::string state = "IDLE"; CommandDispatcher dispatcher(client, lifecycle, agentId); @@ -215,10 +244,19 @@ int main() { while (true) { std::vector commands; - bool heartbeatSent = client.SendHeartbeat(token, agentId, state, tier, commands); + SemanticHeartbeatFeatures semanticFeatures; + bool heartbeatSent = client.SendHeartbeat(token, agentId, state, tier, commands, &semanticFeatures); if (heartbeatSent) { std::cout << "[Agent] Heartbeat sent." << std::endl; + if (semanticFeatures.present) { + InferenceDecision decision = inferenceEngine.Evaluate(semanticFeatures); + std::cout << "[Agent][Inference] prob=" << decision.probability + << " preempt=" << (decision.shouldPreempt ? "yes" : "no") + << " mode=" << (decision.usedOnnxRuntime ? "onnx" : "fallback") + << " reason=" << decision.reason + << std::endl; + } } else { std::cerr << "[Agent] Failed to send heartbeat" << std::endl; } diff --git a/src/services/agent-cpp/tests/InferenceEngineTests.cpp b/src/services/agent-cpp/tests/InferenceEngineTests.cpp new file mode 100644 index 0000000..424459a --- /dev/null +++ b/src/services/agent-cpp/tests/InferenceEngineTests.cpp @@ -0,0 +1,76 @@ +#include "InferenceEngine.hpp" + +#include +#include + +namespace { +int Fail(const std::string& message) { + std::cerr << message << std::endl; + return 1; +} +} // namespace + +int main() { + SemanticHeartbeatFeatures semantic; + semantic.present = true; + semantic.sVNegative = 0.8; + semantic.sVNeutral = 0.1; + semantic.sVPositive = 0.1; + semantic.pV = 0.9; + semantic.bS = 0.1; + semantic.fallbackUsed = false; + + { + InferenceRuntimeConfig config; + config.enabled = false; + InferenceEngine engine(config); + if (!engine.Initialize()) { + return Fail("Initialization should succeed when feature gate is disabled."); + } + const auto decision = engine.Evaluate(semantic); + if (decision.enabled) { + return Fail("Decision should be marked disabled when feature gate is off."); + } + if (!decision.fallbackApplied) { + return Fail("Fallback should be applied when feature gate is disabled."); + } + } + + { + InferenceRuntimeConfig config; + config.enabled = true; + config.forceV22Fallback = true; + InferenceEngine engine(config); + if (!engine.Initialize()) { + return Fail("Initialization should succeed when rollback fallback is forced."); + } + const auto decision = engine.Evaluate(semantic); + if (!decision.fallbackApplied) { + return Fail("Rollback should force fallback mode."); + } + if (decision.reason.find("rollback_forced_v22_fallback") == std::string::npos) { + return Fail("Expected rollback reason in fallback decision."); + } + } + + { + InferenceRuntimeConfig config; + config.enabled = true; + config.failOpen = true; + config.decisionThreshold = 0.7; + config.modelPath = "missing-model.onnx"; + InferenceEngine engine(config); + if (!engine.Initialize()) { + return Fail("Initialization should fail-open when model file is missing."); + } + const auto decision = engine.Evaluate(semantic); + if (!decision.fallbackApplied) { + return Fail("Missing model should use fallback inference."); + } + if (!decision.shouldPreempt) { + return Fail("High-risk semantic vector should trigger preempt recommendation."); + } + } + + return 0; +} From e8178971a8ba4b317e2449578a6d9ec291925da1 Mon Sep 17 00:00:00 2001 From: JasonEran Date: Wed, 25 Feb 2026 10:56:37 +0800 Subject: [PATCH 20/24] feat(v2.3): add per-agent inference rollout gating and fallback --- README.md | 1 + docs/Core-Semantic-Rollout-v2.3-M3.md | 38 +++++++++ src/services/agent-cpp/NetworkClient.cpp | 1 + src/services/agent-cpp/NetworkClient.hpp | 1 + src/services/agent-cpp/main.cpp | 8 +- .../AetherGuard.Core.Tests.csproj | 1 + .../AgentWorkflowServiceTests.cs | 77 +++++++++++++++++++ .../Controllers/AgentController.cs | 3 +- .../core-dotnet/AetherGuard.Core/Program.cs | 2 + .../Services/AgentInferenceOptions.cs | 7 ++ .../Services/AgentWorkflowService.cs | 41 ++++++++-- src/shared/protos/agent_service.proto | 1 + 12 files changed, 174 insertions(+), 7 deletions(-) create mode 100644 docs/Core-Semantic-Rollout-v2.3-M3.md create mode 100644 src/services/core-dotnet/AetherGuard.Core.Tests/AgentWorkflowServiceTests.cs create mode 100644 src/services/core-dotnet/AetherGuard.Core/Services/AgentInferenceOptions.cs diff --git a/README.md b/README.md index 6a4cb16..4a95a53 100644 --- a/README.md +++ b/README.md @@ -204,6 +204,7 @@ Open the dashboard at http://localhost:3000. - v2.3 M2 model training scripts: scripts/model_training/README.md - v2.3 M3 heartbeat semantic payload contract: docs/PROTO-Heartbeat-Semantic-v2.3-M3.md - v2.3 M3 agent ONNX inference + gating: docs/Agent-ONNX-Inference-v2.3-M3.md +- v2.3 M3 core semantic rollout + per-agent gating: docs/Core-Semantic-Rollout-v2.3-M3.md If you want to simulate migrations, start at least two agents: diff --git a/docs/Core-Semantic-Rollout-v2.3-M3.md b/docs/Core-Semantic-Rollout-v2.3-M3.md new file mode 100644 index 0000000..ff46557 --- /dev/null +++ b/docs/Core-Semantic-Rollout-v2.3-M3.md @@ -0,0 +1,38 @@ +# v2.3 M3 Core Push + Per-Agent Gating + +This document captures issue #41. + +## Scope + +- Control-plane rollout gate for agent local inference. +- Safe fallback behavior preserved for v2.2 compatibility. + +## Changes + +- Protobuf (`agent_service.proto`) + - `AgentConfig.enable_local_inference = 6` (additive field). +- Core rollout options: + - `AgentInference:EnableLocalInferenceRollout` + - `AgentInference:RolloutPercentage` +- Register response config now carries `enableLocalInference`. +- Heartbeat semantic payload remains active and includes fallback vectors when enrichment is unavailable. + +## Rollout Logic + +`AgentWorkflowService` computes a deterministic per-agent rollout bucket from stable agent key (hostname): + +- Global rollout disabled => `enable_local_inference=false` +- Rollout 0 => disabled for all agents +- Rollout 100 => enabled for all agents +- Rollout N (1-99) => deterministic subset enabled + +## Compatibility + +- New proto field is additive; existing agents remain wire-compatible. +- Agent runtime still honors rollback (`AG_M3_FORCE_V22_FALLBACK`) and fail-open behavior. + +## Validation + +- New tests in `AetherGuard.Core.Tests`: + - rollout disabled -> local inference off + - rollout 100% -> local inference on diff --git a/src/services/agent-cpp/NetworkClient.cpp b/src/services/agent-cpp/NetworkClient.cpp index 04fb75d..771ef75 100644 --- a/src/services/agent-cpp/NetworkClient.cpp +++ b/src/services/agent-cpp/NetworkClient.cpp @@ -162,6 +162,7 @@ bool NetworkClient::Register( outConfig->enableEbpf = config.value("enableEbpf", false); outConfig->enableNetTopology = config.value("enableNetTopology", false); outConfig->enableChaos = config.value("enableChaos", false); + outConfig->enableLocalInference = config.value("enableLocalInference", false); outConfig->nodeMode = config.value("nodeMode", ""); } return !outToken.empty() && !outAgentId.empty(); diff --git a/src/services/agent-cpp/NetworkClient.hpp b/src/services/agent-cpp/NetworkClient.hpp index 5e12e6b..bf174c2 100644 --- a/src/services/agent-cpp/NetworkClient.hpp +++ b/src/services/agent-cpp/NetworkClient.hpp @@ -42,6 +42,7 @@ struct AgentConfig { bool enableEbpf = false; bool enableNetTopology = false; bool enableChaos = false; + bool enableLocalInference = false; std::string nodeMode; }; diff --git a/src/services/agent-cpp/main.cpp b/src/services/agent-cpp/main.cpp index 966703b..95b37d4 100644 --- a/src/services/agent-cpp/main.cpp +++ b/src/services/agent-cpp/main.cpp @@ -207,6 +207,7 @@ int main() { if (!agentConfig.nodeMode.empty()) { std::cout << "[Agent] Config: snapshot=" << (agentConfig.enableSnapshot ? "on" : "off") << ", ebpf=" << (agentConfig.enableEbpf ? "on" : "off") + << ", local_inference=" << (agentConfig.enableLocalInference ? "on" : "off") << ", node_mode=" << agentConfig.nodeMode << std::endl; } break; @@ -222,12 +223,17 @@ int main() { } InferenceRuntimeConfig inferenceConfig; - inferenceConfig.enabled = GetEnvBool("AG_M3_ONLINE_INFERENCE_ENABLED", false); + const bool localInferenceRequested = GetEnvBool("AG_M3_ONLINE_INFERENCE_ENABLED", false); + inferenceConfig.enabled = localInferenceRequested && agentConfig.enableLocalInference; inferenceConfig.forceV22Fallback = GetEnvBool("AG_M3_FORCE_V22_FALLBACK", false); inferenceConfig.failOpen = GetEnvBool("AG_ONNX_FAIL_OPEN", true); inferenceConfig.modelPath = GetEnvOrDefault("AG_ONNX_MODEL_PATH", ""); inferenceConfig.decisionThreshold = GetEnvDouble("AG_ONNX_DECISION_THRESHOLD", 0.65); + if (localInferenceRequested && !agentConfig.enableLocalInference) { + std::cout << "[Agent] Local inference requested but disabled by control-plane gate." << std::endl; + } + InferenceEngine inferenceEngine(inferenceConfig); const bool inferenceInitialized = inferenceEngine.Initialize(); std::cout << "[Agent] Inference init status: " << inferenceEngine.InitializationStatus() << std::endl; diff --git a/src/services/core-dotnet/AetherGuard.Core.Tests/AetherGuard.Core.Tests.csproj b/src/services/core-dotnet/AetherGuard.Core.Tests/AetherGuard.Core.Tests.csproj index b162747..ddb0829 100644 --- a/src/services/core-dotnet/AetherGuard.Core.Tests/AetherGuard.Core.Tests.csproj +++ b/src/services/core-dotnet/AetherGuard.Core.Tests/AetherGuard.Core.Tests.csproj @@ -10,6 +10,7 @@ + diff --git a/src/services/core-dotnet/AetherGuard.Core.Tests/AgentWorkflowServiceTests.cs b/src/services/core-dotnet/AetherGuard.Core.Tests/AgentWorkflowServiceTests.cs new file mode 100644 index 0000000..1ad1486 --- /dev/null +++ b/src/services/core-dotnet/AetherGuard.Core.Tests/AgentWorkflowServiceTests.cs @@ -0,0 +1,77 @@ +using System; +using System.Threading; +using System.Threading.Tasks; +using AetherGuard.Core.Data; +using AetherGuard.Core.Services; +using AetherGuard.Grpc.V1; +using Microsoft.EntityFrameworkCore; +using Microsoft.Extensions.Logging.Abstractions; +using Microsoft.Extensions.Options; +using Xunit; + +namespace AetherGuard.Core.Tests; + +public class AgentWorkflowServiceTests +{ + [Fact] + public async Task RegisterAsync_DisablesLocalInference_WhenRolloutDisabled() + { + await using var db = CreateDbContext(); + var service = CreateService( + db, + new AgentInferenceOptions + { + EnableLocalInferenceRollout = false, + RolloutPercentage = 100 + }); + + var result = await service.RegisterAsync(new RegisterRequest + { + Hostname = "agent-rollout-disabled" + }, CancellationToken.None); + + Assert.True(result.Success); + Assert.NotNull(result.Payload); + Assert.False(result.Payload!.Config.EnableLocalInference); + } + + [Fact] + public async Task RegisterAsync_EnablesLocalInference_WhenRolloutAt100Percent() + { + await using var db = CreateDbContext(); + var service = CreateService( + db, + new AgentInferenceOptions + { + EnableLocalInferenceRollout = true, + RolloutPercentage = 100 + }); + + var result = await service.RegisterAsync(new RegisterRequest + { + Hostname = "agent-rollout-100" + }, CancellationToken.None); + + Assert.True(result.Success); + Assert.NotNull(result.Payload); + Assert.True(result.Payload!.Config.EnableLocalInference); + } + + private static AgentWorkflowService CreateService( + ApplicationDbContext db, + AgentInferenceOptions options) + { + return new AgentWorkflowService( + db, + NullLogger.Instance, + Options.Create(options)); + } + + private static ApplicationDbContext CreateDbContext() + { + var dbOptions = new DbContextOptionsBuilder() + .UseInMemoryDatabase($"agent-workflow-tests-{Guid.NewGuid():N}") + .Options; + return new ApplicationDbContext(dbOptions); + } +} diff --git a/src/services/core-dotnet/AetherGuard.Core/Controllers/AgentController.cs b/src/services/core-dotnet/AetherGuard.Core/Controllers/AgentController.cs index 40f7b0f..a21af7e 100644 --- a/src/services/core-dotnet/AetherGuard.Core/Controllers/AgentController.cs +++ b/src/services/core-dotnet/AetherGuard.Core/Controllers/AgentController.cs @@ -52,7 +52,8 @@ public async Task RegisterAgent([FromBody] RegisterAgentRequest r enableEbpf = config.EnableEbpf, enableNetTopology = config.EnableNetTopology, enableChaos = config.EnableChaos, - nodeMode = config.NodeMode + nodeMode = config.NodeMode, + enableLocalInference = config.EnableLocalInference }; return Ok(new { token = result.Payload?.Token, agentId = result.Payload?.AgentId, config = configPayload }); diff --git a/src/services/core-dotnet/AetherGuard.Core/Program.cs b/src/services/core-dotnet/AetherGuard.Core/Program.cs index c681211..812ce1c 100644 --- a/src/services/core-dotnet/AetherGuard.Core/Program.cs +++ b/src/services/core-dotnet/AetherGuard.Core/Program.cs @@ -78,6 +78,8 @@ builder.Services.AddHostedService(); builder.Services.Configure( builder.Configuration.GetSection("ExternalSignals")); +builder.Services.Configure( + builder.Configuration.GetSection("AgentInference")); builder.Services.AddHostedService(); var otelOptions = builder.Configuration.GetSection("OpenTelemetry").Get() diff --git a/src/services/core-dotnet/AetherGuard.Core/Services/AgentInferenceOptions.cs b/src/services/core-dotnet/AetherGuard.Core/Services/AgentInferenceOptions.cs new file mode 100644 index 0000000..abab91b --- /dev/null +++ b/src/services/core-dotnet/AetherGuard.Core/Services/AgentInferenceOptions.cs @@ -0,0 +1,7 @@ +namespace AetherGuard.Core.Services; + +public sealed class AgentInferenceOptions +{ + public bool EnableLocalInferenceRollout { get; set; } + public int RolloutPercentage { get; set; } +} diff --git a/src/services/core-dotnet/AetherGuard.Core/Services/AgentWorkflowService.cs b/src/services/core-dotnet/AetherGuard.Core/Services/AgentWorkflowService.cs index e6a1576..33ba140 100644 --- a/src/services/core-dotnet/AetherGuard.Core/Services/AgentWorkflowService.cs +++ b/src/services/core-dotnet/AetherGuard.Core/Services/AgentWorkflowService.cs @@ -4,6 +4,7 @@ using AetherGuard.Grpc.V1; using Microsoft.AspNetCore.Http; using Microsoft.EntityFrameworkCore; +using Microsoft.Extensions.Options; namespace AetherGuard.Core.Services; @@ -11,11 +12,16 @@ public class AgentWorkflowService { private readonly ApplicationDbContext _context; private readonly ILogger _logger; + private readonly AgentInferenceOptions _inferenceOptions; - public AgentWorkflowService(ApplicationDbContext context, ILogger logger) + public AgentWorkflowService( + ApplicationDbContext context, + ILogger logger, + IOptions inferenceOptions) { _context = context; _logger = logger; + _inferenceOptions = inferenceOptions.Value; } public async Task> RegisterAsync(RegisterRequest request, CancellationToken cancellationToken) @@ -37,7 +43,7 @@ public async Task> RegisterAsync(RegisterRequest req { Token = existingAgent.AgentToken, AgentId = existingAgent.Id.ToString(), - Config = BuildAgentConfig(request.Capabilities) + Config = BuildAgentConfig(request.Capabilities, existingAgent.Hostname) }); } @@ -57,25 +63,50 @@ public async Task> RegisterAsync(RegisterRequest req { Token = agent.AgentToken, AgentId = agent.Id.ToString(), - Config = BuildAgentConfig(request.Capabilities) + Config = BuildAgentConfig(request.Capabilities, agent.Hostname) }); } - private static AgentConfig BuildAgentConfig(AgentCapabilities? capabilities) + private AgentConfig BuildAgentConfig(AgentCapabilities? capabilities, string stableAgentKey) { var criuAvailable = capabilities?.CriuAvailable == true; + var inferenceEnabled = IsInferenceEnabledForAgent(stableAgentKey); var config = new AgentConfig { EnableSnapshot = criuAvailable && capabilities?.SupportsSnapshot != false, EnableEbpf = capabilities?.EbpfAvailable == true, EnableNetTopology = capabilities?.SupportsNetTopology == true, EnableChaos = capabilities?.SupportsChaos == true, - NodeMode = criuAvailable ? "STATEFUL" : "STATELESS" + NodeMode = criuAvailable ? "STATEFUL" : "STATELESS", + EnableLocalInference = inferenceEnabled }; return config; } + private bool IsInferenceEnabledForAgent(string stableAgentKey) + { + if (!_inferenceOptions.EnableLocalInferenceRollout) + { + return false; + } + + var rolloutPercentage = Math.Clamp(_inferenceOptions.RolloutPercentage, 0, 100); + if (rolloutPercentage <= 0) + { + return false; + } + + if (rolloutPercentage >= 100) + { + return true; + } + + var key = string.IsNullOrWhiteSpace(stableAgentKey) ? "unknown-agent" : stableAgentKey.Trim(); + var bucket = (int)((uint)StringComparer.Ordinal.GetHashCode(key) % 100); + return bucket < rolloutPercentage; + } + public async Task> HeartbeatAsync(HeartbeatRequest request, CancellationToken cancellationToken) { if (request is null || string.IsNullOrWhiteSpace(request.Token)) diff --git a/src/shared/protos/agent_service.proto b/src/shared/protos/agent_service.proto index 2e69af8..429a10f 100644 --- a/src/shared/protos/agent_service.proto +++ b/src/shared/protos/agent_service.proto @@ -75,6 +75,7 @@ message AgentConfig { bool enable_net_topology = 3; bool enable_chaos = 4; string node_mode = 5; + bool enable_local_inference = 6; } message SemanticHeartbeatFeatures { From 71b66d419d78217bda4b27ef1b2a1cd1efeece78 Mon Sep 17 00:00:00 2001 From: JasonEran Date: Wed, 25 Feb 2026 11:04:53 +0800 Subject: [PATCH 21/24] feat(v2.3): add canary rollback runbook and evaluator --- README.md | 2 + docs/QA-Canary-Rollback-v2.3-M3.md | 56 +++++++ scripts/qa/README.md | 24 +++ scripts/qa/evaluate_m3_canary.py | 247 +++++++++++++++++++++++++++++ 4 files changed, 329 insertions(+) create mode 100644 docs/QA-Canary-Rollback-v2.3-M3.md create mode 100644 scripts/qa/README.md create mode 100644 scripts/qa/evaluate_m3_canary.py diff --git a/README.md b/README.md index 4a95a53..d1959ab 100644 --- a/README.md +++ b/README.md @@ -205,6 +205,8 @@ Open the dashboard at http://localhost:3000. - v2.3 M3 heartbeat semantic payload contract: docs/PROTO-Heartbeat-Semantic-v2.3-M3.md - v2.3 M3 agent ONNX inference + gating: docs/Agent-ONNX-Inference-v2.3-M3.md - v2.3 M3 core semantic rollout + per-agent gating: docs/Core-Semantic-Rollout-v2.3-M3.md +- v2.3 M3 canary + rollback plan: docs/QA-Canary-Rollback-v2.3-M3.md +- v2.3 M3 canary evaluator script: scripts/qa/README.md If you want to simulate migrations, start at least two agents: diff --git a/docs/QA-Canary-Rollback-v2.3-M3.md b/docs/QA-Canary-Rollback-v2.3-M3.md new file mode 100644 index 0000000..0a6d654 --- /dev/null +++ b/docs/QA-Canary-Rollback-v2.3-M3.md @@ -0,0 +1,56 @@ +# v2.3 M3 Canary + Rollback Plan + +This document captures issue #42. + +## Goal + +Ship online inference safely with explicit rollout gates and automatic rollback triggers. + +## Canary Plan + +1. **Stage 0 (shadow)** + Enable heartbeat semantic payload and local inference logging only. +2. **Stage 1 (1-5%)** + Set `AgentInference:EnableLocalInferenceRollout=true`, `RolloutPercentage=5`. +3. **Stage 2 (10-25%)** + Increase rollout only if no rollback trigger is hit for two full canary windows. +4. **Stage 3 (50%+)** + Expand progressively with the same guardrails. +5. **Stage 4 (100%)** + Promote after stable windows and no critical incidents. + +## Automated Rollback Triggers + +Use `scripts/qa/evaluate_m3_canary.py` with canary metrics input. + +Rollback is required when any critical threshold is breached: + +- `critical_incident_count > 0` +- `heartbeat_failure_rate > 0.05` +- `inference_error_rate > 0.02` +- `p95_inference_latency_ms > 50` +- `false_positive_rate_delta > 0.10` +- `preempt_decision_rate_delta > 0.15` + +## Automation Output + +Script output: + +- JSON decision artifact (`promote` / `hold` / `rollback`) +- optional markdown summary +- exit code suitable for CI gate: + - `0` promote + - `10` hold + - `20` rollback + +## Rollback Procedure (Immediate) + +1. Set `AG_M3_FORCE_V22_FALLBACK=true` on canary agents. +2. Disable local inference rollout (`AgentInference:RolloutPercentage=0`). +3. Keep semantic heartbeat transport enabled for observability. +4. Open incident ticket and attach decision artifacts. + +## Evidence Artifacts + +- Script: `scripts/qa/evaluate_m3_canary.py` +- Script usage: `scripts/qa/README.md` diff --git a/scripts/qa/README.md b/scripts/qa/README.md new file mode 100644 index 0000000..7998fc4 --- /dev/null +++ b/scripts/qa/README.md @@ -0,0 +1,24 @@ +# v2.3 M3 Canary QA Scripts + +## `evaluate_m3_canary.py` + +Evaluate canary window metrics and produce a deterministic decision: + +- `promote` +- `hold` +- `rollback` + +### Usage + +```bash +python scripts/qa/evaluate_m3_canary.py \ + --input .tmp/canary-input.json \ + --output .tmp/canary-decision.json \ + --summary-md .tmp/canary-decision.md +``` + +Exit codes: + +- `0` => promote +- `10` => hold +- `20` => rollback diff --git a/scripts/qa/evaluate_m3_canary.py b/scripts/qa/evaluate_m3_canary.py new file mode 100644 index 0000000..0c57fdf --- /dev/null +++ b/scripts/qa/evaluate_m3_canary.py @@ -0,0 +1,247 @@ +#!/usr/bin/env python3 +"""Evaluate Milestone 3 canary metrics and emit promote/hold/rollback decision.""" + +from __future__ import annotations + +import argparse +import json +from dataclasses import dataclass +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + + +@dataclass(frozen=True) +class Threshold: + warning: float | int | None + rollback: float | int + direction: str # "upper_is_bad" or "lower_is_bad" + label: str + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--input", required=True, help="Input canary metrics JSON.") + parser.add_argument("--output", required=True, help="Output decision JSON.") + parser.add_argument( + "--summary-md", + default="", + help="Optional markdown summary output path.", + ) + return parser.parse_args() + + +def now_utc_iso() -> str: + return datetime.now(tz=timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") + + +def load_json(path: Path) -> dict[str, Any]: + with path.open("r", encoding="utf-8-sig") as handle: + payload = json.load(handle) + if not isinstance(payload, dict): + raise ValueError("Input JSON must be an object.") + return payload + + +def write_json(path: Path, payload: dict[str, Any]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w", encoding="utf-8") as handle: + json.dump(payload, handle, indent=2, sort_keys=True) + handle.write("\n") + + +def resolve_metric(metrics: dict[str, Any], key: str) -> float: + value = metrics.get(key) + if value is None: + return float("nan") + try: + return float(value) + except (TypeError, ValueError): + return float("nan") + + +def evaluate_metric( + *, + metric_key: str, + value: float, + threshold: Threshold, +) -> dict[str, Any]: + if value != value: # NaN + return { + "metric": metric_key, + "label": threshold.label, + "status": "missing", + "severity": "warning", + "value": None, + "warning_threshold": threshold.warning, + "rollback_threshold": threshold.rollback, + "message": f"Metric {metric_key} is missing.", + } + + if threshold.direction == "upper_is_bad": + rollback_hit = value > float(threshold.rollback) + warning_hit = threshold.warning is not None and value > float(threshold.warning) + else: + rollback_hit = value < float(threshold.rollback) + warning_hit = threshold.warning is not None and value < float(threshold.warning) + + if rollback_hit: + status = "rollback" + severity = "critical" + elif warning_hit: + status = "warning" + severity = "warning" + else: + status = "ok" + severity = "info" + + return { + "metric": metric_key, + "label": threshold.label, + "status": status, + "severity": severity, + "value": value, + "warning_threshold": threshold.warning, + "rollback_threshold": threshold.rollback, + "message": f"{metric_key}={value}", + } + + +def build_thresholds() -> dict[str, Threshold]: + return { + "critical_incident_count": Threshold( + warning=0, + rollback=0, + direction="upper_is_bad", + label="Critical incidents", + ), + "heartbeat_failure_rate": Threshold( + warning=0.02, + rollback=0.05, + direction="upper_is_bad", + label="Heartbeat failure rate", + ), + "inference_error_rate": Threshold( + warning=0.01, + rollback=0.02, + direction="upper_is_bad", + label="Inference error rate", + ), + "p95_inference_latency_ms": Threshold( + warning=35.0, + rollback=50.0, + direction="upper_is_bad", + label="P95 inference latency (ms)", + ), + "false_positive_rate_delta": Threshold( + warning=0.06, + rollback=0.10, + direction="upper_is_bad", + label="False positive delta vs v2.2", + ), + "preempt_decision_rate_delta": Threshold( + warning=0.10, + rollback=0.15, + direction="upper_is_bad", + label="Preempt decision rate delta vs v2.2", + ), + } + + +def decide(checks: list[dict[str, Any]]) -> tuple[str, bool, list[str]]: + has_rollback = any(item["status"] == "rollback" for item in checks) + has_warning_or_missing = any(item["status"] in ("warning", "missing") for item in checks) + + if has_rollback: + return ( + "rollback", + True, + [ + "Disable AG_M3_ONLINE_INFERENCE_ENABLED for canary agents.", + "Set AG_M3_FORCE_V22_FALLBACK=true for immediate rollback.", + "Open incident and attach canary_decision.json for audit.", + ], + ) + + if has_warning_or_missing: + return ( + "hold", + False, + [ + "Keep canary scope unchanged.", + "Collect one more observation window and re-evaluate.", + ], + ) + + return ( + "promote", + False, + [ + "Increase rollout percentage in AgentInference options.", + "Continue monitoring rollback guardrails.", + ], + ) + + +def write_summary_markdown(path: Path, payload: dict[str, Any]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w", encoding="utf-8") as handle: + handle.write("# M3 Canary Decision Summary\n\n") + handle.write(f"- Generated at UTC: {payload['generated_at_utc']}\n") + handle.write(f"- Decision: **{payload['decision']}**\n") + handle.write(f"- Rollback required: {payload['rollback_required']}\n\n") + handle.write("| Metric | Value | Status | Rollback Threshold |\n") + handle.write("|---|---:|---|---:|\n") + for check in payload["checks"]: + value = "n/a" if check["value"] is None else f"{check['value']:.6f}" + handle.write( + f"| {check['metric']} | {value} | {check['status']} | {check['rollback_threshold']} |\n" + ) + handle.write("\n## Actions\n\n") + for action in payload["actions"]: + handle.write(f"- {action}\n") + + +def main() -> int: + args = parse_args() + input_path = Path(args.input) + output_path = Path(args.output) + payload = load_json(input_path) + metrics = payload.get("metrics") + if not isinstance(metrics, dict): + raise ValueError("Input JSON must contain object field: metrics") + + checks: list[dict[str, Any]] = [] + for key, threshold in build_thresholds().items(): + value = resolve_metric(metrics, key) + checks.append(evaluate_metric(metric_key=key, value=value, threshold=threshold)) + + decision, rollback_required, actions = decide(checks) + decision_payload = { + "generated_at_utc": now_utc_iso(), + "input_path": input_path.as_posix(), + "decision": decision, + "rollback_required": rollback_required, + "checks": checks, + "actions": actions, + } + write_json(output_path, decision_payload) + + if args.summary_md: + write_summary_markdown(Path(args.summary_md), decision_payload) + + print(f"Decision: {decision}") + print(f"Rollback required: {rollback_required}") + print(f"Decision file: {output_path}") + if args.summary_md: + print(f"Summary file: {args.summary_md}") + + if decision == "rollback": + return 20 + if decision == "hold": + return 10 + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) From 112ab1a05505041dcd3fdc48949afdeb723994c0 Mon Sep 17 00:00:00 2001 From: JasonEran Date: Wed, 25 Feb 2026 11:14:53 +0800 Subject: [PATCH 22/24] feat(v2.3): implement dynamic risk alpha policy with guardrails --- README.md | 1 + docs/Core-Dynamic-Risk-v2.3-M4.md | 61 ++++++++ .../DynamicRiskPolicyTests.cs | 129 ++++++++++++++++ .../core-dotnet/AetherGuard.Core/Program.cs | 3 + .../Services/DynamicRiskOptions.cs | 13 ++ .../Services/DynamicRiskPolicy.cs | 133 +++++++++++++++++ .../Services/MigrationOrchestrator.cs | 140 ++++++++++++++++-- 7 files changed, 469 insertions(+), 11 deletions(-) create mode 100644 docs/Core-Dynamic-Risk-v2.3-M4.md create mode 100644 src/services/core-dotnet/AetherGuard.Core.Tests/DynamicRiskPolicyTests.cs create mode 100644 src/services/core-dotnet/AetherGuard.Core/Services/DynamicRiskOptions.cs create mode 100644 src/services/core-dotnet/AetherGuard.Core/Services/DynamicRiskPolicy.cs diff --git a/README.md b/README.md index d1959ab..ac793df 100644 --- a/README.md +++ b/README.md @@ -207,6 +207,7 @@ Open the dashboard at http://localhost:3000. - v2.3 M3 core semantic rollout + per-agent gating: docs/Core-Semantic-Rollout-v2.3-M3.md - v2.3 M3 canary + rollback plan: docs/QA-Canary-Rollback-v2.3-M3.md - v2.3 M3 canary evaluator script: scripts/qa/README.md +- v2.3 M4 dynamic risk allocation (core): docs/Core-Dynamic-Risk-v2.3-M4.md If you want to simulate migrations, start at least two agents: diff --git a/docs/Core-Dynamic-Risk-v2.3-M4.md b/docs/Core-Dynamic-Risk-v2.3-M4.md new file mode 100644 index 0000000..c6749fd --- /dev/null +++ b/docs/Core-Dynamic-Risk-v2.3-M4.md @@ -0,0 +1,61 @@ +# v2.3 M4 Dynamic Risk Allocation (Core) + +This document captures issue #43. + +## Alpha Computation + +Core computes dynamic `alpha` via `DynamicRiskPolicy`: + +``` +sentiment_pressure = max(0, S_neg - S_pos) +alpha = clamp( + base_alpha + + volatility_weight * P_v + + sentiment_weight * sentiment_pressure, + min_alpha, + max_alpha +) +``` + +When `rebalanceSignal=true`, `alpha` is forced to `max_alpha`. + +## Decision Score + +``` +decision_score = clamp(P_preempt * alpha, 0, 1) +migrate if decision_score >= decision_threshold +``` + +`P_preempt` is derived from AI analysis confidence/prediction with fallback handling for `Unavailable`. + +## Guardrails + +Implemented in migration orchestration: + +- **Cooldown guardrail**: block migration if source agent migrated within `CooldownMinutes`. +- **Max-rate guardrail**: block migration if completed migrations in the last hour exceed `MaxMigrationsPerHour`. + +Guardrail blocks take precedence over score-based migration. + +## Config Section + +`DynamicRisk`: + +- `BaseAlpha` +- `VolatilityWeight` +- `SentimentWeight` +- `MinAlpha` +- `MaxAlpha` +- `DecisionThreshold` +- `CooldownMinutes` +- `MaxMigrationsPerHour` + +## Tests + +`DynamicRiskPolicyTests` cover edge cases: + +- alpha upper clamp +- alpha lower clamp +- cooldown guardrail block +- max-rate guardrail block +- threshold crossing migration decision diff --git a/src/services/core-dotnet/AetherGuard.Core.Tests/DynamicRiskPolicyTests.cs b/src/services/core-dotnet/AetherGuard.Core.Tests/DynamicRiskPolicyTests.cs new file mode 100644 index 0000000..5a59c17 --- /dev/null +++ b/src/services/core-dotnet/AetherGuard.Core.Tests/DynamicRiskPolicyTests.cs @@ -0,0 +1,129 @@ +using AetherGuard.Core.Services; +using Xunit; + +namespace AetherGuard.Core.Tests; + +public class DynamicRiskPolicyTests +{ + [Fact] + public void ComputeAlpha_ClampsToMax() + { + var options = new DynamicRiskOptions + { + BaseAlpha = 1.4, + VolatilityWeight = 0.8, + SentimentWeight = 0.6, + MinAlpha = 0.5, + MaxAlpha = 1.6 + }; + + var alpha = DynamicRiskPolicy.ComputeAlpha( + options, + new DynamicRiskInput( + PreemptProbability: 0.6, + RebalanceSignal: false, + VolatilityProbability: 1.0, + SentimentNegative: 1.0, + SentimentPositive: 0.0)); + + Assert.Equal(1.6, alpha, 3); + } + + [Fact] + public void Evaluate_ReturnsCooldownBlock_WhenCooldownGuardrailActive() + { + var policy = new DynamicRiskPolicy(new DynamicRiskOptions()); + var decision = policy.Evaluate( + new DynamicRiskInput( + PreemptProbability: 0.9, + RebalanceSignal: false, + VolatilityProbability: 0.8, + SentimentNegative: 0.9, + SentimentPositive: 0.1), + new RiskGuardrailState( + CooldownActive: true, + MaxRateExceeded: false, + RecentMigrationsLastHour: 1, + MaxMigrationsPerHour: 30)); + + Assert.False(decision.ShouldMigrate); + Assert.Equal("guardrail_cooldown_active", decision.Reason); + } + + [Fact] + public void Evaluate_ReturnsMaxRateBlock_WhenRateGuardrailExceeded() + { + var policy = new DynamicRiskPolicy(new DynamicRiskOptions()); + var decision = policy.Evaluate( + new DynamicRiskInput( + PreemptProbability: 0.9, + RebalanceSignal: false, + VolatilityProbability: 0.8, + SentimentNegative: 0.9, + SentimentPositive: 0.1), + new RiskGuardrailState( + CooldownActive: false, + MaxRateExceeded: true, + RecentMigrationsLastHour: 31, + MaxMigrationsPerHour: 30)); + + Assert.False(decision.ShouldMigrate); + Assert.Equal("guardrail_max_rate_exceeded", decision.Reason); + } + + [Fact] + public void Evaluate_ReturnsMigrate_WhenScoreExceedsThreshold() + { + var policy = new DynamicRiskPolicy(new DynamicRiskOptions + { + BaseAlpha = 1.0, + VolatilityWeight = 0.4, + SentimentWeight = 0.3, + MinAlpha = 0.5, + MaxAlpha = 1.6, + DecisionThreshold = 0.65 + }); + + var decision = policy.Evaluate( + new DynamicRiskInput( + PreemptProbability: 0.7, + RebalanceSignal: false, + VolatilityProbability: 0.9, + SentimentNegative: 0.8, + SentimentPositive: 0.2), + new RiskGuardrailState( + CooldownActive: false, + MaxRateExceeded: false, + RecentMigrationsLastHour: 0, + MaxMigrationsPerHour: 30)); + + Assert.True(decision.ShouldMigrate); + Assert.Equal("decision_score_above_threshold", decision.Reason); + } + + [Fact] + public void Evaluate_RebalanceSignal_UsesMaxAlpha() + { + var policy = new DynamicRiskPolicy(new DynamicRiskOptions + { + MaxAlpha = 1.4, + DecisionThreshold = 0.9 + }); + + var decision = policy.Evaluate( + new DynamicRiskInput( + PreemptProbability: 0.2, + RebalanceSignal: true, + VolatilityProbability: 0.1, + SentimentNegative: 0.1, + SentimentPositive: 0.9), + new RiskGuardrailState( + CooldownActive: false, + MaxRateExceeded: false, + RecentMigrationsLastHour: 0, + MaxMigrationsPerHour: 30)); + + Assert.Equal(1.4, decision.Alpha, 3); + Assert.True(decision.ShouldMigrate); + } +} diff --git a/src/services/core-dotnet/AetherGuard.Core/Program.cs b/src/services/core-dotnet/AetherGuard.Core/Program.cs index 812ce1c..5cadf5a 100644 --- a/src/services/core-dotnet/AetherGuard.Core/Program.cs +++ b/src/services/core-dotnet/AetherGuard.Core/Program.cs @@ -68,6 +68,7 @@ builder.Services.AddScoped(); builder.Services.AddScoped(); builder.Services.AddScoped(); +builder.Services.AddSingleton(); builder.Services.AddScoped(); builder.Services.AddSingleton(); builder.Services.AddSingleton(); @@ -80,6 +81,8 @@ builder.Configuration.GetSection("ExternalSignals")); builder.Services.Configure( builder.Configuration.GetSection("AgentInference")); +builder.Services.Configure( + builder.Configuration.GetSection("DynamicRisk")); builder.Services.AddHostedService(); var otelOptions = builder.Configuration.GetSection("OpenTelemetry").Get() diff --git a/src/services/core-dotnet/AetherGuard.Core/Services/DynamicRiskOptions.cs b/src/services/core-dotnet/AetherGuard.Core/Services/DynamicRiskOptions.cs new file mode 100644 index 0000000..023cb66 --- /dev/null +++ b/src/services/core-dotnet/AetherGuard.Core/Services/DynamicRiskOptions.cs @@ -0,0 +1,13 @@ +namespace AetherGuard.Core.Services; + +public sealed class DynamicRiskOptions +{ + public double BaseAlpha { get; set; } = 0.8; + public double VolatilityWeight { get; set; } = 0.4; + public double SentimentWeight { get; set; } = 0.3; + public double MinAlpha { get; set; } = 0.5; + public double MaxAlpha { get; set; } = 1.6; + public double DecisionThreshold { get; set; } = 0.75; + public int CooldownMinutes { get; set; } = 2; + public int MaxMigrationsPerHour { get; set; } = 30; +} diff --git a/src/services/core-dotnet/AetherGuard.Core/Services/DynamicRiskPolicy.cs b/src/services/core-dotnet/AetherGuard.Core/Services/DynamicRiskPolicy.cs new file mode 100644 index 0000000..1f860a7 --- /dev/null +++ b/src/services/core-dotnet/AetherGuard.Core/Services/DynamicRiskPolicy.cs @@ -0,0 +1,133 @@ +using Microsoft.Extensions.Options; + +namespace AetherGuard.Core.Services; + +public sealed record DynamicRiskInput( + double PreemptProbability, + bool RebalanceSignal, + double VolatilityProbability, + double SentimentNegative, + double SentimentPositive); + +public sealed record RiskGuardrailState( + bool CooldownActive, + bool MaxRateExceeded, + int RecentMigrationsLastHour, + int MaxMigrationsPerHour); + +public sealed record DynamicRiskDecision( + bool ShouldMigrate, + double Alpha, + double DecisionScore, + string Reason); + +public sealed class DynamicRiskPolicy +{ + private readonly DynamicRiskOptions _options; + + public DynamicRiskPolicy(IOptions options) + : this(options.Value) + { + } + + public DynamicRiskPolicy(DynamicRiskOptions options) + { + _options = NormalizeOptions(options); + } + + public DynamicRiskOptions Options => _options; + + public DynamicRiskDecision Evaluate( + DynamicRiskInput input, + RiskGuardrailState guardrails) + { + if (guardrails.CooldownActive) + { + return new DynamicRiskDecision( + ShouldMigrate: false, + Alpha: 0.0, + DecisionScore: 0.0, + Reason: "guardrail_cooldown_active"); + } + + if (guardrails.MaxRateExceeded) + { + return new DynamicRiskDecision( + ShouldMigrate: false, + Alpha: 0.0, + DecisionScore: 0.0, + Reason: "guardrail_max_rate_exceeded"); + } + + var alpha = ComputeAlpha(_options, input); + var preemptProbability = Clamp01(input.RebalanceSignal + ? Math.Max(input.PreemptProbability, 1.0) + : input.PreemptProbability); + var decisionScore = Clamp01(preemptProbability * alpha); + var shouldMigrate = decisionScore >= _options.DecisionThreshold; + + return new DynamicRiskDecision( + ShouldMigrate: shouldMigrate, + Alpha: alpha, + DecisionScore: decisionScore, + Reason: shouldMigrate ? "decision_score_above_threshold" : "decision_score_below_threshold"); + } + + public static double ComputeAlpha(DynamicRiskOptions options, DynamicRiskInput input) + { + var normalized = NormalizeOptions(options); + return ComputeAlphaInternal(normalized, input); + } + + private static double ComputeAlphaInternal(DynamicRiskOptions options, DynamicRiskInput input) + { + if (input.RebalanceSignal) + { + return options.MaxAlpha; + } + + var volatility = Clamp01(input.VolatilityProbability); + var sentimentPressure = Math.Max(0.0, Clamp01(input.SentimentNegative) - Clamp01(input.SentimentPositive)); + var alpha = options.BaseAlpha + + options.VolatilityWeight * volatility + + options.SentimentWeight * sentimentPressure; + + return Clamp(alpha, options.MinAlpha, options.MaxAlpha); + } + + private static DynamicRiskOptions NormalizeOptions(DynamicRiskOptions options) + { + var minAlpha = options.MinAlpha <= 0 ? 0.1 : options.MinAlpha; + var maxAlpha = options.MaxAlpha < minAlpha ? minAlpha : options.MaxAlpha; + + return new DynamicRiskOptions + { + BaseAlpha = Clamp(options.BaseAlpha, minAlpha, maxAlpha), + VolatilityWeight = Math.Max(0.0, options.VolatilityWeight), + SentimentWeight = Math.Max(0.0, options.SentimentWeight), + MinAlpha = minAlpha, + MaxAlpha = maxAlpha, + DecisionThreshold = Clamp01(options.DecisionThreshold), + CooldownMinutes = Math.Max(0, options.CooldownMinutes), + MaxMigrationsPerHour = Math.Max(1, options.MaxMigrationsPerHour) + }; + } + + private static double Clamp01(double value) + => Clamp(value, 0.0, 1.0); + + private static double Clamp(double value, double min, double max) + { + if (value < min) + { + return min; + } + + if (value > max) + { + return max; + } + + return value; + } +} diff --git a/src/services/core-dotnet/AetherGuard.Core/Services/MigrationOrchestrator.cs b/src/services/core-dotnet/AetherGuard.Core/Services/MigrationOrchestrator.cs index 3f57092..7e1418a 100644 --- a/src/services/core-dotnet/AetherGuard.Core/Services/MigrationOrchestrator.cs +++ b/src/services/core-dotnet/AetherGuard.Core/Services/MigrationOrchestrator.cs @@ -15,6 +15,7 @@ public class MigrationOrchestrator private readonly IServiceScopeFactory _serviceScopeFactory; private readonly CommandService _commandService; private readonly AnalysisService _analysisService; + private readonly DynamicRiskPolicy _dynamicRiskPolicy; private readonly SnapshotStorageService _snapshotStorage; private readonly IConfiguration _configuration; private readonly IWebHostEnvironment _environment; @@ -24,6 +25,7 @@ public MigrationOrchestrator( IServiceScopeFactory serviceScopeFactory, CommandService commandService, AnalysisService analysisService, + DynamicRiskPolicy dynamicRiskPolicy, SnapshotStorageService snapshotStorage, IConfiguration configuration, IWebHostEnvironment environment, @@ -32,6 +34,7 @@ public MigrationOrchestrator( _serviceScopeFactory = serviceScopeFactory; _commandService = commandService; _analysisService = analysisService; + _dynamicRiskPolicy = dynamicRiskPolicy; _snapshotStorage = snapshotStorage; _configuration = configuration; _environment = environment; @@ -65,11 +68,12 @@ public async Task RunMigrationCycle(string sourceAgentId, CancellationToken canc return; } - if (await HasRecentMigrationAsync(context, sourceAgentId, cancellationToken)) - { - _logger.LogInformation("Recent migration found for {SourceAgentId}. Skipping cycle.", sourceAgentId); - return; - } + var cooldownWindow = TimeSpan.FromMinutes(_dynamicRiskPolicy.Options.CooldownMinutes); + var cooldownActive = cooldownWindow > TimeSpan.Zero + && await HasRecentMigrationAsync(context, sourceAgentId, cooldownWindow, cancellationToken); + var migrationsLastHour = await CountRecentMigrationsAsync(context, TimeSpan.FromHours(1), cancellationToken); + var maxMigrationsPerHour = _dynamicRiskPolicy.Options.MaxMigrationsPerHour; + var maxRateExceeded = migrationsLastHour >= maxMigrationsPerHour; var latestTelemetry = await context.TelemetryRecords .AsNoTracking() @@ -94,15 +98,40 @@ public async Task RunMigrationCycle(string sourceAgentId, CancellationToken canc diskAvailable); var riskResult = await _analysisService.AnalyzeAsync(riskPayload); - var isCritical = string.Equals(riskResult.Status, "CRITICAL", StringComparison.OrdinalIgnoreCase) - || (rebalanceSignal && string.Equals(riskResult.Status, "Unavailable", StringComparison.OrdinalIgnoreCase)); - - if (!isCritical) + var semantic = await GetLatestSemanticSnapshotAsync(context, cancellationToken); + var preemptProbability = ResolvePreemptProbability(riskResult, rebalanceSignal); + var riskInput = new DynamicRiskInput( + PreemptProbability: preemptProbability, + RebalanceSignal: rebalanceSignal, + VolatilityProbability: semantic.VolatilityProbability, + SentimentNegative: semantic.SentimentNegative, + SentimentPositive: semantic.SentimentPositive); + var guardrails = new RiskGuardrailState( + CooldownActive: cooldownActive, + MaxRateExceeded: maxRateExceeded, + RecentMigrationsLastHour: migrationsLastHour, + MaxMigrationsPerHour: maxMigrationsPerHour); + var decision = _dynamicRiskPolicy.Evaluate(riskInput, guardrails); + + if (!decision.ShouldMigrate) { - _logger.LogInformation("Risk check for {SourceAgentId} returned {Status}", sourceAgentId, riskResult.Status); + _logger.LogInformation( + "Risk check skipped migration for {SourceAgentId}: reason={Reason}, alpha={Alpha:F3}, score={Score:F3}, status={Status}", + sourceAgentId, + decision.Reason, + decision.Alpha, + decision.DecisionScore, + riskResult.Status); return; } + _logger.LogInformation( + "Risk decision approved migration for {SourceAgentId}: alpha={Alpha:F3}, score={Score:F3}, p_preempt={PreemptProbability:F3}", + sourceAgentId, + decision.Alpha, + decision.DecisionScore, + preemptProbability); + var targetAgent = await FindIdleTargetAsync(context, sourceId, cancellationToken); if (targetAgent is null) { @@ -200,9 +229,15 @@ private async Task HasPendingCommandsAsync( private async Task HasRecentMigrationAsync( ApplicationDbContext context, string sourceAgentId, + TimeSpan window, CancellationToken cancellationToken) { - var cutoff = DateTime.UtcNow.AddMinutes(-2); + if (window <= TimeSpan.Zero) + { + return false; + } + + var cutoff = DateTime.UtcNow - window; return await context.CommandAudits .AsNoTracking() .AnyAsync( @@ -212,6 +247,84 @@ private async Task HasRecentMigrationAsync( cancellationToken); } + private async Task CountRecentMigrationsAsync( + ApplicationDbContext context, + TimeSpan window, + CancellationToken cancellationToken) + { + if (window <= TimeSpan.Zero) + { + return 0; + } + + var cutoff = DateTime.UtcNow - window; + return await context.CommandAudits + .AsNoTracking() + .Where(audit => audit.Action == "Migration Completed" && audit.CreatedAt >= cutoff) + .CountAsync(cancellationToken); + } + + private async Task GetLatestSemanticSnapshotAsync( + ApplicationDbContext context, + CancellationToken cancellationToken) + { + var signal = await context.ExternalSignals + .AsNoTracking() + .Where(item => + item.SentimentNegative.HasValue && + item.SentimentPositive.HasValue && + item.VolatilityProbability.HasValue) + .OrderByDescending(item => item.EnrichedAt ?? item.PublishedAt) + .FirstOrDefaultAsync(cancellationToken); + + if (signal is null) + { + return new SemanticSnapshot(0.5, 0.33, 0.33); + } + + return new SemanticSnapshot( + VolatilityProbability: Clamp01(signal.VolatilityProbability ?? 0.5), + SentimentNegative: Clamp01(signal.SentimentNegative ?? 0.33), + SentimentPositive: Clamp01(signal.SentimentPositive ?? 0.33)); + } + + private static double ResolvePreemptProbability(AnalysisResult riskResult, bool rebalanceSignal) + { + var confidence = Clamp01(riskResult.Confidence); + var prediction = Clamp01(riskResult.Prediction / 100.0); + var statusCritical = string.Equals(riskResult.Status, "CRITICAL", StringComparison.OrdinalIgnoreCase); + var unavailable = string.Equals(riskResult.Status, "Unavailable", StringComparison.OrdinalIgnoreCase); + + var baseProbability = statusCritical + ? Math.Max(confidence, prediction) + : prediction; + if (rebalanceSignal && unavailable) + { + baseProbability = Math.Max(baseProbability, 0.9); + } + if (rebalanceSignal && statusCritical) + { + baseProbability = Math.Max(baseProbability, 0.95); + } + + return Clamp01(baseProbability); + } + + private static double Clamp01(double value) + { + if (value < 0.0) + { + return 0.0; + } + + if (value > 1.0) + { + return 1.0; + } + + return value; + } + private async Task FindIdleTargetAsync( ApplicationDbContext context, Guid sourceAgentId, @@ -332,6 +445,11 @@ private string BuildDownloadUrl(string workloadId) return $"{baseUrl.TrimEnd('/')}/download/{safeWorkloadId}"; } + private sealed record SemanticSnapshot( + double VolatilityProbability, + double SentimentNegative, + double SentimentPositive); + private enum CommandOutcome { Completed, From ebb8b4ba8ebcc6d9655c88ea39fe8c2948273e4f Mon Sep 17 00:00:00 2001 From: JasonEran Date: Wed, 25 Feb 2026 11:31:54 +0800 Subject: [PATCH 23/24] feat(m4): add dashboard explainability with alpha and top signals (#44) --- README.md | 1 + docs/Web-Explainability-v2.3-M4.md | 55 +++++ .../Controllers/DashboardController.cs | 32 ++- .../Grpc/ControlPlaneGrpcService.cs | 6 +- .../Services/ControlPlaneService.cs | 214 +++++++++++++++++- src/shared/protos/control_plane.proto | 13 ++ src/web/dashboard/app/DashboardClient.tsx | 134 +++++++++++ .../components/ExplainabilityPanel.tsx | 70 ++++++ src/web/dashboard/lib/api.ts | 16 ++ src/web/dashboard/types/index.ts | 13 ++ 10 files changed, 544 insertions(+), 10 deletions(-) create mode 100644 docs/Web-Explainability-v2.3-M4.md diff --git a/README.md b/README.md index ac793df..83d7180 100644 --- a/README.md +++ b/README.md @@ -208,6 +208,7 @@ Open the dashboard at http://localhost:3000. - v2.3 M3 canary + rollback plan: docs/QA-Canary-Rollback-v2.3-M3.md - v2.3 M3 canary evaluator script: scripts/qa/README.md - v2.3 M4 dynamic risk allocation (core): docs/Core-Dynamic-Risk-v2.3-M4.md +- v2.3 M4 dashboard explainability: docs/Web-Explainability-v2.3-M4.md If you want to simulate migrations, start at least two agents: diff --git a/docs/Web-Explainability-v2.3-M4.md b/docs/Web-Explainability-v2.3-M4.md new file mode 100644 index 0000000..7970a95 --- /dev/null +++ b/docs/Web-Explainability-v2.3-M4.md @@ -0,0 +1,55 @@ +# v2.3 M4 Dashboard Explainability + +Issue: #44 +Epic: #13 + +## Goal + +Expose dynamic-risk explainability to operators in the dashboard: + +- `alpha` +- `P_preempt` +- top fused signals +- decision rationale + confidence + +## API Surface + +`GET /api/v1/dashboard/latest` now returns extra fields under `analysis`: + +- `alpha` +- `preemptProbability` +- `decisionScore` +- `rationale` +- `topSignals[]` (`key`, `label`, `value`, `source`, `detail`) + +`DashboardAnalysis` in `src/shared/protos/control_plane.proto` is updated with additive fields to keep gRPC/JSON-transcoding parity. + +## Explainability Derivation + +Core computes explainability with existing dynamic-risk policy math: + +- `P_preempt` from AI analysis confidence/prediction + rebalance overrides. +- `alpha` from dynamic policy (`volatility`, `sentiment pressure`, rebalance force-to-max path). +- `decisionScore = clamp(P_preempt * alpha, 0, 1)`. +- `topSignals` ranked from fused inputs (telemetry + AI + enriched external signal semantics). + +When no enriched external signal exists, semantic values use safe fallback defaults and indicate fallback source. + +## Dashboard UI + +`Explainability` panel now shows: + +- AI status, confidence, predicted CPU +- `alpha`, `P_preempt`, decision score +- top 3 fused signals with source/detail +- decision rationale +- root cause + +## Local Validation + +- `dotnet build src/services/core-dotnet/AetherGuard.Core/AetherGuard.Core.csproj -c Release` +- `dotnet test src/services/core-dotnet/AetherGuard.Core.Tests/AetherGuard.Core.Tests.csproj -c Release` +- `npm run lint` (in `src/web/dashboard`) +- `npm run build` (in `src/web/dashboard`) + +All commands passed. diff --git a/src/services/core-dotnet/AetherGuard.Core/Controllers/DashboardController.cs b/src/services/core-dotnet/AetherGuard.Core/Controllers/DashboardController.cs index b49dca7..6985bd8 100644 --- a/src/services/core-dotnet/AetherGuard.Core/Controllers/DashboardController.cs +++ b/src/services/core-dotnet/AetherGuard.Core/Controllers/DashboardController.cs @@ -17,9 +17,9 @@ public DashboardController(ControlPlaneService controlPlaneService) } [HttpGet("latest")] - public IActionResult GetLatest() + public async Task GetLatest() { - var result = _controlPlaneService.GetLatest(); + var result = await _controlPlaneService.GetLatestAsync(HttpContext.RequestAborted); if (!result.Success) { return StatusCode(result.StatusCode); @@ -43,7 +43,19 @@ result.Payload.Analysis is null result.Payload.Analysis.Status, result.Payload.Analysis.Confidence, ClampPrediction(result.Payload.Analysis.PredictedCpu), - result.Payload.Analysis.RootCause)); + result.Payload.Analysis.RootCause, + result.Payload.Analysis.Alpha, + result.Payload.Analysis.PreemptProbability, + result.Payload.Analysis.DecisionScore, + result.Payload.Analysis.Rationale, + result.Payload.Analysis.TopSignals + .Select(signal => new DashboardSignalDto( + signal.Key, + signal.Label, + signal.Value, + signal.Source, + signal.Detail)) + .ToArray())); return Ok(response); } @@ -96,7 +108,19 @@ private sealed record DashboardAnalysisDto( string Status, double Confidence, double PredictedCpu, - string RootCause); + string RootCause, + double Alpha, + double PreemptProbability, + double DecisionScore, + string Rationale, + DashboardSignalDto[] TopSignals); + + private sealed record DashboardSignalDto( + string Key, + string Label, + double Value, + string Source, + string Detail); private sealed record TelemetryHistoryDto( long Id, diff --git a/src/services/core-dotnet/AetherGuard.Core/Grpc/ControlPlaneGrpcService.cs b/src/services/core-dotnet/AetherGuard.Core/Grpc/ControlPlaneGrpcService.cs index b8265d5..1310547 100644 --- a/src/services/core-dotnet/AetherGuard.Core/Grpc/ControlPlaneGrpcService.cs +++ b/src/services/core-dotnet/AetherGuard.Core/Grpc/ControlPlaneGrpcService.cs @@ -21,12 +21,12 @@ public override async Task QueueCommand( return HandleResult(result); } - public override Task GetDashboardLatest( + public override async Task GetDashboardLatest( GetDashboardLatestRequest request, ServerCallContext context) { - var result = _controlPlaneService.GetLatest(); - return Task.FromResult(HandleResult(result)); + var result = await _controlPlaneService.GetLatestAsync(context.CancellationToken); + return HandleResult(result); } public override async Task GetDashboardHistory( diff --git a/src/services/core-dotnet/AetherGuard.Core/Services/ControlPlaneService.cs b/src/services/core-dotnet/AetherGuard.Core/Services/ControlPlaneService.cs index 7bc3a22..465edc2 100644 --- a/src/services/core-dotnet/AetherGuard.Core/Services/ControlPlaneService.cs +++ b/src/services/core-dotnet/AetherGuard.Core/Services/ControlPlaneService.cs @@ -1,4 +1,5 @@ using AetherGuard.Core.Data; +using AetherGuard.Core.Models; using AetherGuard.Grpc.V1; using Microsoft.AspNetCore.Http; using Microsoft.EntityFrameworkCore; @@ -10,15 +11,18 @@ public class ControlPlaneService private readonly ApplicationDbContext _context; private readonly CommandService _commandService; private readonly TelemetryStore _telemetryStore; + private readonly DynamicRiskPolicy _dynamicRiskPolicy; public ControlPlaneService( ApplicationDbContext context, CommandService commandService, - TelemetryStore telemetryStore) + TelemetryStore telemetryStore, + DynamicRiskPolicy dynamicRiskPolicy) { _context = context; _commandService = commandService; _telemetryStore = telemetryStore; + _dynamicRiskPolicy = dynamicRiskPolicy; } public async Task> QueueCommandAsync( @@ -61,7 +65,7 @@ public async Task> QueueCommandAsync( return ApiResult.Ok(response, StatusCodes.Status202Accepted); } - public ApiResult GetLatest() + public async Task> GetLatestAsync(CancellationToken cancellationToken) { var latest = _telemetryStore.GetLatest(); if (latest is null) @@ -83,13 +87,34 @@ public ApiResult GetLatest() DashboardAnalysis? analysis = null; if (latest.Analysis is not null) { + var semantic = await GetLatestSemanticSnapshotAsync(cancellationToken); + var preemptProbability = ResolvePreemptProbability(latest.Analysis, latest.Telemetry.RebalanceSignal); + var riskInput = new DynamicRiskInput( + PreemptProbability: preemptProbability, + RebalanceSignal: latest.Telemetry.RebalanceSignal, + VolatilityProbability: semantic.VolatilityProbability, + SentimentNegative: semantic.SentimentNegative, + SentimentPositive: semantic.SentimentPositive); + var decision = _dynamicRiskPolicy.Evaluate( + riskInput, + new RiskGuardrailState( + CooldownActive: false, + MaxRateExceeded: false, + RecentMigrationsLastHour: 0, + MaxMigrationsPerHour: _dynamicRiskPolicy.Options.MaxMigrationsPerHour)); + analysis = new DashboardAnalysis { Status = latest.Analysis.Status, Confidence = latest.Analysis.Confidence, PredictedCpu = ClampPrediction(latest.Analysis.Prediction), - RootCause = latest.Analysis.RootCause ?? string.Empty + RootCause = latest.Analysis.RootCause ?? string.Empty, + Alpha = decision.Alpha, + PreemptProbability = preemptProbability, + DecisionScore = decision.DecisionScore, + Rationale = BuildDecisionRationale(decision, latest, semantic) }; + analysis.TopSignals.AddRange(BuildTopSignals(latest, latest.Analysis, preemptProbability, semantic)); } var response = new DashboardLatestResponse @@ -142,4 +167,187 @@ private static double ClampPrediction(double prediction) private static double ClampPrediction(double? prediction) => prediction.HasValue ? Math.Clamp(prediction.Value, 0, 100) : 0; + + private async Task GetLatestSemanticSnapshotAsync(CancellationToken cancellationToken) + { + var signal = await _context.ExternalSignals + .AsNoTracking() + .Where(item => + item.SentimentNegative.HasValue && + item.SentimentPositive.HasValue && + item.VolatilityProbability.HasValue) + .OrderByDescending(item => item.EnrichedAt ?? item.PublishedAt) + .Select(item => new + { + item.Source, + item.Title, + item.SummaryDigest, + item.Summary, + item.SentimentNegative, + item.SentimentPositive, + item.VolatilityProbability + }) + .FirstOrDefaultAsync(cancellationToken); + + if (signal is null) + { + return new SemanticSnapshot( + VolatilityProbability: 0.5, + SentimentNegative: 0.33, + SentimentPositive: 0.33, + Source: "fallback", + Detail: "No enriched external signal available."); + } + + return new SemanticSnapshot( + VolatilityProbability: Clamp01(signal.VolatilityProbability ?? 0.5), + SentimentNegative: Clamp01(signal.SentimentNegative ?? 0.33), + SentimentPositive: Clamp01(signal.SentimentPositive ?? 0.33), + Source: signal.Source, + Detail: FirstNonEmpty(signal.SummaryDigest, signal.Summary, signal.Title, "Latest enriched signal")); + } + + private static IEnumerable BuildTopSignals( + TelemetrySnapshot latest, + AnalysisResult analysis, + double preemptProbability, + SemanticSnapshot semantic) + { + var sentimentPressure = Math.Max(0.0, Clamp01(semantic.SentimentNegative) - Clamp01(semantic.SentimentPositive)); + var candidates = new[] + { + new SignalCandidate( + Key: "rebalance_signal", + Label: "Rebalance Signal", + Value: latest.Telemetry.RebalanceSignal ? 1.0 : 0.0, + Influence: latest.Telemetry.RebalanceSignal ? 1.0 : 0.0, + Source: "telemetry", + Detail: latest.Telemetry.RebalanceSignal + ? "Provider rebalance hint is active." + : "Provider rebalance hint is inactive."), + new SignalCandidate( + Key: "ai_preempt_probability", + Label: "AI P_preempt", + Value: preemptProbability, + Influence: preemptProbability, + Source: "ai", + Detail: $"status={analysis.Status}, confidence={Clamp01(analysis.Confidence):F2}, predicted_cpu={ClampPrediction(analysis.Prediction):F0}%"), + new SignalCandidate( + Key: "volatility_probability", + Label: "Volatility Probability", + Value: semantic.VolatilityProbability, + Influence: semantic.VolatilityProbability, + Source: $"external:{semantic.Source}", + Detail: semantic.Detail), + new SignalCandidate( + Key: "sentiment_pressure", + Label: "Sentiment Pressure", + Value: sentimentPressure, + Influence: sentimentPressure, + Source: $"external:{semantic.Source}", + Detail: $"negative={semantic.SentimentNegative:F2}, positive={semantic.SentimentPositive:F2}") + }; + + return candidates + .OrderByDescending(candidate => candidate.Influence) + .ThenByDescending(candidate => candidate.Value) + .Take(3) + .Select(candidate => new DashboardSignal + { + Key = candidate.Key, + Label = candidate.Label, + Value = candidate.Value, + Source = candidate.Source, + Detail = candidate.Detail + }); + } + + private static string BuildDecisionRationale( + DynamicRiskDecision decision, + TelemetrySnapshot latest, + SemanticSnapshot semantic) + { + if (decision.Reason == "guardrail_cooldown_active") + { + return "Cooldown guardrail is active; migration is temporarily suppressed."; + } + + if (decision.Reason == "guardrail_max_rate_exceeded") + { + return "Migration rate guardrail is active; migration is rate-limited."; + } + + if (latest.Telemetry.RebalanceSignal) + { + return "Rebalance signal is active, so alpha is pinned to max risk posture."; + } + + var sentimentPressure = Math.Max(0.0, semantic.SentimentNegative - semantic.SentimentPositive); + return $"Dynamic alpha is derived from volatility={semantic.VolatilityProbability:F2} and sentiment pressure={sentimentPressure:F2}."; + } + + private static double ResolvePreemptProbability(AnalysisResult riskResult, bool rebalanceSignal) + { + var confidence = Clamp01(riskResult.Confidence); + var prediction = Clamp01(riskResult.Prediction / 100.0); + var statusCritical = string.Equals(riskResult.Status, "CRITICAL", StringComparison.OrdinalIgnoreCase); + var unavailable = string.Equals(riskResult.Status, "Unavailable", StringComparison.OrdinalIgnoreCase); + + var baseProbability = statusCritical + ? Math.Max(confidence, prediction) + : prediction; + if (rebalanceSignal && unavailable) + { + baseProbability = Math.Max(baseProbability, 0.9); + } + if (rebalanceSignal && statusCritical) + { + baseProbability = Math.Max(baseProbability, 0.95); + } + + return Clamp01(baseProbability); + } + + private static string FirstNonEmpty(params string?[] candidates) + { + foreach (var candidate in candidates) + { + if (!string.IsNullOrWhiteSpace(candidate)) + { + return candidate.Trim(); + } + } + + return string.Empty; + } + + private static double Clamp01(double value) + { + if (value < 0.0) + { + return 0.0; + } + + if (value > 1.0) + { + return 1.0; + } + + return value; + } + + private sealed record SemanticSnapshot( + double VolatilityProbability, + double SentimentNegative, + double SentimentPositive, + string Source, + string Detail); + + private sealed record SignalCandidate( + string Key, + string Label, + double Value, + double Influence, + string Source, + string Detail); } diff --git a/src/shared/protos/control_plane.proto b/src/shared/protos/control_plane.proto index 902497d..90b6333 100644 --- a/src/shared/protos/control_plane.proto +++ b/src/shared/protos/control_plane.proto @@ -62,6 +62,19 @@ message DashboardAnalysis { double confidence = 2; double predicted_cpu = 3; string root_cause = 4; + double alpha = 5; + double preempt_probability = 6; + double decision_score = 7; + string rationale = 8; + repeated DashboardSignal top_signals = 9; +} + +message DashboardSignal { + string key = 1; + string label = 2; + double value = 3; + string source = 4; + string detail = 5; } message GetDashboardHistoryRequest { diff --git a/src/web/dashboard/app/DashboardClient.tsx b/src/web/dashboard/app/DashboardClient.tsx index 84912a9..551eeb7 100644 --- a/src/web/dashboard/app/DashboardClient.tsx +++ b/src/web/dashboard/app/DashboardClient.tsx @@ -54,6 +54,33 @@ const buildMockPayload = (now: number, chaosActive: boolean) => { rootCause: 'Stable capacity', rebalanceSignal: false, diskAvailable: 180 * 1024 * 1024 * 1024, + alpha: 0.86, + preemptProbability: 0.18, + decisionScore: 0.15, + decisionRationale: 'Dynamic alpha remains low due to calm volatility and weak sentiment pressure.', + topSignals: [ + { + key: 'ai_preempt_probability', + label: 'AI P_preempt', + value: 0.18, + source: 'ai', + detail: 'status=LOW, confidence=0.74, predicted_cpu=28%', + }, + { + key: 'volatility_probability', + label: 'Volatility Probability', + value: 0.22, + source: 'external:aws-status', + detail: 'Stable provider incident cadence.', + }, + { + key: 'sentiment_pressure', + label: 'Sentiment Pressure', + value: 0.08, + source: 'external:aws-status', + detail: 'negative=0.27, positive=0.19', + }, + ], }, { agentId: 'node-zephyr-07', @@ -68,6 +95,59 @@ const buildMockPayload = (now: number, chaosActive: boolean) => { rootCause: chaosActive ? 'Rebalance signal asserted' : 'Stable capacity', rebalanceSignal: chaosActive, diskAvailable: chaosActive ? 52 * 1024 * 1024 * 1024 : 96 * 1024 * 1024 * 1024, + alpha: chaosActive ? 1.6 : 1.03, + preemptProbability: chaosActive ? 0.95 : 0.46, + decisionScore: chaosActive ? 1.0 : 0.47, + decisionRationale: chaosActive + ? 'Rebalance signal is active, so alpha is pinned to max risk posture.' + : 'Dynamic alpha rises with volatility and sentiment pressure.', + topSignals: chaosActive + ? [ + { + key: 'rebalance_signal', + label: 'Rebalance Signal', + value: 1, + source: 'telemetry', + detail: 'Provider rebalance hint is active.', + }, + { + key: 'ai_preempt_probability', + label: 'AI P_preempt', + value: 0.95, + source: 'ai', + detail: 'status=CRITICAL, confidence=0.93, predicted_cpu=92%', + }, + { + key: 'volatility_probability', + label: 'Volatility Probability', + value: 0.84, + source: 'external:aws-status', + detail: 'Elevated incident volatility in latest signal.', + }, + ] + : [ + { + key: 'ai_preempt_probability', + label: 'AI P_preempt', + value: 0.46, + source: 'ai', + detail: 'status=LOW, confidence=0.67, predicted_cpu=48%', + }, + { + key: 'volatility_probability', + label: 'Volatility Probability', + value: 0.51, + source: 'external:gcp-status', + detail: 'Moderate volatility from capacity advisory.', + }, + { + key: 'sentiment_pressure', + label: 'Sentiment Pressure', + value: 0.17, + source: 'external:gcp-status', + detail: 'negative=0.41, positive=0.24', + }, + ], }, { agentId: 'node-sigma-12', @@ -82,6 +162,33 @@ const buildMockPayload = (now: number, chaosActive: boolean) => { rootCause: 'Checkpoint restore failed on target node', rebalanceSignal: true, diskAvailable: 12 * 1024 * 1024 * 1024, + alpha: 1.6, + preemptProbability: 0.98, + decisionScore: 1.0, + decisionRationale: 'Rebalance signal is active, so alpha is pinned to max risk posture.', + topSignals: [ + { + key: 'rebalance_signal', + label: 'Rebalance Signal', + value: 1, + source: 'telemetry', + detail: 'Provider rebalance hint is active.', + }, + { + key: 'ai_preempt_probability', + label: 'AI P_preempt', + value: 0.98, + source: 'ai', + detail: 'status=CRITICAL, confidence=0.98, predicted_cpu=98%', + }, + { + key: 'volatility_probability', + label: 'Volatility Probability', + value: 0.88, + source: 'external:azure-status', + detail: 'High volatility from spot price advisory.', + }, + ], }, ]; @@ -340,6 +447,33 @@ export default function DashboardClient({ userName, userRole }: DashboardClientP predictedCpu: 90, rootCause: 'Rebalance signal asserted', rebalanceSignal: true, + alpha: 1.6, + preemptProbability: 0.95, + decisionScore: 1.0, + decisionRationale: 'Rebalance signal is active, so alpha is pinned to max risk posture.', + topSignals: [ + { + key: 'rebalance_signal', + label: 'Rebalance Signal', + value: 1, + source: 'telemetry', + detail: 'Provider rebalance hint is active.', + }, + { + key: 'ai_preempt_probability', + label: 'AI P_preempt', + value: 0.95, + source: 'ai', + detail: 'status=CRITICAL, confidence=0.92, predicted_cpu=90%', + }, + { + key: 'volatility_probability', + label: 'Volatility Probability', + value: 0.82, + source: 'external:aws-status', + detail: 'Elevated incident volatility in latest signal.', + }, + ], } : agent, ), diff --git a/src/web/dashboard/components/ExplainabilityPanel.tsx b/src/web/dashboard/components/ExplainabilityPanel.tsx index 57bfceb..3ca0366 100644 --- a/src/web/dashboard/components/ExplainabilityPanel.tsx +++ b/src/web/dashboard/components/ExplainabilityPanel.tsx @@ -31,6 +31,21 @@ const formatPrediction = (value?: number) => { return `${Math.round(value)}%`; }; +const formatRatio = (value?: number) => { + if (value === undefined || Number.isNaN(value)) { + return '--'; + } + const normalized = value > 1 ? value / 100 : value; + return `${(Math.max(0, Math.min(1, normalized)) * 100).toFixed(1)}%`; +}; + +const formatAlpha = (value?: number) => { + if (value === undefined || Number.isNaN(value)) { + return '--'; + } + return value.toFixed(2); +}; + const formatDiskAvailable = (value?: number) => { if (value === undefined || Number.isNaN(value)) { return '--'; @@ -47,9 +62,26 @@ const formatRootCause = (value?: string) => { return normalized; }; +const formatSignalValue = (key: string, value: number) => { + if (!Number.isFinite(value)) { + return '--'; + } + + if (key === 'rebalance_signal') { + return value >= 0.5 ? 'Active' : 'Inactive'; + } + + if (key.includes('probability') || key === 'sentiment_pressure') { + return formatRatio(value); + } + + return value.toFixed(2); +}; + export default function ExplainabilityPanel({ agent, usingMock }: ExplainabilityPanelProps) { const statusLabel = agent?.analysisStatus?.trim().toUpperCase() || 'PENDING'; const statusClass = statusStyles[statusLabel] ?? statusStyles.PENDING; + const rationale = agent?.decisionRationale?.trim() || 'Awaiting dynamic risk decision details.'; return (
@@ -77,6 +109,20 @@ export default function ExplainabilityPanel({ agent, usingMock }: Explainability Predicted CPU: {formatPrediction(agent.predictedCpu)}
+
+
Risk Factors
+
+ alpha: {formatAlpha(agent.alpha)} +
+
+ P_preempt:{' '} + {formatRatio(agent.preemptProbability)} +
+
+ Decision Score:{' '} + {formatRatio(agent.decisionScore)} +
+
Telemetry Context
@@ -91,6 +137,30 @@ export default function ExplainabilityPanel({ agent, usingMock }: Explainability
+
+
Top Signals
+
+ {(agent.topSignals ?? []).length === 0 ? ( +
Awaiting fused signal inputs.
+ ) : ( + (agent.topSignals ?? []).slice(0, 3).map((signal) => ( +
+
+ {signal.label} + {formatSignalValue(signal.key, signal.value)} +
+
{signal.detail}
+
+ )) + )} +
+
+ +
+
Decision Rationale
+
{rationale}
+
+
Root Cause
{formatRootCause(agent.rootCause)}
diff --git a/src/web/dashboard/lib/api.ts b/src/web/dashboard/lib/api.ts index 7602f1a..ce1a559 100644 --- a/src/web/dashboard/lib/api.ts +++ b/src/web/dashboard/lib/api.ts @@ -18,6 +18,17 @@ interface CoreLatestResponse { confidence: number; predictedCpu: number; rootCause: string; + alpha?: number; + preemptProbability?: number; + decisionScore?: number; + rationale?: string; + topSignals?: Array<{ + key: string; + label: string; + value: number; + source: string; + detail: string; + }>; }; } @@ -100,6 +111,11 @@ export async function fetchFleetStatus(): Promise { rootCause: analysis?.rootCause, rebalanceSignal: telemetry.rebalanceSignal, diskAvailable: telemetry.diskAvailable, + alpha: analysis?.alpha, + preemptProbability: analysis?.preemptProbability, + decisionScore: analysis?.decisionScore, + decisionRationale: analysis?.rationale, + topSignals: analysis?.topSignals, }, ]; } catch (error) { diff --git a/src/web/dashboard/types/index.ts b/src/web/dashboard/types/index.ts index eec6a0f..e29b818 100644 --- a/src/web/dashboard/types/index.ts +++ b/src/web/dashboard/types/index.ts @@ -1,5 +1,13 @@ export type WorkloadTier = 'T1' | 'T2' | 'T3'; +export interface ExplainabilitySignal { + key: string; + label: string; + value: number; + source: string; + detail: string; +} + export interface Agent { agentId: string; status: string; @@ -12,6 +20,11 @@ export interface Agent { rootCause?: string; rebalanceSignal?: boolean; diskAvailable?: number; + alpha?: number; + preemptProbability?: number; + decisionScore?: number; + decisionRationale?: string; + topSignals?: ExplainabilitySignal[]; } export interface AuditLog { From 5887cadfbe30817b4fd1acd40e649c0bc3a36fd4 Mon Sep 17 00:00:00 2001 From: JasonEran Date: Wed, 25 Feb 2026 11:40:42 +0800 Subject: [PATCH 24/24] docs(v2.3): add release notes and acceptance PR checklist --- .github/PULL_REQUEST_TEMPLATE.md | 3 ++ CHANGELOG.md | 1 + README.md | 10 ++-- docs/PR-Template-v2.3-Acceptance.md | 78 +++++++++++++++++++++++++++++ docs/ROADMAP-v2.3.md | 3 ++ docs/Release-Notes-v2.3.md | 69 +++++++++++++++++++++++++ 6 files changed, 160 insertions(+), 4 deletions(-) create mode 100644 docs/PR-Template-v2.3-Acceptance.md create mode 100644 docs/Release-Notes-v2.3.md diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 3b2f075..d1ab05d 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -1,5 +1,8 @@ # Pull Request +For v2.3 release-track PRs, you can use: +`docs/PR-Template-v2.3-Acceptance.md` + ## Summary Describe the changes and their purpose. diff --git a/CHANGELOG.md b/CHANGELOG.md index ae3cc1e..ae7dab1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,7 @@ Semantic Versioning. - Expanded v2.3 roadmap with model choices, data sources, and validation guidance. - Verification scripts now support API key headers and optional agent build flags. - Optional HTTP listener when mTLS is enabled to keep dashboard/AI traffic on port 8080. +- v2.3 release notes (`docs/Release-Notes-v2.3.md`) and PR acceptance template (`docs/PR-Template-v2.3-Acceptance.md`). ### Changed - Agent now injects W3C trace headers for HTTP requests. diff --git a/README.md b/README.md index 83d7180..d40b3b6 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ v2.2 reference architecture with a concrete implementation guide. ## Project Status -- Stage: v2.2 baseline delivered (Phase 0-4). v2.3 Milestone 1 delivered; Milestone 2+ tracked in docs/ROADMAP-v2.3.md. +- Stage: v2.2 baseline delivered (Phase 0-4). v2.3 Milestones 1-4 delivered. - License: MIT - Authors: Qi Junyi, Xiao Erdong (2026) - Sponsor: https://github.com/sponsors/JasonEran @@ -131,10 +131,10 @@ This project targets a product-grade release, not a demo. The following standard - [x] Add snapshot retention automation and S3 lifecycle policy support. - [x] Generate SBOMs and sign container images with cosign in CI. -## v2.3 Preview (Roadmap) +## v2.3 Delivery (Roadmap) -We keep the current README focused on v2.2 implementation details. The next evolution is documented in -`docs/ARCHITECTURE-v2.3.md`. In brief, v2.3 moves from reactive thresholds to predictive, multimodal risk allocation: +v2.3 architecture and delivery detail are documented in `docs/ARCHITECTURE-v2.3.md` and +`docs/ROADMAP-v2.3.md`. In brief, v2.3 moves from reactive thresholds to predictive, multimodal risk allocation: - Multimodal inputs: telemetry plus external cloud signals (status pages, incident reports, capacity advisories). - Lightweight time-series forecasting on agents, with semantic enrichment computed in the control plane. @@ -209,6 +209,8 @@ Open the dashboard at http://localhost:3000. - v2.3 M3 canary evaluator script: scripts/qa/README.md - v2.3 M4 dynamic risk allocation (core): docs/Core-Dynamic-Risk-v2.3-M4.md - v2.3 M4 dashboard explainability: docs/Web-Explainability-v2.3-M4.md +- v2.3 release notes: docs/Release-Notes-v2.3.md +- v2.3 PR acceptance template: docs/PR-Template-v2.3-Acceptance.md If you want to simulate migrations, start at least two agents: diff --git a/docs/PR-Template-v2.3-Acceptance.md b/docs/PR-Template-v2.3-Acceptance.md new file mode 100644 index 0000000..3f58a41 --- /dev/null +++ b/docs/PR-Template-v2.3-Acceptance.md @@ -0,0 +1,78 @@ +# PR Template - v2.3 Acceptance Checklist + +Copy/paste this into your PR description for v2.3 release gating. + +## Scope + +- Release track: `v2.3` +- Type: `feature / fix / docs / release` +- Related issues: closes #___ + +## What Changed + +- [ ] Core (.NET) +- [ ] Agent (C++) +- [ ] AI engine (Python) +- [ ] Dashboard (Next.js) +- [ ] Docs / runbooks + +Summary: + +## Acceptance Checklist (One-Click Gate) + +### M1 Semantic Enrichment + +- [ ] External signals ingestion + enrichment path is operational +- [ ] Enrichment outputs are schema-versioned and persisted +- [ ] Observability evidence attached (metrics/traces/dashboard) + +### M2 Fusion + Backtesting + +- [ ] Fusion/backtesting scripts run successfully +- [ ] Artifact versioning + reproducibility evidence attached + +### M3 Federated Inference + +- [ ] Heartbeat semantic payload is compatible (additive contract) +- [ ] Agent local inference gate + rollback gate validated +- [ ] Canary evaluator evidence attached + +### M4 Dynamic Risk + Explainability + +- [ ] Dynamic alpha + guardrails enabled and tested +- [ ] Dashboard shows `alpha`, `P_preempt`, top signals, rationale/confidence + +### Compatibility / Safety + +- [ ] v2.2 compatibility path verified (fallback works) +- [ ] Rollback path verified +- [ ] No breaking API changes without compatibility shim +- [ ] Security checks passed (auth/mTLS/supply-chain impact reviewed) + +## Validation Commands + +- [ ] `dotnet build src/services/core-dotnet/AetherGuard.Core/AetherGuard.Core.csproj -c Release` +- [ ] `dotnet test src/services/core-dotnet/AetherGuard.Core.Tests/AetherGuard.Core.Tests.csproj -c Release` +- [ ] `cmake --build src/services/agent-cpp/build_m3_inference --config Release --target AetherAgent AetherAgentTests AetherAgentInferenceTests` +- [ ] `ctest --test-dir src/services/agent-cpp/build_m3_inference -C Release --output-on-failure` +- [ ] `npm run lint` (in `src/web/dashboard`) +- [ ] `npm run build` (in `src/web/dashboard`) + +Paste key output snippets: + +```text +# build/test summary here +``` + +## Evidence Links + +- Jaeger trace / observability screenshots: +- Canary evaluator report: +- Dashboard explainability screenshot: +- Issue comments with final evidence: + +## Risk / Rollback + +- Risk level: `low / medium / high` +- Rollback command/runbook: +- Guardrail impact notes: diff --git a/docs/ROADMAP-v2.3.md b/docs/ROADMAP-v2.3.md index 55a6bdd..740757a 100644 --- a/docs/ROADMAP-v2.3.md +++ b/docs/ROADMAP-v2.3.md @@ -44,6 +44,7 @@ minimize integration risk while preserving backward compatibility with the v2.2 ### Milestone 2: Fusion and Forecasting (Offline) **Goal**: Train and evaluate models with historical replay. +**Status**: Completed (2026-02-25) - Add TSMixer baseline for numerical telemetry (PyTorch), with export to ONNX for agent inference. - Fuse exogenous semantic vectors for `P(Preemption | Telemetry, Signals)`. @@ -57,6 +58,7 @@ minimize integration risk while preserving backward compatibility with the v2.2 ### Milestone 3: Federated Inference (Online) **Goal**: Deliver semantic vectors to agents and run local inference. +**Status**: Completed (2026-02-25) - Extend gRPC heartbeat payload with semantic features. - Deploy lightweight on-agent inference (TSMixer). @@ -70,6 +72,7 @@ minimize integration risk while preserving backward compatibility with the v2.2 ### Milestone 4: Dynamic Risk Management **Goal**: Replace static thresholds with dynamic risk allocation. +**Status**: Completed (2026-02-25) - Implement confidence score and risk allocation factor. - Add guardrails (max migration rate, minimum cool-down windows). diff --git a/docs/Release-Notes-v2.3.md b/docs/Release-Notes-v2.3.md new file mode 100644 index 0000000..d637e20 --- /dev/null +++ b/docs/Release-Notes-v2.3.md @@ -0,0 +1,69 @@ +# Aether-Guard v2.3 Release Notes + +Release date: 2026-02-25 +Release branch: `feature/v2.3` + +## Release Summary + +v2.3 upgrades Aether-Guard from static/reactive risk handling to a predictive multimodal flow: + +- external cloud incident signals are ingested and semantically enriched +- offline fusion/backtesting + reproducible model artifacts are in place +- online semantic delivery and optional agent local ONNX inference are enabled behind gates +- dynamic risk allocation + guardrails drive migration decisions +- dashboard explainability now exposes `alpha`, `P_preempt`, decision score, and top fused signals + +## Milestone Closure + +### M1 - Semantic Enrichment + +- `#33` [Obs] enrichment metrics/traces + dashboards (closed 2026-02-22) +- `#34` [Data] acquisition scripts for spot history/traces/incidents (closed 2026-02-22) + +### M2 - Fusion + Backtesting (Offline) + +- `#38` [Ops] model artifact versioning + reproducible runs (closed 2026-02-25) +- Epic `#11` closed (2026-02-25) + +### M3 - Federated Inference (Online) + +- `#39` heartbeat semantic payload extension +- `#40` agent local ONNX inference + feature gating +- `#41` core semantic push + safe fallback + per-agent rollout +- `#42` canary + rollback plan and evaluator +- Epic `#12` closed (2026-02-25) + +### M4 - Dynamic Risk Management + +- `#43` dynamic risk alpha + guardrails +- `#45` guardrail regression tests +- `#44` web explainability (`alpha`, `P_preempt`, top signals) +- Epic `#13` closed (2026-02-25) + +## Validation Evidence (Local) + +- Core build/test: + - `dotnet build src/services/core-dotnet/AetherGuard.Core/AetherGuard.Core.csproj -c Release` + - `dotnet test src/services/core-dotnet/AetherGuard.Core.Tests/AetherGuard.Core.Tests.csproj -c Release` +- Agent build/test (M3 path): + - `cmake -S src/services/agent-cpp -B src/services/agent-cpp/build_m3_inference -DAETHER_ENABLE_GRPC=OFF -DAETHER_USE_LOCAL_PROTOBUF=ON -DAETHER_ENABLE_ONNX_RUNTIME=OFF` + - `cmake --build src/services/agent-cpp/build_m3_inference --config Release --target AetherAgent AetherAgentTests AetherAgentInferenceTests` + - `ctest --test-dir src/services/agent-cpp/build_m3_inference -C Release --output-on-failure` +- Dashboard build checks: + - `npm run lint` (in `src/web/dashboard`) + - `npm run build` (in `src/web/dashboard`) +- Canary evaluator: + - `python scripts/qa/evaluate_m3_canary.py ...` (promote/rollback sample evidence under `.tmp/`) + +## Key Documentation + +- v2.3 roadmap: `docs/ROADMAP-v2.3.md` +- dynamic risk (core): `docs/Core-Dynamic-Risk-v2.3-M4.md` +- web explainability (M4): `docs/Web-Explainability-v2.3-M4.md` +- canary + rollback (M3): `docs/QA-Canary-Rollback-v2.3-M3.md` + +## Known Follow-ups + +- `#48` v2.3 release criteria checklist (open) +- `#14` CI / supply-chain stabilization epic (open) +- `#8` course/project management epic (open)