diff --git a/examples/dynamo/mutable_torchtrt_module_example.py b/examples/dynamo/mutable_torchtrt_module_example.py index 4bb6143023..3a3c2e1dff 100644 --- a/examples/dynamo/mutable_torchtrt_module_example.py +++ b/examples/dynamo/mutable_torchtrt_module_example.py @@ -14,6 +14,7 @@ 1. Sample workflow of Mutable Torch TensorRT Module with ResNet 18 2. Save a Mutable Torch TensorRT Module 3. Integration with Huggingface pipeline in LoRA use case +4. Usage of dynamic shape with Mutable Torch TensorRT Module """ import numpy as np @@ -25,89 +26,123 @@ torch.manual_seed(5) inputs = [torch.rand((1, 3, 224, 224)).to("cuda")] -# %% -# Initialize the Mutable Torch TensorRT Module with settings. -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -settings = { - "use_python": False, - "enabled_precisions": {torch.float32}, - "immutable_weights": False, -} +# # %% +# # Initialize the Mutable Torch TensorRT Module with settings. +# # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# settings = { +# "use_python": False, +# "enabled_precisions": {torch.float32}, +# "immutable_weights": False, +# } -model = models.resnet18(pretrained=True).eval().to("cuda") -mutable_module = torch_trt.MutableTorchTensorRTModule(model, **settings) -# You can use the mutable module just like the original pytorch module. The compilation happens while you first call the mutable module. -mutable_module(*inputs) +# model = models.resnet18(pretrained=True).eval().to("cuda") +# mutable_module = torch_trt.MutableTorchTensorRTModule(model, **settings) +# # You can use the mutable module just like the original pytorch module. The compilation happens while you first call the mutable module. +# mutable_module(*inputs) -# %% -# Make modifications to the mutable module. -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# # %% +# # Make modifications to the mutable module. +# # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -# %% -# Making changes to mutable module can trigger refit or re-compilation. For example, loading a different state_dict and setting new weight values will trigger refit, and adding a module to the model will trigger re-compilation. -model2 = models.resnet18(pretrained=False).eval().to("cuda") -mutable_module.load_state_dict(model2.state_dict()) +# # %% +# # Making changes to mutable module can trigger refit or re-compilation. For example, loading a different state_dict and setting new weight values will trigger refit, and adding a module to the model will trigger re-compilation. +# model2 = models.resnet18(pretrained=False).eval().to("cuda") +# mutable_module.load_state_dict(model2.state_dict()) -# Check the output -# The refit happens while you call the mutable module again. -expected_outputs, refitted_outputs = model2(*inputs), mutable_module(*inputs) -for expected_output, refitted_output in zip(expected_outputs, refitted_outputs): - assert torch.allclose( - expected_output, refitted_output, 1e-2, 1e-2 - ), "Refit Result is not correct. Refit failed" +# # Check the output +# # The refit happens while you call the mutable module again. +# expected_outputs, refitted_outputs = model2(*inputs), mutable_module(*inputs) +# for expected_output, refitted_output in zip(expected_outputs, refitted_outputs): +# assert torch.allclose( +# expected_output, refitted_output, 1e-2, 1e-2 +# ), "Refit Result is not correct. Refit failed" -print("Refit successfully!") +# print("Refit successfully!") -# %% -# Saving Mutable Torch TensorRT Module -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# # %% +# # Saving Mutable Torch TensorRT Module +# # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -# Currently, saving is only enabled for C++ runtime, not python runtime. -torch_trt.MutableTorchTensorRTModule.save(mutable_module, "mutable_module.pkl") -reload = torch_trt.MutableTorchTensorRTModule.load("mutable_module.pkl") +# # Currently, saving is only when "use_python" = False in settings +# torch_trt.MutableTorchTensorRTModule.save(mutable_module, "mutable_module.pkl") +# reload = torch_trt.MutableTorchTensorRTModule.load("mutable_module.pkl") -# %% -# Stable Diffusion with Huggingface -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# # %% +# # Stable Diffusion with Huggingface +# # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -# The LoRA checkpoint is from https://civitai.com/models/12597/moxin +# # The LoRA checkpoint is from https://civitai.com/models/12597/moxin -from diffusers import DiffusionPipeline +# from diffusers import DiffusionPipeline -with torch.no_grad(): - settings = { - "use_python_runtime": True, - "enabled_precisions": {torch.float16}, - "debug": True, - "immutable_weights": False, - } +# with torch.no_grad(): +# settings = { +# "use_python_runtime": True, +# "enabled_precisions": {torch.float16}, +# "debug": True, +# "immutable_weights": False, +# } - model_id = "stabilityai/stable-diffusion-xl-base-1.0" - device = "cuda:0" +# model_id = "stabilityai/stable-diffusion-xl-base-1.0" +# device = "cuda:0" - prompt = "cinematic photo elsa, police uniform , . 35mm photograph, film, bokeh, professional, 4k, highly detailed" - negative = "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly, nude" +# prompt = "cinematic photo elsa, police uniform , . 35mm photograph, film, bokeh, professional, 4k, highly detailed" +# negative = "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly, nude" - pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16) - pipe.to(device) +# pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16) +# pipe.to(device) - # The only extra line you need - pipe.unet = torch_trt.MutableTorchTensorRTModule(pipe.unet, **settings) +# # The only extra line you need +# pipe.unet = torch_trt.MutableTorchTensorRTModule(pipe.unet, **settings) - image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0] - image.save("./without_LoRA_mutable.jpg") +# image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0] +# image.save("./without_LoRA_mutable.jpg") - # Standard Huggingface LoRA loading procedure - pipe.load_lora_weights( - "stablediffusionapi/load_lora_embeddings", - weight_name="all-disney-princess-xl-lo.safetensors", - adapter_name="lora1", - ) - pipe.set_adapters(["lora1"], adapter_weights=[1]) - pipe.fuse_lora() - pipe.unload_lora_weights() +# # Standard Huggingface LoRA loading procedure +# pipe.load_lora_weights( +# "stablediffusionapi/load_lora_embeddings", +# weight_name="all-disney-princess-xl-lo.safetensors", +# adapter_name="lora1", +# ) +# pipe.set_adapters(["lora1"], adapter_weights=[1]) +# pipe.fuse_lora() +# pipe.unload_lora_weights() - # Refit triggered - image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0] - image.save("./with_LoRA_mutable.jpg") +# # Refit triggered +# image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0] +# image.save("./with_LoRA_mutable.jpg") + + +# %% +# Use Mutable Torch TensorRT module with dynamic shape +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +class Model(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, a, b, c={}): + x = torch.matmul(a, b) + x = torch.matmul(c["a"], c["b"].T) + x = 2 * c["b"][1] + return x + + +model = Model().eval().cuda() +inputs = (torch.rand(10, 3), torch.rand(3, 30)) +kwargs = { + "c": {"a": torch.rand(10, 30), "b": torch.rand(10, 30)}, +} + +dim = torch.export.Dim("dim", min=1, max=50) +dim2 = torch.export.Dim("dim2", min=1, max=50) +args_dynamic_shapes = ({1: dim}, {0: dim}) +kwarg_dynamic_shapes = { + "c": {"a": {}, "b": {0: dim2}}, +} +# Export the model first with custom dynamic shape constraints +# exp_program = torch.export.export(model, tuple(inputs), kwargs=k +trt_gm = torch_trt.MutableTorchTensorRTModule(model, debug=True) +trt_gm.set_expected_dynamic_shape_range(args_dynamic_shapes, kwarg_dynamic_shapes) +# Run inference +trt_gm(*inputs, **kwargs) diff --git a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py index cabed5b601..b215c4218c 100644 --- a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py @@ -196,12 +196,11 @@ def __init__( } self.arg_dynamic_shapes: Optional[tuple[Any]] = None self.kwarg_dynamic_shapes: Optional[dict[Any, Any]] = None - self.total_dynamic_shape: Optional[dict[Any, Any]] = None self.settings = CompilationSettings(**compilation_options) self.run_info: Optional[tuple[Any, ...]] = None self.state_dict_metadata: dict[str, torch.Size] = {} - self.store_state_dict_metadata() + self._store_state_dict_metadata() cls = self.__class__ self.__class__ = type( @@ -211,11 +210,28 @@ def __init__( ) self.init_finished = True - def set_dynamic_shape_hint( + def set_expected_dynamic_shape_range( self, args_dynamic_shape: tuple[dict[Any, Any]], kwargs_dynamic_shape: dict[str, Any], ) -> None: + """ + Set the dynamic shape range. The shape hint format should follow arg_inputs and kwarg_inputs in the forward function + and should not omit and inputs. If the dynamic shape is not required for the input, an empty dictionary should be given + as the shape hint for that input. + + Example: + def forward(a, b, c=0, d=0): + pass + + seq_len = torch.export.Dim("seq_len", min=1, max=10) + args_dynamic_shape = ({0: seq_len}, {}) # b does not have a dynamic shape + kwargs_dynamic_shape = {'c': {0, seq_len}, 'd': {}} # d does not have a dynamic shape + + Arguments: + args_dynamic_shape (tuple[dict[Any, Any]]): Dynamic shape hint for the arg_inputs, + kwargs_dynamic_shape: (dict[str, Any]): Dynamic shape hint for the kwarg_inputs + """ assert isinstance( args_dynamic_shape, tuple ), "args dynamic shape has to be a tuple" @@ -224,19 +240,31 @@ def set_dynamic_shape_hint( ), "args dynamic shape has to be a dictionary" self.kwarg_dynamic_shapes = kwargs_dynamic_shape self.arg_dynamic_shapes = args_dynamic_shape - self.total_dynamic_shape = self.kwarg_dynamic_shapes.copy() - signature = list( - inspect.signature(self.original_model.forward).parameters.keys() - ) - for i, arg in enumerate(self.arg_dynamic_shapes): - self.total_dynamic_shape[signature[i]] = arg - self.refit_state.set_state(RefitFlag.NEEDS_RECOMPILE) # Clear cached inputs self.arg_inputs = tuple() self.kwarg_inputs = {} - def store_state_dict_metadata(self) -> None: + self.refit_state.set_state(RefitFlag.NEEDS_RECOMPILE) + + def _get_total_dynamic_shapes(self) -> dict[str, Any] | None: + if not self.arg_dynamic_shapes and not self.kwarg_dynamic_shapes: + return None + total_dynamic_shape = ( + self.kwarg_dynamic_shapes.copy() + if self.kwarg_dynamic_shapes is not None + else {} + ) + if self.arg_dynamic_shapes: + signature = list( + inspect.signature(self.original_model.forward).parameters.keys() + ) + for i, arg in enumerate(self.arg_dynamic_shapes): + total_dynamic_shape[signature[i]] = arg + + return total_dynamic_shape + + def _store_state_dict_metadata(self) -> None: for k, v in self.original_model.state_dict().items(): self.state_dict_metadata[k] = v.shape @@ -328,7 +356,7 @@ def compile(self) -> None: self.original_model, self.arg_inputs, kwargs=self.kwarg_inputs, - dynamic_shapes=self.total_dynamic_shape, + dynamic_shapes=self._get_total_dynamic_shapes(), ) self.gm = dynamo_compile( self.exp_program, @@ -340,15 +368,19 @@ def compile(self) -> None: torch.cuda.empty_cache() def _validate_inputs(self, *args: Any, **kwargs: Any) -> None: + + if not self.arg_inputs: + logger.info("First time compilation initiated. This may take some time.") + self.refit_state.set_state(RefitFlag.NEEDS_RECOMPILE) + self.store_inputs(args, kwargs) + return + + # If input does not equal or does not fall into dynamic shape range, recompile the engine try: - if ( - not self.arg_inputs - or not MutableTorchTensorRTModule.check_inputs_equal( - self.arg_inputs, args, dynamic_shapes=self.arg_dynamic_shapes - ) - or not MutableTorchTensorRTModule.check_inputs_equal( - self.kwarg_inputs, kwargs, dynamic_shapes=self.kwarg_dynamic_shapes - ) + if not MutableTorchTensorRTModule._check_inputs_shape( + self.arg_inputs, args, dynamic_shapes=self.arg_dynamic_shapes + ) or not MutableTorchTensorRTModule._check_inputs_shape( + self.kwarg_inputs, kwargs, dynamic_shapes=self.kwarg_dynamic_shapes ): logger.info("Input change detected.") self.refit_state.set_state(RefitFlag.NEEDS_RECOMPILE) @@ -356,7 +388,9 @@ def _validate_inputs(self, *args: Any, **kwargs: Any) -> None: except DynamicShapeOutOfRangeException as e: logger.info("Input change detected.") logger.warning(e) - logger.warning("Recompiling the engine with static shape") + logger.warning( + "Provided inputs are outside the set expected shape range, recompiling module for the provided input shapes (static)" + ) self.arg_dynamic_shapes = None self.kwarg_dynamic_shapes = None self.total_dynamic_shape = None @@ -368,12 +402,12 @@ def store_inputs(self, arg_inputs: Any, kwarg_inputs: Any) -> None: self.kwarg_inputs = kwarg_inputs @staticmethod - def process_kwarg_inputs(inputs: Any) -> Any: + def _process_kwarg_inputs(inputs: Any) -> Any: # Process kwarg inputs to be acceptable for Torch-TensorRT if isinstance(inputs, dict): # None should be excluded. AOT compile also does not allow dynamic control flow, bool is also excluded. return { - k: MutableTorchTensorRTModule.process_kwarg_inputs(v) + k: MutableTorchTensorRTModule._process_kwarg_inputs(v) for k, v in inputs.items() if (v is not None and not isinstance(v, bool)) } @@ -384,7 +418,10 @@ def process_kwarg_inputs(inputs: Any) -> Any: elif isinstance(inputs, (list, tuple)): if None not in inputs: return type(inputs)( - [MutableTorchTensorRTModule.process_kwarg_inputs(v) for v in inputs] + [ + MutableTorchTensorRTModule._process_kwarg_inputs(v) + for v in inputs + ] ) raise ValueError( @@ -394,7 +431,7 @@ def process_kwarg_inputs(inputs: Any) -> Any: def forward(self, *args: Any, **kwargs: Any) -> Any: # Step 1: Check whether the input shape has changed - kwargs = MutableTorchTensorRTModule.process_kwarg_inputs(kwargs) + kwargs = MutableTorchTensorRTModule._process_kwarg_inputs(kwargs) self._validate_inputs(*args, **kwargs) # Step 2: If the flag is unknown, it could be a recompile or refit. @@ -406,7 +443,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: if self.refit_state.get_state() == RefitFlag.NEEDS_RECOMPILE: logger.info("(Re)Compiling the engine...") self.compile() - self.store_state_dict_metadata() + self._store_state_dict_metadata() self.refit_state.set_state(RefitFlag.LIVE) elif self.refit_state.get_state() == RefitFlag.NEEDS_REFIT: @@ -417,7 +454,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: logger.error(e) logger.error("Model refit failed. Recompiling the graph module.") self.compile() - self.store_state_dict_metadata() + self._store_state_dict_metadata() self.refit_state.set_state(RefitFlag.LIVE) result = self.gm(*args, **kwargs) @@ -427,7 +464,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: def to(self, device: str) -> None: logger.warning("Original PyTorch model is moved. CPU offload may failed.") - self.orignial_model.to(device) + self.original_model.to(device) def __deepcopy__(self, memo: Any) -> Any: cls = self.__class__ @@ -479,7 +516,7 @@ def __setattr__(self, name: str, value: Any) -> None: object.__setattr__(self, name, value) @staticmethod - def check_inputs_equal( + def _check_inputs_shape( input1: Any, input2: Any, dynamic_shapes: Any = None, @@ -498,7 +535,7 @@ def check_inputs_equal( return False else: tensor_dynamic_shape = dynamic_shapes[i] - if not MutableTorchTensorRTModule.check_tensor_shapes_with_dynamic_shapes( + if not MutableTorchTensorRTModule._check_tensor_shapes_with_dynamic_shapes( a, b, tensor_dynamic_shape ): return False @@ -516,20 +553,20 @@ def check_inputs_equal( return False else: tensor_dynamic_shape = dynamic_shapes[ka] - if not MutableTorchTensorRTModule.check_tensor_shapes_with_dynamic_shapes( + if not MutableTorchTensorRTModule._check_tensor_shapes_with_dynamic_shapes( va, vb, tensor_dynamic_shape ): return False elif isinstance( va, (list, tuple, dict) - ) and not MutableTorchTensorRTModule.check_inputs_equal( + ) and not MutableTorchTensorRTModule._check_inputs_shape( va, vb, dynamic_shapes[ka] if dynamic_shapes else None ): return False return True @staticmethod - def check_tensor_shapes_with_dynamic_shapes( + def _check_tensor_shapes_with_dynamic_shapes( t1: torch.tensor, t2: torch.tensor, dynamic_shape: dict[int, Any] ) -> bool: for (i, axis_0), axis_1 in zip(enumerate(t1.shape), t2.shape): diff --git a/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py b/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py index 8a4b2a376e..05f388656f 100644 --- a/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py +++ b/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py @@ -35,6 +35,17 @@ def test_check_output_equal(): msg=f"test_check_output_equal is not correct.", ) + torch.manual_seed(1) + c = { + "a": torch.rand(10, 30), + "b": [torch.rand(10, 30), torch.rand(5, 5)], + "c": {"a": torch.rand(10, 30), "b": [torch.rand(10, 30), torch.rand(5, 5)]}, + } + assertions.assertFalse( + check_output_equal(a, c), + msg=f"test_check_output_equal is not correct.", + ) + @pytest.mark.unit def test_check_input_shape_dynamic(): @@ -54,11 +65,11 @@ def test_check_input_shape_dynamic(): dim = torch.export.Dim("dim", min=1, max=50) dynamic_shape = {"a": {1: dim}, "b": [{}, {}], "c": {"a": {}, "b": [{}, {}]}} assertions.assertFalse( - torch_trt.MutableTorchTensorRTModule.check_inputs_equal(a, b), + torch_trt.MutableTorchTensorRTModule._check_inputs_shape(a, b), msg=f"test_check_output_equal is not correct.", ) assertions.assertTrue( - torch_trt.MutableTorchTensorRTModule.check_inputs_equal(a, b, dynamic_shape), + torch_trt.MutableTorchTensorRTModule._check_inputs_shape(a, b, dynamic_shape), msg=f"test_check_output_equal is not correct.", ) @@ -92,7 +103,7 @@ def forward(self, a, b, c=None): # Export the model first with custom dynamic shape constraints # exp_program = torch.export.export(model, tuple(inputs), kwargs=k trt_gm = torch_trt.MutableTorchTensorRTModule(model, debug=True) - trt_gm.set_dynamic_shape_hint(args_dynamic_shapes, kwarg_dynamic_shapes) + trt_gm.set_expected_dynamic_shape_range(args_dynamic_shapes, kwarg_dynamic_shapes) # Run inference trt_gm(*inputs, **kwargs) @@ -102,7 +113,7 @@ def forward(self, a, b, c=None): "c": {"a": torch.rand(10, 30), "b": [torch.rand(10, 30), torch.rand(5, 20)]}, } - kwargs = torch_trt.MutableTorchTensorRTModule.process_kwarg_inputs(kwargs_2) + kwargs = torch_trt.MutableTorchTensorRTModule._process_kwarg_inputs(kwargs_2) trt_gm._validate_inputs(*inputs_2, **kwargs_2) assertions.assertTrue( trt_gm.refit_state.get_state() == RefitFlag.LIVE, @@ -117,7 +128,7 @@ def forward(self, a, b, c=None): "c": {"a": torch.rand(10, 30), "b": [torch.rand(10, 30), torch.rand(5, 20)]}, } - kwargs = torch_trt.MutableTorchTensorRTModule.process_kwarg_inputs(kwargs_3) + kwargs = torch_trt.MutableTorchTensorRTModule._process_kwarg_inputs(kwargs_3) trt_gm._validate_inputs(*inputs_3, **kwargs_3) assertions.assertTrue( trt_gm.refit_state.get_state() == RefitFlag.NEEDS_RECOMPILE, @@ -132,7 +143,7 @@ def forward(self, a, b, c=None): "c": {"a": torch.rand(10, 30), "b": [torch.rand(10, 30), torch.rand(5, 20)]}, } - kwargs = torch_trt.MutableTorchTensorRTModule.process_kwarg_inputs(kwargs_4) + kwargs = torch_trt.MutableTorchTensorRTModule._process_kwarg_inputs(kwargs_4) trt_gm._validate_inputs(*inputs_4, **kwargs_4) assertions.assertTrue( trt_gm.refit_state.get_state() == RefitFlag.LIVE, @@ -147,13 +158,20 @@ def forward(self, a, b, c=None): "c": {"a": torch.rand(10, 30), "b": [torch.rand(10, 30), torch.rand(5, 20)]}, } - kwargs = torch_trt.MutableTorchTensorRTModule.process_kwarg_inputs(kwargs_5) + kwargs = torch_trt.MutableTorchTensorRTModule._process_kwarg_inputs(kwargs_5) trt_gm._validate_inputs(*inputs_5, **kwargs_5) assertions.assertTrue( trt_gm.refit_state.get_state() == RefitFlag.NEEDS_RECOMPILE, msg=f"Dynamic shape support is not correct.", ) - trt_gm(*inputs_5, **kwargs_5) + assertions.assertTrue( + trt_gm.arg_dynamic_shapes == None, + msg=f"Dynamic shape support is not correct.", + ) + assertions.assertTrue( + trt_gm.kwarg_dynamic_shapes == None, + msg=f"Dynamic shape support is not correct.", + ) @unittest.skipIf( @@ -308,7 +326,7 @@ def test_resnet18_modify_attribute_no_refit(): for expected_output, refitted_output in zip(expected_outputs, refitted_outputs): assertions.assertTrue( torch.allclose(expected_output, refitted_output, 1e-2, 1e-2), - msg=f"The output of refitted Mutable Module is not correct.", + msg=f"The output of original and refitted Mutable Module is not the same.", ) # # Clean up model env @@ -375,7 +393,7 @@ def forward(self, x, b=5, c=None, d=None): ) assertions.assertTrue( check_output_equal(expected_outputs, refitted_outputs), - msg=f"The output of saved and reloaded Mutable Module is not correct.", + msg=f"The output of original and refitted Mutable Module is not the same.", ) # Clean up model env @@ -438,7 +456,7 @@ def set_weights(self): expected_outputs, refitted_outputs = model(*args), mutable_module(*args) assertions.assertTrue( check_output_equal(expected_outputs, refitted_outputs), - msg=f"The output of saved and reloaded Mutable Module is not correct.", + msg=f"The output of original and refitted Mutable Module is not the same.", ) # Clean up model env @@ -501,7 +519,7 @@ def set_layer(self): expected_outputs, refitted_outputs = model(*args), mutable_module(*args) assertions.assertTrue( check_output_equal(expected_outputs, refitted_outputs), - msg=f"The output of saved and reloaded Mutable Module is not correct.", + msg=f"The output of original and refitted Mutable Module is not the same.", ) # Clean up model env @@ -571,7 +589,7 @@ def forward(self, x, b=5, c=None, d=None): ) assertions.assertTrue( check_output_equal(expected_outputs, refitted_outputs), - msg=f"The output of saved and reloaded Mutable Module is not correct.", + msg=f"The output of original and refitted Mutable Module is not the same.", ) # Clean up model env