From b9fa7a599c4ae8e7859aa50ce27eba30ecd04df4 Mon Sep 17 00:00:00 2001 From: Gilad <88031955+gilfree@users.noreply.github.com> Date: Sun, 1 Dec 2024 10:58:28 +0200 Subject: [PATCH] Do not crash on torch.export (#73) Signed-off-by: youkaichao Co-authored-by: youkaichao --- .github/workflows/test_pytorch.yml | 1 + .../explain/patched_lazy_format_graph_code.py | 4 +++ tests/test_pytorch/test_export.py | 25 +++++++++++++++++++ 3 files changed, 30 insertions(+) create mode 100644 tests/test_pytorch/test_export.py diff --git a/.github/workflows/test_pytorch.yml b/.github/workflows/test_pytorch.yml index a00cde0a..daa9c385 100644 --- a/.github/workflows/test_pytorch.yml +++ b/.github/workflows/test_pytorch.yml @@ -45,6 +45,7 @@ jobs: coverage run --append tests/test_pytorch/test_wrapper.py coverage run --append tests/test_pytorch/test_mp.py coverage run --append tests/test_pytorch/test_no_graph.py + coverage run --append tests/test_pytorch/test_export.py coverage run --append tests/test_pytorch/test_irregular.py coverage run --append tests/test_pytorch/test_simple_graph.py TORCH_LOGS="+bytecode" coverage run --append tests/test_pytorch/test_logging.py diff --git a/depyf/explain/patched_lazy_format_graph_code.py b/depyf/explain/patched_lazy_format_graph_code.py index a2f3b3c3..2ab26ad0 100644 --- a/depyf/explain/patched_lazy_format_graph_code.py +++ b/depyf/explain/patched_lazy_format_graph_code.py @@ -1,6 +1,10 @@ def patched_lazy_format_graph_code(name, gm, maybe_id=None, **kwargs): from depyf.explain.utils import get_current_compiled_fn_name, write_code_to_file_template from depyf.utils import get_code_owner + # When using torch export, the name includes + # a dumped dict of the nn_module_stack of a node in the module after the ':' + if ':' in name: + name = name.split(':')[0] func_name = get_current_compiled_fn_name() file_name = name if name != func_name else "Captured Graph" file_name = file_name.replace(" ", "_") diff --git a/tests/test_pytorch/test_export.py b/tests/test_pytorch/test_export.py new file mode 100644 index 00000000..94e899f3 --- /dev/null +++ b/tests/test_pytorch/test_export.py @@ -0,0 +1,25 @@ +import torch +import depyf + +# make sure a very long variable name will not cause any problem +very_long_variable = "a" * 1000 +class MyModel(torch.nn.Module): + def __init__(self): + super(MyModel, self).__init__() + encoder = torch.nn.TransformerEncoder( + torch.nn.TransformerEncoderLayer(d_model=8, nhead=2, batch_first=True), + num_layers=6, + ) + setattr(self, very_long_variable, encoder) + + def forward(self, x): + encoder = getattr(self, very_long_variable) + return encoder(x) + +model = MyModel() +x = torch.randn(1, 10, 8) +with depyf.prepare_debug('export_output'): + model_opt = torch.compile(model,fullgraph=True) + model_opt(x) + exported = torch.export.export(model,(x,)) + exported_model=exported.module()