Skip to content

尝试使用base model训练后无法推理 #16

@kenny-hash

Description

@kenny-hash

作者好,最近关注到你们的工作,我尝试使用sd-v1.4的模型进行训练与推理时遇到一些问题,训练时在log文件夹中看测试的图没有问题,但使用vico_txt2img.py进行推理时的结果是彩噪。
a-photo-of-_-on-the-beach

此外,在v1-finetune.yaml配置文件中修改batch_size会导致训练错误:请问是在代码中硬编码了参数吗?

  Traceback (most recent call last):
    File "main-real.py", line 820, in <module>
      trainer.fit(model, data)
    File "/opt/miniconda/envs/vico/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 740, in fit
      self._call_and_handle_interrupt(
    File "/opt/miniconda/envs/vico/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 685, in _call_and_handle_interrupt
      return trainer_fn(*args, **kwargs)
    File "/opt/miniconda/envs/vico/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 777, in _fit_impl
      self._run(model, ckpt_path=ckpt_path)
    File "/opt/miniconda/envs/vico/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1199, in _run
      self._dispatch()
    File "/opt/miniconda/envs/vico/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1279, in _dispatch
      self.training_type_plugin.start_training(self)
    File "/opt/miniconda/envs/vico/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 202, in start_training
      self._results = trainer.run_stage()
    File "/opt/miniconda/envs/vico/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1289, in run_stage
      return self._run_train()
    File "/opt/miniconda/envs/vico/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1319, in _run_train
      self.fit_loop.run()
    File "/opt/miniconda/envs/vico/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run
      self.advance(*args, **kwargs)
    File "/opt/miniconda/envs/vico/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py", line 234, in advance
      self.epoch_loop.run(data_fetcher)
    File "/opt/miniconda/envs/vico/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run
      self.advance(*args, **kwargs)
    File "/opt/miniconda/envs/vico/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 216, in advance
      self.trainer.call_hook("on_train_batch_end", batch_end_outputs, batch, batch_idx, **extra_kwargs)
    File "/opt/miniconda/envs/vico/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1495, in call_hook
      callback_fx(*args, **kwargs)
    File "/opt/miniconda/envs/vico/lib/python3.8/site-packages/pytorch_lightning/trainer/callback_hook.py", line 179, in on_train_batch_end
      callback.on_train_batch_end(self, self.lightning_module, outputs, batch, batch_idx, 0)
    File "/home/azureuser/ViCo/main.py", line 442, in on_train_batch_end
      self.log_img(pl_module, batch, batch_idx, split="train")
    File "/home/azureuser/ViCo/main.py", line 410, in log_img
      images = pl_module.log_images(batch, split=split, **self.log_images_kwargs)
    File "/opt/miniconda/envs/vico/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
      return func(*args, **kwargs)
    File "/home/azureuser/ViCo/ldm/models/diffusion/ddpm.py", line 1409, in log_images
      sample_scaled, _ = self.sample_log(cond=c, 
    File "/opt/miniconda/envs/vico/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
      return func(*args, **kwargs)
    File "/home/azureuser/ViCo/ldm/models/diffusion/ddpm.py", line 1337, in sample_log
      samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size,
    File "/opt/miniconda/envs/vico/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
      return func(*args, **kwargs)
    File "/home/azureuser/ViCo/ldm/models/diffusion/ddim.py", line 98, in sample
      samples, intermediates = self.ddim_sampling(conditioning, image_cond, ph_pos, size,
    File "/opt/miniconda/envs/vico/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
      return func(*args, **kwargs)
    File "/home/azureuser/ViCo/ldm/models/diffusion/ddim.py", line 151, in ddim_sampling
      outs = self.p_sample_ddim(img, cond, image_cond, ts, ph_pos, index=index, total_steps=total_steps, use_original_steps=ddim_use_original_steps,
    File "/opt/miniconda/envs/vico/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
      return func(*args, **kwargs)
    File "/home/azureuser/ViCo/ldm/models/diffusion/ddim.py", line 187, in p_sample_ddim
      e_t_uncond, e_t = self.model.apply_model(x_in, c_img_in, t_in, c_in, c_in, ph_pos_in, use_img_cond=True)[0].chunk(2)
    File "/home/azureuser/ViCo/ldm/models/diffusion/ddpm.py", line 1062, in apply_model
      x_recon, loss_reg = self.model(x_noisy, x_ref, t, cond_init, ph_pos, use_img_cond, **cond,)
    File "/opt/miniconda/envs/vico/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
      return forward_call(*input, **kwargs)
    File "/home/azureuser/ViCo/ldm/models/diffusion/ddpm.py", line 1624, in forward
      out, loss_reg = self.diffusion_model(x, xr, t, cc_init, ph_pos, use_img_cond, context=cc)        
    File "/opt/miniconda/envs/vico/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
      return forward_call(*input, **kwargs)
    File "/home/azureuser/ViCo/ldm/modules/diffusionmodules/openaimodel.py", line 766, in forward
      h, hr, loss_reg, attn = module(h, hr, emb, context, cc_init, ph_pos, use_img_cond)
    File "/opt/miniconda/envs/vico/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
      return forward_call(*input, **kwargs)
    File "/home/azureuser/ViCo/ldm/modules/diffusionmodules/openaimodel.py", line 87, in forward
      x, xr, loss_reg, attn = layer(x, xr, context, cc_init, ph_pos, use_img_cond, return_attn=True)
    File "/opt/miniconda/envs/vico/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
      return forward_call(*input, **kwargs)
    File "/home/azureuser/ViCo/ldm/modules/attention.py", line 333, in forward
      attn_ph = attn[ph_idx].squeeze(1) # bs, n_patch
  IndexError: shape mismatch: indexing tensors could not be broadcast together with shapes [4], [2]

感谢你们的回复。

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions