From d541d208bbe746cc5c69019dfb386e4f42f66a73 Mon Sep 17 00:00:00 2001 From: Sangkug Lym Date: Fri, 5 Apr 2024 22:44:27 -0700 Subject: [PATCH] Fix the default userbuffer communicator init settings (#755) fix the default userbuffer communicator init settings Signed-off-by: Sangkug Lym --- transformer_engine/pytorch/module/base.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 6ef6d4eb3b..56dd3c8fc4 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -153,13 +153,13 @@ def get_method(name): def add_ub( name: str, method: str, + is_reduce_scatter: int, num_sm: int = 16, cga_size: int = 2, set_sm_margin: int = 0, - num_splits: int = 4, + num_splits: int = 0, aggregate: int = 0, atomic_gemm: int = 0, - is_reduce_scatter: int = 0, fp8_buf: bool = False, ) -> None: if atomic_gemm: @@ -243,7 +243,7 @@ def add_ub( method = ub_cfg.get("method", get_method(name)) num_sm = ub_cfg.get("num_sm", 16) cga_size = ub_cfg.get("cga_size", 2) - num_splits = ub_cfg.get("num_splits", 4) + num_splits = ub_cfg.get("num_splits", 4 if method == "pipeline" else 0) set_sm_margin = ub_cfg.get("set_sm_margin", 0) aggregate = ub_cfg.get("aggregate", 0) atomic_gemm = ub_cfg.get("atomic_gemm", 0) @@ -254,21 +254,24 @@ def add_ub( add_ub( name, method, + is_reduce_scatter, num_sm, cga_size, set_sm_margin, num_splits, aggregate, atomic_gemm, - is_reduce_scatter, fp8_buf, ) else: method = get_method(name) - if method == "pipeline": - add_ub(name, method) - else: - add_ub(name, method, num_splits=0) + add_ub( + name, + method=method, + is_reduce_scatter=1 if name in layers_reduce_scatter_overlap else 0, + num_splits=4 if method == "pipeline" else 0, + fp8_buf=name in layers_all_gather_overlap, + ) def get_ub(name: str):