Skip to content

Commit

Permalink
[JAX] Use default factory for not sharing mutable default values (NVI…
Browse files Browse the repository at this point in the history
…DIA#1364)

* Bug Fix: Use default factory for not sharing mutable default values
---------

Signed-off-by: Reese Wang <[email protected]>
Co-authored-by: Phuong Nguyen <[email protected]>
  • Loading branch information
zlsh80826 and phu0ngng authored Dec 10, 2024
1 parent 3102fdd commit e4c99b0
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 8 deletions.
25 changes: 19 additions & 6 deletions transformer_engine/jax/praxis/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""
Praxis Modules
"""
from dataclasses import field
from functools import partial
from typing import Callable, Iterable, Sequence, Tuple, Union

Expand Down Expand Up @@ -74,7 +75,9 @@ class LayerNorm(TransformerEngineBaseLayer):
zero_centered_gamma: bool = False
scale_init: WeightInit = None
scale_axes: Tuple[str, ...] = ()
bias_init: WeightInit = WeightInit.Constant(0.0)
bias_init: WeightInit = field( # pylint: disable=invalid-field-call
default_factory=partial(WeightInit.Constant, scale=0.0)
)
bias_axes: Tuple[str, ...] = ()
transpose_batch_sequence: bool = False

Expand Down Expand Up @@ -129,7 +132,9 @@ class Linear(TransformerEngineBaseLayer):
out_features: int = 512
kernel_axes: Tuple[str, ...] = ()
use_bias: bool = True
bias_init: WeightInit = WeightInit.Constant(0.0)
bias_init: WeightInit = field( # pylint: disable=invalid-field-call
default_factory=partial(WeightInit.Constant, scale=0.0)
)
bias_axes: Tuple[str, ...] = ()
enable_low_rank_adaptation: bool = False
low_rank_adaptation_dim: int = 32
Expand Down Expand Up @@ -174,11 +179,15 @@ class LayerNormLinear(TransformerEngineBaseLayer):
zero_centered_gamma: bool = False
scale_init: WeightInit = None
scale_axes: Tuple[str, ...] = ()
ln_bias_init: WeightInit = WeightInit.Constant(1.0)
ln_bias_init: WeightInit = field( # pylint: disable=invalid-field-call
default_factory=partial(WeightInit.Constant, scale=1.0)
)
ln_bias_axes: Tuple[str, ...] = ()
kernel_axes: Tuple[str, ...] = ()
use_bias: bool = False
bias_init: WeightInit = WeightInit.Constant(0.0)
bias_init: WeightInit = field( # pylint: disable=invalid-field-call
default_factory=partial(WeightInit.Constant, scale=0.0)
)
bias_axes: Tuple[str, ...] = ()
enable_low_rank_adaptation: bool = False
low_rank_adaptation_dim: int = 32
Expand Down Expand Up @@ -237,12 +246,16 @@ class LayerNormMLP(TransformerEngineBaseLayer):
zero_centered_gamma: bool = False
scale_init: WeightInit = None
scale_axes: Tuple[str, ...] = ()
ln_bias_init: WeightInit = WeightInit.Constant(1.0)
ln_bias_init: WeightInit = field( # pylint: disable=invalid-field-call
default_factory=partial(WeightInit.Constant, scale=1.0)
)
ln_bias_axes: Tuple[str, ...] = ()
kernel_axes_1: Tuple[str, ...] = ()
kernel_axes_2: Tuple[str, ...] = ()
use_bias: bool = False
bias_init: WeightInit = WeightInit.Constant(0.0)
bias_init: WeightInit = field( # pylint: disable=invalid-field-call
default_factory=partial(WeightInit.Constant, scale=0.0)
)
bias_axes_1: Tuple[str, ...] = ()
bias_axes_2: Tuple[str, ...] = ()
enable_low_rank_adaptation: bool = False
Expand Down
9 changes: 7 additions & 2 deletions transformer_engine/jax/praxis/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""
Praxis Modules related Transformer
"""
from dataclasses import field
from functools import partial
from typing import Optional, Sequence, Tuple
import warnings
Expand Down Expand Up @@ -138,7 +139,9 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
zero_centered_gamma: bool = False
return_layernorm_output: bool = False
use_bias: bool = False
bias_init: WeightInit = WeightInit.Constant(0.0)
bias_init: WeightInit = field( # pylint: disable=invalid-field-call
default_factory=partial(WeightInit.Constant, scale=0.0)
)
attn_mask_type: str = "causal"
attn_bias_type: Optional[str] = None
enable_rotary_pos_emb: bool = False
Expand Down Expand Up @@ -275,7 +278,9 @@ class TransformerLayer(TransformerEngineBaseLayer):
dropout_rng_name: str = "dropout"
mlp_activations: Sequence[str] = ("relu",)
use_bias: bool = False
bias_init: WeightInit = WeightInit.Constant(0.0)
bias_init: WeightInit = field( # pylint: disable=invalid-field-call
default_factory=partial(WeightInit.Constant, scale=0.0)
)
apply_residual_connection_post_layernorm: bool = False
output_layernorm: bool = False
float32_attention_logits: bool = False
Expand Down

0 comments on commit e4c99b0

Please sign in to comment.