From b0cec9c696d8fc2d35765fb965d7458abe4bf91b Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Wed, 18 Dec 2024 13:13:45 +0000 Subject: [PATCH] set actual device Signed-off-by: jiqing-feng --- optimum/exporters/ipex/model_patcher.py | 1 + optimum/exporters/ipex/modeling_utils.py | 8 ++++---- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/optimum/exporters/ipex/model_patcher.py b/optimum/exporters/ipex/model_patcher.py index 03937754a..5b3dbe42f 100644 --- a/optimum/exporters/ipex/model_patcher.py +++ b/optimum/exporters/ipex/model_patcher.py @@ -133,6 +133,7 @@ def _patch_vit_model(model): def _patch_model(model): + setattr(model.config, "device", model.device) if is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_PATCHING): raise ImportError(f"Only ipex version >= {_IPEX_MINIMUM_VERSION_FOR_PATCHING} supports llama model patching") if is_transformers_version("<", _TRANSFORMERS_MIN_VERSION) or is_transformers_version( diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index ec9a18e04..1b8c2da41 100755 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -599,7 +599,7 @@ def __init__(self, module, config) -> None: super().__init__() _setattr_from_module(self, module) self.config = config - self.module_device = next(module.parameters()).device + self.module_device = config.device self.num_groups = self.num_heads // self.num_key_value_heads self.kv_head_mapping = torch.arange( 0, self.num_key_value_heads, dtype=torch.int32, device=self.module_device @@ -779,7 +779,7 @@ def __init__(self, module, config) -> None: super().__init__() _setattr_from_module(self, module) self.config = config - self.module_device = next(module.parameters()).device + self.module_device = config.device if getattr(config, "quantization_config", None) is None: if self.module_device.type == "cpu": # LinearAllreduce and LinearLayer cannot use fused op LinearAdd @@ -812,7 +812,7 @@ def __init__(self, module, config) -> None: super().__init__() _setattr_from_module(self, module) self.config = config - self.module_device = next(module.parameters()).device + self.module_device = config.device if getattr(config, "quantization_config", None) is None: # LinearAllreduce and LinearLayer cannot use fused op LinearAdd if self.module_device.type == "cpu": @@ -911,7 +911,7 @@ class _IPEXIntermediate(nn.Module): def __init__(self, module, config): super().__init__() _setattr_from_module(self, module) - self.module_device = next(module.parameters()).device + self.module_device = config.device if getattr(config, "quantization_config", None) is None: if self.module_device.type == "cpu": self.linear_gelu = LinearGelu(module.dense)