Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pull with master #2

Merged
merged 8 commits into from
Dec 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions docs/source/model_doc/blenderbot.rst
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,15 @@ BlenderbotSmallTokenizer
:members:


BlenderbotModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

See :obj:`transformers.BartModel` for arguments to `forward` and `generate`

.. autoclass:: transformers.BlenderbotModel
:members:


BlenderbotForConditionalGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
7 changes: 7 additions & 0 deletions docs/source/model_doc/mbart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,13 @@ MBartTokenizerFast
:members:


MBartModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.MBartModel
:members:


MBartForConditionalGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
6 changes: 6 additions & 0 deletions docs/source/model_doc/pegasus.rst
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,12 @@ PegasusTokenizerFast
:members:


PegasusModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.PegasusModel


PegasusForConditionalGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
1 change: 1 addition & 0 deletions examples/legacy/pytorch-lightning/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ pytest
conllu
sentencepiece != 0.1.92
protobuf
ray
38 changes: 38 additions & 0 deletions examples/research_projects/rag/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,44 @@ python examples/rag/consolidate_rag_checkpoint.py \
```
You will then be able to pass `path/to/checkpoint` as `model_name_or_path` to the `finetune_rag.py` script.

## Document Retrieval
When running distributed fine-tuning, each training worker needs to retrieve contextual documents
for its input by querying a index loaded into memory. RAG provides two implementations for document retrieval,
one with [`torch.distributed`](https://pytorch.org/docs/stable/distributed.html) communication package and the other
with [`Ray`](https://docs.ray.io/en/master/).

This option can be configured with the `--distributed_retriever` flag which can either be set to `pytorch` or `ray`.
By default this flag is set to `pytorch`.

For the Pytorch implementation, only training worker 0 loads the index into CPU memory, and a gather/scatter pattern is used
to collect the inputs from the other training workers and send back the corresponding document embeddings.

For the Ray implementation, the index is loaded in *separate* process(es). The training workers randomly select which
retriever worker to query. To use Ray for distributed retrieval, you have to set the `--distributed_retriever` arg to `ray`.
To configure the number of retrieval workers (the number of processes that load the index), you can set the `num_retrieval_workers` flag.
Also make sure to start the Ray cluster before running fine-tuning.

```bash
# Start a single-node Ray cluster.
ray start --head

python examples/rag/finetune_rag.py \
--data_dir $DATA_DIR \
--output_dir $OUTPUT_DIR \
--model_name_or_path $MODEL_NAME_OR_PATH \
--model_type rag_sequence \
--fp16 \
--gpus 8
--distributed_retriever ray \
--num_retrieval_workers 4

# Stop the ray cluster once fine-tuning has finished.
ray stop
```

Using Ray can lead to retrieval speedups on multi-GPU settings since multiple processes load the index rather than
just the rank 0 training worker. Using Ray also allows you to load the index on GPU since the index is loaded on a separate
processes than the model, while with pytorch distributed retrieval, both are loaded in the same process potentially leading to GPU OOM.

# Evaluation
Our evaluation script enables two modes of evaluation (controlled by the `eval_mode` argument): `e2e` - end2end evaluation, returns EM (exact match) and F1 scores calculated for the downstream task and `retrieval` - which returns precision@k of the documents retrieved for provided inputs.
Expand Down
16 changes: 15 additions & 1 deletion examples/research_projects/rag/_test_finetune_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from transformers.testing_utils import (
TestCasePlus,
execute_subprocess_async,
require_ray,
require_torch_gpu,
require_torch_multi_gpu,
)
Expand All @@ -29,7 +30,7 @@ def _create_dummy_data(self, data_dir):
with open(os.path.join(data_dir, f"{split}.{field}"), "w") as f:
f.write(content)

def _run_finetune(self, gpus: int):
def _run_finetune(self, gpus: int, distributed_retriever: str = "pytorch"):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)

Expand Down Expand Up @@ -66,6 +67,7 @@ def _run_finetune(self, gpus: int):
--gradient_accumulation_steps 1 \
--distributed-port 8787 \
--use_dummy_dataset 1 \
--distributed_retriever {distributed_retriever} \
""".split()

if gpus > 0:
Expand Down Expand Up @@ -94,3 +96,15 @@ def test_finetune_gpu(self):
def test_finetune_multigpu(self):
result = self._run_finetune(gpus=2)
self.assertGreaterEqual(result["test"][0]["test_avg_em"], 0.2)

@require_torch_gpu
@require_ray
def test_finetune_gpu_ray_retrieval(self):
result = self._run_finetune(gpus=1, distributed_retriever="ray")
self.assertGreaterEqual(result["test"][0]["test_avg_em"], 0.2)

@require_torch_multi_gpu
@require_ray
def test_finetune_multigpu_ray_retrieval(self):
result = self._run_finetune(gpus=1, distributed_retriever="ray")
self.assertGreaterEqual(result["test"][0]["test_avg_em"], 0.2)
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,13 @@ class RagPyTorchDistributedRetriever(RagRetriever):
If specified, use this index instead of the one built using the configuration
"""

_init_retrieval = False

def __init__(self, config, question_encoder_tokenizer, generator_tokenizer, index=None):
super().__init__(
config,
question_encoder_tokenizer=question_encoder_tokenizer,
generator_tokenizer=generator_tokenizer,
index=index,
init_retrieval=False,
)
self.process_group = None

Expand Down
154 changes: 154 additions & 0 deletions examples/research_projects/rag/distributed_ray_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import logging
import random

import ray
from transformers import RagConfig, RagRetriever, RagTokenizer
from transformers.file_utils import requires_datasets, requires_faiss
from transformers.models.rag.retrieval_rag import CustomHFIndex


logger = logging.getLogger(__name__)


class RayRetriever:
def __init__(self):
self.initialized = False

def create_rag_retriever(self, config, question_encoder_tokenizer, generator_tokenizer, index):
if not self.initialized:
self.retriever = RagRetriever(
config,
question_encoder_tokenizer=question_encoder_tokenizer,
generator_tokenizer=generator_tokenizer,
index=index,
init_retrieval=False,
)
self.initialized = True

def init_retrieval(self):
self.retriever.index.init_index()

def retrieve(self, question_hidden_states, n_docs):
doc_ids, retrieved_doc_embeds = self.retriever._main_retrieve(question_hidden_states, n_docs)
return doc_ids, retrieved_doc_embeds


class RagRayDistributedRetriever(RagRetriever):
"""
A distributed retriever built on top of the ``Ray`` API, a library
for building distributed applications (https://docs.ray.io/en/master/).
package. During training, all training workers initialize their own
instance of a `RagRayDistributedRetriever`, and each instance of
this distributed retriever shares a common set of Retrieval Ray
Actors (https://docs.ray.io/en/master/walkthrough.html#remote
-classes-actors) that load the index on separate processes. Ray
handles the communication between the `RagRayDistributedRetriever`
instances and the remote Ray actors. If training is done in a
non-distributed setup, the index will simply be loaded in the same
process as the training worker and Ray will not be used.

Args:
config (:class:`~transformers.RagConfig`):
The configuration of the RAG model this Retriever is used with. Contains parameters indicating which ``Index`` to build.
question_encoder_tokenizer (:class:`~transformers.PretrainedTokenizer`):
The tokenizer that was used to tokenize the question.
It is used to decode the question and then use the generator_tokenizer.
generator_tokenizer (:class:`~transformers.PretrainedTokenizer`):
The tokenizer used for the generator part of the RagModel.
retrieval_workers (:obj:`List[ray.ActorClass(RayRetriever)]`): A list of already initialized `RayRetriever` actors.
These actor classes run on remote processes and are responsible for performing the index lookup.
index (:class:`~transformers.retrieval_rag.Index`, optional, defaults to the one defined by the configuration):
If specified, use this index instead of the one built using the configuration
"""

def __init__(self, config, question_encoder_tokenizer, generator_tokenizer, retrieval_workers, index=None):
if index is not None and index.is_initialized() and len(retrieval_workers) > 0:
raise ValueError(
"When using Ray for distributed fine-tuning, "
"you'll need to provide the paths instead, "
"as the dataset and the index are loaded "
"separately. More info in examples/rag/use_own_knowledge_dataset.py "
)
super().__init__(
config,
question_encoder_tokenizer=question_encoder_tokenizer,
generator_tokenizer=generator_tokenizer,
index=index,
init_retrieval=False,
)
self.retrieval_workers = retrieval_workers
if len(self.retrieval_workers) > 0:
ray.get(
[
worker.create_rag_retriever.remote(config, question_encoder_tokenizer, generator_tokenizer, index)
for worker in self.retrieval_workers
]
)

def init_retrieval(self):
"""
Retriever initialization function, needs to be called from the
training process. This function triggers retrieval initialization
for all retrieval actors if using distributed setting, or loads
index into current process if training is not distributed.
"""
logger.info("initializing retrieval")

if len(self.retrieval_workers) > 0:
ray.get([worker.init_retrieval.remote() for worker in self.retrieval_workers])
else:
# Non-distributed training. Load index into this same process.
self.index.init_index()

def retrieve(self, question_hidden_states, n_docs):
"""
Retrieves documents for specified ``question_hidden_states``. If
running training with multiple workers, a random retrieval actor is
selected to perform the index lookup and return the result.

Args:
question_hidden_states (:obj:`np.ndarray` of shape :obj:`(batch_size, vector_size)`):
A batch of query vectors to retrieve with.
n_docs (:obj:`int`):
The number of docs retrieved per query.

Output:
retrieved_doc_embeds (:obj:`np.ndarray` of shape :obj:`(batch_size, n_docs, dim)`
The retrieval embeddings of the retrieved docs per query.
doc_ids (:obj:`np.ndarray` of shape :obj:`batch_size, n_docs`)
The ids of the documents in the index
doc_dicts (:obj:`List[dict]`):
The retrieved_doc_embeds examples per query.
"""
if len(self.retrieval_workers) > 0:
# Select a random retrieval actor.
random_worker = self.retrieval_workers[random.randint(0, len(self.retrieval_workers) - 1)]
doc_ids, retrieved_doc_embeds = ray.get(random_worker.retrieve.remote(question_hidden_states, n_docs))
else:
doc_ids, retrieved_doc_embeds = self._main_retrieve(question_hidden_states, n_docs)
return retrieved_doc_embeds, doc_ids, self.index.get_doc_dicts(doc_ids)

@classmethod
def get_tokenizers(cls, retriever_name_or_path, indexed_dataset=None, **kwargs):
return super(RagRayDistributedRetriever, cls).get_tokenizers(retriever_name_or_path, indexed_dataset, **kwargs)

@classmethod
def from_pretrained(cls, retriever_name_or_path, actor_handles, indexed_dataset=None, **kwargs):
requires_datasets(cls)
requires_faiss(cls)
config = kwargs.pop("config", None) or RagConfig.from_pretrained(retriever_name_or_path, **kwargs)
rag_tokenizer = RagTokenizer.from_pretrained(retriever_name_or_path, config=config)
question_encoder_tokenizer = rag_tokenizer.question_encoder
generator_tokenizer = rag_tokenizer.generator
if indexed_dataset is not None:
config.index_name = "custom"
index = CustomHFIndex(config.retrieval_vector_size, indexed_dataset)
else:
index = cls._build_index(config)
return cls(
config,
question_encoder_tokenizer=question_encoder_tokenizer,
generator_tokenizer=generator_tokenizer,
retrieval_workers=actor_handles,
index=index,
)
Loading