Skip to content

Commit

Permalink
Fix the default userbuffer communicator init settings (NVIDIA#755)
Browse files Browse the repository at this point in the history
fix the default userbuffer communicator init settings

Signed-off-by: Sangkug Lym <[email protected]>
  • Loading branch information
erhoo82 authored Apr 6, 2024
1 parent e3de403 commit d541d20
Showing 1 changed file with 11 additions and 8 deletions.
19 changes: 11 additions & 8 deletions transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,13 +153,13 @@ def get_method(name):
def add_ub(
name: str,
method: str,
is_reduce_scatter: int,
num_sm: int = 16,
cga_size: int = 2,
set_sm_margin: int = 0,
num_splits: int = 4,
num_splits: int = 0,
aggregate: int = 0,
atomic_gemm: int = 0,
is_reduce_scatter: int = 0,
fp8_buf: bool = False,
) -> None:
if atomic_gemm:
Expand Down Expand Up @@ -243,7 +243,7 @@ def add_ub(
method = ub_cfg.get("method", get_method(name))
num_sm = ub_cfg.get("num_sm", 16)
cga_size = ub_cfg.get("cga_size", 2)
num_splits = ub_cfg.get("num_splits", 4)
num_splits = ub_cfg.get("num_splits", 4 if method == "pipeline" else 0)
set_sm_margin = ub_cfg.get("set_sm_margin", 0)
aggregate = ub_cfg.get("aggregate", 0)
atomic_gemm = ub_cfg.get("atomic_gemm", 0)
Expand All @@ -254,21 +254,24 @@ def add_ub(
add_ub(
name,
method,
is_reduce_scatter,
num_sm,
cga_size,
set_sm_margin,
num_splits,
aggregate,
atomic_gemm,
is_reduce_scatter,
fp8_buf,
)
else:
method = get_method(name)
if method == "pipeline":
add_ub(name, method)
else:
add_ub(name, method, num_splits=0)
add_ub(
name,
method=method,
is_reduce_scatter=1 if name in layers_reduce_scatter_overlap else 0,
num_splits=4 if method == "pipeline" else 0,
fp8_buf=name in layers_all_gather_overlap,
)


def get_ub(name: str):
Expand Down

0 comments on commit d541d20

Please sign in to comment.