Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
a583155
# First commit
Ronalmoo Oct 1, 2019
fcb8508
CNN Model
Ronalmoo Oct 2, 2019
3364dc6
add argparse
Ronalmoo Oct 2, 2019
dc55d07
Merge remote-tracking branch 'upstream/develop'
Ronalmoo Oct 3, 2019
ea7a5e8
repository hierarchy reconstructed
Ronalmoo Oct 3, 2019
8ecd211
Re-construct hierarchy (#11)
ssaru Oct 5, 2019
574e269
add ResNet50
Ronalmoo Oct 8, 2019
943e85f
Use GPU
Ronalmoo Oct 8, 2019
28a7fc1
test
Ronalmoo Oct 8, 2019
f2d8c74
minor fixed
Ronalmoo Oct 11, 2019
6de5887
add train_acc
Ronalmoo Oct 11, 2019
b3e85a5
add test.py
Ronalmoo Oct 11, 2019
6433bd1
add validation_set_loader
Ronalmoo Oct 12, 2019
afc72c9
validate
Ronalmoo Oct 12, 2019
5ce21d4
add confusion_matrix
Ronalmoo Oct 14, 2019
00f6e5e
add validation
Ronalmoo Oct 14, 2019
933f67c
minor change
Ronalmoo Oct 14, 2019
702eaa8
minor change
Ronalmoo Oct 14, 2019
693c81d
split data into train, validation, test
Ronalmoo Oct 14, 2019
4539c07
remove metrics.py
Ronalmoo Oct 24, 2019
c9a39af
Merge pull request #35 from chromatices/develop
chromatices Oct 25, 2019
0c22a0f
Merge branch 'master' of https://github.com/DeepBaksuVision/BinaryCon…
Ronalmoo Oct 26, 2019
f119606
Merge branch 'develop' of https://github.com/DeepBaksuVision/BinaryCo…
Ronalmoo Oct 26, 2019
3f24d24
add aug module with albumentations
Ronalmoo Oct 26, 2019
e60db94
Merge remote-tracking branch 'origin/master'
Ronalmoo Oct 26, 2019
6832bd9
fix dataloader
Ronalmoo Oct 26, 2019
83b87c9
변경 사항에 대한 커밋 메시지를 입력하십시오. '#' 문자로 시작하는
Ronalmoo Oct 26, 2019
0ebaacd
delete train
Ronalmoo Oct 26, 2019
fb7e9ac
delete validate
Ronalmoo Oct 26, 2019
b7df5c0
delete resnet
Ronalmoo Oct 26, 2019
53433a5
add BinaryLinear
Ronalmoo Oct 26, 2019
46330c8
add binarized_conv
Ronalmoo Oct 26, 2019
4e0eed0
add binarized_linear
Ronalmoo Oct 26, 2019
d4fa978
add aug module with albumentation
Ronalmoo Oct 26, 2019
3a24004
add aug module with albumentations
Ronalmoo Oct 26, 2019
4b50002
Merge remote-tracking branch 'upstream/develop'
Ronalmoo Oct 28, 2019
0b809bf
Merge remote-tracking branch 'upstream/develop'
Ronalmoo Oct 28, 2019
142794d
change configurations
Ronalmoo Oct 29, 2019
77b6743
pep-8
Ronalmoo Oct 30, 2019
875ec6c
Merge branch 'master' of https://github.com/Ronalmoo/BinaryConnect
Ronalmoo Oct 30, 2019
d375683
fix model
Ronalmoo Oct 31, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions augmentations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import json
import cv2
from matplotlib import pyplot as plt
from albumentations import (
HorizontalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, RandomRotate90,
Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue,
IAAAdditiveGaussianNoise, GaussNoise, MotionBlur, MedianBlur, IAAPiecewiseAffine,
IAASharpen, IAAEmboss, RandomBrightnessContrast, Flip, OneOf, Compose
)


def aug(config:{}, p=0.5):
return Compose([
HorizontalFlip(True),
RandomRotate90(True),
Flip(),
Transpose(),
OneOf([
IAAAdditiveGaussianNoise(),
GaussNoise(),
], p=0.2),
OneOf([

MotionBlur(p=config["MotionBlur"]),
MedianBlur(blur_limit=config["blur_limit"], p=0.1),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

p값은 별도로 제어 안해도 될까요?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

이 부분을 고민해봤는데 나중에 라이트닝에서 사용하는 파서로 아래 방식처럼 리스트로 옵션(확률값)을 주는게 어떨까요?
parser.opt_list('--drop_prob', default=0.2, options=[0.2, 0.5], type=float, tunable=False)
사실 저 p를 어느정도까지 세밀하게 해야하는지를 모르겠습니다ㅠ

Copy link
Contributor

@ssaru ssaru Oct 31, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

네네 list로 줘도 괜찮을 것 같습니다.

ShiftScaleRotate로 예를 들어보겠습니다.
configuration 파일은 json으로 구성한다고 가정합시다.
(json으로 구성해서 python에서 파싱하면 dict로 파싱되니, config 데이터 타입은 dict로 가정합니다.)

ShiftScaleRotate API를 확인해보니 아래와 같이 구성되어있음을 확인할 수 있습니다.
ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit=45, interpolation=1, border_mode=4, value=None, mask_value=None, always_apply=False, p=0.5)

그러면 config은 다음과 같이 작성할 수 있습니다.
dict로 파싱하면 아래와 같습니다.

{
    ...
    "augmentation" : {
        "shift_scale_rotate": {
            "shift_limit": 0.0625,
            "scale_limit": 0.1, 
            "rotate_limit": 45,
            "interpolation": 1, 
            "border_mode": 4
            "value": None,
            "mask_value": None, 
            "always_apply": false, 
            "p": 0.5
        }
    }
    ...
}

config 파싱한 결과는 실제 사용할 때, 아래와 같이 사용할 수 있습니다.

aug(config: dict) -> Compose:
    return Compose([
        ...
        ShiftScaleRotate(**config["augmentation"]["shift_scale_rotate"])
        ...
    ])

이렇게 사용하게되면 모든 제어를 config.json 설정을 통해서 할 수 있게됩니다.

(제 예시는 슈도코드입니다. 실제 사용하실 때는 테스트해보시면서 작동하는 것을 확인해보면서 작성하셔야합니다.)

추가적으로 라이트닝이 Argparser와 유사한 메카니즘의 파서를 사용한다면 저는 하이드라를 사용하는게 더 좋을 것 같습니다.
Argparser은 하이드라와 비교했을 때, 별도의 저장 로직이 코드상에 없다면 상대적으로 히스토리 추적이 어렵고 오타 확률이 높고 default값을 제외한 커스텀 값은 매번 타이핑 해야해서 불편합니다.

라이트닝에 탑재되어있는 파서는 저장 로직이 심어져있을 것으로 추측되지만, Argparser의 메카니즘을 따라가면 저는 하이드라를 사용하는 것이 더 좋다는 생각입니다.

Blur(blur_limit=config["blur_limit"], p=0.1),
], p=0.2),
ShiftScaleRotate(shift_limit=config["shift_limit"], scale_limit=config['scale_limit'],
rotate_limit=config["rotate_limit"], p=0.2),
OneOf([
OpticalDistortion(p=0.3),
GridDistortion(p=0.1),
IAAPiecewiseAffine(p=0.3),
], p=0.2),
OneOf([
CLAHE(clip_limit=config["clip_limit"]),
IAASharpen(),
IAAEmboss(),
RandomBrightnessContrast(),

], p=0.3),
HueSaturationValue(p=config["HueSaturationValue"]),
], p=p)


def show_img(img, figsize=(8, 8)):
fig, ax = plt.subplots(figsize=figsize)
ax.grid(False)
ax.set_yticklabels([])
ax.set_xticklabels([])
ax.imshow(img)
plt.imshow(img)


# with open('aug_config.json') as json_file:
# json_data = json.load(json_file)
# print(json_data['MotionBlur'])

if __name__ == "__main__":
image = cv2.imread('dog.12473.jpg')
config = {
"MotionBlur": 0.2,
"blur_limit": 3,
"MedianBlur": 0.1,
"shift_limit": 0.0625,
"scale_limit": 0.2,
"rotate_limit": 45,
"clip_limit": 2,
"HueSaturationValue": 0.3
}
augmentation = aug(config)
data = {'image': image}
augmented = augmentation(**data)
image = augmented['image']
show_img(image)
3 changes: 2 additions & 1 deletion dataloader/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torchvision
import torchvision.transforms as transforms


def data_loader(dataset="CIFAR-10", batch_size = 16):

if dataset == "CIFAR-10":
Expand All @@ -21,4 +22,4 @@ def data_loader(dataset="CIFAR-10", batch_size = 16):
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=2)
classes = ("0", "1", "2", "3", "4", "5", "6", "7", "8", "9")

return train_loader, test_loader, classes
return train_loader, test_loader, classes
8 changes: 4 additions & 4 deletions models/binarized_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
import torch.nn as nn
import torchvision.transforms as transforms

from augmentations import aug
from torchsummary import summary
from torch.nn import functional as F
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -77,15 +77,15 @@ def configure_optimizers(self):

@pl.data_loader
def train_dataloader(self):
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=128)
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=aug()), batch_size=128)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

여기서 aug()로 파라미터가 존재하는데 위에서 default valie가 {}로 빈 dict입니다. 실행이 되나요?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

모델에서는 테스트를 못했더니 미스가 났군요ㅜ 수정하였습니다.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

항상 짜신 모듈에 대해서는 철저하게 테스트 하셔야합니다.


@pl.data_loader
def val_dataloader(self):
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=aug()), batch_size=32)

@pl.data_loader
def test_dataloader(self):
return DataLoader(MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()), batch_size=32)
return DataLoader(MNIST(os.getcwd(), train=False, download=True, transform=aug()), batch_size=32)


if __name__ == "__main__":
Expand Down
5 changes: 3 additions & 2 deletions models/binarized_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch.nn as nn
import torchvision.transforms as transforms

from augmentations import aug
from torchsummary import summary
from torch.nn import functional as F
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -63,11 +64,11 @@ def configure_optimizers(self):

@pl.data_loader
def train_dataloader(self):
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=128)
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=aug()), batch_size=128)

@pl.data_loader
def val_dataloader(self):
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=aug()), batch_size=32)

@pl.data_loader
def test_dataloader(self):
Expand Down
143 changes: 143 additions & 0 deletions models/resnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import os
import torch
import torch.nn as nn
import pytorch_lightning as pl
import torch.nn.functional as F
import torchvision.transforms as transforms
from pytorch_lightning import Trainer
from augmentations import aug
from torchvision.models import resnet50
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10


def conv3x3(in_channels, out_channels, stride=1):
return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)


class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1, downsample=None):
super(ResidualBlock, self).__init__()
self.conv1 = conv3x3(in_channels, out_channels, stride)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(out_channels, out_channels)
self.bn2 = nn.BatchNorm2d(out_channels)
self.downsample = downsample

def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out


class ResNet(pl.LightningModule):
def __init__(self, block, layers, num_classes):
super(ResNet, self).__init__()
self.in_channels = 16
self.conv = conv3x3(3, 16)
self.bn = nn.BatchNorm2d(16)
self.relu = nn.ReLU(inplace=True)
self.layer1 = self.make_layer(block, 16, layers[0])
self.layer2 = self.make_layer(block, 32, layers[1], 2)
self.layer3 = self.make_layer(block, 64, layers[2], 2)
self.avg_pool = nn.AvgPool2d(8)
self.fc = nn.Linear(64, num_classes)

def make_layer(self, block, out_channels, blocks, stride=1):
downsample = None
if (stride != 1) or (self.in_channels != out_channels):
downsample = nn.Sequential(
conv3x3(self.in_channels, out_channels, stride=stride),
nn.BatchNorm2d(out_channels))
layers = []
layers.append(block(self.in_channels, out_channels, stride, downsample))
self.in_channels = out_channels
for i in range(1, blocks):
layers.append(block(out_channels, out_channels))
return nn.Sequential(*layers)

def forward(self, x):
out = self.conv(x)
out = self.bn(out)
out = self.relu(out)
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.avg_pool(out)
out = out.view(out.size(0), -1)
out = self.fc(out)
return out

def transform(self):
config = {
"MotionBlur": 0.2,
"blur_limit": 3,
"MedianBlur": 0.1,
"shift_limit": 0.0625,
"scale_limit": 0.2,
"rotate_limit": 45,
"clip_limit": 2,
"HueSaturationValue": 0.3
}
augs = aug(config)
return augs

def loss(self, y_hat, y):
return F.cross_entropy(y_hat, y)

def training_step(self, batch, batch_nb):
x, y = batch
x = x.to(device)
y = y.to(device)
y_hat = self.forward(x)
return {'loss': F.cross_entropy(y_hat, y)}

def validation_step(self, batch, batch_nb):
x, y = batch
x = x.to(device)
y = y.to(device)
y_hat = self.forward(x)
return {'val_loss': F.cross_entropy(y_hat, y)}

def validation_end(self, outputs):
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
return {'avg_val_loss': avg_loss}

def test_step(self, batch, batch_nb):
x, y = batch
y_hat = self.forward(x)
return {'test_loss': F.cross_entropy(y_hat, y)}

def test_end(self, outputs):
avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
return {'avg_test_loss': avg_loss}

def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.02)

@pl.data_loader
def train_dataloader(self):
return DataLoader(CIFAR10(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=128)

@pl.data_loader
def val_dataloader(self):
return DataLoader(CIFAR10(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)

@pl.data_loader
def test_dataloader(self):
return DataLoader(CIFAR10(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()),
batch_size=32)


device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
model = ResNet(ResidualBlock, [3, 4, 6], 10).to(device)
trainer = Trainer()
trainer.fit(model)
75 changes: 0 additions & 75 deletions train.py

This file was deleted.

4 changes: 4 additions & 0 deletions utils/BinaryLinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@ def deterministic(weight: torch.tensor) -> torch.tensor:
return weight.sign()

@staticmethod
<<<<<<< HEAD
def stochastic(weight: torch.tensor) -> torch.tensor:
=======
def stocastic(weight: torch.tensor) -> torch.tensor:
>>>>>>> upstream/develop
p = torch.sigmoid(weight)
uniform_matrix = torch.empty(p.shape).uniform_(0,1)
bin_weight = (p >= uniform_matrix).type(torch.float32)
Expand Down
2 changes: 1 addition & 1 deletion utils/binarized_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def deterministic(self, weight: torch.tensor) -> torch.tensor:
bin_weight = weight.sign()
return bin_weight

def stocastic(self, weight: torch.tensor) -> torch.tensor:
def stochastic(self, weight: torch.tensor) -> torch.tensor:
with torch.no_grad():
p = torch.sigmoid(weight)
uniform_matrix = torch.empty(p.shape).uniform_(0, 1)
Expand Down
2 changes: 1 addition & 1 deletion utils/binarized_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def deterministic(self, weight: torch.tensor) -> torch.tensor:
bin_weight = weight.sign()
return bin_weight

def stocastic(self, weight: torch.tensor) -> torch.tensor:
def stochastic(self, weight: torch.tensor) -> torch.tensor:
with torch.no_grad():
p = torch.sigmoid(weight)
uniform_matrix = torch.empty(p.shape).uniform_(0,1)
Expand Down