Skip to content

Commit

Permalink
Fix update causal mask for transformers 4.42 (#852)
Browse files Browse the repository at this point in the history
* fix update causal mask for transformers 4.42

* more models

* revert rope for phi3

* fix phi3

* phi3 issue
  • Loading branch information
eaidova authored Jul 29, 2024
1 parent 4b4bbcb commit 3ee174c
Show file tree
Hide file tree
Showing 2 changed files with 199 additions and 10 deletions.
39 changes: 39 additions & 0 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
PersimmonModelPatcher,
Phi3ModelPatcher,
QwenModelPatcher,
UpdateCausalMaskModelPatcher,
XverseModelPatcher,
)

Expand Down Expand Up @@ -119,6 +120,11 @@ class Qwen2OpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return UpdateCausalMaskModelPatcher(self, model, model_kwargs=model_kwargs)


@register_in_tasks_manager("qwen2-moe", *["text-generation", "text-generation-with-past"], library_name="transformers")
class Qwen2MoEOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
Expand All @@ -128,6 +134,11 @@ class Qwen2MoEOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return UpdateCausalMaskModelPatcher(self, model, model_kwargs=model_kwargs)


@register_in_tasks_manager("minicpm", *["text-generation", "text-generation-with-past"], library_name="transformers")
class MiniCPMOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
Expand All @@ -146,6 +157,11 @@ class StableLMOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return UpdateCausalMaskModelPatcher(self, model, model_kwargs=model_kwargs)


class ChatGLM2DummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator):
def __init__(
Expand Down Expand Up @@ -468,6 +484,11 @@ class Starcoder2OpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return UpdateCausalMaskModelPatcher(self, model, model_kwargs=model_kwargs)


@register_in_tasks_manager("internlm2", *["text-generation", "text-generation-with-past"], library_name="transformers")
class InternLM2OpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
Expand Down Expand Up @@ -532,6 +553,24 @@ def patch_model_for_export(
return Phi3ModelPatcher(self, model, model_kwargs=model_kwargs)


@register_in_tasks_manager(
"phi",
*[
"feature-extraction",
"feature-extraction-with-past",
"text-generation",
"text-generation-with-past",
"text-classification",
],
library_name="transformers",
)
class PhiOpenVINOConfig(PhiOnnxConfig):
def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return UpdateCausalMaskModelPatcher(self, model, model_kwargs=model_kwargs)


class OVFalconDummyPastKeyValuesGenerator(FalconDummyPastKeyValuesGenerator):
def __init__(
self,
Expand Down
170 changes: 160 additions & 10 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,11 @@ def patch_model_with_bettertransformer(model):
return model


def patch_update_causal_mask(model, transformers_version):
if is_transformers_version(">=", transformers_version):
model.model._update_causal_mask = types.MethodType(_llama_gemma_update_causal_mask, model.model)


def _mixtral_sparse_moe_block_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
""" """
batch_size, sequence_length, hidden_dim = hidden_states.shape
Expand Down Expand Up @@ -144,6 +149,8 @@ def _mixtral_sparse_moe_block_forward(self, hidden_states: torch.Tensor) -> torc
class MixtralModelPatcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()
patch_update_causal_mask(self._model, "4.42.0")

for layer in self._model.model.layers:
layer.block_sparse_moe._unpatched_forward = layer.block_sparse_moe.forward
layer.block_sparse_moe.forward = types.MethodType(
Expand All @@ -152,6 +159,9 @@ def __enter__(self):

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if hasattr(self._model.model, "_orig_update_causal_mask"):
self._model.model._update_causal_mask = self._model.model._orig_update_causal_mask

for layer in self._model.model.layers:
layer.block_sparse_moe.forward = layer.block_sparse_moe._unpatched_forward

Expand Down Expand Up @@ -549,11 +559,9 @@ def __enter__(self):

# llama/gemma has some accuracy issues with bf16 with transformers >= 4.39
# fill causal mask in slightly different way for avoid overflow on some platforms
patch_update_causal_mask(self._model, "4.39.0")

if is_transformers_version(">=", "4.39.0"):
self._model.model._orig_update_causal_mask = self._model.model._update_causal_mask
self._model.model._update_causal_mask = types.MethodType(
_llama_gemma_update_causal_mask, self._model.model
)
register_sin_cos_buffer(self._model)

def __exit__(self, exc_type, exc_value, traceback):
Expand Down Expand Up @@ -620,7 +628,7 @@ def _mistral_update_causal_mask(
return None

dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
min_dtype = torch.finfo(torch.float16).min
sequence_length = input_tensor.shape[1]
# SlidingWindowCache
if using_sliding_window_cache:
Expand Down Expand Up @@ -1328,6 +1336,128 @@ def __exit__(self, exc_type, exc_value, traceback):
block.attention.forward = block.attention._orig_forward


def phi3_442_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask

output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape[:2]
elif inputs_embeds is not None:
batch_size, seq_length = inputs_embeds.shape[:2]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")

past_key_values_length = 0

if use_cache:
use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_key_values_length = past_key_values.get_usable_length(seq_length)

if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
if is_padding_right:
raise ValueError(
"You are attempting to perform batched generation with padding_side='right'"
" this may lead to unexpected behaviour for Flash Attention version of Phi3. Make sure to "
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
)

if self._attn_implementation == "flash_attention_2":
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
sliding_window=self.config.sliding_window,
)

hidden_states = inputs_embeds

# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None

for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
)

hidden_states = layer_outputs[0]

if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]

if output_attentions:
all_self_attns += (layer_outputs[1],)

hidden_states = self.norm(hidden_states)

# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)

next_cache = None
if use_cache:
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)


# Adapted from https://github.com/huggingface/transformers/blob/ccdabc5642bf84849af93f591e207dc625c8e1e1/src/transformers/models/phi3/modeling_phi3.py#L729
def _phi3_self_attn_sdpa_forward(
self,
Expand Down Expand Up @@ -1373,7 +1503,7 @@ def _phi3_self_attn_sdpa_forward(
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

key_states = repeat_kv(key_states, self.num_key_value_groups)
Expand Down Expand Up @@ -1411,6 +1541,11 @@ def _phi3_self_attn_sdpa_forward(
class Phi3ModelPatcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()

if is_transformers_version(">=", "4.42.0"):
self._model.model._orig_forward = self._model.model.forward
self._model.model.forward = types.MethodType(phi3_442_forward, self._model.model)

# https://github.com/huggingface/transformers/blob/30ee508c6c92a1c0aa0281d193c7c0fb815b8d2f/src/transformers/models/phi3/modeling_phi3.py#L113
# init inv_freq for torchscript tracing
for layer in self._model.model.layers:
Expand All @@ -1425,15 +1560,15 @@ def __enter__(self):
rotary_emb.base ** (torch.arange(0, rotary_emb.dim, 2, dtype=torch.int64).float() / rotary_emb.dim)
)

# phi3 has issue with bf16 inference, precollect sin/cos for rotary_position_embedding for avoid accuracy issues
register_sin_cos_buffer(self._model)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if hasattr(self._model.model, "_orig_forward"):
self._model.model.forward = self._model.model._orig_forward
if hasattr(self._model.model, "_orig_update_causal_mask"):
self._model.model._update_causal_mask = self._model.model._orig_update_causal_mask
for layer in self._model.model.layers:
if hasattr(layer.self_attn, "_orig_forward"):
layer.self_attn.forward = layer.self_attn._orig_forward
layer.self_attn.rotary_emb.forward = layer.self_attn.rotary_emb._orig_forward


def _aquila_self_attn_sdpa_forward(
Expand Down Expand Up @@ -2089,6 +2224,8 @@ def _persimmon_self_attn_sdpa_forward(
class PersimmonModelPatcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()
patch_update_causal_mask(self._model, "4.42.0")

for layer in self._model.model.layers:
if is_torch_version(">=", "2.1.0"):
orig_self_attn_fwd = layer.self_attn.forward
Expand All @@ -2097,6 +2234,8 @@ def __enter__(self):

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if hasattr(self._model.model, "_orig_update_causal_mask"):
self._model.model._update_causal_mask = self._model.model._orig_update_causal_mask
for layer in self._model.model.layers:
if hasattr(layer.self_attn, "_orig_forward"):
layer.self_attn.forward = layer.self_attn._orig_forward
Expand Down Expand Up @@ -2221,3 +2360,14 @@ def __exit__(self, exc_type, exc_value, traceback):
if hasattr(layer.attn, "_orig_attn"):
layer.attn._attn = layer.attn._orig_attn
layer.attn.forward = layer.attn._orig_forward


class UpdateCausalMaskModelPatcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()
patch_update_causal_mask(self._model, "4.42.0")

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if hasattr(self._model.model, "_orig_update_causal_mask"):
self._model.model._update_causal_mask = self._model.model._orig_update_causal_mask

0 comments on commit 3ee174c

Please sign in to comment.