Subclass torch.distributed.pipelining.PipelineStage
for PFO
#131
Labels
enhancement
New feature or request
torch.distributed.pipelining.PipelineStage
for PFO
#131
PyTorch 2.4 has a new API for pipeline parallelism, which includes
PipelineStage
. With this, we can subclassPipelineStage
and overrideforward_one_chunk
andbackward_one_chunk
, where each will first set the GPU's frequency using the async frequency controller and run actual forward/backward.In case users already have an instance of
PipelineStage
(manual splitting) or_PipelineStage
(automatic splitting withpipeline
), we can provide a static method on ourPipelineStage
subclass that melts the user's pipeline stage into ours.POC can be done on TorchTitan's
train.py
without having to modify TorchTitan.The text was updated successfully, but these errors were encountered: