-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
xe: sdpa: Improve performance of quantization with better alignment and prefetching #2322
base: main
Are you sure you want to change the base?
Conversation
make test |
9af6de9
to
4746108
Compare
make test |
src/gpu/intel/ocl/micro_sdpa.cl
Outdated
/* n_sg */ sg_per_wg, | ||
/* sg_size */ SUBGROUP_SIZE, | ||
/* cache */ LSC_LDCC_L1C_L3C); | ||
//return; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it improve performance to have the first K tile prefetch here (before loading Q)? IIRC in my earlier testing it was better to delay the first K tile prefetch until after issuing the Q load.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I saw a slight improvement by moving that forward but I want to test with a larger set of examples. I am compiling that and post my results.
/* sg_id */ sg_ij, | ||
/* n_sg */ sg_per_wg, | ||
/* sg_size */ SUBGROUP_SIZE, | ||
/* cache */ LSC_LDCC_L1C_L3C); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This tile is so small that it doesn't need cooperative prefetching (hence the earlier simpler logic). Does this change improve performance?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looked like all subgroups will prefetch the same memory region with the simpler call because all subgroups will be assigned the same sg_id(0). I didn't think the compiler would be able to gate the other subgroups from executing the prefetch operation. Am I interpreting that incorrectly.
I think I saw a (5%) gain for certain sizes but I can test it again and post my results with the rest of the changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, the original code has all subgroups doing the prefetch. It's a bit of a tradeoff: with all subgroups doing the prefetch, there's likely some additional overhead in LSC keeping track of the outstanding prefetches. On the other hand, with cooperative prefetch, we're just relying on the timing being right, since there's no barrier between this prefetch and the mask load.
But all that said, if cooperative prefetch shows better performance, let's use it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is the difference between the previous version and the current version of the mask prefetch for broadcasted masks.
shape | k | ks | kzp | q | msk | v | vs | vzp | file | baseline | newtime | speedup |
---|---|---|---|---|---|---|---|---|---|---|---|---|
–in-shapes | 1x1x128x384*abcd | 1x1x2x384 | 1x1x2x384 | 1x1x384x128 | 1x1x1x384 | 1x1x384x128 | 1x1x384x2 | 1x1x384x2 | –case=complexfusion/mha/sdpa-0ks8f16s8-3qf16-wscale-wmask-6vs8f16s8.json | 0.03088 | 0.02864 | 1.0782123 |
–in-shapes | 1x1x128x384*abdc | 1x1x2x384 | 1x1x2x384 | 1x1x384x128 | 1x1x1x384 | 1x1x384x128 | 1x1x384x2 | 1x1x384x2 | –case=complexfusion/mha/sdpa-0ks8f16s8-3qf16-wscale-wmask-6vs8f16s8.json | 0.02752 | 0.02752 | 1. |
–in-shapes | 1x1x128x512*abcd | 1x1x2x512 | 1x1x2x512 | 1x1x512x128 | 1x1x1x512 | 1x1x512x128 | 1x1x512x2 | 1x1x512x2 | –case=complexfusion/mha/sdpa-0ks8f16s8-3qf16-wscale-wmask-6vs8f16s8.json | 0.03184 | 0.03056 | 1.0418848 |
–in-shapes | 1x1x128x512*abdc | 1x1x2x512 | 1x1x2x512 | 1x1x512x128 | 1x1x1x512 | 1x1x512x128 | 1x1x512x2 | 1x1x512x2 | –case=complexfusion/mha/sdpa-0ks8f16s8-3qf16-wscale-wmask-6vs8f16s8.json | 0.0312 | 0.02976 | 1.0483871 |
–in-shapes | 1x1x128x1024*abcd | 1x1x2x1024 | 1x1x2x1024 | 1x1x1024x128 | 1x1x1x1024 | 1x1x1024x128 | 1x1x1024x2 | 1x1x1024x2 | –case=complexfusion/mha/sdpa-0ks8f16s8-3qf16-wscale-wmask-6vs8f16s8.json | 0.0584 | 0.05696 | 1.0252809 |
–in-shapes | 1x1x128x1024*abdc | 1x1x2x1024 | 1x1x2x1024 | 1x1x1024x128 | 1x1x1x1024 | 1x1x1024x128 | 1x1x1024x2 | 1x1x1024x2 | –case=complexfusion/mha/sdpa-0ks8f16s8-3qf16-wscale-wmask-6vs8f16s8.json | 0.05632 | 0.05504 | 1.0232558 |
–in-shapes | 1x1x128x2048*abcd | 1x1x2x2048 | 1x1x2x2048 | 1x1x2048x128 | 1x1x1x2048 | 1x1x2048x128 | 1x1x2048x2 | 1x1x2048x2 | –case=complexfusion/mha/sdpa-0ks8f16s8-3qf16-wscale-wmask-6vs8f16s8.json | 0.11568 | 0.1192 | 0.97046980 |
–in-shapes | 1x1x128x2048*abdc | 1x1x2x2048 | 1x1x2x2048 | 1x1x2048x128 | 1x1x1x2048 | 1x1x2048x128 | 1x1x2048x2 | 1x1x2048x2 | –case=complexfusion/mha/sdpa-0ks8f16s8-3qf16-wscale-wmask-6vs8f16s8.json | 0.10832 | 0.10608 | 1.0211161 |
4746108
to
1d2375a
Compare
src/gpu/intel/ocl/tile_ops.h
Outdated
const uint cl_per_sg = (cl + n_sg - 1) / n_sg; | ||
const uint cl_iters = (cl_per_sg + sg_size - 1) / sg_size; | ||
const uint cl_per_sg = (cl + sg_size - 1) / sg_size; | ||
const uint cl_iters = (cl_per_sg + n_sg - 1) / n_sg; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain what's going on this patch? This doesn't look right.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cl_per_sg was using the number of subgroups instead of the subgroup size to calculate the cache lines per sg.
The main difference in this commit is that multiple subgroups were prefetching the same memory region because the i_cl indexing was not offsetting across subgroups. This increased the number of iterations and I believe some cache lines were skipped because of this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cl_per_sg was using the number of subgroups instead of the subgroup size to calculate the cache lines per sg.
The existing code has the expected behavior. We're gathering up all the cache lines to prefetch (cl
) then splitting them among the n_sg
available subgroups, leaving cl_per_sg
cache lines per subgroup. Then, we're splitting up cl_per_sg
between work-items in the subgroup.
The main difference in this commit is that multiple subgroups were prefetching the same memory region because the i_cl indexing was not offsetting across subgroups
Ah yes, it wasn't offsetting properly, thanks for catching that. I think you can quickly fix it by reverting the patch to these lines (617-618) and applying the patch I suggested below.
src/gpu/intel/ocl/tile_ops.h
Outdated
uint i_cl = ii_cl * cl_per_sg * sg_size + (sg_id * sg_size) | ||
+ get_sub_group_local_id(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
uint i_cl = ii_cl * cl_per_sg * sg_size + (sg_id * sg_size) | |
+ get_sub_group_local_id(); | |
uint i_cl = (ii_cl + (sg_id * cl_per_sg)) * sg_size + get_sub_group_local_id(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't it be:
uint i_cl = ii_cl * cl_per_sg * sg_size + (sg_id * sg_size) | |
+ get_sub_group_local_id(); | |
uint i_cl = (ii_cl * cl_per_sg + sg_id) * sg_size + get_sub_group_local_id(); |
Otherwise the second iteration will only offset by sg_size which will overlap with iteration zero of subgroup one.
1d2375a
to
675e3c8
Compare
Description
This PR improves the performance of the micro SDPA kernel by using prefetching and setting better alignment when generating the microkernels. This change has a significant impact on certain sizes ranging from (0.89x-1.26x) over the original version.