Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jan 16, 2025
1 parent 650af68 commit 307d8f1
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 78 deletions.
5 changes: 3 additions & 2 deletions colossalai/checkpoint_io/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from .checkpoint_io_base import CheckpointIO
from .general_checkpoint_io import GeneralCheckpointIO
from .hybrid_parallel_checkpoint_io import HybridParallelCheckpointIO
from.distributed_checkpoint_io import DistributedCheckpointIO

from .distributed_checkpoint_io import DistributedCheckpointIO
from .index_file import CheckpointIndexFile
from .moe_checkpoint import MoECheckpointIO

Expand All @@ -11,5 +12,5 @@
"GeneralCheckpointIO",
"HybridParallelCheckpointIO",
"MoECheckpointIO",
"DistributedCheckpointIO"
"DistributedCheckpointIO",
]
128 changes: 61 additions & 67 deletions colossalai/checkpoint_io/distributed_checkpoint_io.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,30 @@
import copy
import json
import logging
import os
from functools import reduce
from pathlib import Path
from shutil import rmtree
from typing import Dict, Iterator, Optional, OrderedDict, Tuple
import json

import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ProcessGroup
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils._pytree import tree_map
from torch.distributed.distributed_c10d import _get_default_group

from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.tensor.padded_tensor import (
init_as_padded_tensor,
is_padded_tensor,
to_padded_tensor,
to_unpadded_tensor,
)
from colossalai.utils import get_current_device, get_non_persistent_buffers_set
from torch.distributed.distributed_c10d import _get_default_group
from colossalai.interface import ModelWrapper
from colossalai.utils import get_non_persistent_buffers_set

from .general_checkpoint_io import GeneralCheckpointIO
from .index_file import CheckpointIndexFile
from .utils import (
StateDictSharder,
async_save_state_dict_shards,
create_pinned_state_dict,
gather_distributed_param,
get_model_base_filenames,
get_optimizer_base_filenames,
is_safetensors_available,
load_shard_state_dict,
load_state_dict,
load_state_dict_into_model,
load_states_into_optimizer,
save_config_file,
save_param_groups,
save_state_dict,
save_state_dict_shards,
search_padding_dim,
search_tp_partition_dim,
sharded_optimizer_loading_epilogue,
)

try:
Expand Down Expand Up @@ -97,7 +76,6 @@ def __init__(
self.model_metadata = None
self.optimizer_metadata = None
self.global_rank = dist.get_rank(_get_default_group())


@staticmethod
def model_state_dict(model: nn.Module, prefix: str = "", keep_vars: bool = False):
Expand All @@ -106,13 +84,13 @@ def model_state_dict(model: nn.Module, prefix: str = "", keep_vars: bool = False
for name, param in model.named_parameters():
if param is None:
continue
destination[prefix+name] = param
destination[prefix + name] = param
# Save buffers.
non_persist_buffers_set = get_non_persistent_buffers_set(model)
for name, buf in model.named_buffers():
if buf is not None and name not in non_persist_buffers_set:
buffer = buf if keep_vars else buf.detach()
destination[prefix+name] = buffer
destination[prefix + name] = buffer

# Save extra states.
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
Expand All @@ -123,22 +101,24 @@ def model_state_dict(model: nn.Module, prefix: str = "", keep_vars: bool = False
extra_state = model.get_extra_state()
destination[extra_state_key] = extra_state
return destination

@staticmethod
def load_state_dict(model: nn.Module, state_dict: Dict, prefix: str = "", keep_vars: bool = False, strict: bool = False):
def load_state_dict(
model: nn.Module, state_dict: Dict, prefix: str = "", keep_vars: bool = False, strict: bool = False
):
destination = dict()
# Save parameters.
for name, param in model.named_parameters():
if param is None:
continue
with torch.no_grad():
param.copy_(state_dict[prefix+name])
param.copy_(state_dict[prefix + name])
# Save buffers.
non_persist_buffers_set = get_non_persistent_buffers_set(model)
for name, buf in model.named_buffers():
if buf is not None and name not in non_persist_buffers_set:
with torch.no_grad():
buf.copy_(state_dict[prefix+name])
buf.copy_(state_dict[prefix + name])

# Save extra states.
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
Expand All @@ -151,26 +131,33 @@ def load_state_dict(model: nn.Module, state_dict: Dict, prefix: str = "", keep_v
extra_state.copy_(state_dict[extra_state_key])
return destination

def create_model_metadata(self, model: nn.Module, prefix: str = "",):
def create_model_metadata(
self,
model: nn.Module,
prefix: str = "",
):
param_origin_shape = model.param_origin_shape
model = model.unwrap()
self.model_metadata = {}
for name, param in model.named_parameters():
if param is None:
continue
self.model_metadata[prefix+name] = {}
self.model_metadata[prefix + name] = {}
original_shape = param_origin_shape[name]
tp_partition_dim = search_tp_partition_dim(current_shape=param.shape, original_shape=original_shape, tp_size=self.tp_size)
self.model_metadata[prefix+name]["offsets"] = torch.zeros(len(original_shape), dtype=torch.int)
self.model_metadata[prefix+name]["lengths"] = list(param.shape)
self.model_metadata[prefix+name]["global_shape"] = list(original_shape)
tp_partition_dim = search_tp_partition_dim(
current_shape=param.shape, original_shape=original_shape, tp_size=self.tp_size
)
self.model_metadata[prefix + name]["offsets"] = torch.zeros(len(original_shape), dtype=torch.int)
self.model_metadata[prefix + name]["lengths"] = list(param.shape)
self.model_metadata[prefix + name]["global_shape"] = list(original_shape)
if tp_partition_dim is not None:
partition_size = param.shape[tp_partition_dim]
self.model_metadata[prefix+name]["offsets"][tp_partition_dim] = partition_size * self.tp_rank
self.model_metadata[prefix + name]["offsets"][tp_partition_dim] = partition_size * self.tp_rank
if self.tp_rank == self.tp_size - 1:
self.model_metadata[prefix+name]["lengths"][tp_partition_dim] = original_shape[tp_partition_dim] - (partition_size * (self.tp_size -1))
self.model_metadata[prefix + name]["lengths"][tp_partition_dim] = original_shape[
tp_partition_dim
] - (partition_size * (self.tp_size - 1))


def save_metadata(self, metadata_file, checkpoint_file=None, total_size=None):
metadata_dicts = {
"checkpoint_version": "1.0",
Expand All @@ -188,7 +175,7 @@ def save_metadata(self, metadata_file, checkpoint_file=None, total_size=None):
metadata_dicts["metadata"][name]["rank"] = self.global_rank
with open(metadata_file, "w") as json_file:
json.dump(metadata_dicts, json_file, indent=4)

def save_unsharded_model(
self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
):
Expand Down Expand Up @@ -249,13 +236,13 @@ def load_metadata(self, checkpoint: str):
try:
with open(file_path, "r") as f:
metadata_json = json.load(f)
for name, item in metadata_json['metadata'].items():
for name, item in metadata_json["metadata"].items():
if name not in metadata_dict:
metadata_dict[name] = {}
metadata_dict[name]["global_shape"] = item['global_shape']
metadata_dict[name]["global_shape"] = item["global_shape"]
metadata_dict[name]["shards"] = {}
else:
assert metadata_dict[name]["global_shape"] == item['global_shape']
assert metadata_dict[name]["global_shape"] == item["global_shape"]
shard = {}
shard[item["rank"]] = {}
shard[item["rank"]]["file"] = item["file"]
Expand Down Expand Up @@ -304,7 +291,7 @@ def find_covering_shards(self, shards, target_offsets, target_lengths):

assert total_lengths == global_shape
return covering_shards

def extract_weight_from_shard_partial(self, shard, target_offsets, target_lengths):
"""
Extract the target range of weights from shard data, supporting partial overlap.
Expand All @@ -314,14 +301,16 @@ def extract_weight_from_shard_partial(self, shard, target_offsets, target_length
param target_lengths: A 1D array indicating the length of the target tensor in each dimension.
return: The extracted sub-tensor of the target weights and its position within the target range.
"""
shard_offsets = shard['offsets']
shard_lengths = shard['lengths']
weight = shard['weight']
shard_offsets = shard["offsets"]
shard_lengths = shard["lengths"]
weight = shard["weight"]

slices = []
target_slices = []

for dim, (t_offset, t_length, s_offset, s_length) in enumerate(zip(target_offsets, target_lengths, shard_offsets, shard_lengths)):
for dim, (t_offset, t_length, s_offset, s_length) in enumerate(
zip(target_offsets, target_lengths, shard_offsets, shard_lengths)
):
intersection_start = max(t_offset, s_offset)
intersection_end = min(t_offset + t_length, s_offset + s_length)

Expand All @@ -339,7 +328,6 @@ def extract_weight_from_shard_partial(self, shard, target_offsets, target_length
target_weight = weight[tuple(slices)]
return target_weight, target_slices


def assemble_tensor_from_shards_partial(self, shards, target_offsets, target_lengths, dtype):
target_tensor = torch.zeros(target_lengths, dtype=dtype)

Expand All @@ -351,15 +339,14 @@ def assemble_tensor_from_shards_partial(self, shards, target_offsets, target_len

return target_tensor


def load_unsharded_model(
def load_unsharded_model(
self,
model: ModelWrapper,
checkpoint: str,
strict: bool = False,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
):
"""
Load model from a single file with the given path of checkpoint.
Expand Down Expand Up @@ -390,30 +377,34 @@ def load_unsharded_model(
for key, item in self.model_metadata.items():
offsets = item["offsets"]
lengths = item["lengths"]
assert item["global_shape"] == metadata_loaded[key]["global_shape"], f"{item['global_shape']}, {metadata_loaded[key]['global_shape']}"
assert (
item["global_shape"] == metadata_loaded[key]["global_shape"]
), f"{item['global_shape']}, {metadata_loaded[key]['global_shape']}"
shards = metadata_loaded[key]["shards"]
covering_shards = self.find_covering_shards(shards=shards, target_offsets=offsets, target_lengths=lengths)
covered_shards[key] = covering_shards
# load_files.update({rank: shard['file'] for rank, shard in covering_shards.items()})
for rank, shard in covering_shards.items():
if rank not in load_files:
load_files[rank] = set()
load_files[rank].add(shard['file'])
load_files[rank].add(shard["file"])

dtype = None
for rank, files in load_files.items():
for file in files:
file_path = os.path.join(checkpoint, file)
state_dict_shard = load_state_dict(file_path)
for key, weight in state_dict_shard.items():
for key, weight in state_dict_shard.items():
if key not in covered_shards:
continue
if dtype == None:
dtype = weight.dtype
covered_shards[key][rank]["weight"] = weight
state_dict = {}
for key, shards in covered_shards.items():
state = self.assemble_tensor_from_shards_partial(shards, self.model_metadata[key]["offsets"], self.model_metadata[key]["lengths"], dtype=dtype)
state = self.assemble_tensor_from_shards_partial(
shards, self.model_metadata[key]["offsets"], self.model_metadata[key]["lengths"], dtype=dtype
)
state_dict[key] = state

if not low_cpu_mem_mode:
Expand All @@ -424,7 +415,6 @@ def load_unsharded_model(
# Update master params if mixed-precision training is enabled.
model_before_wrapping.update_master_params()


@staticmethod
def _model_sharder(
model: nn.Module,
Expand Down Expand Up @@ -571,7 +561,7 @@ def save_sharded_model(
)
for k, _ in self.model_metadata.items():
self.model_metadata[k]["file"] = index_file.get_checkpoint_file(k)

self.save_metadata(metadata_file, total_size=total_size)

def load_sharded_model(
Expand Down Expand Up @@ -606,30 +596,34 @@ def load_sharded_model(
for key, item in self.model_metadata.items():
offsets = item["offsets"]
lengths = item["lengths"]
assert item["global_shape"] == metadata_loaded[key]["global_shape"], f"{item['global_shape']}, {metadata_loaded[key]['global_shape']}"
assert (
item["global_shape"] == metadata_loaded[key]["global_shape"]
), f"{item['global_shape']}, {metadata_loaded[key]['global_shape']}"
shards = metadata_loaded[key]["shards"]
covering_shards = self.find_covering_shards(shards=shards, target_offsets=offsets, target_lengths=lengths)
covered_shards[key] = covering_shards
for rank, shard in covering_shards.items():
if rank not in load_files:
load_files[rank] = set()
load_files[rank].add(shard['file'])
load_files[rank].add(shard["file"])

dtype = None
for rank, files in load_files.items():
for file in files:
file_path = os.path.join(checkpoint, file)
state_dict_shard = load_state_dict(file_path)
for key, weight in state_dict_shard.items():
for key, weight in state_dict_shard.items():
if key not in covered_shards:
continue
if dtype == None:
dtype = weight.dtype
covered_shards[key][rank]["weight"] = weight

state_dict = {}
for key, shards in covered_shards.items():
state = self.assemble_tensor_from_shards_partial(shards, self.model_metadata[key]["offsets"], self.model_metadata[key]["lengths"], dtype=dtype)
state = self.assemble_tensor_from_shards_partial(
shards, self.model_metadata[key]["offsets"], self.model_metadata[key]["lengths"], dtype=dtype
)
state_dict[key] = state

if not low_cpu_mem_mode:
Expand All @@ -638,4 +632,4 @@ def load_sharded_model(
DistributedCheckpointIO.load_state_dict(model=model, state_dict=state_dict)

# Update master params if mixed-precision training is enabled.
model_before_wrapping.update_master_params()
model_before_wrapping.update_master_params()
Loading

0 comments on commit 307d8f1

Please sign in to comment.