-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtool.py
More file actions
67 lines (59 loc) · 2.21 KB
/
tool.py
File metadata and controls
67 lines (59 loc) · 2.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#【准备】GPU管理工具函数
import gc
import torch
class GPUManager:
"""仅针对 CUDA 显存:查看、整理缓存、彻底释放。"""
def __init__(self, device_id=None):
if torch.cuda.is_available():
if device_id is not None:
torch.cuda.set_device(device_id)
self._device = torch.device(f"cuda:{torch.cuda.current_device()}")
else:
self._device = torch.device("cpu")
@property
def device(self):
return self._device
def check(self, title="显存"):
if not torch.cuda.is_available():
print(f"[{title}] 无 CUDA,当前 device={self._device}")
return
idx = torch.cuda.current_device()
props = torch.cuda.get_device_properties(idx)
total_gb = props.total_memory / 1024**3
alloc = torch.cuda.memory_allocated(idx) / 1024**3
reserved = torch.cuda.memory_reserved(idx) / 1024**3
peak = torch.cuda.max_memory_allocated(idx) / 1024**3
print(
f"[{title}] 已分配 {alloc:.3f} GB | 预留 {reserved:.3f} GB | 峰值 {peak:.3f} GB / 总显存 {total_gb:.3f} GB"
)
def clean(self):
"""同步并回收缓存;不删用户张量/模块。"""
if not torch.cuda.is_available():
return
torch.cuda.synchronize()
gc.collect()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
def clear(self, extra_names=None):
"""从 __main__ 删除 model、data 等名,再尽量释放显存。"""
import __main__
ns = __main__.__dict__
to_del = {"model", "data"}
if extra_names:
if isinstance(extra_names, str):
to_del.add(extra_names)
else:
to_del.update(extra_names)
for name in list(to_del):
if name in ns:
del ns[name]
gc.collect()
if torch.cuda.is_available():
torch.cuda.synchronize()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
torch.cuda.reset_peak_memory_stats()
try:
torch.cuda.reset_accumulated_memory_stats()
except AttributeError:
pass