Skip to content

Add ensemble_batch_size for single-device inference#906

Open
randommm wants to merge 1 commit into
PriorLabs:mainfrom
randommm:ensemble_batch_size
Open

Add ensemble_batch_size for single-device inference#906
randommm wants to merge 1 commit into
PriorLabs:mainfrom
randommm:ensemble_batch_size

Conversation

@randommm
Copy link
Copy Markdown
Contributor

@randommm randommm commented Apr 29, 2026

Issue

#905

Motivation and Context

On devices with large amounts of RAM like Strix Halo, this can greatly speed up results

Public API Changes

  • No Public API changes
  • Yes, Public API changes (Details below)

How Has This Been Tested?

local testing

Checklist

  • The changes have been tested locally.
  • Documentation has been updated (if the public API or usage changes).
  • A changelog entry has been added (see changelog/README.md), or "no changelog needed" label requested.
  • The code follows the project's style guidelines.
  • I have considered the impact of these changes on the public API.

@randommm randommm requested a review from a team as a code owner April 29, 2026 22:01
@randommm randommm requested review from adrian-prior and removed request for a team April 29, 2026 22:01
@chatgpt-codex-connector
Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.
Credits must be used to enable repository wide code reviews.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements ensemble batching for single-device predictions in TabPFN, enabling multiple compatible ensemble members to be processed in a single forward pass to improve performance. The changes introduce an ensemble_batch_size parameter across the API and update the inference engines to handle batched outputs. Feedback focuses on fixing a potential shape mismatch in embedding extraction, addressing an unused parameter in the on-demand engine, and standardizing telemetry timing for model execution.

Comment thread src/tabpfn/base.py Outdated
Comment thread src/tabpfn/inference.py
Comment thread src/tabpfn/inference.py
@randommm randommm force-pushed the ensemble_batch_size branch 2 times, most recently from 709dc67 to c65fda0 Compare April 29, 2026 22:09
Comment thread src/tabpfn/inference.py Outdated
Comment thread src/tabpfn/base.py
Copy link
Copy Markdown
Contributor

@adrian-prior adrian-prior left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @randommm,

Thanks for setting this up, and apologies for taking so long to get back on this! I do have some open questions, which I left on the PR, and there are also a bunch of comments left by Gemini, of which I think many are valid. Would you mind going through them?

@randommm randommm force-pushed the ensemble_batch_size branch 5 times, most recently from e3643a9 to a37bc35 Compare May 14, 2026 06:30
@randommm
Copy link
Copy Markdown
Contributor Author

@adrian-prior i think it should be all fixed now.

On devices with large amounts of RAM like Strix Halo,
this can greatly speed up results
@randommm randommm force-pushed the ensemble_batch_size branch from a37bc35 to a49e278 Compare May 14, 2026 07:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants