Skip to content

Commit

Permalink
[scan] Test we don't recompile under debugging env flags
Browse files Browse the repository at this point in the history
When XLA_HLO_DEBUG=1 is set, the lowered HLO contains extra scope and
line number information. When a function is scanned twice, the line
numbers stay the same but the scopes will change. This tests that those
differences don't cause the graph hash of the scan computation to
change, so that we don't recompile on every scan call.
  • Loading branch information
tengyifei committed Jan 10, 2025
1 parent 6963e19 commit d74c39c
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 2 deletions.
5 changes: 3 additions & 2 deletions test/neuron/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,8 @@ function run_xla_op_tests1 {
function run_xla_op_tests2 {
run_test "$CDIR/pjrt/test_dtypes.py"
#run_test "$CDIR/test_while_loop.py"
run_test "$CDIR/test_scan.py"
run_test "$CDIR/scan/test_scan.py"
run_xla_hlo_debug "$CDIR/scan/test_scan_debug.py"
run_test "$CDIR/test_autocast.py"
run_test "$CDIR/test_grad_checkpoint.py"
run_test "$CDIR/test_grad_checkpoint.py" "$@" --test_autocast
Expand Down Expand Up @@ -321,4 +322,4 @@ if [ "$LOGFILE" != "" ]; then
run_tests 2>&1 | tee $LOGFILE
else
run_tests
fi
fi
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ function run_xla_op_tests2 {
run_test "$CDIR/scan/test_scan.py"
run_test "$CDIR/scan/test_scan_spmd.py"
run_test "$CDIR/scan/test_scan_layers.py"
run_xla_hlo_debug run_test "$CDIR/scan/test_scan_debug.py"
run_test "$CDIR/test_autocast.py"
run_test "$CDIR/eager/test_eager.py"
run_test "$CDIR/eager/test_eager_with_xla_compile.py"
Expand Down
55 changes: 55 additions & 0 deletions test/scan/test_scan_debug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import sys
import os
import unittest

import torch
import torch_xla
import torch_xla.debug.metrics as met
from torch_xla.experimental.scan import scan

parent_folder = os.path.dirname(os.path.dirname(__file__))
sys.path.append(parent_folder)
from test_utils import XlaTestCase # type:ignore


class ScanDebugTest(XlaTestCase):

def test_scan_no_recompile_with_debug_annotations(self):
"""
When someone adds debugging annotations to the HLO via env vars, the
HLO graph of the combine function captured by scan would have additional metadata
such as line numbers and file names. Still, that should not cause the final IR
graph hash to change. This is subtle because the IR of the `scan` operation will
reference the HLO computation within.
"""
assert os.environ["XLA_HLO_DEBUG"] == "1"
met.clear_all()

def fn(carry, x):
carry = carry + x
y = x + 42
return carry, y

init = torch.tensor([0.0, 0.0],
requires_grad=True,
device=torch_xla.device())
xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
requires_grad=True,
device=torch_xla.device())

# Run some graph involving a scan operation two times.
for i in range(2):
init.grad = None
xs.grad = None
carry, ys = scan(fn, init, xs)
(carry.sum() + ys.sum()).backward()
torch_xla.sync()

# Should only compile once and cache the next two times.
self.assertEqual(int(met.counter_value("UncachedCompile")), 1)
self.assertEqual(int(met.counter_value("CachedCompile")), 1)


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
1 change: 1 addition & 0 deletions test/tpu/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ python3 "$TEST_CDIR/test_while_loop.py"
python3 "$TEST_CDIR/scan/test_scan.py"
python3 "$TEST_CDIR/scan/test_scan_spmd.py"
python3 "$TEST_CDIR/scan/test_scan_layers.py"
run_xla_hlo_debug python3 "$TEST_CDIR/scan/test_scan_debug.py"
python3 "$TEST_CDIR/test_pallas.py" -v
python3 "$TEST_CDIR/test_pallas_spmd.py"
python3 "$TEST_CDIR/test_tpu_paged_attention_kernel.py"
Expand Down

0 comments on commit d74c39c

Please sign in to comment.