Skip to content

Commit

Permalink
[Device]Support npu (#6159)
Browse files Browse the repository at this point in the history
* support npu

* support pretrain

support pretrain

fix

* support lora

fix

fix

* support chatglm

fix

fxi

fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix

fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix

fix

fix

* Update train.py

* Update train.py

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
flybird11111 and pre-commit-ci[bot] authored Dec 17, 2024
1 parent e994c64 commit aaafb38
Show file tree
Hide file tree
Showing 18 changed files with 292 additions and 149 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def dict(self):
messages=[],
offset=0,
sep_style=SeparatorStyle.ADD_BOS_EOS_TOKEN,
seps=["<|begin_of_text|>", "<|end_of_text|>"],
seps=["<|begin_of_text|>", "<|eot_id|>"],
)

default_conversation = LLaMA3_Conv
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def supervised_tokenize_sft(

assert (
tokenizer.bos_token == conversation_template.seps[0] and tokenizer.eos_token == conversation_template.seps[1]
), "`bos_token` and `eos_token` should be the same with `conversation_template.seps`."
), f"`bos_token`{tokenizer.bos_token} and `eos_token`{tokenizer.eos_token} should be the same with `conversation_template.seps`{conversation_template.seps}."

if ignore_index is None:
ignore_index = IGNORE_INDEX
Expand Down
6 changes: 5 additions & 1 deletion applications/Colossal-LLaMA/colossal_llama/utils/ckpt_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def save_checkpoint(
step: int,
batch_size: int,
coordinator: DistCoordinator,
use_lora: bool = False,
) -> None:
"""
Save model checkpoint, optimizer, LR scheduler and intermedidate running states.
Expand All @@ -51,7 +52,10 @@ def save_checkpoint(
save_dir = os.path.join(save_dir, f"epoch-{epoch}_step-{step}")
os.makedirs(os.path.join(save_dir, "modeling"), exist_ok=True)

booster.save_model(model, os.path.join(save_dir, "modeling"), shard=True)
if use_lora:
booster.save_lora_as_pretrained(model, os.path.join(save_dir, "modeling"))
else:
booster.save_model(model, os.path.join(save_dir, "modeling"), shard=True)

booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True)
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
Expand Down
115 changes: 61 additions & 54 deletions applications/Colossal-LLaMA/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from colossal_llama.utils.froze import freeze_non_embeds_parameters
from colossal_llama.utils.neftune_patch import activate_neftune, deactivate_neftune
from colossal_llama.utils.utils import all_reduce_mean, format_numel_str, get_model_numel
from peft import LoraConfig
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
Expand Down Expand Up @@ -65,7 +66,7 @@ def train(args) -> None:
initial_scale=2**16,
max_norm=args.grad_clip,
enable_gradient_accumulation=(args.accumulation_steps > 1),
enable_fused_normalization=torch.cuda.is_available(),
enable_fused_normalization=get_accelerator().is_available(),
enable_flash_attention=args.use_flash_attn,
)
elif args.plugin == "gemini_auto":
Expand All @@ -75,7 +76,7 @@ def train(args) -> None:
initial_scale=2**16,
max_norm=args.grad_clip,
enable_gradient_accumulation=(args.accumulation_steps > 1),
enable_fused_normalization=torch.cuda.is_available(),
enable_fused_normalization=get_accelerator().is_available(),
enable_flash_attention=args.use_flash_attn,
)
elif args.plugin == "zero2":
Expand All @@ -101,10 +102,9 @@ def train(args) -> None:
sequence_parallelism_mode=args.sp_mode,
zero_stage=args.zero_stage,
enable_flash_attention=args.use_flash_attn,
enable_fused_normalization=torch.cuda.is_available(),
enable_fused_normalization=get_accelerator().is_available(),
enable_sequence_parallelism=args.enable_sequence_parallelism,
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
parallel_output=False,
max_norm=args.grad_clip,
precision=args.mixed_precision,
microbatch_size=args.microbatch_size,
Expand All @@ -117,11 +117,17 @@ def train(args) -> None:
# ======================================================
# Initialize Tokenizer, Dataset, Collator and Dataloader
# ======================================================
tokenizer = AutoTokenizer.from_pretrained(args.pretrained)
tokenizer = AutoTokenizer.from_pretrained(args.pretrained, trust_remote_code=True)
if args.pad_token == "eos":
tokenizer.pad_token = tokenizer.eos_token
try:
tokenizer.pad_token = tokenizer.eos_token
except AttributeError:
coordinator.print_on_master(f"pad_token can't be set")
elif args.pad_token == "unk":
tokenizer.pad_token = tokenizer.unk_token
try:
tokenizer.pad_token = tokenizer.unk_token
except AttributeError:
coordinator.print_on_master(f"pad_token can't be set")
tokenizer.add_bos_token = False
tokenizer.add_eos_token = False

Expand Down Expand Up @@ -164,33 +170,31 @@ def train(args) -> None:
# ======================================================
# Initialize Model, Objective, Optimizer and LR Scheduler
# ======================================================
# When training the ChatGLM model, LoRA and gradient checkpointing are incompatible.
init_ctx = (
LazyInitContext(default_device=get_current_device())
if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin))
if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin)) and args.lora_rank == 0
else nullcontext()
)
with init_ctx:
if args.use_flash_attn:
model = AutoModelForCausalLM.from_pretrained(
args.pretrained,
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
trust_remote_code=True,
)
else:
model = AutoModelForCausalLM.from_pretrained(
args.pretrained,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
trust_remote_code=True,
)
model = AutoModelForCausalLM.from_pretrained(
args.pretrained,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
trust_remote_code=True,
)
# Freeze part of parameters.
if args.freeze_non_embeds_params:
freeze_non_embeds_parameters(model=model)

if args.lora_rank > 0:
lora_config = LoraConfig(task_type="CAUSAL_LM", r=args.lora_rank, lora_alpha=32, lora_dropout=0.1)
model = booster.enable_lora(model, lora_config=lora_config)

# this is essential, otherwise the grad checkpoint will not work.
model.train()

if args.use_grad_checkpoint:
model.gradient_checkpointing_enable()
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")

model_numel = get_model_numel(model)
Expand Down Expand Up @@ -327,6 +331,7 @@ def train(args) -> None:
step=step + 1,
batch_size=args.batch_size,
coordinator=coordinator,
use_lora=(args.lora_rank > 0),
)
coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}"
Expand Down Expand Up @@ -371,44 +376,45 @@ def train(args) -> None:
total_loss.fill_(0.0)
pbar.update()

# Save modeling.
save_model_condition = (
args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0
)
# Save modeling.
save_model_condition = (
args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0
)

if not args.skip_save_each_epoch:
save_model_condition = save_model_condition or (step + 1) == len(dataloader)
if not args.skip_save_each_epoch:
save_model_condition = save_model_condition or (step + 1) == len(dataloader)

if save_model_condition and not args.benchmark:
coordinator.print_on_master("\nStart saving model checkpoint with running states")
if save_model_condition and not args.benchmark:
coordinator.print_on_master("\nStart saving model checkpoint with running states")

if args.use_neft:
coordinator.print_on_master("Deactivate NEFTune before saving model.")
deactivate_neftune(model, handle)
if args.use_neft:
coordinator.print_on_master("Deactivate NEFTune before saving model.")
deactivate_neftune(model, handle)

accelerator.empty_cache()
save_checkpoint(
save_dir=args.save_dir,
booster=booster,
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
epoch=epoch,
step=step + 1,
batch_size=args.batch_size,
coordinator=coordinator,
)
coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}"
)
accelerator.empty_cache()
save_checkpoint(
save_dir=args.save_dir,
booster=booster,
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
epoch=epoch,
step=step + 1,
batch_size=args.batch_size,
coordinator=coordinator,
use_lora=(args.lora_rank > 0),
)
coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}"
)

if args.use_neft:
coordinator.print_on_master("Activate NEFTune.")
model, handle = activate_neftune(model)
if args.use_neft:
coordinator.print_on_master("Activate NEFTune.")
model, handle = activate_neftune(model)

# Delete cache.
# del batch, batch_labels, batch_output, loss
accelerator.empty_cache()
# Delete cache.
# del batch, batch_labels, batch_output, loss
accelerator.empty_cache()

# the continue epochs are not resumed, so we need to reset the sampler start index and start step
dataloader.sampler.set_start_index(start_index=0)
Expand Down Expand Up @@ -522,6 +528,7 @@ def train(args) -> None:
parser.add_argument(
"--microbatch_size", type=int, default=1, help="Batch size for each process in PP, used for 3d plugin."
)
parser.add_argument("--lora_rank", type=int, default=0, help="lora rank when using lora to train.")

# Additional arguments for benchmark.
parser.add_argument("--num_samples", type=int, default=500, help="Number of samples for benchmarking.")
Expand Down
6 changes: 3 additions & 3 deletions colossalai/lazy/lazy_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,9 +509,9 @@ def wrap_factory_like_method(orig_target, target):
# factory_like functions (eg. torch.empty_like())
def wrapper(*args, **kwargs):
orig_t = args[0]
return self.tensor_cls(
orig_target, *orig_t.shape, *args[1:], device=orig_t.device, dtype=orig_t.dtype, **kwargs
)
device = kwargs.pop("device", orig_t.device)
dtype = kwargs.pop("dtype", orig_t.dtype)
return self.tensor_cls(orig_target, *orig_t.shape, *args[1:], device=device, dtype=dtype, **kwargs)

return wrapper, target

Expand Down
2 changes: 1 addition & 1 deletion colossalai/legacy/communication/p2p.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def _communicate(
for req in reqs:
req.wait()
# To protect against race condition when using batch_isend_irecv().
torch.cuda.synchronize()
get_accelerator().synchronize()

if recv_prev and recv_prev_split:
if isinstance(tensor_recv_prev, torch.Tensor):
Expand Down
33 changes: 20 additions & 13 deletions colossalai/pipeline/p2p.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from torch.distributed import distributed_c10d as c10d
from torch.utils._pytree import tree_flatten, tree_unflatten

from colossalai.accelerator import get_accelerator

from .stage_manager import PipelineStageManager


Expand All @@ -31,7 +33,7 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -
buf = tensor.numpy().tobytes()[:tensor_size]
if b"cuda" in buf:
buf_array = bytearray(buf)
device_index = torch.cuda.current_device()
device_index = get_accelerator().current_device()
# There might be more than one output tensors during forward
for cuda_str in re.finditer(b"cuda", buf_array):
pos = cuda_str.start()
Expand Down Expand Up @@ -86,7 +88,7 @@ def _broadcast_object_list(
else:
current_device = torch.device("cpu")
if is_nccl_backend:
current_device = torch.device("cuda", torch.cuda.current_device())
current_device = torch.device("cuda", get_accelerator().current_device())

my_rank = dist.get_rank()
# Serialize object_list elements to tensors on src rank.
Expand Down Expand Up @@ -139,29 +141,29 @@ def _broadcast_object_list(
# unconsistence in device
if (
isinstance(unpickle_object, torch.Tensor)
and unpickle_object.device.index != torch.cuda.current_device()
and unpickle_object.device.index != get_accelerator().current_device()
):
unpickle_object = unpickle_object.cuda()
unpickle_object = unpickle_object.to(get_accelerator().current_device())

object_list[i] = unpickle_object


def _check_for_nccl_backend(group):
def _check_for_nccl_hccl_backend(group):
pg = group or c10d._get_default_group()
# Gate PG wrapper check on Gloo availability.
if c10d._GLOO_AVAILABLE:
# It is not expected for PG to be wrapped many times, but support it just in case
while isinstance(pg, c10d._ProcessGroupWrapper):
pg = pg.wrapped_pg

return c10d.is_nccl_available() and pg.name() == c10d.Backend.NCCL
return (c10d.is_nccl_available() or torch.distributed.is_hccl_available()) and pg.name() == c10d.Backend.NCCL


def _check_device(group):
is_nccl_backend = _check_for_nccl_backend(group)
is_nccl_backend = _check_for_nccl_hccl_backend(group)
current_device = torch.device("cpu")
if is_nccl_backend:
current_device = torch.device("cuda", torch.cuda.current_device())
current_device = torch.device(get_accelerator().current_device())
return current_device, is_nccl_backend


Expand Down Expand Up @@ -348,8 +350,11 @@ def _send_recv_serialization_object(

unpickle_object = _cuda_safe_tensor_to_object(recv_object_tensor, recv_object_size_tensor.item())

if isinstance(unpickle_object, torch.Tensor) and unpickle_object.device.index != torch.cuda.current_device():
unpickle_object = unpickle_object.cuda()
if (
isinstance(unpickle_object, torch.Tensor)
and unpickle_object.device.index != get_accelerator().current_device()
):
unpickle_object = unpickle_object.to(get_accelerator().current_device())

return unpickle_object

Expand Down Expand Up @@ -474,9 +479,11 @@ def _p2p_comm(
recv_prev_shape = None

if tensor_send_next is not None:
send_next_shape = torch.tensor(tensor_send_next.size(), device=torch.cuda.current_device(), dtype=torch.int64)
send_next_shape = torch.tensor(
tensor_send_next.size(), device=get_accelerator().current_device(), dtype=torch.int64
)
if recv_prev:
recv_prev_shape = torch.empty((3), device=torch.cuda.current_device(), dtype=torch.int64)
recv_prev_shape = torch.empty((3), device=get_accelerator().current_device(), dtype=torch.int64)

ops = []
if send_next_shape is not None:
Expand All @@ -501,7 +508,7 @@ def _p2p_comm(
# send and recv data
tensor_recv_prev = None
if recv_prev:
tensor_recv_prev = torch.empty(recv_prev_shape, device=torch.cuda.current_device(), dtype=comm_dtype)
tensor_recv_prev = torch.empty(recv_prev_shape, device=get_accelerator().current_device(), dtype=comm_dtype)

ops = []
if tensor_send_next is not None:
Expand Down
3 changes: 1 addition & 2 deletions colossalai/pipeline/schedule/interleaved_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

import torch
import torch.cuda
import torch.distributed
from torch.nn import Module, ModuleList
from torch.utils._pytree import tree_map
Expand All @@ -18,7 +17,7 @@
from .base import PipelineSchedule


def _wait_p2p(wait_handles: List[torch.cuda.Event]) -> None:
def _wait_p2p(wait_handles) -> None:
if wait_handles is not None:
for req in wait_handles:
req.wait()
Expand Down
1 change: 0 additions & 1 deletion colossalai/pipeline/schedule/one_f_one_b.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import Any, Callable, Dict, Iterable, List, Optional, Union

import torch
import torch.cuda
from torch.nn import Module
from torch.utils._pytree import tree_map

Expand Down
Loading

0 comments on commit aaafb38

Please sign in to comment.