Skip to content

Commit

Permalink
support distribute checkpoint io
Browse files Browse the repository at this point in the history
  • Loading branch information
flybird11111 committed Jan 16, 2025
1 parent 5b094a8 commit 650af68
Show file tree
Hide file tree
Showing 7 changed files with 801 additions and 14 deletions.
3 changes: 3 additions & 0 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ def __init__(
self.require_grad_sync = True
self.overlap_allgather = overlap_allgather
self.use_fp8 = use_fp8
self.param_origin_shape = {}
for name, param in module.named_parameters():
self.param_origin_shape[name] = param.shape

shardformer = ShardFormer(shard_config)
if custom_policy is not None:
Expand Down
2 changes: 2 additions & 0 deletions colossalai/checkpoint_io/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .checkpoint_io_base import CheckpointIO
from .general_checkpoint_io import GeneralCheckpointIO
from .hybrid_parallel_checkpoint_io import HybridParallelCheckpointIO
from.distributed_checkpoint_io import DistributedCheckpointIO
from .index_file import CheckpointIndexFile
from .moe_checkpoint import MoECheckpointIO

Expand All @@ -10,4 +11,5 @@
"GeneralCheckpointIO",
"HybridParallelCheckpointIO",
"MoECheckpointIO",
"DistributedCheckpointIO"
]
Loading

0 comments on commit 650af68

Please sign in to comment.