Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add one more AT_DISPATCH for layer norm gamma scalar type #883

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

Liangliang-Ma
Copy link

@Liangliang-Ma Liangliang-Ma commented Sep 9, 2024

Original kernel only do AT_DISPATCH on input dtype, while weight/bias (gamma/beta) may have different data types. It will break when trying to get gamma.data_ptr<weight_t>(), for its wrong scalar type.

Same issue break comfirmed on cuda.

Fixing this by adding one more AT_DISPATCH.

@xytintel
Copy link
Contributor

@Liangliang-Ma What scenarios need this fix? Since CUDA doesn’t specifically address this issue, is it necessary? @fengyuan14 @EikanWang

@Liangliang-Ma
Copy link
Author

Liangliang-Ma commented Sep 11, 2024

@Liangliang-Ma What scenarios need this fix? Since CUDA doesn’t specifically address this issue, is it necessary? @fengyuan14 @EikanWang

Met this in a ipex rebase 2.5 workload jira, which is a bart model doing question answering task on float16 precision(no AMP). It turned out to be easy to reproduce in single UT, also failing with lastest stock pytorch whl on CUDA.

import torch
from torch.nn import LayerNorm

dtype1 = torch.float16
inputs = torch.rand([1, 1024, 4096], dtype=dtype1, requires_grad=True).cuda()
ln = LayerNorm(4096, eps=1e-05, elementwise_affine=True).cuda()
outputs = ln(inputs)
#RuntimeError: expected scalar type Half but found Float

More detailed description:
LayerNorm's weight/bias is initialized with float32 by default if dtype is not specified in python interface. The layernorm kernel will get this kind of parameter: input--float16, weight--float32, bias--float32. But currently AT_DISPATCH will dispatch these three to same type <scalar_t, scalar_t, scalar_t>, which will end up with failing for getting data_ptr.

@xytintel xytintel marked this pull request as draft November 5, 2024 03:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants