Skip to content

Commit

Permalink
[checkpointio] support non blocking pin load (#6172)
Browse files Browse the repository at this point in the history
* [checkpointio] support non blocking pin load

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
ver217 and pre-commit-ci[bot] authored Dec 25, 2024
1 parent 8369924 commit af06d16
Show file tree
Hide file tree
Showing 15 changed files with 484 additions and 174 deletions.
32 changes: 25 additions & 7 deletions colossalai/booster/booster.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,14 @@ def enable_lora(

return self.plugin.enable_lora(model, pretrained_dir, lora_config, bnb_quantization_config)

def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True) -> None:
def load_model(
self,
model: Union[nn.Module, ModelWrapper],
checkpoint: str,
strict: bool = True,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
) -> None:
"""Load model from checkpoint.
Args:
Expand All @@ -298,8 +305,12 @@ def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, str
strict (bool, optional): whether to strictly enforce that the keys
in :attr:`state_dict` match the keys returned by this module's
:meth:`~torch.nn.Module.state_dict` function. Defaults to True.
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
"""
self.checkpoint_io.load_model(model, checkpoint, strict)
self.checkpoint_io.load_model(
model, checkpoint, strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
)

def save_model(
self,
Expand Down Expand Up @@ -338,18 +349,25 @@ def save_model(
use_async=use_async,
)

def load_optimizer(self, optimizer: Optimizer, checkpoint: str) -> None:
def load_optimizer(
self,
optimizer: Optimizer,
checkpoint: str,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
) -> None:
"""Load optimizer from checkpoint.
Args:
optimizer (Optimizer): An optimizer boosted by Booster.
checkpoint (str): Path to the checkpoint. It must be a local path.
It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path.
prefix (str, optional): A prefix added to parameter and buffer
names to compose the keys in state_dict. Defaults to None.
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
"""
self.checkpoint_io.load_optimizer(optimizer, checkpoint)
self.checkpoint_io.load_optimizer(
optimizer, checkpoint, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
)

def save_optimizer(
self,
Expand Down
53 changes: 43 additions & 10 deletions colossalai/booster/plugin/gemini_plugin.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import gc
import os
import random
from pathlib import Path
Expand Down Expand Up @@ -97,13 +96,22 @@ def save_unsharded_model(
else:
save_state_dict(state_dict, checkpoint, use_safetensors)

def load_unsharded_model(self, model: GeminiDDP, checkpoint: str, strict: bool = True):
def load_unsharded_model(
self,
model: GeminiDDP,
checkpoint: str,
strict: bool = True,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""
Load model from checkpoint with automatic unwrapping.
The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
"""
assert isinstance(model, GeminiDDP), "Please boost the model before loading!"
super().load_unsharded_model(model, checkpoint, strict=strict)
super().load_unsharded_model(
model, checkpoint, strict=strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
)

def save_unsharded_optimizer(
self, optimizer: GeminiOptimizer, checkpoint: str, gather_dtensor: bool, use_async: bool = False
Expand Down Expand Up @@ -131,13 +139,17 @@ def save_unsharded_optimizer(
else:
save_state_dict(state_dict, checkpoint, use_safetensors=False)

def load_unsharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint: str):
def load_unsharded_optimizer(
self, optimizer: GeminiOptimizer, checkpoint: str, low_cpu_mem_mode: bool = True, num_threads: int = 1
):
"""
Loading unsharded optimizer from checkpoint file.
For each process, only loading optimizer states of parameters it controls.
"""
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before loading!"
super().load_unsharded_optimizer(optimizer, checkpoint)
super().load_unsharded_optimizer(
optimizer, checkpoint, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
)

def save_sharded_model(
self,
Expand Down Expand Up @@ -206,13 +218,27 @@ def save_sharded_model(
)

def load_sharded_model(
self, model: GeminiDDP, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False
self,
model: GeminiDDP,
checkpoint_index_file: Path,
strict: bool = False,
use_safetensors: bool = False,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""
Load shard model, load model from multiple files.
"""
assert isinstance(model, GeminiDDP), "Please boost the model before loading!"
return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False)
return super().load_sharded_model(
model,
checkpoint_index_file,
strict,
use_safetensors,
load_sub_module=False,
low_cpu_mem_mode=low_cpu_mem_mode,
num_threads=num_threads,
)

def save_sharded_optimizer(
self,
Expand Down Expand Up @@ -289,7 +315,14 @@ def save_sharded_optimizer(
ranks=[0],
)

def load_sharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint_index_file: Path, prefix: str):
def load_sharded_optimizer(
self,
optimizer: GeminiOptimizer,
checkpoint_index_file: Path,
prefix: str,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""
Loading sharded optimizer from checkpoint folder, with index file given.
For each process, only loading optimizer states of parameters it controls.
Expand Down Expand Up @@ -322,9 +355,9 @@ def load_sharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint_index_fi
state_dict_shard = load_flat(shard_file)
else:
state_dict_shard = load_shard_state_dict(Path(shard_file), use_safetensors=False)
if not low_cpu_mem_mode:
state_dict_shard = create_pinned_state_dict(state_dict_shard, empty=False, num_threads=num_threads)
optimizer.load_param_states(state_dict_shard)
del state_dict_shard
gc.collect()

optimizer.optimizer_loading_epilogue()

Expand Down
48 changes: 42 additions & 6 deletions colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from colossalai.accelerator import get_accelerator
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO
from colossalai.checkpoint_io.utils import (
create_pinned_state_dict,
get_optimizer_base_filenames,
get_shard_filename,
load_param_groups_into_optimizer,
Expand Down Expand Up @@ -145,14 +146,18 @@ def save_unsharded_optimizer(
else:
save_state_dict(state_dict, checkpoint, use_safetensors=False)

def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
def load_unsharded_optimizer(
self, optimizer: OptimizerWrapper, checkpoint: str, low_cpu_mem_mode: bool = True, num_threads: int = 1
):
use_async = checkpoint.endswith(".safetensors")
if use_async:
from colossalai.utils.safetensors import load_flat

checkpoint = load_flat(checkpoint)
else:
checkpoint = load_state_dict(checkpoint)
if not low_cpu_mem_mode:
checkpoint = create_pinned_state_dict(checkpoint, empty=False, num_threads=num_threads)
optimizer.load_state_dict(checkpoint)

def save_sharded_optimizer(
Expand Down Expand Up @@ -239,7 +244,14 @@ def save_sharded_optimizer(
ranks=[0],
)

def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: str, prefix: str):
def load_sharded_optimizer(
self,
optimizer: OptimizerWrapper,
index_file_path: str,
prefix: str,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""Load sharded optimizer with the given path to index file.
Args:
Expand Down Expand Up @@ -283,14 +295,28 @@ def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: s
if padding_size > 0:
v = torch.nn.functional.pad(v, [0, padding_size])
v_list = v.split(v.numel() // self.coordinator.world_size)
state_dict[param_idx][k] = v_list[self.coordinator.rank].detach().clone()
state_dict[param_idx][k] = v_list[self.coordinator.rank].detach()
if low_cpu_mem_mode:
state_dict[param_idx][k] = state_dict[param_idx][k].clone()

if not low_cpu_mem_mode:
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
load_states_into_optimizer(optimizer, state_dict, id_map)
sharded_optimizer_loading_epilogue(optimizer)

def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True):
def load_unsharded_model(
self,
model: ModelWrapper,
checkpoint: str,
strict: bool = True,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
model._force_wait_all_gather()
super().load_unsharded_model(model, checkpoint, strict)
super().load_unsharded_model(
model, checkpoint, strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
)
model.update_master_params()

def load_sharded_model(
Expand All @@ -300,10 +326,20 @@ def load_sharded_model(
strict: bool = False,
use_safetensors: bool = False,
load_sub_module: bool = True,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
model._force_wait_all_gather()
super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module)
super().load_sharded_model(
model,
checkpoint_index_file,
strict,
use_safetensors,
load_sub_module,
low_cpu_mem_mode=low_cpu_mem_mode,
num_threads=num_threads,
)
model.update_master_params()

def save_unsharded_model(
Expand Down
39 changes: 33 additions & 6 deletions colossalai/booster/plugin/torch_ddp_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,21 @@ def __init__(self) -> None:
self.coordinator = DistCoordinator()
self.logger = get_dist_logger()

def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True):
def load_unsharded_model(
self,
model: ModelWrapper,
checkpoint: str,
strict: bool = True,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""
Load model from checkpoint.
"""
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
super().load_unsharded_model(model.unwrap(), checkpoint, strict=strict)
super().load_unsharded_model(
model.unwrap(), checkpoint, strict=strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
)

def save_unsharded_model(
self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
Expand All @@ -45,12 +54,16 @@ def save_unsharded_model(
model.unwrap(), checkpoint, gather_dtensor, use_safetensors, use_async=use_async
)

def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
def load_unsharded_optimizer(
self, optimizer: OptimizerWrapper, checkpoint: str, low_cpu_mem_mode: bool = True, num_threads: int = 1
):
"""
Load optimizer from checkpoint.
"""
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
super().load_unsharded_optimizer(optimizer, checkpoint)
super().load_unsharded_optimizer(
optimizer, checkpoint, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
)

def save_unsharded_optimizer(
self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool, use_async: bool = False
Expand Down Expand Up @@ -101,12 +114,22 @@ def load_sharded_model(
strict: bool = False,
use_safetensors: bool = False,
load_sub_module: bool = True,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""
Load model from sharded checkpoint.
"""
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
super().load_sharded_model(model.unwrap(), checkpoint_index_file, strict, use_safetensors, load_sub_module)
super().load_sharded_model(
model.unwrap(),
checkpoint_index_file,
strict,
use_safetensors,
load_sub_module,
low_cpu_mem_mode=low_cpu_mem_mode,
num_threads=num_threads,
)

def save_sharded_optimizer(
self,
Expand All @@ -131,12 +154,16 @@ def load_sharded_optimizer(
optimizer: Optimizer,
index_file_path: str,
prefix: Optional[str] = None,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""
Load optimizer from sharded checkpoint.
"""
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
super().load_sharded_optimizer(optimizer.unwrap(), index_file_path, prefix)
super().load_sharded_optimizer(
optimizer.unwrap(), index_file_path, prefix, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
)

def save_lora_as_pretrained(
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False
Expand Down
19 changes: 16 additions & 3 deletions colossalai/booster/plugin/torch_fsdp_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,17 @@ def __init__(self) -> None:
self.coordinator = DistCoordinator()
self.logger = get_dist_logger()

def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool):
def load_unsharded_model(
self, model: ModelWrapper, checkpoint: str, strict: bool, low_cpu_mem_mode: bool = True, num_threads: int = 1
):
assert isinstance(model, TorchFSDPModel), "Please boost the model before loading!"
model = model.unwrap()
checkpoint = utils.load_state_dict(checkpoint)
model.load_state_dict(checkpoint)

def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: Path):
def load_unsharded_optimizer(
self, optimizer: OptimizerWrapper, checkpoint: Path, low_cpu_mem_mode: bool = True, num_threads: int = 1
):
assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before loading!"
if checkpoint.endswith(".safetensors"):
checkpoint = load_flat(checkpoint, seperator=".")
Expand Down Expand Up @@ -232,6 +236,8 @@ def load_sharded_model(
strict: bool = False,
use_safetensors: bool = False,
load_sub_module: bool = True,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""
Load model to checkpoint but only on master process.
Expand Down Expand Up @@ -354,7 +360,14 @@ def pack_group(group: Dict[str, Any]) -> Dict[str, Any]:
f"index located at {save_index_file}."
)

def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, size_per_shard: int):
def load_sharded_optimizer(
self,
optimizer: Optimizer,
index_file_path: str,
size_per_shard: int,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""
Load optimizer to checkpoint but only on master process.
"""
Expand Down
Loading

0 comments on commit af06d16

Please sign in to comment.