Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix broken type annotations, and replace pytype disables with typing.cast #1384

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 50 additions & 55 deletions t5x/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import dataclasses
import functools
import inspect
import typing
from typing import Any, Callable, Mapping, MutableMapping, Optional, Tuple, Union

from absl import logging
Expand Down Expand Up @@ -210,17 +211,15 @@ def predict_batch_with_aux(
predictions: the model predictions
aux: auxiliary data
"""
pass

@abc.abstractmethod
def score_batch(
self,
params: PyTree,
batch: Mapping[str, jnp.ndarray],
return_intermediates: bool = False,
) -> jnp.ndarray:
) -> Union[jnp.ndarray, Tuple[jnp.ndarray, Mapping[str, Any]]]:
"""Computes scores for batch."""
pass

@abc.abstractmethod
def get_initial_variables(
Expand All @@ -230,7 +229,6 @@ def get_initial_variables(
input_types: Optional[Mapping[str, jnp.dtype]] = None,
) -> flax_scope.FrozenVariableDict:
"""Returns the initial variables of the model."""
pass


class BaseTransformerModel(BaseModel):
Expand Down Expand Up @@ -281,9 +279,8 @@ def _compute_logits(
params: PyTree,
batch: Mapping[str, jnp.ndarray],
dropout_rng: Optional[jax.random.KeyArray] = None,
) -> jnp.ndarray:
) -> Union[jnp.ndarray, Tuple[jnp.ndarray, flax_scope.FrozenVariableDict]]:
"""Computes logits via a forward pass of the model."""
pass

def loss_fn(
self,
Expand Down Expand Up @@ -384,9 +381,8 @@ def __init__(
default_decoder_params: Optional[DecoderParams] = None,
):
if feature_converter_cls is not None:
self.FEATURE_CONVERTER_CLS = (
feature_converter_cls # pylint: disable=invalid-name
)
# pylint: disable-next=invalid-name
self.FEATURE_CONVERTER_CLS = feature_converter_cls
self._default_decoder_params = default_decoder_params or DecoderParams()
super().__init__(
module=module,
Expand Down Expand Up @@ -455,7 +451,7 @@ def get_initial_variables(
)
return initial_variables

def _compute_logits( # pytype: disable=signature-mismatch # jax-ndarray
def _compute_logits(
self,
params: PyTree,
batch: Mapping[str, jnp.ndarray],
Expand Down Expand Up @@ -665,7 +661,7 @@ def predict_batch_with_aux(
else:
decoder_prompt_inputs = jnp.zeros_like(decoder_input_tokens)

encoded_inputs = self.module.apply(
encoded_inputs: jnp.ndarray = self.module.apply(
{'params': params},
encoder_input_tokens,
enable_dropout=False,
Expand Down Expand Up @@ -759,7 +755,7 @@ def predict_batch_with_aux(
else:
return decodes[:, -1, :], {'scores': scores[:, -1]}

def score_batch( # pytype: disable=signature-mismatch # jax-ndarray
def score_batch(
self,
params: PyTree,
batch: Mapping[str, jnp.ndarray],
Expand All @@ -773,23 +769,9 @@ def score_batch( # pytype: disable=signature-mismatch # jax-ndarray
logits, modified_variables = self._compute_logits(
params=params, batch=batch, mutable=['intermediates']
)

# Inside self.module, we called nn.Module.sow to track various
# intermediate values. We extract them here.
intermediates = flax_core.unfreeze(
modified_variables.get('intermediates', {})
)

# Track per-token labels and loss weights as well. These are not
# intermediate values of logit computation, so we manually add them here.
intermediates.setdefault('decoder', {})
intermediates['decoder']['target_tokens'] = (target_tokens,)
intermediates['decoder']['loss_weights'] = (weights,)
# Note that the values are singleton tuples. This is because values inside
# `intermediates` should be tuples tracking all instantiations of a value.
# These values each have just one instantiation, hence singletons.
else:
logits = self._compute_logits(params, batch) # type: jnp.ndarray # pytype: disable=annotation-type-mismatch # jax-ndarray
logits = typing.cast(jnp.ndarray, self._compute_logits(params, batch))
modified_variables = {}

# Purposefully don't use config.z_loss because that term is for training
# stability and shouldn't affect our reported scores.
Expand All @@ -803,12 +785,26 @@ def score_batch( # pytype: disable=signature-mismatch # jax-ndarray
)[0]
* weights
)
if return_intermediates:
intermediates['decoder']['token_scores'] = (token_scores,)

sequence_scores = token_scores.sum(-1)

if return_intermediates:

# Inside self.module, we called nn.Module.sow to track various
# intermediate values. We extract them here.
intermediates = flax_core.unfreeze(
modified_variables.get('intermediates', {})
)

# Track per-token labels and loss weights as well. These are not
# intermediate values of logit computation, so we manually add them here.
intermediates.setdefault('decoder', {})
intermediates['decoder']['target_tokens'] = (target_tokens,)
intermediates['decoder']['loss_weights'] = (weights,)
# Note that the values are singleton tuples. This is because values inside
# `intermediates` should be tuples tracking all instantiations of a value.
# These values each have just one instantiation, hence singletons.
intermediates['decoder']['token_scores'] = (token_scores,)
return sequence_scores, intermediates

return sequence_scores
Expand Down Expand Up @@ -847,9 +843,8 @@ def __init__(
] = None,
):
if feature_converter_cls is not None:
self.FEATURE_CONVERTER_CLS = (
feature_converter_cls # pylint: disable=invalid-name
)
# pylint: disable-next=invalid-name
self.FEATURE_CONVERTER_CLS = feature_converter_cls
self._inputs_bidirectional_attention = inputs_bidirectional_attention
super().__init__(
module,
Expand Down Expand Up @@ -901,7 +896,7 @@ def _compute_logits(
dropout_rng: Optional[jax.random.KeyArray] = None,
mutable: flax_scope.CollectionFilter = False,
other_variables: Optional[PyTree] = None,
) -> jnp.ndarray:
) -> Union[jnp.ndarray, Tuple[jnp.ndarray, flax_scope.FrozenVariableDict]]:
"""Computes logits via a forward pass of `self.module`."""
rngs = {'dropout': dropout_rng} if dropout_rng is not None else None
decoder_causal_attention = self._get_decoder_causal_attention(batch)
Expand Down Expand Up @@ -954,7 +949,7 @@ def score_batch(
params: PyTree,
batch: Mapping[str, jnp.ndarray],
return_intermediates: bool = False,
) -> jnp.ndarray:
) -> Union[jnp.ndarray, Tuple[jnp.ndarray, Mapping[str, Any]]]:
"""Compute log likelihood score on a batch."""

decoder_target_tokens = batch['decoder_target_tokens']
Expand All @@ -967,7 +962,26 @@ def score_batch(
dropout_rng=None,
mutable=['intermediates'],
)
else:
logits = typing.cast(
jnp.ndarray,
self._compute_logits(params=params, batch=batch, dropout_rng=None),
)
modified_variables = {}

token_scores = (
-losses.cross_entropy_with_logits(
logits,
common_utils.onehot(
decoder_target_tokens, logits.shape[-1], on_value=1, off_value=0
),
z_loss=0.0,
)[0]
* weights
)
sequence_scores = token_scores.sum(-1)

if return_intermediates:
# Inside self.module, we called nn.Module.sow to track various
# intermediate values. We extract them here.
intermediates = flax_core.unfreeze(
Expand All @@ -982,28 +996,9 @@ def score_batch(
# Note that the values are singleton tuples. This is because values inside
# `intermediates` should be tuples tracking all instantiations of a value.
# These values each have just one instantiation, hence singletons.
else:
logits = self._compute_logits(
params=params, batch=batch, dropout_rng=None
)

token_scores = (
-losses.cross_entropy_with_logits(
logits,
common_utils.onehot(
decoder_target_tokens, logits.shape[-1], on_value=1, off_value=0
),
z_loss=0.0,
)[0]
* weights
)
if return_intermediates:
intermediates['decoder']['token_scores'] = (token_scores,)

sequence_scores = token_scores.sum(-1)

if return_intermediates:
return sequence_scores, intermediates # pytype: disable=bad-return-type # jax-ndarray
return sequence_scores, intermediates

return sequence_scores

Expand Down