Skip to content

Commit

Permalink
Refactored modules/tokenizers to be a subdir of modules/transforms (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Ankur-singh authored Jan 27, 2025
1 parent 5764650 commit 3cceb86
Show file tree
Hide file tree
Showing 41 changed files with 113 additions and 72 deletions.
12 changes: 6 additions & 6 deletions docs/source/api_ref_modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ model specific tokenizers.
:toctree: generated/
:nosignatures:

tokenizers.SentencePieceBaseTokenizer
tokenizers.TikTokenBaseTokenizer
tokenizers.ModelTokenizer
tokenizers.BaseTokenizer
transforms.tokenizers.SentencePieceBaseTokenizer
transforms.tokenizers.TikTokenBaseTokenizer
transforms.tokenizers.ModelTokenizer
transforms.tokenizers.BaseTokenizer

Tokenizer Utilities
-------------------
Expand All @@ -61,8 +61,8 @@ These are helper methods that can be used by any tokenizer.
:toctree: generated/
:nosignatures:

tokenizers.tokenize_messages_no_special_tokens
tokenizers.parse_hf_tokenizer_json
transforms.tokenizers.tokenize_messages_no_special_tokens
transforms.tokenizers.parse_hf_tokenizer_json


PEFT Components
Expand Down
2 changes: 1 addition & 1 deletion docs/source/basics/custom_components.rst
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ our models in torchtune - see :func:`~torchtune.models.llama3_2_vision.llama3_2_
#
from torchtune.datasets import SFTDataset, PackedDataset
from torchtune.data import InputOutputToMessages
from torchtune.modules.tokenizers import ModelTokenizer
from torchtune.modules.transforms.tokenizers import ModelTokenizer
# Example builder function for a custom code instruct dataset not in torchtune, but using
# different dataset building blocks from torchtune
Expand Down
2 changes: 1 addition & 1 deletion docs/source/basics/model_transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ The following methods are required on the model transform:

.. code-block:: python
from torchtune.modules.tokenizers import ModelTokenizer
from torchtune.modules.transforms.tokenizers import ModelTokenizer
from torchtune.modules.transforms import Transform
class MyMultimodalTransform(ModelTokenizer, Transform):
Expand Down
10 changes: 5 additions & 5 deletions docs/source/basics/tokenizers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ For example, here we change the ``"<|begin_of_text|>"`` and ``"<|end_of_text|>"`
Base tokenizers
---------------

:class:`~torchtune.modules.tokenizers.BaseTokenizer` are the underlying byte-pair encoding modules that perform the actual raw string to token ID conversion and back.
:class:`~torchtune.modules.transforms.tokenizers.BaseTokenizer` are the underlying byte-pair encoding modules that perform the actual raw string to token ID conversion and back.
In torchtune, they are required to implement ``encode`` and ``decode`` methods, which are called by the :ref:`model_tokenizers` to convert
between raw text and token IDs.

Expand Down Expand Up @@ -202,13 +202,13 @@ between raw text and token IDs.
"""
pass
If you load any :ref:`model_tokenizers`, you can see that it calls its underlying :class:`~torchtune.modules.tokenizers.BaseTokenizer`
If you load any :ref:`model_tokenizers`, you can see that it calls its underlying :class:`~torchtune.modules.transforms.tokenizers.BaseTokenizer`
to do the actual encoding and decoding.

.. code-block:: python
from torchtune.models.mistral import mistral_tokenizer
from torchtune.modules.tokenizers import SentencePieceBaseTokenizer
from torchtune.modules.transforms.tokenizers import SentencePieceBaseTokenizer
m_tokenizer = mistral_tokenizer("/tmp/Mistral-7B-v0.1/tokenizer.model")
# Mistral uses SentencePiece for its underlying BPE
Expand All @@ -227,7 +227,7 @@ to do the actual encoding and decoding.
Model tokenizers
----------------

:class:`~torchtune.modules.tokenizers.ModelTokenizer` are specific to a particular model. They are required to implement the ``tokenize_messages`` method,
:class:`~torchtune.modules.transforms.tokenizers.ModelTokenizer` are specific to a particular model. They are required to implement the ``tokenize_messages`` method,
which converts a list of Messages into a list of token IDs.

.. code-block:: python
Expand Down Expand Up @@ -259,7 +259,7 @@ is because they add all the necessary special tokens or prompt templates require
.. code-block:: python
from torchtune.models.mistral import mistral_tokenizer
from torchtune.modules.tokenizers import SentencePieceBaseTokenizer
from torchtune.modules.transforms.tokenizers import SentencePieceBaseTokenizer
from torchtune.data import Message
m_tokenizer = mistral_tokenizer("/tmp/Mistral-7B-v0.1/tokenizer.model")
Expand Down
2 changes: 1 addition & 1 deletion recipes/eleuther_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@
from torchtune.modules import TransformerDecoder
from torchtune.modules.common_utils import local_kv_cache
from torchtune.modules.model_fusion import DeepFusionModel
from torchtune.modules.tokenizers import ModelTokenizer
from torchtune.modules.transforms import Transform
from torchtune.modules.transforms.tokenizers import ModelTokenizer
from torchtune.recipe_interfaces import EvalRecipeInterface
from torchtune.training import FullModelTorchTuneCheckpointer

Expand Down
2 changes: 1 addition & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
import torch
from torch import nn
from torchtune.data import Message, PromptTemplate, truncate
from torchtune.modules.tokenizers import ModelTokenizer
from torchtune.modules.transforms import Transform
from torchtune.modules.transforms.tokenizers import ModelTokenizer

skip_if_cuda_not_available = unittest.skipIf(
not torch.cuda.is_available(), "CUDA is not available"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest

from tests.common import ASSETS
from torchtune.modules.tokenizers import SentencePieceBaseTokenizer
from torchtune.modules.transforms.tokenizers import SentencePieceBaseTokenizer


class TestSentencePieceBaseTokenizer:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from tests.common import ASSETS
from torchtune.models.llama3._tokenizer import CL100K_PATTERN
from torchtune.modules.tokenizers import TikTokenBaseTokenizer
from torchtune.modules.transforms.tokenizers import TikTokenBaseTokenizer


class TestTikTokenBaseTokenizer:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from tests.test_utils import DummyTokenizer
from torchtune.data import Message

from torchtune.modules.tokenizers import tokenize_messages_no_special_tokens
from torchtune.modules.transforms.tokenizers import tokenize_messages_no_special_tokens


class TestTokenizerUtils:
Expand Down
7 changes: 4 additions & 3 deletions torchtune/data/_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@
class Message:
"""
This class represents individual messages in a fine-tuning dataset. It supports
text-only content, text with interleaved images, and tool calls. The :class:`~torchtune.modules.tokenizers.ModelTokenizer`
will tokenize the content of the message using ``tokenize_messages`` and attach
the appropriate special tokens based on the flags set in this class.
text-only content, text with interleaved images, and tool calls. The
:class:`~torchtune.modules.transforms.tokenizers.ModelTokenizer` will tokenize
the content of the message using ``tokenize_messages`` and attach the appropriate
special tokens based on the flags set in this class.
Args:
role (Role): role of the message writer. Can be "system" for system prompts,
Expand Down
2 changes: 1 addition & 1 deletion torchtune/datasets/_alpaca.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from torchtune.datasets._packed import PackedDataset
from torchtune.datasets._sft import SFTDataset
from torchtune.modules.tokenizers import ModelTokenizer
from torchtune.modules.transforms.tokenizers import ModelTokenizer


def alpaca_dataset(
Expand Down
2 changes: 1 addition & 1 deletion torchtune/datasets/_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torchtune.data._messages import OpenAIToMessages, ShareGPTToMessages
from torchtune.datasets._packed import PackedDataset
from torchtune.datasets._sft import SFTDataset
from torchtune.modules.tokenizers import ModelTokenizer
from torchtune.modules.transforms.tokenizers import ModelTokenizer


def chat_dataset(
Expand Down
2 changes: 1 addition & 1 deletion torchtune/datasets/_cnn_dailymail.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from torchtune.datasets._text_completion import TextCompletionDataset

from torchtune.modules.tokenizers import ModelTokenizer
from torchtune.modules.transforms.tokenizers import ModelTokenizer


def cnn_dailymail_articles_dataset(
Expand Down
2 changes: 1 addition & 1 deletion torchtune/datasets/_grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torchtune.data import InputOutputToMessages
from torchtune.datasets._packed import PackedDataset
from torchtune.datasets._sft import SFTDataset
from torchtune.modules.tokenizers import ModelTokenizer
from torchtune.modules.transforms.tokenizers import ModelTokenizer


def grammar_dataset(
Expand Down
2 changes: 1 addition & 1 deletion torchtune/datasets/_hh_rlhf_helpful.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from torchtune.data import ChosenRejectedToMessages
from torchtune.datasets._preference import PreferenceDataset
from torchtune.modules.tokenizers import ModelTokenizer
from torchtune.modules.transforms.tokenizers import ModelTokenizer


def hh_rlhf_helpful_dataset(
Expand Down
2 changes: 1 addition & 1 deletion torchtune/datasets/_instruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torchtune.data import InputOutputToMessages
from torchtune.datasets._packed import PackedDataset
from torchtune.datasets._sft import SFTDataset
from torchtune.modules.tokenizers import ModelTokenizer
from torchtune.modules.transforms.tokenizers import ModelTokenizer


def instruct_dataset(
Expand Down
6 changes: 3 additions & 3 deletions torchtune/datasets/_preference.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
from torch.utils.data import Dataset

from torchtune.data import ChosenRejectedToMessages, CROSS_ENTROPY_IGNORE_IDX

from torchtune.modules.tokenizers import ModelTokenizer
from torchtune.modules.transforms import Transform

from torchtune.modules.transforms.tokenizers import ModelTokenizer


class PreferenceDataset(Dataset):
"""
Expand Down Expand Up @@ -84,7 +84,7 @@ class requires the dataset to have "chosen" and "rejected" model responses. Thes
of messages are stored in the ``"chosen"`` and ``"rejected"`` keys.
tokenizer (ModelTokenizer): Tokenizer used by the model that implements the ``tokenize_messages`` method.
Since PreferenceDataset only supports text data, it requires a
:class:`~torchtune.modules.tokenizers.ModelTokenizer` instead of the ``model_transform`` in
:class:`~torchtune.modules.transforms.tokenizers.ModelTokenizer` instead of the ``model_transform`` in
:class:`~torchtune.datasets.SFTDataset`.
filter_fn (Optional[Callable]): callable used to filter the dataset prior to any pre-processing. See
the Hugging Face `docs <https://huggingface.co/docs/datasets/v2.20.0/process#select-and-filter>`_ for more
Expand Down
2 changes: 1 addition & 1 deletion torchtune/datasets/_samsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torchtune.data import InputOutputToMessages
from torchtune.datasets._packed import PackedDataset
from torchtune.datasets._sft import SFTDataset
from torchtune.modules.tokenizers import ModelTokenizer
from torchtune.modules.transforms.tokenizers import ModelTokenizer


def samsum_dataset(
Expand Down
12 changes: 7 additions & 5 deletions torchtune/datasets/_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,13 @@ class SFTDataset(Dataset):
multimodal datasets requires processing the images in a way specific to the vision
encoder being used by the model and is agnostic to the specific dataset.
Tokenization is handled by the ``model_transform``. All :class:`~torchtune.modules.tokenizers.ModelTokenizer`
can be treated as a ``model_transform`` since it uses the model-specific tokenizer to
transform the list of messages outputted from the ``message_transform`` into tokens
used by the model for training. Text-only datasets will simply pass the :class:`~torchtune.modules.tokenizers.ModelTokenizer`
into ``model_transform``. Tokenizers handle prompt templating, if configured.
Tokenization is handled by the ``model_transform``. All
:class:`~torchtune.modules.transforms.tokenizers.ModelTokenizer` can be treated as
a ``model_transform`` since it uses the model-specific tokenizer to transform the
list of messages outputted from the ``message_transform`` into tokens used by the
model for training. Text-only datasets will simply pass the
:class:`~torchtune.modules.transforms.tokenizers.ModelTokenizer` into ``model_transform``.
Tokenizers handle prompt templating, if configured.
Args:
source (str): path to dataset repository on Hugging Face. For local datasets,
Expand Down
2 changes: 1 addition & 1 deletion torchtune/datasets/_slimorca.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torchtune.datasets._packed import PackedDataset

from torchtune.datasets._sft import SFTDataset
from torchtune.modules.tokenizers import ModelTokenizer
from torchtune.modules.transforms.tokenizers import ModelTokenizer


def slimorca_dataset(
Expand Down
2 changes: 1 addition & 1 deletion torchtune/datasets/_stack_exchange_paired.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

from torchtune.data import Message
from torchtune.datasets._preference import PreferenceDataset
from torchtune.modules.tokenizers import ModelTokenizer
from torchtune.modules.transforms import Transform
from torchtune.modules.transforms.tokenizers import ModelTokenizer


class StackExchangePairedToMessages(Transform):
Expand Down
2 changes: 1 addition & 1 deletion torchtune/datasets/_text_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torch.utils.data import Dataset
from torchtune.data._utils import truncate
from torchtune.datasets._packed import PackedDataset
from torchtune.modules.tokenizers import ModelTokenizer
from torchtune.modules.transforms.tokenizers import ModelTokenizer


class TextCompletionDataset(Dataset):
Expand Down
2 changes: 1 addition & 1 deletion torchtune/datasets/_wikitext.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
TextCompletionDataset,
)

from torchtune.modules.tokenizers import ModelTokenizer
from torchtune.modules.transforms.tokenizers import ModelTokenizer


def wikitext_dataset(
Expand Down
2 changes: 1 addition & 1 deletion torchtune/models/clip/_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import regex as re

from torchtune.modules.tokenizers._utils import BaseTokenizer
from torchtune.modules.transforms.tokenizers._utils import BaseTokenizer

WORD_BOUNDARY = "</w>"

Expand Down
4 changes: 2 additions & 2 deletions torchtune/models/gemma/_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
from typing import Any, List, Mapping, Optional, Tuple

from torchtune.data import Message, PromptTemplate
from torchtune.modules.tokenizers import (
from torchtune.modules.transforms import Transform
from torchtune.modules.transforms.tokenizers import (
ModelTokenizer,
SentencePieceBaseTokenizer,
tokenize_messages_no_special_tokens,
)
from torchtune.modules.transforms import Transform

WHITESPACE_CHARS = [" ", "\n", "\t", "\r", "\v"]

Expand Down
4 changes: 2 additions & 2 deletions torchtune/models/llama2/_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@

from torchtune.data import Message, PromptTemplate
from torchtune.models.llama2._prompt_template import Llama2ChatTemplate
from torchtune.modules.tokenizers import (
from torchtune.modules.transforms import Transform
from torchtune.modules.transforms.tokenizers import (
ModelTokenizer,
SentencePieceBaseTokenizer,
tokenize_messages_no_special_tokens,
)
from torchtune.modules.transforms import Transform

WHITESPACE_CHARS = [" ", "\n", "\t", "\r", "\v"]

Expand Down
2 changes: 1 addition & 1 deletion torchtune/models/llama3/_model_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from torchtune.modules import TransformerDecoder
from torchtune.modules.peft import LORA_ATTN_MODULES
from torchtune.modules.tokenizers import parse_hf_tokenizer_json
from torchtune.modules.transforms.tokenizers import parse_hf_tokenizer_json


"""
Expand Down
5 changes: 4 additions & 1 deletion torchtune/models/llama3/_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@
from typing import Any, Dict, List, Mapping, Optional, Tuple

from torchtune.data import Message, PromptTemplate, truncate
from torchtune.modules.tokenizers import ModelTokenizer, TikTokenBaseTokenizer
from torchtune.modules.transforms import Transform
from torchtune.modules.transforms.tokenizers import (
ModelTokenizer,
TikTokenBaseTokenizer,
)


CL100K_PATTERN = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""" # noqa
Expand Down
2 changes: 1 addition & 1 deletion torchtune/models/llama3_2_vision/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

from torchtune.models.clip import CLIPImageTransform
from torchtune.models.llama3 import llama3_tokenizer
from torchtune.modules.tokenizers import ModelTokenizer
from torchtune.modules.transforms import Transform, VisionCrossAttentionMask
from torchtune.modules.transforms.tokenizers import ModelTokenizer


class Llama3VisionTransform(ModelTokenizer, Transform):
Expand Down
4 changes: 2 additions & 2 deletions torchtune/models/mistral/_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@

from torchtune.data import Message, PromptTemplate
from torchtune.models.mistral._prompt_template import MistralChatTemplate
from torchtune.modules.tokenizers import (
from torchtune.modules.transforms import Transform
from torchtune.modules.transforms.tokenizers import (
ModelTokenizer,
SentencePieceBaseTokenizer,
tokenize_messages_no_special_tokens,
)
from torchtune.modules.transforms import Transform

WHITESPACE_CHARS = [" ", "\n", "\t", "\r", "\v"]

Expand Down
2 changes: 1 addition & 1 deletion torchtune/models/phi3/_model_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torchtune.modules import TransformerDecoder
from torchtune.modules.peft import LORA_ATTN_MODULES
from functools import partial
from torchtune.modules.tokenizers import parse_hf_tokenizer_json
from torchtune.modules.transforms.tokenizers import parse_hf_tokenizer_json
from torchtune.data._prompt_templates import _TemplateType
from torchtune.data._prompt_templates import _get_prompt_template

Expand Down
5 changes: 4 additions & 1 deletion torchtune/models/phi3/_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@
from torchtune.data._messages import Message
from torchtune.data._prompt_templates import PromptTemplate
from torchtune.data._utils import truncate
from torchtune.modules.tokenizers import ModelTokenizer, SentencePieceBaseTokenizer
from torchtune.modules.transforms import Transform
from torchtune.modules.transforms.tokenizers import (
ModelTokenizer,
SentencePieceBaseTokenizer,
)

PHI3_SPECIAL_TOKENS = {
"<|endoftext|>": 32000,
Expand Down
Loading

0 comments on commit 3cceb86

Please sign in to comment.