Skip to content

Comments

feat: Add auto-save checkpoint on training interruption#139

Open
facok wants to merge 2 commits intotdrussell:mainfrom
facok:main
Open

feat: Add auto-save checkpoint on training interruption#139
facok wants to merge 2 commits intotdrussell:mainfrom
facok:main

Conversation

@facok
Copy link

@facok facok commented Mar 8, 2025

  • Add signal handlers for Ctrl+C(SIGINT) and SIGTERM
  • Auto-save training progress on interruption
  • Add exception handling for safe exit
  • Support synchronized exit in distributed training
  • Prevent duplicate save triggers

facok added 2 commits March 8, 2025 19:29
- Add signal handlers for Ctrl+C(SIGINT) and SIGTERM
- Auto-save training progress on interruption
- Add exception handling for safe exit
- Support synchronized exit in distributed training
- Prevent duplicate save triggers
- Add triggered flag to interrupt handler to avoid multiple saves when Ctrl+C is pressed
- Ensure checkpoint is only saved once during program termination
num_steps += 1
train_dataloader.sync_epoch()

# Check if checkpoint save is needed
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Theoretically you could still save a checkpoint twice on the same step, since saver.process_step() could checkpoint on this step. I don't think it matters much (it would just overwrite the checkpoint dir) but it would be nice to gracefully handle this case.

  1. Modify saver.process_step to return whether a checkpoint was saved.
  2. Move this block below saver.process_step.
  3. Add an extra condition to only save the checkpoint if process_step didn't. Make sure to still set the should_save flag to False regardless.

setup_checkpoint_signal.should_save = False


def setup_interrupt_handler(saver):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine with always setting up a handler for the USR1 signal to checkpoint. But the interrupt handler should be guarded behind a new parameter in the TOML config file which is false by default. For testing things I am often ctrl+c killing the program. And sometimes you launch training and immediately realize something is wrong and just want to terminate it immediately. If my understanding is correct the current code will always wait until you get through the next step, and then save a checkpoint.

An alternative, maybe, would be to set up something where double ctrl+c instantly kills the program. I've seen other tools do that, and I think that would be an okay option as well. Or maybe ctrl+c blocks and waits for user input with a "want to checkpoint [y]/n?" option, and a second ctrl+c at that point instantly kills it?

@birunram
Copy link

birunram commented May 18, 2025

Hi. I'm a noob here. I've successfully trained 5 LoRA with this diffusion pipe, but now I can't. IDK why. During my last successful training, I got interrupted in the middle of the training. Later, I continued to train with a check mark on continue from the previous checkpoint and finished successfully, but after that, I couldn't train any new LORA. I get stuck. What's happened? thx in advance

_I got this error: Traceback (most recent call last):
File "/usr/local/lib/python3.11/dist-packages/multiprocess/process.py", line 314, in _bootstrap
self.run()
File "/usr/local/lib/python3.11/dist-packages/multiprocess/process.py", line 108, in run
self._target(*self._args, **self._kwargs)
File "/content/diffusion-pipe/utils/dataset.py", line 741, in _cache_fn
ds.cache_latents(latents_map_fn, regenerate_cache=regenerate_cache, caching_batch_size=caching_batch_size)
File "/content/diffusion-pipe/utils/dataset.py", line 687, in cache_latents
ds.cache_latents(map_fn, regenerate_cache=regenerate_cache, caching_batch_size=caching_batch_size)
File "/content/diffusion-pipe/utils/dataset.py", line 561, in cache_latents
ds.cache_latents(map_fn, regenerate_cache=regenerate_cache, caching_batch_size=caching_batch_size)
File "/content/diffusion-pipe/utils/dataset.py", line 243, in cache_latents
ds.cache_latents(map_fn, regenerate_cache=regenerate_cache, caching_batch_size=caching_batch_size)
File "/content/diffusion-pipe/utils/dataset.py", line 135, in cache_latents
for example in self.latent_dataset.select_columns(['image_file', 'caption']):
File "/usr/local/lib/python3.11/dist-packages/datasets/arrow_dataset.py", line 2384, in iter
formatted_output = format_table(
^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/datasets/formatting/formatting.py", line 629, in format_table
return formatter(pa_table, query_type=query_type)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/datasets/formatting/formatting.py", line 396, in call
return self.format_row(pa_table)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/datasets/formatting/torch_formatter.py", line 88, in format_row
row = self.numpy_arrow_extractor().extract_row(pa_table)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/datasets/formatting/formatting.py", line 158, in extract_row
return _unnest(self.extract_batch(pa_table))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/datasets/formatting/formatting.py", line 164, in extract_batch
return {col: self._arrow_array_to_numpy(pa_table[col]) for col in pa_table.column_names}
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/datasets/formatting/formatting.py", line 164, in
return {col: self._arrow_array_to_numpy(pa_table[col]) for col in pa_table.column_names}
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/datasets/formatting/formatting.py", line 196, in arrow_array_to_numpy
return np.array(array, copy=False)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: Unable to avoid copy while creating an array as requested.
If using np.array(obj, copy=False) replace it with np.asarray(obj) to allow a copy when needed (no behavior change in NumPy 1.x).
For more details, see https://numpy.org/devdocs/numpy_2_0_migration_guide.html#adapting-to-changes-in-the-copy-keyword.
====================================================

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants