Skip to content

Commit afa3ccc

Browse files
authored
Reraise Jinja2 TemplateError (#840)
* Handle jinja2.TemplateError in gateway * Raise GatewayError in openai interface
1 parent 3121673 commit afa3ccc

2 files changed

Lines changed: 34 additions & 13 deletions

File tree

gateway/src/dstack/gateway/openai/clients/tgi.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import httpx
88
import jinja2
9+
import jinja2.sandbox
910

1011
from dstack.gateway.errors import GatewayError
1112
from dstack.gateway.openai.clients import ChatCompletionsClient
@@ -31,9 +32,17 @@ def __init__(
3132
headers={} if host is None else {"Host": host},
3233
timeout=60,
3334
)
34-
self.chat_template = jinja2.Template(chat_template)
3535
self.eos_token = eos_token
3636

37+
try:
38+
jinja_env = jinja2.sandbox.ImmutableSandboxedEnvironment(
39+
trim_blocks=True, lstrip_blocks=True
40+
)
41+
jinja_env.globals["raise_exception"] = raise_exception
42+
self.chat_template = jinja_env.from_string(chat_template)
43+
except jinja2.TemplateError as e:
44+
raise GatewayError(f"Failed to compile chat template: {e}")
45+
3746
async def generate(self, request: ChatCompletionsRequest) -> ChatCompletionsResponse:
3847
payload = self.get_payload(request)
3948
resp = await self.client.post("/generate", json=payload)
@@ -123,10 +132,14 @@ async def stream(self, request: ChatCompletionsRequest) -> AsyncIterator[ChatCom
123132
yield chunk
124133

125134
def get_payload(self, request: ChatCompletionsRequest) -> Dict:
126-
inputs = self.chat_template.render(
127-
messages=request.messages,
128-
add_generation_prompt=True,
129-
)
135+
try:
136+
inputs = self.chat_template.render(
137+
messages=request.messages,
138+
add_generation_prompt=True,
139+
)
140+
except jinja2.TemplateError as e:
141+
raise GatewayError(f"Failed to render chat template: {e}")
142+
130143
stop = ([request.stop] if isinstance(request.stop, str) else request.stop) or []
131144
if self.eos_token not in stop:
132145
stop.append(self.eos_token)
@@ -178,3 +191,7 @@ def __del__(self):
178191
asyncio.get_running_loop().create_task(self.aclose())
179192
except Exception:
180193
pass
194+
195+
196+
def raise_exception(message: str):
197+
raise jinja2.TemplateError(message)

gateway/src/dstack/gateway/openai/routes.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from fastapi import APIRouter, Depends
44
from fastapi.responses import StreamingResponse
55

6+
from dstack.gateway.errors import GatewayError
67
from dstack.gateway.openai.schemas import (
78
ChatCompletionsChunk,
89
ChatCompletionsRequest,
@@ -25,14 +26,17 @@ async def get_models(
2526
async def post_chat_completions(
2627
project: str, body: ChatCompletionsRequest, store: Annotated[OpenAIStore, Depends(get_store)]
2728
):
28-
client = await store.get_chat_client(project, body.model)
29-
if not body.stream:
30-
return await client.generate(body)
31-
else:
32-
return StreamingResponse(
33-
stream_chunks(client.stream(body)),
34-
media_type="text/event-stream",
35-
)
29+
try:
30+
client = await store.get_chat_client(project, body.model)
31+
if not body.stream:
32+
return await client.generate(body)
33+
else:
34+
return StreamingResponse(
35+
stream_chunks(client.stream(body)),
36+
media_type="text/event-stream",
37+
)
38+
except GatewayError as e:
39+
raise e.http()
3640

3741

3842
async def stream_chunks(chunks: AsyncIterator[ChatCompletionsChunk]) -> AsyncIterator[bytes]:

0 commit comments

Comments
 (0)