diff --git a/flash/text/classification/data.py b/flash/text/classification/data.py index cadad286d9..462d8d323c 100644 --- a/flash/text/classification/data.py +++ b/flash/text/classification/data.py @@ -37,9 +37,14 @@ else: Dataset = object +# Skip doctests if requirements aren't available +if not _TEXT_AVAILABLE: + __doctest_skip__ = ["TextClassificationData", "TextClassificationData.*"] + class TextClassificationData(DataModule): - """Data Module for text classification tasks.""" + """The ``TextClassificationData`` class is a :class:`~flash.core.data.data_module.DataModule` with a set of + classmethods for loading data for text classification.""" input_transform_cls = TransformersInputTransform @@ -61,28 +66,96 @@ def from_csv( max_length: int = 128, **data_module_kwargs: Any, ) -> "TextClassificationData": - """Creates a :class:`~flash.text.classification.data.TextClassificationData` object from the given CSV - files. + """Load the :class:`~flash.text.classification.data.TextClassificationData` from CSV files containing text + snippets and their corresponding targets. + + Input text snippets will be extracted from the ``input_field`` column in the CSV files. + The targets will be extracted from the ``target_fields`` in the CSV files and can be in any of our + :ref:`supported classification target formats `. + To learn how to customize the transforms applied for each stage, read our + :ref:`customizing transforms guide `. Args: - input_field: The field (column) in the pandas ``Dataset`` to use for the input. - target_fields: The field or fields (columns) in the pandas ``Dataset`` to use for the target. - train_file: The CSV file containing the training data. - val_file: The CSV file containing the validation data. - test_file: The CSV file containing the testing data. - predict_file: The CSV file containing the data to use when predicting. - train_transform: The dictionary of transforms to use during training which maps - :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. - val_transform: The dictionary of transforms to use during validation which maps - :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. - test_transform: The dictionary of transforms to use during testing which maps - :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. - predict_transform: The dictionary of transforms to use during predicting which maps - :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. - max_length: The maximum sequence length. + input_field: The field (column name) in the CSV files containing the text snippets. + target_fields: The field (column name) or list of fields in the CSV files containing the targets. + train_file: The CSV file to use when training. + val_file: The CSV file to use when validating. + test_file: The CSV file to use when testing. + predict_file: The CSV file to use when predicting. + train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. + val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. + test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. + predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when + predicting. + input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. + max_length: The maximum length to pad / truncate sequences to. + data_module_kwargs: Additional keyword arguments to provide to the + :class:`~flash.core.data.data_module.DataModule` constructor. Returns: - The constructed data module. + The constructed :class:`~flash.text.classification.data.TextClassificationData`. + + Examples + ________ + + .. testsetup:: + + >>> import os + >>> from pandas import DataFrame + >>> DataFrame.from_dict({ + ... "reviews": ["Best movie ever!", "Not good", "Fine I guess"], + ... "targets": ["positive", "negative", "neutral"], + ... }).to_csv("train_data.csv", index=False) + >>> DataFrame.from_dict({ + ... "reviews": ["Worst movie ever!", "I didn't enjoy it", "It was ok"], + ... }).to_csv("predict_data.csv", index=False) + + The file ``train_data.csv`` contains the following: + + .. code-block:: + + reviews,targets + Best movie ever!,positive + Not good,negative + Fine I guess,neutral + + The file ``predict_data.csv`` contains the following: + + .. code-block:: + + reviews + Worst movie ever! + I didn't enjoy it + It was ok + + .. doctest:: + + >>> from flash import Trainer + >>> from flash.text import TextClassifier, TextClassificationData + >>> datamodule = TextClassificationData.from_csv( + ... "reviews", + ... "targets", + ... train_file="train_data.csv", + ... predict_file="predict_data.csv", + ... batch_size=2, + ... ) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Downloading... + >>> datamodule.num_classes + 3 + >>> datamodule.labels + ['negative', 'neutral', 'positive'] + >>> model = TextClassifier(num_classes=datamodule.num_classes) + >>> trainer = Trainer(fast_dev_run=True) + >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Training... + >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Predicting... + + .. testcleanup:: + + >>> os.remove("train_data.csv") + >>> os.remove("predict_data.csv") """ ds_kw = dict( @@ -121,29 +194,95 @@ def from_json( max_length: int = 128, **data_module_kwargs: Any, ) -> "TextClassificationData": - """Creates a :class:`~flash.text.classification.data.TextClassificationData` object from the given JSON - files. + """Load the :class:`~flash.text.classification.data.TextClassificationData` from JSON files containing text + snippets and their corresponding targets. + + Input text snippets will be extracted from the ``input_field`` in the JSON objects. + The targets will be extracted from the ``target_fields`` in the JSON objects and can be in any of our + :ref:`supported classification target formats `. + To learn how to customize the transforms applied for each stage, read our + :ref:`customizing transforms guide `. Args: - input_field: The field (column) in the pandas ``Dataset`` to use for the input. - target_fields: The field or fields (columns) in the pandas ``Dataset`` to use for the target. - train_file: The JSON file containing the training data. - val_file: The JSON file containing the validation data. - test_file: The JSON file containing the testing data. - predict_file: The JSON file containing the data to use when predicting. - train_transform: The dictionary of transforms to use during training which maps - :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. - val_transform: The dictionary of transforms to use during validation which maps - :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. - test_transform: The dictionary of transforms to use during testing which maps - :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. - predict_transform: The dictionary of transforms to use during predicting which maps - :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. + input_field: The field in the JSON objects containing the text snippets. + target_fields: The field or list of fields in the JSON objects containing the targets. + train_file: The JSON file to use when training. + val_file: The JSON file to use when validating. + test_file: The JSON file to use when testing. + predict_file: The JSON file to use when predicting. + train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. + val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. + test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. + predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when + predicting. + input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. field: To specify the field that holds the data in the JSON file. - max_length: The maximum sequence length. + max_length: The maximum length to pad / truncate sequences to. + data_module_kwargs: Additional keyword arguments to provide to the + :class:`~flash.core.data.data_module.DataModule` constructor. Returns: - The constructed data module. + The constructed :class:`~flash.text.classification.data.TextClassificationData`. + + Examples + ________ + + .. testsetup:: + + >>> import os + >>> from pandas import DataFrame + >>> DataFrame.from_dict({ + ... "reviews": ["Best movie ever!", "Not good", "Fine I guess"], + ... "targets": ["positive", "negative", "neutral"], + ... }).to_json("train_data.json", orient="records", lines=True) + >>> DataFrame.from_dict({ + ... "reviews": ["Worst movie ever!", "I didn't enjoy it", "It was ok"], + ... }).to_json("predict_data.json", orient="records", lines=True) + + The file ``train_data.json`` contains the following: + + .. code-block:: + + {"reviews":"Best movie ever!","targets":"positive"} + {"reviews":"Not good","targets":"negative"} + {"reviews":"Fine I guess","targets":"neutral"} + + The file ``predict_data.json`` contains the following: + + .. code-block:: + + {"reviews":"Worst movie ever!"} + {"reviews":"I didn't enjoy it"} + {"reviews":"It was ok"} + + .. doctest:: + + >>> from flash import Trainer + >>> from flash.text import TextClassifier, TextClassificationData + >>> datamodule = TextClassificationData.from_json( + ... "reviews", + ... "targets", + ... train_file="train_data.json", + ... predict_file="predict_data.json", + ... batch_size=2, + ... ) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Downloading... + >>> datamodule.num_classes + 3 + >>> datamodule.labels + ['negative', 'neutral', 'positive'] + >>> model = TextClassifier(num_classes=datamodule.num_classes) + >>> trainer = Trainer(fast_dev_run=True) + >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Training... + >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Predicting... + + .. testcleanup:: + + >>> os.remove("train_data.json") + >>> os.remove("predict_data.json") """ ds_kw = dict( @@ -182,28 +321,96 @@ def from_parquet( max_length: int = 128, **data_module_kwargs: Any, ) -> "TextClassificationData": - """Creates a :class:`~flash.text.classification.data.TextClassificationData` object from the given PARQUET - files. + """Load the :class:`~flash.text.classification.data.TextClassificationData` from PARQUET files containing + text snippets and their corresponding targets. + + Input text snippets will be extracted from the ``input_field`` column in the PARQUET files. + The targets will be extracted from the ``target_fields`` in the PARQUET files and can be in any of our + :ref:`supported classification target formats `. + To learn how to customize the transforms applied for each stage, read our + :ref:`customizing transforms guide `. Args: - input_field: The field (column) in the pandas ``Dataset`` to use for the input. - target_fields: The field or fields (columns) in the pandas ``Dataset`` to use for the target. - train_file: The PARQUET file containing the training data. - val_file: The PARQUET file containing the validation data. - test_file: The PARQUET file containing the testing data. - predict_file: The PARQUET file containing the data to use when predicting. - train_transform: The dictionary of transforms to use during training which maps - :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. - val_transform: The dictionary of transforms to use during validation which maps - :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. - test_transform: The dictionary of transforms to use during testing which maps - :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. - predict_transform: The dictionary of transforms to use during predicting which maps - :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. - max_length: The maximum sequence length. + input_field: The field (column name) in the PARQUET files containing the text snippets. + target_fields: The field (column name) or list of fields in the PARQUET files containing the targets. + train_file: The PARQUET file to use when training. + val_file: The PARQUET file to use when validating. + test_file: The PARQUET file to use when testing. + predict_file: The PARQUET file to use when predicting. + train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. + val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. + test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. + predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when + predicting. + input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. + max_length: The maximum length to pad / truncate sequences to. + data_module_kwargs: Additional keyword arguments to provide to the + :class:`~flash.core.data.data_module.DataModule` constructor. Returns: - The constructed data module. + The constructed :class:`~flash.text.classification.data.TextClassificationData`. + + Examples + ________ + + .. testsetup:: + + >>> import os + >>> from pandas import DataFrame + >>> DataFrame.from_dict({ + ... "reviews": ["Best movie ever!", "Not good", "Fine I guess"], + ... "targets": ["positive", "negative", "neutral"], + ... }).to_parquet("train_data.parquet", index=False) + >>> DataFrame.from_dict({ + ... "reviews": ["Worst movie ever!", "I didn't enjoy it", "It was ok"], + ... }).to_parquet("predict_data.parquet", index=False) + + The file ``train_data.parquet`` contains the following contents encoded in the PARQUET format: + + .. code-block:: + + reviews,targets + Best movie ever!,positive + Not good,negative + Fine I guess,neutral + + The file ``predict_data.parquet`` contains the following contents encoded in the PARQUET format: + + .. code-block:: + + reviews + Worst movie ever! + I didn't enjoy it + It was ok + + .. doctest:: + + >>> from flash import Trainer + >>> from flash.text import TextClassifier, TextClassificationData + >>> datamodule = TextClassificationData.from_parquet( + ... "reviews", + ... "targets", + ... train_file="train_data.parquet", + ... predict_file="predict_data.parquet", + ... batch_size=2, + ... ) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Downloading... + >>> datamodule.num_classes + 3 + >>> datamodule.labels + ['negative', 'neutral', 'positive'] + >>> model = TextClassifier(num_classes=datamodule.num_classes) + >>> trainer = Trainer(fast_dev_run=True) + >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Training... + >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Predicting... + + .. testcleanup:: + + >>> os.remove("train_data.parquet") + >>> os.remove("predict_data.parquet") """ ds_kw = dict( @@ -241,28 +448,72 @@ def from_hf_datasets( max_length: int = 128, **data_module_kwargs: Any, ) -> "TextClassificationData": - """Creates a :class:`~flash.text.classification.data.TextClassificationData` object from the given Hugging - Face datasets ``Dataset`` objects. + """Load the :class:`~flash.text.classification.data.TextClassificationData` from Hugging Face ``Dataset`` + objects containing text snippets and their corresponding targets. + + Input text snippets will be extracted from the ``input_field`` column in the ``Dataset`` objects. + The targets will be extracted from the ``target_fields`` in the ``Dataset`` objects and can be in any of our + :ref:`supported classification target formats `. + To learn how to customize the transforms applied for each stage, read our + :ref:`customizing transforms guide `. Args: - input_field: The field (column) in the pandas ``Dataset`` to use for the input. - target_fields: The field or fields (columns) in the pandas ``Dataset`` to use for the target. - train_hf_dataset: The pandas ``Dataset`` containing the training data. - val_hf_dataset: The pandas ``Dataset`` containing the validation data. - test_hf_dataset: The pandas ``Dataset`` containing the testing data. - predict_hf_dataset: The pandas ``Dataset`` containing the data to use when predicting. - train_transform: The dictionary of transforms to use during training which maps - :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. - val_transform: The dictionary of transforms to use during validation which maps - :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. - test_transform: The dictionary of transforms to use during testing which maps - :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. - predict_transform: The dictionary of transforms to use during predicting which maps - :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. - max_length: The maximum sequence length. + input_field: The field (column name) in the ``Dataset`` objects containing the text snippets. + target_fields: The field (column name) or list of fields in the ``Dataset`` objects containing the targets. + train_hf_dataset: The ``Dataset`` to use when training. + val_hf_dataset: The ``Dataset`` to use when validating. + test_hf_dataset: The ``Dataset`` to use when testing. + predict_hf_dataset: The ``Dataset`` to use when predicting. + train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. + val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. + test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. + predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when + predicting. + input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. + max_length: The maximum length to pad / truncate sequences to. + data_module_kwargs: Additional keyword arguments to provide to the + :class:`~flash.core.data.data_module.DataModule` constructor. Returns: - The constructed data module. + The constructed :class:`~flash.text.classification.data.TextClassificationData`. + + Examples + ________ + + .. doctest:: + + >>> from datasets import Dataset + >>> from flash import Trainer + >>> from flash.text import TextClassifier, TextClassificationData + >>> train_data = Dataset.from_dict( + ... { + ... "reviews": ["Best movie ever!", "Not good", "Fine I guess"], + ... "targets": ["positive", "negative", "neutral"], + ... } + ... ) + >>> predict_data = Dataset.from_dict( + ... { + ... "reviews": ["Worst movie ever!", "I didn't enjoy it", "It was ok"], + ... } + ... ) + >>> datamodule = TextClassificationData.from_hf_datasets( + ... "reviews", + ... "targets", + ... train_hf_dataset=train_data, + ... predict_hf_dataset=predict_data, + ... batch_size=2, + ... ) + >>> datamodule.num_classes + 3 + >>> datamodule.labels + ['negative', 'neutral', 'positive'] + >>> model = TextClassifier(num_classes=datamodule.num_classes) + >>> trainer = Trainer(fast_dev_run=True) + >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Training... + >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Predicting... """ ds_kw = dict( @@ -300,28 +551,73 @@ def from_data_frame( max_length: int = 128, **data_module_kwargs: Any, ) -> "TextClassificationData": - """Creates a :class:`~flash.text.classification.data.TextClassificationData` object from the given pandas - ``DataFrame`` objects. + """Load the :class:`~flash.text.classification.data.TextClassificationData` from Pandas ``DataFrame`` + objects containing text snippets and their corresponding targets. + + Input text snippets will be extracted from the ``input_field`` column in the ``DataFrame`` objects. + The targets will be extracted from the ``target_fields`` in the ``DataFrame`` objects and can be in any of our + :ref:`supported classification target formats `. + To learn how to customize the transforms applied for each stage, read our + :ref:`customizing transforms guide `. Args: - input_field: The field (column) in the pandas ``DataFrame`` to use for the input. - target_fields: The field or fields (columns) in the pandas ``DataFrame`` to use for the target. - train_data_frame: The pandas ``DataFrame`` containing the training data. - val_data_frame: The pandas ``DataFrame`` containing the validation data. - test_data_frame: The pandas ``DataFrame`` containing the testing data. - predict_data_frame: The pandas ``DataFrame`` containing the data to use when predicting. - train_transform: The dictionary of transforms to use during training which maps - :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. - val_transform: The dictionary of transforms to use during validation which maps - :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. - test_transform: The dictionary of transforms to use during testing which maps - :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. - predict_transform: The dictionary of transforms to use during predicting which maps - :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. - max_length: The maximum sequence length. + input_field: The field (column name) in the ``DataFrame`` objects containing the text snippets. + target_fields: The field (column name) or list of fields in the ``DataFrame`` objects containing the + targets. + train_data_frame: The ``DataFrame`` to use when training. + val_data_frame: The ``DataFrame`` to use when validating. + test_data_frame: The ``DataFrame`` to use when testing. + predict_data_frame: The ``DataFrame`` to use when predicting. + train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. + val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. + test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. + predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when + predicting. + input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. + max_length: The maximum length to pad / truncate sequences to. + data_module_kwargs: Additional keyword arguments to provide to the + :class:`~flash.core.data.data_module.DataModule` constructor. Returns: - The constructed data module. + The constructed :class:`~flash.text.classification.data.TextClassificationData`. + + Examples + ________ + + .. doctest:: + + >>> from pandas import DataFrame + >>> from flash import Trainer + >>> from flash.text import TextClassifier, TextClassificationData + >>> train_data = DataFrame.from_dict( + ... { + ... "reviews": ["Best movie ever!", "Not good", "Fine I guess"], + ... "targets": ["positive", "negative", "neutral"], + ... } + ... ) + >>> predict_data = DataFrame.from_dict( + ... { + ... "reviews": ["Worst movie ever!", "I didn't enjoy it", "It was ok"], + ... } + ... ) + >>> datamodule = TextClassificationData.from_data_frame( + ... "reviews", + ... "targets", + ... train_data_frame=train_data, + ... predict_data_frame=predict_data, + ... batch_size=2, + ... ) + >>> datamodule.num_classes + 3 + >>> datamodule.labels + ['negative', 'neutral', 'positive'] + >>> model = TextClassifier(num_classes=datamodule.num_classes) + >>> trainer = Trainer(fast_dev_run=True) + >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Training... + >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Predicting... """ ds_kw = dict( @@ -360,32 +656,59 @@ def from_lists( max_length: int = 128, **data_module_kwargs: Any, ) -> "TextClassificationData": - """Creates a :class:`~flash.text.classification.data.TextClassificationData` object from the given Python - lists. + """Load the :class:`~flash.text.classification.data.TextClassificationData` from lists of text snippets and + corresponding lists of targets. + + The targets can be in any of our + :ref:`supported classification target formats `. + To learn how to customize the transforms applied for each stage, read our + :ref:`customizing transforms guide `. Args: - train_data: A list of sentences to use as the train inputs. - train_targets: A list of targets to use as the train targets. For multi-label classification, the targets - should be provided as a list of lists, where each inner list contains the targets for a sample. - val_data: A list of sentences to use as the validation inputs. - val_targets: A list of targets to use as the validation targets. For multi-label classification, the targets - should be provided as a list of lists, where each inner list contains the targets for a sample. - test_data: A list of sentences to use as the test inputs. - test_targets: A list of targets to use as the test targets. For multi-label classification, the targets - should be provided as a list of lists, where each inner list contains the targets for a sample. - predict_data: A list of sentences to use when predicting. - train_transform: The dictionary of transforms to use during training which maps - :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. - val_transform: The dictionary of transforms to use during validation which maps - :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. - test_transform: The dictionary of transforms to use during testing which maps - :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. - predict_transform: The dictionary of transforms to use during predicting which maps - :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. - max_length: The maximum sequence length. + train_data: The list of text snippets to use when training. + train_targets: The list of targets to use when training. + val_data: The list of text snippets to use when validating. + val_targets: The list of targets to use when validating. + test_data: The list of text snippets to use when testing. + test_targets: The list of targets to use when testing. + predict_data: The list of text snippets to use when predicting. + train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. + val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. + test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. + predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when + predicting. + input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. + max_length: The maximum length to pad / truncate sequences to. + data_module_kwargs: Additional keyword arguments to provide to the + :class:`~flash.core.data.data_module.DataModule` constructor. Returns: - The constructed data module. + The constructed :class:`~flash.text.classification.data.TextClassificationData`. + + Examples + ________ + + .. doctest:: + + >>> from flash import Trainer + >>> from flash.text import TextClassifier, TextClassificationData + >>> datamodule = TextClassificationData.from_lists( + ... train_data=["Best movie ever!", "Not good", "Fine I guess"], + ... train_targets=["positive", "negative", "neutral"], + ... predict_data=["Worst movie ever!", "I didn't enjoy it", "It was ok"], + ... batch_size=2, + ... ) + >>> datamodule.num_classes + 3 + >>> datamodule.labels + ['negative', 'neutral', 'positive'] + >>> model = TextClassifier(num_classes=datamodule.num_classes) + >>> trainer = Trainer(fast_dev_run=True) + >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Training... + >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Predicting... """ ds_kw = dict(