Skip to content

Commit

Permalink
refactored tests to use verify method
Browse files Browse the repository at this point in the history
  • Loading branch information
vkovinicTT committed Jan 3, 2025
1 parent 5cd8425 commit a6cba01
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 135 deletions.
154 changes: 25 additions & 129 deletions forge/test/mlir/vit/ops/test_vit_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from torch import nn

import forge
from forge.op.eval.common import compare_with_golden_pcc
from forge.verify.verify import verify
from forge.verify.config import VerifyConfig
from forge.verify.value_checkers import AutomaticValueChecker


@pytest.mark.parametrize(
Expand All @@ -32,13 +34,9 @@ def forward(self, a, b):
inputs = [torch.rand(shapes[0]), torch.rand(shapes[1])] # when we use dtype=torch.bfloat16, pcc fails

framework_model = Add()
fw_out = framework_model(*inputs)

compiled_model = forge.compile(framework_model, sample_inputs=inputs)
co_out = compiled_model(*inputs)

co_out = [co.to("cpu") for co in co_out]
assert compare_with_golden_pcc(golden=fw_out, calculated=co_out[0], pcc=0.99)
verify(inputs, framework_model, compiled_model)


@pytest.mark.push
Expand Down Expand Up @@ -84,20 +82,10 @@ def forward(self, x):

inputs = [torch.rand(shapes)]

# Framework Model
framework_model = Broadcast(dim, new_shape)
fw_out = framework_model(*inputs)

# Compile Model
compiled_model = forge.compile(framework_model, sample_inputs=inputs)
co_out = compiled_model(*inputs)

# Move compiled outputs to CPU for comparison
co_out = [co.to("cpu") for co in co_out]

# Validate the output shapes and values
assert fw_out.shape == co_out[0].shape, f"Expected shape {fw_out.shape}, but got {co_out[0].shape}"
assert compare_with_golden_pcc(golden=fw_out, calculated=co_out[0], pcc=0.99)
verify(inputs, framework_model, compiled_model)


@pytest.mark.parametrize(
Expand All @@ -107,7 +95,6 @@ def forward(self, x):
],
)
@pytest.mark.push
@pytest.mark.xfail(reason="Data values do not match, pcc < 0.1")
def test_concat(inputs_and_dim):
in_shape1, in_shape2, dim = inputs_and_dim

Expand All @@ -121,13 +108,9 @@ def forward(self, a, b):
inputs = [torch.rand(in_shape1), torch.rand(in_shape2)]

framework_model = Concat()
fw_out = framework_model(*inputs)

compiled_model = forge.compile(framework_model, sample_inputs=inputs)
co_out = compiled_model(*inputs)

co_out = [co.to("cpu") for co in co_out]
assert compare_with_golden_pcc(golden=fw_out, calculated=co_out[0], pcc=0.1)
verify(inputs, framework_model, compiled_model)


@pytest.mark.parametrize("shape", [(1, 3, 224, 224)])
Expand Down Expand Up @@ -163,25 +146,13 @@ def __init__(self, conv_params):
def forward(self, x):
return self.conv(x)

# Instantiate the model with the parameters
framework_model = Conv2d(conv_params)

# Prepare the input tensor
inputs = [torch.rand(shape)]

# Get framework output
fw_out = framework_model(*inputs)

# Compile the model with Forge
framework_model = Conv2d(conv_params)
compiled_model = forge.compile(framework_model, sample_inputs=inputs)
co_out = compiled_model(*inputs)

# Move compiled output to CPU for comparison
co_out = [co.to("cpu") for co in co_out]
fw_out = fw_out if isinstance(fw_out, list) else [fw_out]

# Ensure the framework and compiled outputs match
assert all([compare_with_golden_pcc(golden=fo, calculated=co, pcc=0.99) for fo, co in zip(fw_out, co_out)])
verify(inputs, framework_model, compiled_model)


@pytest.mark.parametrize("shape", [(1, 197, 3072)])
Expand All @@ -200,25 +171,13 @@ def __init__(self, gelu_params):
def forward(self, x):
return self.gelu(x)

# Instantiate the model with the parameters
framework_model = GELU(gelu_params)

# Prepare the input tensor
inputs = [torch.rand(shape)]

# Get framework output
fw_out = framework_model(*inputs)

# Compile the model with Forge
framework_model = GELU(gelu_params)
compiled_model = forge.compile(framework_model, sample_inputs=inputs)
co_out = compiled_model(*inputs)

# Move compiled output to CPU for comparison
co_out = [co.to("cpu") for co in co_out]
fw_out = fw_out if isinstance(fw_out, list) else [fw_out]

# Ensure the framework and compiled outputs match
assert all([compare_with_golden_pcc(golden=fo, calculated=co, pcc=0.99) for fo, co in zip(fw_out, co_out)])
verify(inputs, framework_model, compiled_model)


@pytest.mark.parametrize("shape", [(1, 197, 768)])
Expand Down Expand Up @@ -246,25 +205,13 @@ def __init__(self, index_params):
def forward(self, x):
return x.narrow(self.dim, self.start, self.stop - self.start)[:: self.stride]

# Instantiate the model with the parameters
framework_model = IndexModule(index_params)

# Prepare the input tensor
inputs = [torch.rand(shape)]

# Get framework output
fw_out = framework_model(*inputs)

# Compile the model with Forge
framework_model = IndexModule(index_params)
compiled_model = forge.compile(framework_model, sample_inputs=inputs)
co_out = compiled_model(*inputs)

# Move compiled output to CPU for comparison
co_out = [co.to("cpu") for co in co_out]
fw_out = fw_out if isinstance(fw_out, list) else [fw_out]

# Ensure the framework and compiled outputs match
assert all([compare_with_golden_pcc(golden=fo, calculated=co, pcc=0.99) for fo, co in zip(fw_out, co_out)])
verify(inputs, framework_model, compiled_model)


@pytest.mark.parametrize("shape", [(1, 197, 768)])
Expand Down Expand Up @@ -292,25 +239,12 @@ def __init__(self, layernorm_params):
def forward(self, x):
return nn.functional.layer_norm(x, (x.size(self.dim),), self.weights, self.bias, self.epsilon)

# Instantiate the model with the parameters
framework_model = LayernormModule(layernorm_params)

# Prepare the input tensor
inputs = [torch.rand(shape)]

# Get framework output
fw_out = framework_model(*inputs)

# Compile the model with Forge
framework_model = LayernormModule(layernorm_params)
compiled_model = forge.compile(framework_model, sample_inputs=inputs)
co_out = compiled_model(*inputs)

# Move compiled output to CPU for comparison
co_out = [co.to("cpu") for co in co_out]
fw_out = fw_out if isinstance(fw_out, list) else [fw_out]

# Ensure the framework and compiled outputs match
assert all([compare_with_golden_pcc(golden=fo, calculated=co, pcc=0.99) for fo, co in zip(fw_out, co_out)])
verify(inputs, framework_model, compiled_model)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -341,13 +275,9 @@ def forward(self, x, y):
]

framework_model = Matmul()
fw_out = framework_model(*inputs)

compiled_model = forge.compile(framework_model, sample_inputs=inputs)
co_out = compiled_model(*inputs)

co_out = [co.to("cpu") for co in co_out]
assert compare_with_golden_pcc(golden=fw_out, calculated=co_out[0], pcc=0.90)
verify(inputs, framework_model, compiled_model, VerifyConfig(value_checker=AutomaticValueChecker(pcc=0.9)))


@pytest.mark.parametrize(
Expand All @@ -374,12 +304,9 @@ def forward(self, x, y):
]

framework_model = Multiply()
fw_out = framework_model(*inputs)

compiled_model = forge.compile(framework_model, sample_inputs=inputs)
co_out = compiled_model(*inputs)[0].to("cpu")

assert compare_with_golden_pcc(fw_out, co_out, pcc=0.99)
verify(inputs, framework_model, compiled_model)


@pytest.mark.parametrize(
Expand All @@ -405,13 +332,6 @@ def test_reshape(source_and_target_shape):
if len(source_shape) > 4 or len(target_shape) > 4:
pytest.xfail("Only 2D, 3D, and 4D tensors are supported")

if (
source_and_target_shape == ((32, 11, 11), (1, 32, 11, 11))
or source_and_target_shape == ((32, 11, 11), (32, 11, 11))
or source_and_target_shape == ((1, 32, 64, 11), (32, 64, 11))
):
pytest.xfail("pcc < 0.99")

class Reshape(nn.Module):
def __init__(self, target_shape):
super().__init__()
Expand All @@ -421,15 +341,11 @@ def forward(self, a):
return torch.reshape(a, self.target_shape)

inputs = [torch.rand(source_shape, dtype=torch.bfloat16)]
framework_model = Reshape(target_shape)
fw_out = framework_model(*inputs)

framework_model = Reshape(target_shape)
compiled_model = forge.compile(framework_model, sample_inputs=inputs)
co_out = compiled_model(*inputs)

co_out = [co.to("cpu") for co in co_out]
assert compare_with_golden_pcc(golden=fw_out, calculated=co_out[0], pcc=0.99)
# some of them are failing with pcc < 0.99
verify(inputs, framework_model, compiled_model)


@pytest.mark.parametrize(
Expand All @@ -451,13 +367,11 @@ def forward(self, x):
return torch.softmax(x, dim=dim)

inputs = [torch.rand(shape)]
framework_model = Softmax(dim)
fw_out = framework_model(*inputs)

framework_model = Softmax(dim)
compiled_model = forge.compile(framework_model, sample_inputs=inputs)
co_out = compiled_model(*inputs)[0].to("cpu")

assert compare_with_golden_pcc(fw_out, co_out, pcc=0.99)
verify(inputs, framework_model, compiled_model)


@pytest.mark.parametrize(
Expand All @@ -481,14 +395,9 @@ def forward(self, a):
inputs = [torch.rand(input_shape)] # pcc fails if we use dtype=torch.bfloat16

framework_model = Squeeze(dim)
fw_out = framework_model(*inputs)

compiled_model = forge.compile(framework_model, sample_inputs=inputs)
co_out = compiled_model(*inputs)

co_out = [co.to("cpu") for co in co_out]
assert co_out[0].shape == fw_out.shape
assert compare_with_golden_pcc(golden=fw_out, calculated=co_out[0], pcc=0.99)
verify(inputs, framework_model, compiled_model)


@pytest.mark.parametrize(
Expand All @@ -498,7 +407,6 @@ def forward(self, a):
],
)
@pytest.mark.push
@pytest.mark.xfail(reason="Found Unsupported operations while lowering from TTForge to TTIR in forward graph")
def test_tanh(input_shape):
class Tanh(nn.Module):
def __init__(self):
Expand All @@ -510,14 +418,9 @@ def forward(self, a):
inputs = [torch.rand(input_shape)]

framework_model = Tanh()
fw_out = framework_model(*inputs)

compiled_model = forge.compile(framework_model, sample_inputs=inputs)
co_out = compiled_model(*inputs)

co_out = [co.to("cpu") for co in co_out]
assert co_out[0].shape == fw_out.shape
assert compare_with_golden_pcc(golden=fw_out, calculated=co_out[0], pcc=0.99)
verify(inputs, framework_model, compiled_model)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -546,13 +449,11 @@ def forward(self, a):
return torch.transpose(a, *self.dims)

inputs = [torch.rand(shapes)]
framework_model = Transpose(dims)
fw_out = framework_model(*inputs)

framework_model = Transpose(dims)
compiled_model = forge.compile(framework_model, sample_inputs=inputs)
co_out = compiled_model(*inputs)[0].to("cpu")

assert compare_with_golden_pcc(fw_out, co_out, pcc=0.99)
verify(inputs, framework_model, compiled_model)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -587,11 +488,6 @@ def forward(self, a):
inputs = [torch.rand(input_shape)] # pcc fails if we use dtype=torch.bfloat16

framework_model = Unsqueeze(dim)
fw_out = framework_model(*inputs)

compiled_model = forge.compile(framework_model, sample_inputs=inputs)
co_out = compiled_model(*inputs)

co_out = [co.to("cpu") for co in co_out]
assert co_out[0].shape == fw_out.shape
assert compare_with_golden_pcc(golden=fw_out, calculated=co_out[0], pcc=0.9)
verify(inputs, framework_model, compiled_model)
10 changes: 4 additions & 6 deletions forge/test/mlir/vit/test_vit_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,21 @@
import torch
import forge
from test.mlir.vit.utils.utils import load_model
from forge.verify.verify import verify
from forge.verify.config import VerifyConfig


@pytest.mark.parametrize("model_path", ["google/vit-base-patch16-224"])
def test_vit_inference(model_path):

# Load Vision Transformer (ViT) model
framework_model, image_processor = load_model(model_path=model_path)

# Prepare input
input_image = torch.rand(1, 3, 224, 224) # Simulated image tensor
inputs = image_processor(images=input_image, return_tensors="pt").pixel_values

# Sanity run
fw_out = framework_model(inputs)

# Compile the model
compiled_model = forge.compile(framework_model, inputs)
co_out = compiled_model(inputs)

# TODO: add verification
# Run inference and verify the output
verify(inputs, framework_model, compiled_model, VerifyConfig(verify_data=False, verify_allclose=False))

0 comments on commit a6cba01

Please sign in to comment.