Skip to content

Commit

Permalink
support any input resolution in stable diffusion models
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Dec 20, 2024
1 parent 8ef3997 commit f6fc95d
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 0 deletions.
19 changes: 19 additions & 0 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1783,6 +1783,25 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
return super().generate(input_name, framework, int_dtype, float_dtype)


class DummyUnetVisionInputGenerator(DummyVisionInputGenerator):
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
print("HERE")

if not input_name in ["sample", "latent_sample"]:
return super().generate(input_name, framework, int_dtype, float_dtype)
# add height and width discount for enable any resolution generation
return self.random_float_tensor(
shape=[self.batch_size, self.num_channels, self.height - 1 , self.width - 1],
framework=framework,
dtype=float_dtype,
)


@register_in_tasks_manager("unet", *["semantic-segmentation"], library_name="diffusers")
class UnetOpenVINOConfig(UNetOnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (DummyUnetVisionInputGenerator, ) + UNetOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES[1:]


@register_in_tasks_manager("sd3-transformer", *["semantic-segmentation"], library_name="diffusers")
class SD3TransformerOpenVINOConfig(UNetOnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (
Expand Down
39 changes: 39 additions & 0 deletions tests/openvino/test_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,18 @@ def test_compare_to_diffusers_pipeline(self, model_arch: str):
diffusers_output = diffusers_pipeline(**inputs, generator=get_generator("pt", SEED)).images

np.testing.assert_allclose(ov_output, diffusers_output, atol=6e-3, rtol=1e-2)

# test on inputs nondivisible on 64
height, width, batch_size = 96, 96, 1

for output_type in ["latent", "np", "pt"]:
inputs["output_type"] = output_type

ov_output = ov_pipeline(**inputs, generator=get_generator("pt", SEED)).images
diffusers_output = diffusers_pipeline(**inputs, generator=get_generator("pt", SEED)).images

np.testing.assert_allclose(ov_output, diffusers_output, atol=6e-3, rtol=1e-2)


@parameterized.expand(CALLBACK_SUPPORT_ARCHITECTURES)
@require_diffusers
Expand Down Expand Up @@ -541,6 +553,20 @@ def test_compare_to_diffusers_pipeline(self, model_arch: str):

np.testing.assert_allclose(ov_output, diffusers_output, atol=6e-3, rtol=1e-2)

# test generation when input resolution nondevisible on 64
height, width, batch_size = 96, 96, 1

inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size, model_type=model_arch)

for output_type in ["latent", "np", "pt"]:
print(output_type)
inputs["output_type"] = output_type

ov_output = ov_pipeline(**inputs, generator=get_generator("pt", SEED)).images
diffusers_output = diffusers_pipeline(**inputs, generator=get_generator("pt", SEED)).images

np.testing.assert_allclose(ov_output, diffusers_output, atol=6e-3, rtol=1e-2)

@parameterized.expand(SUPPORTED_ARCHITECTURES)
@require_diffusers
def test_image_reproducibility(self, model_arch: str):
Expand Down Expand Up @@ -776,6 +802,19 @@ def test_compare_to_diffusers_pipeline(self, model_arch: str):
diffusers_output = diffusers_pipeline(**inputs, generator=get_generator("pt", SEED)).images

np.testing.assert_allclose(ov_output, diffusers_output, atol=6e-3, rtol=1e-2)

# test generation when input resolution nondevisible on 64
height, width, batch_size = 96, 96, 1
inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size)

for output_type in ["latent", "np", "pt"]:
inputs["output_type"] = output_type

ov_output = ov_pipeline(**inputs, generator=get_generator("pt", SEED)).images
diffusers_output = diffusers_pipeline(**inputs, generator=get_generator("pt", SEED)).images

np.testing.assert_allclose(ov_output, diffusers_output, atol=6e-3, rtol=1e-2)


@parameterized.expand(SUPPORTED_ARCHITECTURES)
@require_diffusers
Expand Down

0 comments on commit f6fc95d

Please sign in to comment.