From 6bfa639591a99997855d6697656b6b9cde49b2cc Mon Sep 17 00:00:00 2001 From: Ashutosh0x Date: Tue, 6 Jan 2026 20:35:52 +0530 Subject: [PATCH] fix(dataclass): add type hints to dataclass wrapper (closes #321) --- chex/_src/dataclass.py | 35 ++++++++++++++++++----------------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/chex/_src/dataclass.py b/chex/_src/dataclass.py index ea111259..d25bf5e1 100644 --- a/chex/_src/dataclass.py +++ b/chex/_src/dataclass.py @@ -21,6 +21,7 @@ from absl import logging import jax +from typing import Any, Callable, Optional, Type, Union from typing_extensions import dataclass_transform # pytype: disable=not-supported-yet @@ -92,17 +93,17 @@ def new_init(self, *orig_args, **orig_kwargs): @dataclass_transform() def dataclass( - cls=None, + cls: Optional[Type[Any]] = None, *, - init=True, - repr=True, # pylint: disable=redefined-builtin - eq=True, - order=False, - unsafe_hash=False, - frozen=False, + init: bool = True, + repr: bool = True, # pylint: disable=redefined-builtin + eq: bool = True, + order: bool = False, + unsafe_hash: bool = False, + frozen: bool = False, kw_only: bool = False, - mappable_dataclass=True, # pylint: disable=redefined-outer-name -): + mappable_dataclass: bool = True, # pylint: disable=redefined-outer-name +) -> Union[Type[Any], Callable[[Type[Any]], Type[Any]]]: """JAX-friendly wrapper for :py:func:`dataclasses.dataclass`. This wrapper class registers new dataclasses with JAX so that tree utils @@ -148,14 +149,14 @@ class _Dataclass(): def __init__( self, - init=True, - repr=True, # pylint: disable=redefined-builtin - eq=True, - order=False, - unsafe_hash=False, - frozen=False, - kw_only=False, - mappable_dataclass=True, # pylint: disable=redefined-outer-name + init: bool = True, + repr: bool = True, # pylint: disable=redefined-builtin + eq: bool = True, + order: bool = False, + unsafe_hash: bool = False, + frozen: bool = False, + kw_only: bool = False, + mappable_dataclass: bool = True, # pylint: disable=redefined-outer-name ): self.init = init self.repr = repr # pylint: disable=redefined-builtin