Skip to content

Commit 84bc143

Browse files
CISCabetlen
andauthored
fix: match Transformers tojson in chat template rendering (abetlen#1486)
* Render chat template tojson filter as unicode * replace tojson with custom filter * add indent parameter * add break, continue and strftime * renamed to strftime_now * add separators and sort_keys plus import fix * allow ensure_ascii parameter as a few templates are using it now --------- Co-authored-by: abetlen <abetlen@gmail.com>
1 parent 5848020 commit 84bc143

2 files changed

Lines changed: 23 additions & 2 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
- fix: match Transformers `tojson` in chat template rendering by @CISC in #1486
1011
- fix: use env var configured multimodal library override paths when loading shared libraries by @navratil-matej in #1782
1112
- feat: add Jinja2 loop controls to chat templates by @handshape in #2018
1213
- fix: avoid cleanup errors for partially initialized `LlamaModel` objects by @usernames122 in #2173

llama_cpp/llama_chat_format.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def __init__(
219219
set(stop_token_ids) if stop_token_ids is not None else None
220220
)
221221

222-
self._environment = ImmutableSandboxedEnvironment(
222+
environment = ImmutableSandboxedEnvironment(
223223
loader=jinja2.BaseLoader(),
224224
trim_blocks=True,
225225
lstrip_blocks=True,
@@ -229,12 +229,32 @@ def __init__(
229229
Jinja2ChatFormatter.IgnoreGenerationTags,
230230
jinja2.ext.loopcontrols,
231231
],
232-
).from_string(self.template)
232+
)
233+
# Match Transformers' chat-template JSON rendering behavior.
234+
# https://github.com/huggingface/transformers/blob/39603d0e5cdb6f00e8d473d7fcbb01032d709181/src/transformers/utils/chat_template_utils.py#L481-L484
235+
environment.filters["tojson"] = self.tojson
236+
self._environment = environment.from_string(self.template)
233237

234238
@staticmethod
235239
def strftime_now(f: str) -> str:
236240
return datetime.now().strftime(f)
237241

242+
@staticmethod
243+
def tojson(
244+
x: Any,
245+
ensure_ascii: bool = False,
246+
indent: Optional[int] = None,
247+
separators: Optional[Tuple[str, str]] = None,
248+
sort_keys: bool = False,
249+
) -> str:
250+
return json.dumps(
251+
x,
252+
ensure_ascii=ensure_ascii,
253+
indent=indent,
254+
separators=separators,
255+
sort_keys=sort_keys,
256+
)
257+
238258
def __call__(
239259
self,
240260
*,

0 commit comments

Comments
 (0)