Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for loading torchscript models #25321

Merged
merged 15 commits into from
Feb 11, 2023
75 changes: 61 additions & 14 deletions sdks/python/apache_beam/ml/inference/pytorch_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,11 @@


def _load_model(
model_class: torch.nn.Module, state_dict_path, device, **model_params):
model = model_class(**model_params)

model_class: torch.nn.Module,
state_dict_path,
device,
model_params,
use_torch_script_format=False):
if device == torch.device('cuda') and not torch.cuda.is_available():
logging.warning(
"Model handler specified a 'GPU' device, but GPUs are not available. " \
Expand All @@ -71,18 +73,26 @@ def _load_model(
try:
logging.info(
"Loading state_dict_path %s onto a %s device", state_dict_path, device)
state_dict = torch.load(file, map_location=device)
if not use_torch_script_format:
damccorm marked this conversation as resolved.
Show resolved Hide resolved
model = model_class(**model_params)
state_dict = torch.load(file, map_location=device)
model.load_state_dict(state_dict)
else:
model = torch.jit.load(file, map_location=device)
except RuntimeError as e:
if device == torch.device('cuda'):
message = "Loading the model onto a GPU device failed due to an " \
f"exception:\n{e}\nAttempting to load onto a CPU device instead."
logging.warning(message)
return _load_model(
model_class, state_dict_path, torch.device('cpu'), **model_params)
model_class,
state_dict_path,
torch.device('cpu'),
model_params,
use_torch_script_format)
else:
raise e

model.load_state_dict(state_dict)
model.to(device)
model.eval()
logging.info("Finished loading PyTorch model.")
Expand Down Expand Up @@ -149,11 +159,13 @@ class PytorchModelHandlerTensor(ModelHandler[torch.Tensor,
def __init__(
self,
state_dict_path: str,
model_class: Callable[..., torch.nn.Module],
model_params: Dict[str, Any],
model_class: Optional[Callable[..., torch.nn.Module]] = None,
model_params: Optional[Dict[str, Any]] = None,
device: str = 'CPU',
*,
inference_fn: TensorInferenceFn = default_tensor_inference_fn):
inference_fn: TensorInferenceFn = default_tensor_inference_fn,
use_torch_script_format=False,
):
"""Implementation of the ModelHandler interface for PyTorch.

Example Usage::
Expand All @@ -174,6 +186,9 @@ def __init__(
Otherwise, it will be CPU.
inference_fn: the inference function to use during RunInference.
default=_default_tensor_inference_fn
use_torch_script_format: When `use_torch_script_format` is set to `True`,
the model will be loaded using `torch.jit.load()`.
`model_class` and `model_params` arguments will be disregarded.
damccorm marked this conversation as resolved.
Show resolved Hide resolved

**Supported Versions:** RunInference APIs in Apache Beam have been tested
with PyTorch 1.9 and 1.10.
Expand All @@ -188,14 +203,28 @@ def __init__(
self._model_class = model_class
self._model_params = model_params
self._inference_fn = inference_fn
self._use_torch_script_format = use_torch_script_format

self._validate_func_args()

def _validate_func_args(self):
if not self._use_torch_script_format and (self._model_class is None or
self._model_params is None):
raise RuntimeError(
"Please pass both `model_class` and `model_params` to the torch "
"model handler when using it with PyTorch. "
"If you opt to load the entire that was saved using TorchScript, "
"set `use_torch_script_format` to True.")

def load_model(self) -> torch.nn.Module:
"""Loads and initializes a Pytorch model for processing."""
model, device = _load_model(
self._model_class,
self._state_dict_path,
self._device,
**self._model_params)
self._model_params,
self._use_torch_script_format
)
self._device = device
return model

Expand Down Expand Up @@ -323,11 +352,12 @@ class PytorchModelHandlerKeyedTensor(ModelHandler[Dict[str, torch.Tensor],
def __init__(
self,
state_dict_path: str,
model_class: Callable[..., torch.nn.Module],
model_params: Dict[str, Any],
model_class: Optional[Callable[..., torch.nn.Module]] = None,
model_params: Optional[Dict[str, Any]] = None,
device: str = 'CPU',
*,
inference_fn: KeyedTensorInferenceFn = default_keyed_tensor_inference_fn):
inference_fn: KeyedTensorInferenceFn = default_keyed_tensor_inference_fn,
use_torch_script_format: bool = False):
"""Implementation of the ModelHandler interface for PyTorch.

Example Usage::
Expand All @@ -352,6 +382,9 @@ def __init__(
Otherwise, it will be CPU.
inference_fn: the function to invoke on run_inference.
default = default_keyed_tensor_inference_fn
use_torch_script_format: When `use_torch_script_format` is set to `True`,
the model will be loaded using `torch.jit.load()`.
`model_class` and `model_params` arguments will be disregarded.

**Supported Versions:** RunInference APIs in Apache Beam have been tested
on torch>=1.9.0,<1.14.0.
Expand All @@ -366,14 +399,19 @@ def __init__(
self._model_class = model_class
self._model_params = model_params
self._inference_fn = inference_fn
self._use_torch_script_format = use_torch_script_format

self._validate_func_args()

def load_model(self) -> torch.nn.Module:
"""Loads and initializes a Pytorch model for processing."""
model, device = _load_model(
self._model_class,
self._state_dict_path,
self._device,
**self._model_params)
self._model_params,
self._use_torch_script_format
)
self._device = device
return model

Expand Down Expand Up @@ -429,3 +467,12 @@ def get_metrics_namespace(self) -> str:

def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]):
pass

def _validate_func_args(self):
damccorm marked this conversation as resolved.
Show resolved Hide resolved
if not self._use_torch_script_format and (self._model_class is None or
self._model_params is None):
raise RuntimeError(
"Please pass both `model_class` and `model_params` to the torch "
"model handler when using it with PyTorch. "
"If you opt to load the entire that was saved using TorchScript, "
"set `use_torch_script_format` to True.")
51 changes: 51 additions & 0 deletions sdks/python/apache_beam/ml/inference/pytorch_inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,57 @@ def test_gpu_auto_convert_to_cpu(self):
"are not available. Switching to CPU.",
log.output)

def test_load_torch_script_model(self):
torch_model = PytorchLinearRegression(2, 1)
torch_script_model = torch.jit.script(torch_model)

torch_script_path = os.path.join(self.tmpdir, 'torch_script_model.pt')

torch.jit.save(torch_script_model, torch_script_path)
damccorm marked this conversation as resolved.
Show resolved Hide resolved

model_handler = PytorchModelHandlerTensor(
state_dict_path=torch_script_path, use_torch_script_format=True)

torch_script_model = model_handler.load_model()

self.assertTrue(isinstance(torch_script_model, torch.jit.ScriptModule))

def test_inference_torch_script_model(self):
torch_model = PytorchLinearRegression(2, 1)
torch_model.load_state_dict(
OrderedDict([('linear.weight', torch.Tensor([[2.0, 3]])),
('linear.bias', torch.Tensor([0.5]))]))

torch_script_model = torch.jit.script(torch_model)

torch_script_path = os.path.join(self.tmpdir, 'torch_script_model.pt')

torch.jit.save(torch_script_model, torch_script_path)

model_handler = PytorchModelHandlerTensor(
state_dict_path=torch_script_path, use_torch_script_format=True)

with TestPipeline() as pipeline:
pcoll = pipeline | 'start' >> beam.Create(TWO_FEATURES_EXAMPLES)
predictions = pcoll | RunInference(model_handler)
assert_that(
predictions,
equal_to(
TWO_FEATURES_PREDICTIONS, equals_fn=_compare_prediction_result))

def test_torch_model_class_none(self):
torch_model = PytorchLinearRegression(2, 1)
torch_path = os.path.join(self.tmpdir, 'torch_model.pt')

torch.save(torch_model, torch_path)

with self.assertRaisesRegex(
RuntimeError,
"Please pass both `model_class` and `model_params` to the torch "
"model handler when using it with PyTorch. "
"If you opt to load the entire that was saved using TorchScript"):
_ = PytorchModelHandlerTensor(state_dict_path=torch_path)


if __name__ == '__main__':
unittest.main()