From b5204f5c8e5abbf05c1f69d998d800d68210e6d7 Mon Sep 17 00:00:00 2001 From: Andrew Deng Date: Tue, 28 Apr 2026 22:26:43 +0000 Subject: [PATCH] fix: update dynamo test in line with nightly --- test/dynamo/test_guard_manager.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/test/dynamo/test_guard_manager.py b/test/dynamo/test_guard_manager.py index 319104f7a5cf3..5132eff9e3fb2 100644 --- a/test/dynamo/test_guard_manager.py +++ b/test/dynamo/test_guard_manager.py @@ -2,7 +2,6 @@ import abc import functools import inspect -import unittest import weakref import torch @@ -257,8 +256,12 @@ def test_default_device_guard(self): guard = guards.DEFAULT_DEVICE(root, ["cpu device"], None) self.assertTrue(guard(foo)) + if not torch.accelerator.is_available(): + self.skipTest("Accelerator is not available") + try: - torch.set_default_device("cuda") + device = torch.accelerator.current_accelerator() + torch.set_default_device(device) self.assertFalse(guard(foo)) finally: torch.set_default_device(None) @@ -448,11 +451,14 @@ def test_weakref_alive_guard(self): del x self.assertFalse(guard(weakref_x())) - @unittest.skipIf(not torch.accelerator.is_available(), "requires accelerator") - @requires_cuda def test_call_function_no_args_guard(self): + if not torch.accelerator.is_available(): + self.skipTest("Accelerator is not available") + root = RootGuardManager() - x = torch.cuda.current_device() + device = torch.accelerator.current_accelerator() + # Use device.index which is device-agnostic (works on all accelerators) + x = device.index if device.index is not None else 0 guard = guards.EQUALS_MATCH(root, x, [0], None) self.assertTrue(guard(0)) self.assertFalse(guard(1))