From e1d027ee363ca8b615a87f29bf0a6f82851781bb Mon Sep 17 00:00:00 2001 From: eaidova Date: Thu, 17 Oct 2024 15:26:19 +0400 Subject: [PATCH] fix llava legacy procesing selection --- .../openvino/modeling_visual_language.py | 28 +++++++++++-------- 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/optimum/intel/openvino/modeling_visual_language.py b/optimum/intel/openvino/modeling_visual_language.py index 588235c1ec..e7148c5faf 100644 --- a/optimum/intel/openvino/modeling_visual_language.py +++ b/optimum/intel/openvino/modeling_visual_language.py @@ -749,12 +749,12 @@ def merge_vision_text_embeddings( inputs_embeds = torch.from_numpy(inputs_embeds) if isinstance(inputs_embeds, np.ndarray) else inputs_embeds if legacy_processing is None: legacy_processing = ( - not hasattr(self.config, "image_seq_length") + not (hasattr(self.config, "image_seq_length") and (input_ids.shape[-1] == 1)) or ((input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length) - or (input_ids.shape[-1] == 1) ) if legacy_processing: + logger.warn("LEGACY") pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 num_images, num_image_patches, embed_dim = image_features.shape @@ -832,11 +832,13 @@ def merge_vision_text_embeddings( def get_multimodal_embeddings( self, input_ids, pixel_values=None, attention_mask=None, position_ids=None, past_key_values=None, **kwargs ): - legacy_processing = ( - not hasattr(self.config, "image_seq_length") - or ((input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length) - or (input_ids.shape[-1] == 1 and pixel_values is not None) - ) + legacy_processing = getattr(self, "_legacy_processing", not hasattr(self.config, "image_seq_length")) + inputs_embeds = self.get_text_embeddings(input_ids, **kwargs) + + if pixel_values is not None and not legacy_processing and past_key_values is None: + legacy_processing = (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length + self._legacy_processing = legacy_processing + inputs_embeds, attention_mask, position_ids = super().get_multimodal_embeddings( input_ids, pixel_values, attention_mask, position_ids, legacy_processing=legacy_processing, **kwargs ) @@ -847,6 +849,7 @@ def get_multimodal_embeddings( return inputs_embeds, attention_mask, position_ids def _filter_unattended_tokens(self, input_ids, attention_mask, past_key_values): + logger.warn("LEGACY") if not self.language_model.stateful: first_layer_past_key_value = torch.from_numpy(past_key_values[0][0][:, :, :, 0]) else: @@ -954,12 +957,13 @@ def get_multimodal_embeddings( from transformers.models.llava_next.modeling_llava_next import image_size_to_num_patches inputs_embeds = self.get_text_embeddings(input_ids, **kwargs) + legacy_processing = getattr(self, "_legacy_processing", not hasattr(self.config, "image_seq_length")) + + + if pixel_values is not None and not legacy_processing and past_key_values is None: + legacy_processing = (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length + self._legacy_processing = legacy_processing - legacy_processing = ( - not hasattr(self.config, "image_seq_length") - or ((input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length) - or (input_ids.shape[-1] == 1 and pixel_values is not None) - ) if pixel_values is not None and pixel_values.size(0) > 0: # ! infer image_num_patches from image_sizes image_num_patches = [