Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions config/cifar10/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import chainer
from src.model.wide_resnet_new import WideResNet
from src.dataset.cifar import cifar
from src.preprocess.preprocess_svhn import PreprocessSVHN
from src.extension.learning_rate_scheduler import LearningRateScheduler
from src.extension.learning_rate_scheduler import ExponentialSchedule
from src.hook.power_iteration import PowerIteration

batchsize = 128
dataset = cifar()
epoch = 160
preprocess = PreprocessSVHN()
predictor = WideResNet(k=4, n_layer=16, drop=.4)
lr = 1e-2
optimizer = chainer.optimizers.NesterovAG(lr)
extension = [(LearningRateScheduler(ExponentialSchedule(.1, (80, 120))), (1, 'iteration'))]
hook = [PowerIteration()]
10 changes: 10 additions & 0 deletions config/cifar10/lmt_001.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from config.cifar10.base import *
from src.model.classifier import LMTraining
from src.extension.c_scheduler import CScheduler
from src.extension.c_scheduler import GoyalSchedule

mode = ['lmt', 'lmt-fc']
start_c = 1e-5
end_c = 1e-2
model = LMTraining(predictor, preprocess, c=start_c)
extension += [(CScheduler(GoyalSchedule(start_c, end_c, 5)), (1, 'iteration'))]
22 changes: 22 additions & 0 deletions src/dataset/cifar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import chainer
from chainer.datasets import TupleDataset

means = (0.4914, 0.4822, 0.4465)
sds = (0.2023, 0.1994, 0.2010)

def normalize(X):
for i in range(3):
X[:,i,:,:] -= means[i]
X[:,i,:,:] /= sds[i]
return X

def transform(dataset):
X = dataset._datasets[0]
y = dataset._datasets[1]
return TupleDataset(normalize(X), y)

def cifar():
cifar10 = chainer.datasets.get_cifar10()
train = transform(cifar10[0])
test = transform(cifar10[1])
return (train, test)