diff --git a/examples/legacy/pytorch-lightning/requirements.txt b/examples/legacy/pytorch-lightning/requirements.txt index cb218847c67e..7a3030197745 100644 --- a/examples/legacy/pytorch-lightning/requirements.txt +++ b/examples/legacy/pytorch-lightning/requirements.txt @@ -19,3 +19,4 @@ pytest conllu sentencepiece != 0.1.92 protobuf +ray diff --git a/examples/research_projects/rag/README.md b/examples/research_projects/rag/README.md index 12da66fa7e35..c16104d3c062 100644 --- a/examples/research_projects/rag/README.md +++ b/examples/research_projects/rag/README.md @@ -50,6 +50,44 @@ python examples/rag/consolidate_rag_checkpoint.py \ ``` You will then be able to pass `path/to/checkpoint` as `model_name_or_path` to the `finetune_rag.py` script. +## Document Retrieval +When running distributed fine-tuning, each training worker needs to retrieve contextual documents +for its input by querying a index loaded into memory. RAG provides two implementations for document retrieval, +one with [`torch.distributed`](https://pytorch.org/docs/stable/distributed.html) communication package and the other +with [`Ray`](https://docs.ray.io/en/master/). + +This option can be configured with the `--distributed_retriever` flag which can either be set to `pytorch` or `ray`. +By default this flag is set to `pytorch`. + +For the Pytorch implementation, only training worker 0 loads the index into CPU memory, and a gather/scatter pattern is used +to collect the inputs from the other training workers and send back the corresponding document embeddings. + +For the Ray implementation, the index is loaded in *separate* process(es). The training workers randomly select which +retriever worker to query. To use Ray for distributed retrieval, you have to set the `--distributed_retriever` arg to `ray`. +To configure the number of retrieval workers (the number of processes that load the index), you can set the `num_retrieval_workers` flag. +Also make sure to start the Ray cluster before running fine-tuning. + +```bash +# Start a single-node Ray cluster. +ray start --head + +python examples/rag/finetune_rag.py \ + --data_dir $DATA_DIR \ + --output_dir $OUTPUT_DIR \ + --model_name_or_path $MODEL_NAME_OR_PATH \ + --model_type rag_sequence \ + --fp16 \ + --gpus 8 + --distributed_retriever ray \ + --num_retrieval_workers 4 + +# Stop the ray cluster once fine-tuning has finished. +ray stop +``` + +Using Ray can lead to retrieval speedups on multi-GPU settings since multiple processes load the index rather than +just the rank 0 training worker. Using Ray also allows you to load the index on GPU since the index is loaded on a separate +processes than the model, while with pytorch distributed retrieval, both are loaded in the same process potentially leading to GPU OOM. # Evaluation Our evaluation script enables two modes of evaluation (controlled by the `eval_mode` argument): `e2e` - end2end evaluation, returns EM (exact match) and F1 scores calculated for the downstream task and `retrieval` - which returns precision@k of the documents retrieved for provided inputs. diff --git a/examples/research_projects/rag/_test_finetune_rag.py b/examples/research_projects/rag/_test_finetune_rag.py index 164ecfd93211..1be5ecbb89db 100644 --- a/examples/research_projects/rag/_test_finetune_rag.py +++ b/examples/research_projects/rag/_test_finetune_rag.py @@ -9,6 +9,7 @@ from transformers.testing_utils import ( TestCasePlus, execute_subprocess_async, + require_ray, require_torch_gpu, require_torch_multi_gpu, ) @@ -29,7 +30,7 @@ def _create_dummy_data(self, data_dir): with open(os.path.join(data_dir, f"{split}.{field}"), "w") as f: f.write(content) - def _run_finetune(self, gpus: int): + def _run_finetune(self, gpus: int, distributed_retriever: str = "pytorch"): stream_handler = logging.StreamHandler(sys.stdout) logger.addHandler(stream_handler) @@ -66,6 +67,7 @@ def _run_finetune(self, gpus: int): --gradient_accumulation_steps 1 \ --distributed-port 8787 \ --use_dummy_dataset 1 \ + --distributed_retriever {distributed_retriever} \ """.split() if gpus > 0: @@ -94,3 +96,15 @@ def test_finetune_gpu(self): def test_finetune_multigpu(self): result = self._run_finetune(gpus=2) self.assertGreaterEqual(result["test"][0]["test_avg_em"], 0.2) + + @require_torch_gpu + @require_ray + def test_finetune_gpu_ray_retrieval(self): + result = self._run_finetune(gpus=1, distributed_retriever="ray") + self.assertGreaterEqual(result["test"][0]["test_avg_em"], 0.2) + + @require_torch_multi_gpu + @require_ray + def test_finetune_multigpu_ray_retrieval(self): + result = self._run_finetune(gpus=1, distributed_retriever="ray") + self.assertGreaterEqual(result["test"][0]["test_avg_em"], 0.2) diff --git a/examples/research_projects/rag/distributed_retriever.py b/examples/research_projects/rag/distributed_pytorch_retriever.py similarity index 99% rename from examples/research_projects/rag/distributed_retriever.py rename to examples/research_projects/rag/distributed_pytorch_retriever.py index cedd2c33409f..0edbc969a5d0 100644 --- a/examples/research_projects/rag/distributed_retriever.py +++ b/examples/research_projects/rag/distributed_pytorch_retriever.py @@ -31,14 +31,13 @@ class RagPyTorchDistributedRetriever(RagRetriever): If specified, use this index instead of the one built using the configuration """ - _init_retrieval = False - def __init__(self, config, question_encoder_tokenizer, generator_tokenizer, index=None): super().__init__( config, question_encoder_tokenizer=question_encoder_tokenizer, generator_tokenizer=generator_tokenizer, index=index, + init_retrieval=False, ) self.process_group = None diff --git a/examples/research_projects/rag/distributed_ray_retriever.py b/examples/research_projects/rag/distributed_ray_retriever.py new file mode 100644 index 000000000000..69fd719cbcc4 --- /dev/null +++ b/examples/research_projects/rag/distributed_ray_retriever.py @@ -0,0 +1,154 @@ +import logging +import random + +import ray +from transformers import RagConfig, RagRetriever, RagTokenizer +from transformers.file_utils import requires_datasets, requires_faiss +from transformers.models.rag.retrieval_rag import CustomHFIndex + + +logger = logging.getLogger(__name__) + + +class RayRetriever: + def __init__(self): + self.initialized = False + + def create_rag_retriever(self, config, question_encoder_tokenizer, generator_tokenizer, index): + if not self.initialized: + self.retriever = RagRetriever( + config, + question_encoder_tokenizer=question_encoder_tokenizer, + generator_tokenizer=generator_tokenizer, + index=index, + init_retrieval=False, + ) + self.initialized = True + + def init_retrieval(self): + self.retriever.index.init_index() + + def retrieve(self, question_hidden_states, n_docs): + doc_ids, retrieved_doc_embeds = self.retriever._main_retrieve(question_hidden_states, n_docs) + return doc_ids, retrieved_doc_embeds + + +class RagRayDistributedRetriever(RagRetriever): + """ + A distributed retriever built on top of the ``Ray`` API, a library + for building distributed applications (https://docs.ray.io/en/master/). + package. During training, all training workers initialize their own + instance of a `RagRayDistributedRetriever`, and each instance of + this distributed retriever shares a common set of Retrieval Ray + Actors (https://docs.ray.io/en/master/walkthrough.html#remote + -classes-actors) that load the index on separate processes. Ray + handles the communication between the `RagRayDistributedRetriever` + instances and the remote Ray actors. If training is done in a + non-distributed setup, the index will simply be loaded in the same + process as the training worker and Ray will not be used. + + Args: + config (:class:`~transformers.RagConfig`): + The configuration of the RAG model this Retriever is used with. Contains parameters indicating which ``Index`` to build. + question_encoder_tokenizer (:class:`~transformers.PretrainedTokenizer`): + The tokenizer that was used to tokenize the question. + It is used to decode the question and then use the generator_tokenizer. + generator_tokenizer (:class:`~transformers.PretrainedTokenizer`): + The tokenizer used for the generator part of the RagModel. + retrieval_workers (:obj:`List[ray.ActorClass(RayRetriever)]`): A list of already initialized `RayRetriever` actors. + These actor classes run on remote processes and are responsible for performing the index lookup. + index (:class:`~transformers.retrieval_rag.Index`, optional, defaults to the one defined by the configuration): + If specified, use this index instead of the one built using the configuration + """ + + def __init__(self, config, question_encoder_tokenizer, generator_tokenizer, retrieval_workers, index=None): + if index is not None and index.is_initialized() and len(retrieval_workers) > 0: + raise ValueError( + "When using Ray for distributed fine-tuning, " + "you'll need to provide the paths instead, " + "as the dataset and the index are loaded " + "separately. More info in examples/rag/use_own_knowledge_dataset.py " + ) + super().__init__( + config, + question_encoder_tokenizer=question_encoder_tokenizer, + generator_tokenizer=generator_tokenizer, + index=index, + init_retrieval=False, + ) + self.retrieval_workers = retrieval_workers + if len(self.retrieval_workers) > 0: + ray.get( + [ + worker.create_rag_retriever.remote(config, question_encoder_tokenizer, generator_tokenizer, index) + for worker in self.retrieval_workers + ] + ) + + def init_retrieval(self): + """ + Retriever initialization function, needs to be called from the + training process. This function triggers retrieval initialization + for all retrieval actors if using distributed setting, or loads + index into current process if training is not distributed. + """ + logger.info("initializing retrieval") + + if len(self.retrieval_workers) > 0: + ray.get([worker.init_retrieval.remote() for worker in self.retrieval_workers]) + else: + # Non-distributed training. Load index into this same process. + self.index.init_index() + + def retrieve(self, question_hidden_states, n_docs): + """ + Retrieves documents for specified ``question_hidden_states``. If + running training with multiple workers, a random retrieval actor is + selected to perform the index lookup and return the result. + + Args: + question_hidden_states (:obj:`np.ndarray` of shape :obj:`(batch_size, vector_size)`): + A batch of query vectors to retrieve with. + n_docs (:obj:`int`): + The number of docs retrieved per query. + + Output: + retrieved_doc_embeds (:obj:`np.ndarray` of shape :obj:`(batch_size, n_docs, dim)` + The retrieval embeddings of the retrieved docs per query. + doc_ids (:obj:`np.ndarray` of shape :obj:`batch_size, n_docs`) + The ids of the documents in the index + doc_dicts (:obj:`List[dict]`): + The retrieved_doc_embeds examples per query. + """ + if len(self.retrieval_workers) > 0: + # Select a random retrieval actor. + random_worker = self.retrieval_workers[random.randint(0, len(self.retrieval_workers) - 1)] + doc_ids, retrieved_doc_embeds = ray.get(random_worker.retrieve.remote(question_hidden_states, n_docs)) + else: + doc_ids, retrieved_doc_embeds = self._main_retrieve(question_hidden_states, n_docs) + return retrieved_doc_embeds, doc_ids, self.index.get_doc_dicts(doc_ids) + + @classmethod + def get_tokenizers(cls, retriever_name_or_path, indexed_dataset=None, **kwargs): + return super(RagRayDistributedRetriever, cls).get_tokenizers(retriever_name_or_path, indexed_dataset, **kwargs) + + @classmethod + def from_pretrained(cls, retriever_name_or_path, actor_handles, indexed_dataset=None, **kwargs): + requires_datasets(cls) + requires_faiss(cls) + config = kwargs.pop("config", None) or RagConfig.from_pretrained(retriever_name_or_path, **kwargs) + rag_tokenizer = RagTokenizer.from_pretrained(retriever_name_or_path, config=config) + question_encoder_tokenizer = rag_tokenizer.question_encoder + generator_tokenizer = rag_tokenizer.generator + if indexed_dataset is not None: + config.index_name = "custom" + index = CustomHFIndex(config.retrieval_vector_size, indexed_dataset) + else: + index = cls._build_index(config) + return cls( + config, + question_encoder_tokenizer=question_encoder_tokenizer, + generator_tokenizer=generator_tokenizer, + retrieval_workers=actor_handles, + index=index, + ) diff --git a/examples/research_projects/rag/finetune_rag.py b/examples/research_projects/rag/finetune_rag.py index b62da19688ce..ef4c1a37854a 100644 --- a/examples/research_projects/rag/finetune_rag.py +++ b/examples/research_projects/rag/finetune_rag.py @@ -29,6 +29,12 @@ T5ForConditionalGeneration, ) from transformers import logging as transformers_logging +from transformers.integrations import is_ray_available + + +if is_ray_available(): + import ray + from distributed_ray_retriever import RagRayDistributedRetriever, RayRetriever from callbacks_rag import ( # noqa: E402 # isort:skipq @@ -36,7 +42,8 @@ get_early_stopping_callback, Seq2SeqLoggingCallback, ) -from distributed_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip + +from distributed_pytorch_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip from utils_rag import ( # noqa: E402 # isort:skip calculate_exact_match, flatten_list, @@ -88,7 +95,12 @@ def init_ddp_connection(self, global_rank: int, world_size: int, is_slurm_managi os.environ["MASTER_PORT"] = str(self.distributed_port) super().init_ddp_connection(global_rank, world_size, is_slurm_managing_tasks) if module.is_rag_model: - module.model.rag.retriever.init_retrieval(self.distributed_port) + if module.distributed_retriever == "pytorch": + module.model.rag.retriever.init_retrieval(self.distributed_port) + elif module.distributed_retriever == "ray" and global_rank == 0: + # For the Ray retriever, only initialize it once when global + # rank is 0. + module.model.rag.retriever.init_retrieval() class GenerativeQAModule(BaseTransformer): @@ -127,7 +139,13 @@ def __init__(self, hparams, **kwargs): config.generator.prefix = hparams.prefix config.label_smoothing = hparams.label_smoothing hparams, config.generator = set_extra_model_params(extra_model_params, hparams, config.generator) - retriever = RagPyTorchDistributedRetriever.from_pretrained(hparams.model_name_or_path, config=config) + if hparams.distributed_retriever == "pytorch": + retriever = RagPyTorchDistributedRetriever.from_pretrained(hparams.model_name_or_path, config=config) + elif hparams.distributed_retriever == "ray": + # The Ray retriever needs the handles to the retriever actors. + retriever = RagRayDistributedRetriever.from_pretrained( + hparams.model_name_or_path, hparams.actor_handles, config=config + ) model = self.model_class.from_pretrained(hparams.model_name_or_path, config=config, retriever=retriever) prefix = config.question_encoder.prefix else: @@ -180,7 +198,12 @@ def __init__(self, hparams, **kwargs): # For single GPU training, init_ddp_connection is not called. # So we need to initialize the retrievers here. if hparams.gpus <= 1: - self.model.retriever.init_retrieval(self.distributed_port) + if hparams.distributed_retriever == "ray": + self.model.retriever.init_retrieval() + elif hparams.distributed_retriever == "pytorch": + self.model.retriever.init_retrieval(self.distributed_port) + + self.distributed_retriever = hparams.distributed_retriever def forward(self, input_ids, **kwargs): return self.model(input_ids, **kwargs) @@ -420,6 +443,7 @@ def add_model_specific_args(parser, root_dir): type=str, help="RAG model type: sequence or token, if none specified, the type is inferred from the model_name_or_path", ) + return parser @staticmethod @@ -442,12 +466,58 @@ def add_retriever_specific_args(parser): default=None, help="Path to the faiss index for custom index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`", ) + parser.add_argument( + "--distributed_retriever", + choices=["ray", "pytorch"], + type=str, + default="pytorch", + help="What implementation to use for distributed retriever? If " + "pytorch is selected, the index is loaded on training " + "worker 0, and torch.distributed is used to handle " + "communication between training worker 0, and the other " + "training workers. If ray is selected, the Ray library is " + "used to create load the index on separate processes, " + "and Ray handles the communication between the training " + "workers and the retrieval actors.", + ) parser.add_argument( "--use_dummy_dataset", type=bool, default=False, help="Whether to use the dummy version of the dataset index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`", ) + + parser.add_argument( + "--num_retrieval_workers", + type=int, + default=1, + help="The number of retrieval actors to use when Ray is selected" + "for the distributed retriever. Has no effect when " + "distributed_retriever is set to pytorch.", + ) + + @staticmethod + def add_ray_specific_args(parser): + parser.add_argument( + "--num_retrieval_workers", + type=int, + default=1, + help="The number of retrieval actors to use when Ray is selected" + "for the distributed retriever. Has no effect when " + "distributed_retriever is set to pytorch.", + ) + + # Ray cluster address. + parser.add_argument( + "--ray-address", + default="auto", + type=str, + help="The address of the Ray cluster to connect to. If not " + "specified, Ray will attempt to automatically detect the " + "cluster. Has no effect if pytorch is used as the distributed " + "retriever.", + ) + return parser @@ -461,6 +531,46 @@ def main(args=None, model=None) -> GenerativeQAModule: args = args or parser.parse_args() Path(args.output_dir).mkdir(exist_ok=True) + + named_actors = [] + if args.distributed_retriever == "ray" and args.gpus > 1: + if not is_ray_available(): + raise RuntimeError("Please install Ray to use the Ray " "distributed retriever.") + # Connect to an existing Ray cluster. + try: + ray.init(address=args.ray_address) + except (ConnectionError, ValueError): + logger.warning( + "Connection to Ray cluster failed. Make sure a Ray" + "cluster is running by either using Ray's cluster " + "launcher (`ray up`) or by manually starting Ray on " + "each node via `ray start --head` for the head node " + "and `ray start --address=':6379'` for " + "additional nodes. See " + "https://docs.ray.io/en/master/cluster/index.html " + "for more info." + ) + raise + + # Create Ray actors only for rank 0. + if ("LOCAL_RANK" not in os.environ or os.environ["LOCAL_RANK"] == 0) and ( + "NODE_RANK" not in os.environ or os.environ["NODE_RANK"] == 0 + ): + remote_cls = ray.remote(RayRetriever) + named_actors = [ + remote_cls.options(name="retrieval_worker_{}".format(i)).remote() + for i in range(args.num_retrieval_workers) + ] + else: + logger.info( + "Getting named actors for NODE_RANK {}, LOCAL_RANK {}".format( + os.environ["NODE_RANK"], os.environ["LOCAL_RANK"] + ) + ) + named_actors = [ray.get_actor("retrieval_worker_{}".format(i)) for i in range(args.num_retrieval_workers)] + args.actor_handles = named_actors + assert args.actor_handles == named_actors + if model is None: model: GenerativeQAModule = GenerativeQAModule(args) @@ -471,17 +581,17 @@ def main(args=None, model=None) -> GenerativeQAModule: or str(args.output_dir).startswith("/tmp") or str(args.output_dir).startswith("/var") ): - logger = True # don't pollute wandb logs unnecessarily + training_logger = True # don't pollute wandb logs unnecessarily elif args.logger_name == "wandb": from pytorch_lightning.loggers import WandbLogger project = os.environ.get("WANDB_PROJECT", dataset) - logger = WandbLogger(name=model.output_dir.name, project=project) + training_logger = WandbLogger(name=model.output_dir.name, project=project) elif args.logger_name == "wandb_shared": from pytorch_lightning.loggers import WandbLogger - logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}") + training_logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}") es_callback = ( get_early_stopping_callback(model.val_metric, args.early_stopping_patience) @@ -495,8 +605,9 @@ def main(args=None, model=None) -> GenerativeQAModule: logging_callback=Seq2SeqLoggingCallback(), checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric), early_stopping_callback=es_callback, - logger=logger, + logger=training_logger, accelerator=CustomAccel() if args.gpus > 1 else None, + profiler=pl.profiler.AdvancedProfiler() if args.profile else None, ) pickle_save(model.hparams, model.output_dir / "hparams.pkl") @@ -509,4 +620,19 @@ def main(args=None, model=None) -> GenerativeQAModule: if __name__ == "__main__": - main() + parser = argparse.ArgumentParser() + parser = pl.Trainer.add_argparse_args(parser) + parser = GenerativeQAModule.add_model_specific_args(parser, os.getcwd()) + parser = GenerativeQAModule.add_retriever_specific_args(parser) + parser = GenerativeQAModule.add_ray_specific_args(parser) + + # Pytorch Lightning Profiler + parser.add_argument( + "--profile", + action="store_true", + help="If True, use pytorch_lightning.profiler.AdvancedProfiler to profile the Trainer.", + ) + + args = parser.parse_args() + + main(args) diff --git a/examples/research_projects/rag/finetune_rag.sh b/examples/research_projects/rag/finetune_rag.sh index 577b6ebd0dbd..8fd1fea3e546 100755 --- a/examples/research_projects/rag/finetune_rag.sh +++ b/examples/research_projects/rag/finetune_rag.sh @@ -2,7 +2,7 @@ export PYTHONPATH="../":"${PYTHONPATH}" # A sample finetuning run, you need to specify data_dir, output_dir and model_name_or_path -# run ./examples/rag/finetune.sh --help to see all the possible options +# run ./examples/rag/finetune_rag.sh --help to see all the possible options python examples/rag/finetune_rag.py \ --data_dir $DATA_DIR \ @@ -11,10 +11,10 @@ python examples/rag/finetune_rag.py \ --model_type rag_sequence \ --fp16 \ --gpus 8 \ + --profile \ --do_train \ --do_predict \ --n_val -1 \ - --val_check_interval 0.25 \ --train_batch_size 8 \ --eval_batch_size 1 \ --max_source_length 128 \ @@ -31,4 +31,4 @@ python examples/rag/finetune_rag.py \ --learning_rate 3e-05 \ --num_train_epochs 100 \ --warmup_steps 500 \ - --gradient_accumulation_steps 1 \ No newline at end of file + --gradient_accumulation_steps 1 \ diff --git a/examples/research_projects/rag/finetune_rag_ray.sh b/examples/research_projects/rag/finetune_rag_ray.sh new file mode 100755 index 000000000000..7c8e7b97e77c --- /dev/null +++ b/examples/research_projects/rag/finetune_rag_ray.sh @@ -0,0 +1,44 @@ +# Sample script to finetune RAG using Ray for distributed retrieval. + +# Add parent directory to python path to access lightning_base.py +export PYTHONPATH="../":"${PYTHONPATH}" + +# Start a single-node Ray cluster. +ray start --head + +# A sample finetuning run, you need to specify data_dir, output_dir and model_name_or_path +# run ./examples/rag/finetune_rag_ray.sh --help to see all the possible options + +python examples/rag/finetune_rag.py \ + --data_dir $DATA_DIR \ + --output_dir $OUTPUT_DIR \ + --model_name_or_path $MODEL_NAME_OR_PATH \ + --model_type rag_sequence \ + --fp16 \ + --gpus 8 \ + --profile \ + --do_train \ + --do_predict \ + --n_val -1 \ + --train_batch_size 8 \ + --eval_batch_size 1 \ + --max_source_length 128 \ + --max_target_length 25 \ + --val_max_target_length 25 \ + --test_max_target_length 25 \ + --label_smoothing 0.1 \ + --dropout 0.1 \ + --attention_dropout 0.1 \ + --weight_decay 0.001 \ + --adam_epsilon 1e-08 \ + --max_grad_norm 0.1 \ + --lr_scheduler polynomial \ + --learning_rate 3e-05 \ + --num_train_epochs 100 \ + --warmup_steps 500 \ + --gradient_accumulation_steps 1 \ + --distributed_retriever ray \ + --num_retrieval_workers 4 + +# Stop the Ray cluster. +ray stop diff --git a/examples/research_projects/rag/test_distributed_retriever.py b/examples/research_projects/rag/test_distributed_retriever.py index e7a5d9ba91a3..8865a3098959 100644 --- a/examples/research_projects/rag/test_distributed_retriever.py +++ b/examples/research_projects/rag/test_distributed_retriever.py @@ -13,15 +13,27 @@ import faiss from transformers import BartConfig, BartTokenizer, DPRConfig, DPRQuestionEncoderTokenizer, RagConfig from transformers.file_utils import is_datasets_available, is_faiss_available, is_psutil_available, is_torch_available +from transformers.integrations import is_ray_available from transformers.models.bert.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES -from transformers.models.rag.retrieval_rag import CustomHFIndex +from transformers.models.rag.retrieval_rag import CustomHFIndex, RagRetriever from transformers.models.roberta.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES -from transformers.testing_utils import require_torch_non_multi_gpu_but_fix_me +from transformers.testing_utils import require_ray, require_torch_non_multi_gpu_but_fix_me sys.path.append(os.path.join(os.getcwd())) # noqa: E402 # noqa: E402 # isort:skip -from distributed_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip +if is_torch_available(): + from distributed_pytorch_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip +else: + RagPyTorchDistributedRetriever = None + +if is_ray_available(): + import ray # noqa: E402 # isort:skip + from distributed_ray_retriever import RagRayDistributedRetriever, RayRetriever # noqa: E402 # isort:skip +else: + ray = None + RagRayDistributedRetriever = None + RayRetriever = None def require_distributed_retrieval(test_case): @@ -32,8 +44,8 @@ def require_distributed_retrieval(test_case): These tests are skipped when respective libraries are not installed. """ - if not (is_torch_available() and is_datasets_available() and is_faiss_available() and is_psutil_available()): - test_case = unittest.skip("test requires PyTorch, Datasets, Faiss, psutil")(test_case) + if not (is_datasets_available() and is_faiss_available() and is_psutil_available()): + test_case = unittest.skip("test requires Datasets, Faiss, psutil")(test_case) return test_case @@ -144,7 +156,31 @@ def get_dummy_pytorch_distributed_retriever( retriever.init_retrieval(port) return retriever - def get_dummy_custom_hf_index_retriever(self, init_retrieval: bool, from_disk: bool, port=12345): + def get_dummy_ray_distributed_retriever(self, init_retrieval: bool) -> RagRayDistributedRetriever: + # Have to run in local mode because sys.path modifications at top of + # file are not propogated to remote workers. + # https://stackoverflow.com/questions/54338013/parallel-import-a-python-file-from-sibling-folder + ray.init(local_mode=True) + config = RagConfig( + retrieval_vector_size=self.retrieval_vector_size, + question_encoder=DPRConfig().to_dict(), + generator=BartConfig().to_dict(), + ) + remote_cls = ray.remote(RayRetriever) + workers = [remote_cls.remote() for _ in range(1)] + with patch("transformers.models.rag.retrieval_rag.load_dataset") as mock_load_dataset: + mock_load_dataset.return_value = self.get_dummy_dataset() + retriever = RagRayDistributedRetriever( + config, + question_encoder_tokenizer=self.get_dpr_tokenizer(), + generator_tokenizer=self.get_bart_tokenizer(), + retrieval_workers=workers, + ) + if init_retrieval: + retriever.init_retrieval() + return retriever + + def get_dummy_custom_hf_index_pytorch_retriever(self, init_retrieval: bool, from_disk: bool, port=12345): dataset = self.get_dummy_dataset() config = RagConfig( retrieval_vector_size=self.retrieval_vector_size, @@ -175,13 +211,51 @@ def get_dummy_custom_hf_index_retriever(self, init_retrieval: bool, from_disk: b retriever.init_retrieval(port) return retriever - @require_torch_non_multi_gpu_but_fix_me - def test_pytorch_distributed_retriever_retrieve(self): - n_docs = 1 - retriever = self.get_dummy_pytorch_distributed_retriever(init_retrieval=True) - hidden_states = np.array( - [np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32 + def get_dummy_custom_hf_index_ray_retriever(self, init_retrieval: bool, from_disk: bool): + # Have to run in local mode because sys.path modifications at top of + # file are not propogated to remote workers. + # https://stackoverflow.com/questions/54338013/parallel-import-a-python-file-from-sibling-folder + ray.init(local_mode=True) + dataset = self.get_dummy_dataset() + config = RagConfig( + retrieval_vector_size=self.retrieval_vector_size, + question_encoder=DPRConfig().to_dict(), + generator=BartConfig().to_dict(), + index_name="custom", ) + remote_cls = ray.remote(RayRetriever) + workers = [remote_cls.remote() for _ in range(1)] + if from_disk: + config.passages_path = os.path.join(self.tmpdirname, "dataset") + config.index_path = os.path.join(self.tmpdirname, "index.faiss") + dataset.get_index("embeddings").save(os.path.join(self.tmpdirname, "index.faiss")) + dataset.drop_index("embeddings") + dataset.save_to_disk(os.path.join(self.tmpdirname, "dataset")) + del dataset + retriever = RagRayDistributedRetriever( + config, + question_encoder_tokenizer=self.get_dpr_tokenizer(), + generator_tokenizer=self.get_bart_tokenizer(), + retrieval_workers=workers, + index=CustomHFIndex.load_from_disk( + vector_size=config.retrieval_vector_size, + dataset_path=config.passages_path, + index_path=config.index_path, + ), + ) + else: + retriever = RagRayDistributedRetriever( + config, + question_encoder_tokenizer=self.get_dpr_tokenizer(), + generator_tokenizer=self.get_bart_tokenizer(), + retrieval_workers=workers, + index=CustomHFIndex(config.retrieval_vector_size, dataset), + ) + if init_retrieval: + retriever.init_retrieval() + return retriever + + def distributed_retriever_check(self, retriever: RagRetriever, hidden_states: np.array, n_docs: int) -> None: retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs) self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size)) self.assertEqual(len(doc_dicts), 2) @@ -192,33 +266,76 @@ def test_pytorch_distributed_retriever_retrieve(self): self.assertListEqual(doc_ids.tolist(), [[1], [0]]) @require_torch_non_multi_gpu_but_fix_me - def test_custom_hf_index_retriever_retrieve(self): + def test_pytorch_distributed_retriever_retrieve(self): + n_docs = 1 + hidden_states = np.array( + [np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32 + ) + + self.distributed_retriever_check( + self.get_dummy_pytorch_distributed_retriever(init_retrieval=True), hidden_states, n_docs + ) + + @require_torch_non_multi_gpu_but_fix_me + def test_custom_hf_index_pytorch_retriever_retrieve(self): n_docs = 1 - retriever = self.get_dummy_custom_hf_index_retriever(init_retrieval=True, from_disk=False) hidden_states = np.array( [np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32 ) - retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs) - self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size)) - self.assertEqual(len(doc_dicts), 2) - self.assertEqual(sorted(doc_dicts[0]), ["embeddings", "id", "text", "title"]) - self.assertEqual(len(doc_dicts[0]["id"]), n_docs) - self.assertEqual(doc_dicts[0]["id"][0], "1") # max inner product is reached with second doc - self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc - self.assertListEqual(doc_ids.tolist(), [[1], [0]]) + + self.distributed_retriever_check( + self.get_dummy_custom_hf_index_pytorch_retriever(init_retrieval=True, from_disk=False), + hidden_states, + n_docs, + ) @require_torch_non_multi_gpu_but_fix_me def test_custom_pytorch_distributed_retriever_retrieve_from_disk(self): n_docs = 1 - retriever = self.get_dummy_custom_hf_index_retriever(init_retrieval=True, from_disk=True) hidden_states = np.array( [np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32 ) - retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs) - self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size)) - self.assertEqual(len(doc_dicts), 2) - self.assertEqual(sorted(doc_dicts[0]), ["embeddings", "id", "text", "title"]) - self.assertEqual(len(doc_dicts[0]["id"]), n_docs) - self.assertEqual(doc_dicts[0]["id"][0], "1") # max inner product is reached with second doc - self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc - self.assertListEqual(doc_ids.tolist(), [[1], [0]]) + + self.distributed_retriever_check( + self.get_dummy_custom_hf_index_pytorch_retriever(init_retrieval=True, from_disk=True), + hidden_states, + n_docs, + ) + + @require_ray + def test_ray_distributed_retriever_retrieve(self): + n_docs = 1 + hidden_states = np.array( + [np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32 + ) + + self.distributed_retriever_check( + self.get_dummy_ray_distributed_retriever(init_retrieval=True), hidden_states, n_docs + ) + ray.shutdown() + + @require_ray + def test_custom_hf_index_ray_retriever_retrieve(self): + n_docs = 1 + hidden_states = np.array( + [np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32 + ) + with self.assertRaises(ValueError): + self.distributed_retriever_check( + self.get_dummy_custom_hf_index_ray_retriever(init_retrieval=True, from_disk=False), + hidden_states, + n_docs, + ) + ray.shutdown() + + @require_ray + def test_custom_ray_distributed_retriever_retrieve_from_disk(self): + n_docs = 1 + hidden_states = np.array( + [np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32 + ) + + self.distributed_retriever_check( + self.get_dummy_custom_hf_index_ray_retriever(init_retrieval=True, from_disk=True), hidden_states, n_docs + ) + ray.shutdown() diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 0e39ed7ba5a9..4586fe5363f3 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -219,6 +219,7 @@ is_comet_available, is_optuna_available, is_ray_available, + is_ray_tune_available, is_tensorboard_available, is_wandb_available, ) diff --git a/src/transformers/integrations.py b/src/transformers/integrations.py index ecc2a9f635a6..2d673087e832 100644 --- a/src/transformers/integrations.py +++ b/src/transformers/integrations.py @@ -63,8 +63,16 @@ import ray # noqa: F401 _has_ray = True + try: + # Ray Tune has additional dependencies. + from ray import tune # noqa: F401 + + _has_ray_tune = True + except (ImportError): + _has_ray_tune = False except (ImportError): _has_ray = False + _has_ray_tune = False try: from torch.utils.tensorboard import SummaryWriter # noqa: F401 @@ -127,6 +135,10 @@ def is_ray_available(): return _has_ray +def is_ray_tune_available(): + return _has_ray_tune + + def is_azureml_available(): return _has_azureml @@ -143,7 +155,7 @@ def hp_params(trial): if is_optuna_available(): if isinstance(trial, optuna.Trial): return trial.params - if is_ray_available(): + if is_ray_tune_available(): if isinstance(trial, dict): return trial @@ -153,7 +165,7 @@ def hp_params(trial): def default_hp_search_backend(): if is_optuna_available(): return "optuna" - elif is_ray_available(): + elif is_ray_tune_available(): return "ray" diff --git a/src/transformers/models/rag/retrieval_rag.py b/src/transformers/models/rag/retrieval_rag.py index 8db18a1d65d9..ff85560e5933 100644 --- a/src/transformers/models/rag/retrieval_rag.py +++ b/src/transformers/models/rag/retrieval_rag.py @@ -370,9 +370,8 @@ class RagRetriever: """ - _init_retrieval = True - - def __init__(self, config, question_encoder_tokenizer, generator_tokenizer, index=None): + def __init__(self, config, question_encoder_tokenizer, generator_tokenizer, index=None, init_retrieval=True): + self._init_retrieval = init_retrieval requires_datasets(self) requires_faiss(self) super().__init__() diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 8cbc8ea299b5..10a911bd2743 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -37,7 +37,7 @@ is_fairscale_available, is_mlflow_available, is_optuna_available, - is_ray_available, + is_ray_tune_available, is_tensorboard_available, is_wandb_available, run_hp_search_optuna, @@ -145,7 +145,7 @@ if is_optuna_available(): import optuna -if is_ray_available(): +if is_ray_tune_available(): from ray import tune if is_azureml_available(): @@ -1062,7 +1062,7 @@ def hyperparameter_search( backend = HPSearchBackend(backend) if backend == HPSearchBackend.OPTUNA and not is_optuna_available(): raise RuntimeError("You picked the optuna backend, but it is not installed. Use `pip install optuna`.") - if backend == HPSearchBackend.RAY and not is_ray_available(): + if backend == HPSearchBackend.RAY and not is_ray_tune_available(): raise RuntimeError( "You picked the Ray Tune backend, but it is not installed. Use `pip install 'ray[tune]'`." ) diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index a63959f48cce..4dc7874a0d59 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -132,9 +132,9 @@ def default_hp_space_optuna(trial) -> Dict[str, float]: def default_hp_space_ray(trial) -> Dict[str, float]: - from .integrations import is_ray_available + from .integrations import is_ray_tune_available - assert is_ray_available(), "This function needs ray installed: `pip install ray[tune]`" + assert is_ray_tune_available(), "This function needs ray installed: `pip " "install ray[tune]`" from ray import tune return {