Skip to content

Functions to Run Kernels #51

@PaulScemama

Description

@PaulScemama

@junpenglao I'm taking a look at implementing the run_inference_loop from here. I'm running into a potential issue. It seems as though some inference algorithms require more than rng_key and state as inputs to their step function. Take for example sgld which requires a minibatch of data and a step size at each call to its .step.

I suspect this will also be the case too for the variational inference algorithms when they are in a more final state. In these situations, run_inference_loop cannot currently handle such cases.

Should I just leave these particular examples where this is the case alone? And then use the run_inference_loop wherever I can?

One potential solution to allow the incorporation of batches to be passed in during step is to modify run_inference_loop like so:

def run_inference_algorithm(
    rng_key,
    initial_state_or_position,
    inference_algorithm,
    batches,
    num_steps,
):  -> tuple[State, State, Info]:
    try:
        initial_state = inference_algorithm.init(initial_state_or_position)
    except TypeError:
        # We assume initial_state is already in the right format.
        initial_state = initial_state_or_position

    keys = split(rng_key, num_steps)

    @jax.jit
    def one_step(state, rng_key):
        batch = next(batches)
        state, info = inference_algorithm.step(rng_key, state, batch)
        return state, (state, info)

    final_state, (state_history, info_history) = lax.scan(one_step, initial_state, keys)
    return final_state, state_history, info_history

Where batches is any iterator (possibly a generator) over batches of data examples. However, if batches is a generator that uses any jax operations, then I have run into issues with scan (not exactly sure the reason), but if batches is a generator that uses (say numpy) then it does work.

An example of a numpy data generator:

def data_stream(seed, data, batch_size, data_size):
    """Return an iterator over batches of data."""
    rng = np.random.RandomState(seed)
    num_batches = int(np.ceil(data_size / batch_size))
    while True:
        perm = rng.permutation(data_size)
        for i in range(num_batches):
            batch_idx = perm[i * batch_size : (i + 1) * batch_size]
            yield data[batch_idx]

batches = data_stream(...)

This also works with (say huggingface dataset) data loader. Something like

from datasets import Dataset
batches = Dataset.from_dict({"data":data}).with_format("jax").iter(batch_size=50)

I'm not sure this would be the preferred solution. I am also In any case, I'll think about it some more.

Thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions