forked from HoagyC/sparse_coding
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcluster_runs.py
More file actions
157 lines (122 loc) · 4.2 KB
/
cluster_runs.py
File metadata and controls
157 lines (122 loc) · 4.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import itertools
import sys
import time
import progressbar
import torch
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
from autoencoders.ensemble import FunctionalEnsemble
def job_wrapper(job, ensemble_state_dict, cfg, args, tag, dataset, done_flag, progress_counter):
if not sys.warnoptions:
import warnings
warnings.filterwarnings("ignore")
ensemble = FunctionalEnsemble.from_state(ensemble_state_dict)
batch_size = args["batch_size"]
device = args["device"]
# can't use DataLoaders because they copy the dataset
# instead, we use a custom sampler for indexes
sampler = torch.utils.data.BatchSampler(
torch.utils.data.RandomSampler(range(dataset.shape[0])),
batch_size=batch_size,
drop_last=False,
)
job(ensemble, cfg, args, tag, sampler, dataset, progress_counter)
done_flag.value = 1
def job_wrapper_lite(ensemble_state_dict, cfg, args, tag, done_flag, progress_counter, job):
if not sys.warnoptions:
import warnings
warnings.filterwarnings("ignore")
ensemble = FunctionalEnsemble.from_state(ensemble_state_dict)
job(ensemble, cfg, args, tag, progress_counter)
done_flag.value = 1
def dispatch_lite(cfg, ensemble, args, name, job):
ensemble.to_shared_memory()
finished = mp.Value("i", 0)
progress = mp.Value("f", 0)
p = mp.Process(
target=job_wrapper_lite,
args=(ensemble.state_dict(), cfg, args, name, finished, progress, job),
)
p.start()
return p, finished, progress
def statusbar_lite(processes, n_points=1000):
# initialize progress bar
bar = progressbar.ProgressBar(
widgets=[
progressbar.Bar(),
" ",
progressbar.AdaptiveETA(),
" | ",
progressbar.Timer(),
" | ",
*[progressbar.Variable(tag, precision=0, width=1, format="{formatted_value}") for (_, _, _), _, tag in processes],
],
max_value=n_points,
)
return bar
def update_statusbar_lite(bar, processes, n_points=1000):
sum_progress = sum(progress.value for (_, _, progress), _, _ in processes)
mean_progress = sum_progress / len(processes)
progress_count = int(mean_progress * n_points)
bar.update(progress_count, **{tag: done.value for (_, done, _), _, tag in processes})
def collect_lite(processes):
all_done = all(done.value == 1 for (_, done, _), _, _ in processes)
if all_done:
for (p, _, _), _, _ in processes:
p.join()
return True
else:
return False
def dispatch_job_on_chunk(ensembles, cfg, dataset, job):
dataset.pin_memory()
dataset.share_memory_()
for ensemble, _, _ in ensembles:
ensemble.to_shared_memory()
processes = []
done_flags = []
progress_counters = []
n_batches_total = 0
for ensemble, args, tag in ensembles:
finished = mp.Value("i", 0)
progress = mp.Value("i", 0)
p = mp.Process(
target=job_wrapper,
args=(
job,
ensemble.state_dict(),
cfg,
args,
tag,
dataset,
finished,
progress,
),
)
p.start()
processes.append(p)
done_flags.append(finished)
n_batches_total += dataset.shape[0] // args["batch_size"] + 1
progress_counters.append(progress)
bar = progressbar.ProgressBar(
widgets=[
progressbar.Bar(),
" ",
progressbar.AdaptiveETA(),
" | ",
progressbar.Timer(),
" | ",
*[progressbar.Variable(tag, precision=0, width=1, format="{formatted_value}") for _, _, tag in ensembles],
],
max_value=n_batches_total,
)
while True:
done = [finished.value for finished in done_flags]
n_batches_done = sum(counter.value for counter in progress_counters)
bar.update(n_batches_done, **{tag: done.value for done, (_, _, tag) in zip(done_flags, ensembles)})
if all(done):
break
time.sleep(0.1)
for p in processes:
p.join()