66
77import httpx
88import jinja2
9+ import jinja2 .sandbox
910
1011from dstack .gateway .errors import GatewayError
1112from dstack .gateway .openai .clients import ChatCompletionsClient
@@ -31,9 +32,17 @@ def __init__(
3132 headers = {} if host is None else {"Host" : host },
3233 timeout = 60 ,
3334 )
34- self .chat_template = jinja2 .Template (chat_template )
3535 self .eos_token = eos_token
3636
37+ try :
38+ jinja_env = jinja2 .sandbox .ImmutableSandboxedEnvironment (
39+ trim_blocks = True , lstrip_blocks = True
40+ )
41+ jinja_env .globals ["raise_exception" ] = raise_exception
42+ self .chat_template = jinja_env .from_string (chat_template )
43+ except jinja2 .TemplateError as e :
44+ raise GatewayError (f"Failed to compile chat template: { e } " )
45+
3746 async def generate (self , request : ChatCompletionsRequest ) -> ChatCompletionsResponse :
3847 payload = self .get_payload (request )
3948 resp = await self .client .post ("/generate" , json = payload )
@@ -123,10 +132,14 @@ async def stream(self, request: ChatCompletionsRequest) -> AsyncIterator[ChatCom
123132 yield chunk
124133
125134 def get_payload (self , request : ChatCompletionsRequest ) -> Dict :
126- inputs = self .chat_template .render (
127- messages = request .messages ,
128- add_generation_prompt = True ,
129- )
135+ try :
136+ inputs = self .chat_template .render (
137+ messages = request .messages ,
138+ add_generation_prompt = True ,
139+ )
140+ except jinja2 .TemplateError as e :
141+ raise GatewayError (f"Failed to render chat template: { e } " )
142+
130143 stop = ([request .stop ] if isinstance (request .stop , str ) else request .stop ) or []
131144 if self .eos_token not in stop :
132145 stop .append (self .eos_token )
@@ -178,3 +191,7 @@ def __del__(self):
178191 asyncio .get_running_loop ().create_task (self .aclose ())
179192 except Exception :
180193 pass
194+
195+
196+ def raise_exception (message : str ):
197+ raise jinja2 .TemplateError (message )
0 commit comments