diff --git a/augmentations.py b/augmentations.py new file mode 100644 index 0000000..e285ec5 --- /dev/null +++ b/augmentations.py @@ -0,0 +1,75 @@ +import json +import cv2 +from matplotlib import pyplot as plt +from albumentations import ( + HorizontalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, RandomRotate90, + Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue, + IAAAdditiveGaussianNoise, GaussNoise, MotionBlur, MedianBlur, IAAPiecewiseAffine, + IAASharpen, IAAEmboss, RandomBrightnessContrast, Flip, OneOf, Compose +) + + +def aug(config:{}, p=0.5): + return Compose([ + HorizontalFlip(True), + RandomRotate90(True), + Flip(), + Transpose(), + OneOf([ + IAAAdditiveGaussianNoise(), + GaussNoise(), + ], p=0.2), + OneOf([ + + MotionBlur(p=config["MotionBlur"]), + MedianBlur(blur_limit=config["blur_limit"], p=0.1), + Blur(blur_limit=config["blur_limit"], p=0.1), + ], p=0.2), + ShiftScaleRotate(shift_limit=config["shift_limit"], scale_limit=config['scale_limit'], + rotate_limit=config["rotate_limit"], p=0.2), + OneOf([ + OpticalDistortion(p=0.3), + GridDistortion(p=0.1), + IAAPiecewiseAffine(p=0.3), + ], p=0.2), + OneOf([ + CLAHE(clip_limit=config["clip_limit"]), + IAASharpen(), + IAAEmboss(), + RandomBrightnessContrast(), + + ], p=0.3), + HueSaturationValue(p=config["HueSaturationValue"]), + ], p=p) + + +def show_img(img, figsize=(8, 8)): + fig, ax = plt.subplots(figsize=figsize) + ax.grid(False) + ax.set_yticklabels([]) + ax.set_xticklabels([]) + ax.imshow(img) + plt.imshow(img) + + +# with open('aug_config.json') as json_file: +# json_data = json.load(json_file) +# print(json_data['MotionBlur']) + +if __name__ == "__main__": + image = cv2.imread('dog.12473.jpg') + config = { + "MotionBlur": 0.2, + "blur_limit": 3, + "MedianBlur": 0.1, + "shift_limit": 0.0625, + "scale_limit": 0.2, + "rotate_limit": 45, + "clip_limit": 2, + "HueSaturationValue": 0.3 + } + augmentation = aug(config) + data = {'image': image} + augmented = augmentation(**data) + image = augmented['image'] + show_img(image) \ No newline at end of file diff --git a/dataloader/data_loader.py b/dataloader/data_loader.py index f230e64..ef8ee6b 100644 --- a/dataloader/data_loader.py +++ b/dataloader/data_loader.py @@ -2,6 +2,7 @@ import torchvision import torchvision.transforms as transforms + def data_loader(dataset="CIFAR-10", batch_size = 16): if dataset == "CIFAR-10": @@ -21,4 +22,4 @@ def data_loader(dataset="CIFAR-10", batch_size = 16): test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=2) classes = ("0", "1", "2", "3", "4", "5", "6", "7", "8", "9") - return train_loader, test_loader, classes + return train_loader, test_loader, classes \ No newline at end of file diff --git a/models/binarized_conv.py b/models/binarized_conv.py index b825586..6bb2dd5 100644 --- a/models/binarized_conv.py +++ b/models/binarized_conv.py @@ -3,7 +3,7 @@ import torch import torch.nn as nn import torchvision.transforms as transforms - +from augmentations import aug from torchsummary import summary from torch.nn import functional as F from torch.utils.data import DataLoader @@ -77,15 +77,15 @@ def configure_optimizers(self): @pl.data_loader def train_dataloader(self): - return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=128) + return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=aug()), batch_size=128) @pl.data_loader def val_dataloader(self): - return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32) + return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=aug()), batch_size=32) @pl.data_loader def test_dataloader(self): - return DataLoader(MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()), batch_size=32) + return DataLoader(MNIST(os.getcwd(), train=False, download=True, transform=aug()), batch_size=32) if __name__ == "__main__": diff --git a/models/binarized_mlp.py b/models/binarized_mlp.py index a832cbf..fea4146 100644 --- a/models/binarized_mlp.py +++ b/models/binarized_mlp.py @@ -4,6 +4,7 @@ import torch.nn as nn import torchvision.transforms as transforms +from augmentations import aug from torchsummary import summary from torch.nn import functional as F from torch.utils.data import DataLoader @@ -63,11 +64,11 @@ def configure_optimizers(self): @pl.data_loader def train_dataloader(self): - return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=128) + return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=aug()), batch_size=128) @pl.data_loader def val_dataloader(self): - return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32) + return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=aug()), batch_size=32) @pl.data_loader def test_dataloader(self): diff --git a/models/resnet.py b/models/resnet.py new file mode 100644 index 0000000..4cd3d42 --- /dev/null +++ b/models/resnet.py @@ -0,0 +1,143 @@ +import os +import torch +import torch.nn as nn +import pytorch_lightning as pl +import torch.nn.functional as F +import torchvision.transforms as transforms +from pytorch_lightning import Trainer +from augmentations import aug +from torchvision.models import resnet50 +from torch.utils.data import DataLoader +from torchvision.datasets import CIFAR10 + + +def conv3x3(in_channels, out_channels, stride=1): + return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) + + +class ResidualBlock(nn.Module): + def __init__(self, in_channels, out_channels, stride=1, downsample=None): + super(ResidualBlock, self).__init__() + self.conv1 = conv3x3(in_channels, out_channels, stride) + self.bn1 = nn.BatchNorm2d(out_channels) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(out_channels, out_channels) + self.bn2 = nn.BatchNorm2d(out_channels) + self.downsample = downsample + + def forward(self, x): + residual = x + out = self.conv1(x) + out = self.bn1(out) + out = self.conv2(out) + out = self.bn2(out) + if self.downsample: + residual = self.downsample(x) + out += residual + out = self.relu(out) + return out + + +class ResNet(pl.LightningModule): + def __init__(self, block, layers, num_classes): + super(ResNet, self).__init__() + self.in_channels = 16 + self.conv = conv3x3(3, 16) + self.bn = nn.BatchNorm2d(16) + self.relu = nn.ReLU(inplace=True) + self.layer1 = self.make_layer(block, 16, layers[0]) + self.layer2 = self.make_layer(block, 32, layers[1], 2) + self.layer3 = self.make_layer(block, 64, layers[2], 2) + self.avg_pool = nn.AvgPool2d(8) + self.fc = nn.Linear(64, num_classes) + + def make_layer(self, block, out_channels, blocks, stride=1): + downsample = None + if (stride != 1) or (self.in_channels != out_channels): + downsample = nn.Sequential( + conv3x3(self.in_channels, out_channels, stride=stride), + nn.BatchNorm2d(out_channels)) + layers = [] + layers.append(block(self.in_channels, out_channels, stride, downsample)) + self.in_channels = out_channels + for i in range(1, blocks): + layers.append(block(out_channels, out_channels)) + return nn.Sequential(*layers) + + def forward(self, x): + out = self.conv(x) + out = self.bn(out) + out = self.relu(out) + out = self.layer1(out) + out = self.layer2(out) + out = self.layer3(out) + out = self.avg_pool(out) + out = out.view(out.size(0), -1) + out = self.fc(out) + return out + + def transform(self): + config = { + "MotionBlur": 0.2, + "blur_limit": 3, + "MedianBlur": 0.1, + "shift_limit": 0.0625, + "scale_limit": 0.2, + "rotate_limit": 45, + "clip_limit": 2, + "HueSaturationValue": 0.3 + } + augs = aug(config) + return augs + + def loss(self, y_hat, y): + return F.cross_entropy(y_hat, y) + + def training_step(self, batch, batch_nb): + x, y = batch + x = x.to(device) + y = y.to(device) + y_hat = self.forward(x) + return {'loss': F.cross_entropy(y_hat, y)} + + def validation_step(self, batch, batch_nb): + x, y = batch + x = x.to(device) + y = y.to(device) + y_hat = self.forward(x) + return {'val_loss': F.cross_entropy(y_hat, y)} + + def validation_end(self, outputs): + avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() + return {'avg_val_loss': avg_loss} + + def test_step(self, batch, batch_nb): + x, y = batch + y_hat = self.forward(x) + return {'test_loss': F.cross_entropy(y_hat, y)} + + def test_end(self, outputs): + avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean() + return {'avg_test_loss': avg_loss} + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=0.02) + + @pl.data_loader + def train_dataloader(self): + return DataLoader(CIFAR10(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=128) + + @pl.data_loader + def val_dataloader(self): + return DataLoader(CIFAR10(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32) + + @pl.data_loader + def test_dataloader(self): + return DataLoader(CIFAR10(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()), + batch_size=32) + + +device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu') +model = ResNet(ResidualBlock, [3, 4, 6], 10).to(device) +trainer = Trainer() +trainer.fit(model) diff --git a/train.py b/train.py deleted file mode 100644 index be71925..0000000 --- a/train.py +++ /dev/null @@ -1,75 +0,0 @@ -import torch -import torch.nn as nn -import torch.optim as optim - -from dataloader.data_loader import data_loader -from tqdm import tqdm -import argparse -from models.conv import CNN -from models.mlp import MLP -from models.binarized_mlp import Binarized_MLP - - -def train(args, model, device, train_loader, optimizer, epoch, criterion): - model.train() - for i, (data, target) in tqdm(enumerate(train_loader, 0)): - data = data.to(device) - target = target.to(device) - optimizer.zero_grad() - output = model(data) - loss = criterion(output, target) - loss.backward() - optimizer.step() - if i % args.log_interval == 0: - print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( - epoch, i * len(data), len(train_loader.dataset), - 100. * i / len(train_loader), loss.item())) - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument('--model', type=str, default='MLP') - parser.add_argument('--bnn_type', type=str, default='Stochastic') - parser.add_argument('--dataset', type=str, default='CIFAR-10') - parser.add_argument('--batch_size', type=int, default=16, metavar='N', - help='input batch size for training (default: 64)') - parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', - help='input batch size for testing (default: 1000)') - parser.add_argument('--epochs', type=int, default=1, metavar='N', - help='number of epochs to train (default: 10)') - parser.add_argument('--lr', type=float, default=0.01, metavar='LR', - help='learning rate (default: 0.01)') - parser.add_argument('--momentum', type=float, default=0.5, metavar='M', - help='SGD momentum (default: 0.5)') - parser.add_argument('--seed', type=int, default=1, metavar='S', - help='random seed (default: 1)') - parser.add_argument('--log-interval', type=int, default=10, metavar='N', - help='how many batches to wait before logging training status') - - parser.add_argument('--save_model', action='store_true', default=False, - help='For Saving the current Model') - args = parser.parse_args() - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - train_loader = data_loader(dataset=args.dataset, batch_size=args.batch_size)[0] - - if args.model == 'CNN': - model = CNN() - elif args.model == 'MLP': - model = MLP() - elif args.model == 'BNN': - if args.bnn_type == "Stochastic" or args.bnn_type == "Deterministic": - model = Binarized_MLP(args.bnn_type) - else: - raise RuntimeError("not supported quantization method") - - optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) - criterion = nn.CrossEntropyLoss() - for epoch in range(1, args.epochs + 1): - train(args, model, device, train_loader, optimizer, epoch, criterion) - - if (args.save_model): - torch.save(model.state_dict(), "CIFAR-10_MLP.pt") - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/utils/BinaryLinear.py b/utils/BinaryLinear.py index 79f7682..f6bf18b 100644 --- a/utils/BinaryLinear.py +++ b/utils/BinaryLinear.py @@ -32,7 +32,11 @@ def deterministic(weight: torch.tensor) -> torch.tensor: return weight.sign() @staticmethod +<<<<<<< HEAD + def stochastic(weight: torch.tensor) -> torch.tensor: +======= def stocastic(weight: torch.tensor) -> torch.tensor: +>>>>>>> upstream/develop p = torch.sigmoid(weight) uniform_matrix = torch.empty(p.shape).uniform_(0,1) bin_weight = (p >= uniform_matrix).type(torch.float32) diff --git a/utils/binarized_conv.py b/utils/binarized_conv.py index fa26c8c..8e14f0d 100644 --- a/utils/binarized_conv.py +++ b/utils/binarized_conv.py @@ -54,7 +54,7 @@ def deterministic(self, weight: torch.tensor) -> torch.tensor: bin_weight = weight.sign() return bin_weight - def stocastic(self, weight: torch.tensor) -> torch.tensor: + def stochastic(self, weight: torch.tensor) -> torch.tensor: with torch.no_grad(): p = torch.sigmoid(weight) uniform_matrix = torch.empty(p.shape).uniform_(0, 1) diff --git a/utils/binarized_linear.py b/utils/binarized_linear.py index a148160..25a1f59 100644 --- a/utils/binarized_linear.py +++ b/utils/binarized_linear.py @@ -34,7 +34,7 @@ def deterministic(self, weight: torch.tensor) -> torch.tensor: bin_weight = weight.sign() return bin_weight - def stocastic(self, weight: torch.tensor) -> torch.tensor: + def stochastic(self, weight: torch.tensor) -> torch.tensor: with torch.no_grad(): p = torch.sigmoid(weight) uniform_matrix = torch.empty(p.shape).uniform_(0,1)