Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 59 additions & 72 deletions index.js
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,48 @@ export async function resolveModel(client, requestedModel, providerOverride) {
throw new Error(`Unknown model '${requestedModel}'. Call GET /v1/models to inspect available IDs.`)
}

export function createSseQueue() {
const chunks = []
let resolve = null
let done = false

function enqueue(value) {
chunks.push(value)
if (resolve) {
const r = resolve
resolve = null
r()
}
}

function finish() {
done = true
if (resolve) {
const r = resolve
resolve = null
r()
}
}

async function* generateChunks() {
while (true) {
while (chunks.length > 0) {
yield chunks.shift()
}
if (done) break
await new Promise((r) => {
resolve = r
})
}
// Drain any remaining chunks
while (chunks.length > 0) {
yield chunks.shift()
}
}

return { enqueue, finish, generateChunks }
}

function sseResponse(corsHeadersObj, generator) {
const encoder = new TextEncoder()
const body = new ReadableStream({
Expand Down Expand Up @@ -595,18 +637,7 @@ export function createProxyFetchHandler(client) {
const completionID = `chatcmpl_${crypto.randomUUID().replace(/-/g, "")}`
const now = Math.floor(Date.now() / 1000)

const chunks = []
let resolve = null
let done = false

function enqueue(value) {
chunks.push(value)
if (resolve) {
const r = resolve
resolve = null
r()
}
}
const queue = createSseQueue()

async function* generateSse() {
const runPromise = executePromptStreaming(
Expand All @@ -622,7 +653,7 @@ export function createProxyFetchHandler(client) {
model: model.id,
choices: [{ index: 0, delta: { role: "assistant", content: delta }, finish_reason: null }],
})
enqueue(`data: ${chunk}\n\n`)
queue.enqueue(`data: ${chunk}\n\n`)
},
)
.then((streamResult) => {
Expand All @@ -638,7 +669,7 @@ export function createProxyFetchHandler(client) {
total_tokens: streamResult.tokens.input + streamResult.tokens.output,
},
})
enqueue(`data: ${finalChunk}\n\ndata: [DONE]\n\n`)
queue.enqueue(`data: ${finalChunk}\n\ndata: [DONE]\n\n`)
})
.catch(async (err) => {
const streamError = err instanceof Error ? err.message : String(err)
Expand All @@ -649,30 +680,13 @@ export function createProxyFetchHandler(client) {
const errChunk = JSON.stringify({
error: { message: streamError, type: "server_error" },
})
enqueue(`data: ${errChunk}\n\ndata: [DONE]\n\n`)
queue.enqueue(`data: ${errChunk}\n\ndata: [DONE]\n\n`)
})
.finally(() => {
done = true
if (resolve) {
const r = resolve
resolve = null
r()
}
queue.finish()
})

while (true) {
while (chunks.length > 0) {
yield chunks.shift()
}
if (done) break
await new Promise((r) => {
resolve = r
})
}
// Drain any remaining chunks
while (chunks.length > 0) {
yield chunks.shift()
}
yield* queue.generateChunks()

await runPromise
}
Expand Down Expand Up @@ -739,25 +753,14 @@ export function createProxyFetchHandler(client) {
const itemID = `msg_${crypto.randomUUID().replace(/-/g, "")}`
const now = Math.floor(Date.now() / 1000)

const chunks = []
let resolve = null
let done = false

function enqueue(value) {
chunks.push(value)
if (resolve) {
const r = resolve
resolve = null
r()
}
}
const queue = createSseQueue()

function sseEvent(eventType, data) {
return `event: ${eventType}\ndata: ${JSON.stringify(data)}\n\n`
}

async function* generateSse() {
enqueue(
queue.enqueue(
sseEvent("response.created", {
type: "response.created",
response: {
Expand All @@ -770,7 +773,7 @@ export function createProxyFetchHandler(client) {
},
}),
)
enqueue(
queue.enqueue(
sseEvent("response.output_item.added", {
type: "response.output_item.added",
output_index: 0,
Expand All @@ -786,7 +789,7 @@ export function createProxyFetchHandler(client) {
system,
(delta) => {
if (partIndex === 0) {
enqueue(
queue.enqueue(
sseEvent("response.content_part.added", {
type: "response.content_part.added",
item_id: itemID,
Expand All @@ -797,7 +800,7 @@ export function createProxyFetchHandler(client) {
)
partIndex++
}
enqueue(
queue.enqueue(
sseEvent("response.output_text.delta", {
type: "response.output_text.delta",
item_id: itemID,
Expand All @@ -809,7 +812,7 @@ export function createProxyFetchHandler(client) {
},
)
.then((streamResult) => {
enqueue(
queue.enqueue(
sseEvent("response.output_text.done", {
type: "response.output_text.done",
item_id: itemID,
Expand All @@ -818,14 +821,14 @@ export function createProxyFetchHandler(client) {
text: "",
}),
)
enqueue(
queue.enqueue(
sseEvent("response.output_item.done", {
type: "response.output_item.done",
output_index: 0,
item: { id: itemID, type: "message", status: "completed", role: "assistant" },
}),
)
enqueue(
queue.enqueue(
sseEvent("response.completed", {
type: "response.completed",
response: {
Expand All @@ -849,7 +852,7 @@ export function createProxyFetchHandler(client) {
error: errMsg,
requestedModel: body.model,
})
enqueue(
queue.enqueue(
sseEvent("response.failed", {
type: "response.failed",
response: {
Expand All @@ -863,26 +866,10 @@ export function createProxyFetchHandler(client) {
)
})
.finally(() => {
done = true
if (resolve) {
const r = resolve
resolve = null
r()
}
queue.finish()
})

while (true) {
while (chunks.length > 0) {
yield chunks.shift()
}
if (done) break
await new Promise((r) => {
resolve = r
})
}
while (chunks.length > 0) {
yield chunks.shift()
}
yield* queue.generateChunks()

await runPromise
}
Expand Down
Loading