Skip to content

Commit

Permalink
fix llava legacy procesing selection
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Oct 17, 2024
1 parent 682362d commit e1d027e
Showing 1 changed file with 16 additions and 12 deletions.
28 changes: 16 additions & 12 deletions optimum/intel/openvino/modeling_visual_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,12 +749,12 @@ def merge_vision_text_embeddings(
inputs_embeds = torch.from_numpy(inputs_embeds) if isinstance(inputs_embeds, np.ndarray) else inputs_embeds
if legacy_processing is None:
legacy_processing = (
not hasattr(self.config, "image_seq_length")
not (hasattr(self.config, "image_seq_length") and (input_ids.shape[-1] == 1))
or ((input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length)
or (input_ids.shape[-1] == 1)
)

if legacy_processing:
logger.warn("LEGACY")
pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1

num_images, num_image_patches, embed_dim = image_features.shape
Expand Down Expand Up @@ -832,11 +832,13 @@ def merge_vision_text_embeddings(
def get_multimodal_embeddings(
self, input_ids, pixel_values=None, attention_mask=None, position_ids=None, past_key_values=None, **kwargs
):
legacy_processing = (
not hasattr(self.config, "image_seq_length")
or ((input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length)
or (input_ids.shape[-1] == 1 and pixel_values is not None)
)
legacy_processing = getattr(self, "_legacy_processing", not hasattr(self.config, "image_seq_length"))
inputs_embeds = self.get_text_embeddings(input_ids, **kwargs)

if pixel_values is not None and not legacy_processing and past_key_values is None:
legacy_processing = (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
self._legacy_processing = legacy_processing

inputs_embeds, attention_mask, position_ids = super().get_multimodal_embeddings(
input_ids, pixel_values, attention_mask, position_ids, legacy_processing=legacy_processing, **kwargs
)
Expand All @@ -847,6 +849,7 @@ def get_multimodal_embeddings(
return inputs_embeds, attention_mask, position_ids

def _filter_unattended_tokens(self, input_ids, attention_mask, past_key_values):
logger.warn("LEGACY")
if not self.language_model.stateful:
first_layer_past_key_value = torch.from_numpy(past_key_values[0][0][:, :, :, 0])
else:
Expand Down Expand Up @@ -954,12 +957,13 @@ def get_multimodal_embeddings(
from transformers.models.llava_next.modeling_llava_next import image_size_to_num_patches

inputs_embeds = self.get_text_embeddings(input_ids, **kwargs)
legacy_processing = getattr(self, "_legacy_processing", not hasattr(self.config, "image_seq_length"))


if pixel_values is not None and not legacy_processing and past_key_values is None:
legacy_processing = (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
self._legacy_processing = legacy_processing

legacy_processing = (
not hasattr(self.config, "image_seq_length")
or ((input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length)
or (input_ids.shape[-1] == 1 and pixel_values is not None)
)
if pixel_values is not None and pixel_values.size(0) > 0:
# ! infer image_num_patches from image_sizes
image_num_patches = [
Expand Down

0 comments on commit e1d027e

Please sign in to comment.