-
Notifications
You must be signed in to change notification settings - Fork 488
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[scan] Test we don't recompile under debugging env flags
When XLA_HLO_DEBUG=1 is set, the lowered HLO contains extra scope and line number information. When a function is scanned twice, the line numbers stay the same but the scopes will change. This tests that those differences don't cause the graph hash of the scan computation to change, so that we don't recompile on every scan call.
- Loading branch information
Showing
4 changed files
with
60 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import sys | ||
import os | ||
import unittest | ||
|
||
import torch | ||
import torch_xla | ||
import torch_xla.debug.metrics as met | ||
from torch_xla.experimental.scan import scan | ||
|
||
parent_folder = os.path.dirname(os.path.dirname(__file__)) | ||
sys.path.append(parent_folder) | ||
from test_utils import XlaTestCase # type:ignore | ||
|
||
|
||
class ScanDebugTest(XlaTestCase): | ||
|
||
def test_scan_no_recompile_with_debug_annotations(self): | ||
""" | ||
When someone adds debugging annotations to the HLO via env vars, the | ||
HLO graph of the combine function captured by scan would have additional metadata | ||
such as line numbers and file names. Still, that should not cause the final IR | ||
graph hash to change. This is subtle because the IR of the `scan` operation will | ||
reference the HLO computation within. | ||
""" | ||
assert os.environ["XLA_HLO_DEBUG"] == "1" | ||
met.clear_all() | ||
|
||
def fn(carry, x): | ||
carry = carry + x | ||
y = x + 42 | ||
return carry, y | ||
|
||
init = torch.tensor([0.0, 0.0], | ||
requires_grad=True, | ||
device=torch_xla.device()) | ||
xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], | ||
requires_grad=True, | ||
device=torch_xla.device()) | ||
|
||
# Run some graph involving a scan operation two times. | ||
for i in range(2): | ||
init.grad = None | ||
xs.grad = None | ||
carry, ys = scan(fn, init, xs) | ||
(carry.sum() + ys.sum()).backward() | ||
torch_xla.sync() | ||
|
||
# Should only compile once and cache the next two times. | ||
self.assertEqual(int(met.counter_value("UncachedCompile")), 1) | ||
self.assertEqual(int(met.counter_value("CachedCompile")), 1) | ||
|
||
|
||
if __name__ == '__main__': | ||
test = unittest.main() | ||
sys.exit(0 if test.result.wasSuccessful() else 1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters