diff --git a/test/distributed/tensor/debug/test_debug_mode.py b/test/distributed/tensor/debug/test_debug_mode.py index eb20889f2a04f..e3030bd5982fd 100644 --- a/test/distributed/tensor/debug/test_debug_mode.py +++ b/test/distributed/tensor/debug/test_debug_mode.py @@ -52,10 +52,14 @@ @requires_accelerator def _get_accelerator_memory(): - try: - return torch.accelerator.get_memory_info(0)[1] - except (NotImplementedError): - return 0 # Return 0, as that would help skip the test is not skipped + if not torch.accelerator.is_available(): + return -1 + device = torch.accelerator.current_accelerator() + if not hasattr(torch, device.type) or not hasattr( + getattr(torch, device.type), "mem_get_info" + ): + return -1 + return torch.accelerator.get_memory_info(0)[1] class TestDTensorDebugMode(TestCase): def tearDown(self): @@ -917,12 +921,13 @@ def test_check_hash_mismatches(self): [call["call"] for call in mismatches], ["aten::sin", "aten::sum"] ) - @unittest.skipIf( - not torch.accelerator.is_available() - or _get_accelerator_memory() < 2**26, - "Being conservative, test peak memory is 25MB?", - ) def test_tensor_hash_redistribute(self): + + mem = _get_accelerator_memory() + if mem ==-1: + self.skipTest("No accelerator available or memory query not supported") + if mem < 2**26: + self.skipTest("Requires accelerator with at least 64MB memory") # test that hashing collectives gives correct results mesh = DeviceMesh(self.device_type, list(range(self.world_size)))