diff --git a/dataloader/data_loader.py b/dataloader/data_loader.py index f230e64..ef8ee6b 100644 --- a/dataloader/data_loader.py +++ b/dataloader/data_loader.py @@ -2,6 +2,7 @@ import torchvision import torchvision.transforms as transforms + def data_loader(dataset="CIFAR-10", batch_size = 16): if dataset == "CIFAR-10": @@ -21,4 +22,4 @@ def data_loader(dataset="CIFAR-10", batch_size = 16): test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=2) classes = ("0", "1", "2", "3", "4", "5", "6", "7", "8", "9") - return train_loader, test_loader, classes + return train_loader, test_loader, classes \ No newline at end of file diff --git a/layers/BinaryLinear.py b/layers/BinaryLinear.py index 79f7682..ea2c026 100644 --- a/layers/BinaryLinear.py +++ b/layers/BinaryLinear.py @@ -5,7 +5,7 @@ class BinaryLinear(torch.nn.Linear): - def __init__(self, in_features, out_features, bias=True, mode="Stocastic"): + def __init__(self, in_features, out_features, bias=True, mode="Stochastic"): super().__init__(in_features, out_features, bias) self.mode = mode self.bin_weight = self.weight_binarization(self.weight, self.mode) @@ -17,8 +17,8 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: def weight_binarization(self, weight: torch.tensor, mode:str): with torch.set_grad_enabled(False): - if mode == "Stocastic": - bin_weight = self.stocastic(weight) + if mode == "Stochastic": + bin_weight = self.stochastic(weight) elif mode == "Deterministic": bin_weight = self.deterministic(weight) else: @@ -32,7 +32,7 @@ def deterministic(weight: torch.tensor) -> torch.tensor: return weight.sign() @staticmethod - def stocastic(weight: torch.tensor) -> torch.tensor: + def stochastic(weight: torch.tensor) -> torch.tensor: p = torch.sigmoid(weight) uniform_matrix = torch.empty(p.shape).uniform_(0,1) bin_weight = (p >= uniform_matrix).type(torch.float32) @@ -49,4 +49,4 @@ def clipping_weight(self, weight:torch.tensor) -> torch.tensor: with torch.set_grad_enabled(False): weight = torch.clamp(weight, -1, 1) weight.requires_grad = True - return weight + return weight \ No newline at end of file diff --git a/models/binarized_conv.py b/models/binarized_conv.py index c340116..6463d1a 100644 --- a/models/binarized_conv.py +++ b/models/binarized_conv.py @@ -114,4 +114,4 @@ def test_dataloader(self): trainer = Trainer(checkpoint_callback=checkpoint_callback, max_nb_epochs=1, train_percent_check=0.1) trainer.fit(model) - trainer.test(model) + trainer.test(model) \ No newline at end of file diff --git a/models/binarized_mlp.py b/models/binarized_mlp.py index 652c1df..24c04d6 100644 --- a/models/binarized_mlp.py +++ b/models/binarized_mlp.py @@ -120,16 +120,16 @@ def test_dataloader(self): monitor='val_loss', mode='min', prefix='', - save_weights_only= True + save_weights_only=True ) - + gpus = torch.cuda.device_count() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = Binarized_MLP(device=device, mode="Stochastic") model.to(device) model.summary() - + trainer = Trainer(checkpoint_callback=checkpoint_callback, max_nb_epochs=5, train_percent_check=0.1) trainer.fit(model) - trainer.test(model) + trainer.test(model) \ No newline at end of file diff --git a/models/resnet.py b/models/resnet.py new file mode 100644 index 0000000..5a713db --- /dev/null +++ b/models/resnet.py @@ -0,0 +1,149 @@ +import os +import torch +import torch.nn as nn +import pytorch_lightning as pl +import torch.nn.functional as F +import torchvision.transforms as transforms +from torch.utils.data import DataLoader +from pytorch_lightning import Trainer +from torchvision.datasets import CIFAR10, MNIST + + +def conv3x3(in_channels, out_channels, stride=1): + return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class ResidualBlock(nn.Module): + def __init__(self, in_channels, out_channels, stride=1, downsample=None): + super(ResidualBlock, self).__init__() + self.conv1 = conv3x3(in_channels, out_channels, stride) + self.bn1 = nn.BatchNorm2d(out_channels) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(out_channels, out_channels) + self.bn2 = nn.BatchNorm2d(out_channels) + self.downsample = downsample + + def forward(self, x): + residual = x + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + out = self.conv2(out) + out = self.bn2(out) + if self.downsample: + residual = self.downsample(x) + out += residual + out = self.relu(out) + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, in_planes, planes, stride=1): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(self.expansion*planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion*planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(self.expansion*planes) + ) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = F.relu(self.bn2(self.conv2(out))) + out = self.bn3(self.conv3(out)) + out += self.shortcut(x) + out = F.relu(out) + return out + + +class ResNet(pl.LightningModule): + def __init__(self, block, num_blocks, num_classes=10): + super(ResNet, self).__init__() + self.in_planes = 64 + + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) + self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) + self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) + self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) + self.linear = nn.Linear(512 * block.expansion, num_classes) + + def _make_layer(self, block, planes, num_blocks, stride): + strides = [stride] + [1] * (num_blocks - 1) + layers = [] + for stride in strides: + layers.append(block(self.in_planes, planes, stride)) + self.in_planes = planes * block.expansion + return nn.Sequential(*layers) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.layer1(out) + out = self.layer2(out) + out = self.layer3(out) + out = self.layer4(out) + out = F.avg_pool2d(out, 4) + out = out.view(out.size(0), -1) + out = self.linear(out) + return out + + def training_step(self, batch, batch_nb): + x, y = batch + y_hat = self.forward(x) + return {'loss': F.cross_entropy(y_hat, y)} + + def validation_step(self, batch, batch_nb): + x, y = batch + y_hat = self.forward(x) + return {'val_loss': F.cross_entropy(y_hat, y)} + + def validation_end(self, outputs): + avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() + return {'avg_val_loss': avg_loss} + + def test_step(self, batch, batch_nb): + x, y = batch + y_hat = self.forward(x) + return {'test_loss': F.cross_entropy(y_hat, y)} + + def test_end(self, outputs): + avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean() + return {'avg_test_loss': avg_loss} + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=0.02) + + @pl.data_loader + def train_dataloader(self): + return DataLoader(CIFAR10(os.getcwd(), train=True, transform=transforms.ToTensor(), download=True), batch_size=128) + + @pl.data_loader + def val_dataloader(self): + return DataLoader(CIFAR10(os.getcwd(), train=True, transform=transforms.ToTensor(), download=True), batch_size=32) + + @pl.data_loader + def test_dataloader(self): + return DataLoader(CIFAR10(os.getcwd(), train=False, download=True), batch_size=32) + + +def ResNet50(): + return ResNet(Bottleneck, [3,4,6,3]) + + +if __name__ == "__main__": + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + model = ResNet50() + trainer = Trainer() + trainer.fit(model) \ No newline at end of file diff --git a/train.py b/train.py deleted file mode 100644 index be71925..0000000 --- a/train.py +++ /dev/null @@ -1,75 +0,0 @@ -import torch -import torch.nn as nn -import torch.optim as optim - -from dataloader.data_loader import data_loader -from tqdm import tqdm -import argparse -from models.conv import CNN -from models.mlp import MLP -from models.binarized_mlp import Binarized_MLP - - -def train(args, model, device, train_loader, optimizer, epoch, criterion): - model.train() - for i, (data, target) in tqdm(enumerate(train_loader, 0)): - data = data.to(device) - target = target.to(device) - optimizer.zero_grad() - output = model(data) - loss = criterion(output, target) - loss.backward() - optimizer.step() - if i % args.log_interval == 0: - print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( - epoch, i * len(data), len(train_loader.dataset), - 100. * i / len(train_loader), loss.item())) - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument('--model', type=str, default='MLP') - parser.add_argument('--bnn_type', type=str, default='Stochastic') - parser.add_argument('--dataset', type=str, default='CIFAR-10') - parser.add_argument('--batch_size', type=int, default=16, metavar='N', - help='input batch size for training (default: 64)') - parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', - help='input batch size for testing (default: 1000)') - parser.add_argument('--epochs', type=int, default=1, metavar='N', - help='number of epochs to train (default: 10)') - parser.add_argument('--lr', type=float, default=0.01, metavar='LR', - help='learning rate (default: 0.01)') - parser.add_argument('--momentum', type=float, default=0.5, metavar='M', - help='SGD momentum (default: 0.5)') - parser.add_argument('--seed', type=int, default=1, metavar='S', - help='random seed (default: 1)') - parser.add_argument('--log-interval', type=int, default=10, metavar='N', - help='how many batches to wait before logging training status') - - parser.add_argument('--save_model', action='store_true', default=False, - help='For Saving the current Model') - args = parser.parse_args() - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - train_loader = data_loader(dataset=args.dataset, batch_size=args.batch_size)[0] - - if args.model == 'CNN': - model = CNN() - elif args.model == 'MLP': - model = MLP() - elif args.model == 'BNN': - if args.bnn_type == "Stochastic" or args.bnn_type == "Deterministic": - model = Binarized_MLP(args.bnn_type) - else: - raise RuntimeError("not supported quantization method") - - optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) - criterion = nn.CrossEntropyLoss() - for epoch in range(1, args.epochs + 1): - train(args, model, device, train_loader, optimizer, epoch, criterion) - - if (args.save_model): - torch.save(model.state_dict(), "CIFAR-10_MLP.pt") - - -if __name__ == "__main__": - main() \ No newline at end of file