@junpenglao I'm taking a look at implementing the run_inference_loop from here. I'm running into a potential issue. It seems as though some inference algorithms require more than rng_key and state as inputs to their step function. Take for example sgld which requires a minibatch of data and a step size at each call to its .step.
I suspect this will also be the case too for the variational inference algorithms when they are in a more final state. In these situations, run_inference_loop cannot currently handle such cases.
Should I just leave these particular examples where this is the case alone? And then use the run_inference_loop wherever I can?
One potential solution to allow the incorporation of batches to be passed in during step is to modify run_inference_loop like so:
def run_inference_algorithm(
rng_key,
initial_state_or_position,
inference_algorithm,
batches,
num_steps,
): -> tuple[State, State, Info]:
try:
initial_state = inference_algorithm.init(initial_state_or_position)
except TypeError:
# We assume initial_state is already in the right format.
initial_state = initial_state_or_position
keys = split(rng_key, num_steps)
@jax.jit
def one_step(state, rng_key):
batch = next(batches)
state, info = inference_algorithm.step(rng_key, state, batch)
return state, (state, info)
final_state, (state_history, info_history) = lax.scan(one_step, initial_state, keys)
return final_state, state_history, info_history
Where batches is any iterator (possibly a generator) over batches of data examples. However, if batches is a generator that uses any jax operations, then I have run into issues with scan (not exactly sure the reason), but if batches is a generator that uses (say numpy) then it does work.
An example of a numpy data generator:
def data_stream(seed, data, batch_size, data_size):
"""Return an iterator over batches of data."""
rng = np.random.RandomState(seed)
num_batches = int(np.ceil(data_size / batch_size))
while True:
perm = rng.permutation(data_size)
for i in range(num_batches):
batch_idx = perm[i * batch_size : (i + 1) * batch_size]
yield data[batch_idx]
batches = data_stream(...)
This also works with (say huggingface dataset) data loader. Something like
from datasets import Dataset
batches = Dataset.from_dict({"data":data}).with_format("jax").iter(batch_size=50)
I'm not sure this would be the preferred solution. I am also In any case, I'll think about it some more.
Thanks!
@junpenglao I'm taking a look at implementing the
run_inference_loopfrom here. I'm running into a potential issue. It seems as though some inference algorithms require more thanrng_keyandstateas inputs to theirstepfunction. Take for examplesgldwhich requires a minibatch of data and a step size at each call to its.step.I suspect this will also be the case too for the variational inference algorithms when they are in a more final state. In these situations,
run_inference_loopcannot currently handle such cases.Should I just leave these particular examples where this is the case alone? And then use the
run_inference_loopwherever I can?One potential solution to allow the incorporation of
batches to be passed in during step is to modifyrun_inference_looplike so:Where
batchesis any iterator (possibly a generator) over batches of data examples. However, ifbatchesis a generator that uses anyjaxoperations, then I have run into issues withscan(not exactly sure the reason), but ifbatchesis a generator that uses (say numpy) then it does work.An example of a numpy data generator:
This also works with (say huggingface dataset) data loader. Something like
I'm not sure this would be the preferred solution. I am also In any case, I'll think about it some more.
Thanks!