Skip to content

Commit

Permalink
drop python 3.7 support (openai#889)
Browse files Browse the repository at this point in the history
  • Loading branch information
jongwook authored and abyesilyurt committed Nov 13, 2023
1 parent 37bfedb commit aead305
Show file tree
Hide file tree
Showing 3 changed files with 241 additions and 95 deletions.
19 changes: 6 additions & 13 deletions whisper/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,11 +252,10 @@ def __init__(self, temperature: float, eot: int):
self.eot = eot

def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
temperature = self.temperature
if temperature == 0:
if self.temperature == 0:
next_tokens = logits.argmax(dim=-1)
else:
next_tokens = Categorical(logits=logits / temperature).sample()
next_tokens = Categorical(logits=logits / self.temperature).sample()

logprobs = F.log_softmax(logits.float(), dim=-1)
current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens]
Expand Down Expand Up @@ -511,10 +510,8 @@ def _verify_options(self, options: DecodingOptions) -> DecodingOptions:

def _get_initial_tokens(self) -> Tuple[int]:
tokens = list(self.sot_sequence)
prefix = self.options.prefix
prompt = self.options.prompt

if prefix:
if prefix := self.options.prefix:
prefix_tokens = (
self.tokenizer.encode(" " + prefix.strip()) if isinstance(prefix, str) else prefix
)
Expand All @@ -523,7 +520,7 @@ def _get_initial_tokens(self) -> Tuple[int]:
prefix_tokens = prefix_tokens[-max_prefix_len:]
tokens = tokens + prefix_tokens

if prompt:
if prompt := self.options.prompt:
prompt_tokens = (
self.tokenizer.encode(" " + prompt.strip()) if isinstance(prompt, str) else prompt
)
Expand Down Expand Up @@ -698,13 +695,9 @@ def decode(model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOpt
result: Union[DecodingResult, List[DecodingResult]]
The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
"""
single = mel.ndim == 2
if single:
if single := mel.ndim == 2:
mel = mel.unsqueeze(0)

result = DecodingTask(model, options).run(mel)

if single:
result = result[0]

return result
return result[0] if single else result
38 changes: 13 additions & 25 deletions whisper/tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from dataclasses import dataclass
from functools import lru_cache
from functools import lru_cache, cached_property
from typing import List, Optional, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -156,43 +156,35 @@ def decode_with_timestamps(self, tokens) -> str:
outputs = [s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
return "".join(outputs)

@property
@lru_cache()
@cached_property
def eot(self) -> int:
return self.tokenizer.eos_token_id

@property
@lru_cache()
@cached_property
def sot(self) -> int:
return self._get_single_token_id("<|startoftranscript|>")

@property
@lru_cache()
@cached_property
def sot_lm(self) -> int:
return self._get_single_token_id("<|startoflm|>")

@property
@lru_cache()
@cached_property
def sot_prev(self) -> int:
return self._get_single_token_id("<|startofprev|>")

@property
@lru_cache()
@cached_property
def no_speech(self) -> int:
return self._get_single_token_id("<|nospeech|>")

@property
@lru_cache()
@cached_property
def no_timestamps(self) -> int:
return self._get_single_token_id("<|notimestamps|>")

@property
@lru_cache()
@cached_property
def timestamp_begin(self) -> int:
return self.tokenizer.all_special_ids[-1] + 1

@property
@lru_cache()
@cached_property
def language_token(self) -> int:
"""Returns the token id corresponding to the value of the `language` field"""
if self.language is None:
Expand All @@ -210,8 +202,7 @@ def language_token(self) -> int:

raise KeyError(f"Language {self.language} not found in tokenizer.")

@property
@lru_cache()
@cached_property
def all_language_tokens(self) -> Tuple[int]:
result = []
for token, token_id in zip(
Expand All @@ -222,18 +213,15 @@ def all_language_tokens(self) -> Tuple[int]:
result.append(token_id)
return tuple(result)

@property
@lru_cache()
@cached_property
def all_language_codes(self) -> Tuple[str]:
return tuple(self.decode([l]).strip("<|>") for l in self.all_language_tokens)

@property
@lru_cache()
@cached_property
def sot_sequence_including_notimestamps(self) -> Tuple[int]:
return tuple(list(self.sot_sequence) + [self.no_timestamps])

@property
@lru_cache()
@cached_property
def non_speech_tokens(self) -> Tuple[int]:
"""
Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
Expand Down
Loading

0 comments on commit aead305

Please sign in to comment.