-
Notifications
You must be signed in to change notification settings - Fork 12
Expand file tree
/
Copy pathutils.py
More file actions
47 lines (40 loc) · 1.45 KB
/
utils.py
File metadata and controls
47 lines (40 loc) · 1.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from texttable import Texttable
from torch_sparse import SparseTensor
import torch
import numpy as np
MB = 1024 ** 2
GB = 1024 ** 3
def print_args(args):
_dict = vars(args)
t = Texttable()
t.add_row(["Parameter", "Value"])
for k in _dict:
# if k in ['lr', 'dst_sample_rate', 'dst_walk_length', 'dst_update_interval', 'dst_update_rate']:
t.add_row([k, _dict[k]])
print(t.draw())
def get_memory_usage(gpu, print_info=False):
"""Get accurate gpu memory usage by querying torch runtime"""
allocated = torch.cuda.memory_allocated(gpu)
reserved = torch.cuda.memory_reserved(gpu)
if print_info:
print("allocated: %.2f MB" % (allocated / 1024 / 1024), flush=True)
print("reserved: %.2f MB" % (reserved / 1024 / 1024), flush=True)
return allocated
def compute_tensor_bytes(tensors):
"""Compute the bytes used by a list of tensors"""
if not isinstance(tensors, (list, tuple)):
tensors = [tensors]
ret = 0
for x in tensors:
if x.dtype in [torch.int64, torch.long]:
ret += np.prod(x.size()) * 8
if x.dtype in [torch.float32, torch.int, torch.int32]:
ret += np.prod(x.size()) * 4
elif x.dtype in [torch.bfloat16, torch.float16, torch.int16]:
ret += np.prod(x.size()) * 2
elif x.dtype in [torch.int8]:
ret += np.prod(x.size())
else:
print(x.dtype)
raise ValueError()
return ret