Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 19 additions & 15 deletions src/client/definePlaygroundAPI.ts
Original file line number Diff line number Diff line change
Expand Up @@ -276,18 +276,22 @@ export function definePlaygroundAPI<DataModel extends GenericDataModel>(
},
{ contextOptions, storageOptions, saveStreamDeltas: true },
);
const outputMessages = await Promise.all(
(await steps).map(async (step) => {
const { messages } = await serializeNewMessagesInStep(
ctx,
component,
step,
{
model: getModelName(agent.options.languageModel),
provider: getProviderName(agent.options.languageModel),
},
);
return messages.map((messageWithMetadata, i) => {
const outputMessages: MessageDoc[][] = [];
let previousResponseMessageCount = 0;
for (const step of await steps) {
const { messages } = await serializeNewMessagesInStep(
ctx,
component,
step,
{
model: getModelName(agent.options.languageModel),
provider: getProviderName(agent.options.languageModel),
},
previousResponseMessageCount,
);
previousResponseMessageCount = step.response.messages.length;
outputMessages.push(
messages.map((messageWithMetadata, i) => {
return {
...messageWithMetadata,
tool: isTool(messageWithMetadata.message),
Expand All @@ -300,9 +304,9 @@ export function definePlaygroundAPI<DataModel extends GenericDataModel>(
order: 0,
stepOrder: i + 1,
} satisfies MessageDoc;
});
}),
);
}),
);
}
return { text: await text, messages: outputMessages.flat() };
},
returns: v.object({ text: v.string(), messages: v.array(vMessageDoc) }),
Expand Down
91 changes: 91 additions & 0 deletions src/client/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { describe, expect, test } from "vitest";
import {
Agent,
createThread,
createTool,
filterOutOrphanedToolMessages,
type MessageDoc,
} from "./index.js";
Expand Down Expand Up @@ -62,6 +63,74 @@ export const createThreadManually = mutation({
},
});

const saveStepAgent = new Agent(components.agent, {
name: "save-step-test",
instructions: "test",
tools: {
echo: createTool({
description: "Echo a value",
inputSchema: z.object({ value: z.string() }),
execute: async (_ctx, input) => `echo:${input.value}`,
}),
},
languageModel: mockModel({
contentSteps: [
[
{
type: "tool-call",
toolCallId: "ss-1",
toolName: "echo",
input: JSON.stringify({ value: "hi" }),
},
],
[{ type: "text", text: "done" }],
],
}),
stopWhen: stepCountIs(5),
});

export const replayStepsViaSaveStep = action({
args: { withWatermark: v.boolean() },
handler: async (ctx, args) => {
const { thread } = await saveStepAgent.createThread(ctx, {
userId: "ss-gen",
});
const genResult = await thread.generateText({ prompt: "echo hi" });
const steps = genResult.steps;

const { threadId } = await saveStepAgent.createThread(ctx, {
userId: "ss-replay",
});
const { messageId: promptMessageId } = await saveStepAgent.saveMessage(ctx, {
threadId,
message: { role: "user", content: "echo hi" },
skipEmbeddings: true,
});
let previousStep: (typeof steps)[number] | undefined;
for (const step of steps) {
await saveStepAgent.saveStep(ctx, {
threadId,
promptMessageId,
step,
previousStep: args.withWatermark ? previousStep : undefined,
});
previousStep = step;
}

const replayed = await saveStepAgent.listMessages(ctx, {
threadId,
paginationOpts: { cursor: null, numItems: 50 },
statuses: ["success", "pending", "failed"],
});
const contentTypes = replayed.page.flatMap((m) =>
Array.isArray(m.message?.content)
? m.message!.content.map((c: { type?: string }) => c.type ?? "text")
: ["text"],
);
return { stepCount: steps.length, contentTypes };
},
});

export const createThreadMutation = agent.createThreadMutation();
export const generateObjectAction = agent.asObjectAction({
schema: z.object({ hello: z.string().describe("A string for testing") }),
Expand Down Expand Up @@ -162,6 +231,7 @@ const testApi: ApiFromModules<{
generateTextAction: typeof generateTextAction;
generateObjectAction: typeof generateObjectAction;
saveMessageMutation: typeof saveMessageMutation;
replayStepsViaSaveStep: typeof replayStepsViaSaveStep;
};
}>["fns"] = anyApi["index.test"] as any;

Expand All @@ -177,6 +247,27 @@ describe("Agent thick client", () => {
expect(result).toBeDefined();
expect(result).toMatch(TEST_TEXT);
});
test("saveStep with previousStep saves each step's new messages exactly once", async () => {
const t = initConvexTest(schema);
const res = await t.action(testApi.replayStepsViaSaveStep, {
withWatermark: true,
});
expect(res.stepCount).toBe(2);
const toolCalls = res.contentTypes.filter((t) => t === "tool-call").length;
const toolResults = res.contentTypes.filter(
(t) => t === "tool-result",
).length;
expect(toolCalls).toBe(1);
expect(toolResults).toBe(1);
});
test("saveStep without previousStep duplicates prior messages", async () => {
const t = initConvexTest(schema);
const res = await t.action(testApi.replayStepsViaSaveStep, {
withWatermark: false,
});
const toolCalls = res.contentTypes.filter((t) => t === "tool-call").length;
expect(toolCalls).toBeGreaterThan(1);
});
});

describe("filterOutOrphanedToolMessages", () => {
Expand Down
26 changes: 25 additions & 1 deletion src/client/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1187,7 +1187,9 @@ export class Agent<
}

/**
* Explicitly save a "step" created by the AI SDK.
* Explicitly save a "step" created by the AI SDK. For multi-step generation
* loops, pass `previousStep` so we save only the new response messages —
* see the arg JSDoc for why.
* @param ctx The ctx argument to a mutation or action.
* @param args The Step generated by the AI SDK.
*/
Expand All @@ -1204,6 +1206,15 @@ export class Agent<
* The step to save, possibly including multiple tool calls.
*/
step: StepResult<TOOLS>;
/**
* The previous step in the same generation loop, if any. Pass it so we
* can compute how many of `step.response.messages` are already saved.
* Omit for the first step. AI SDK v6's `step.response.messages` is
* cumulative across steps; without this, multi-step callers duplicate
* every prior message on every save — the exact failure mode this fix
* addresses, just at the public-API layer.
*/
previousStep?: StepResult<TOOLS>;
/**
* The model used to generate the step.
* Defaults to the chat model for the Agent.
Expand All @@ -1216,6 +1227,18 @@ export class Agent<
provider?: string;
},
): Promise<{ messages: MessageDoc[] }> {
const previousResponseMessageCount =
args.previousStep?.response.messages.length ?? 0;
if (
args.previousStep !== undefined &&
args.step.response.messages.length < previousResponseMessageCount
) {
throw new Error(
`saveStep: step.response.messages length (${args.step.response.messages.length}) is less than ` +
`previousStep.response.messages length (${previousResponseMessageCount}). ` +
`Ensure previousStep is from the immediately preceding step in the same generation loop.`,
);
}
const { messages } = await serializeNewMessagesInStep(
ctx,
this.component,
Expand All @@ -1224,6 +1247,7 @@ export class Agent<
provider: args.provider ?? getProviderName(this.options.languageModel),
model: args.model ?? getModelName(this.options.languageModel),
},
previousResponseMessageCount,
);
const embeddings = await this.generateEmbeddings(
ctx,
Expand Down
Loading
Loading