Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify small-batched weight only quantization #2213

Open
wants to merge 41 commits into
base: main
Choose a base branch
from

Conversation

dasistwo
Copy link

I've found that small-batch weight-only GEMV has suffered from the global memory load stall in some inefficient cases.

This PR uses the shared memory in this case to

  1. Remove some duplicate memory requests in the scale factor/zero point load
  2. Use memcpy_async to establish double buffering for load & MMA, which reduces global memory stall

It had little or no effect on the small GEMV, but had some effect on the large GEMV.

Below is the percentage of reduced GEMV computation time, sum of the 5 types of GEMV kernel in a first decoding stage.
Tested on single A100 40GB and H100 80GB, batch size 4, context length <= 512.

A100 H100
Gemma 1-7B 7.93% 3.83%
Llama 2-7B 0.84% 7.52%
Llama 2-13B -0.09% 6.05%
Llama 3-8B 1.62% 9.77%

dasistwo and others added 30 commits April 11, 2024 06:24
Update MLP branch with upstream
* Update TensorRT-LLM

---------

Co-authored-by: Shixiaowei02 <[email protected]>
Found some error cases with unit test cases with small load chunks.
@lfr-0531 lfr-0531 added triaged Issue has been triaged by maintainers Low Precision Issue about lower bit quantization, including int8, int4, fp8 labels Sep 24, 2024
@MARD1NO
Copy link

MARD1NO commented Nov 13, 2024

I've found that small-batch weight-only GEMV has suffered from the global memory load stall in some inefficient cases.

This PR uses the shared memory in this case to

  1. Remove some duplicate memory requests in the scale factor/zero point load
  2. Use memcpy_async to establish double buffering for load & MMA, which reduces global memory stall

It had little or no effect on the small GEMV, but had some effect on the large GEMV.

Below is the percentage of reduced GEMV computation time, sum of the 5 types of GEMV kernel in a first decoding stage. Tested on single A100 40GB and H100 80GB, batch size 4, context length <= 512.

A100 H100
Gemma 1-7B 7.93% 3.83%
Llama 2-7B 0.84% 7.52%
Llama 2-13B -0.09% 6.05%
Llama 3-8B 1.62% 9.77%

Great work! and I wonder what is your benchmark gemv quantize type? channelwise or groupwise, and 4bit or 8bit?

@dasistwo
Copy link
Author

These are the result from the W4A16_AWQ, so the config is 4bit with group size 128. I've also checked it with channel-wise but the performance gain was slightly less than the group-wise one.

@Void1024
Copy link

Thank you for your excellent work. I am the author of the batched GEMV kernel in TRT-LLM. My colleagues and I have reviewed and benchmarked your modifications in this PR. We had previously tried a similar approach, but it didn't yield significant benefits at that time.

We validated the kernel latency with your modifications on different shapes on the H100 but found that there was a performance regression in some shapes. Considering that we have other optimization work for this part of the code in progress, we are unable to merge your changes at this time.

Could you please provide benchmark data comparing the kernel latency before and after your changes for different shapes (for example, m=1, 2, 3, 4 and n/k=2048, 4096, 8192, 12288, 16384) under the GPTQ/AWQ case on both A100 and H100?

@MARD1NO
Copy link

MARD1NO commented Nov 21, 2024

Thank you for your excellent work. I am the author of the batched GEMV kernel in TRT-LLM. My colleagues and I have reviewed and benchmarked your modifications in this PR. We had previously tried a similar approach, but it didn't yield significant benefits at that time.

We validated the kernel latency with your modifications on different shapes on the H100 but found that there was a performance regression in some shapes. Considering that we have other optimization work for this part of the code in progress, we are unable to merge your changes at this time.

Could you please provide benchmark data comparing the kernel latency before and after your changes for different shapes (for example, m=1, 2, 3, 4 and n/k=2048, 4096, 8192, 12288, 16384) under the GPTQ/AWQ case on both A100 and H100?

Hi author, what do you think the idea to use async copy in gemv? gemv is memory bound operation, will async copy boost its performance?
In my experiment (I just write a lowbit cuda core gemv with async copy, not use this weightonly gemv), async copy version performs better only in some larger MNK case and also in some specific devices.

@Void1024
Copy link

Hi author, what do you think the idea to use async copy in gemv? gemv is memory bound operation, will async copy boost its performance? In my experiment (I just write a lowbit cuda core gemv with async copy, not use this weightonly gemv), async copy version performs better only in some larger MNK case and also in some specific devices.

Yes, in my previous experiments, I came to a similar conclusion. If the tileMNK is not large enough, there might not be sufficient computation and LDS to hide the latency of copy_async. Furthermore, in GEMV cases with small batch sizes, the data often fits within the registers.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Low Precision Issue about lower bit quantization, including int8, int4, fp8 triaged Issue has been triaged by maintainers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants