Skip to content

Inference.py error #4

@musdfakoc

Description

@musdfakoc

hi, thanks for the great work here. I tried to run the pretrained model but I get this error:
Using GPU 0
['Tesla K80']
{'cuda': 0, 'comment': 0, 'batch_size': 8, 'train_data_dir': '/content/drive/MyDrive/Colab Notebooks/Building-GAN/Data/6types-processed_data', 'raw_dir': '/content/drive/MyDrive/Colab Notebooks/Building-GAN/Data/6types-raw_data', 'train_size': 96000, 'test_size': 4000, 'n_cpu': 8, 'variation_eval_id1': 96018, 'variation_eval_id2': 96010, 'variation_num': 25, 'latent_dim': 128, 'noise_dim': 32, 'program_layer': 4, 'voxel_layer': 12, 'gan_loss': 'WGANGP', 'gp_lambda': 10.0, 'lp_weight': 0.0, 'tr_weight': 0.0, 'far_weight': 0.0, 'lp_sample_size': 20, 'lp_similarity_fun': 'cos', 'lp_loss_fun': 'hinge', 'lp_hinge_margin': 1.0, 'n_epochs': 1000, 'n_critic_d': 1, 'n_critic_g': 5, 'n_critic_p': 5, 'plot_period': 10, 'eval_period': 20, 'g_lr': 0.0001, 'd_lr': 0.0001, 'b1': 0.5, 'b2': 0.999}
Total 120000 data: 96000 train / 4000 test

TypeError Traceback (most recent call last)
in ()
59 variation_test_data1 = torch.load(os.path.join(args.train_data_dir, data_fname_list[args.variation_eval_id1]))
60 variation_test_data2 = torch.load(os.path.join(args.train_data_dir, data_fname_list[args.variation_eval_id2]))
---> 61 variation_test_batch1 = Batch.from_data_list([variation_test_data1 for _ in range(args.variation_num)], follow_batch)
62 variation_test_batch2 = Batch.from_data_list([variation_test_data2 for _ in range(args.variation_num)], follow_batch)
63

2 frames
/usr/local/lib/python3.7/dist-packages/torch_geometric/data/collate.py in _collate(key, values, data_list, stores, increment)
135 # Concatenate a list of torch.Tensor along the cat_dim.
136 # NOTE: We need to take care of incrementing elements appropriately.
--> 137 cat_dim = data_list[0].cat_dim(key, elem, stores[0])
138 if cat_dim is None or elem.dim() == 0:
139 values = [value.unsqueeze(0) for value in values]

TypeError: cat_dim() takes 3 positional arguments but 4 were given

any idea how to solve this? I think this is about package versions. Can you provide some requirements.txt or something else?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions