-
Notifications
You must be signed in to change notification settings - Fork 28.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[RAG] Add Ray implementation for distributed retrieval (#9197)
* wip * wip * wip * wip * wip * wip * wip * wip * uncomment * uncomment * wip * updates * add docstring * updates * fix arg * fixes * add unit tests * update readme * update readme * update finetune script * update test * add test * add ray to test dependencies * separate ray and ray tune * formatting * shutdown ray at end of test * fix tests * formatting * formatting * even more formatting * address comments * formatting * add files * Update examples/research_projects/rag/test_distributed_retriever.py Co-authored-by: Sylvain Gugger <[email protected]> * address comments * addressing comments Co-authored-by: Ubuntu <[email protected]> Co-authored-by: Sylvain Gugger <[email protected]>
- Loading branch information
1 parent
f38c4ad
commit a4b21cd
Showing
14 changed files
with
561 additions
and
56 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,3 +19,4 @@ pytest | |
conllu | ||
sentencepiece != 0.1.92 | ||
protobuf | ||
ray |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
154 changes: 154 additions & 0 deletions
154
examples/research_projects/rag/distributed_ray_retriever.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
import logging | ||
import random | ||
|
||
import ray | ||
from transformers import RagConfig, RagRetriever, RagTokenizer | ||
from transformers.file_utils import requires_datasets, requires_faiss | ||
from transformers.models.rag.retrieval_rag import CustomHFIndex | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class RayRetriever: | ||
def __init__(self): | ||
self.initialized = False | ||
|
||
def create_rag_retriever(self, config, question_encoder_tokenizer, generator_tokenizer, index): | ||
if not self.initialized: | ||
self.retriever = RagRetriever( | ||
config, | ||
question_encoder_tokenizer=question_encoder_tokenizer, | ||
generator_tokenizer=generator_tokenizer, | ||
index=index, | ||
init_retrieval=False, | ||
) | ||
self.initialized = True | ||
|
||
def init_retrieval(self): | ||
self.retriever.index.init_index() | ||
|
||
def retrieve(self, question_hidden_states, n_docs): | ||
doc_ids, retrieved_doc_embeds = self.retriever._main_retrieve(question_hidden_states, n_docs) | ||
return doc_ids, retrieved_doc_embeds | ||
|
||
|
||
class RagRayDistributedRetriever(RagRetriever): | ||
""" | ||
A distributed retriever built on top of the ``Ray`` API, a library | ||
for building distributed applications (https://docs.ray.io/en/master/). | ||
package. During training, all training workers initialize their own | ||
instance of a `RagRayDistributedRetriever`, and each instance of | ||
this distributed retriever shares a common set of Retrieval Ray | ||
Actors (https://docs.ray.io/en/master/walkthrough.html#remote | ||
-classes-actors) that load the index on separate processes. Ray | ||
handles the communication between the `RagRayDistributedRetriever` | ||
instances and the remote Ray actors. If training is done in a | ||
non-distributed setup, the index will simply be loaded in the same | ||
process as the training worker and Ray will not be used. | ||
Args: | ||
config (:class:`~transformers.RagConfig`): | ||
The configuration of the RAG model this Retriever is used with. Contains parameters indicating which ``Index`` to build. | ||
question_encoder_tokenizer (:class:`~transformers.PretrainedTokenizer`): | ||
The tokenizer that was used to tokenize the question. | ||
It is used to decode the question and then use the generator_tokenizer. | ||
generator_tokenizer (:class:`~transformers.PretrainedTokenizer`): | ||
The tokenizer used for the generator part of the RagModel. | ||
retrieval_workers (:obj:`List[ray.ActorClass(RayRetriever)]`): A list of already initialized `RayRetriever` actors. | ||
These actor classes run on remote processes and are responsible for performing the index lookup. | ||
index (:class:`~transformers.retrieval_rag.Index`, optional, defaults to the one defined by the configuration): | ||
If specified, use this index instead of the one built using the configuration | ||
""" | ||
|
||
def __init__(self, config, question_encoder_tokenizer, generator_tokenizer, retrieval_workers, index=None): | ||
if index is not None and index.is_initialized() and len(retrieval_workers) > 0: | ||
raise ValueError( | ||
"When using Ray for distributed fine-tuning, " | ||
"you'll need to provide the paths instead, " | ||
"as the dataset and the index are loaded " | ||
"separately. More info in examples/rag/use_own_knowledge_dataset.py " | ||
) | ||
super().__init__( | ||
config, | ||
question_encoder_tokenizer=question_encoder_tokenizer, | ||
generator_tokenizer=generator_tokenizer, | ||
index=index, | ||
init_retrieval=False, | ||
) | ||
self.retrieval_workers = retrieval_workers | ||
if len(self.retrieval_workers) > 0: | ||
ray.get( | ||
[ | ||
worker.create_rag_retriever.remote(config, question_encoder_tokenizer, generator_tokenizer, index) | ||
for worker in self.retrieval_workers | ||
] | ||
) | ||
|
||
def init_retrieval(self): | ||
""" | ||
Retriever initialization function, needs to be called from the | ||
training process. This function triggers retrieval initialization | ||
for all retrieval actors if using distributed setting, or loads | ||
index into current process if training is not distributed. | ||
""" | ||
logger.info("initializing retrieval") | ||
|
||
if len(self.retrieval_workers) > 0: | ||
ray.get([worker.init_retrieval.remote() for worker in self.retrieval_workers]) | ||
else: | ||
# Non-distributed training. Load index into this same process. | ||
self.index.init_index() | ||
|
||
def retrieve(self, question_hidden_states, n_docs): | ||
""" | ||
Retrieves documents for specified ``question_hidden_states``. If | ||
running training with multiple workers, a random retrieval actor is | ||
selected to perform the index lookup and return the result. | ||
Args: | ||
question_hidden_states (:obj:`np.ndarray` of shape :obj:`(batch_size, vector_size)`): | ||
A batch of query vectors to retrieve with. | ||
n_docs (:obj:`int`): | ||
The number of docs retrieved per query. | ||
Output: | ||
retrieved_doc_embeds (:obj:`np.ndarray` of shape :obj:`(batch_size, n_docs, dim)` | ||
The retrieval embeddings of the retrieved docs per query. | ||
doc_ids (:obj:`np.ndarray` of shape :obj:`batch_size, n_docs`) | ||
The ids of the documents in the index | ||
doc_dicts (:obj:`List[dict]`): | ||
The retrieved_doc_embeds examples per query. | ||
""" | ||
if len(self.retrieval_workers) > 0: | ||
# Select a random retrieval actor. | ||
random_worker = self.retrieval_workers[random.randint(0, len(self.retrieval_workers) - 1)] | ||
doc_ids, retrieved_doc_embeds = ray.get(random_worker.retrieve.remote(question_hidden_states, n_docs)) | ||
else: | ||
doc_ids, retrieved_doc_embeds = self._main_retrieve(question_hidden_states, n_docs) | ||
return retrieved_doc_embeds, doc_ids, self.index.get_doc_dicts(doc_ids) | ||
|
||
@classmethod | ||
def get_tokenizers(cls, retriever_name_or_path, indexed_dataset=None, **kwargs): | ||
return super(RagRayDistributedRetriever, cls).get_tokenizers(retriever_name_or_path, indexed_dataset, **kwargs) | ||
|
||
@classmethod | ||
def from_pretrained(cls, retriever_name_or_path, actor_handles, indexed_dataset=None, **kwargs): | ||
requires_datasets(cls) | ||
requires_faiss(cls) | ||
config = kwargs.pop("config", None) or RagConfig.from_pretrained(retriever_name_or_path, **kwargs) | ||
rag_tokenizer = RagTokenizer.from_pretrained(retriever_name_or_path, config=config) | ||
question_encoder_tokenizer = rag_tokenizer.question_encoder | ||
generator_tokenizer = rag_tokenizer.generator | ||
if indexed_dataset is not None: | ||
config.index_name = "custom" | ||
index = CustomHFIndex(config.retrieval_vector_size, indexed_dataset) | ||
else: | ||
index = cls._build_index(config) | ||
return cls( | ||
config, | ||
question_encoder_tokenizer=question_encoder_tokenizer, | ||
generator_tokenizer=generator_tokenizer, | ||
retrieval_workers=actor_handles, | ||
index=index, | ||
) |
Oops, something went wrong.