Skip to content

Commit

Permalink
[PyTorch] Userbuffers support in operation-based API (NVIDIA#1142)
Browse files Browse the repository at this point in the history
* Add Userbuffers support for column TP linear layer

Signed-off-by: Tim Moon <[email protected]>

* Add Userbuffers support for row TP linear layer

Signed-off-by: Tim Moon <[email protected]>

* Interpret linear+RS as row TP linear

Signed-off-by: Tim Moon <[email protected]>

* Add Userbuffers support for FP8 row TP linear layer

Assumes FP8 RS, which is not a good assumption.

Signed-off-by: Tim Moon <[email protected]>

* Debug bug with incorrect bias pointers in UB GEMM

Bias pointers are not properly offset for different data chunks. Also removed logic for FP8 RS.

Signed-off-by: Tim Moon <[email protected]>

* Add Userbuffers support for linear dgrad

Test passes with row TP, fails with col TP.

Signed-off-by: Tim Moon <[email protected]>

* Add Userbuffers support for linear wgrad

Signed-off-by: Tim Moon <[email protected]>

* Add support for grad bias

Signed-off-by: Tim Moon <[email protected]>

* Fused cast-transpose-dbias

Signed-off-by: Tim Moon <[email protected]>

* Support case where wgrad is optional

Signed-off-by: Tim Moon <[email protected]>

* Expand documentation

Signed-off-by: Tim Moon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix linter warnings

Signed-off-by: Tim Moon <[email protected]>

* Use recently added convenience functions in Float8Tensor

Signed-off-by: Tim Moon <[email protected]>

* Respect autograd dtype

Signed-off-by: Tim Moon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix missing imports

Signed-off-by: Tim Moon <[email protected]>

* Respect PyT autocast dtype in bprop

Signed-off-by: Tim Moon <[email protected]>

* Fix linter warnings

Signed-off-by: Tim Moon <[email protected]>

* Debug merge conflicts

Signed-off-by: Tim Moon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
timmoon10 and pre-commit-ci[bot] authored Nov 6, 2024
1 parent 77c37d4 commit 095b27d
Show file tree
Hide file tree
Showing 11 changed files with 2,033 additions and 10 deletions.
1 change: 1 addition & 0 deletions qa/L1_pytorch_distributed_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ pip install pytest==8.2.1
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py
pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py
Loading

0 comments on commit 095b27d

Please sign in to comment.