diff --git a/src/accelerate/parallelism_config.py b/src/accelerate/parallelism_config.py index c4135f2f791..22d15d3678c 100644 --- a/src/accelerate/parallelism_config.py +++ b/src/accelerate/parallelism_config.py @@ -248,7 +248,7 @@ def get_device_mesh(self, device_type: Optional[str] = None): if device_type is not None: self.device_mesh = self.build_device_mesh(device_type) else: - raise ("You need to pass a device_type e.g cuda to build the device mesh") + raise ValueError("You need to pass a device_type e.g cuda to build the device mesh") else: if device_type is not None: if self.device_mesh.device_type != device_type: