Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jan 20, 2025
1 parent 17ee5a7 commit e8659ea
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 20 deletions.
1 change: 0 additions & 1 deletion colossalai/checkpoint_io/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from .checkpoint_io_base import CheckpointIO
from .general_checkpoint_io import GeneralCheckpointIO
from .hybrid_parallel_checkpoint_io import HybridParallelCheckpointIO

from .index_file import CheckpointIndexFile
from .moe_checkpoint import MoECheckpointIO

Expand Down
2 changes: 1 addition & 1 deletion colossalai/checkpoint_io/general_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,4 +309,4 @@ def load_sharded_model(
)

def save_lora_as_pretrained(self, model: nn.Module, checkpoint: str, use_safetensors: bool = False) -> None:
raise NotImplementedError
raise NotImplementedError
60 changes: 44 additions & 16 deletions colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@
from colossalai.utils import get_current_device, get_non_persistent_buffers_set
from colossalai.utils.safetensors import _flatten_optim_state_dict, load_flat

from .distributed_checkpoint_utils import (
create_model_metadata,
is_pytorch_model_meta_dist_file,
load_dist_model,
save_dist_sharded_model,
save_dist_unshard_model,
)
from .general_checkpoint_io import GeneralCheckpointIO
from .index_file import CheckpointIndexFile
from .utils import (
Expand All @@ -47,14 +54,6 @@
sharded_optimizer_loading_epilogue,
)

from .distributed_checkpoint_utils import (
save_dist_sharded_model,
save_dist_unshard_model,
load_dist_model,
is_pytorch_model_meta_dist_file,
create_model_metadata
)

try:
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
except ImportError:
Expand Down Expand Up @@ -244,9 +243,19 @@ def save_sharded_model(
return
dist_id = self.tp_size * self.pp_rank + self.tp_rank
model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank)
save_dist_sharded_model(model=model, model_metadata=model_metadata, checkpoint=checkpoint, prefix=prefix, size_per_shard=size_per_shard, use_safetensors=use_safetensors, use_async=use_async, dist_id = dist_id, pinned_state_dicts = self.pinned_state_dicts)
save_dist_sharded_model(
model=model,
model_metadata=model_metadata,
checkpoint=checkpoint,
prefix=prefix,
size_per_shard=size_per_shard,
use_safetensors=use_safetensors,
use_async=use_async,
dist_id=dist_id,
pinned_state_dicts=self.pinned_state_dicts,
)
return

model = model.unwrap()

if os.path.isfile(checkpoint):
Expand Down Expand Up @@ -394,9 +403,15 @@ def load_sharded_model(

if is_pytorch_model_meta_dist_file(checkpoint_index_file):
model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank)
load_dist_model(model=model, model_metadata=model_metadata, checkpoint=checkpoint_index_file, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads)
load_dist_model(
model=model,
model_metadata=model_metadata,
checkpoint=checkpoint_index_file,
low_cpu_mem_mode=low_cpu_mem_mode,
num_threads=num_threads,
)
return

model_before_wrapping = model # backup for model before wrapping
model = model.unwrap()

Expand Down Expand Up @@ -792,9 +807,17 @@ def save_unsharded_model(
if self.dp_rank != 0 and self.sp_rank != 0:
return
dist_id = self.tp_size * self.pp_rank + self.tp_rank
save_dist_unshard_model(model=model, model_metadata=model_metadata, checkpoint=checkpoint, use_safetensors=use_safetensors, use_async=use_async, dist_id = dist_id, pinned_state_dicts = self.pinned_state_dicts)
save_dist_unshard_model(
model=model,
model_metadata=model_metadata,
checkpoint=checkpoint,
use_safetensors=use_safetensors,
use_async=use_async,
dist_id=dist_id,
pinned_state_dicts=self.pinned_state_dicts,
)
return

model = model.unwrap()
if self.dp_rank != 0:
return
Expand Down Expand Up @@ -867,7 +890,13 @@ def load_unsharded_model(
for filename in os.listdir(checkpoint):
if is_pytorch_model_meta_dist_file(filename):
model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank)
load_dist_model(model=model, model_metadata=model_metadata, checkpoint=checkpoint, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads)
load_dist_model(
model=model,
model_metadata=model_metadata,
checkpoint=checkpoint,
low_cpu_mem_mode=low_cpu_mem_mode,
num_threads=num_threads,
)
return

strict = False
Expand Down Expand Up @@ -1099,7 +1128,6 @@ def gather_from_sharded_optimizer_state(
dist.all_gather(gather_tensor, v, group=dp_group)
v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)


# Then gather TP shards.
partition_dim = search_tp_partition_dim(current_shape, original_shape, tp_size)
if partition_dim is not None:
Expand Down
14 changes: 12 additions & 2 deletions tests/test_checkpoint_io/test_dist_checkpointio.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,12 @@ def _preprocess_data(data):
model_ckpt_path_0 = f"{tempdir}/model_0"

booster_0.save_model(
model_0, model_ckpt_path_0, shard=shard, gather_dtensor=True, size_per_shard=size_per_shard, use_async=use_async
model_0,
model_ckpt_path_0,
shard=shard,
gather_dtensor=True,
size_per_shard=size_per_shard,
use_async=use_async,
)
booster_0.checkpoint_io._sync_d2h()
booster_0.checkpoint_io._sync_io()
Expand All @@ -96,7 +101,12 @@ def _preprocess_data(data):

model_ckpt_path_1 = f"{tempdir}/model_1"
booster_1.save_model(
model_1, model_ckpt_path_1, shard=shard, gather_dtensor=True, size_per_shard=size_per_shard, use_async=use_async
model_1,
model_ckpt_path_1,
shard=shard,
gather_dtensor=True,
size_per_shard=size_per_shard,
use_async=use_async,
)
booster_1.checkpoint_io._sync_d2h()
booster_1.checkpoint_io._sync_io()
Expand Down

0 comments on commit e8659ea

Please sign in to comment.