Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
node_modules
.DS_Store
.pi/workflows/sessions/
.pnpm-store
npm-debug.log
bun.lock

# Coverage
coverage/
Expand Down
30 changes: 30 additions & 0 deletions .pi/extensions/workflow-orchestrator/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,22 @@ const AllowedExtensionsByAgentSchema = Type.Record(Type.String(), Type.Array(Typ

const AgentsSchema = Type.Record(Type.String(), Type.String());

const AgentRetrySchema = Type.Object({
maxAttempts: Type.Optional(Type.Number()),
initialDelayMs: Type.Optional(Type.Number()),
maxDelayMs: Type.Optional(Type.Number()),
backoffMultiplier: Type.Optional(Type.Number()),
jitterMs: Type.Optional(Type.Number()),
});

const WorkflowSchema = Type.Object({
name: Type.String(),
goal: Type.String(),
maxWaves: Type.Optional(Type.Number()),
maxTaskRetries: Type.Optional(Type.Number()),
maxPmRetries: Type.Optional(Type.Number()),
parallelism: Type.Optional(Type.Number()),
agentRetry: Type.Optional(AgentRetrySchema),
allowedExtensions: Type.Optional(Type.Array(Type.String())),
allowedExtensionsByAgent: Type.Optional(AllowedExtensionsByAgentSchema),
agents: AgentsSchema,
Expand Down Expand Up @@ -127,6 +136,12 @@ export function loadWorkflowConfig(cwd: string, name: string): LoadedWorkflow {
config.maxWaves = config.maxWaves ?? 10;
config.maxTaskRetries = config.maxTaskRetries ?? 2;
config.parallelism = config.parallelism ?? 1;
config.agentRetry = config.agentRetry ?? {};
config.agentRetry.maxAttempts = config.agentRetry.maxAttempts ?? 5;
config.agentRetry.initialDelayMs = config.agentRetry.initialDelayMs ?? 5000;
config.agentRetry.maxDelayMs = config.agentRetry.maxDelayMs ?? 120000;
config.agentRetry.backoffMultiplier = config.agentRetry.backoffMultiplier ?? 2;
config.agentRetry.jitterMs = config.agentRetry.jitterMs ?? 1000;
config.taskFlow.memory = config.taskFlow.memory ?? {};
config.taskFlow.memory.keepDeveloperMemory = config.taskFlow.memory.keepDeveloperMemory ?? true;
config.taskFlow.memory.keepVerifierMemoryOnDeveloperFailure =
Expand All @@ -137,6 +152,21 @@ export function loadWorkflowConfig(cwd: string, name: string): LoadedWorkflow {
if (config.parallelism < 1) {
throw new Error("parallelism must be at least 1");
}
if (config.agentRetry.maxAttempts < 1) {
throw new Error("agentRetry.maxAttempts must be at least 1");
}
if (config.agentRetry.initialDelayMs < 0) {
throw new Error("agentRetry.initialDelayMs must be at least 0");
}
if (config.agentRetry.maxDelayMs < 0) {
throw new Error("agentRetry.maxDelayMs must be at least 0");
}
if (config.agentRetry.backoffMultiplier < 1) {
throw new Error("agentRetry.backoffMultiplier must be at least 1");
}
if (config.agentRetry.jitterMs < 0) {
throw new Error("agentRetry.jitterMs must be at least 0");
}

return { config, path: workflowPath };
}
21 changes: 13 additions & 8 deletions .pi/extensions/workflow-orchestrator/engine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ export interface TaskFlowInput<TTask, TOutput> {
output: TOutput | null,
error?: string,
reason?: "verification_failed" | "malformed_output" | "error",
) => boolean;
applyGenericFailure: (task: TTask, error: string) => boolean;
) => void;
applyGenericFailure: (task: TTask, error: string) => void;
}

function defaultGetField(obj: any, path: string): any {
Expand All @@ -56,6 +56,11 @@ export async function runTaskFlow<TTask extends { retries: number }, TOutput>(
const getNextStageId = input.getNextStageId ?? defaultGetNextStageId;
let currentStageId = input.startStageId ?? input.stages[0]?.id;

function recordFailure(): boolean {
input.task.retries += 1;
return input.task.retries <= input.maxRetries;
}

while (currentStageId) {
if (input.isStopped?.(input.task)) return;

Expand All @@ -74,17 +79,17 @@ export async function runTaskFlow<TTask extends { retries: number }, TOutput>(
input.onError?.(stage, input.task, error instanceof Error ? error : new Error(message));

if (stage.id === "verify") {
const retry = input.applyVerifyFailure(input.task, stage.id, null, message, "error");
if (!retry) {
input.applyVerifyFailure(input.task, stage.id, null, message, "error");
if (!recordFailure()) {
input.markFailed(input.task, stage.id);
return;
}
currentStageId = input.stages[0]?.id;
continue;
}

const retry = input.applyGenericFailure(input.task, message);
if (!retry) {
input.applyGenericFailure(input.task, message);
if (!recordFailure()) {
input.markFailed(input.task, stage.id);
return;
}
Expand Down Expand Up @@ -123,8 +128,8 @@ export async function runTaskFlow<TTask extends { retries: number }, TOutput>(
if ((nextStageId === firstStageId || nextStageId === stage.id) && stage.id === "verify") {
const reason =
nextStageId === stage.id && !matchedTransition ? "malformed_output" : "verification_failed";
const retry = input.applyVerifyFailure(input.task, stage.id, output, undefined, reason);
if (!retry) {
input.applyVerifyFailure(input.task, stage.id, output, undefined, reason);
if (!recordFailure()) {
input.markFailed(input.task, stage.id);
return;
}
Expand Down
Loading