From 69fac7bb937a5e9f7c9ff17bc8b1f14ad4c04fd2 Mon Sep 17 00:00:00 2001 From: eaidova Date: Mon, 23 Dec 2024 15:14:53 +0400 Subject: [PATCH] restore SDPA in gpt neo after 4.45 --- optimum/exporters/openvino/model_configs.py | 20 ++++++ optimum/exporters/openvino/model_patcher.py | 68 +++++++++++++++++++++ 2 files changed, 88 insertions(+) diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py index 02a8c300a..cecd8f36e 100644 --- a/optimum/exporters/openvino/model_configs.py +++ b/optimum/exporters/openvino/model_configs.py @@ -30,6 +30,7 @@ FalconOnnxConfig, GemmaOnnxConfig, GPTJOnnxConfig, + GPTNeoOnnxConfig, GPTNeoXOnnxConfig, IBertOnnxConfig, LlamaOnnxConfig, @@ -68,6 +69,7 @@ FluxTransfromerModelPatcher, Gemma2ModelPatcher, GptJModelPatcher, + GptNeoModelPatcher, GptNeoxJapaneseModelPatcher, GptNeoxModelPatcher, IBertModelPatcher, @@ -790,6 +792,24 @@ def patch_model_for_export( return GptNeoxJapaneseModelPatcher(self, model, model_kwargs=model_kwargs) +@register_in_tasks_manager( + "gpt-neo", + *[ + "feature-extraction", + "feature-extraction-with-past", + "text-generation", + "text-generation-with-past", + "text-classification", + ], + library_name="transformers", +) +class GPTNeoOpenVINOConfig(GPTNeoOnnxConfig): + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return GptNeoModelPatcher(self, model, model_kwargs=model_kwargs) + + @register_in_tasks_manager( "gptj", *[ diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index 825eaac48..1128e3b79 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -2654,6 +2654,74 @@ def __exit__(self, exc_type, exc_value, traceback): unpatch_update_causal_mask(self._model, "gpt_neox_japanese") +def _gpt_neo_sdpa_attn(self, query, key, value, attention_mask=None, head_mask=None): + # Keep the attention weights computation in fp32 to avoid overflow issues + query = query.to(torch.float32) + key = key.to(torch.float32) + + # Apply sliding window masking for local attention layers + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] + # different from original for prevent overflow, apply to mask instead of directly to weights + mask_value = torch.finfo(torch.float16).min + # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. + # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` + mask_value = torch.tensor(mask_value, dtype=query.dtype).to(query.device) + if attention_mask is None: + attention_mask = torch.ones_like(causal_mask) + attention_mask = torch.where(causal_mask, attention_mask[:, :, :, : key.shape[-2]], mask_value) + + # Mask heads if we want to + if head_mask is not None: + attention_mask = attention_mask * head_mask + + attn_output = torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask) + + return attn_output, None + + +def _gpt_neo_attn_forward( + self, + hidden_states, + attention_mask=None, + layer_past=None, + head_mask=None, + use_cache=False, + output_attentions=False, + cache_position=None, +): + if output_attentions: + self._attn = self._orig_attn + + return self._orig_forward( + hidden_states, + attention_mask=attention_mask, + layer_past=layer_past, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + ) + + +class GptNeoModelPatcher(DecoderModelPatcher): + def __enter__(self): + super().__enter__() + if is_transformers_version(">=", "4.45.0") and is_torch_version(">=", "2.1.0"): + for layer in self._model.transformer.h: + self_attn = layer.attn.attention + self_attn._orig_attn = self_attn._attn + self_attn._attn = types.MethodType(_gpt_neo_sdpa_attn, self_attn) + self_attn._orig_forward = types.MethodType(_gpt_neo_attn_forward, self_attn) + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + for layer in self._model.transformer.h: + if hasattr(layer.attn.attention, "_orig_forward"): + layer.attn.attention.forward = layer.attn.attention._orig_forward + layer.attn.attention._attn = layer.attn.attention._orig_attn + + class Gemma2ModelPatcher(LlamaModelPatcher): def __init__( self,