diff --git a/main_ce.py b/main_ce.py index 29573d30..cbe88254 100644 --- a/main_ce.py +++ b/main_ce.py @@ -161,8 +161,8 @@ def set_loader(opt): train_dataset, batch_size=opt.batch_size, shuffle=(train_sampler is None), num_workers=opt.num_workers, pin_memory=True, sampler=train_sampler) val_loader = torch.utils.data.DataLoader( - val_dataset, batch_size=256, shuffle=False, - num_workers=8, pin_memory=True) + val_dataset, batch_size=opt.batch_size, shuffle=False, + num_workers=opt.num_workers, pin_memory=True) return train_loader, val_loader