Skip to content

Missing device-aware RNG construction in ZeroFlow #8

Description

@MichaeWeixi

content: The sparse_grad_rng generator in ZeroFlow is hardcoded to choose 'cuda' if available, ignoring the actual device of the parameters. When the model lives on CPU or 'mps', the torch.Generator(device='cuda') call still succeeds at construction, but generating on a mismatched device when fast_random_mask_like later passes tensors on a different device will either produce silent inconsistencies or fail. Additionally, SimpleNamespace is imported twice and self.get_grad_reduce(grad_reduce) is called twice in sam.py (once in the base __init__ and again in the subclass).
file: src/infty/optim/zeroth_order_updates/zeroflow.py
code:

self.sparse_grad_rng = torch.Generator(device='cuda' if torch.cuda.is_available() else 'cpu')

description: The RNG generator should be built on a device matching the parameters, derived from next(self.model.parameters()).device. Also clean up the duplicate SimpleNamespace import on line 11, and remove the redundant second call to get_grad_reduce(grad_reduce) in sam.py. This removes a hidden portability issue on non-CUDA setups (e.g. Mac MPS).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions