Skip to content

Commit

Permalink
Do not crash on torch.export (#73)
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
Co-authored-by: youkaichao <[email protected]>
  • Loading branch information
gilfree and youkaichao authored Dec 1, 2024
1 parent ee7d231 commit b9fa7a5
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 0 deletions.
1 change: 1 addition & 0 deletions .github/workflows/test_pytorch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ jobs:
coverage run --append tests/test_pytorch/test_wrapper.py
coverage run --append tests/test_pytorch/test_mp.py
coverage run --append tests/test_pytorch/test_no_graph.py
coverage run --append tests/test_pytorch/test_export.py
coverage run --append tests/test_pytorch/test_irregular.py
coverage run --append tests/test_pytorch/test_simple_graph.py
TORCH_LOGS="+bytecode" coverage run --append tests/test_pytorch/test_logging.py
Expand Down
4 changes: 4 additions & 0 deletions depyf/explain/patched_lazy_format_graph_code.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
def patched_lazy_format_graph_code(name, gm, maybe_id=None, **kwargs):
from depyf.explain.utils import get_current_compiled_fn_name, write_code_to_file_template
from depyf.utils import get_code_owner
# When using torch export, the name includes
# a dumped dict of the nn_module_stack of a node in the module after the ':'
if ':' in name:
name = name.split(':')[0]
func_name = get_current_compiled_fn_name()
file_name = name if name != func_name else "Captured Graph"
file_name = file_name.replace(" ", "_")
Expand Down
25 changes: 25 additions & 0 deletions tests/test_pytorch/test_export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import torch
import depyf

# make sure a very long variable name will not cause any problem
very_long_variable = "a" * 1000
class MyModel(torch.nn.Module):
def __init__(self):
super(MyModel, self).__init__()
encoder = torch.nn.TransformerEncoder(
torch.nn.TransformerEncoderLayer(d_model=8, nhead=2, batch_first=True),
num_layers=6,
)
setattr(self, very_long_variable, encoder)

def forward(self, x):
encoder = getattr(self, very_long_variable)
return encoder(x)

model = MyModel()
x = torch.randn(1, 10, 8)
with depyf.prepare_debug('export_output'):
model_opt = torch.compile(model,fullgraph=True)
model_opt(x)
exported = torch.export.export(model,(x,))
exported_model=exported.module()

0 comments on commit b9fa7a5

Please sign in to comment.