Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[checkpointio]support distributed checkpoint io for model saving. #6181

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

flybird11111
Copy link
Contributor

📌 Checklist before creating the PR

  • I have created an issue for this PR for traceability
  • The title follows the standard format: [doc/gemini/tensor/...]: A concise description
  • I have added relevant tags if possible for us to better distinguish different PRs
  • I have installed pre-commit: pip install pre-commit && pre-commit install

🚨 Issue number

Link this PR to your issue with words like fixed to automatically close the linked issue upon merge

e.g. fixed #1234, closed #1234, resolved #1234

📝 What does this PR do?

Summarize your work here.
if you have any plots/diagrams/screenshots/tables, please attach them here.

💥 Checklist before requesting a review

  • I have linked my PR to an issue (instruction)
  • My issue clearly describes the problem/feature/proposal, with diagrams/charts/table/code if possible
  • I have performed a self-review of my code
  • I have added thorough tests.
  • I have added docstrings for all the functions/methods I implemented

⭐️ Do you enjoy contributing to Colossal-AI?

  • 🌝 Yes, I do.
  • 🌚 No, I don't.

Tell us more if you don't enjoy contributing to Colossal-AI.

@flybird11111 flybird11111 requested a review from a team as a code owner January 16, 2025 10:28
@flybird11111 flybird11111 changed the title [checkpointio]support distribute checkpoint io [checkpointio]support distributed checkpoint io for model saving. Jan 16, 2025

MODEL_META_PREFIX = "pytorch_model-meta-dist-"
MODEL_WEIGHT_PREFIX = "pytorch_model-dist-"
MODEL_SHARD_SUUFIX = ".index.json"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SHARD_META_SUFFIX?

@@ -10,4 +11,5 @@
"GeneralCheckpointIO",
"HybridParallelCheckpointIO",
"MoECheckpointIO",
"DistributedCheckpointIO",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should not be an independent checkpoint io class. It should provide some utils functions for each current checkpoint io class.

@Lemon-412
Copy link

hi all, take a look at this please. This bug is quite annoying for me.

#6168

@flybird11111 flybird11111 force-pushed the dist-ckp branch 2 times, most recently from e8659ea to 51c208c Compare January 20, 2025 03:30
@flybird11111
Copy link
Contributor Author

hi all, take a look at this please. This bug is quite annoying for me.

#6168

ok

return destination


def load_state_dict_into_dist_model(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this function for? Is it for loading whole state dict? Default model.load_state_dict() has already implemented this feature.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The parallelmodule will perform the gather tensor operation.

tp_rank=None,
):
param_origin_shape = model.param_origin_shape
model = model.unwrap()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Type hint and the method called here are not matched. If you assume the model is wrapped, then the type should be ModuleWrapper.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

tp_partition_dim = search_tp_partition_dim(
current_shape=param.shape, original_shape=original_shape, tp_size=tp_size
)
model_metadata[prefix + name]["offsets"] = torch.zeros(len(original_shape), dtype=torch.int)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use list directly?

Comment on lines +91 to +96
def create_model_metadata(
model: nn.Module,
prefix: str = "",
tp_size=None,
tp_rank=None,
):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that this function is only intended for TP. What about Gemini? If it's only designed for TP, then move it to hybrid parallel checkpoint io file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DP support can be added in the future.

metadata_dicts["metadata"][name][k] = v
if checkpoint_file is not None:
metadata_dicts["metadata"][name]["file"] = checkpoint_file
metadata_dicts["metadata"][name]["rank"] = dist.get_rank(_get_default_group())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This rank is not contiguous. Is it just for indicating order?

Copy link
Contributor Author

@flybird11111 flybird11111 Jan 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This rank is just for marking the shard that the tensor is split into, for internal use, as I understand, it only needs to be unique.

Comment on lines +260 to +510
if key not in covered_shards or rank not in covered_shards[key]:
continue
if dtype == None:
dtype = weight.dtype
covered_shards[key][rank]["weight"] = weight
state_dict = {}
for key, shards in covered_shards.items():
state = assemble_tensor_from_shards_partial(
shards, model_metadata[key]["offsets"], model_metadata[key]["lengths"], dtype=dtype
)
state_dict[key] = state

if not low_cpu_mem_mode:
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)

load_state_dict_into_dist_model(model=model, state_dict=state_dict)

# Update master params if mixed-precision training is enabled.
model_before_wrapping.update_master_params()


def save_dist_sharded_model(
model: ModelWrapper,
model_metadata: Dict,
checkpoint: str,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_safetensors: bool = False,
use_async: bool = False,
dist_id: int = 0,
pinned_state_dicts=None,
) -> None:
"""
Save sharded model checkpoint under the given checkpointing path.
The following files will be created under the path:
- An index file (pytorch_model.bin.index.json) containing a map between model params/buffers and file names.
- Multiple files that store state tensors of models.
If pipeline parallelism is used, the filenames are in the form of "pytorch_model.<prefix>-stage-000XX-shard-000XX.bin".
If pipeline parallelism is not used, "pytorch_model.<prefix>-000XX.bin"


Args:
model (nn.Module): Model on local device to be saved.
checkpoint (str): Checkpointing path which should be a directory path.
gather_dtensor (bool, optional): Whether to gather_dtensor, currently not used. Defaults to True.
prefix (str, optional): Perfix of file to save. Defaults to None.
size_per_shard (int, optional): Size per shard in MB. Defaults to 1024.
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
use_async (bool, optional): Whether to save the state_dicts of model asynchronously. Defaults to False.
"""

model = model.unwrap()

if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return

Path(checkpoint).mkdir(parents=True, exist_ok=True)
# Devices along the same dp_group share the same copies of model.
# So only let the device with dp_rank == 0 and sp_rank == 0 save the model.

if use_async:
if id(model) not in pinned_state_dicts:
pinned_state_dicts[id(model)] = {}
pinned_state_dicts = pinned_state_dicts[id(model)]
else:
pinned_state_dicts = None
state_dict_shard = dist_model_sharder(model, size_per_shard=size_per_shard, pinned_state_dicts=pinned_state_dicts)
weights_name, _ = get_model_base_filenames(prefix, use_safetensors)
index_file = CheckpointIndexFile(checkpoint)

# Manage filenames of sharded weights and index file for each pipeline stage.
weights_name = weights_name.replace(".bin", f"-dist-{dist_id:05d}-shard.bin")
weights_name = weights_name.replace(".safetensors", f"-dist-{dist_id:05d}-shard.safetensors")
metadata_file = os.path.join(checkpoint, f"{MODEL_META_PREFIX}{dist_id:05d}{SHARD_META_SUFFIX}")
async_writers = []
if use_async:
total_size, writers = async_save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=True,
state_preprocess=False,
)
async_writers.extend(writers)
else:
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=True,
use_safetensors=use_safetensors,
use_pp_format=True,
)
for k, _ in model_metadata.items():
model_metadata[k]["file"] = index_file.get_checkpoint_file(k)

save_metadata(model_metadata, metadata_file, total_size=total_size)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's only designed for hybrid parallel, then move it to hybrid parallel checkpoint io file. AND too many redundant codes. Please try to reuse some common code snippets.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The format of metadata is uniform, and save_metadata is generic.

@ver217
Copy link
Member

ver217 commented Jan 20, 2025

DON'T merge to main. Create a new feature branch on the org repo and merge to it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants