Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions src/model/providers/anthropic/request.ts
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,16 @@ function toAnthropicContentBlock(block: CanonicalContentBlock): unknown {
switch (block.type) {
case "text":
return { type: "text", text: block.text };
case "thinking":
return { type: "thinking", thinking: block.text };
case "thinking": {
const thinking: { type: "thinking"; thinking: string; signature?: string } = {
type: "thinking",
thinking: block.text,
};
if (block.signature) {
thinking.signature = block.signature;
}
return thinking;
}
case "image":
return block.source === "url"
? { type: "image", source: { type: "url", url: block.data } }
Expand Down
14 changes: 12 additions & 2 deletions src/model/providers/anthropic/response.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import type {
CanonicalContentBlock,
CanonicalModelResponse,
CanonicalThinkingBlock,
CanonicalToolCallBlock,
} from "../../protocol/canonical.js";
import { normalizeAnthropicFinishReason } from "../../response/normalizeFinishReason.js";
Expand All @@ -26,8 +27,17 @@ function toCanonicalContentBlock(block: unknown): CanonicalContentBlock[] {
switch (record.type) {
case "text":
return [{ type: "text", text: readString(record.text) ?? "" }];
case "thinking":
return [{ type: "thinking", text: readString(record.thinking) ?? readString(record.text) ?? "" }];
case "thinking": {
const thinking: CanonicalThinkingBlock = {
type: "thinking",
text: readString(record.thinking) ?? readString(record.text) ?? "",
};
const signature = readString(record.signature);
if (signature) {
thinking.signature = signature;
}
return [thinking];
}
case "tool_use":
return [
{
Expand Down
4 changes: 1 addition & 3 deletions src/model/request/validateModelRequest.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import type { CanonicalModelRequest, ModelConfig, ModelDefinition, ProviderConfig } from "../protocol/canonical.js";
import { ModelRequestError } from "../protocol/errors.js";
import { assertContentSupported, downgradeUnsupportedContent } from "../protocol/multimodal.js";
import { assertContentSupported } from "../protocol/multimodal.js";

export type ResolvedModelRequest = {
provider: ProviderConfig;
Expand Down Expand Up @@ -39,8 +39,6 @@ export function validateModelRequest(
throw new ModelRequestError("unsupported_tool_use", `Model ${request.model} does not support tools.`);
}

downgradeUnsupportedContent(request.messages, model.multimodal);

for (const message of request.messages) {
assertContentSupported(message.content, model.multimodal);
}
Expand Down
184 changes: 159 additions & 25 deletions src/router/RouterRuntime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ import type {
CanonicalModelRequest,
ModelRuntime,
} from "../model/index.js";
import { ModelRequestError } from "../model/index.js";
import type { InputModality } from "../model/index.js";
import {
DEFAULT_SUBAGENT_MAX_TOKENS,
DEFAULT_SUBAGENT_POLICY,
Expand Down Expand Up @@ -37,6 +39,10 @@ import {
import { TokenStatsCollector } from "./stats/TokenStatsCollector.js";
import { classifyAndRoute } from "./tokenSaver/classifyAndRoute.js";
import { countMessagesTokens, countResponseTokens, dispose as disposeTokenizer } from "./utils/countTokens.js";
import {
collectRequiredInputModalities,
missingInputModalities,
} from "./utils/mediaRequirements.js";
import type { TelemetryClient } from "../telemetry/index.js";

export type RouterRuntimeDeps = {
Expand Down Expand Up @@ -113,6 +119,90 @@ export function createRouterRuntime(
return tracker;
}

function missingForModel(
ref: RouterModelRef,
required: readonly InputModality[],
): InputModality[] {
if (required.length === 0) {
return [];
}
try {
return missingInputModalities(
deps.modelRuntime.getMultimodal(ref.provider, ref.model),
required,
);
} catch {
return [...required];
}
}

function supportsMediaRequirements(
ref: RouterModelRef,
required: readonly InputModality[],
): boolean {
return missingForModel(ref, required).length === 0;
}

function fallbackCandidatesFor(scenarioType: RouterScenarioType): RouterModelRef[] {
const candidates: RouterModelRef[] = [];
const add = (refs: RouterModelRef[] | undefined) => {
for (const ref of refs ?? []) {
const id = ref.id || `${ref.provider}/${ref.model}`;
if (!candidates.some((candidate) => candidate.provider === ref.provider && candidate.model === ref.model)) {
candidates.push({ ...ref, id });
}
}
};
add((config.fallback as Record<string, RouterModelRef[] | undefined> | undefined)?.[scenarioType]);
add(config.fallback?.default);
return candidates;
}

function findCompatibleFallback(
scenarioType: RouterScenarioType,
required: readonly InputModality[],
): RouterModelRef | undefined {
return fallbackCandidatesFor(scenarioType)
.find((ref) => supportsMediaRequirements(ref, required));
}

function rerouteDecisionForMedia(
decision: RouterDecision,
messages: CanonicalModelRequest["messages"],
mutations: RouterMutationsLog,
): RouterMutationsLog {
const required = collectRequiredInputModalities(messages);
if (required.length === 0) {
return mutations;
}

const selected: RouterModelRef = {
id: `${decision.provider}/${decision.model}`,
provider: decision.provider,
model: decision.model,
};
if (supportsMediaRequirements(selected, required)) {
return mutations;
}

const replacement = findCompatibleFallback(decision.scenarioType, required);
if (!replacement) {
return mutations;
}

decision.provider = replacement.provider;
decision.model = replacement.model;
decision.resolvedFrom = "fallback";
return {
...mutations,
mediaCapabilityRerouted: {
required: [...required],
from: selected.id,
to: replacement.id || `${replacement.provider}/${replacement.model}`,
},
};
}

async function resolveCustom(
input: RouterDecisionInput,
): Promise<Partial<RouterDecision> | undefined> {
Expand Down Expand Up @@ -279,38 +369,16 @@ export function createRouterRuntime(
`[router] decision: tier=${tokenSaverTier}, model=${selection.provider}/${selection.model}, orchGate=${orchGate}, alreadyOrch=${alreadyOrchestrating}, resolvedFrom=${resolvedFrom}`,
);

let skillPrompt: string | undefined;
if (
config.autoOrchestrate?.enabled &&
orchGate &&
input.isMainAgent &&
config.autoOrchestrate.skillExtensionId &&
deps.loadSkillPrompt
) {
try {
skillPrompt = await deps.loadSkillPrompt(config.autoOrchestrate.skillExtensionId);
} catch {
skillPrompt = undefined;
}
}

let mutations: RouterMutationsLog = {};
if (config.autoOrchestrate?.enabled && orchGate) {
const orchestrated = applyOrchestration({
request: input.request,
config: config.autoOrchestrate,
isMainAgent: input.isMainAgent,
tier: tokenSaverTier,
alreadyOrchestrating,
skillPrompt,
});
if (orchestrated.applied) {
mutations = { ...mutations, ...orchestrated.mutations };
decision.requestPatch = {
messages: orchestrated.request.messages,
tools: orchestrated.request.tools,
systemPrompt: orchestrated.request.systemPrompt,
};
decision.orchestrating = true;
if (config.autoOrchestrate.mainAgentModel) {
decision.provider = config.autoOrchestrate.mainAgentModel.provider;
Expand All @@ -329,6 +397,9 @@ export function createRouterRuntime(
mutations = { ...mutations, subagentTagStripped: true };
}

const mediaMessages = decision.requestPatch?.messages ?? input.request.messages;
mutations = rerouteDecisionForMedia(decision, mediaMessages, mutations);

decision.mutations = mutations;

sessionStore.set({
Expand Down Expand Up @@ -375,10 +446,21 @@ export function createRouterRuntime(
): AsyncIterable<CanonicalModelEvent> {
const startedAt = (deps.now?.() ?? new Date()).toISOString();
const fallbackPlan = planFallback(config.fallback, decision.scenarioType);
const baseRequest = applyDecisionToRequest(decision, request);
const requiredModalities = collectRequiredInputModalities(baseRequest.messages);
const requestedAttempt: RouterModelRef = {
id: `${decision.provider}/${decision.model}`,
provider: decision.provider,
model: decision.model,
};
const attempts: RouterModelRef[] = [
{ id: `${decision.provider}/${decision.model}`, provider: decision.provider, model: decision.model },
requestedAttempt,
...fallbackPlan.attempts,
];
].filter((attempt, index, all) =>
all.findIndex((candidate) =>
candidate.provider === attempt.provider && candidate.model === attempt.model
) === index
).filter((attempt) => supportsMediaRequirements(attempt, requiredModalities));
const zeroUsageMax = Math.max(1, config.zeroUsageRetry?.maxAttempts ?? 5);
const zeroUsageEnabled = config.zeroUsageRetry?.enabled ?? true;
const transientRetryEnabled = config.transientRetry?.enabled ?? true;
Expand All @@ -393,6 +475,22 @@ export function createRouterRuntime(
let lastDecision: RouterDecision = decision;
let lastHasYieldedContent = false;

if (attempts.length === 0) {
const missing = missingForModel(requestedAttempt, requiredModalities);
const error = createUnsupportedMediaError(requestedAttempt, requiredModalities, missing);
events.emit({
type: "pilotdeck_router_execute_failed",
sessionId: ctx.sessionId,
turnId: ctx.turnId,
scenarioType: decision.scenarioType,
provider: requestedAttempt.provider,
model: requestedAttempt.model,
error,
});
yield { type: "error", error };
return;
}

outer: for (let attemptIndex = 0; attemptIndex < attempts.length; attemptIndex += 1) {
if (ctx.abortSignal?.aborted) {
return;
Expand Down Expand Up @@ -856,7 +954,7 @@ async function* streamAttempt(
throw error;
}
const fromError = (error as { error?: import("../model/index.js").CanonicalModelError })?.error;
providerError = fromError ?? {
providerError = fromError ?? canonicalizeModelRequestError(error, request) ?? {
provider: request.provider,
protocol: "anthropic",
code: classifyNetworkErrorCode(error),
Expand All @@ -876,6 +974,24 @@ async function* streamAttempt(
};
}

function canonicalizeModelRequestError(
error: unknown,
request: CanonicalModelRequest,
): import("../model/index.js").CanonicalModelError | undefined {
if (!(error instanceof ModelRequestError)) {
return undefined;
}

return {
provider: request.provider,
protocol: "anthropic",
code: error.code,
message: error.message,
retryable: false,
raw: error.details,
};
}

function abortableDelay(ms: number, signal?: AbortSignal): Promise<void> {
if (!signal) {
return new Promise((resolve) => setTimeout(resolve, ms));
Expand Down Expand Up @@ -944,6 +1060,24 @@ function classifyRetryReason(errorCode: string): "rate_limit" | "server_error" |
return "server_error";
}

function createUnsupportedMediaError(
attempt: RouterModelRef,
required: readonly InputModality[],
missing: readonly InputModality[],
): import("../model/index.js").CanonicalModelError {
const missingText = (missing.length > 0 ? missing : required).join(", ");
const requiredText = required.join(", ");
return {
provider: attempt.provider,
protocol: "openai",
code: "unsupported_modality",
message:
`Router could not find a configured fallback model for ${attempt.provider}/${attempt.model} ` +
`that supports required input modalities: ${requiredText}. Missing: ${missingText}.`,
retryable: false,
};
}

function extractPartialText(buffered: CanonicalModelEvent[]): string {
let text = "";
for (const ev of buffered) {
Expand Down
Loading