content: The sparse_grad_rng generator in ZeroFlow is hardcoded to choose 'cuda' if available, ignoring the actual device of the parameters. When the model lives on CPU or 'mps', the torch.Generator(device='cuda') call still succeeds at construction, but generating on a mismatched device when fast_random_mask_like later passes tensors on a different device will either produce silent inconsistencies or fail. Additionally, SimpleNamespace is imported twice and self.get_grad_reduce(grad_reduce) is called twice in sam.py (once in the base __init__ and again in the subclass).
file: src/infty/optim/zeroth_order_updates/zeroflow.py
code:
self.sparse_grad_rng = torch.Generator(device='cuda' if torch.cuda.is_available() else 'cpu')
description: The RNG generator should be built on a device matching the parameters, derived from next(self.model.parameters()).device. Also clean up the duplicate SimpleNamespace import on line 11, and remove the redundant second call to get_grad_reduce(grad_reduce) in sam.py. This removes a hidden portability issue on non-CUDA setups (e.g. Mac MPS).
content: The
sparse_grad_rnggenerator in ZeroFlow is hardcoded to choose 'cuda' if available, ignoring the actual device of the parameters. When the model lives on CPU or 'mps', thetorch.Generator(device='cuda')call still succeeds at construction, but generating on a mismatched device whenfast_random_mask_likelater passes tensors on a different device will either produce silent inconsistencies or fail. Additionally,SimpleNamespaceis imported twice andself.get_grad_reduce(grad_reduce)is called twice insam.py(once in the base__init__and again in the subclass).file: src/infty/optim/zeroth_order_updates/zeroflow.py
code:
description: The RNG generator should be built on a device matching the parameters, derived from
next(self.model.parameters()).device. Also clean up the duplicateSimpleNamespaceimport on line 11, and remove the redundant second call toget_grad_reduce(grad_reduce)insam.py. This removes a hidden portability issue on non-CUDA setups (e.g. Mac MPS).