Skip to content

[Type] ndarray typing 4: Make primitive dtypes Python classes wrapping DataTypeCxx#414

Draft
hughperkins wants to merge 1 commit intohp/typing-t4-3-cook-dtypefrom
hp/typing-t4-4-dtype-classes
Draft

[Type] ndarray typing 4: Make primitive dtypes Python classes wrapping DataTypeCxx#414
hughperkins wants to merge 1 commit intohp/typing-t4-3-cook-dtypefrom
hp/typing-t4-4-dtype-classes

Conversation

@hughperkins
Copy link
Collaborator

Convert primitive dtypes (f32, i32, etc.) from bare DataTypeCxx module-level variables into Python classes with a PrimitiveMeta metaclass. Each class has a .cxx attribute holding the underlying DataTypeCxx, and the metaclass delegates eq, hash, getattr for backward compatibility.

Update cook_dtype, to_quadrants_type, MAP_TYPE_IDS, and type utility functions to handle the new class-based types. Add PrimitiveBase checks in expr_init and quant.py.

Issue: #

Brief Summary

copilot:summary

Walkthrough

copilot:walkthrough

@hughperkins hughperkins force-pushed the hp/typing-t4-3-cook-dtype branch from 89f34d8 to 72a9636 Compare March 12, 2026 03:49
@hughperkins hughperkins force-pushed the hp/typing-t4-4-dtype-classes branch from 2ec1c5d to 24584de Compare March 12, 2026 03:49
@hughperkins
Copy link
Collaborator Author

Opus 4.6 review:

PR Review: hp/typing-t4-4-dtype-classes

Summary

This PR refactors primitive dtypes (f32, i32, etc.) from bare DataTypeCxx module-level variables into Python classes with a PrimitiveMeta metaclass. Each dtype class holds a .cxx class variable pointing to the underlying C++ DataTypeCxx instance. The metaclass delegates __eq__, __hash__, __getattr__, and __repr__ to the .cxx object for backward compatibility. The cook_dtype and to_quadrants_type functions are updated to normalize both Python classes and raw DataTypeCxx instances to DataTypeCxx for C++ APIs.


Issues Found

1. PrimitiveBase._registry overwrites on subclass registration

In primitive_types.py, __init_subclass__ does:

if hasattr(cls, "cxx"):
    PrimitiveBase._registry[cls.cxx] = cls

If a user subclasses a primitive (e.g. class MyF32(f32): pass), MyF32.cxx inherits f32_cxx, so _registry[f32_cxx] = MyF32 overwrites the previous f32 entry. cxx_to_py(f32_cxx) would then return MyF32 instead of f32. cxx_to_py is not used in the codebase yet, but this is a latent bug for future use. Consider either:

  • Only registering leaf classes (e.g. those that define their own cxx), or
  • Documenting that cxx_to_py returns the most recently registered subclass for a given cxx.

2. real_type_ids and integer_type_ids do not include C++ ids

type_ids is updated to include both Python class ids and DataTypeCxx ids (type_ids = _py_type_ids | cxx_type_ids), but real_type_ids and integer_type_ids are not:

real_type_ids = {id(t) for t in real_types}      # only Python classes
integer_type_ids = {id(t) for t in integer_types}

In _func_base.py, needed_arg_type_id in primitive_types.real_type_ids and integer_type_ids is used for scalar argument validation. needed_arg_type comes from kernel parameter annotations, which are Python classes (e.g. def foo(x: f32)), so id(needed_arg_type) is id(f32) and is in real_type_ids. This is fine for the current call paths. If needed_arg_type could ever be a raw DataTypeCxx (e.g. from some future API), it would not match. Worth a brief comment or test to lock in this assumption.


No Critical Issues

  • Metaclass delegation: PrimitiveMeta.__eq__ and __ne__ correctly handle both PrimitiveMeta (identity) and DataTypeCxx (delegate to .cxx). Comparisons like f32_cxx == f32 and f32 == f32_cxx work via f32.__eq__(f32_cxx) returning cls.cxx == other.
  • Set membership: dtype in primitive_types.real_types works when dtype is DataTypeCxx, because f32.__eq__(f32_cxx) returns True and set lookup uses __eq__.
  • _cook_cache: Uses id() of module-level constants (Python classes and *_cxx objects) that are never GC'd. Object id reuse is not a concern for these keys.
  • cook_dtype / to_quadrants_type: All paths (Python classes, DataTypeCxx, numpy/torch dtypes, Type wrappers, builtins) are covered. Cache lookup is correct; uncached paths fall through to the right branches.
  • quant.py _to_ptr: Correctly converts PrimitiveBase classes to .cxx before get_ptr(), fixing the case where compute=i32 (a class) was passed directly to C++.
  • types/utils.py: _cook_if_needed correctly normalizes PrimitiveBase to .cxx before calling C++ helpers.
  • type_ids: Combining _py_type_ids and cxx_type_ids is correct for id(annotation) in type_ids and id(arg.element_type) in type_ids, where element_type can be DataTypeCxx from C++.

Suggestions for Improvement

  1. cxx_to_py robustness: Add a KeyError handler or get() with a clear error message when the DataTypeCxx is not a known primitive (e.g. quantized or custom types).
  2. _cook_cache documentation: Add a short comment that the cache is safe because it only keys on module-level constants that live for the process lifetime.
  3. MAP_TYPE_IDS semantics: The int branch in to_quadrants_type uses MAP_TYPE_IDS[dt] where dt is a "primitive type id" (i.e. id(f32) or id(f32_cxx)). Consider a brief docstring or comment clarifying that dt must be such an id when type(dt) is int.
  4. Test coverage: Add tests for:
    • f32_cxx == f32 and f32 == f32_cxx
    • f32_cxx in real_types
    • cook_dtype(f32) and cook_dtype(f32_cxx) both returning the same DataTypeCxx
    • to_numpy_type(f32_cxx) and to_pytorch_type(f32_cxx) (dtype from C++ side)

Edge Cases Considered

  • isinstance(dtype, DataTypeCxx): After the refactor, dtype from annotations is a Python class, so isinstance(f32, DataTypeCxx) is False. Code that needs C++ types uses cook_dtype() or to_quadrants_type(), which handle both classes and DataTypeCxx. No problematic isinstance/type checks were found.
  • dtype_to_torch_dtype: Uses i32, f32, etc. as dict keys. PrimitiveMeta.__hash__ delegates to hash(cls.cxx), so these remain valid and stable keys.
  • impl.expr_init: The new PrimitiveBase branch correctly returns rhs.cxx when rhs is a dtype class.

@hughperkins hughperkins force-pushed the hp/typing-t4-4-dtype-classes branch from 24584de to cb8f4a0 Compare March 12, 2026 04:33
@hughperkins hughperkins force-pushed the hp/typing-t4-3-cook-dtype branch from 72a9636 to 9b5e1a8 Compare March 12, 2026 04:33
@hughperkins hughperkins force-pushed the hp/typing-t4-4-dtype-classes branch from cb8f4a0 to 9899efb Compare March 12, 2026 04:34
@hughperkins hughperkins force-pushed the hp/typing-t4-3-cook-dtype branch 2 times, most recently from 1adb863 to 71dcd62 Compare March 12, 2026 04:37
@hughperkins hughperkins force-pushed the hp/typing-t4-4-dtype-classes branch 4 times, most recently from bbc0f0b to 3a1b4f5 Compare March 12, 2026 04:52
Convert primitive dtypes (f32, i32, etc.) from bare DataTypeCxx module-level
variables into Python classes with a PrimitiveMeta metaclass. Each class has
a .cxx attribute holding the underlying DataTypeCxx, and the metaclass
delegates __eq__, __hash__, __getattr__ for backward compatibility.

Update cook_dtype, to_quadrants_type, MAP_TYPE_IDS, and type utility functions
to handle the new class-based types. Add PrimitiveBase checks in expr_init
and quant.py.
@hughperkins hughperkins force-pushed the hp/typing-t4-4-dtype-classes branch from 3a1b4f5 to 9bbc16b Compare March 12, 2026 04:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant