Skip to content

Commit

Permalink
Add tvm compiler flags for detr
Browse files Browse the repository at this point in the history
  • Loading branch information
meenakshiramanathan1 committed Jan 6, 2025
1 parent d28dad5 commit 80ca96a
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion forge/test/models/pytorch/vision/detr/test_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

@pytest.mark.nightly
@pytest.mark.model_analysis
@pytest.mark.xfail(reason="AttributeError: <class 'tvm.ir.op.Op'> has no attribute name_hint")
@pytest.mark.xfail(reason="Failing with pcc=0.97")
@pytest.mark.parametrize("variant", ["facebook/detr-resnet-50"])
def test_detr_detection(variant):

Expand All @@ -29,6 +29,9 @@ def test_detr_detection(variant):
input_batch = preprocess_input_data(image_url)

# Compiler test
compiler_cfg = forge.config._get_global_compiler_config()
compiler_cfg.enable_tvm_constant_prop = True
compiler_cfg.convert_framework_params_to_tvm = True
compiled_model = forge.compile(
framework_model, sample_inputs=[input_batch], module_name="pt_" + str(variant.split("/")[-1].replace("-", "_"))
)
Expand Down

0 comments on commit 80ca96a

Please sign in to comment.