-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy path3_mnist_distributed.py
More file actions
68 lines (58 loc) · 2.13 KB
/
3_mnist_distributed.py
File metadata and controls
68 lines (58 loc) · 2.13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch.multiprocessing as mp
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
torch.manual_seed(0)
torch.cuda.manual_seed(0)
# Simple Feedforward Neural Network for MNIST
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc = nn.Sequential(
nn.Linear(28*28, 128),
nn.ReLU(),
nn.Linear(128, 10)
)
def forward(self, x):
return self.fc(x)
# Training function for each worker
def train_worker(rank, epochs=5):
# Data loading
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: x.view(-1)) # flatten 28x28 -> 784
])
train_dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=6400, shuffle=True)
# Assign device based on rank
device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu")
# Model, Loss, Optimizer
model = Net().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Training loop
losses = []
for epoch in range(epochs):
for data, target in train_loader:
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
losses.append(loss.item())
print(f"[GPU {rank}] Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f}")
# Save loss curve per worker
plt.plot(losses)
plt.xlabel("Iteration")
plt.ylabel("Training Loss")
plt.title(f"MNIST Loss Curve (GPU {rank})")
plt.savefig(f"./img/mnist_loss_curve_gpu{rank}.png")
if __name__ == "__main__":
print('Start MNIST training on multiple GPUs (non-DDP)...')
# n_gpus = torch.cuda.device_count()
world_size = 4 # use up to 4 GPUs
mp.spawn(train_worker, args=(5,), nprocs=world_size, join=True)