-
Notifications
You must be signed in to change notification settings - Fork 23
Description
hi, thanks for the great work here. I tried to run the pretrained model but I get this error:
Using GPU 0
['Tesla K80']
{'cuda': 0, 'comment': 0, 'batch_size': 8, 'train_data_dir': '/content/drive/MyDrive/Colab Notebooks/Building-GAN/Data/6types-processed_data', 'raw_dir': '/content/drive/MyDrive/Colab Notebooks/Building-GAN/Data/6types-raw_data', 'train_size': 96000, 'test_size': 4000, 'n_cpu': 8, 'variation_eval_id1': 96018, 'variation_eval_id2': 96010, 'variation_num': 25, 'latent_dim': 128, 'noise_dim': 32, 'program_layer': 4, 'voxel_layer': 12, 'gan_loss': 'WGANGP', 'gp_lambda': 10.0, 'lp_weight': 0.0, 'tr_weight': 0.0, 'far_weight': 0.0, 'lp_sample_size': 20, 'lp_similarity_fun': 'cos', 'lp_loss_fun': 'hinge', 'lp_hinge_margin': 1.0, 'n_epochs': 1000, 'n_critic_d': 1, 'n_critic_g': 5, 'n_critic_p': 5, 'plot_period': 10, 'eval_period': 20, 'g_lr': 0.0001, 'd_lr': 0.0001, 'b1': 0.5, 'b2': 0.999}
Total 120000 data: 96000 train / 4000 test
TypeError Traceback (most recent call last)
in ()
59 variation_test_data1 = torch.load(os.path.join(args.train_data_dir, data_fname_list[args.variation_eval_id1]))
60 variation_test_data2 = torch.load(os.path.join(args.train_data_dir, data_fname_list[args.variation_eval_id2]))
---> 61 variation_test_batch1 = Batch.from_data_list([variation_test_data1 for _ in range(args.variation_num)], follow_batch)
62 variation_test_batch2 = Batch.from_data_list([variation_test_data2 for _ in range(args.variation_num)], follow_batch)
63
2 frames
/usr/local/lib/python3.7/dist-packages/torch_geometric/data/collate.py in _collate(key, values, data_list, stores, increment)
135 # Concatenate a list of torch.Tensor along the cat_dim.
136 # NOTE: We need to take care of incrementing elements appropriately.
--> 137 cat_dim = data_list[0].cat_dim(key, elem, stores[0])
138 if cat_dim is None or elem.dim() == 0:
139 values = [value.unsqueeze(0) for value in values]
TypeError: cat_dim() takes 3 positional arguments but 4 were given
any idea how to solve this? I think this is about package versions. Can you provide some requirements.txt or something else?