-
Notifications
You must be signed in to change notification settings - Fork 37
Expand file tree
/
Copy pathparallel_train.py
More file actions
158 lines (133 loc) · 6.47 KB
/
parallel_train.py
File metadata and controls
158 lines (133 loc) · 6.47 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
"""
This script solves as many puzzles as possible at the same time, trying to saturate your GPU(s). First,
it measures the amount of memory that each puzzle takes to solve, and then it schedules the puzzles in
a way that saturates your GPU memory. It does this by running a "server" in parallelize_runs() that
tracks each GPU's usage and greedily schedules jobs on the GPUs to fit under each GPU's budget.
"""
# We run 2 steps of every puzzle to determine how much memory each puzzle uses.
# We run 2000 steps per task at optimal puzzle parallelization under memory constraint.
# We have changed layers.direction_share() to make it run faster, and got something like a 5-10% speedup.
import os
import sys
import time
import json
import importlib
import multiprocessing
from multiprocessing import Pool
import numpy as np
import torch
import preprocessing
import train
import arc_compressor
import initializers
import multitensor_systems
import layers
import solution_selection
import visualization
import solve_task
# Getting all the task names, setting defaults and constants
multiprocessing.set_start_method('spawn', force=True)
torch.set_default_dtype(torch.float32)
torch.set_default_device('cuda')
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
# Function that can spawn processes thata solve a puzzle, and schedule them on GPUs to take up each GPUs quota
# "quota/usage" here is just a number per puzzle/GPU given to the function. It can be memory, or job count, or
# anything.
def parallelize_runs(gpu_quotas, task_usages, n_iterations, verbose=False):
"""
Runs a server that spawns processes to solve many ARC-AGI tasks in parallel.
Args:
gpu_quotas (list[float]): The max quota that each GPU has to use.
task_usages (list[float]): The amount of quota that each task uses.
n_iterations (int): The number of training iterations to use to solve each puzzle.
verbose (bool): Whether or not to print tqdm bars to show progress.
Returns:
Dict[str, int]: The max memory allocated for every puzzle.
Dict[str, list[Dict[str, list[list[int]]]]]: The guessed solution for every puzzle.
float: The amount of time taken to solve all the puzzles in parallel.
"""
t = time.time()
gpu_quotas = gpu_quotas[:]
tasks_started = [False for i in range(n_tasks)]
tasks_finished = [False for i in range(n_tasks)]
processes = [None for i in range(n_tasks)]
process_gpu_ids = [None for i in range(n_tasks)]
with multiprocessing.Manager() as manager:
# Construct structures for inter-process communication
memory_dict = manager.dict()
solutions_dict = manager.dict()
error_queue = manager.Queue()
# Job monitoring loop
while not all(tasks_finished):
# Scan for errors
if not error_queue.empty():
raise ValueError(error_queue.get())
# If a job finishes, release its quota
for i in range(n_tasks):
if tasks_started[i] and not tasks_finished[i]:
processes[i].join(timeout=0)
if not processes[i].is_alive():
tasks_finished[i] = True
gpu_quotas[process_gpu_ids[i]] += task_usages[i]
if verbose:
print(task_names[i], 'finished on gpu', process_gpu_ids[i],
'New quota is', gpu_quotas[process_gpu_ids[i]])
# If there is enough quota to start a new job, do it
for gpu_id in range(n_gpus):
for i in range(n_tasks):
enough_quota = gpu_quotas[gpu_id] >= task_usages[i]
enough_cpus = sum(map(int, tasks_started)) - sum(map(int, tasks_finished)) < n_cpus
if not tasks_started[i] and enough_quota and enough_cpus:
gpu_quotas[gpu_id] -= task_usages[i]
args = (task_names[i], split, 1e20, n_iterations, gpu_id, memory_dict, solutions_dict, error_queue)
p = multiprocessing.Process(target=solve_task.solve_task, args=args)
p.start()
processes[i] = p
tasks_started[i] = True
process_gpu_ids[i] = gpu_id
if verbose:
print(task_names[i], 'started on gpu', process_gpu_ids[i],
'New quota is', gpu_quotas[process_gpu_ids[i]])
time.sleep(1)
# Scan for errors
if not error_queue.empty():
raise ValueError(error_queue.get())
# Save the solutions in the server process
memory_dict = dict(memory_dict)
solutions_dict = dict(solutions_dict)
time_taken = time.time() - t
if verbose:
print('All jobs finished in', time_taken, 'seconds.')
return memory_dict, solutions_dict, time_taken
if __name__ == '__main__':
start_time = time.time()
n_cpus = multiprocessing.cpu_count()
n_gpus = torch.cuda.device_count()
# Find all the puzzle names
split = "training"
with open(f'dataset/arc-agi_{split}_challenges.json', 'r') as f:
problems = json.load(f)
task_names = list(problems.keys())
del problems
n_tasks = len(task_names)
# Measuring the amount of memory used for every task
gpu_memory_quotas = [torch.cuda.mem_get_info(i)[0] for i in range(n_gpus)]
gpu_task_quotas = [int(gpu_memory_quota // (4 * 1024**3)) for gpu_memory_quota in gpu_memory_quotas]
task_usages = [1 for i in range(n_tasks)]
memory_dict, _, _ = parallelize_runs(gpu_task_quotas, task_usages, 2, verbose=True)
# Sort the tasks by decreasing memory usage
tasks = sorted(memory_dict.items(), key=lambda x: x[1], reverse=True)
task_names, task_memory_usages = zip(*tasks)
# Computing the solution for every task, while saturating memory
n_steps = 2000
safe_gpu_memory_quotas = [memory_quota - 4 * 1024**3 for memory_quota in gpu_memory_quotas]
_, solutions_dict, time_taken = parallelize_runs(safe_gpu_memory_quotas, task_memory_usages, n_steps, verbose=True)
# Format the solutions and put into submission file
with open('submission.json', 'w') as f:
json.dump(solutions_dict, f, indent=4)
with open("submission.json", "r") as f:
contents = json.load(f)
print(len(contents.keys()), 'puzzles solved.')
print(n_steps, 'steps per puzzle.')
print(time_taken, 'seconds.')