Skip to content

Commit

Permalink
[JAX] Adapt latest JAX/PAX image (NVIDIA#744)
Browse files Browse the repository at this point in the history
* value_and_grad requires same shape for input and gradients

Signed-off-by: Reese Wang <[email protected]>

* Use high precision layernorm

Signed-off-by: Reese Wang <[email protected]>

* Remove local_device_ids as it caused unexpected behaviors

Signed-off-by: Reese Wang <[email protected]>

* Revert "Remove local_device_ids as it caused unexpected behaviors"

This reverts commit c54349b.

Signed-off-by: Reese Wang <[email protected]>

---------

Signed-off-by: Reese Wang <[email protected]>
  • Loading branch information
zlsh80826 authored Apr 6, 2024
1 parent d541d20 commit bfe21c3
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
5 changes: 3 additions & 2 deletions tests/jax/test_custom_call_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,8 @@ def primitive_bwd(ctx, g):
primitive.defvjp(primitive_fwd, primitive_bwd)
func = value_and_grad(lambda x, y, z, w: jnp.mean(primitive(x, y, z, w)), (0, 1, 2, 3))

return func(inputs, no_use, no_use, no_use)
return func(inputs, jnp.transpose(inputs, (2, 0, 1)),
jnp.zeros(inputs.shape[-1], dtype=inputs.dtype), no_use)

@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('shape', [(32, 2, 64), (64, 2, 256)])
Expand Down Expand Up @@ -582,7 +583,7 @@ def primitive_bwd(ctx, g):
primitive.defvjp(primitive_fwd, primitive_bwd)
func = value_and_grad(lambda x, y, z: jnp.mean(primitive(x, y, z)), (0, 1, 2))

return func(inputs, no_use, no_use)
return func(inputs, jnp.transpose(inputs, (1, 2, 0)), no_use)

@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('shape', [(32, 2, 64), (64, 2, 256)])
Expand Down
7 changes: 3 additions & 4 deletions tests/jax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,19 +731,18 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
axes=('embed',))
bias = jnp.asarray(bias, self.dtype)

y = jnp.asarray(y, self.dtype)
if not self.zero_centered_gamma:
z = y * scale + bias
else:
z = y * (scale + 1) + bias
z = y * (scale + 1.) + bias
else:
assert self.layernorm_type == 'rmsnorm'
assert not self.zero_centered_gamma
mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True)
y = jnp.asarray(x * lax.rsqrt(mean2 + self.epsilon), self.dtype)
y = x * lax.rsqrt(mean2 + self.epsilon)
z = y * scale

return z
return jnp.asarray(z, self.dtype)


class RelativePositionBiases(nn.Module):
Expand Down

0 comments on commit bfe21c3

Please sign in to comment.