From 391cd450ded9038c06099906a667f4d9a6e6ed3e Mon Sep 17 00:00:00 2001 From: Robin Bakker Date: Mon, 24 Jun 2024 14:27:54 +0200 Subject: [PATCH] add language to words _collate_word_timestamps uses the return_language flag to determine whether the language of the chunk should be added to the word's information --- .../models/whisper/tokenization_whisper.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/whisper/tokenization_whisper.py b/src/transformers/models/whisper/tokenization_whisper.py index 303822de65f8b0..c5015f3b6e934a 100644 --- a/src/transformers/models/whisper/tokenization_whisper.py +++ b/src/transformers/models/whisper/tokenization_whisper.py @@ -1013,7 +1013,7 @@ def new_chunk(): chunk["text"] = resolved_text if return_timestamps == "word": chunk["words"] = _collate_word_timestamps( - tokenizer, resolved_tokens, resolved_token_timestamps, last_language + tokenizer, resolved_tokens, resolved_token_timestamps, last_language, return_language ) chunks.append(chunk) @@ -1065,7 +1065,7 @@ def new_chunk(): chunk["text"] = resolved_text if return_timestamps == "word": chunk["words"] = _collate_word_timestamps( - tokenizer, resolved_tokens, resolved_token_timestamps, last_language + tokenizer, resolved_tokens, resolved_token_timestamps, last_language, return_language ) chunks.append(chunk) @@ -1197,12 +1197,16 @@ def _find_longest_common_sequence(sequences, token_timestamp_sequences=None): return total_sequence, [] -def _collate_word_timestamps(tokenizer, tokens, token_timestamps, language): +def _collate_word_timestamps(tokenizer, tokens, token_timestamps, language, return_language): words, _, token_indices = _combine_tokens_into_words(tokenizer, tokens, language) + + optional_language_field = {"language": language} if return_language else {} + timings = [ { "text": word, "timestamp": (token_timestamps[indices[0]][0], token_timestamps[indices[-1]][1]), + **optional_language_field } for word, indices in zip(words, token_indices) ]