Skip to content

Commit

Permalink
Add image reshaping for statically reshaped OpenVINO model (#428)
Browse files Browse the repository at this point in the history
* add vae image processor

* add image reshaping for statically reshaped model

* add test

* format

* fix pipeline

* fix reshaping

* disable reshaping for inpaint SD models

* add reshaping for inpaint
  • Loading branch information
echarlaix authored Sep 22, 2023
1 parent 99f6008 commit 1db2651
Show file tree
Hide file tree
Showing 2 changed files with 256 additions and 40 deletions.
250 changes: 234 additions & 16 deletions optimum/intel/openvino/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import numpy as np
import openvino
import PIL
from diffusers import (
DDIMScheduler,
LMSDiscreteScheduler,
Expand Down Expand Up @@ -351,6 +352,13 @@ def width(self) -> int:
return -1
return width.get_length() * self.vae_scale_factor

@property
def _batch_size(self) -> int:
batch_size = self.unet.model.inputs[0].get_partial_shape()[0]
if batch_size.is_dynamic:
return -1
return batch_size.get_length()

def _reshape_unet(
self,
model: openvino.runtime.Model,
Expand Down Expand Up @@ -649,6 +657,7 @@ def __call__(
width = width or self.unet.config.get("sample_size", 64) * self.vae_scale_factor
_height = self.height
_width = self.width
expected_batch_size = self._batch_size

if _height != -1 and height != _height:
logger.warning(
Expand All @@ -664,11 +673,15 @@ def __call__(
)
width = _width

if guidance_scale is not None and guidance_scale <= 1 and not self.is_dynamic:
raise ValueError(
f"`guidance_scale` was set to {guidance_scale}, static shapes are only supported for `guidance_scale` > 1, "
"please set `dynamic_shapes` to `True` when loading the model."
)
if expected_batch_size != -1:
if isinstance(prompt, str):
batch_size = 1
elif isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = kwargs.get("prompt_embeds").shape[0]

_raise_invalid_batch_size(expected_batch_size, batch_size, num_images_per_prompt, guidance_scale)

return StableDiffusionPipelineMixin.__call__(
self,
Expand All @@ -684,16 +697,115 @@ def __call__(


class OVStableDiffusionImg2ImgPipeline(OVStableDiffusionPipelineBase, StableDiffusionImg2ImgPipelineMixin):
def __call__(self, *args, **kwargs):
# TODO : add default height and width if model statically reshaped
# resize image if doesn't match height and width given during reshaping
return StableDiffusionImg2ImgPipelineMixin.__call__(self, *args, **kwargs)
def __call__(
self,
prompt: Optional[Union[str, List[str]]] = None,
image: Union[np.ndarray, PIL.Image.Image] = None,
strength: float = 0.8,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: int = 1,
**kwargs,
):
_height = self.height
_width = self.width
expected_batch_size = self._batch_size

if _height != -1 and _width != -1:
image = self.image_processor.preprocess(image, height=_height, width=_width).transpose(0, 2, 3, 1)

if expected_batch_size != -1:
if isinstance(prompt, str):
batch_size = 1
elif isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = kwargs.get("prompt_embeds").shape[0]

_raise_invalid_batch_size(expected_batch_size, batch_size, num_images_per_prompt, guidance_scale)

return StableDiffusionImg2ImgPipelineMixin.__call__(
self,
prompt=prompt,
image=image,
strength=strength,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
negative_prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
**kwargs,
)


class OVStableDiffusionInpaintPipeline(OVStableDiffusionPipelineBase, StableDiffusionInpaintPipelineMixin):
def __call__(self, *args, **kwargs):
# TODO : add default height and width if model statically reshaped
return StableDiffusionInpaintPipelineMixin.__call__(self, *args, **kwargs)
def __call__(
self,
prompt: Optional[Union[str, List[str]]],
image: PIL.Image.Image,
mask_image: PIL.Image.Image,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: int = 1,
**kwargs,
):
height = height or self.unet.config.get("sample_size", 64) * self.vae_scale_factor
width = width or self.unet.config.get("sample_size", 64) * self.vae_scale_factor
_height = self.height
_width = self.width
expected_batch_size = self._batch_size

if _height != -1 and _width != -1:
if height != _height:
logger.warning(
f"`height` was set to {height} but the static model will output images of height {_height}."
"To fix the height, please reshape your model accordingly using the `.reshape()` method."
)
height = _height

if width != _width:
logger.warning(
f"`width` was set to {width} but the static model will output images of width {_width}."
"To fix the width, please reshape your model accordingly using the `.reshape()` method."
)
width = _width

if isinstance(image, list):
image = [self.image_processor.resize(i, _height, _width) for i in image]
else:
image = self.image_processor.resize(image, _height, _width)

if isinstance(mask_image, list):
mask_image = [self.image_processor.resize(i, _height, _width) for i in mask_image]
else:
mask_image = self.image_processor.resize(mask_image, _height, _width)

if expected_batch_size != -1:
if isinstance(prompt, str):
batch_size = 1
elif isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = kwargs.get("prompt_embeds").shape[0]

_raise_invalid_batch_size(expected_batch_size, batch_size, num_images_per_prompt, guidance_scale)

return StableDiffusionInpaintPipelineMixin.__call__(
self,
prompt=prompt,
image=image,
mask_image=mask_image,
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
negative_prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
**kwargs,
)


class OVStableDiffusionXLPipelineBase(OVStableDiffusionPipelineBase):
Expand All @@ -718,10 +830,116 @@ def __init__(self, *args, add_watermarker: Optional[bool] = None, **kwargs):


class OVStableDiffusionXLPipeline(OVStableDiffusionXLPipelineBase, StableDiffusionXLPipelineMixin):
def __call__(self, *args, **kwargs):
return StableDiffusionXLPipelineMixin.__call__(self, *args, **kwargs)
def __call__(
self,
prompt: Optional[Union[str, List[str]]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
guidance_scale: float = 5.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: int = 1,
**kwargs,
):
height = height or self.unet.config["sample_size"] * self.vae_scale_factor
width = width or self.unet.config["sample_size"] * self.vae_scale_factor
_height = self.height
_width = self.width
expected_batch_size = self._batch_size

if _height != -1 and height != _height:
logger.warning(
f"`height` was set to {height} but the static model will output images of height {_height}."
"To fix the height, please reshape your model accordingly using the `.reshape()` method."
)
height = _height

if _width != -1 and width != _width:
logger.warning(
f"`width` was set to {width} but the static model will output images of width {_width}."
"To fix the width, please reshape your model accordingly using the `.reshape()` method."
)
width = _width

if expected_batch_size != -1:
if isinstance(prompt, str):
batch_size = 1
elif isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = kwargs.get("prompt_embeds").shape[0]

_raise_invalid_batch_size(expected_batch_size, batch_size, num_images_per_prompt, guidance_scale)

return StableDiffusionXLPipelineMixin.__call__(
self,
prompt=prompt,
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
negative_prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
**kwargs,
)


class OVStableDiffusionXLImg2ImgPipeline(OVStableDiffusionXLPipelineBase, StableDiffusionXLImg2ImgPipelineMixin):
def __call__(self, *args, **kwargs):
return StableDiffusionXLImg2ImgPipelineMixin.__call__(self, *args, **kwargs)
def __call__(
self,
prompt: Optional[Union[str, List[str]]] = None,
image: Union[np.ndarray, PIL.Image.Image] = None,
strength: float = 0.3,
num_inference_steps: int = 50,
guidance_scale: float = 5.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: int = 1,
**kwargs,
):
_height = self.height
_width = self.width
expected_batch_size = self._batch_size

if _height != -1 and _width != -1:
image = self.image_processor.preprocess(image, height=_height, width=_width).transpose(0, 2, 3, 1)

if expected_batch_size != -1:
if isinstance(prompt, str):
batch_size = 1
elif isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = kwargs.get("prompt_embeds").shape[0]

_raise_invalid_batch_size(expected_batch_size, batch_size, num_images_per_prompt, guidance_scale)

return StableDiffusionXLImg2ImgPipelineMixin.__call__(
self,
prompt=prompt,
image=image,
strength=strength,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
negative_prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
**kwargs,
)


def _raise_invalid_batch_size(
expected_batch_size: int, batch_size: int, num_images_per_prompt: int, guidance_scale: float
):
current_batch_size = batch_size * num_images_per_prompt * (1 if guidance_scale <= 1 else 2)

if expected_batch_size != current_batch_size:
msg = ""
if guidance_scale is not None and guidance_scale <= 1:
msg = f"`guidance_scale` was set to {guidance_scale}, static shapes are currently only supported for `guidance_scale` > 1 "

raise ValueError(
"The model was statically reshaped and the pipeline inputs do not match the expected shapes. "
f"The `batch_size`, `num_images_per_prompt` and `guidance_scale` were respectively set to {batch_size}, {num_images_per_prompt} and {guidance_scale}. "
f"The static model expects an input of size equal to {expected_batch_size} and got the following value instead : {current_batch_size}. "
f"To fix this, please either provide a different inputs to your model so that `batch_size` * `num_images_per_prompt` * 2 is equal to {expected_batch_size} "
"or reshape it again accordingly using the `.reshape()` method by setting `batch_size` to -1. " + msg
)
46 changes: 22 additions & 24 deletions tests/openvino/test_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,11 @@ def test_num_images_per_prompt_static_model(self, model_arch: str):
pipeline = self.MODEL_CLASS.from_pretrained(model_id, export=True, compile=False, dynamic_shapes=False)
batch_size, num_images, height, width = 2, 3, 128, 64
pipeline.half()
inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size)
pipeline.reshape(batch_size=batch_size, height=height, width=width, num_images_per_prompt=num_images)
outputs = pipeline(**inputs, num_images_per_prompt=num_images, generator=np.random.RandomState(0)).images
self.assertEqual(outputs.shape, (batch_size * num_images, height, width, 3))
for _height in [height, height + 16]:
inputs = self.generate_inputs(height=_height, width=width, batch_size=batch_size)
outputs = pipeline(**inputs, num_images_per_prompt=num_images).images
self.assertEqual(outputs.shape, (batch_size * num_images, height, width, 3))

def generate_inputs(self, height=128, width=128, batch_size=1, input_type="np"):
inputs = _generate_inputs(batch_size)
Expand Down Expand Up @@ -264,21 +265,15 @@ def test_num_images_per_prompt_static_model(self, model_arch: str):
model_id = MODEL_NAMES[model_arch]
pipeline = self.MODEL_CLASS.from_pretrained(model_id, export=True, compile=False)
batch_size, num_images, height, width = 3, 4, 128, 64
prompt = "sailing ship in storm by Leonardo da Vinci"
pipeline.half()
pipeline.reshape(batch_size=batch_size, height=height, width=width, num_images_per_prompt=num_images)
self.assertFalse(pipeline.is_dynamic)
pipeline.compile()
# Verify output shapes requirements not matching the static model don't impact the final outputs
outputs = pipeline(
[prompt] * batch_size,
num_inference_steps=2,
num_images_per_prompt=num_images,
height=height + 8,
width=width,
output_type="np",
).images
self.assertEqual(outputs.shape, (batch_size * num_images, height, width, 3))
# Verify output shapes requirements not matching the static model doesn't impact the final outputs
for _height in [height, height + 16]:
inputs = _generate_inputs(batch_size)
outputs = pipeline(**inputs, num_images_per_prompt=num_images, height=_height, width=width).images
self.assertEqual(outputs.shape, (batch_size * num_images, height, width, 3))

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_height_width_properties(self, model_arch: str):
Expand Down Expand Up @@ -341,10 +336,11 @@ def test_num_images_per_prompt_static_model(self, model_arch: str):
pipeline = self.MODEL_CLASS.from_pretrained(model_id, export=True, compile=False, dynamic_shapes=False)
batch_size, num_images, height, width = 1, 3, 128, 64
pipeline.half()
inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size)
pipeline.reshape(batch_size=batch_size, height=height, width=width, num_images_per_prompt=num_images)
outputs = pipeline(**inputs, num_images_per_prompt=num_images, generator=np.random.RandomState(0)).images
self.assertEqual(outputs.shape, (batch_size * num_images, height, width, 3))
for _height in [height, height + 16]:
inputs = self.generate_inputs(height=_height, width=width, batch_size=batch_size)
outputs = pipeline(**inputs, num_images_per_prompt=num_images).images
self.assertEqual(outputs.shape, (batch_size * num_images, height, width, 3))

def generate_inputs(self, height=128, width=128, batch_size=1):
inputs = super(OVStableDiffusionInpaintPipelineTest, self).generate_inputs(height, width, batch_size)
Expand Down Expand Up @@ -432,10 +428,11 @@ def test_num_images_per_prompt_static_model(self, model_arch: str):
pipeline.reshape(batch_size=batch_size, height=height, width=width, num_images_per_prompt=num_images)
self.assertFalse(pipeline.is_dynamic)
pipeline.compile()
# Verify output shapes requirements not matching the static model don't impact the final outputs
inputs = _generate_inputs(batch_size)
outputs = pipeline(**inputs, num_images_per_prompt=num_images, height=height, width=width).images
self.assertEqual(outputs.shape, (batch_size * num_images, height, width, 3))

for _height in [height, height + 16]:
inputs = _generate_inputs(batch_size)
outputs = pipeline(**inputs, num_images_per_prompt=num_images, height=_height, width=width).images
self.assertEqual(outputs.shape, (batch_size * num_images, height, width, 3))


class OVStableDiffusionXLImg2ImgPipelineTest(unittest.TestCase):
Expand Down Expand Up @@ -467,10 +464,11 @@ def test_num_images_per_prompt_static_model(self, model_arch: str):
pipeline = self.MODEL_CLASS.from_pretrained(model_id, export=True, compile=False, dynamic_shapes=False)
batch_size, num_images, height, width = 2, 3, 128, 64
pipeline.half()
inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size)
pipeline.reshape(batch_size=batch_size, height=height, width=width, num_images_per_prompt=num_images)
outputs = pipeline(**inputs, num_images_per_prompt=num_images, generator=np.random.RandomState(0)).images
self.assertEqual(outputs.shape, (batch_size * num_images, height, width, 3))
for _height in [height, height + 16]:
inputs = self.generate_inputs(height=_height, width=width, batch_size=batch_size)
outputs = pipeline(**inputs, num_images_per_prompt=num_images).images
self.assertEqual(outputs.shape, (batch_size * num_images, height, width, 3))

def generate_inputs(self, height=128, width=128, batch_size=1, input_type="np"):
inputs = _generate_inputs(batch_size)
Expand Down

0 comments on commit 1db2651

Please sign in to comment.