Skip to content

chore: float64 -> float32 in process_text_na_dataframe#926

Draft
ggprior wants to merge 1 commit into
mainfrom
georg/tabpfn-dtype-float32
Draft

chore: float64 -> float32 in process_text_na_dataframe#926
ggprior wants to merge 1 commit into
mainfrom
georg/tabpfn-dtype-float32

Conversation

@ggprior
Copy link
Copy Markdown
Contributor

@ggprior ggprior commented May 7, 2026

No description provided.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request modifies the process_text_na_dataframe function in src/tabpfn/preprocessing/clean.py to cast the output to float32 instead of float64. Feedback indicates that this change introduces numerical stability risks, contradicts the function's docstrings, and creates inconsistencies with the rest of the preprocessing pipeline which defaults to float64.

X_encoded[:, string_cols_ix],
)
return typing.cast("np.ndarray", X_encoded.astype(np.float64))
return typing.cast("np.ndarray", X_encoded.astype(np.float32))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Changing the return type to float32 introduces several inconsistencies:

  1. Numerical Stability: src/tabpfn/constants.py (lines 50-51) defines DEFAULT_NUMPY_PREPROCESSING_DTYPE as np.float64 specifically to avoid overflows during transformations like Yeo-Johnson. Hardcoding float32 here may lead to issues in subsequent preprocessing steps.
  2. Docstring Inconsistency: The docstring for process_text_na_dataframe (lines 142 and 145) still explicitly mentions conversion to float64.
  3. Pipeline Inconsistency: fix_dtypes (line 69) defaults to float64. Since clean_data calls both, the numeric_dtype setting in fix_dtypes is now effectively overridden by this hardcoded float32 cast.

If the intention is to move the pipeline to float32, consider updating the global constant or making the dtype a parameter to maintain consistency.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant