Skip to content

Commit

Permalink
Add more audio tasks for OpenVINO inference (#396)
Browse files Browse the repository at this point in the history
* Add more audio tasks for OpenVINO inference

* Fix style

* Add reference docs for new audio tasks

* Add reference link to top of OpenVINO inference docs

* Allow import from optimum.intel directly

* Add relative link to reference docs

* Add more imports to optimum.intel

* Use optimum.intel imports in test
  • Loading branch information
helena-intel authored Aug 2, 2023
1 parent 44cebca commit 0a80e20
Show file tree
Hide file tree
Showing 8 changed files with 464 additions and 6 deletions.
2 changes: 2 additions & 0 deletions docs/source/inference.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ outputs = cls_pipe("He's a dreadful magician.")
[{'label': 'NEGATIVE', 'score': 0.9919503927230835}]
```

See the [reference documentation](reference_ov) for more information about parameters, and examples for different tasks.

To easily save the resulting model, you can use the `save_pretrained()` method, which will save both the BIN and XML files describing the graph. It is useful to save the tokenizer to the same directory, to enable easy loading of the tokenizer for the model.


Expand Down
12 changes: 11 additions & 1 deletion docs/source/reference_ov.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,21 @@ limitations under the License.

[[autodoc]] openvino.modeling.OVModelForTokenClassification


## OVModelForAudioClassification

[[autodoc]] openvino.modeling.OVModelForAudioClassification

## OVModelForAudioFrameClassification

[[autodoc]] openvino.modeling.OVModelForAudioFrameClassification

## OVModelForCTC

[[autodoc]] openvino.modeling.OVModelForCTC

## OVModelForAudioXVector

[[autodoc]] openvino.modeling.OVModelForAudioXVector

## OVModelForImageClassification

Expand Down
6 changes: 6 additions & 0 deletions optimum/intel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@
_import_structure["openvino"].extend(
[
"OVModelForAudioClassification",
"OVModelForAudioFrameClassification",
"OVModelForAudioXVector",
"OVModelForCTC",
"OVModelForCausalLM",
"OVModelForFeatureExtraction",
"OVModelForImageClassification",
Expand Down Expand Up @@ -176,7 +179,10 @@
else:
from .openvino import (
OVModelForAudioClassification,
OVModelForAudioFrameClassification,
OVModelForAudioXVector,
OVModelForCausalLM,
OVModelForCTC,
OVModelForFeatureExtraction,
OVModelForImageClassification,
OVModelForMaskedLM,
Expand Down
3 changes: 3 additions & 0 deletions optimum/intel/openvino/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@

from .modeling import (
OVModelForAudioClassification,
OVModelForAudioFrameClassification,
OVModelForAudioXVector,
OVModelForCTC,
OVModelForFeatureExtraction,
OVModelForImageClassification,
OVModelForMaskedLM,
Expand Down
245 changes: 245 additions & 0 deletions optimum/intel/openvino/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
AutoConfig,
AutoModel,
AutoModelForAudioClassification,
AutoModelForAudioFrameClassification,
AutoModelForAudioXVector,
AutoModelForCTC,
AutoModelForImageClassification,
AutoModelForMaskedLM,
AutoModelForQuestionAnswering,
Expand All @@ -32,13 +35,17 @@
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
from transformers.modeling_outputs import (
BaseModelOutput,
CausalLMOutput,
ImageClassifierOutput,
MaskedLMOutput,
QuestionAnsweringModelOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
XVectorOutput,
)

from optimum.exporters import TasksManager

from .modeling_base import OVBaseModel


Expand Down Expand Up @@ -93,6 +100,13 @@
Pixel values can be obtained from encoded images using [`AutoFeatureExtractor`](https://huggingface.co/docs/transformers/autoclass_tutorial#autofeatureextractor).
"""

AUDIO_INPUTS_DOCSTRING = r"""
Args:
input_values (`torch.Tensor` of shape `({0})`):
Float values of input raw speech waveform..
Input values can be obtained from audio file loaded into an array using [`AutoFeatureExtractor`](https://huggingface.co/docs/transformers/autoclass_tutorial#autofeatureextractor).
"""


class OVModel(OVBaseModel):
base_model_prefix = "openvino_model"
Expand Down Expand Up @@ -575,3 +589,234 @@ def forward(
outputs = self.request(inputs)
logits = torch.from_numpy(outputs["logits"]).to(self.device) if not np_inputs else outputs["logits"]
return SequenceClassifierOutput(logits=logits)


CTC_EXAMPLE = r"""
Example of CTC:
```python
>>> from transformers import {processor_class}
>>> from optimum.intel import {model_class}
>>> from datasets import load_dataset
>>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
>>> dataset = dataset.sort("id")
>>> sampling_rate = dataset.features["audio"].sampling_rate
>>> processor = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}", export=True)
>>> # audio file is decoded on the fly
>>> inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="np")
>>> logits = model(**inputs).logits
>>> predicted_ids = np.argmax(logits, axis=-1)
>>> transcription = processor.batch_decode(predicted_ids)
```
"""


@add_start_docstrings(
"""
Onnx Model with a language modeling head on top for Connectionist Temporal Classification (CTC).
""",
MODEL_START_DOCSTRING,
)
class OVModelForCTC(OVModel):
"""
CTC model for OpenVINO.
"""

auto_model_class = AutoModelForCTC
export_feature = TasksManager.infer_task_from_model(auto_model_class)

@add_start_docstrings_to_model_forward(
AUDIO_INPUTS_DOCSTRING.format("batch_size, sequence_length")
+ CTC_EXAMPLE.format(
processor_class=_FEATURE_EXTRACTOR_FOR_DOC,
model_class="OVModelForCTC",
checkpoint="facebook/hubert-large-ls960-ft",
)
)
def forward(
self,
input_values: Optional[torch.Tensor] = None,
attention_mask: Optional[Union[torch.Tensor, np.ndarray]] = None,
**kwargs,
):
np_inputs = isinstance(input_values, np.ndarray)
if not np_inputs:
input_values = np.array(input_values)
attention_mask = np.array(attention_mask) if attention_mask is not None else attention_mask

inputs = {
"input_values": input_values,
}

# Add the attention_mask when needed
if "attention_mask" in self.input_names:
inputs["attention_mask"] = attention_mask

# Run inference
outputs = self.request(inputs)
logits = torch.from_numpy(outputs["logits"]).to(self.device) if not np_inputs else outputs["logits"]
return CausalLMOutput(logits=logits)


AUDIO_XVECTOR_EXAMPLE = r"""
Example of Audio XVector:
```python
>>> from transformers import {processor_class}
>>> from optimum.intel import {model_class}
>>> from datasets import load_dataset
>>> import torch
>>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
>>> dataset = dataset.sort("id")
>>> sampling_rate = dataset.features["audio"].sampling_rate
>>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}", export=True)
>>> # audio file is decoded on the fly
>>> inputs = feature_extractor(
... [d["array"] for d in dataset[:2]["audio"]], sampling_rate=sampling_rate, return_tensors="pt", padding=True
... )
>>> embeddings = model(**inputs).embeddings
>>> embeddings = torch.nn.functional.normalize(embeddings, dim=-1).cpu()
>>> cosine_sim = torch.nn.CosineSimilarity(dim=-1)
>>> similarity = cosine_sim(embeddings[0], embeddings[1])
>>> threshold = 0.7
>>> if similarity < threshold:
... print("Speakers are not the same!")
>>> round(similarity.item(), 2)
```
"""


@add_start_docstrings(
"""
Onnx Model with an XVector feature extraction head on top for tasks like Speaker Verification.
""",
MODEL_START_DOCSTRING,
)
class OVModelForAudioXVector(OVModel):
"""
Audio XVector model for OpenVINO.
"""

auto_model_class = AutoModelForAudioXVector
export_feature = TasksManager.infer_task_from_model(auto_model_class)

@add_start_docstrings_to_model_forward(
AUDIO_INPUTS_DOCSTRING.format("batch_size, sequence_length")
+ AUDIO_XVECTOR_EXAMPLE.format(
processor_class=_FEATURE_EXTRACTOR_FOR_DOC,
model_class="OVModelForAudioXVector",
checkpoint="anton-l/wav2vec2-base-superb-sv",
)
)
def forward(
self,
input_values: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
):
np_inputs = isinstance(input_values, np.ndarray)
if not np_inputs:
input_values = np.array(input_values)
attention_mask = np.array(attention_mask) if attention_mask is not None else attention_mask

inputs = {
"input_values": input_values,
}

# Add the attention_mask when needed
if "attention_mask" in self.input_names:
inputs["attention_mask"] = attention_mask

# Run inference
outputs = self.request(inputs)
logits = torch.from_numpy(outputs["logits"]).to(self.device) if not np_inputs else outputs["logits"]
embeddings = (
torch.from_numpy(outputs["embeddings"]).to(self.device) if not np_inputs else outputs["embeddings"]
)

return XVectorOutput(logits=logits, embeddings=embeddings)


AUDIO_FRAME_CLASSIFICATION_EXAMPLE = r"""
Example of audio frame classification:
```python
>>> from transformers import {processor_class}
>>> from optimum.intel import {model_class}
>>> from datasets import load_dataset
>>> import torch
>>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
>>> dataset = dataset.sort("id")
>>> sampling_rate = dataset.features["audio"].sampling_rate
>>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}", export=True)
>>> inputs = feature_extractor(dataset[0]["audio"]["array"], return_tensors="pt", sampling_rate=sampling_rate)
>>> logits = model(**inputs).logits
>>> probabilities = torch.sigmoid(torch.as_tensor(logits)[0])
>>> labels = (probabilities > 0.5).long()
>>> labels[0].tolist()
```
"""


@add_start_docstrings(
"""
OpenVINO Model for with a frame classification head on top for tasks like Speaker Diarization.
""",
MODEL_START_DOCSTRING,
)
class OVModelForAudioFrameClassification(OVModel):
"""
Audio Frame Classification model for OpenVINO.
"""

auto_model_class = AutoModelForAudioFrameClassification
export_feature = TasksManager.infer_task_from_model(auto_model_class)

@add_start_docstrings_to_model_forward(
AUDIO_INPUTS_DOCSTRING.format("batch_size, sequence_length")
+ AUDIO_FRAME_CLASSIFICATION_EXAMPLE.format(
processor_class=_FEATURE_EXTRACTOR_FOR_DOC,
model_class="OVModelForAudioFrameClassification",
checkpoint="anton-l/wav2vec2-base-superb-sd",
)
)
def forward(
self,
input_values: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
):
np_inputs = isinstance(input_values, np.ndarray)
if not np_inputs:
input_values = np.array(input_values)
attention_mask = np.array(attention_mask) if attention_mask is not None else attention_mask

inputs = {
"input_values": input_values,
}

# Add the attention_mask when needed
if "attention_mask" in self.input_names:
inputs["attention_mask"] = attention_mask

# Run inference
outputs = self.request(inputs)
logits = torch.from_numpy(outputs["logits"]).to(self.device) if not np_inputs else outputs["logits"]

return TokenClassifierOutput(logits=logits)
33 changes: 33 additions & 0 deletions optimum/intel/utils/dummy_openvino_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,39 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["openvino"])


class OVModelForAudioFrameClassification(metaclass=DummyObject):
_backends = ["openvino"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["openvino"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["openvino"])


class OVModelForAudioXVector(metaclass=DummyObject):
_backends = ["openvino"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["openvino"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["openvino"])


class OVModelForCTC(metaclass=DummyObject):
_backends = ["openvino"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["openvino"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["openvino"])


class OVModelForCausalLM(metaclass=DummyObject):
_backends = ["openvino"]

Expand Down
Loading

0 comments on commit 0a80e20

Please sign in to comment.