Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

phi3 vision #977

Merged
merged 9 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions optimum/exporters/openvino/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
from .utils import _MAX_UNCOMPRESSED_SIZE, MULTI_MODAL_TEXT_GENERATION_MODELS, clear_class_registry


FORCE_ATTN_MODEL_CLASSES = {"phi3-v": "eager"}

if TYPE_CHECKING:
from optimum.intel.openvino.configuration import OVConfig

Expand Down Expand Up @@ -264,6 +266,10 @@ def main_export(

if is_transformers_version(">=", "4.36") and model_type in SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED:
loading_kwargs["attn_implementation"] = "eager"

# some models force flash_attn attention by default that does not support load model on cpu
if is_transformers_version(">=", "4.36") and model_type in FORCE_ATTN_MODEL_CLASSES:
loading_kwargs["_attn_implementation"] = FORCE_ATTN_MODEL_CLASSES[model_type]
# there are some difference between remote and in library representation of past key values for some models,
# for avoiding confusion we disable remote code for them
if (
Expand Down
17 changes: 16 additions & 1 deletion optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,18 @@ def export_from_model(
)

model_name_or_path = model.config._name_or_path
maybe_save_preprocessors(model_name_or_path, output, trust_remote_code=trust_remote_code)
if preprocessors is not None:
# phi3-vision processor does not have chat_template attribute that breaks Processor saving on disk
if is_transformers_version(">=", "4.45") and model_type == "phi3-v" and len(preprocessors) > 1:
if not hasattr(preprocessors[1], "chat_template"):
preprocessors[1].chat_template = getattr(preprocessors[0], "chat_template", None)
for processor in preprocessors:
try:
processor.save_pretrained(output)
except Exception as ex:
logger.error(f"Saving {type(processor)} failed with {ex}")
eaidova marked this conversation as resolved.
Show resolved Hide resolved
else:
maybe_save_preprocessors(model_name_or_path, output, trust_remote_code=trust_remote_code)

files_subpaths = ["openvino_" + model_name + ".xml" for model_name in models_and_export_configs.keys()]

Expand Down Expand Up @@ -891,6 +902,10 @@ def _get_multi_modal_submodels_and_export_configs(
if model_type == "internvl-chat" and preprocessors is not None:
model.config.img_context_token_id = preprocessors[0].convert_tokens_to_ids("<IMG_CONTEXT>")

if model_type == "phi3-v":
model.config.glb_GN = model.model.vision_embed_tokens.glb_GN.tolist()
model.config.sub_GN = model.model.vision_embed_tokens.sub_GN.tolist()

if hasattr(model, "image_newline"):
model.config.image_newline = model.image_newline.tolist()
main_config_cls = TasksManager.get_exporter_config_constructor(
Expand Down
426 changes: 217 additions & 209 deletions optimum/exporters/openvino/model_configs.py

Large diffs are not rendered by default.

21 changes: 21 additions & 0 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -1369,6 +1369,7 @@ def phi3_442_forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, BaseModelOutputWithPast]:
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
Expand Down Expand Up @@ -3216,3 +3217,23 @@ def forward(self, input):
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
self._model.forward = self._model.__orig_forward


def phi3_vision_embeddings_forward(self, pixel_values: torch.FloatTensor):
return self.get_img_features(pixel_values)


class Phi3VisionImageEmbeddingsPatcher(ModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Dict[str, Any],
):
model.__orig_forward = model.forward
model.forward = types.MethodType(phi3_vision_embeddings_forward, model)
super().__init__(config, model, model_kwargs)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
self._model.forward = self._model.__orig_forward
2 changes: 1 addition & 1 deletion optimum/exporters/openvino/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def get_submodels(model):
return custom_export, fn_get_submodels


MULTI_MODAL_TEXT_GENERATION_MODELS = ["llava", "llava-next", "llava-qwen2", "internvl-chat", "minicpmv"]
MULTI_MODAL_TEXT_GENERATION_MODELS = ["llava", "llava-next", "llava-qwen2", "internvl-chat", "minicpmv", "phi3-v"]


def save_config(config, save_dir):
Expand Down
165 changes: 162 additions & 3 deletions optimum/intel/openvino/modeling_visual_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,15 @@ def forward(self, image_feature, pos_embed, key_padding_mask):
return result


MODEL_PARTS_CLS_MAPPING = {"resampler": OVResampler}
class OVVisionProjection(OVModelPart):
_model_name = "vision_projection"

def forward(self, img_features):
self._compile()
return self.request(img_features)[0]


MODEL_PARTS_CLS_MAPPING = {"resampler": OVResampler, "vision_projection": OVVisionProjection}


class OVModelForVisualCausalLM(OVBaseModel, GenerationMixin):
Expand Down Expand Up @@ -1802,8 +1810,8 @@ def preprocess_inputs(
raise ValueError("Tokenizer is required.")
if image is not None and processor is None:
raise ValueError("Processor is required.")
text_content = f"<image>\n{text}" if image is not None else text
messages = [{"role": "user", "content": text_content}]
text = f"<image>\n{text}" if image is not None else text
messages = [{"role": "user", "content": text}]
if tokenizer.chat_template is not None:
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
if image is not None:
Expand All @@ -1818,10 +1826,161 @@ def preprocess_inputs(
return result


class _OVPhi3VisionForCausalLM(OVModelForVisualCausalLM):
additional_parts = ["vision_projection"]

def __init__(
self,
language_model: ov.Model,
text_embeddings: ov.Model,
vision_embeddings: ov.Model,
config: PretrainedConfig = None,
device: str = "CPU",
dynamic_shapes: bool = True,
ov_config: Optional[Dict[str, str]] = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
quantization_config: Union[OVWeightQuantizationConfig, Dict] = None,
**kwargs,
):
super().__init__(
language_model,
text_embeddings,
vision_embeddings,
config,
device,
dynamic_shapes,
ov_config,
model_save_dir,
quantization_config,
**kwargs,
)
self.sub_GN = torch.tensor(self.config.sub_GN)
self.glb_GN = torch.tensor(self.config.glb_GN)

def get_vision_embeddings(self, pixel_values, image_sizes, **kwargs):
num_images, num_crops, c, h, w = pixel_values.shape
img_features = self.vision_embeddings(pixel_values.flatten(0, 1)).last_hidden_state.reshape(
num_images, num_crops, -1, self.config.img_processor["image_dim_out"]
)
image_features_proj = self.hd_feature_transform(img_features, image_sizes)
return image_features_proj

def hd_feature_transform(self, image_features, image_sizes):
"""
image_features: (num_images, num_crops+1, 24*24, 1024)
"""

image_features = torch.from_numpy(image_features)
global_image_features = image_features[:, 0] # (num_images, 24*24, 1024)
# global feature can be viewed as a special HD case with num_crops 1x1
global_image_features_hd = self.reshape_hd_patches_2x2merge(global_image_features, 1, 1)
global_image_features_hd_newline = self.add_image_newline(global_image_features_hd)

all_image_embeddings = []
# need a for loop to process each image because of different image sizes
# (patch arrangement is different for each image)
for i, img_size in enumerate(image_sizes):
h, w = img_size
h_crop = h // 336
w_crop = w // 336
num_crops = h_crop * w_crop

# NOTE: real num_crops is padded
# (num_crops, 24*24, 1024)
sub_image_features = image_features[i, 1 : 1 + num_crops]
sub_image_features_hd = self.reshape_hd_patches_2x2merge(sub_image_features, h_crop, w_crop)
sub_image_features_hd_newline = self.add_image_newline(sub_image_features_hd)

# [sub features, separator, global features]
all_image_embeddings.extend(
[
sub_image_features_hd_newline.squeeze(0), # (h_crop*12*(w_crop*12+1), 4096)
self.glb_GN.squeeze(0),
global_image_features_hd_newline[i],
]
)
image_features_proj = self.vision_projection(torch.cat(all_image_embeddings, dim=0).unsqueeze(0))[0]

return image_features_proj

def reshape_hd_patches_2x2merge(self, image_features, h_crop, w_crop):
"""
image_features: (num_images*num_crops, 24*24, 1024)
output: (num_images, h_crop*12, w_crop*12, 4096), h_crop*w_crop == num_crops
"""
N, L, C = image_features.shape
assert L == 24 * 24 and C == 1024 and N % (h_crop * w_crop) == 0
num_images = N // (h_crop * w_crop)
H = int(L**0.5)
image_features_hd = (
image_features.reshape(N, H, H, C) # N, 24, 24, 1024
.reshape(N, H // 2, 2, H // 2, 2, C) # N, 12, 2, 12, 2, 1024
.permute(0, 1, 3, 2, 4, 5) # N, 12, 12, 2, 2, 1024
.reshape(N, -1, 4 * C) # N, 144, 4096
.reshape(num_images, h_crop, w_crop, H // 2, H // 2, -1) # n_img, h_crop, w_crop, 12, 12, 4096
.permute(0, 1, 3, 2, 4, 5) # n_img, h_crop, 12, w_crop, 12, 4096
.reshape(num_images, h_crop * H // 2, w_crop * H // 2, 4 * C) # n_img, h_crop*12, w_crop*12, 4096
)

return image_features_hd

def add_image_newline(self, image_features_hd):
"""
image_features_hd: (num_images, h_crop*12, w_crop*12, 4096)
output: (num_images, (h_crop*12) * (w_crop*12+1), 4096)
"""
num_images, h, w, hid_dim = image_features_hd.shape
# add the newline token to the HD image feature patches
newline_embeddings = self.sub_GN.expand(num_images, h, -1, -1) # (n_img, h, 1, hid_dim)
image_features_hd_newline = torch.cat([image_features_hd, newline_embeddings], dim=2).reshape(
num_images, -1, hid_dim
)
return image_features_hd_newline

def get_multimodal_embeddings(
self, input_ids, pixel_values=None, attention_mask=None, position_ids=None, image_sizes=None, **kwargs
):
MAX_INPUT_ID = int(1e9)
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])

# positions for image tokens
positions = torch.nonzero((input_ids < 0) & (input_ids > -MAX_INPUT_ID), as_tuple=True)
has_image = len(positions[0].tolist()) > 0
input_ids = input_ids.clamp_min(0).clamp_max(self.config.vocab_size)
inputs_embeds = torch.from_numpy(self.get_text_embeddings(input_ids, **kwargs))
if has_image:
vision_embeds = self.get_vision_embeddings(
pixel_values, input_ids=input_ids, image_sizes=image_sizes, **kwargs
)
image_features_proj = torch.from_numpy(vision_embeds)
inputs_embeds = inputs_embeds.index_put(positions, image_features_proj, accumulate=False)

return inputs_embeds, attention_mask, position_ids

@staticmethod
def preprocess_inputs(
text: str,
image: Optional[Image] = None,
processor: Optional[AutoImageProcessor] = None,
tokenizer: Optional[PreTrainedTokenizer] = None,
):
if processor is None:
raise ValueError("Processor is required.")
if image is not None and "<|image_1|>" not in text:
text = "<|image_1|>\n" + text
if getattr(processor.tokenizer, "chat_template", None) is not None:
chat_prompt = [{"role": "user", "content": text}]
text = processor.tokenizer.apply_chat_template(chat_prompt, add_generation_prompt=True, tokenize=False)
inputs = processor(images=image, text=text, return_tensors="pt")
return inputs


MODEL_TYPE_TO_CLS_MAPPING = {
"llava": _OVLlavaForCausalLM,
"llava_next": _OVLlavaNextForCausalLM,
"internvl_chat": _OvInternVLForCausalLM,
"minicpmv": _OVMiniCPMVForCausalLM,
"llava-qwen2": _OVNanoLlavaForCausalLM,
"phi3_v": _OVPhi3VisionForCausalLM,
}
4 changes: 2 additions & 2 deletions tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1880,9 +1880,9 @@ class OVModelForVisualCausalLMIntegrationTest(unittest.TestCase):
if is_transformers_version(">=", "4.40.0"):
SUPPORTED_ARCHITECTURES += ["llava_next", "nanollava"]
if is_transformers_version(">=", "4.45.0"):
SUPPORTED_ARCHITECTURES += ["minicpmv", "internvl2"]
SUPPORTED_ARCHITECTURES += ["minicpmv", "internvl2", "phi3_v"]
TASK = "image-text-to-text"
REMOTE_CODE_MODELS = ["internvl2", "minicpmv", "nanollava"]
REMOTE_CODE_MODELS = ["internvl2", "minicpmv", "nanollava", "phi3_v"]

IMAGE = Image.open(
requests.get(
Expand Down
1 change: 1 addition & 0 deletions tests/openvino/utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@
"pix2struct": "fxmarty/pix2struct-tiny-random",
"phi": "echarlaix/tiny-random-PhiForCausalLM",
"phi3": "Xenova/tiny-random-Phi3ForCausalLM",
"phi3_v": "katuni4ka/tiny-random-phi3-vision",
"poolformer": "hf-internal-testing/tiny-random-PoolFormerModel",
"qwen": "katuni4ka/tiny-random-qwen",
"qwen2": "fxmarty/tiny-dummy-qwen2",
Expand Down
Loading