Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

xe: sdpa: Improve performance of quantization with better alignment and prefetching #2322

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

umar456
Copy link
Contributor

@umar456 umar456 commented Dec 27, 2024

Description

This PR improves the performance of the micro SDPA kernel by using prefetching and setting better alignment when generating the microkernels. This change has a significant impact on certain sizes ranging from (0.89x-1.26x) over the original version.

@umar456 umar456 added performance platform:gpu-intel Codeowner: @oneapi-src/onednn-gpu-intel labels Dec 27, 2024
@umar456 umar456 requested a review from a team as a code owner December 27, 2024 20:56
@umar456
Copy link
Contributor Author

umar456 commented Dec 27, 2024

make test
disable device_cpu
disable benchdnn_all
enable benchdnn_nightly
enable benchdnn_graph

@umar456 umar456 force-pushed the uarshad/sdpa_scale_zp_alignment branch from 9af6de9 to 4746108 Compare January 3, 2025 09:01
@umar456
Copy link
Contributor Author

umar456 commented Jan 3, 2025

make test
disable device_cpu
disable benchdnn_all
enable benchdnn_nightly
enable benchdnn_graph

/* n_sg */ sg_per_wg,
/* sg_size */ SUBGROUP_SIZE,
/* cache */ LSC_LDCC_L1C_L3C);
//return;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it improve performance to have the first K tile prefetch here (before loading Q)? IIRC in my earlier testing it was better to delay the first K tile prefetch until after issuing the Q load.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw a slight improvement by moving that forward but I want to test with a larger set of examples. I am compiling that and post my results.

/* sg_id */ sg_ij,
/* n_sg */ sg_per_wg,
/* sg_size */ SUBGROUP_SIZE,
/* cache */ LSC_LDCC_L1C_L3C);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This tile is so small that it doesn't need cooperative prefetching (hence the earlier simpler logic). Does this change improve performance?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looked like all subgroups will prefetch the same memory region with the simpler call because all subgroups will be assigned the same sg_id(0). I didn't think the compiler would be able to gate the other subgroups from executing the prefetch operation. Am I interpreting that incorrectly.

I think I saw a (5%) gain for certain sizes but I can test it again and post my results with the rest of the changes.

Copy link
Contributor

@petercad petercad Jan 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, the original code has all subgroups doing the prefetch. It's a bit of a tradeoff: with all subgroups doing the prefetch, there's likely some additional overhead in LSC keeping track of the outstanding prefetches. On the other hand, with cooperative prefetch, we're just relying on the timing being right, since there's no barrier between this prefetch and the mask load.

But all that said, if cooperative prefetch shows better performance, let's use it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is the difference between the previous version and the current version of the mask prefetch for broadcasted masks.

shape k ks kzp q msk v vs vzp file baseline newtime speedup
–in-shapes 1x1x128x384*abcd 1x1x2x384 1x1x2x384 1x1x384x128 1x1x1x384 1x1x384x128 1x1x384x2 1x1x384x2 –case=complexfusion/mha/sdpa-0ks8f16s8-3qf16-wscale-wmask-6vs8f16s8.json 0.03088 0.02864 1.0782123
–in-shapes 1x1x128x384*abdc 1x1x2x384 1x1x2x384 1x1x384x128 1x1x1x384 1x1x384x128 1x1x384x2 1x1x384x2 –case=complexfusion/mha/sdpa-0ks8f16s8-3qf16-wscale-wmask-6vs8f16s8.json 0.02752 0.02752 1.
–in-shapes 1x1x128x512*abcd 1x1x2x512 1x1x2x512 1x1x512x128 1x1x1x512 1x1x512x128 1x1x512x2 1x1x512x2 –case=complexfusion/mha/sdpa-0ks8f16s8-3qf16-wscale-wmask-6vs8f16s8.json 0.03184 0.03056 1.0418848
–in-shapes 1x1x128x512*abdc 1x1x2x512 1x1x2x512 1x1x512x128 1x1x1x512 1x1x512x128 1x1x512x2 1x1x512x2 –case=complexfusion/mha/sdpa-0ks8f16s8-3qf16-wscale-wmask-6vs8f16s8.json 0.0312 0.02976 1.0483871
–in-shapes 1x1x128x1024*abcd 1x1x2x1024 1x1x2x1024 1x1x1024x128 1x1x1x1024 1x1x1024x128 1x1x1024x2 1x1x1024x2 –case=complexfusion/mha/sdpa-0ks8f16s8-3qf16-wscale-wmask-6vs8f16s8.json 0.0584 0.05696 1.0252809
–in-shapes 1x1x128x1024*abdc 1x1x2x1024 1x1x2x1024 1x1x1024x128 1x1x1x1024 1x1x1024x128 1x1x1024x2 1x1x1024x2 –case=complexfusion/mha/sdpa-0ks8f16s8-3qf16-wscale-wmask-6vs8f16s8.json 0.05632 0.05504 1.0232558
–in-shapes 1x1x128x2048*abcd 1x1x2x2048 1x1x2x2048 1x1x2048x128 1x1x1x2048 1x1x2048x128 1x1x2048x2 1x1x2048x2 –case=complexfusion/mha/sdpa-0ks8f16s8-3qf16-wscale-wmask-6vs8f16s8.json 0.11568 0.1192 0.97046980
–in-shapes 1x1x128x2048*abdc 1x1x2x2048 1x1x2x2048 1x1x2048x128 1x1x1x2048 1x1x2048x128 1x1x2048x2 1x1x2048x2 –case=complexfusion/mha/sdpa-0ks8f16s8-3qf16-wscale-wmask-6vs8f16s8.json 0.10832 0.10608 1.0211161

@umar456 umar456 force-pushed the uarshad/sdpa_scale_zp_alignment branch from 4746108 to 1d2375a Compare January 3, 2025 18:51
const uint cl_per_sg = (cl + n_sg - 1) / n_sg;
const uint cl_iters = (cl_per_sg + sg_size - 1) / sg_size;
const uint cl_per_sg = (cl + sg_size - 1) / sg_size;
const uint cl_iters = (cl_per_sg + n_sg - 1) / n_sg;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain what's going on this patch? This doesn't look right.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cl_per_sg was using the number of subgroups instead of the subgroup size to calculate the cache lines per sg.

The main difference in this commit is that multiple subgroups were prefetching the same memory region because the i_cl indexing was not offsetting across subgroups. This increased the number of iterations and I believe some cache lines were skipped because of this.

Copy link
Contributor

@petercad petercad Jan 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cl_per_sg was using the number of subgroups instead of the subgroup size to calculate the cache lines per sg.

The existing code has the expected behavior. We're gathering up all the cache lines to prefetch (cl) then splitting them among the n_sg available subgroups, leaving cl_per_sg cache lines per subgroup. Then, we're splitting up cl_per_sg between work-items in the subgroup.

The main difference in this commit is that multiple subgroups were prefetching the same memory region because the i_cl indexing was not offsetting across subgroups

Ah yes, it wasn't offsetting properly, thanks for catching that. I think you can quickly fix it by reverting the patch to these lines (617-618) and applying the patch I suggested below.

Comment on lines 621 to 623
uint i_cl = ii_cl * cl_per_sg * sg_size + (sg_id * sg_size)
+ get_sub_group_local_id();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
uint i_cl = ii_cl * cl_per_sg * sg_size + (sg_id * sg_size)
+ get_sub_group_local_id();
uint i_cl = (ii_cl + (sg_id * cl_per_sg)) * sg_size + get_sub_group_local_id();

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't it be:

Suggested change
uint i_cl = ii_cl * cl_per_sg * sg_size + (sg_id * sg_size)
+ get_sub_group_local_id();
uint i_cl = (ii_cl * cl_per_sg + sg_id) * sg_size + get_sub_group_local_id();

Otherwise the second iteration will only offset by sg_size which will overlap with iteration zero of subgroup one.

@umar456 umar456 force-pushed the uarshad/sdpa_scale_zp_alignment branch from 1d2375a to 675e3c8 Compare January 4, 2025 21:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance platform:gpu-intel Codeowner: @oneapi-src/onednn-gpu-intel
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants