diff --git a/ldm/data/simple.py b/ldm/data/simple.py index 4d200b28..e6b57e8d 100644 --- a/ldm/data/simple.py +++ b/ldm/data/simple.py @@ -63,10 +63,11 @@ def hf_dataset( split='train', image_key='image', caption_key='txt', + use_auth_token=False, ): """Make huggingface dataset with appropriate list of transforms applied """ - ds = load_dataset(name, split=split) + ds = load_dataset(name, split=split, use_auth_token=use_auth_token) image_transforms = [instantiate_from_config(tt) for tt in image_transforms] image_transforms.extend([transforms.ToTensor(), transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]) @@ -112,4 +113,4 @@ def __getitem__(self, index): def _load_caption_file(self, filename): with open(filename, 'rt') as f: captions = f.readlines() - return [x.strip('\n') for x in captions] \ No newline at end of file + return [x.strip('\n') for x in captions] diff --git a/main.py b/main.py index b21a775f..bbd785f1 100644 --- a/main.py +++ b/main.py @@ -855,7 +855,7 @@ def check_frequency(self, check_idx): # configure learning rate bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate if not cpu: - ngpu = len(lightning_config.trainer.gpus.strip(",").split(',')) + ngpu = len(str(lightning_config.trainer.gpus).strip(",").split(',')) else: ngpu = 1 if 'accumulate_grad_batches' in lightning_config.trainer: