Skip to content

Fix calibrator pickle compatibility and add clean inference API#16

Open
michaelzenkay wants to merge 2 commits intoyala:masterfrom
michaelzenkay:fix/calibrator-compat-and-inference-api
Open

Fix calibrator pickle compatibility and add clean inference API#16
michaelzenkay wants to merge 2 commits intoyala:masterfrom
michaelzenkay:fix/calibrator-compat-and-inference-api

Conversation

@michaelzenkay
Copy link
Copy Markdown

Summary

  • Fix CalibratedClassifierCV pickle loading across sklearn versions (0.20 → 1.7+) by patching renamed internal attributes (calibratorscalibrators_, classesclasses_, base_estimatorestimator)
  • Add scripts/infer.py — a clean programmatic API (predict()) that builds args internally instead of requiring sys.argv manipulation
  • Fix torch.load for PyTorch 2.6+ (weights_only=False)

Details

Calibrator compatibility (scripts/main.py)

Calibrator pickles saved with older sklearn versions fail on newer versions due to renamed internals:

Pickle saved with Breaks on Error
sklearn ~0.20 >= 0.22 ModuleNotFoundError: sklearn.svm.classes
sklearn 0.23 0.24+ _CalibratedClassifier missing calibrators attr
sklearn 0.23 1.2+ CalibratedClassifierCV missing estimator attr

Added _patch_calibrator() to fix attribute names after loading, and load_calibrator() with graceful error handling.

Clean inference API (scripts/infer.py)

New programmatic interface that avoids sys.argv manipulation:

from scripts.infer import predict

df = predict(
    metadata_csv="path/to/metadata.csv",
    model_dir="path/to/models/",
    calibrate=True,
)
# df has columns: patient_exam_id, 1_year_risk, ..., 5_year_risk

PyTorch 2.6+ compatibility (onconet/models/factory.py)

Added weights_only=False to torch.load calls since PyTorch 2.6 changed the default to True.

Test plan

  • Load calibrator pickle on sklearn 0.24.2 and 1.7.0, verify predict_proba works
  • Run infer.predict() on single exam, verify calibrated output
  • Run scripts/main.py end-to-end, verify CSV export with calibrated scores
  • Batch inference on 1k+ exams (Windows + Linux HPC), calibrated output verified

- Fix calibrator loading across sklearn versions (0.20 -> 1.7+) by patching
  renamed internal attributes (calibrators/calibrators_, classes/classes_,
  base_estimator/estimator) at load time
- Add `scripts/infer.py` with a clean `predict()` function for programmatic
  inference without sys.argv manipulation. Auto-detects model files in a
  directory and returns a pandas DataFrame with calibrated risk scores.
- Extract `run()` function from `scripts/main.py` so it can be called
  programmatically (previously all logic was in `if __name__ == '__main__'`)
- Make `import git` optional (try/except) so the code works in environments
  without GitPython installed
- Add `weights_only=False` to `torch.load` in model factory for PyTorch 2.6+
  compatibility
- Update README with sklearn compatibility note and inference API docs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant