Skip to content
This repository has been archived by the owner on Sep 17, 2024. It is now read-only.

Commit

Permalink
add new language test
Browse files Browse the repository at this point in the history
test that the pipeline can return both the language and timestamp
  • Loading branch information
robinderat committed Jul 2, 2024
1 parent 88f0f0e commit 0352e8e
Showing 1 changed file with 57 additions and 1 deletion.
58 changes: 57 additions & 1 deletion tests/pipelines/test_pipelines_automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,6 @@ def test_torch_large_with_input_features(self):

@slow
@require_torch
@slow
def test_return_timestamps_in_preprocess(self):
pipe = pipeline(
task="automatic-speech-recognition",
Expand Down Expand Up @@ -368,6 +367,63 @@ def test_return_timestamps_in_preprocess(self):
)
# fmt: on

@slow
@require_torch
def test_return_timestamps_and_language_in_preprocess(self):
pipe = pipeline(
task="automatic-speech-recognition",
model="openai/whisper-tiny",
chunk_length_s=8,
stride_length_s=1,
return_language=True,
)
data = load_dataset("openslr/librispeech_asr", "clean", split="test", streaming=True, trust_remote_code=True)
sample = next(iter(data))
pipe.model.config.forced_decoder_ids = pipe.tokenizer.get_decoder_prompt_ids(language="en", task="transcribe")

res = pipe(sample["audio"]["array"])
self.assertEqual(
res,
{
"text": " Conquered returned to its place amidst the tents.",
"chunks": [{"language": "english", "text": " Conquered returned to its place amidst the tents."}],
},
)
res = pipe(sample["audio"]["array"], return_timestamps=True)
self.assertEqual(
res,
{
"text": " Conquered returned to its place amidst the tents.",
"chunks": [
{
"timestamp": (0.0, 3.36),
"language": "english",
"text": " Conquered returned to its place amidst the tents.",
}
],
},
)
pipe.model.generation_config.alignment_heads = [[2, 2], [3, 0], [3, 2], [3, 3], [3, 4], [3, 5]]
res = pipe(sample["audio"]["array"], return_timestamps="word")
# fmt: off
self.assertEqual(
res,
{
'text': ' Conquered returned to its place amidst the tents.',
'chunks': [
{"language": "english",'text': ' Conquered', 'timestamp': (0.5, 1.2)},
{"language": "english", 'text': ' returned', 'timestamp': (1.2, 1.64)},
{"language": "english",'text': ' to', 'timestamp': (1.64, 1.84)},
{"language": "english",'text': ' its', 'timestamp': (1.84, 2.02)},
{"language": "english",'text': ' place', 'timestamp': (2.02, 2.28)},
{"language": "english",'text': ' amidst', 'timestamp': (2.28, 2.8)},
{"language": "english",'text': ' the', 'timestamp': (2.8, 2.98)},
{"language": "english",'text': ' tents.', 'timestamp': (2.98, 3.48)},
],
},
)
# fmt: on

@slow
@require_torch
def test_return_timestamps_in_preprocess_longform(self):
Expand Down

0 comments on commit 0352e8e

Please sign in to comment.