Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 81 additions & 41 deletions src/client/predictions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import {
WebPredictionParams,
SchemaResponse,
GenerationConfigParams,
AgentSkill,
} from "./types";
import { processImage } from "../utils/image";
import { convertToJsonSchema } from "../utils/utils";
Expand Down Expand Up @@ -184,6 +185,15 @@ export class ImagePredictions extends Predictions {
callbackUrl,
} = params;

const hasSkills = config?.skills && config.skills.length > 0;
if (!domain && !hasSkills) {
throw new InputError(
"Either `domain` or `config.skills` must be provided",
"missing_parameter",
"Provide either a domain or skills in the config"
);
}

const imagesData = this._handleImagesOrUrls(images, urls);

let jsonSchema = config?.jsonSchema;
Expand All @@ -194,32 +204,43 @@ export class ImagePredictions extends Predictions {
);
}

const serializedSkills = config?.skills?.map((s) =>
s instanceof AgentSkill ? s.toJSON() : new AgentSkill(s).toJSON()
);

const data: Record<string, any> = {
images: imagesData,
model,
batch,
config: {
detail: config?.detail ?? "auto",
json_schema: jsonSchema,
skills: serializedSkills,
confidence: config?.confidence ?? false,
grounding: config?.grounding ?? false,
gql_stmt: config?.gqlStmt ?? null,
},
metadata: {
environment: metadata?.environment ?? "dev",
session_id: metadata?.sessionId,
allow_training: metadata?.allowTraining ?? true,
},
callback_url: callbackUrl,
};
if (domain !== undefined) {
data.domain = domain;
}

const [response] = await this.requestor.request<PredictionResponse>(
"POST",
"image/generate",
undefined,
{
images: imagesData,
model,
domain,
batch,
config: {
detail: config?.detail ?? "auto",
json_schema: jsonSchema,
confidence: config?.confidence ?? false,
grounding: config?.grounding ?? false,
gql_stmt: config?.gqlStmt ?? null,
},
metadata: {
environment: metadata?.environment ?? "dev",
session_id: metadata?.sessionId,
allow_training: metadata?.allowTraining ?? true,
},
callback_url: callbackUrl,
}
data
);

this._castResponseToSchema(response, domain, config);
if (domain) {
this._castResponseToSchema(response, domain, config);
}

return response;
}
Expand Down Expand Up @@ -307,6 +328,15 @@ export class FilePredictions extends Predictions {
callbackUrl,
} = params;

const hasSkills = config?.skills && config.skills.length > 0;
if (!domain && !hasSkills) {
throw new InputError(
"Either `domain` or `config.skills` must be provided",
"missing_parameter",
"Provide either a domain or skills in the config"
);
}

const fileOrUrl = this._handleFileOrUrl(fileId, url);

let jsonSchema = config?.jsonSchema;
Expand All @@ -317,33 +347,43 @@ export class FilePredictions extends Predictions {
);
}

const serializedSkills = config?.skills?.map((s) =>
s instanceof AgentSkill ? s.toJSON() : new AgentSkill(s).toJSON()
);

const data: Record<string, any> = {
...fileOrUrl,
model,
batch,
config: {
detail: config?.detail ?? "auto",
json_schema: jsonSchema,
skills: serializedSkills,
confidence: config?.confidence ?? false,
grounding: config?.grounding ?? false,
gql_stmt: config?.gqlStmt ?? null,
},
metadata: {
environment: metadata?.environment ?? "dev",
session_id: metadata?.sessionId,
allow_training: metadata?.allowTraining ?? true,
},
callback_url: callbackUrl,
};
if (domain !== undefined) {
data.domain = domain;
}

const [response] = await this.requestor.request<PredictionResponse>(
"POST",
`/${this.route}/generate`,
undefined,
{
...fileOrUrl,
model,
domain,
batch,
config: {
detail: config?.detail ?? "auto",
json_schema: jsonSchema,
confidence: config?.confidence ?? false,
grounding: config?.grounding ?? false,
gql_stmt: config?.gqlStmt ?? null,
},
metadata: {
environment: metadata?.environment ?? "dev",
session_id: metadata?.sessionId,
allow_training: metadata?.allowTraining ?? true,
},
callback_url: callbackUrl,
}
data
);

// Cast response to schema if needed
this._castResponseToSchema(response, domain!, config);
if (domain) {
this._castResponseToSchema(response, domain, config);
}

return response;
}
Expand Down
53 changes: 52 additions & 1 deletion src/client/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -241,9 +241,41 @@ export interface FileUploadParams {
generatePublicUrl?: boolean;
}

export interface AgentSkillParams {
type?: string;
skillId?: string;
skillName?: string;
version?: string;
}

export class AgentSkill {
type: string = "vlm-run";
skillId?: string;
skillName?: string;
version: string = "latest";

constructor(params: AgentSkillParams = {}) {
if (!params.skillId && !params.skillName) {
throw new Error("Either 'skillId' or 'skillName' must be provided");
}
Object.assign(this, params);
}

toJSON() {
return {
type: this.type,
skill_id: this.skillId,
skill_name: this.skillName,
version: this.version,
};
}
}

export type AgentSkillInput = AgentSkill | AgentSkillParams;

export interface PredictionGenerateParams {
model?: string;
domain: string;
domain?: string;
config?: GenerationConfigParams;
metadata?: RequestMetadataParams;
callbackUrl?: string;
Expand Down Expand Up @@ -294,6 +326,7 @@ export type GenerationConfigParams = {
responseModel?: ZodType;
zodToJsonParams?: any;
jsonSchema?: Record<string, any> | null;
skills?: AgentSkillInput[];
confidence?: boolean;
grounding?: boolean;
gqlStmt?: string | null;
Expand All @@ -310,6 +343,11 @@ export class GenerationConfig {
*/
jsonSchema: Record<string, any> | null = null;

/**
* List of skills to enable for this request.
*/
skills?: AgentSkillInput[];

/**
* Include confidence scores in the response (included in the `_metadata` field).
*/
Expand All @@ -336,6 +374,9 @@ export class GenerationConfig {
return {
detail: this.detail,
json_schema: this.jsonSchema,
skills: this.skills?.map((s) =>
s instanceof AgentSkill ? s.toJSON() : new AgentSkill(s).toJSON()
),
confidence: this.confidence,
grounding: this.grounding,
gql_stmt: this.gqlStmt,
Expand Down Expand Up @@ -601,11 +642,13 @@ export type AgentExecutionConfigParams = {
prompt?: string;
responseModel?: ZodType;
jsonSchema?: Record<string, any>;
skills?: AgentSkillInput[];
};

export class AgentExecutionConfig {
prompt?: string;
jsonSchema?: Record<string, any>;
skills?: AgentSkillInput[];

constructor(params: Partial<AgentExecutionConfig> = {}) {
Object.assign(this, params);
Expand All @@ -615,6 +658,9 @@ export class AgentExecutionConfig {
return {
prompt: this.prompt,
json_schema: this.jsonSchema,
skills: this.skills?.map((s) =>
s instanceof AgentSkill ? s.toJSON() : new AgentSkill(s).toJSON()
),
};
}
}
Expand All @@ -623,11 +669,13 @@ export type AgentCreationConfigParams = {
prompt?: string;
responseModel?: ZodType;
jsonSchema?: Record<string, any>;
skills?: AgentSkillInput[];
};

export class AgentCreationConfig {
prompt?: string;
jsonSchema?: Record<string, any>;
skills?: AgentSkillInput[];

constructor(params: Partial<AgentCreationConfig> = {}) {
Object.assign(this, params);
Expand All @@ -637,6 +685,9 @@ export class AgentCreationConfig {
return {
prompt: this.prompt,
json_schema: this.jsonSchema,
skills: this.skills?.map((s) =>
s instanceof AgentSkill ? s.toJSON() : new AgentSkill(s).toJSON()
),
};
}
}
Expand Down
Loading