diff --git a/src/transformers/models/whisper/tokenization_whisper.py b/src/transformers/models/whisper/tokenization_whisper.py index 303822de65f8b0..c5015f3b6e934a 100644 --- a/src/transformers/models/whisper/tokenization_whisper.py +++ b/src/transformers/models/whisper/tokenization_whisper.py @@ -1013,7 +1013,7 @@ def new_chunk(): chunk["text"] = resolved_text if return_timestamps == "word": chunk["words"] = _collate_word_timestamps( - tokenizer, resolved_tokens, resolved_token_timestamps, last_language + tokenizer, resolved_tokens, resolved_token_timestamps, last_language, return_language ) chunks.append(chunk) @@ -1065,7 +1065,7 @@ def new_chunk(): chunk["text"] = resolved_text if return_timestamps == "word": chunk["words"] = _collate_word_timestamps( - tokenizer, resolved_tokens, resolved_token_timestamps, last_language + tokenizer, resolved_tokens, resolved_token_timestamps, last_language, return_language ) chunks.append(chunk) @@ -1197,12 +1197,16 @@ def _find_longest_common_sequence(sequences, token_timestamp_sequences=None): return total_sequence, [] -def _collate_word_timestamps(tokenizer, tokens, token_timestamps, language): +def _collate_word_timestamps(tokenizer, tokens, token_timestamps, language, return_language): words, _, token_indices = _combine_tokens_into_words(tokenizer, tokens, language) + + optional_language_field = {"language": language} if return_language else {} + timings = [ { "text": word, "timestamp": (token_timestamps[indices[0]][0], token_timestamps[indices[-1]][1]), + **optional_language_field } for word, indices in zip(words, token_indices) ]