From 307d8f1838a13ceffb5f2b060466216a327da2c9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 16 Jan 2025 10:29:08 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/checkpoint_io/__init__.py | 5 +- .../distributed_checkpoint_io.py | 128 +++++++++--------- .../test_dist_checkpointio.py | 48 +++++-- 3 files changed, 103 insertions(+), 78 deletions(-) diff --git a/colossalai/checkpoint_io/__init__.py b/colossalai/checkpoint_io/__init__.py index 02964e9ae5ee..bf61a862aab4 100644 --- a/colossalai/checkpoint_io/__init__.py +++ b/colossalai/checkpoint_io/__init__.py @@ -1,7 +1,8 @@ from .checkpoint_io_base import CheckpointIO from .general_checkpoint_io import GeneralCheckpointIO from .hybrid_parallel_checkpoint_io import HybridParallelCheckpointIO -from.distributed_checkpoint_io import DistributedCheckpointIO + +from .distributed_checkpoint_io import DistributedCheckpointIO from .index_file import CheckpointIndexFile from .moe_checkpoint import MoECheckpointIO @@ -11,5 +12,5 @@ "GeneralCheckpointIO", "HybridParallelCheckpointIO", "MoECheckpointIO", - "DistributedCheckpointIO" + "DistributedCheckpointIO", ] diff --git a/colossalai/checkpoint_io/distributed_checkpoint_io.py b/colossalai/checkpoint_io/distributed_checkpoint_io.py index f89fcd4ce973..e062ab951074 100644 --- a/colossalai/checkpoint_io/distributed_checkpoint_io.py +++ b/colossalai/checkpoint_io/distributed_checkpoint_io.py @@ -1,29 +1,18 @@ -import copy +import json import logging import os -from functools import reduce from pathlib import Path -from shutil import rmtree from typing import Dict, Iterator, Optional, OrderedDict, Tuple -import json import torch import torch.distributed as dist import torch.nn as nn from torch.distributed import ProcessGroup -from torch.optim.lr_scheduler import _LRScheduler as LRScheduler -from torch.utils._pytree import tree_map +from torch.distributed.distributed_c10d import _get_default_group from colossalai.cluster import DistCoordinator -from colossalai.interface import ModelWrapper, OptimizerWrapper -from colossalai.tensor.padded_tensor import ( - init_as_padded_tensor, - is_padded_tensor, - to_padded_tensor, - to_unpadded_tensor, -) -from colossalai.utils import get_current_device, get_non_persistent_buffers_set -from torch.distributed.distributed_c10d import _get_default_group +from colossalai.interface import ModelWrapper +from colossalai.utils import get_non_persistent_buffers_set from .general_checkpoint_io import GeneralCheckpointIO from .index_file import CheckpointIndexFile @@ -31,21 +20,11 @@ StateDictSharder, async_save_state_dict_shards, create_pinned_state_dict, - gather_distributed_param, get_model_base_filenames, - get_optimizer_base_filenames, - is_safetensors_available, - load_shard_state_dict, load_state_dict, - load_state_dict_into_model, - load_states_into_optimizer, - save_config_file, - save_param_groups, save_state_dict, save_state_dict_shards, - search_padding_dim, search_tp_partition_dim, - sharded_optimizer_loading_epilogue, ) try: @@ -97,7 +76,6 @@ def __init__( self.model_metadata = None self.optimizer_metadata = None self.global_rank = dist.get_rank(_get_default_group()) - @staticmethod def model_state_dict(model: nn.Module, prefix: str = "", keep_vars: bool = False): @@ -106,13 +84,13 @@ def model_state_dict(model: nn.Module, prefix: str = "", keep_vars: bool = False for name, param in model.named_parameters(): if param is None: continue - destination[prefix+name] = param + destination[prefix + name] = param # Save buffers. non_persist_buffers_set = get_non_persistent_buffers_set(model) for name, buf in model.named_buffers(): if buf is not None and name not in non_persist_buffers_set: buffer = buf if keep_vars else buf.detach() - destination[prefix+name] = buffer + destination[prefix + name] = buffer # Save extra states. extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX @@ -123,22 +101,24 @@ def model_state_dict(model: nn.Module, prefix: str = "", keep_vars: bool = False extra_state = model.get_extra_state() destination[extra_state_key] = extra_state return destination - + @staticmethod - def load_state_dict(model: nn.Module, state_dict: Dict, prefix: str = "", keep_vars: bool = False, strict: bool = False): + def load_state_dict( + model: nn.Module, state_dict: Dict, prefix: str = "", keep_vars: bool = False, strict: bool = False + ): destination = dict() # Save parameters. for name, param in model.named_parameters(): if param is None: continue with torch.no_grad(): - param.copy_(state_dict[prefix+name]) + param.copy_(state_dict[prefix + name]) # Save buffers. non_persist_buffers_set = get_non_persistent_buffers_set(model) for name, buf in model.named_buffers(): if buf is not None and name not in non_persist_buffers_set: with torch.no_grad(): - buf.copy_(state_dict[prefix+name]) + buf.copy_(state_dict[prefix + name]) # Save extra states. extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX @@ -151,26 +131,33 @@ def load_state_dict(model: nn.Module, state_dict: Dict, prefix: str = "", keep_v extra_state.copy_(state_dict[extra_state_key]) return destination - def create_model_metadata(self, model: nn.Module, prefix: str = "",): + def create_model_metadata( + self, + model: nn.Module, + prefix: str = "", + ): param_origin_shape = model.param_origin_shape model = model.unwrap() self.model_metadata = {} for name, param in model.named_parameters(): if param is None: continue - self.model_metadata[prefix+name] = {} + self.model_metadata[prefix + name] = {} original_shape = param_origin_shape[name] - tp_partition_dim = search_tp_partition_dim(current_shape=param.shape, original_shape=original_shape, tp_size=self.tp_size) - self.model_metadata[prefix+name]["offsets"] = torch.zeros(len(original_shape), dtype=torch.int) - self.model_metadata[prefix+name]["lengths"] = list(param.shape) - self.model_metadata[prefix+name]["global_shape"] = list(original_shape) + tp_partition_dim = search_tp_partition_dim( + current_shape=param.shape, original_shape=original_shape, tp_size=self.tp_size + ) + self.model_metadata[prefix + name]["offsets"] = torch.zeros(len(original_shape), dtype=torch.int) + self.model_metadata[prefix + name]["lengths"] = list(param.shape) + self.model_metadata[prefix + name]["global_shape"] = list(original_shape) if tp_partition_dim is not None: partition_size = param.shape[tp_partition_dim] - self.model_metadata[prefix+name]["offsets"][tp_partition_dim] = partition_size * self.tp_rank + self.model_metadata[prefix + name]["offsets"][tp_partition_dim] = partition_size * self.tp_rank if self.tp_rank == self.tp_size - 1: - self.model_metadata[prefix+name]["lengths"][tp_partition_dim] = original_shape[tp_partition_dim] - (partition_size * (self.tp_size -1)) + self.model_metadata[prefix + name]["lengths"][tp_partition_dim] = original_shape[ + tp_partition_dim + ] - (partition_size * (self.tp_size - 1)) - def save_metadata(self, metadata_file, checkpoint_file=None, total_size=None): metadata_dicts = { "checkpoint_version": "1.0", @@ -188,7 +175,7 @@ def save_metadata(self, metadata_file, checkpoint_file=None, total_size=None): metadata_dicts["metadata"][name]["rank"] = self.global_rank with open(metadata_file, "w") as json_file: json.dump(metadata_dicts, json_file, indent=4) - + def save_unsharded_model( self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False ): @@ -249,13 +236,13 @@ def load_metadata(self, checkpoint: str): try: with open(file_path, "r") as f: metadata_json = json.load(f) - for name, item in metadata_json['metadata'].items(): + for name, item in metadata_json["metadata"].items(): if name not in metadata_dict: metadata_dict[name] = {} - metadata_dict[name]["global_shape"] = item['global_shape'] + metadata_dict[name]["global_shape"] = item["global_shape"] metadata_dict[name]["shards"] = {} else: - assert metadata_dict[name]["global_shape"] == item['global_shape'] + assert metadata_dict[name]["global_shape"] == item["global_shape"] shard = {} shard[item["rank"]] = {} shard[item["rank"]]["file"] = item["file"] @@ -304,7 +291,7 @@ def find_covering_shards(self, shards, target_offsets, target_lengths): assert total_lengths == global_shape return covering_shards - + def extract_weight_from_shard_partial(self, shard, target_offsets, target_lengths): """ Extract the target range of weights from shard data, supporting partial overlap. @@ -314,14 +301,16 @@ def extract_weight_from_shard_partial(self, shard, target_offsets, target_length param target_lengths: A 1D array indicating the length of the target tensor in each dimension. return: The extracted sub-tensor of the target weights and its position within the target range. """ - shard_offsets = shard['offsets'] - shard_lengths = shard['lengths'] - weight = shard['weight'] + shard_offsets = shard["offsets"] + shard_lengths = shard["lengths"] + weight = shard["weight"] slices = [] target_slices = [] - for dim, (t_offset, t_length, s_offset, s_length) in enumerate(zip(target_offsets, target_lengths, shard_offsets, shard_lengths)): + for dim, (t_offset, t_length, s_offset, s_length) in enumerate( + zip(target_offsets, target_lengths, shard_offsets, shard_lengths) + ): intersection_start = max(t_offset, s_offset) intersection_end = min(t_offset + t_length, s_offset + s_length) @@ -339,7 +328,6 @@ def extract_weight_from_shard_partial(self, shard, target_offsets, target_length target_weight = weight[tuple(slices)] return target_weight, target_slices - def assemble_tensor_from_shards_partial(self, shards, target_offsets, target_lengths, dtype): target_tensor = torch.zeros(target_lengths, dtype=dtype) @@ -351,15 +339,14 @@ def assemble_tensor_from_shards_partial(self, shards, target_offsets, target_len return target_tensor - - def load_unsharded_model( + def load_unsharded_model( self, model: ModelWrapper, checkpoint: str, strict: bool = False, low_cpu_mem_mode: bool = True, num_threads: int = 1, - ): + ): """ Load model from a single file with the given path of checkpoint. @@ -390,7 +377,9 @@ def load_unsharded_model( for key, item in self.model_metadata.items(): offsets = item["offsets"] lengths = item["lengths"] - assert item["global_shape"] == metadata_loaded[key]["global_shape"], f"{item['global_shape']}, {metadata_loaded[key]['global_shape']}" + assert ( + item["global_shape"] == metadata_loaded[key]["global_shape"] + ), f"{item['global_shape']}, {metadata_loaded[key]['global_shape']}" shards = metadata_loaded[key]["shards"] covering_shards = self.find_covering_shards(shards=shards, target_offsets=offsets, target_lengths=lengths) covered_shards[key] = covering_shards @@ -398,14 +387,14 @@ def load_unsharded_model( for rank, shard in covering_shards.items(): if rank not in load_files: load_files[rank] = set() - load_files[rank].add(shard['file']) + load_files[rank].add(shard["file"]) dtype = None for rank, files in load_files.items(): for file in files: file_path = os.path.join(checkpoint, file) state_dict_shard = load_state_dict(file_path) - for key, weight in state_dict_shard.items(): + for key, weight in state_dict_shard.items(): if key not in covered_shards: continue if dtype == None: @@ -413,7 +402,9 @@ def load_unsharded_model( covered_shards[key][rank]["weight"] = weight state_dict = {} for key, shards in covered_shards.items(): - state = self.assemble_tensor_from_shards_partial(shards, self.model_metadata[key]["offsets"], self.model_metadata[key]["lengths"], dtype=dtype) + state = self.assemble_tensor_from_shards_partial( + shards, self.model_metadata[key]["offsets"], self.model_metadata[key]["lengths"], dtype=dtype + ) state_dict[key] = state if not low_cpu_mem_mode: @@ -424,7 +415,6 @@ def load_unsharded_model( # Update master params if mixed-precision training is enabled. model_before_wrapping.update_master_params() - @staticmethod def _model_sharder( model: nn.Module, @@ -571,7 +561,7 @@ def save_sharded_model( ) for k, _ in self.model_metadata.items(): self.model_metadata[k]["file"] = index_file.get_checkpoint_file(k) - + self.save_metadata(metadata_file, total_size=total_size) def load_sharded_model( @@ -606,30 +596,34 @@ def load_sharded_model( for key, item in self.model_metadata.items(): offsets = item["offsets"] lengths = item["lengths"] - assert item["global_shape"] == metadata_loaded[key]["global_shape"], f"{item['global_shape']}, {metadata_loaded[key]['global_shape']}" + assert ( + item["global_shape"] == metadata_loaded[key]["global_shape"] + ), f"{item['global_shape']}, {metadata_loaded[key]['global_shape']}" shards = metadata_loaded[key]["shards"] covering_shards = self.find_covering_shards(shards=shards, target_offsets=offsets, target_lengths=lengths) covered_shards[key] = covering_shards for rank, shard in covering_shards.items(): if rank not in load_files: load_files[rank] = set() - load_files[rank].add(shard['file']) - + load_files[rank].add(shard["file"]) + dtype = None for rank, files in load_files.items(): for file in files: file_path = os.path.join(checkpoint, file) state_dict_shard = load_state_dict(file_path) - for key, weight in state_dict_shard.items(): + for key, weight in state_dict_shard.items(): if key not in covered_shards: continue if dtype == None: dtype = weight.dtype covered_shards[key][rank]["weight"] = weight - + state_dict = {} for key, shards in covered_shards.items(): - state = self.assemble_tensor_from_shards_partial(shards, self.model_metadata[key]["offsets"], self.model_metadata[key]["lengths"], dtype=dtype) + state = self.assemble_tensor_from_shards_partial( + shards, self.model_metadata[key]["offsets"], self.model_metadata[key]["lengths"], dtype=dtype + ) state_dict[key] = state if not low_cpu_mem_mode: @@ -638,4 +632,4 @@ def load_sharded_model( DistributedCheckpointIO.load_state_dict(model=model, state_dict=state_dict) # Update master params if mixed-precision training is enabled. - model_before_wrapping.update_master_params() \ No newline at end of file + model_before_wrapping.update_master_params() diff --git a/tests/test_checkpoint_io/test_dist_checkpointio.py b/tests/test_checkpoint_io/test_dist_checkpointio.py index b1d3b7a63de4..fc1b9650809c 100644 --- a/tests/test_checkpoint_io/test_dist_checkpointio.py +++ b/tests/test_checkpoint_io/test_dist_checkpointio.py @@ -4,16 +4,14 @@ from packaging.version import Version from torch.optim import Adam from utils import shared_tempdir -from copy import deepcopy import colossalai from colossalai.booster import Booster from colossalai.booster.plugin import HybridParallelPlugin +from colossalai.checkpoint_io import DistributedCheckpointIO from colossalai.shardformer.layer.utils import Randomizer from colossalai.tensor.d_tensor.api import clear_layout_converter -from colossalai.checkpoint_io import DistributedCheckpointIO from colossalai.testing import ( - assert_close_loose, check_state_dict_equal, clear_cache_before_run, parameterize, @@ -36,8 +34,24 @@ else: TEST_CONFIGS = [ # TODO(ver217): other configs lead to hang - ({"tp_size": 1, "pp_size": 2, "num_microbatches": 4, "zero_stage": 1, "precision": "fp16", "initial_scale": 1}, - {"tp_size": 2, "pp_size": 1, "num_microbatches": 4, "zero_stage": 1, "precision": "fp16", "initial_scale": 1},) + ( + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 4, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 1, + "num_microbatches": 4, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, + ) ] @@ -59,7 +73,13 @@ def exam_state_dict( plugin_0 = HybridParallelPlugin(**test_config_0) booster_0 = Booster(plugin=plugin_0) hybrid_ckp_0 = booster_0.checkpoint_io - booster_0.checkpoint_io = DistributedCheckpointIO(hybrid_ckp_0.global_dp_group, hybrid_ckp_0.pp_group, hybrid_ckp_0.tp_group, hybrid_ckp_0.sp_group, hybrid_ckp_0.use_zero) + booster_0.checkpoint_io = DistributedCheckpointIO( + hybrid_ckp_0.global_dp_group, + hybrid_ckp_0.pp_group, + hybrid_ckp_0.tp_group, + hybrid_ckp_0.sp_group, + hybrid_ckp_0.use_zero, + ) def _criterion(outputs, inputs): outputs = output_transform_fn(outputs) @@ -95,7 +115,9 @@ def _preprocess_data(data): with shared_tempdir() as tempdir: model_ckpt_path_0 = f"{tempdir}/model_0" - booster_0.save_model(model_0, model_ckpt_path_0, shard=shard, size_per_shard=size_per_shard, use_async=use_async) + booster_0.save_model( + model_0, model_ckpt_path_0, shard=shard, size_per_shard=size_per_shard, use_async=use_async + ) booster_0.checkpoint_io._sync_d2h() booster_0.checkpoint_io._sync_io() dist.barrier() @@ -103,7 +125,13 @@ def _preprocess_data(data): plugin_1 = HybridParallelPlugin(**test_config_1) booster_1 = Booster(plugin=plugin_1) hybrid_ckp_1 = booster_1.checkpoint_io - booster_1.checkpoint_io = DistributedCheckpointIO(hybrid_ckp_1.global_dp_group, hybrid_ckp_1.pp_group, hybrid_ckp_1.tp_group, hybrid_ckp_1.sp_group, hybrid_ckp_1.use_zero) + booster_1.checkpoint_io = DistributedCheckpointIO( + hybrid_ckp_1.global_dp_group, + hybrid_ckp_1.pp_group, + hybrid_ckp_1.tp_group, + hybrid_ckp_1.sp_group, + hybrid_ckp_1.use_zero, + ) model_1 = model_fn().cuda() optimizer_1 = Adam(model_1.parameters(), lr=1e-3) @@ -112,7 +140,9 @@ def _preprocess_data(data): booster_1.load_model(model_1, model_ckpt_path_0, low_cpu_mem_mode=low_cpu_mem_mode) model_ckpt_path_1 = f"{tempdir}/model_1" - booster_1.save_model(model_1, model_ckpt_path_1, shard=shard, size_per_shard=size_per_shard, use_async=use_async) + booster_1.save_model( + model_1, model_ckpt_path_1, shard=shard, size_per_shard=size_per_shard, use_async=use_async + ) booster_1.checkpoint_io._sync_d2h() booster_1.checkpoint_io._sync_io() dist.barrier()