-
Notifications
You must be signed in to change notification settings - Fork 480
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] Step-based checkpointing in torchtune #2105
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2105
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit c4d7a93 with merge base efa91bf (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
It makes sense to me. I am not 100% sure how i feel about keeping epochs AND steps, since defaulting to one might simplify things. Internally, we only have the concept of steps, and it is up to the user to use bsz + len(dataset) to calculate how many steps == 1 epoch (which is a bit annoying). Can we have a config showing the args for streaming dataset and checkpointing? If it looks simple, i like keeping both regarding |
Yay step-based checkpointing! Some thoughts:
|
I totally get this. I might; however, be annoying and punt this to a different discussion b/c right now I am keeping epochs as a method to determine how long to train the model, whereas steps only determine when to checkpoint. I'd rather give more time to properly educate users if we do end up taking away epochs entirely.
Yeah, this makes me a little uncomfortable, too, but it does seem to be the standard in other training code like Titan and TNT. |
Yes, I think this is a natural extension to this work. If you look at the TNT code, they actual have some scaffolding to do this. I'd be happy to add that as a follow-up or even have this be community contributed.
This makes a ton of sense and how I would want to incorporate this information now is to make sure it's easy to extend to such a use case when the time comes. |
Enabling step-based checkpointing in torchtune
Original context: #2070
What are we currently doing?
We currently only checkpoint at epoch boundaries. That means a fine-tuning run has to iterate through all data in a dataset before saving a checkpoint. That's a problem when GPUs (especially interconnected GPUs) can fail frequently, losses can diverge, and datasets keep getting larger and larger.
We provide a tiny amount of flexibility by allowing the user to specify
max_steps_per_epoch
, so they can short-circuit the epoch and save sooner. In addition, it's always possible to split a dataset into chunks and train over them independently, resuming from training to simulate a larger training run.Both of these "hacks" are not ideal and we've had users continually asking if they can control checkpointing based on number of training steps. (#988, #1107)
What does step-based checkpointing look like for the user?
I think the best way to do this would to show an example. Let's take our Llama3 8B single device full fine-tuning recipe, which utilizes the Alpaca dataset. The Alpaca dataset has ~52k samples. Using a batch size of 2 and a gradient accumulation of 16 steps, we can estimate around 1625 steps in this training run. Let's save a checkpointing every 500 steps!
From the config, we can specify:
And in our output directory, we can expect to see something like this:
At this point you might be saying: @joecummings, do you think memory grows on trees? Do you think we all drive Bugattis and smash up Grace Hopper machines for fun? Each Llama3 8B model is roughly 16 GB of memory and we've saved 4 copies of that in addition to the base model we used. That's 80 GB just for checkpoints! Not even to mention if we wanted to save the optimizer states, too...
Introducing:
This param will prune all the checkpoints except for the last N specified, leaving you with just the checkpoints you're interested in:
What about the concept of epochs?
The concept of epochs will stay as a way to control how long training runs, as will the possibility to shorten training using
max_steps_per_epoch
; however, checkpointing will be entirely handled by a specification of steps.Will this slow down training?
Great question! Checkpointing can take a long time, especially if saving the optimizer state for resuming training at a later date. For single device recipes, this likely isn't a huge issue, but for distributed recipes where the state dict needs to be collected on rank zero before saving, this can be verrrrrrry slow so anything that increases the frequency of checkpointing will increase the time it takes for training to complete. There are two ways to mitigate this:
What changes need to be made in code?
In the recipe:
And in the checkpointer:
Inspiration from relevant repositories: