diff --git a/.github/workflows/test_inc.yml b/.github/workflows/test_inc.yml index 81d102bc01..7b87a20d0c 100644 --- a/.github/workflows/test_inc.yml +++ b/.github/workflows/test_inc.yml @@ -4,9 +4,12 @@ name: Intel Neural Compressor - Test on: push: - branches: [ main ] + branches: + - main + - v*-release pull_request: - branches: [ main ] + branches: + - main concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} diff --git a/.github/workflows/test_ipex.yml b/.github/workflows/test_ipex.yml index 8e02bd5510..96ef047aaf 100644 --- a/.github/workflows/test_ipex.yml +++ b/.github/workflows/test_ipex.yml @@ -4,9 +4,12 @@ name: Intel IPEX - Test on: push: - branches: [ main ] + branches: + - main + - v*-release pull_request: - branches: [ main ] + branches: + - main concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} @@ -18,6 +21,7 @@ jobs: fail-fast: false matrix: python-version: [3.8, 3.9] + transformers-version: [4.39.0, 4.41.2] os: [ubuntu-latest] runs-on: ${{ matrix.os }} @@ -32,6 +36,7 @@ jobs: python -m pip install --upgrade pip pip install torch torchaudio torchvision --extra-index-url https://download.pytorch.org/whl/cpu pip install .[ipex,tests] + pip install transformers==${{ matrix.transformers-version }} - name: Test with Pytest run: | pytest tests/ipex/ diff --git a/.github/workflows/test_openvino.yml b/.github/workflows/test_openvino.yml index c7d20eb321..37cf81fecc 100644 --- a/.github/workflows/test_openvino.yml +++ b/.github/workflows/test_openvino.yml @@ -4,9 +4,12 @@ name: OpenVINO - Test on: push: - branches: [ main ] + branches: + - main + - v*-release pull_request: - branches: [ main ] + branches: + - main concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} @@ -17,14 +20,15 @@ jobs: strategy: fail-fast: false matrix: - python-version: [3.8, 3.11] + python-version: ["3.8", "3.12"] + transformers-version: ["4.36.0", "4.41.*"] os: [ubuntu-latest] runs-on: ${{ matrix.os }} steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install dependencies @@ -32,6 +36,7 @@ jobs: python -m pip install --upgrade pip # install PyTorch CPU version to avoid installing CUDA packages on GitHub runner without GPU pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu + pip install transformers==${{ matrix.transformers-version }} pip install .[openvino,openvino-tokenizers,tests,diffusers] onnxruntime - name: Test with Pytest run: | @@ -46,3 +51,4 @@ jobs: pip install openvino-nightly python -c "from optimum.intel import OVModelForCausalLM; OVModelForCausalLM.from_pretrained('hf-internal-testing/tiny-random-gpt2', export=True, compile=False)" optimum-cli export openvino -m hf-internal-testing/tiny-random-gpt2 gpt2-ov + diff --git a/.github/workflows/test_openvino_basic.yml b/.github/workflows/test_openvino_basic.yml index 3135e6c004..240428e70a 100644 --- a/.github/workflows/test_openvino_basic.yml +++ b/.github/workflows/test_openvino_basic.yml @@ -24,16 +24,16 @@ jobs: matrix: # Testing lower and upper bound of supported Python versions # This also ensures that the test fails if dependencies break for Python 3.7 - python-version: ["3.8", "3.11"] - transformers: ['transformers'] + python-version: ["3.8", "3.12"] optimum: ['optimum', 'git+https://github.com/huggingface/optimum.git'] + os: ["ubuntu-22.04", "windows-latest"] - runs-on: ubuntu-20.04 + runs-on: ${{ matrix.os }} steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} @@ -43,12 +43,17 @@ jobs: # optimum or transformers to a specific version # Install PyTorch CPU to prevent unnecessary downloading/installing of CUDA packages pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu - pip install .[tests] openvino onnx onnxruntime ${{ matrix.optimum}} ${{ matrix.transformers }} + pip install .[tests] openvino onnxruntime ${{ matrix.optimum}} - - name: Pip freeze + - name: Pip freeze run: pip freeze - name: Test with Pytest run: | pytest tests/openvino/test_modeling_basic.py - RUN_SLOW=1 pytest tests/openvino/test_modeling.py -s -m "run_slow" --durations=0 \ No newline at end of file + + - name: Slow tests + run: | + pytest tests/openvino/test_modeling.py -s -m "run_slow" --durations=0 + env: + RUN_SLOW: 1 diff --git a/.github/workflows/test_openvino_examples.yml b/.github/workflows/test_openvino_examples.yml index 747afa31b5..c76374e9ea 100644 --- a/.github/workflows/test_openvino_examples.yml +++ b/.github/workflows/test_openvino_examples.yml @@ -22,7 +22,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.8", "3.11"] + python-version: ["3.8", "3.12"] runs-on: ubuntu-22.04 diff --git a/.github/workflows/test_openvino_notebooks.yml b/.github/workflows/test_openvino_notebooks.yml index ed77077e87..34017e0baf 100644 --- a/.github/workflows/test_openvino_notebooks.yml +++ b/.github/workflows/test_openvino_notebooks.yml @@ -23,7 +23,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.8", "3.11"] + python-version: ["3.8", "3.12"] runs-on: ubuntu-22.04 diff --git a/README.md b/README.md index 49f0d79768..0226b5d470 100644 --- a/README.md +++ b/README.md @@ -239,3 +239,8 @@ Do not forget to install requirements for every example: cd pip install -r requirements.txt ``` + + +## Gaudi + +To train your model on [Intel Gaudi AI Accelerators (HPU)](https://docs.habana.ai/en/latest/index.html), check out [Optimum Habana](https://github.com/huggingface/optimum-habana) which provides a set of tools enabling easy model loading, training and inference on single- and multi-HPU settings for different downstream tasks. After training your model, feel free to submit it to the Intel [leaderboard](https://huggingface.co/spaces/Intel/powered_by_intel_llm_leaderboard) which is designed to evaluate, score, and rank open-source LLMs that have been pre-trained or fine-tuned on Intel Hardwares. Models submitted to the leaderboard will be evaluated on the Intel Developer Cloud. The evaluation platform consists of Gaudi Accelerators and Xeon CPUs running benchmarks from the Eleuther AI Language Model Evaluation Harness. diff --git a/docs/source/inference.mdx b/docs/source/inference.mdx index e0b60baa2e..305beac3c9 100644 --- a/docs/source/inference.mdx +++ b/docs/source/inference.mdx @@ -28,8 +28,12 @@ As shown in the table below, each task is associated with a class enabling to au | `image-classification` | `OVModelForImageClassification` | | `feature-extraction` | `OVModelForFeatureExtraction` | | `fill-mask` | `OVModelForMaskedLM` | -| `text-generation` | `OVModelForCausalLM` | -| `text2text-generation` | `OVModelForSeq2SeqLM` | +| `image-classification` | `OVModelForImageClassification` | +| `audio-classification` | `OVModelForAudioClassification` | +| `text-generation-with-past` | `OVModelForCausalLM` | +| `text2text-generation-with-past` | `OVModelForSeq2SeqLM` | +| `automatic-speech-recognition` | `OVModelForSpeechSeq2Seq` | +| `image-to-text` | `OVModelForVision2Seq` | ### Export @@ -42,7 +46,7 @@ optimum-cli export openvino --model gpt2 ov_model The example above illustrates exporting a checkpoint from the 🤗 Hub. When exporting a local model, first make sure that you saved both the model’s weights and tokenizer files in the same directory (`local_path`). When using CLI, pass the `local_path` to the model argument instead of the checkpoint name of the model hosted on the Hub and provide the `--task` argument. You can review the list of supported tasks in the 🤗 [Optimum documentation](https://huggingface.co/docs/optimum/exporters/task_manager). If task argument is not provided, it will default to the model architecture without any task specific head. -Here we set the `task` to `text-generation-with-past`, with the `-with-past` suffix enabling the re-use of the pre-computed key/values hidden-states `use_cache=True`. +The `-with-past` suffix enable the re-use of the pre-computed key/values hidden-states and is the recommended option, to export the model without (equivalent to `use_cache=False`), you will need to remove this suffix. ```bash optimum-cli export openvino --model local_path --task text-generation-with-past ov_model @@ -50,6 +54,12 @@ optimum-cli export openvino --model local_path --task text-generation-with-past To export your model in fp16, you can add `--weight-format fp16` when exporting your model. + + +Models larger than 1 billion parameters are exported to the OpenVINO format with 8-bit weights by default. You can disable it with `--weight-format fp32`. + + + Once the model is exported, you can load the OpenVINO model using : ```python @@ -126,7 +136,7 @@ model = OVModelForCausalLM.from_pretrained(model_id, load_in_8bit=True) -`load_in_8bit` is enabled by default for the models larger than 1 billion parameters. You can disable it with `load_in_8bit=False`. +If not specified, `load_in_8bit` will be set to `True` by default when models larger than 1 billion parameters are exported to the OpenVINO format (with `export=True`). You can disable it with `load_in_8bit=False`. diff --git a/docs/source/optimization_ov.mdx b/docs/source/optimization_ov.mdx index e018134964..c82f2ab384 100644 --- a/docs/source/optimization_ov.mdx +++ b/docs/source/optimization_ov.mdx @@ -44,7 +44,7 @@ model.save_pretrained(saving_directory) -`load_in_8bit` is enabled by default for the models larger than 1 billion parameters. You can disable it with `load_in_8bit=False`. +If not specified, `load_in_8bit` will be set to `True` by default when models larger than 1 billion parameters are exported to the OpenVINO format (with `export=True`). You can disable it with `load_in_8bit=False`. diff --git a/docs/source/reference_ov.mdx b/docs/source/reference_ov.mdx index 4c5ede653e..32385eae00 100644 --- a/docs/source/reference_ov.mdx +++ b/docs/source/reference_ov.mdx @@ -14,56 +14,113 @@ See the License for the specific language governing permissions and limitations under the License. --> -# Reference +# Models -## OVModelForFeatureExtraction +## Natural Language Processing -[[autodoc]] openvino.modeling.OVModelForFeatureExtraction +The following classes are available for the following natural language processing tasks. + +### OVModelForCausalLM + +[[autodoc]] openvino.modeling_decoder.OVModelForCausalLM + - forward + - generate -## OVModelForMaskedLM +### OVModelForMaskedLM [[autodoc]] openvino.modeling.OVModelForMaskedLM + - forward + +### OVModelForSeq2SeqLM + +[[autodoc]] openvino.modeling_seq2seq.OVModelForSeq2SeqLM + - forward -## OVModelForQuestionAnswering +### OVModelForQuestionAnswering [[autodoc]] openvino.modeling.OVModelForQuestionAnswering + - forward -## OVModelForSequenceClassification +### OVModelForSequenceClassification [[autodoc]] openvino.modeling.OVModelForSequenceClassification + - forward -## OVModelForTokenClassification +### OVModelForTokenClassification [[autodoc]] openvino.modeling.OVModelForTokenClassification + - forward -## OVModelForAudioClassification + +## Audio + +The following classes are available for the following audio tasks. + +### OVModelForAudioClassification [[autodoc]] openvino.modeling.OVModelForAudioClassification + - forward -## OVModelForAudioFrameClassification +### OVModelForAudioFrameClassification [[autodoc]] openvino.modeling.OVModelForAudioFrameClassification + - forward -## OVModelForCTC +### OVModelForCTC [[autodoc]] openvino.modeling.OVModelForCTC + - forward -## OVModelForAudioXVector +### OVModelForAudioXVector [[autodoc]] openvino.modeling.OVModelForAudioXVector + - forward + +### OVModelForSpeechSeq2Seq + +[[autodoc]] openvino.modeling_seq2seq.OVModelForSpeechSeq2Seq + - forward + + +## Computer Vision -## OVModelForImageClassification +The following classes are available for the following computer vision tasks. + +### OVModelForImageClassification [[autodoc]] openvino.modeling.OVModelForImageClassification + - forward -## OVModelForCausalLM -[[autodoc]] openvino.modeling_decoder.OVModelForCausalLM +## Multimodal -## OVModelForSeq2SeqLM +The following classes are available for the following multimodal tasks. -[[autodoc]] openvino.modeling_seq2seq.OVModelForSeq2SeqLM +### OVModelForVision2Seq + +[[autodoc]] openvino.modeling_seq2seq.OVModelForVision2Seq + - forward + +### OVModelForPix2Struct + +[[autodoc]] openvino.modeling_seq2seq.OVModelForPix2Struct + - forward + +## Custom Tasks + +### OVModelForCustomTasks + +[[autodoc]] openvino.modeling.OVModelForCustomTasks + - forward + +### OVModelForFeatureExtraction + +[[autodoc]] openvino.modeling.OVModelForFeatureExtraction + - forward + + +# Quantization -## OVQuantizer +### OVQuantizer [[autodoc]] openvino.quantization.OVQuantizer diff --git a/notebooks/openvino/quantized_generation_demo.ipynb b/notebooks/openvino/quantized_generation_demo.ipynb index 7671064088..5673243cb2 100644 --- a/notebooks/openvino/quantized_generation_demo.ipynb +++ b/notebooks/openvino/quantized_generation_demo.ipynb @@ -32,7 +32,7 @@ "metadata": {}, "outputs": [], "source": [ - "# ! pip install optimum[openvino,nncf] torch" + "# ! pip install optimum[openvino,nncf] torch==2.2.2" ] }, { diff --git a/optimum/commands/export/openvino.py b/optimum/commands/export/openvino.py index ffd084d4e6..07e1dcffae 100644 --- a/optimum/commands/export/openvino.py +++ b/optimum/commands/export/openvino.py @@ -19,9 +19,11 @@ from typing import TYPE_CHECKING, Optional from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE +from transformers.utils.quantization_config import QuantizationMethod from ...exporters import TasksManager from ...intel.utils.import_utils import DIFFUSERS_IMPORT_ERROR, is_diffusers_available +from ...utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors from ..base import BaseOptimumCLICommand, CommandInfo @@ -128,6 +130,33 @@ def parse_args_openvino(parser: "ArgumentParser"): "compression is applied, they are compressed to INT8." ), ) + optional_group.add_argument( + "--awq", + action="store_true", + default=None, + help=( + "Whether to apply AWQ algorithm. AWQ improves generation quality of INT4-compressed LLMs, but requires " + "additional time for tuning weights on a calibration dataset. To run AWQ, please also provide a dataset " + "argument. Note: it's possible that there will be no matching patterns in the model to apply AWQ, in such " + "case it will be skipped." + ), + ) + optional_group.add_argument( + "--sensitivity-metric", + type=str, + default=None, + help=( + "The sensitivity metric for assigning quantization precision to layers. Can be one of the following: " + "['weight_quantization_error', 'hessian_input_activation', 'mean_activation_variance', " + "'max_activation_variance', 'mean_activation_magnitude']." + ), + ) + optional_group.add_argument( + "--num-samples", + type=int, + default=None, + help="The maximum number of samples to take from the dataset for quantization.", + ) optional_group.add_argument( "--disable-stateful", action="store_true", @@ -180,7 +209,7 @@ def parse_args(parser: "ArgumentParser"): return parse_args_openvino(parser) def run(self): - from ...exporters.openvino.__main__ import main_export + from ...exporters.openvino.__main__ import infer_task, main_export, maybe_convert_tokenizers from ...intel.openvino.configuration import _DEFAULT_4BIT_CONFIGS, OVConfig if self.args.fp16: @@ -208,6 +237,10 @@ def run(self): and self.args.group_size is None and self.args.sym is None and self.args.all_layers is None + and self.args.dataset is None + and self.args.num_samples is None + and self.args.awq is None + and self.args.sensitivity_metric is None and self.args.model in _DEFAULT_4BIT_CONFIGS ): quantization_config = _DEFAULT_4BIT_CONFIGS[self.args.model] @@ -218,6 +251,10 @@ def run(self): "sym": self.args.sym or False, "group_size": -1 if is_int8 else self.args.group_size, "all_layers": None if is_int8 else self.args.all_layers, + "dataset": self.args.dataset, + "num_samples": self.args.num_samples, + "quant_method": QuantizationMethod.AWQ if self.args.awq else None, + "sensitivity_metric": self.args.sensitivity_metric, } if self.args.weight_format in {"int4_sym_g128", "int4_asym_g128", "int4_sym_g64", "int4_asym_g64"}: @@ -226,7 +263,6 @@ def run(self): ) quantization_config["sym"] = "asym" not in self.args.weight_format quantization_config["group_size"] = 128 if "128" in self.args.weight_format else 64 - quantization_config["dataset"] = self.args.dataset ov_config = OVConfig(quantization_config=quantization_config) library_name = TasksManager.infer_library_from_model(self.args.model, library_name=self.args.library) @@ -240,12 +276,11 @@ def run(self): if self.args.convert_tokenizer: logger.warning("`--convert-tokenizer` option is deprecated. Tokenizer will be converted by default.") - if ( - library_name == "diffusers" - and ov_config - and ov_config.quantization_config - and ov_config.quantization_config.dataset is not None - ): + quantization_config = ov_config.quantization_config if ov_config else None + quantize_with_dataset = quantization_config and getattr(quantization_config, "dataset", None) is not None + task = infer_task(self.args.task, self.args.model) + + if library_name == "diffusers" and quantize_with_dataset: if not is_diffusers_available(): raise ValueError(DIFFUSERS_IMPORT_ERROR.format("Export of diffusers models")) @@ -270,25 +305,29 @@ def run(self): else: raise NotImplementedError(f"Quantization in hybrid mode isn't supported for class {class_name}.") - model = model_cls.from_pretrained( - self.args.model, export=True, quantization_config=ov_config.quantization_config + model = model_cls.from_pretrained(self.args.model, export=True, quantization_config=quantization_config) + model.save_pretrained(self.args.output) + if not self.args.disable_convert_tokenizer: + maybe_convert_tokenizers(library_name, self.args.output, model) + elif task.startswith("text-generation") and quantize_with_dataset: + from optimum.intel import OVModelForCausalLM + + # To quantize a text-generation model with a dataset, an instantiated OVModelForCausalLM is required + model = OVModelForCausalLM.from_pretrained( + self.args.model, + export=True, + quantization_config=quantization_config, + stateful=not self.args.disable_stateful, + trust_remote_code=self.args.trust_remote_code, ) model.save_pretrained(self.args.output) - if self.args.disable_convert_tokenizer: - return - - # avoid import when using other exporters (IPEX, INC) - from ...exporters.openvino.convert import export_tokenizer - - output = Path(self.args.output) - tokenizer = getattr(model, "tokenizer", None) - if tokenizer is not None: - export_tokenizer(tokenizer, output / "tokenizer") - - tokenizer_2 = getattr(model, "tokenizer_2", None) - if tokenizer_2 is not None: - export_tokenizer(tokenizer_2, output / "tokenizer_2") + maybe_save_preprocessors(self.args.model, self.args.output, trust_remote_code=self.args.trust_remote_code) + if not self.args.disable_convert_tokenizer: + preprocessors = maybe_load_preprocessors( + self.args.model, trust_remote_code=self.args.trust_remote_code + ) + maybe_convert_tokenizers(library_name, self.args.output, preprocessors=preprocessors) else: # TODO : add input shapes main_export( diff --git a/optimum/exporters/ipex/model_patcher.py b/optimum/exporters/ipex/model_patcher.py index 60ff3b721b..c6ed33006e 100644 --- a/optimum/exporters/ipex/model_patcher.py +++ b/optimum/exporters/ipex/model_patcher.py @@ -13,23 +13,26 @@ # limitations under the License. from transformers.models.llama.modeling_llama import ( - LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, LlamaRMSNorm, ) -from optimum.intel.utils.import_utils import is_ipex_version +from optimum.intel.utils.import_utils import is_ipex_version, is_transformers_version from .modeling_utils import ( - _IPEXLlamaDecoderLayerRef, - _llama_attn_forward, + _IPEX_MINIMUM_VERSION_FOR_PATCHING, + _IPEXLlamaDecoderLayer, _llama_layer_norm_forward, _llama_model_forward, ) +# Please also update in the setup.py and .github/workflows/test_ipex.yml if you change the transformers version +_TRANSFORMERS_MIN_VERSION = "4.39.0" +_TRANSFORMERS_MAX_VERSION = "4.41.2" + _IPEX_EXPORTED_ARCH = ("LlamaForCausalLM",) _IPEX_EXPORTED_TASK = ("text-generation",) @@ -62,26 +65,17 @@ def patch_op(m, target_m, new_op_name, new_op): def _patch_llama_model(model): - if is_ipex_version("<", "2.5.0"): - raise ImportError("Only ipex version > 2.3.0 supports RotaryEmbedding and IndirectAccessKVCache") - - from intel_extension_for_pytorch.llm.modules import IndirectAccessKVCache, RotaryEmbedding - - ipex_rope = RotaryEmbedding( - model.config.max_position_embeddings, - model.config.hidden_size // model.config.num_attention_heads, - model.config.rope_theta, - model.config.architectures[0], - ) - ipex_scale_dot_product = IndirectAccessKVCache(text_max_length=model.config.max_position_embeddings) - patch_op(model, LlamaAttention, "ipex_rope", ipex_rope) - patch_op(model, LlamaAttention, "ipex_scale_dot_product", ipex_scale_dot_product) - + if is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_PATCHING): + raise ImportError(f"Only ipex version >= {_IPEX_MINIMUM_VERSION_FOR_PATCHING} supports llama model patching") + if is_transformers_version("<", _TRANSFORMERS_MIN_VERSION) or is_transformers_version( + ">", _TRANSFORMERS_MAX_VERSION + ): + raise ImportError( + f"Only transformers versions {_TRANSFORMERS_MIN_VERSION} ~ {_TRANSFORMERS_MAX_VERSION} are verified." + ) convert_functions(model, LlamaModel, "forward", _llama_model_forward) - convert_functions(model, LlamaAttention, "forward", _llama_attn_forward) convert_functions(model, LlamaRMSNorm, "forward", _llama_layer_norm_forward) - - convert_class(model, LlamaDecoderLayer, _IPEXLlamaDecoderLayerRef, model.config) + convert_class(model, LlamaDecoderLayer, _IPEXLlamaDecoderLayer, model.config) return model diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index f75e559eaf..5870a8c792 100644 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -19,9 +19,13 @@ from torch import nn from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from transformers.modeling_outputs import BaseModelOutputWithPast -from transformers.models.llama.modeling_llama import repeat_kv +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv from optimum.intel.utils.import_utils import is_ipex_version +from optimum.intel.utils.modeling_utils import _setattr_from_module + + +_IPEX_MINIMUM_VERSION_FOR_PATCHING = "2.3.0" # Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L83 @@ -29,90 +33,6 @@ def _llama_layer_norm_forward(self, hidden_states): return torch.ops.torch_ipex.rmsnorm(hidden_states, self.weight, self.variance_epsilon) -# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L321 -def _llama_attn_forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query = self.q_proj(hidden_states) - key = self.k_proj(hidden_states) - value = self.v_proj(hidden_states) - - kv_seq_len = q_len + past_key_value[0].size(-2) if past_key_value is not None else q_len - - query = query.view(bsz, q_len, self.num_heads, self.head_dim) - key = key.view(bsz, q_len, self.num_key_value_heads, self.head_dim) - value = value.view(bsz, q_len, self.num_key_value_heads, self.head_dim) - # Use ipex op to rotary position embedding more efficient. - key = self.ipex_rope( - key, - position_ids, - self.num_key_value_heads, - self.head_dim, - self.head_dim // 2, - self.head_dim, - kv_seq_len, - ) - query = self.ipex_rope( - query, - position_ids, - self.num_heads, - self.head_dim, - self.head_dim // 2, - self.head_dim, - kv_seq_len, - ) - - if use_cache: - # This ipex op pre-allocates buffers for past_key_values and use beam index history - # which to decide which beam should be used to make attention scale dot more efficient. - (attn_output, attn_weights, past_key_value) = self.ipex_scale_dot_product( - query, - key, - value, - math.sqrt(self.head_dim), - past_key_value, - None, - attention_mask, - ) - else: - value_states = value.transpose(1, 2) - query_states = query.transpose(1, 2) - key_states = key.transpose(1, 2) - kv_seq_len = key_states.shape[-2] - - past_key_value = None - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attention_mask is not None: - attn_weights = torch.tensor(attn_weights) + torch.tensor(attention_mask) - attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) - - attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - # Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L1130 def _llama_model_forward( self, @@ -216,30 +136,212 @@ def _llama_model_forward( ) -# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L694 -class _IPEXLlamaDecoderLayerRef(nn.Module): - def __init__(self, module, config, distributed=False): - if is_ipex_version("<", "2.5.0"): - raise ImportError("Only ipex version > 2.3.0 supports Linear2SiluMul and LinearAdd") +# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L321 +class _IPEXLlamaAttention(nn.Module): + def __init__(self, module, config, distributed=False) -> None: + if is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_PATCHING): + raise ImportError( + f"Only ipex version > {_IPEX_MINIMUM_VERSION_FOR_PATCHING} supports IndirectAccessKVCacheAttention, LinearAdd, RotaryEmbedding" + ) + super().__init__() + _setattr_from_module(self, module) + self.config = config + self.distributed = distributed + from intel_extension_for_pytorch.llm.modules import IndirectAccessKVCacheAttention, LinearAdd, RotaryEmbedding - from intel_extension_for_pytorch.llm.modules import Linear2SiluMul, LinearAdd + if not self.distributed: + self.mha_linear_add = LinearAdd(self.o_proj) + del self.__dict__["_modules"]["o_proj"] + self.ipex_scale_dot_product = IndirectAccessKVCacheAttention( + text_max_length=module.config.max_position_embeddings + ) + self.ipex_rope = RotaryEmbedding( + module.config.max_position_embeddings, + module.config.hidden_size // module.config.num_attention_heads, + module.config.rope_theta, + module.config.architectures[0], + ) + + def qkv_gemm(self, hidden_states): + bsz, seq_len, _ = hidden_states.size() + + query = self.q_proj(hidden_states) + key = self.k_proj(hidden_states) + value = self.v_proj(hidden_states) + query = query.view(bsz, seq_len, self.num_heads, self.head_dim) + key = key.view(bsz, seq_len, self.num_key_value_heads, self.head_dim) + value = value.view(bsz, seq_len, self.num_key_value_heads, self.head_dim) + + return query, key, value + + def rope(self, query, key, kv_seq_len, position_ids, use_cache): + if use_cache: + key = self.ipex_rope( + key, + position_ids, + self.num_key_value_heads, + self.head_dim, + self.head_dim // 2, + self.head_dim, + kv_seq_len, + ) + query = self.ipex_rope( + query, + position_ids, + self.num_heads, + self.head_dim, + self.head_dim // 2, + self.head_dim, + kv_seq_len, + ) + return query, key + + def sdpa_with_cache(self, query, key, value, past_key_value, attention_mask, position_ids): + # This ipex op pre-allocates buffers for past_key_values and use beam index history + # which to decide which beam should be used to make attention scale dot more efficient. + (attn_output, attn_weights, past_key_value) = self.ipex_scale_dot_product( + query, + key, + value, + math.sqrt(self.head_dim), + past_key_value, + None, + attention_mask, + ) + return attn_output, past_key_value, attn_weights + + # Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L341 + def sdpa_without_cache(self, query, key, value, past_key_value, attention_mask, position_ids): + value_states = value.transpose(1, 2) + query_states = query.transpose(1, 2) + key_states = key.transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + past_key_value = None + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: + attn_weights = torch.tensor(attn_weights) + torch.tensor(attention_mask) + attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + return attn_output, past_key_value, attn_weights + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + residual: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + Attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.n_positions - 1]`. + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. + residual (`torch.Tensor`): residual tensor to the layer of shape (batch, seq_len, embed_dim)` + """ + bsz, seq_len, _ = hidden_states.size() + kv_seq_len = seq_len + past_key_value[0].size(-2) if past_key_value is not None else seq_len + + query, key, value = self.qkv_gemm(hidden_states) + query, key = self.rope(query, key, kv_seq_len, position_ids, use_cache) + + sdpa = self.sdpa_with_cache if use_cache else self.sdpa_without_cache + attn_output, past_key_value, attn_weights = sdpa( + query, key, value, past_key_value, attention_mask, position_ids + ) + attn_output = attn_output.transpose(1, 2).reshape(bsz, seq_len, self.hidden_size) + + if hasattr(self, "mha_linear_add"): + attn_output = self.mha_linear_add(attn_output, residual) + else: + attn_output = self.o_proj(attn_output) + attn_output = residual + attn_output + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L186 +class _IPEXLlamaMLP(nn.Module): + def __init__(self, module, config, distributed=False) -> None: + if is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_PATCHING): + raise ImportError( + f"Only ipex version > {_IPEX_MINIMUM_VERSION_FOR_PATCHING} supports Linear2SiluMul, LinearAdd" + ) super().__init__() - for k, v in module.__dict__.items(): - setattr(self, k, v) - for k, v in module.__class__.__dict__.items(): - if k.startswith("__") or k.startswith("forward"): - continue - setattr(self.__class__, k, getattr(module.__class__, k)) + _setattr_from_module(self, module) + self.config = config self.distributed = distributed + from intel_extension_for_pytorch.llm.modules import Linear2SiluMul, LinearAdd + if not self.distributed: - self.mha_linear_add = LinearAdd(module.self_attn.o_proj) - self.mlp_linear_add = LinearAdd(module.mlp.down_proj) - del self.__dict__["_modules"]["self_attn"].o_proj - del self.__dict__["_modules"]["mlp"].down_proj - self.linear_silu_mul = Linear2SiluMul(module.mlp.gate_proj, module.mlp.up_proj) - del self.__dict__["_modules"]["mlp"].gate_proj - del self.__dict__["_modules"]["mlp"].up_proj + self.mlp_linear_add = LinearAdd(module.down_proj) + del self.__dict__["_modules"]["down_proj"] + self.linear_silu_mul = Linear2SiluMul(module.gate_proj, module.up_proj) + del self.__dict__["_modules"]["gate_proj"] + del self.__dict__["_modules"]["up_proj"] + + def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor = None, **kwargs): + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + residual (`torch.Tensor`): residual tensor to the layer of shape (batch, seq_len, embed_dim)` + """ + if hasattr(self, "linear_silu_mul"): + mlp_gate = self.linear_silu_mul(hidden_states) + if hasattr(self, "mlp_linear_add"): + hidden_states = self.mlp_linear_add(mlp_gate, residual) + else: + hidden_states = self.down_proj(mlp_gate) + hidden_states = residual + hidden_states + else: + hidden_states = self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states)) + hidden_states = residual + hidden_states + + return hidden_states + + +# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L694 +class _IPEXLlamaDecoderLayer(nn.Module): + def __init__(self, module, config, distributed=False): + super().__init__() + _setattr_from_module(self, module) + self.distributed = distributed + self.self_attn = _IPEXLlamaAttention(module.self_attn, config, distributed) + self.mlp = _IPEXLlamaMLP(module.mlp, config, distributed) def forward( self, @@ -255,15 +357,17 @@ def forward( Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + Attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, query_sequence_length, key_sequence_length)` if default attention is used. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.n_positions - 1]`. + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states """ residual = hidden_states @@ -277,24 +381,15 @@ def forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, + cache_position=None, + residual=residual, + **kwargs, ) - if not self.distributed: - hidden_states = self.mha_linear_add(hidden_states, residual) - else: - hidden_states = self.self_attn.o_proj(hidden_states) - hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) - - mlp_gate = self.linear_silu_mul(hidden_states) - - if not self.distributed: - hidden_states = self.mlp_linear_add(mlp_gate, residual) - else: - hidden_states = self.mlp.down_proj(mlp_gate) - hidden_states = residual + hidden_states + hidden_states = self.mlp(hidden_states, residual, **kwargs) outputs = (hidden_states,) diff --git a/optimum/exporters/openvino/__main__.py b/optimum/exporters/openvino/__main__.py index 9db6719069..927c98ac37 100644 --- a/optimum/exporters/openvino/__main__.py +++ b/optimum/exporters/openvino/__main__.py @@ -44,6 +44,22 @@ logger = logging.getLogger(__name__) +def infer_task(task, model_name_or_path): + task = TasksManager.map_from_synonym(task) + if task == "auto": + try: + task = TasksManager.infer_task_from_model(model_name_or_path) + except KeyError as e: + raise KeyError( + f"The task could not be automatically inferred. Please provide the argument --task with the relevant task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}" + ) + except RequestsConnectionError as e: + raise RequestsConnectionError( + f"The task could not be automatically inferred as this is available only for models hosted on the Hugging Face Hub. Please provide the argument --task with the relevant task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}" + ) + return task + + def main_export( model_name_or_path: str, output: Union[str, Path], @@ -174,7 +190,7 @@ def main_export( ov_config = OVConfig(quantization_config=q_config) original_task = task - task = TasksManager.map_from_synonym(task) + task = infer_task(task, model_name_or_path) framework = TasksManager.determine_framework(model_name_or_path, subfolder=subfolder, framework=framework) library_name_is_not_provided = library_name is None library_name = TasksManager.infer_library_from_model( @@ -188,18 +204,6 @@ def main_export( ) library_name = "transformers" - if task == "auto": - try: - task = TasksManager.infer_task_from_model(model_name_or_path) - except KeyError as e: - raise KeyError( - f"The task could not be automatically inferred. Please provide the argument --task with the relevant task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}" - ) - except RequestsConnectionError as e: - raise RequestsConnectionError( - f"The task could not be automatically inferred as this is available only for models hosted on the Hugging Face Hub. Please provide the argument --task with the relevant task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}" - ) - do_gptq_patching = False custom_architecture = False loading_kwargs = {} @@ -360,17 +364,35 @@ class StoreAttr(object): **kwargs_shapes, ) - # hide openvino import when using other exporters - from optimum.exporters.openvino.convert import export_tokenizer + if convert_tokenizer: + maybe_convert_tokenizers(library_name, output, model, preprocessors) + + # Unpatch modules after GPTQ export + if do_gptq_patching: + torch.cuda.is_available = orig_cuda_check + GPTQQuantizer.post_init_model = orig_post_init_model - if convert_tokenizer and is_openvino_tokenizers_available(): - if library_name != "diffusers": - tokenizer = next( - (preprocessor for preprocessor in preprocessors if isinstance(preprocessor, PreTrainedTokenizerBase)), - None, - ) - if tokenizer is not None: +def maybe_convert_tokenizers(library_name: str, output: Path, model=None, preprocessors=None): + """ + Tries to convert tokenizers to OV format and export them to disk. + + Arguments: + library_name (`str`): + The library name. + output (`Path`): + Path to save converted tokenizers to. + model (`PreTrainedModel`, *optional*, defaults to None): + Model instance. + preprocessors (`Iterable`, *optional*, defaults to None): + Iterable possibly containing tokenizers to be converted. + """ + from optimum.exporters.openvino.convert import export_tokenizer + + if is_openvino_tokenizers_available(): + if library_name != "diffusers" and preprocessors: + tokenizer = next(filter(lambda it: isinstance(it, PreTrainedTokenizerBase), preprocessors), None) + if tokenizer: try: export_tokenizer(tokenizer, output) except Exception as exception: @@ -378,18 +400,10 @@ class StoreAttr(object): "Could not load tokenizer using specified model ID or path. OpenVINO tokenizer/detokenizer " f"models won't be generated. Exception: {exception}" ) - else: - tokenizer = getattr(model, "tokenizer", None) - if tokenizer is not None: - export_tokenizer(tokenizer, output / "tokenizer") - - tokenizer_2 = getattr(model, "tokenizer_2", None) - if tokenizer_2 is not None: - export_tokenizer(tokenizer_2, output / "tokenizer_2") - elif convert_tokenizer and not is_openvino_tokenizers_available(): + elif model: + for tokenizer_name in ("tokenizer", "tokenizer_2"): + tokenizer = getattr(model, tokenizer_name, None) + if tokenizer: + export_tokenizer(tokenizer, output / tokenizer_name) + else: logger.warning("Tokenizer won't be converted.") - - # Unpatch modules after GPTQ export - if do_gptq_patching: - torch.cuda.is_available = orig_cuda_check - GPTQQuantizer.post_init_model = orig_post_init_model diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py index d69adc9da3..f78c58589b 100644 --- a/optimum/exporters/openvino/model_configs.py +++ b/optimum/exporters/openvino/model_configs.py @@ -43,6 +43,7 @@ from .model_patcher import ( AquilaModelPatcher, + ArcticModelPatcher, BaichuanModelPatcher, ChatGLMModelPatcher, CodeGenModelPatcher, @@ -50,9 +51,11 @@ GemmaModelPatcher, InternLM2Patcher, InternLMModelPatcher, + JaisModelPatcher, LlamaModelPatcher, MixtralModelPatcher, MPTModelPatcher, + PersimmonModelPatcher, Phi3ModelPatcher, QwenModelPatcher, XverseModelPatcher, @@ -473,7 +476,7 @@ class OrionOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): @register_in_tasks_manager("olmo", *["text-generation", "text-generation-with-past"], library_name="transformers") -class OlmoOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): +class OlmoOpenVINOConfig(LlamaOpenVINOConfig): DEFAULT_ONNX_OPSET = 14 NORMALIZED_CONFIG_CLASS = NormalizedTextConfig @@ -630,6 +633,11 @@ class PersimmonOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): DEFAULT_ONNX_OPSET = 14 NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return PersimmonModelPatcher(self, model, model_kwargs=model_kwargs) + @register_in_tasks_manager("biogpt", *["text-generation", "text-generation-with-past"], library_name="transformers") class BioGPTOpenVINOConfig(TextDecoderOnnxConfig): @@ -785,3 +793,29 @@ def patch_model_for_export( self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None ) -> "ModelPatcher": return DBRXModelPatcher(self, model, model_kwargs=model_kwargs) + + +@register_in_tasks_manager( + "jais", + *["text-generation", "text-generation-with-past"], + library_name="transformers", +) +class JaisOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): + DEFAULT_ONNX_OPSET = 14 + + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, DummyPastKeyValuesGenerator) + DUMMY_PKV_GENERATOR_CLASS = DummyPastKeyValuesGenerator + + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return JaisModelPatcher(self, model, model_kwargs=model_kwargs) + + +@register_in_tasks_manager("arctic", *["text-generation", "text-generation-with-past"], library_name="transformers") +class ArcticOpenVINOConfig(MixtralOpenVINOConfig): + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return ArcticModelPatcher(self, model, model_kwargs=model_kwargs) diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index 0265b3a5fc..6ce7a658c3 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -20,7 +20,6 @@ import torch import torch.nn.functional as F -from transformers.cache_utils import Cache, StaticCache from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.utils import is_tf_available @@ -36,6 +35,7 @@ if TYPE_CHECKING: + from transformers.cache_utils import Cache from transformers.modeling_utils import PreTrainedModel from optimum.exporters.onnx.config import OnnxConfig @@ -131,7 +131,10 @@ def _mixtral_sparse_moe_block_forward(self, hidden_states: torch.Tensor) -> torc # the current expert. We need to make sure to multiply the output hidden # states by `routing_weights` on the corresponding tokens (top-1 and top-2) current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) - current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] + if is_transformers_version("<", "4.37.0"): + current_hidden_states = expert_layer(current_state, routing_weights[top_x, idx, None]) + else: + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) @@ -153,6 +156,17 @@ def __exit__(self, exc_type, exc_value, traceback): layer.block_sparse_moe.forward = layer.block_sparse_moe._unpatched_forward +class ArcticModelPatcher(MixtralModelPatcher): + def __enter__(self): + # model initialize some weights for matrix multiplication in bfloat16, that lead to inconsistency of dtype + try: + self._model.to(torch.float32) + except Exception: + pass + + super().__enter__() + + def _chatglm_transformer_forward( self, input_ids, @@ -1656,9 +1670,10 @@ def _dbrx_update_causal_mask_latest( attention_mask: torch.Tensor, input_tensor: torch.Tensor, cache_position: torch.Tensor, - past_key_values: Cache, + past_key_values: "Cache", output_attentions: bool, ): + from transformers.cache_utils import StaticCache from transformers.modeling_attn_mask_utils import AttentionMaskConverter # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static @@ -1771,3 +1786,222 @@ def __exit__(self, exc_type, exc_value, traceback): self._model.transformer._update_causal_mask = self._model.transformer._orig_update_causal_mask for block in self._model.transformer.blocks: block.ffn.experts.forward = block.ffn.experts._orig_forward + + +def _persimmon_self_attn_sdpa_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional["Cache"] = None, + output_attentions: bool = False, + use_cache: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + from transformers.models.persimmon.modeling_persimmon import apply_rotary_pos_emb + + if output_attentions: + return self._orig_forward( + hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache + ) + + bsz, q_len, _ = hidden_states.size() + + # [batch_size, seq_length, 3 x hidden_size] + fused_qkv = self.query_key_value(hidden_states) + + # 3 x [batch_size, seq_length, num_heads, head_dim] + (query_states, key_states, value_states) = self._split_heads(fused_qkv) + + if self.qk_layernorm: + query_states = self.q_layernorm(query_states) + key_states = self.k_layernorm(key_states) + + # [batch_size, num_heads, seq_length, head_dim] -> [batch_size, seq_length, num_heads, head_dim] + query_states = query_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + # Partial rotary embedding + query_rot, query_pass = ( + query_states[..., : self.rotary_emb.dim], + query_states[..., self.rotary_emb.dim :], + ) + key_rot, key_pass = ( + key_states[..., : self.rotary_emb.dim], + key_states[..., self.rotary_emb.dim :], + ) + # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor] + query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids) + + # [batch_size, seq_length, num_heads, head_dim] + query_states = torch.cat((query_rot, query_pass), dim=-1) + key_states = torch.cat((key_rot, key_pass), dim=-1) + + if past_key_value is not None: + # Specific to RoPE models with partial rotation + cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attn_output = F.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attention_mask, + scale=1 / math.sqrt(self.head_dim), + dropout_p=self.attention_dropout.p, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.dense(attn_output) + + return attn_output, None, past_key_value + + +class PersimmonModelPatcher(DecoderModelPatcher): + def __enter__(self): + super().__enter__() + for layer in self._model.model.layers: + if is_torch_version(">=", "2.1.0"): + orig_self_attn_fwd = layer.self_attn.forward + layer.self_attn.forward = types.MethodType(_persimmon_self_attn_sdpa_forward, layer.self_attn) + layer.self_attn._orig_forward = orig_self_attn_fwd + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + for layer in self._model.model.layers: + if hasattr(layer.self_attn, "_orig_forward"): + layer.self_attn.forward = layer.self_attn._orig_forward + + +def _jais_attn_forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + position_bias: Optional[torch.FloatTensor] = None, +) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: + if encoder_hidden_states is not None: + if not hasattr(self, "q_attn"): + raise ValueError( + "If class is used as cross attention, the weights `q_attn` have to be defined. " + "Please make sure to instantiate class with `JAISAttention(..., is_cross_attention=True)`." + ) + + query = self.q_attn(hidden_states) + key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) + attention_mask = encoder_attention_mask + else: + query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) + + query = self._split_heads(query, self.num_heads, self.head_dim) + key = self._split_heads(key, self.num_heads, self.head_dim) + value = self._split_heads(value, self.num_heads, self.head_dim) + + if layer_past is not None: + past_key, past_value = layer_past + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + + if use_cache is True: + present = (key, value) + else: + present = None + + if self.reorder_and_upcast_attn: + attn_output, attn_weights = self._upcast_and_reordered_attn( + query, key, value, attention_mask, head_mask, position_bias + ) + else: + # Difference with original: override attn realization with sdpa if not output_attentions + if not output_attentions: + attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask, position_bias) + else: + attn_output, attn_weights = self._orig_attn(query, key, value, attention_mask, head_mask, position_bias) + + attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) + attn_output = self.c_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +def _jais_attn(self, query, key, value, attention_mask=None, head_mask=None, position_bias=None): + scale = 1.0 + if self.scale_attn_weights: + scale = 1 / self.head_dim**self.attn_scale_power + + # Layer-wise attention scaling + if self.scale_attn_by_inverse_layer_idx: + scale = scale / float(self.layer_idx + 1) + + query_length = query.size(-2) + attention_mask_sdpa = torch.ones( + (query.shape[0], query.shape[1], query.shape[2], key.shape[2]), + dtype=query.dtype, + ) + + if not self.is_cross_attention: + # if only "normal" attention layer implements causal mask + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] + mask_value = torch.finfo(torch.float16).min + attention_mask_sdpa.masked_fill_(~causal_mask, mask_value) + + if attention_mask is not None: + # Apply the attention mask + attention_mask_sdpa = attention_mask_sdpa + attention_mask + + if position_bias is not None: + attention_mask_sdpa += position_bias.type_as(attention_mask_sdpa).unsqueeze(0) + + # Mask heads if we want to + if head_mask is not None: + attention_mask_sdpa = attention_mask_sdpa * head_mask + + attn_output = F.scaled_dot_product_attention( + query, key, value, attention_mask_sdpa, dropout_p=self.attn_dropout.p, scale=scale + ) + + return attn_output, None + + +class JaisModelPatcher(DecoderModelPatcher): + def __enter__(self): + super().__enter__() + + for layer in self._model.transformer.h: + if is_torch_version(">=", "2.1.0"): + orig_self_attn_fwd = layer.attn._attn + layer.attn._attn = types.MethodType(_jais_attn, layer.attn) + layer.attn._orig_attn = orig_self_attn_fwd + layer.attn._orig_forward = layer.attn.forward + layer.attn.forward = types.MethodType(_jais_attn_forward, layer.attn) + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + for layer in self._model.transformer.h: + if hasattr(layer.attn, "_orig_attn"): + layer.attn._attn = layer.attn._orig_attn + layer.attn.forward = layer.attn._orig_forward diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index e929a4ddb8..3750d56227 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -18,7 +18,7 @@ import warnings from pathlib import Path from tempfile import TemporaryDirectory -from typing import Optional, Tuple, Union +from typing import Dict, Optional, Tuple, Union import intel_extension_for_pytorch as ipex import torch @@ -50,7 +50,7 @@ from optimum.modeling_base import OptimizedModel from optimum.utils import NormalizedConfigManager -from ...exporters.ipex.model_patcher import _IPEX_EXPORTED_TASK, _patch_model +from ...exporters.ipex.model_patcher import _IPEX_EXPORTED_TASK, _IPEX_MINIMUM_VERSION_FOR_PATCHING, _patch_model from ..generation.modeling import prepare_jit_inputs from ..utils.import_utils import is_ipex_version, is_torch_version, is_transformers_version from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS, patch_decoder_attention_mask, recursive_to_device @@ -60,10 +60,11 @@ _IPEX_SUPPORT_MODEL_TYPES = ("llama",) +_IPEX_EXPORTED_GENERATION_METHODS = ("sample", "greedy_search", "beam_sample", "beam_search", "assisted_generation") def _is_patched_with_ipex(model, task): - if is_ipex_version("<", "2.5.0"): + if is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_PATCHING): return False if isinstance(model, torch.jit.ScriptModule): @@ -73,7 +74,12 @@ def _is_patched_with_ipex(model, task): return True return False else: - return model.config.model_type in _IPEX_SUPPORT_MODEL_TYPES and task in _IPEX_EXPORTED_TASK + # The ipex IAKV op in patched model requires the hidden size at least 64 + return ( + model.config.model_type in _IPEX_SUPPORT_MODEL_TYPES + and task in _IPEX_EXPORTED_TASK + and model.config.hidden_size >= 64 + ) def ipex_jit_trace(model, task, use_cache): @@ -83,6 +89,7 @@ def ipex_jit_trace(model, task, use_cache): if _is_patched_with_ipex(model, task): model = _patch_model(model) + # Todo: integerate in prepare_jit_inputs. sample_inputs = get_dummy_input(model, return_dict=True) # Use Tensor Processing Primitives to accelerate linear, see https://arxiv.org/abs/2104.05755. _enable_tpp() @@ -92,9 +99,10 @@ def ipex_jit_trace(model, task, use_cache): model.config.return_dict = False - if "past_key_values" in sample_inputs and use_cache: - # Make sure the model will output past_key_values in generation tasks - model.config.use_cache = True + if "past_key_values" in sample_inputs: + model.config.use_cache = use_cache + if not use_cache: + sample_inputs.pop("past_key_values") model = ipex.optimize(model.eval(), dtype=model.dtype, inplace=True) # Disable repack while jit tracing to reduce the memory @@ -522,6 +530,23 @@ def _prepare_past_key_values(self, input_ids): return past_key_values + # Temporary fix, will delete when https://github.com/huggingface/transformers/pull/31226 release. + def _get_initial_cache_position(self, input_ids, model_kwargs): + """Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length""" + if not model_kwargs.get("use_cache", True): + model_kwargs["cache_position"] = None + return model_kwargs + + past_length = 0 + if "past_key_values" in model_kwargs: + past_length = model_kwargs["past_key_values"][0][0].shape[-2] + if "inputs_embeds" in model_kwargs: + cur_len = model_kwargs["inputs_embeds"].shape[1] + else: + cur_len = input_ids.shape[-1] + model_kwargs["cache_position"] = torch.arange(past_length, cur_len, device=input_ids.device) + return model_kwargs + def forward( self, input_ids: torch.LongTensor = None, @@ -561,6 +586,25 @@ def forward( return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values) + def _prepare_generation_config( + self, generation_config: Optional[GenerationConfig], **kwargs: Dict + ) -> Tuple[GenerationConfig, Dict]: + generation_config, model_kwargs = super()._prepare_generation_config(generation_config, **kwargs) + generation_method = generation_config.get_generation_mode().value + if generation_method not in _IPEX_EXPORTED_GENERATION_METHODS: + raise ValueError( + f"The generation method {generation_method} is not supported for IPEXModelForCausalLM for now, support methods are {_IPEX_EXPORTED_GENERATION_METHODS}" + ) + + return generation_config, model_kwargs + + def generate(self, *args, **kwargs): + if self._is_ipex_exported and kwargs.get("assistant_model", None): + raise ValueError( + f"Assisted decoding is not supported for patched models for now, support methods are {_IPEX_EXPORTED_GENERATION_METHODS}" + ) + return super().generate(*args, **kwargs) + def _prepare_inputs_for_generation_for_llama( input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs diff --git a/optimum/intel/openvino/__init__.py b/optimum/intel/openvino/__init__.py index c62c641a6d..4ee285f07d 100644 --- a/optimum/intel/openvino/__init__.py +++ b/optimum/intel/openvino/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. import logging +import warnings from ..utils.import_utils import is_accelerate_available, is_diffusers_available, is_nncf_available from .utils import ( @@ -25,9 +26,15 @@ ) +warnings.simplefilter(action="ignore", category=FutureWarning) + + if is_nncf_available(): + logging.disable(logging.INFO) import nncf + logging.disable(logging.NOTSET) + # Suppress version mismatch logging nncf.set_log_level(logging.ERROR) from nncf.torch import patch_torch_operators diff --git a/optimum/intel/openvino/modeling.py b/optimum/intel/openvino/modeling.py index 1c907f2135..6e7fb1e9b4 100644 --- a/optimum/intel/openvino/modeling.py +++ b/optimum/intel/openvino/modeling.py @@ -446,6 +446,9 @@ def _from_transformers( save_dir = TemporaryDirectory() save_dir_path = Path(save_dir.name) + # This attribute is needed to keep one reference on the temporary directory, since garbage collecting + # would end-up removing the directory containing the underlying OpenVINO model + cls._model_save_dir_tempdirectory_instance = save_dir # If load_in_8bit and quantization_config not specified then ov_config is set to None and will be set by default in convert depending on the model size if load_in_8bit is None and not quantization_config: diff --git a/optimum/intel/openvino/modeling_base.py b/optimum/intel/openvino/modeling_base.py index 7937deea52..7f2c08ce90 100644 --- a/optimum/intel/openvino/modeling_base.py +++ b/optimum/intel/openvino/modeling_base.py @@ -89,7 +89,10 @@ def __init__( self.model = model self.request = None - self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None + if self.can_generate(): + self.generation_config = kwargs.get("generation_config", GenerationConfig.from_model_config(config)) + else: + self.generation_config = None self._openvino_config = None if quantization_config: @@ -132,6 +135,7 @@ def fix_op_names_duplicates(model: openvino.runtime.Model): if file_name.suffix == ".onnx": model = fix_op_names_duplicates(model) # should be called during model conversion to IR + # TODO: remove this way of applying quantization; instead apply it after instance of OVModel* is loaded if quantization_config: if not is_nncf_available(): raise ImportError( @@ -155,6 +159,14 @@ def _save_pretrained(self, save_directory: Union[str, Path]): """ dst_path = os.path.join(save_directory, OV_XML_FILE_NAME) openvino.save_model(self.model, dst_path, compress_to_fp16=False) + generation_config = getattr(self, "generation_config", None) + if generation_config is not None: + try: + generation_config.save_pretrained(save_directory) + except Exception as exception: + logger.warning( + f"The generation config will not be saved, saving failed with following error:\n{exception}" + ) self._save_openvino_config(save_directory) @@ -240,6 +252,20 @@ def _from_pretrained( quantization_config = cls._prepare_weight_quantization_config(quantization_config, load_in_8bit) model = cls.load_model(model_cache_path, quantization_config=quantization_config) + + try: + generation_config = GenerationConfig.from_pretrained( + model_id, + token=token, + revision=revision, + subfolder=subfolder, + force_download=force_download, + cache_dir=cache_dir, + ) + kwargs["generation_config"] = generation_config + except Exception: + pass + return cls( model, config=config, @@ -366,6 +392,9 @@ def _from_transformers( save_dir = TemporaryDirectory() save_dir_path = Path(save_dir.name) + # This attribute is needed to keep one reference on the temporary directory, since garbage collecting + # would end-up removing the directory containing the underlying OpenVINO model + cls._model_save_dir_tempdirectory_instance = save_dir # If load_in_8bit and quantization_config not specified then ov_config is set to None and will be set by default in convert depending on the model size if load_in_8bit is None and not quantization_config: diff --git a/optimum/intel/openvino/modeling_base_seq2seq.py b/optimum/intel/openvino/modeling_base_seq2seq.py index fb53f9b2e2..718b2f874e 100644 --- a/optimum/intel/openvino/modeling_base_seq2seq.py +++ b/optimum/intel/openvino/modeling_base_seq2seq.py @@ -78,7 +78,10 @@ def __init__( self.encoder_model = encoder self.decoder_model = decoder self.decoder_with_past_model = decoder_with_past - self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None + if self.can_generate(): + self.generation_config = kwargs.get("generation_config", GenerationConfig.from_model_config(config)) + else: + self.generation_config = None self._openvino_config = None if quantization_config: self._openvino_config = OVConfig(quantization_config=quantization_config) @@ -104,6 +107,13 @@ def _save_pretrained(self, save_directory: Union[str, Path]): openvino.save_model(src_file, dst_path, compress_to_fp16=False) self._save_openvino_config(save_directory) + if self.generation_config is not None: + try: + self.generation_config.save_pretrained(save_directory) + except Exception as exception: + logger.warning( + f"The generation config will not be saved, saving failed with following error:\n{exception}" + ) @classmethod def _from_pretrained( @@ -218,6 +228,19 @@ def _from_pretrained( if use_cache: decoder_with_past = cls.load_model(file_names["decoder_with_past"], quantization_config) + try: + generation_config = GenerationConfig.from_pretrained( + model_id, + token=token, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + local_files_only=local_files_only, + ) + kwargs["generation_config"] = generation_config + except Exception: + pass + return cls( encoder=encoder, decoder=decoder, @@ -281,6 +304,10 @@ def _from_transformers( save_dir = TemporaryDirectory() save_dir_path = Path(save_dir.name) + # This attribute is needed to keep one reference on the temporary directory, since garbage collecting + # would end-up removing the directory containing the underlying OpenVINO model + cls._model_save_dir_tempdirectory_instance = save_dir + if task is None: task = cls.export_feature if use_cache: diff --git a/optimum/intel/openvino/modeling_decoder.py b/optimum/intel/openvino/modeling_decoder.py index 72cd1b6487..352c95fc84 100644 --- a/optimum/intel/openvino/modeling_decoder.py +++ b/optimum/intel/openvino/modeling_decoder.py @@ -38,7 +38,7 @@ from ...exporters.openvino import ensure_stateful_is_available, main_export, patch_stateful from ...exporters.openvino.stateful import model_has_state -from ..utils.import_utils import is_nncf_available +from ..utils.import_utils import is_nncf_available, is_transformers_version from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS from .configuration import _DEFAULT_4BIT_CONFIGS, OVConfig, OVWeightQuantizationConfig, _check_default_4bit_configs from .modeling import _TOKENIZER_FOR_DOC, INPUTS_DOCSTRING, MODEL_START_DOCSTRING, OVModel @@ -116,7 +116,6 @@ def __init__( quantization_config=quantization_config, **kwargs, ) - self.is_dynamic = dynamic_shapes use_cache = kwargs.pop("use_cache", True) model_has_sinks = model_has_state(self.model) @@ -224,6 +223,14 @@ def _save_pretrained(self, save_directory: Union[str, Path]): dst_path = os.path.join(save_directory, OV_XML_FILE_NAME) openvino.save_model(model_to_save, dst_path, compress_to_fp16=False) + if self.generation_config is not None: + try: + self.generation_config.save_pretrained(save_directory) + except Exception as exception: + logger.warning( + f"The generation config will not be saved, saving failed with following error:\n{exception}" + ) + self._save_openvino_config(save_directory) @classmethod @@ -256,6 +263,9 @@ def _from_transformers( save_dir = TemporaryDirectory() save_dir_path = Path(save_dir.name) + # This attribute is needed to keep one reference on the temporary directory, since garbage collecting + # would end-up removing the directory containing the underlying OpenVINO model + cls._model_save_dir_tempdirectory_instance = save_dir if task is None: task = cls.export_feature @@ -587,11 +597,11 @@ def _deduplicate_inputs(self, model_inputs: Dict): ) for input_name, input_tensor in model_inputs.items(): if input_name not in ["input_ids", "beam_idx"]: - if not isinstance(input_tensor, Tensor): + if input_name not in self.key_value_input_names: upd_model_inputs[input_name] = input_tensor[indicies] else: - shape = input_tensor.shape - dtype = input_tensor.element_type + shape = input_tensor.shape if isinstance(input_tensor, Tensor) else list(input_tensor.shape) + dtype = input_tensor.element_type if isinstance(input_tensor, Tensor) else Type(input_tensor.dtype) upd_batch_size = indicies.shape[0] if self.config.model_type == "bloom": upd_batch_size *= self.config.num_attention_heads @@ -623,8 +633,12 @@ def generate( negative_prompt_attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> Union[GenerateOutput, torch.LongTensor]: - _generation_config, _ = self._prepare_generation_config(generation_config, **kwargs) - generation_mode = _generation_config.get_generation_mode(assistant_model) + if is_transformers_version(">=", "4.39.0"): + _generation_config, _ = self._prepare_generation_config(generation_config, **kwargs) + generation_mode = _generation_config.get_generation_mode(assistant_model) + else: + _generation_config = generation_config or self.generation_config + generation_mode = self._get_generation_mode(_generation_config, assistant_model) is_beam_search = generation_mode in [ GenerationMode.BEAM_SEARCH, @@ -742,17 +756,7 @@ def _from_pretrained( local_files_only=local_files_only, ) - if isinstance(quantization_config, dict) and quantization_config == {"bits": 4}: - quantization_config = _DEFAULT_4BIT_CONFIGS.get(config.name_or_path, quantization_config) - - quantization_config = cls._prepare_weight_quantization_config(quantization_config, load_in_8bit) - - load_in_4bit = quantization_config.bits == 4 if quantization_config else False - - model = cls.load_model( - model_cache_path, - quantization_config=None if load_in_4bit else quantization_config, - ) + model = cls.load_model(model_cache_path) model_type = config.model_type.replace("_", "-") if model_type == "bloom": @@ -762,7 +766,25 @@ def _from_pretrained( else: init_cls = cls - enable_compilation = kwargs.pop("compile", True) and not load_in_4bit + if isinstance(quantization_config, dict) and quantization_config == {"bits": 4}: + quantization_config = _DEFAULT_4BIT_CONFIGS.get(config.name_or_path, quantization_config) + quantization_config = cls._prepare_weight_quantization_config(quantization_config, load_in_8bit) + + enable_compilation = kwargs.pop("compile", True) and not quantization_config + + try: + generation_config = GenerationConfig.from_pretrained( + model_id, + token=token, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + local_files_only=local_files_only, + ) + kwargs["generation_config"] = generation_config + except Exception: + pass + causal_model = init_cls( model=model, config=config, @@ -772,7 +794,7 @@ def _from_pretrained( **kwargs, ) - if load_in_4bit: + if quantization_config: if not is_nncf_available(): raise ImportError( "Quantization of the weights requires nncf, please install it with `pip install nncf`" diff --git a/optimum/intel/pipelines/pipeline_base.py b/optimum/intel/pipelines/pipeline_base.py index d491934fc2..2386922a0e 100644 --- a/optimum/intel/pipelines/pipeline_base.py +++ b/optimum/intel/pipelines/pipeline_base.py @@ -330,6 +330,12 @@ def pipeline( The task defining which pipeline will be returned. Currently accepted tasks are: - `"text-generation"`: will return a [`TextGenerationPipeline`]:. + - `"fill-mask"`: will return a [`FillMaskPipeline`]. + - `"question-answering"`: will return a [`QuestionAnsweringPipeline`]. + - `"image-classificatio"`: will return a [`ImageClassificationPipeline`]. + - `"text-classification"`: will return a [`TextClassificationPipeline`]. + - `"token-classification"`: will return a [`TokenClassificationPipeline`]. + - `"audio-classification"`: will return a [`AudioClassificationPipeline`]. model (`str` or [`PreTrainedModel`], *optional*): The model that will be used by the pipeline to make predictions. This can be a model identifier or an @@ -344,7 +350,7 @@ def pipeline( is not specified or not a string, then the default tokenizer for `config` is loaded (if it is a string). However, if `config` is also not given or not a string, then the default tokenizer for the given `task` will be loaded. - accelerator (`str`, *optional*, defaults to `"ipex"`): + accelerator (`str`, *optional*): The optimization backends, choose from ["ipex", "inc", "openvino"]. use_fast (`bool`, *optional*, defaults to `True`): Whether or not to use a Fast tokenizer if possible (a [`PreTrainedTokenizerFast`]). diff --git a/optimum/intel/utils/modeling_utils.py b/optimum/intel/utils/modeling_utils.py index a2cd728354..3541f4f933 100644 --- a/optimum/intel/utils/modeling_utils.py +++ b/optimum/intel/utils/modeling_utils.py @@ -182,3 +182,12 @@ def recursive_to_device(value, device): elif isinstance(value, torch.Tensor): return value.to(device) return value + + +def _setattr_from_module(new_module, module): + for k, v in module.__dict__.items(): + setattr(new_module, k, v) + for k, v in module.__class__.__dict__.items(): + if k.startswith("__") or k.startswith("forward"): + continue + setattr(new_module.__class__, k, getattr(module.__class__, k)) diff --git a/optimum/intel/version.py b/optimum/intel/version.py index 9668d62158..a2a8579447 100644 --- a/optimum/intel/version.py +++ b/optimum/intel/version.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "1.17.0.dev0" +__version__ = "1.18.0.dev0" diff --git a/setup.py b/setup.py index d00ce1dd92..23fce0f827 100644 --- a/setup.py +++ b/setup.py @@ -32,13 +32,14 @@ "optimum~=1.20", "datasets>=1.4.0", "sentencepiece", + "setuptools", "scipy", "onnx", ] TESTS_REQUIRE = [ "accelerate", - "pytest<8.2", + "pytest>=7.2.0,<8.0.0", "parameterized", "Pillow", "evaluate", @@ -49,11 +50,10 @@ "rjieba", "timm", "invisible-watermark>=0.2.0", - "auto-gptq", "transformers_stream_generator", "einops", "tiktoken", - "sentence_transformers", + "sentence-transformers", ] QUALITY_REQUIRE = ["black~=23.1", "ruff>=0.0.241"] @@ -62,7 +62,7 @@ "neural-compressor": ["neural-compressor>=2.2.0", "onnxruntime<1.15.0", "accelerate"], "openvino": ["openvino>=2023.3", "nncf>=2.10.0", "openvino-tokenizers[transformers]"], "nncf": ["nncf>=2.10.0"], - "ipex": ["intel-extension-for-pytorch", "transformers>=4.36.0,<4.39.0"], + "ipex": ["intel-extension-for-pytorch", "transformers>=4.39.0,<=4.41.2"], "diffusers": ["diffusers"], "quality": QUALITY_REQUIRE, "tests": TESTS_REQUIRE, diff --git a/tests/ipex/test_modeling.py b/tests/ipex/test_modeling.py index 2a2f18f6f8..8664b99cee 100644 --- a/tests/ipex/test_modeling.py +++ b/tests/ipex/test_modeling.py @@ -14,6 +14,7 @@ # ruff: noqa +import tempfile import time import unittest @@ -87,10 +88,16 @@ def test_compare_to_transformers(self, model_arch): with torch.no_grad(): transformers_outputs = transformers_model(**tokens) outputs = ipex_model(**tokens) + + with tempfile.TemporaryDirectory() as tmpdirname: + ipex_model.save_pretrained(tmpdirname) + loaded_model = self.IPEX_MODEL_CLASS.from_pretrained(tmpdirname) + loaded_model_outputs = loaded_model(**tokens) # Compare tensor outputs for output_name in {"logits", "last_hidden_state"}: if output_name in transformers_outputs: self.assertTrue(torch.allclose(outputs[output_name], transformers_outputs[output_name], atol=1e-4)) + self.assertTrue(torch.equal(outputs[output_name], loaded_model_outputs[output_name])) @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_pipeline(self, model_arch): @@ -139,11 +146,19 @@ def test_compare_to_transformers(self, model_arch): with torch.no_grad(): transformers_outputs = transformers_model(**tokens) outputs = ipex_model(**tokens) + + with tempfile.TemporaryDirectory() as tmpdirname: + ipex_model.save_pretrained(tmpdirname) + loaded_model = self.IPEX_MODEL_CLASS.from_pretrained(tmpdirname) + loaded_model_outputs = loaded_model(**tokens) + self.assertIn("start_logits", outputs) self.assertIn("end_logits", outputs) # Compare tensor outputs self.assertTrue(torch.allclose(outputs.start_logits, transformers_outputs.start_logits, atol=1e-4)) self.assertTrue(torch.allclose(outputs.end_logits, transformers_outputs.end_logits, atol=1e-4)) + self.assertTrue(torch.equal(outputs.start_logits, loaded_model_outputs.start_logits)) + self.assertTrue(torch.equal(outputs.end_logits, loaded_model_outputs.end_logits)) @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_pipeline(self, model_arch): @@ -171,14 +186,14 @@ class IPEXModelForCausalLMTest(unittest.TestCase): "gpt2", "gpt_neo", "gpt_neox", + "mistral", "llama", "llama2", - "mistral", # "phi", "mpt", "opt", ) - IPEX_PATCHED_SUPPORTED_ARCHITECTURES = ("llama",) + IPEX_PATCHED_SUPPORTED_ARCHITECTURES = ("llama2",) GENERATION_LENGTH = 100 SPEEDUP_CACHE = 1.0 @@ -204,8 +219,14 @@ def test_compare_to_transformers(self, model_arch): transformers_model = AutoModelForCausalLM.from_pretrained(model_id) with torch.no_grad(): transformers_outputs = transformers_model(**tokens) + + with tempfile.TemporaryDirectory() as tmpdirname: + ipex_model.save_pretrained(tmpdirname) + loaded_model = self.IPEX_MODEL_CLASS.from_pretrained(tmpdirname) + loaded_model_outputs = loaded_model(**inputs) # Compare tensor outputs self.assertTrue(torch.allclose(outputs.logits, transformers_outputs.logits, atol=1e-4)) + self.assertTrue(torch.equal(outputs.logits, loaded_model_outputs.logits)) @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_pipeline(self, model_arch): @@ -219,18 +240,23 @@ def test_pipeline(self, model_arch): self.assertEqual(pipe.device, model.device) self.assertTrue(all("This is a sample" in item["generated_text"] for item in outputs)) + # High optimized model llama is not supported assisted decoding for now. @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_assisted_decoding(self, model_arch): + if model_arch == "llama2": + return model_id = MODEL_NAMES[model_arch] tokenizer = AutoTokenizer.from_pretrained(model_id) ipex_model = IPEXModelForCausalLM.from_pretrained(model_id, export=True) transformers_model = AutoModelForCausalLM.from_pretrained(model_id) tokens = tokenizer("This is a sample input", return_tensors="pt") - ipex_output = ipex_model.generate(**tokens, do_sample=False) - ipex_output_assisted = ipex_model.generate(**tokens, do_sample=False, assistant_model=transformers_model) - transformers_output = transformers_model.generate(**tokens, do_sample=False) + ipex_output = ipex_model.generate(**tokens, do_sample=False, max_new_tokens=4) + ipex_output_assisted = ipex_model.generate( + **tokens, do_sample=False, assistant_model=transformers_model, max_new_tokens=4 + ) + transformers_output = transformers_model.generate(**tokens, do_sample=False, max_new_tokens=4) transformers_output_assisted = transformers_model.generate( - **tokens, do_sample=False, assistant_model=ipex_model + **tokens, do_sample=False, assistant_model=ipex_model, max_new_tokens=4 ) self.assertTrue(torch.equal(ipex_output, ipex_output_assisted)) self.assertTrue(torch.equal(transformers_output, transformers_output_assisted)) @@ -243,24 +269,25 @@ def test_assisted_decoding(self, model_arch): } ) ) - @unittest.skipIf(is_ipex_version("<", "2.5.0"), reason="Only ipex version > 2.3.0 supports ipex model patching") + @unittest.skipIf(is_ipex_version("<", "2.3.0"), reason="Only ipex version > 2.3.0 supports ipex model patching") def test_ipex_patching_beam_search(self, test_name, model_arch, use_cache): model_id = MODEL_NAMES[model_arch] set_seed(SEED) model = IPEXModelForCausalLM.from_pretrained(model_id, export=True, use_cache=use_cache) - self.assertEqual(model.use_cache, use_cache) trasnformers_model = AutoModelForCausalLM.from_pretrained(model_id) + self.assertEqual(model.use_cache, use_cache) tokenizer = AutoTokenizer.from_pretrained(model_id) tokenizer.pad_token = tokenizer.eos_token # Test with batch_size is 1 and 2. texts = ["This is a sample", ["This is the first input", "This is the second input"]] generation_configs = ( - GenerationConfig(max_new_tokens=4, num_beams=2, do_sample=True), - GenerationConfig(max_new_tokens=4, num_beams=4, do_sample=True), - GenerationConfig(max_new_tokens=4, num_beams=8, do_sample=True), - GenerationConfig(max_new_tokens=4, num_beams=32, do_sample=True), - GenerationConfig(max_new_tokens=4, do_sample=not use_cache, top_p=1.0, top_k=5, penalty_alpha=0.6), - GenerationConfig(max_new_tokens=4, do_sample=True, top_p=0.9, top_k=0), + GenerationConfig(max_new_tokens=4, num_beams=2, do_sample=False), + GenerationConfig(max_new_tokens=4, num_beams=4, do_sample=False), + GenerationConfig(max_new_tokens=4, num_beams=8, do_sample=False), + GenerationConfig(max_new_tokens=4, num_beams=32, do_sample=False), + GenerationConfig( + max_new_tokens=4, do_sample=False, top_p=0.9, top_k=0, pad_token_id=tokenizer.eos_token_id + ), ) for text in texts: tokens = tokenizer(text, padding=True, return_tensors="pt") @@ -268,7 +295,7 @@ def test_ipex_patching_beam_search(self, test_name, model_arch, use_cache): outputs = model.generate(**tokens, generation_config=generation_config) transformers_outputs = trasnformers_model.generate(**tokens, generation_config=generation_config) self.assertIsInstance(outputs, torch.Tensor) - self.assertEqual(outputs, transformers_outputs) + self.assertTrue(torch.equal(outputs, transformers_outputs)) def test_compare_with_and_without_past_key_values(self): model_id = "echarlaix/tiny-random-gpt2-torchscript" @@ -326,8 +353,14 @@ def test_compare_to_transformers(self, model_arch): with torch.no_grad(): transformers_outputs = transformers_model(**inputs) outputs = ipex_model(**inputs) + + with tempfile.TemporaryDirectory() as tmpdirname: + ipex_model.save_pretrained(tmpdirname) + loaded_model = self.IPEX_MODEL_CLASS.from_pretrained(tmpdirname) + loaded_model_outputs = loaded_model(**inputs) # Compare tensor outputs self.assertTrue(torch.allclose(outputs.logits, transformers_outputs.logits, atol=1e-3)) + self.assertTrue(torch.equal(outputs.logits, loaded_model_outputs.logits)) @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_pipeline(self, model_arch): @@ -366,9 +399,16 @@ def test_compare_to_transformers(self, model_arch): with torch.no_grad(): transformers_outputs = transformers_model(**inputs) outputs = ipex_model(**inputs) + + with tempfile.TemporaryDirectory() as tmpdirname: + ipex_model.save_pretrained(tmpdirname) + loaded_model = self.IPEX_MODEL_CLASS.from_pretrained(tmpdirname) + loaded_model_outputs = loaded_model(**inputs) + self.assertIn("logits", outputs) # Compare tensor outputs self.assertTrue(torch.allclose(outputs.logits, transformers_outputs.logits, atol=1e-4)) + self.assertTrue(torch.equal(outputs.logits, loaded_model_outputs.logits)) @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_pipeline(self, model_arch): diff --git a/tests/openvino/test_export.py b/tests/openvino/test_export.py index 851f8355f5..6d05158dda 100644 --- a/tests/openvino/test_export.py +++ b/tests/openvino/test_export.py @@ -18,8 +18,10 @@ from tempfile import TemporaryDirectory from typing import Optional +import torch from parameterized import parameterized -from transformers import AutoConfig +from sentence_transformers import SentenceTransformer, models +from transformers import AutoConfig, AutoTokenizer, GenerationConfig from utils_tests import MODEL_NAMES from optimum.exporters.onnx.constants import SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED @@ -69,6 +71,8 @@ class ExportModelTest(unittest.TestCase): "latent-consistency": OVLatentConsistencyModelPipeline, } + GENERATIVE_MODELS = ("pix2struct", "t5", "bart", "gpt2", "whisper") + def _openvino_export( self, model_type: str, @@ -122,9 +126,54 @@ def _openvino_export( def test_export(self, model_type: str): self._openvino_export(model_type) + @parameterized.expand(GENERATIVE_MODELS) + def test_export_with_custom_gen_config(self, model_type): + auto_model = self.SUPPORTED_ARCHITECTURES[model_type] + task = auto_model.export_feature + model_name = MODEL_NAMES[model_type] + loading_kwargs = {"attn_implementation": "eager"} if model_type in SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED else {} + + model = auto_model.auto_model_class.from_pretrained(model_name, **loading_kwargs) + + model.generation_config.top_k = 42 + model.generation_config.do_sample = True + + if getattr(model.config, "model_type", None) == "pix2struct": + preprocessors = maybe_load_preprocessors(model_name) + else: + preprocessors = None + + supported_tasks = (task, task + "-with-past") if "text-generation" in task else (task,) + for supported_task in supported_tasks: + with TemporaryDirectory() as tmpdirname: + export_from_model( + model=model, + output=Path(tmpdirname), + task=supported_task, + preprocessors=preprocessors, + ) + + use_cache = supported_task.endswith("-with-past") + ov_model = auto_model.from_pretrained(tmpdirname, use_cache=use_cache) + self.assertIsInstance(ov_model, OVBaseModel) + self.assertTrue(ov_model.can_generate()) + self.assertTrue(ov_model.generation_config is not None) + self.assertIsInstance(ov_model.generation_config, GenerationConfig) + self.assertTrue(ov_model.generation_config.top_k == 42) + + # check that generate config remains after repeated saving + with TemporaryDirectory() as tmpdirname2: + ov_model.save_pretrained(tmpdirname2) + ov_model = auto_model.from_pretrained(tmpdirname2, use_cache=use_cache) + self.assertIsInstance(ov_model, OVBaseModel) + self.assertTrue(ov_model.can_generate()) + self.assertTrue(ov_model.generation_config is not None) + self.assertIsInstance(ov_model.generation_config, GenerationConfig) + self.assertTrue(ov_model.generation_config.top_k == 42) + class CustomExportModelTest(unittest.TestCase): - def test_export_custom_model(self): + def test_custom_export_config_model(self): class BertOnnxConfigWithPooler(BertOnnxConfig): @property def outputs(self): @@ -157,3 +206,26 @@ def outputs(self): self.assertIsInstance(ov_model, OVBaseModel) self.assertTrue(ov_model.output_names == {"last_hidden_state": 0, "pooler_output": 1}) + + def test_export_custom_model(self): + model_id = "hf-internal-testing/tiny-random-BertModel" + word_embedding_model = models.Transformer(model_id, max_seq_length=256) + pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension()) + dense_model = models.Dense( + in_features=pooling_model.get_sentence_embedding_dimension(), + out_features=256, + ) + model = SentenceTransformer(modules=[word_embedding_model, pooling_model, dense_model]) + + with TemporaryDirectory() as tmpdirname: + export_from_model(model, output=tmpdirname, task="feature-extraction") + ov_model = OVModelForCustomTasks.from_pretrained(tmpdirname) + + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokens = tokenizer("This is a sample input", return_tensors="pt") + with torch.no_grad(): + model_outputs = model(tokens) + + ov_outputs = ov_model(**tokens) + self.assertTrue(torch.allclose(ov_outputs.token_embeddings, model_outputs.token_embeddings, atol=1e-4)) + self.assertTrue(torch.allclose(ov_outputs.sentence_embedding, model_outputs.sentence_embedding, atol=1e-4)) diff --git a/tests/openvino/test_exporters_cli.py b/tests/openvino/test_exporters_cli.py index cce25bbae1..c81761bc9f 100644 --- a/tests/openvino/test_exporters_cli.py +++ b/tests/openvino/test_exporters_cli.py @@ -89,6 +89,14 @@ class OVCLIExportTestCase(unittest.TestCase): ("text-generation-with-past", "opt125m", "int4_sym_g64", 62, 86), ("text-generation-with-past", "opt125m", "int4_asym_g64", 62, 86), ("text-generation-with-past", "llama_awq", "int4 --ratio 1.0 --sym --group-size 16 --all-layers", 0, 32), + ( + "text-generation-with-past", + "llama_awq", + "int4 --ratio 1.0 --sym --group-size 16 --awq --dataset wikitext2 --num-samples 100 " + "--sensitivity-metric max_activation_variance", + 4, + 28, + ), ] def _openvino_export( @@ -197,10 +205,11 @@ def test_exporters_cli_hybrid_quantization(self, model_type: str, exp_num_fq: in @parameterized.expand(TEST_4BIT_CONFIGURATONS) def test_exporters_cli_int4(self, task: str, model_type: str, option: str, expected_int8: int, expected_int4: int): with TemporaryDirectory() as tmpdir: - subprocess.run( + result = subprocess.run( f"optimum-cli export openvino --model {MODEL_NAMES[model_type]} --task {task} --weight-format {option} {tmpdir}", shell=True, check=True, + capture_output=True, ) model_kwargs = {"use_cache": task.endswith("with-past")} if "generation" in task else {} model = eval(_HEAD_TO_AUTOMODELS[task.replace("-with-past", "")]).from_pretrained(tmpdir, **model_kwargs) @@ -208,6 +217,7 @@ def test_exporters_cli_int4(self, task: str, model_type: str, option: str, expec _, num_int8, num_int4 = get_num_quantized_nodes(model) self.assertEqual(expected_int8, num_int8) self.assertEqual(expected_int4, num_int4) + self.assertTrue("--awq" not in option or b"Applying AWQ" in result.stdout) def test_exporters_cli_help(self): subprocess.run( diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index c8d6caa060..5bf5e1ffa7 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -83,7 +83,7 @@ from optimum.intel.openvino.modeling_timm import TimmImageProcessor from optimum.intel.openvino.utils import _print_compiled_model_properties from optimum.intel.pipelines import pipeline as optimum_pipeline -from optimum.intel.utils.import_utils import is_openvino_version +from optimum.intel.utils.import_utils import is_openvino_version, is_transformers_version from optimum.utils import ( DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER, DIFFUSION_MODEL_UNET_SUBFOLDER, @@ -597,8 +597,6 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase): "chatglm", "codegen", "codegen2", - # "data2vec-text", # TODO : enable when enabled in exporters - "gemma", "gpt2", "gpt_neo", "gpt_neox", @@ -609,15 +607,10 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase): "mistral", "mixtral", "mpt", - "olmo", "opt", "pegasus", "qwen", - "qwen2", - "stablelm", - "starcoder2", "phi", - "phi3", "internlm2", "orion", "falcon", @@ -625,15 +618,28 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase): "persimmon", "biogpt", "gpt_neox_japanese", - "cohere", "xglm", "aquila", "aquila2", "xverse", "internlm", - "dbrx", - "qwen2-moe", + "jais", + "arctic", ) + + if is_transformers_version(">=", "4.40.0"): + SUPPORTED_ARCHITECTURES += ( + "gemma", + "olmo", + "stablelm", + "starcoder2", + "dbrx", + "phi3", + "cohere", + "qwen2", + "qwen2-moe", + ) + GENERATION_LENGTH = 100 REMOTE_CODE_MODELS = ( "chatglm", @@ -644,12 +650,12 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase): "qwen", "internlm2", "orion", - "phi3", "aquila", "aquila2", "xverse", "internlm", "codegen2", + "arctic", ) @parameterized.expand(SUPPORTED_ARCHITECTURES) @@ -691,7 +697,7 @@ def test_compare_to_transformers(self, model_arch): set_seed(SEED) transformers_model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs) - if model_arch == "qwen": + if model_arch in ["qwen", "arctic"]: transformers_model.to(torch.float32) with torch.no_grad(): @@ -945,6 +951,9 @@ def test_beam_search(self, model_arch): model_id, export=True, use_cache=True, stateful=False, **model_kwargs ) transformers_model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs) + + if model_arch == "arctic": + transformers_model.to(torch.float32) tokenizer.pad_token_id = tokenizer.eos_token_id tokens = tokenizer(["Today is a nice day and I am longer", "This is me"], return_tensors="pt", padding=True) tokens.pop("token_type_ids", None) @@ -955,14 +964,14 @@ def test_beam_search(self, model_arch): ov_model_stateless.config.eos_token_id = None transformers_model.config.eos_token_id = None - for gen_config in gen_configs: + for idx, gen_config in enumerate(gen_configs): if gen_config.do_sample and model_arch in ["baichuan2-13b", "olmo"]: continue transformers_outputs = transformers_model.generate(**tokens, generation_config=gen_config) ov_stateful_outputs = ov_model_stateful.generate(**tokens, generation_config=gen_config) - self.assertTrue(torch.allclose(ov_stateful_outputs, transformers_outputs)) + self.assertTrue(torch.allclose(ov_stateful_outputs, transformers_outputs), f"generation config : {idx}") ov_stateless_outputs = ov_model_stateless.generate(**tokens, generation_config=gen_config) - self.assertTrue(torch.allclose(ov_stateless_outputs, transformers_outputs)) + self.assertTrue(torch.allclose(ov_stateless_outputs, transformers_outputs), f"generation config : {idx}") class OVModelForMaskedLMIntegrationTest(unittest.TestCase): diff --git a/tests/openvino/test_quantization.py b/tests/openvino/test_quantization.py index b7ed36d3e6..bae0ad772f 100644 --- a/tests/openvino/test_quantization.py +++ b/tests/openvino/test_quantization.py @@ -63,7 +63,7 @@ from optimum.intel.openvino.configuration import OVQuantizationMethod, OVQuantizationConfigBase from optimum.intel.openvino.quantization import InferRequestWrapper -from optimum.intel.utils.import_utils import is_openvino_version +from optimum.intel.utils.import_utils import is_openvino_version, is_transformers_version from utils_tests import MODEL_NAMES, get_num_quantized_nodes, _ARCHITECTURES_TO_EXPECTED_INT8 _TASK_TO_DATASET = { @@ -89,6 +89,9 @@ def test_automodel_static_quantization(self, model_cls, model_name, expected_fak dataset_name, dataset_config_name, column_name = _TASK_TO_DATASET[task] file_name = "openvino_quantized_model.xml" + if model_name == "bert" and is_transformers_version("<", "4.41.0"): + expected_fake_quantize = 32 + def preprocess_function(examples, tokenizer): return tokenizer(examples[column_name], padding="max_length", max_length=128, truncation=True) @@ -114,7 +117,6 @@ def preprocess_function(examples, tokenizer): ov_config=ov_config, ) model = model_cls.from_pretrained(tmp_dir, file_name=file_name) - num_fake_quantize, num_int8, _ = get_num_quantized_nodes(model) self.assertEqual(expected_fake_quantize, num_fake_quantize) self.assertEqual(expected_int8, num_int8) diff --git a/tests/openvino/test_training.py b/tests/openvino/test_training.py index 89d644319c..639a77b4a6 100644 --- a/tests/openvino/test_training.py +++ b/tests/openvino/test_training.py @@ -54,6 +54,7 @@ ) from optimum.intel.openvino.trainer import DEFAULT_QUANTIZATION_CONFIG, OVTrainer from optimum.intel.openvino.utils import OV_XML_FILE_NAME +from optimum.intel.utils.import_utils import is_transformers_version F32_CONFIG = {"INFERENCE_PRECISION_HINT": "f32"} @@ -463,6 +464,7 @@ class OVTrainerTextClassificationTrainingTest(OVTrainerBaseTrainingTest): task = "sequence-classification" @parameterized.expand(OVTRAINER_TEXT_CLASSIFICATION_TEST_DESCRIPTORS.items()) + @unittest.skipIf(is_transformers_version("<", "4.41.0"), reason="Mismatch in expected fake quantized op") def test_training(self, _, desc: OVTrainerTestDescriptor): self.run_ovtrainer_training_checks(desc) @@ -611,6 +613,7 @@ class OVTrainerImageClassificationTrainingTest(OVTrainerBaseTrainingTest): task = "image-classification" @parameterized.expand(OVTRAINER_IMAGE_CLASSIFICATION_TEST_DESCRIPTORS.items()) + @unittest.skipIf(is_transformers_version("<", "4.41.0"), reason="Mismatch in expected fake quantized op") def test_training(self, _, desc: OVTrainerTestDescriptor): self.run_ovtrainer_training_checks(desc) diff --git a/tests/openvino/utils_tests.py b/tests/openvino/utils_tests.py index 91500cfc63..7f16a8e053 100644 --- a/tests/openvino/utils_tests.py +++ b/tests/openvino/utils_tests.py @@ -63,9 +63,10 @@ "ibert": "hf-internal-testing/tiny-random-ibert", "internlm": "katuni4ka/tiny-random-internlm", "internlm2": "katuni4ka/tiny-random-internlm2", + "jais": "katuni4ka/tiny-random-jais", "levit": "hf-internal-testing/tiny-random-LevitModel", "longt5": "hf-internal-testing/tiny-random-longt5", - "llama": "fxmarty/tiny-llama-fast-tokenizer", + "llama": "HuggingFaceM4/tiny-random-LlamaForCausalLM", "llama_awq": "HuggingFaceH4/tiny-random-LlamaForCausalLM", "llama_gptq": "hf-internal-testing/TinyLlama-1.1B-Chat-v0.3-GPTQ", "m2m_100": "hf-internal-testing/tiny-random-m2m_100", @@ -89,10 +90,10 @@ "persimmon": "hf-internal-testing/tiny-random-PersimmonForCausalLM", "pix2struct": "fxmarty/pix2struct-tiny-random", "phi": "echarlaix/tiny-random-PhiForCausalLM", - "phi3": "katuni4ka/tiny-random-phi3", + "phi3": "Xenova/tiny-random-Phi3ForCausalLM", "poolformer": "hf-internal-testing/tiny-random-PoolFormerModel", "qwen": "katuni4ka/tiny-random-qwen", - "qwen2": "Qwen/Qwen1.5-0.5B", + "qwen2": "fxmarty/tiny-dummy-qwen2", "qwen2-moe": "katuni4ka/tiny-random-qwen1.5-moe", "resnet": "hf-internal-testing/tiny-random-resnet", "roberta": "hf-internal-testing/tiny-random-roberta", @@ -109,6 +110,7 @@ "latent-consistency": "echarlaix/tiny-random-latent-consistency", "sew": "hf-internal-testing/tiny-random-SEWModel", "sew_d": "asapp/sew-d-tiny-100k-ft-ls100h", + "arctic": "katuni4ka/tiny-random-snowflake", "swin": "hf-internal-testing/tiny-random-SwinModel", "t5": "hf-internal-testing/tiny-random-t5", "trocr": "microsoft/trocr-small-handwritten",