diff --git a/README.md b/README.md
index 466bf568aa..8910d2bb46 100644
--- a/README.md
+++ b/README.md
@@ -241,6 +241,7 @@ The following model architectures, tasks and device distributions have been vali
| MiniCPM3 | |
Single card | [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation) |
| Baichuan2 | | Single card | [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation) |
| DeepSeek-V2 | | :heavy_check_mark: | [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation) |
+| ChatGLM | DeepSpeed | Single card | [language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation) |
- Diffusers:
diff --git a/docs/source/index.mdx b/docs/source/index.mdx
index 048456799b..e61a8ad5a2 100644
--- a/docs/source/index.mdx
+++ b/docs/source/index.mdx
@@ -108,6 +108,7 @@ In the tables below, ✅ means single-card, multi-card and DeepSpeed have all be
| MiniCPM3 | | Single card | [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation) |
| Baichuan2 | | Single card | [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation) |
| DeepSeek-V2 | | ✅ | [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation) |
+| ChatGLM | DeepSpeed | Single card | [language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation) |
- Diffusers
diff --git a/examples/language-modeling/README.md b/examples/language-modeling/README.md
index f10e46d757..23cbe5aacf 100644
--- a/examples/language-modeling/README.md
+++ b/examples/language-modeling/README.md
@@ -131,6 +131,33 @@ python ../gaudi_spawn.py \
This example has been validated with the following DeepSpeed ZeRO-2 config: https://github.com/huggingface/optimum-habana/blob/main/tests/configs/deepspeed_zero_2.json
+### Multi-card Training with Deepspeed (chatglm3-6b)
+```bash
+python ../gaudi_spawn.py \
+ --world_size 8 --use_deepspeed run_clm.py \
+ --config_name THUDM/chatglm3-6b \
+ --tokenizer_name THUDM/chatglm3-6b \
+ --dataset_name wikitext \
+ --dataset_config_name wikitext-2-raw-v1 \
+ --per_device_train_batch_size 6 \
+ --per_device_eval_batch_size 4 \
+ --do_train \
+ --do_eval \
+ --deepspeed llama2_ds_zero3_config.json \
+ --output_dir /tmp/test-clm \
+ --gaudi_config_name Habana/gpt2 \
+ --use_habana \
+ --use_lazy_mode \
+ --throughput_warmup_steps 3 \
+ --bf16 \
+ --block_size 1024 \
+ --use_cache False \
+ --overwrite_output_dir \
+ --logging_first_step True \
+ --logging_steps 20
+```
+
+
## Multi-Node Training with Deepspeed (GPT-NeoX)
The following command triggers the fine-tuning of [GPT-NeoX-20B](https://huggingface.co/EleutherAI/gpt-neox-20b) on WikiText-2 with Deepspeed ZeRO-2.
diff --git a/examples/language-modeling/run_clm.py b/examples/language-modeling/run_clm.py
index 16c615f1bf..b97b634941 100644
--- a/examples/language-modeling/run_clm.py
+++ b/examples/language-modeling/run_clm.py
@@ -431,6 +431,10 @@ def main():
config.update_from_string(model_args.config_overrides)
logger.info(f"New config: {config}")
+ # Note that chatglm2/3 has float16 dtype from config.json, and on Gaudi we need to use bfloat16.
+ if config.model_type == "chatglm":
+ config.dtype = "torch.bfloat16"
+
tokenizer_kwargs = {
"cache_dir": model_args.cache_dir,
"use_fast": model_args.use_fast_tokenizer,
@@ -472,8 +476,8 @@ def main():
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
# on a small vocab and want a smaller embedding size, remove this test.
- # We need to skip this test for baichuan pretrain
- if config.model_type not in ("baichuan"):
+ # We need to skip this test for baichuan and chatglm pretrain
+ if config.model_type not in ("baichuan", "chatglm"):
embedding_size = model.get_input_embeddings().weight.shape[0]
if len(tokenizer) > embedding_size:
model.resize_token_embeddings(len(tokenizer))
diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py
index 9468bea956..72789fd994 100644
--- a/optimum/habana/transformers/generation/utils.py
+++ b/optimum/habana/transformers/generation/utils.py
@@ -115,6 +115,7 @@
"minicpm3",
"baichuan",
"deepseek_v2",
+ "chatglm",
]
# Initial generated token index is set to 1 to accomodate SOS (start of string) token.
@@ -1087,8 +1088,9 @@ def generate(
"gemma",
"gemma2",
"baichuan",
+ "chatglm",
]
- ), "reuse_cache only supported by llama, mistral, falcon, mixtral, phi, qwen2, qwen2_moe, gemma, gemma2, starcoder2 and baichuan at the moment"
+ ), "reuse_cache only supported by llama, mistral, falcon, mixtral, phi, qwen2, qwen2_moe, gemma, gemma2, starcoder2, baichuan and chatglm at the moment"
if not generation_config.bucket_internal:
assert (
generation_config.bucket_size <= 0
diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py
index 45b4cc33a6..ee092ecff9 100644
--- a/optimum/habana/transformers/modeling_utils.py
+++ b/optimum/habana/transformers/modeling_utils.py
@@ -32,6 +32,10 @@
BaichuanConfig,
BaichuanForCausalLM,
BaichuanTokenizer,
+ ChatGLMConfig,
+ ChatGLMForConditionalGeneration,
+ ChatGLMForSequenceClassification,
+ ChatGLMTokenizer,
DeciLMConfig,
DeciLMForCausalLM,
DeepseekTokenizerFast,
@@ -719,3 +723,11 @@ def adapt_transformers_to_gaudi():
transformers.AutoConfig.register("baichuan", BaichuanConfig)
transformers.AutoTokenizer.register(BaichuanConfig, slow_tokenizer_class=BaichuanTokenizer)
transformers.AutoModelForCausalLM.register(BaichuanConfig, BaichuanForCausalLM)
+
+ # Register chatglm with optimization on Gaudi
+ transformers.AutoConfig.register("chatglm", ChatGLMConfig)
+ transformers.AutoTokenizer.register(ChatGLMConfig, ChatGLMTokenizer)
+ transformers.AutoModel.register(ChatGLMConfig, ChatGLMForConditionalGeneration)
+ transformers.AutoModelForCausalLM.register(ChatGLMConfig, ChatGLMForConditionalGeneration)
+ transformers.AutoModelForSeq2SeqLM.register(ChatGLMConfig, ChatGLMForConditionalGeneration)
+ transformers.AutoModelForSequenceClassification.register(ChatGLMConfig, ChatGLMForSequenceClassification)
diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py
index 0bb122e0c7..2a5e685942 100644
--- a/optimum/habana/transformers/models/__init__.py
+++ b/optimum/habana/transformers/models/__init__.py
@@ -36,6 +36,12 @@
gaudi_bloom_convert_to_standard_cache,
gaudi_bloom_model_forward,
)
+from .chatglm import (
+ ChatGLMConfig,
+ ChatGLMForConditionalGeneration,
+ ChatGLMForSequenceClassification,
+ ChatGLMTokenizer,
+)
from .clip import (
GaudiCLIPAttention,
GaudiCLIPEncoder,
diff --git a/optimum/habana/transformers/models/chatglm/__init__.py b/optimum/habana/transformers/models/chatglm/__init__.py
new file mode 100644
index 0000000000..c2eefd8955
--- /dev/null
+++ b/optimum/habana/transformers/models/chatglm/__init__.py
@@ -0,0 +1,6 @@
+from .configuration_chatglm import ChatGLMConfig
+from .modeling_chatglm import (
+ ChatGLMForConditionalGeneration,
+ ChatGLMForSequenceClassification,
+)
+from .tokenization_chatglm import ChatGLMTokenizer
diff --git a/optimum/habana/transformers/models/chatglm/configuration_chatglm.py b/optimum/habana/transformers/models/chatglm/configuration_chatglm.py
new file mode 100644
index 0000000000..372e29bd5f
--- /dev/null
+++ b/optimum/habana/transformers/models/chatglm/configuration_chatglm.py
@@ -0,0 +1,88 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+###############################################################################
+# Copyright (C) 2022-2024 Habana Labs, Ltd. an Intel Company
+###############################################################################
+
+"""
+Adapted from the following sources:
+https://huggingface.co/THUDM/chatglm2-6b/blob/main/configuration_chatglm.py
+https://huggingface.co/THUDM/chatglm3-6b/blob/main/configuration_chatglm.py
+"""
+
+from transformers import PretrainedConfig
+
+
+class ChatGLMConfig(PretrainedConfig):
+ model_type = "chatglm"
+
+ def __init__(
+ self,
+ _name_or_path=None,
+ num_layers=28,
+ padded_vocab_size=65024,
+ hidden_size=4096,
+ ffn_hidden_size=13696,
+ kv_channels=128,
+ num_attention_heads=32,
+ seq_length=2048,
+ hidden_dropout=0.0,
+ classifier_dropout=None,
+ attention_dropout=0.0,
+ layernorm_epsilon=1e-5,
+ rmsnorm=True,
+ apply_residual_connection_post_layernorm=False,
+ post_layer_norm=True,
+ add_bias_linear=False,
+ add_qkv_bias=False,
+ bias_dropout_fusion=True,
+ multi_query_attention=False,
+ multi_query_group_num=1,
+ rope_ratio=1,
+ apply_query_key_layer_scaling=True,
+ attention_softmax_in_fp32=True,
+ fp32_residual_connection=False,
+ pre_seq_len=None,
+ prefix_projection=False,
+ **kwargs,
+ ):
+ self.name_or_path = _name_or_path
+ self.num_layers = num_layers
+ self.vocab_size = padded_vocab_size
+ self.padded_vocab_size = padded_vocab_size
+ self.hidden_size = hidden_size
+ self.ffn_hidden_size = ffn_hidden_size
+ self.kv_channels = kv_channels
+ self.num_attention_heads = num_attention_heads
+ self.seq_length = seq_length
+ self.hidden_dropout = hidden_dropout
+ self.classifier_dropout = classifier_dropout
+ self.attention_dropout = attention_dropout
+ self.layernorm_epsilon = layernorm_epsilon
+ self.rmsnorm = rmsnorm
+ self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
+ self.post_layer_norm = post_layer_norm
+ self.add_bias_linear = add_bias_linear
+ self.add_qkv_bias = add_qkv_bias
+ self.bias_dropout_fusion = bias_dropout_fusion
+ self.multi_query_attention = multi_query_attention
+ self.multi_query_group_num = multi_query_group_num
+ self.rope_ratio = rope_ratio
+ self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
+ self.attention_softmax_in_fp32 = attention_softmax_in_fp32
+ self.fp32_residual_connection = fp32_residual_connection
+ self.pre_seq_len = pre_seq_len
+ self.prefix_projection = prefix_projection
+ super().__init__(**kwargs)
diff --git a/optimum/habana/transformers/models/chatglm/modeling_chatglm.py b/optimum/habana/transformers/models/chatglm/modeling_chatglm.py
new file mode 100644
index 0000000000..01c508aa5d
--- /dev/null
+++ b/optimum/habana/transformers/models/chatglm/modeling_chatglm.py
@@ -0,0 +1,1879 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+###############################################################################
+# Copyright (C) 2022-2024 Habana Labs, Ltd. an Intel Company
+###############################################################################
+
+"""PyTorch ChatGLM model."""
+
+import copy
+import json
+import math
+import os
+import warnings
+from typing import Callable, Dict, List, Optional, Tuple, Union
+
+import habana_frameworks.torch.core as htcore
+import torch
+import torch.nn.functional as F
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
+from torch.nn.utils import skip_init
+from transformers.cache_utils import Cache
+from transformers.generation.logits_process import LogitsProcessor
+from transformers.generation.utils import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPast,
+ CausalLMOutputWithPast,
+ SequenceClassifierOutputWithPast,
+)
+from transformers.modeling_utils import PreTrainedModel
+from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
+from transformers.utils import logging
+
+from ...modeling_attn_mask_utils import _gaudi_prepare_4d_causal_attention_mask
+from .configuration_chatglm import ChatGLMConfig
+
+
+"""
+Adapted from the following sources:
+https://huggingface.co/THUDM/chatglm2-6b/blob/main/modeling_chatglm.py
+https://huggingface.co/THUDM/chatglm3-6b/blob/main/modeling_chatglm.py
+"""
+
+try:
+ from habana_frameworks.torch.hpex.kernels import FusedSDPA
+except ImportError:
+ print("Cannot import Fused SDPA from Habana Torch")
+ FusedSDPA = None
+
+try:
+ from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV3 as FusedRoPE
+except ImportError:
+ print("Cannot import Fused Rope from Habana Torch")
+ FusedRoPE = None
+
+try:
+ from habana_frameworks.torch.hpex.normalization import FusedRMSNorm
+except ImportError:
+ print("Cannot import Fused RMSNorm from Habana Torch")
+ FusedRMSNorm = None
+
+
+logger = logging.get_logger(__name__)
+
+MODEL_FOR_CAUSAL_LM_MAPPING_NAMES["chatglm"] = "ChatGLMForConditionalGeneration"
+_CONFIG_FOR_DOC = "ChatGLMConfig"
+
+
+def default_init(cls, *args, **kwargs):
+ return cls(*args, **kwargs)
+
+
+# FusedScaledDotProductAttention
+class ModuleFusedSDPA(torch.nn.Module):
+ def __init__(self, fusedSDPA):
+ super().__init__()
+ self._hpu_kernel_fsdpa = fusedSDPA
+
+ def forward(self, query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode):
+ return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode)
+
+
+class InvalidScoreLogitsProcessor(LogitsProcessor):
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
+ if torch.isnan(scores).any() or torch.isinf(scores).any():
+ scores.zero_()
+ scores[..., 198] = 5e4
+ return scores
+
+
+def split_tensor_along_last_dim(
+ tensor: torch.Tensor,
+ num_partitions: int,
+ contiguous_split_chunks: bool = False,
+) -> List[torch.Tensor]:
+ """Split a tensor along its last dimension.
+
+ Arguments:
+ tensor: input tensor.
+ num_partitions: number of partitions to split the tensor
+ contiguous_split_chunks: If True, make each chunk contiguous
+ in memory.
+
+ Returns:
+ A list of Tensors
+ """
+ # Get the size and dimension.
+ last_dim = tensor.dim() - 1
+ last_dim_size = tensor.size()[last_dim] // num_partitions
+ # Split.
+ tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
+ # Note: torch.split does not create contiguous tensors by default.
+ if contiguous_split_chunks:
+ return tuple(chunk.contiguous() for chunk in tensor_list)
+
+ return tensor_list
+
+
+class Matmul(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x, y):
+ return torch.matmul(x, y)
+
+
+class KVCache(torch.nn.Module):
+ def __init__(self):
+ super(KVCache, self).__init__()
+ self.cache = None
+ self.inp_seq_len = -1
+
+ def allocate(self, inp_seq_len, dtype, device, shape):
+ if self.cache is None or self.cache.shape != shape:
+ self.inp_seq_len = inp_seq_len
+ # self.cache = torch.zeros(shape, dtype=dtype, device=device)
+ self.cache = torch.zeros(shape, dtype=torch.bfloat16, device=device)
+ else:
+ assert (
+ self.inp_seq_len == inp_seq_len
+ ), f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}"
+ self.cache.fill_(0)
+
+ def update(self, prev, cur, dim, idx, inp_seq_len):
+ orig_cur = cur
+ if prev.shape == cur.shape:
+ prev.copy_(cur)
+ return orig_cur
+ if cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]:
+ # Initialize
+ prev[:, :, :inp_seq_len, :].copy_(cur)
+ return orig_cur
+ assert cur.shape[2] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}"
+ if idx is not None:
+ prev.index_copy_(dim, idx - 1, cur)
+ return prev
+ else:
+ return torch.cat((prev, cur), dim=dim)
+
+ def get_shape(self):
+ if self.cache is None:
+ return None
+ return self.cache.shape
+
+ def forward(self, cur, dim, idx):
+ return self.update(self.cache, cur, dim, idx, self.inp_seq_len)
+
+
+def gaudi_chatglm_repeat_kv(
+ query_layer: torch.Tensor,
+ key_layer: torch.Tensor,
+ value_layer: torch.Tensor,
+ attention_mask: torch.Tensor,
+):
+ """
+ Refer https://github.com/huggingface/optimum-habana/blob/main/optimum/habana/transformers/models/llama/modeling_llama.py#L109
+ Copied from repeat_kv: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
+ The only differences are:
+ - Append num_key_value_heads == 1 check as kv states can be broadcasted during matmuls so need to expand and reshape them.
+ - Add new args query_states, key_states, value_states and attention_mask and update the logic for expansion.
+ The query states go from (batch, num_heads, seqlen, head_dim) to (batch, num_key_value_heads, n_rep, seqlen, head_dim)
+ The key/value states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_key_value_heads, 1, seqlen, head_dim)
+ """
+ batch, num_query_heads, q_len, head_dim = query_layer.shape
+ batch, num_key_value_heads, kv_len, head_dim = key_layer.shape
+ n_rep = num_query_heads // num_key_value_heads
+ if n_rep == 1 or num_key_value_heads == 1:
+ return query_layer, key_layer, value_layer, attention_mask
+
+ new_kv_shape = (batch, num_key_value_heads, 1, kv_len, head_dim)
+ key_layer = key_layer.reshape(new_kv_shape)
+ value_layer = value_layer.reshape(new_kv_shape)
+
+ new_q_shape = (batch, num_key_value_heads, n_rep, q_len, head_dim)
+ query_layer = query_layer.reshape(new_q_shape)
+
+ if attention_mask is not None:
+ # Add groups dim and set to 1
+ attention_mask = attention_mask.unsqueeze(1)
+
+ return query_layer, key_layer, value_layer, attention_mask
+
+
+def _config_to_kwargs(args):
+ common_kwargs = {
+ "dtype": args.torch_dtype,
+ }
+ return common_kwargs
+
+
+def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
+ data_dtype = x.dtype
+ compute_dtype = rope_cache.dtype
+
+ if x.device.type == "hpu" and FusedRoPE:
+ x_out = FusedRoPE.apply(x.to(compute_dtype), rope_cache)
+ else:
+ x = x.to(compute_dtype)
+ # x: [sq, b, np, hn]
+ sq, _b, np, _hn = x.size(0), x.size(1), x.size(2), x.size(3)
+ rot_dim = rope_cache.shape[-2] * 2
+ x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
+ # truncate to support variable sizes
+ rope_cache = rope_cache[:sq]
+ xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)
+ rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)
+ x_out2 = torch.stack(
+ [
+ xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
+ xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
+ ],
+ -1,
+ )
+ x_out2 = x_out2.flatten(3)
+ x_out = torch.cat((x_out2, x_pass), dim=-1)
+
+ return x_out.to(data_dtype)
+
+
+# Copied from transformers.models.bart.modeling_bart._make_causal_mask
+def _make_causal_mask(
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
+):
+ """
+ Make causal mask used for bi-directional self-attention.
+ """
+ bsz, tgt_len = input_ids_shape
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
+ mask_cond = torch.arange(mask.size(-1), device=device)
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
+ mask = mask.to(dtype)
+
+ if past_key_values_length > 0:
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
+
+
+# Copied from transformers.models.bart.modeling_bart._expand_mask
+def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
+ """
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
+ """
+ bsz, src_len = mask.size()
+ tgt_len = tgt_len if tgt_len is not None else src_len
+
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
+
+ inverted_mask = 1.0 - expanded_mask
+
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
+
+
+class CoreAttention(torch.nn.Module):
+ def __init__(self, config: ChatGLMConfig, layer_number):
+ super(CoreAttention, self).__init__()
+
+ self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
+ self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
+ if self.apply_query_key_layer_scaling:
+ self.attention_softmax_in_fp32 = True
+ self.layer_number = max(1, layer_number)
+
+ projection_size = config.kv_channels * config.num_attention_heads
+
+ # Per attention head and per partition values.
+ self.hidden_size_per_partition = projection_size
+ self.hidden_size_per_attention_head = projection_size // config.num_attention_heads
+ self.num_attention_heads_per_partition = config.num_attention_heads
+
+ coeff = None
+ self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
+ if self.apply_query_key_layer_scaling:
+ coeff = self.layer_number
+ self.norm_factor *= coeff
+ self.coeff = coeff
+
+ self.dropout_rate = config.attention_dropout
+ self.attention_dropout = torch.nn.Dropout(config.attention_dropout)
+
+ self.matmul_qk = Matmul()
+ self.matmul_av = Matmul()
+ self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) if FusedSDPA else None
+
+ self.q_block_size = 4096
+
+ def gaudi_flash_attn_v1(self, query_layer, key_layer, value_layer, attention_mask, softmax_mode):
+ """
+ Gaudi version of Flash Attention V1 to support long sequence at prompt phase
+ Causal mask is not supported in this optimization
+ """
+ q_len = query_layer.size(-2)
+ q_tiles = (
+ (q_len // self.q_block_size) if (q_len % self.q_block_size == 0) else math.ceil(q_len / self.q_block_size)
+ )
+ q_padding = q_tiles * self.q_block_size - q_len
+ query_layer = F.pad(query_layer, (0, 0, 0, q_padding), "constant", 0)
+ if attention_mask is not None:
+ attention_mask = F.pad(attention_mask, (0, 0, 0, q_padding), "constant", torch.finfo(key_layer.dtype).min)
+
+ row_o_list = []
+ for i in range(q_tiles):
+ s, e = i * self.q_block_size, (i + 1) * self.q_block_size
+ row_q = query_layer[:, :, s:e, :]
+ row_mask = attention_mask[:, :, s:e, :]
+ attn_output_partial = self.fused_scaled_dot_product_attention(
+ row_q, key_layer, value_layer, row_mask, self.dropout_rate, False, None, softmax_mode
+ )
+ row_o_list.append(attn_output_partial)
+ attn_output = torch.cat(row_o_list, dim=-2)
+
+ if q_padding != 0:
+ attn_output = attn_output[:, :, :-q_padding, :]
+
+ return attn_output
+
+ def forward(
+ self,
+ query_layer: torch.Tensor,
+ key_layer: torch.Tensor,
+ value_layer: torch.Tensor,
+ attention_mask: torch.Tensor,
+ cache_position: Optional[torch.LongTensor] = None,
+ attn_softmax_bf16: Optional[bool] = False,
+ use_flash_attention: Optional[bool] = False,
+ flash_attention_recompute: Optional[bool] = False,
+ flash_attention_causal_mask: Optional[bool] = False,
+ flash_attention_fast_softmax: Optional[bool] = False,
+ ):
+ bsz, _, q_len, _ = query_layer.shape
+
+ if use_flash_attention and FusedSDPA:
+ if not self.training:
+ self.dropout_rate = 0.0
+
+ import habana_frameworks.torch.hpu as ht
+
+ softmax_mode = "fast" if flash_attention_fast_softmax else "None"
+
+ if q_len == 1:
+ # next token
+ use_recompute = True if os.getenv("QUANT_CONFIG", "") else False
+ with ht.sdp_kernel(enable_recompute=use_recompute):
+ attn_output = self.fused_scaled_dot_product_attention(
+ query_layer,
+ key_layer,
+ value_layer,
+ attention_mask,
+ self.dropout_rate,
+ False,
+ None,
+ softmax_mode,
+ )
+ else:
+ # first token
+ if flash_attention_causal_mask:
+ # causal masking on first token requires inputs to be of the same length
+ with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
+ attn_output = self.fused_scaled_dot_product_attention(
+ query_layer, key_layer, value_layer, None, self.dropout_rate, True, None, softmax_mode
+ )
+ else:
+ with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
+ if (q_len > 8192 or (q_len >= 6144 and bsz >= 2)) and self.training:
+ attn_output = self.gaudi_flash_attn_v1(
+ query_layer, key_layer, value_layer, attention_mask, softmax_mode
+ )
+ htcore.mark_step()
+ else:
+ attn_output = self.fused_scaled_dot_product_attention(
+ query_layer,
+ key_layer,
+ value_layer,
+ attention_mask,
+ self.dropout_rate,
+ False,
+ None,
+ softmax_mode,
+ )
+ else:
+ query_layer, key_layer, value_layer, attention_mask = gaudi_chatglm_repeat_kv(
+ query_layer, key_layer, value_layer, attention_mask
+ )
+ attn_weights = self.matmul_qk(query_layer, key_layer.transpose(-2, -1)) / self.norm_factor
+
+ if self.coeff is not None:
+ attn_weights = attn_weights * self.coeff
+
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask
+ if cache_position is not None:
+ causal_mask = attention_mask[:, :, cache_position, : key_layer.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ if attn_softmax_bf16:
+ attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=query_layer.dtype)
+ else:
+ # upcast attention to fp32
+ attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
+ query_layer.dtype
+ )
+ if self.training:
+ attn_weights = self.attention_dropout(attn_weights)
+ attn_output = self.matmul_av(attn_weights, value_layer)
+ attn_output = attn_output.reshape(bsz, -1, q_len, self.hidden_size_per_attention_head)
+
+ # =================
+ # Output. [sq, b, h]
+ # =================
+
+ attn_output = attn_output.permute(2, 0, 1, 3).contiguous()
+
+ context_layer = attn_output.reshape(q_len, bsz, self.hidden_size_per_partition)
+
+ return context_layer
+
+
+class SelfAttention(torch.nn.Module):
+ """Parallel self-attention layer abstract class.
+
+ Self-attention layer takes input with size [s, b, h]
+ and returns output of the same size.
+ """
+
+ def __init__(self, config: ChatGLMConfig, layer_number, device=None):
+ super(SelfAttention, self).__init__()
+ self.layer_number = max(1, layer_number)
+ self.config = config
+
+ self.projection_size = config.kv_channels * config.num_attention_heads
+
+ # Per attention head and per partition values.
+ self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads
+ self.num_attention_heads_per_partition = config.num_attention_heads
+
+ self.multi_query_attention = config.multi_query_attention
+ self.qkv_hidden_size = 3 * self.projection_size
+ if self.multi_query_attention:
+ self.num_multi_query_groups_per_partition = config.multi_query_group_num
+ self.qkv_hidden_size = (
+ self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num
+ )
+ self.query_key_value = nn.Linear(
+ config.hidden_size,
+ self.qkv_hidden_size,
+ bias=config.add_bias_linear or config.add_qkv_bias,
+ device=device,
+ **_config_to_kwargs(config),
+ )
+ self.core_attention = CoreAttention(config, self.layer_number)
+
+ # Output.
+ self.dense = nn.Linear(
+ self.projection_size,
+ config.hidden_size,
+ bias=config.add_bias_linear,
+ device=device,
+ **_config_to_kwargs(config),
+ )
+
+ self.k_cache = KVCache()
+ self.v_cache = KVCache()
+ self.inp_seq_len = -1
+
+ def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len):
+ cache_shape = (
+ batch_size,
+ self.num_multi_query_groups_per_partition,
+ max_seq_len,
+ self.hidden_size_per_attention_head,
+ )
+ device = self.query_key_value.weight.device
+ dtype = self.config.torch_dtype
+ self.k_cache.allocate(inp_seq_len, dtype, device, cache_shape)
+ self.v_cache.allocate(inp_seq_len, dtype, device, cache_shape)
+
+ def reorder(self, tensor, beam_idx, dim_a, dim_b):
+ updated = tensor.index_select(0, beam_idx)
+ tensor.copy_(updated)
+
+ def reorder_kv_cache(self, beam_idx: torch.LongTensor):
+ if self.k_cache.cache is None:
+ return (None, None)
+
+ head_dim = self.k_cache.cache.size(-1)
+ seq_length = self.k_cache.cache.size(-2)
+ self.reorder(self.k_cache.cache, beam_idx, seq_length, head_dim)
+ self.reorder(self.v_cache.cache, beam_idx, seq_length, head_dim)
+ return (self.k_cache.cache.shape, self.v_cache.cache.shape)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ prefix_encoder: Optional[torch.Tensor] = None,
+ rotary_pos_emb: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ token_idx: Optional[torch.Tensor] = None,
+ attn_softmax_bf16: Optional[bool] = False,
+ reuse_cache: Optional[bool] = False,
+ use_flash_attention: Optional[bool] = False,
+ flash_attention_recompute: Optional[bool] = False,
+ flash_attention_causal_mask: Optional[bool] = False,
+ flash_attention_fast_softmax: Optional[bool] = False,
+ cache_idx: int = None,
+ ):
+ # hidden_states: [sq, b, h]
+ q_len, bsz, hiddenSize = hidden_states.size()
+
+ # =================================================
+ # Pre-allocate memory for key-values for inference.
+ # =================================================
+ # =====================
+ # Query, Key, and Value
+ # =====================
+
+ # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
+ mixed_x_layer = self.query_key_value(hidden_states)
+
+ if self.multi_query_attention:
+ (query_layer, key_layer, value_layer) = mixed_x_layer.split(
+ [
+ self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,
+ self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
+ self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
+ ],
+ dim=-1,
+ )
+ query_layer = query_layer.view(
+ query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
+ )
+ key_layer = key_layer.view(
+ key_layer.size()[:-1]
+ + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
+ )
+ value_layer = value_layer.view(
+ value_layer.size()[:-1]
+ + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
+ )
+ else:
+ new_tensor_shape = mixed_x_layer.size()[:-1] + (
+ self.num_attention_heads_per_partition,
+ 3 * self.hidden_size_per_attention_head,
+ )
+ mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
+
+ # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
+ (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
+
+ if rotary_pos_emb is not None:
+ query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
+ key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
+
+ if prefix_encoder is not None:
+ prefix_encoder_key, prefix_encoder_value = prefix_encoder
+ if mixed_x_layer.dtype == torch.float8_e4m3fn:
+ from habana_frameworks.torch.hpex.kernels.Fp8Ops import cast_to_fp8_v2
+
+ prefix_encoder_key = cast_to_fp8_v2(prefix_encoder_key, None, False, False, mixed_x_layer.dtype)[0]
+ prefix_encoder_value = cast_to_fp8_v2(prefix_encoder_value, None, False, False, mixed_x_layer.dtype)[0]
+ else:
+ prefix_encoder_key = prefix_encoder_key.to(mixed_x_layer.dtype)
+ prefix_encoder_value = prefix_encoder_value.to(mixed_x_layer.dtype)
+
+ key_layer = torch.cat((prefix_encoder_key, key_layer), dim=0)
+ value_layer = torch.cat((prefix_encoder_value, value_layer), dim=0)
+
+ query_layer = query_layer.permute(1, 2, 0, 3).contiguous()
+ key_layer = key_layer.permute(1, 2, 0, 3).contiguous()
+ value_layer = value_layer.permute(1, 2, 0, 3).contiguous()
+
+ if use_cache:
+ # reuse k, v, self_attention
+ if reuse_cache:
+ key_layer = self.k_cache(key_layer, 2, token_idx)
+ value_layer = self.v_cache(value_layer, 2, token_idx)
+ past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape())
+ else:
+ if past_key_value is None:
+ past_key = torch.zeros(
+ key_layer.shape, dtype=self.query_key_value.weight.dtype, device=key_layer.device
+ )
+ past_value = torch.zeros(
+ key_layer.shape, dtype=self.query_key_value.weight.dtype, device=key_layer.device
+ )
+ past_key_value = [past_key, past_value]
+ key_layer = self.k_cache.update(past_key_value[0], key_layer, 2, token_idx, self.inp_seq_len)
+ value_layer = self.v_cache.update(past_key_value[1], value_layer, 2, token_idx, self.inp_seq_len)
+ if token_idx is None:
+ past_key_value = (key_layer, value_layer)
+
+ if cache_idx is not None and q_len == 1:
+ key_layer = key_layer[:, :, :cache_idx, :]
+ value_layer = value_layer[:, :, :cache_idx, :]
+ if attention_mask is not None:
+ attention_mask = attention_mask[:, :, :, :cache_idx]
+ else:
+ past_key_value = None
+
+ # ==================================
+ # core attention computation
+ # ==================================
+
+ context_layer = self.core_attention(
+ query_layer,
+ key_layer,
+ value_layer,
+ attention_mask,
+ cache_position,
+ attn_softmax_bf16,
+ use_flash_attention,
+ flash_attention_recompute,
+ flash_attention_causal_mask,
+ flash_attention_fast_softmax,
+ )
+
+ output = self.dense(context_layer)
+
+ # No output_attention
+ attn_weights = None
+
+ return output, attn_weights, past_key_value
+
+
+class MLP(torch.nn.Module):
+ """MLP.
+
+ MLP will take the input with h hidden state, project it to 4*h
+ hidden dimension, perform nonlinear transformation, and project the
+ state back into h hidden dimension.
+ """
+
+ def __init__(self, config: ChatGLMConfig, device=None):
+ super(MLP, self).__init__()
+
+ self.add_bias = config.add_bias_linear
+
+ # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
+ self.dense_h_to_4h = nn.Linear(
+ config.hidden_size,
+ config.ffn_hidden_size * 2,
+ bias=self.add_bias,
+ device=device,
+ **_config_to_kwargs(config),
+ )
+
+ def swiglu(x):
+ x = torch.chunk(x, 2, dim=-1)
+ return F.silu(x[0]) * x[1]
+
+ self.activation_func = swiglu
+
+ # Project back to h.
+ self.dense_4h_to_h = nn.Linear(
+ config.ffn_hidden_size, config.hidden_size, bias=self.add_bias, device=device, **_config_to_kwargs(config)
+ )
+
+ def forward(self, hidden_states):
+ # [b, s, 4hp]
+ intermediate_parallel = self.dense_h_to_4h(hidden_states)
+ intermediate_parallel = self.activation_func(intermediate_parallel)
+ # [b, s, h]
+ output = self.dense_4h_to_h(intermediate_parallel)
+ return output
+
+
+class RotaryEmbedding(nn.Module):
+ def __init__(self, dim, rope_ratio=1, original_impl=False, device=None, dtype=None):
+ super().__init__()
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim))
+ self.register_buffer("inv_freq", inv_freq)
+ self.dim = dim
+ self.original_impl = original_impl
+ self.rope_ratio = rope_ratio
+
+ def forward_impl(self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000):
+ """Enhanced Transformer with Rotary Position Embedding.
+
+ Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
+ transformers/rope/__init__.py. MIT License:
+ https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
+ """
+ # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
+ base = base * self.rope_ratio
+ theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=torch.float, device=device) / n_elem))
+
+ # Create position indexes `[0, 1, ..., seq_len - 1]`
+ seq_idx = torch.arange(seq_len, dtype=torch.float, device=device)
+
+ # Calculate the product of position index and $\theta_i$
+ idx_theta = torch.outer(seq_idx, theta).float()
+
+ cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)
+
+ # this is to mimic the behaviour of complex32, else we will get different results
+ if dtype in (torch.float16, torch.bfloat16, torch.int8):
+ cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half()
+ return cache
+
+ def forward(self, max_seq_len, offset=0):
+ return self.forward_impl(max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device)
+
+
+class RMSNorm(torch.nn.Module):
+ def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
+ self.eps = eps
+
+ def forward(self, hidden_states: torch.Tensor):
+ if hidden_states.device.type == "hpu" and FusedRMSNorm:
+ # mixed dtypes are not good for FusedRMSNorm, both inputs need to have same dtype
+ if hidden_states.dtype != self.weight.dtype:
+ orig_dtype = hidden_states.dtype
+ hidden_states = FusedRMSNorm.apply(hidden_states.to(self.weight.dtype), self.weight, self.eps)
+ return hidden_states.to(orig_dtype)
+ else:
+ hidden_states = FusedRMSNorm.apply(hidden_states, self.weight, self.eps)
+ return hidden_states
+ else:
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
+ return self.weight * hidden_states.to(input_dtype)
+
+
+class GLMBlock(torch.nn.Module):
+ """A single transformer layer.
+
+ Transformer layer takes input with size [s, b, h] and returns an
+ output of the same size.
+ """
+
+ def __init__(self, config: ChatGLMConfig, layer_number, device=None):
+ super(GLMBlock, self).__init__()
+ self.layer_number = layer_number
+
+ self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
+
+ self.fp32_residual_connection = config.fp32_residual_connection
+
+ LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
+ # Layernorm on the input data.
+ self.input_layernorm = LayerNormFunc(
+ config.hidden_size, eps=config.layernorm_epsilon, device=device, dtype=config.torch_dtype
+ )
+
+ # Self attention.
+ self.self_attention = SelfAttention(config, layer_number, device=device)
+ self.hidden_dropout = config.hidden_dropout
+
+ # Layernorm on the attention output
+ self.post_attention_layernorm = LayerNormFunc(
+ config.hidden_size, eps=config.layernorm_epsilon, device=device, dtype=config.torch_dtype
+ )
+
+ # MLP
+ self.mlp = MLP(config, device=device)
+
+ def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len):
+ self.self_attention.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len)
+
+ def reorder_kv_cache(self, beam_idx: torch.LongTensor):
+ return self.self_attention.reorder_kv_cache(beam_idx)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ prefix_encoder: Optional[torch.Tensor] = None,
+ rotary_pos_emb: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ token_idx: Optional[torch.Tensor] = None,
+ attn_softmax_bf16: Optional[bool] = False,
+ reuse_cache: Optional[bool] = False,
+ use_flash_attention: Optional[bool] = False,
+ flash_attention_recompute: Optional[bool] = False,
+ flash_attention_causal_mask: Optional[bool] = False,
+ flash_attention_fast_softmax: Optional[bool] = False,
+ cache_idx: int = None,
+ ):
+ # hidden_states: [b, s, h]
+
+ # Layer norm at the beginning of the transformer layer.
+ layernorm_output = self.input_layernorm(hidden_states)
+ # Self attention.
+ attention_output, self_attn_weights, present_key_value = self.self_attention(
+ layernorm_output,
+ attention_mask,
+ prefix_encoder,
+ rotary_pos_emb,
+ past_key_value,
+ output_attentions,
+ use_cache,
+ cache_position,
+ token_idx,
+ attn_softmax_bf16,
+ reuse_cache,
+ use_flash_attention=use_flash_attention,
+ flash_attention_recompute=flash_attention_recompute,
+ flash_attention_fast_softmax=flash_attention_fast_softmax,
+ cache_idx=cache_idx,
+ )
+
+ # Residual connection.
+ if self.apply_residual_connection_post_layernorm:
+ residual = layernorm_output
+ else:
+ residual = hidden_states
+
+ layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training)
+ layernorm_input = residual + layernorm_input
+
+ # Layer norm post the self attention.
+ layernorm_output = self.post_attention_layernorm(layernorm_input)
+
+ # MLP.
+ mlp_output = self.mlp(layernorm_output)
+
+ # Second residual connection.
+ if self.apply_residual_connection_post_layernorm:
+ residual = layernorm_output
+ else:
+ residual = layernorm_input
+
+ output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training)
+ output = residual + output
+
+ outputs = (output,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+
+class GLMTransformer(torch.nn.Module):
+ """Transformer class."""
+
+ def __init__(self, config: ChatGLMConfig, device=None):
+ super(GLMTransformer, self).__init__()
+
+ self.fp32_residual_connection = config.fp32_residual_connection
+ self.post_layer_norm = config.post_layer_norm
+
+ # Number of layers.
+ self.num_layers = config.num_layers
+
+ # Transformer layers.
+ def build_layer(layer_number):
+ return GLMBlock(config, layer_number, device=device)
+
+ self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)])
+
+ if self.post_layer_norm:
+ LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
+ # Final layer norm before output.
+ self.final_layernorm = LayerNormFunc(
+ config.hidden_size, eps=config.layernorm_epsilon, device=device, dtype=config.torch_dtype
+ )
+
+ self.gradient_checkpointing = False
+
+ def _get_layer(self, layer_number):
+ return self.layers[layer_number]
+
+ def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len):
+ for layer in self.layers:
+ layer.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len)
+
+ def reorder_kv_cache(self, beam_idx: torch.LongTensor):
+ return tuple(layer.reorder_kv_cache(beam_idx) for layer in self.layers)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ prefix_encoders: Optional[List[torch.FloatTensor]] = None,
+ rotary_pos_emb: Optional[torch.Tensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ token_idx: Optional[torch.Tensor] = None,
+ attn_softmax_bf16: Optional[bool] = False,
+ reuse_cache: Optional[bool] = False,
+ use_flash_attention: Optional[bool] = False,
+ flash_attention_recompute: Optional[bool] = False,
+ flash_attention_causal_mask: Optional[bool] = False,
+ flash_attention_fast_softmax: Optional[bool] = False,
+ cache_idx: int = None,
+ lazy_mode: Optional[bool] = True,
+ ):
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = () if use_cache else None
+
+ if lazy_mode:
+ htcore.mark_step()
+
+ for index in range(self.num_layers):
+ if (
+ lazy_mode
+ and not self.training
+ and (torch.distributed.is_initialized() is False or torch.distributed.get_world_size() == 1)
+ ):
+ htcore.mark_step()
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ past_key_value = past_key_values[index] if past_key_values is not None else None
+ prefix_encoder = prefix_encoders[index] if prefix_encoders is not None else None
+
+ layer = self._get_layer(index)
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ # None for past_key_value
+ return module(
+ *inputs,
+ past_key_values,
+ output_attentions,
+ use_cache,
+ cache_position,
+ None,
+ attn_softmax_bf16,
+ False,
+ use_flash_attention,
+ flash_attention_recompute,
+ flash_attention_causal_mask,
+ flash_attention_fast_softmax,
+ None,
+ )
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(layer),
+ hidden_states,
+ attention_mask,
+ prefix_encoder,
+ rotary_pos_emb,
+ )
+ else:
+ layer_outputs = layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ prefix_encoder=prefix_encoder,
+ rotary_pos_emb=rotary_pos_emb,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ token_idx=token_idx,
+ attn_softmax_bf16=attn_softmax_bf16,
+ reuse_cache=reuse_cache,
+ use_flash_attention=use_flash_attention,
+ flash_attention_recompute=flash_attention_recompute,
+ flash_attention_fast_softmax=flash_attention_fast_softmax,
+ cache_idx=cache_idx,
+ )
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ next_cache = next_decoder_cache if use_cache else None
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ # Final layer norm.
+ if self.post_layer_norm:
+ hidden_states = self.final_layernorm(hidden_states)
+
+ return hidden_states, next_cache, all_hidden_states, all_self_attns
+
+
+class ChatGLMPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and
+ a simple interface for downloading and loading pretrained models.
+ """
+
+ is_parallelizable = False
+ supports_gradient_checkpointing = True
+ config_class = ChatGLMConfig
+ base_model_prefix = "transformer"
+ _no_split_modules = ["GLMBlock"]
+
+ def _init_weights(self, module: nn.Module):
+ """Initialize the weights."""
+ return
+
+ def get_masks(self, input_ids, past_key_values, padding_mask=None):
+ batch_size, seq_length = input_ids.shape
+ full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device)
+ full_attention_mask.tril_()
+ past_length = 0
+ if past_key_values:
+ past_length = past_key_values[0][0].shape[0]
+ if past_length:
+ full_attention_mask = torch.cat(
+ (torch.ones(batch_size, seq_length, past_length, device=input_ids.device), full_attention_mask), dim=-1
+ )
+ if padding_mask is not None:
+ full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
+ if not past_length and padding_mask is not None:
+ full_attention_mask -= padding_mask.unsqueeze(-1) - 1
+ full_attention_mask = (full_attention_mask < 0.5).bool()
+ full_attention_mask.unsqueeze_(1)
+ return full_attention_mask
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, GLMTransformer):
+ module.gradient_checkpointing = value
+
+
+class PrefixEncoder(torch.nn.Module):
+ """
+ The torch.nn model to encode the prefix
+ Input shape: (batch-size, prefix-length)
+ Output shape: (batch-size, prefix-length, 2*layers*hidden)
+ """
+
+ def __init__(self, config: ChatGLMConfig):
+ super().__init__()
+ self.prefix_projection = config.prefix_projection
+ if self.prefix_projection:
+ # Use a two-layer MLP to encode the prefix
+ kv_size = config.num_layers * config.kv_channels * config.multi_query_group_num * 2
+ self.embedding = torch.nn.Embedding(config.pre_seq_len, kv_size)
+ self.trans = torch.nn.Sequential(
+ torch.nn.Linear(kv_size, config.hidden_size),
+ torch.nn.Tanh(),
+ torch.nn.Linear(config.hidden_size, kv_size),
+ )
+ else:
+ self.embedding = torch.nn.Embedding(
+ config.pre_seq_len, config.num_layers * config.kv_channels * config.multi_query_group_num * 2
+ )
+
+ def forward(self, prefix: torch.Tensor):
+ if self.prefix_projection:
+ prefix_tokens = self.embedding(prefix)
+ past_key_values = self.trans(prefix_tokens)
+ else:
+ past_key_values = self.embedding(prefix)
+ return past_key_values
+
+
+class Embedding(torch.nn.Module):
+ """Language model embeddings."""
+
+ def __init__(self, config: ChatGLMConfig, device=None):
+ super(Embedding, self).__init__()
+
+ self.hidden_size = config.hidden_size
+ # Word embeddings (parallel).
+ self.word_embeddings = nn.Embedding(
+ config.padded_vocab_size, self.hidden_size, dtype=config.torch_dtype, device=device
+ )
+ self.fp32_residual_connection = config.fp32_residual_connection
+
+ def forward(self, input_ids):
+ # Embeddings.
+ words_embeddings = self.word_embeddings(input_ids)
+ embeddings = words_embeddings
+ # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
+ embeddings = embeddings.transpose(0, 1).contiguous()
+ # If the input flag for fp32 residual connection is set, convert for float.
+ if self.fp32_residual_connection:
+ embeddings = embeddings.float()
+ return embeddings
+
+
+class ChatGLMModel(ChatGLMPreTrainedModel):
+ def __init__(self, config: ChatGLMConfig, device=None, empty_init=False):
+ super().__init__(config)
+ if empty_init:
+ init_method = skip_init
+ else:
+ init_method = default_init
+ init_kwargs = {}
+ if device is not None:
+ init_kwargs["device"] = device
+ self.embedding = init_method(Embedding, config, **init_kwargs)
+ self.num_layers = config.num_layers
+ self.multi_query_group_num = config.multi_query_group_num
+ self.kv_channels = config.kv_channels
+
+ # Rotary positional embeddings
+ self.seq_length = config.seq_length
+ rotary_dim = (
+ config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
+ )
+ self.rotary_pos_emb = RotaryEmbedding(
+ rotary_dim // 2, original_impl=config.original_rope, device=device, dtype=config.torch_dtype
+ )
+ self.encoder = init_method(GLMTransformer, config, **init_kwargs)
+ self.output_layer = init_method(
+ nn.Linear,
+ config.hidden_size,
+ config.padded_vocab_size,
+ bias=False,
+ dtype=config.torch_dtype,
+ **init_kwargs,
+ )
+ self.pre_seq_len = config.pre_seq_len if config.pre_seq_len is not None else 0
+ self.prefix_projection = config.prefix_projection
+ if self.pre_seq_len > 0:
+ for param in self.parameters():
+ param.requires_grad = False
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
+ self.prefix_encoder = PrefixEncoder(config)
+ self.dropout = torch.nn.Dropout(0.1)
+
+ def get_input_embeddings(self):
+ return self.embedding.word_embeddings
+
+ def get_prompt(self, batch_size, device, dtype=torch.bfloat16):
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
+ past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)
+ past_key_values = past_key_values.view(
+ batch_size, self.pre_seq_len, self.num_layers * 2, self.multi_query_group_num, self.kv_channels
+ )
+ # seq_len, b, nh, hidden_size
+ past_key_values = self.dropout(past_key_values)
+ past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
+
+ return past_key_values
+
+ def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len):
+ self.encoder.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len)
+
+ def reorder_kv_cache(self, beam_idx: torch.LongTensor):
+ return self.encoder.reorder_kv_cache(beam_idx)
+
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
+ def _prepare_decoder_attention_mask(
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length, pre_seq_len
+ ):
+ # create causal mask
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ combined_attention_mask = None
+ if input_shape[-1] > 1:
+ combined_attention_mask = _make_causal_mask(
+ input_shape,
+ inputs_embeds.dtype,
+ device=inputs_embeds.device,
+ past_key_values_length=past_key_values_length + pre_seq_len,
+ )
+ return combined_attention_mask
+
+ if attention_mask is not None:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
+ inputs_embeds.device
+ )
+ if pre_seq_len > 0:
+ pre_seq_mask = torch.zeros(
+ [input_shape[0], 1, 1, pre_seq_len],
+ dtype=expanded_attn_mask.dtype,
+ device=expanded_attn_mask.device,
+ )
+ expanded_attn_mask = torch.cat([pre_seq_mask, expanded_attn_mask], dim=-1)
+ combined_attention_mask = (
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
+ )
+
+ return combined_attention_mask
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ token_idx: Optional[torch.Tensor] = None,
+ attn_softmax_bf16: Optional[bool] = False,
+ reuse_cache: Optional[bool] = False,
+ use_flash_attention: Optional[bool] = False,
+ flash_attention_recompute: Optional[bool] = False,
+ flash_attention_causal_mask: Optional[bool] = False,
+ flash_attention_fast_softmax: Optional[bool] = False,
+ cache_idx: int = None,
+ lazy_mode: Optional[bool] = True,
+ ):
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ batch_size, seq_length = input_ids.shape
+ elif inputs_embeds is not None:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embedding(input_ids).to(self.embedding.word_embeddings.weight.dtype)
+
+ past_seen_tokens = 0
+
+ if past_key_values is not None and use_cache: # kept for BC (cache positions)
+ if reuse_cache:
+ if isinstance(past_key_values[0][0], torch.Tensor):
+ past_seen_tokens = past_key_values[0][0].shape[2]
+ else:
+ past_seen_tokens = past_key_values[0][0][2]
+ else:
+ past_seen_tokens = past_key_values[0][0].shape[2]
+
+ if position_ids is None:
+ position_ids = torch.arange(
+ past_seen_tokens, seq_length + past_seen_tokens, dtype=torch.long, device=inputs_embeds.device
+ )
+ position_ids = position_ids.unsqueeze(0)
+ if position_ids.size(-1) < seq_length:
+ position_ids = F.pad(position_ids, (0, seq_length - position_ids.size(-1)), "constant", 0)
+ cache_position = None
+
+ # embed positions
+ rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
+ rotary_pos_emb = rotary_pos_emb[position_ids]
+ rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
+
+ if attention_mask is None:
+ attention_mask = torch.ones((batch_size, past_seen_tokens), dtype=torch.bool, device=inputs_embeds.device)
+
+ prefix_encoders = None
+ if self.pre_seq_len > 0:
+ if token_idx is not None:
+ token_idx = token_idx + self.pre_seq_len
+ if past_key_values is None:
+ prefix_encoders = self.get_prompt(
+ batch_size=batch_size, device=input_ids.device, dtype=inputs_embeds.dtype
+ )
+ past_seen_tokens += self.pre_seq_len
+ if attention_mask is not None:
+ attention_mask = torch.cat(
+ [attention_mask.new_ones((batch_size, self.pre_seq_len)), attention_mask], dim=-1
+ )
+
+ attention_mask = _gaudi_prepare_4d_causal_attention_mask(
+ attention_mask,
+ input_ids.shape if input_ids is not None else (batch_size, seq_length),
+ inputs_embeds,
+ past_seen_tokens,
+ )
+
+ # Run encoder.
+ hidden_states, next_cache, all_hidden_states, all_self_attns = self.encoder(
+ inputs_embeds,
+ attention_mask,
+ prefix_encoders,
+ rotary_pos_emb,
+ past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ token_idx=token_idx,
+ attn_softmax_bf16=attn_softmax_bf16,
+ reuse_cache=reuse_cache,
+ use_flash_attention=use_flash_attention,
+ flash_attention_recompute=flash_attention_recompute,
+ flash_attention_causal_mask=flash_attention_causal_mask,
+ flash_attention_fast_softmax=flash_attention_fast_softmax,
+ cache_idx=cache_idx,
+ lazy_mode=lazy_mode,
+ )
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+
+class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
+ def __init__(self, config: ChatGLMConfig, empty_init=False, device=None):
+ super().__init__(config)
+
+ self.max_sequence_length = config.max_length
+ self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
+ self.config = config
+
+ def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len):
+ self.transformer.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len)
+ self.kv_cache_len = max_seq_len
+
+ def reorder_kv_cache(self, beam_idx: torch.LongTensor):
+ return self.transformer.reorder_kv_cache(beam_idx)
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ position_ids: Optional[torch.Tensor] = None,
+ attention_mask=None,
+ inputs_embeds=None,
+ token_idx=None,
+ **kwargs,
+ ):
+ reuse_cache = kwargs.get("reuse_cache")
+ bucket_internal = kwargs.get("bucket_internal")
+
+ if past_key_values:
+ if token_idx is not None:
+ input_ids = torch.index_select(input_ids, 1, token_idx - 1)
+ else:
+ input_ids = input_ids[:, -1:]
+ elif (reuse_cache or bucket_internal) and token_idx is not None:
+ # With reuse_cache, KV cache is pre allocated hence for the 1st token we can slice the inputs till token idx for the fwd pass
+ input_ids = input_ids[:, :token_idx]
+ attention_mask = attention_mask[:, :token_idx]
+
+ position_ids = kwargs.get("position_ids", None)
+ if attention_mask is not None and position_ids is None:
+ # create position_ids on the fly for batch generation
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ if past_key_values:
+ if token_idx is not None:
+ position_ids = torch.index_select(position_ids, 1, token_idx - 1)
+ else:
+ position_ids = position_ids[:, -input_ids.shape[1] :]
+
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+ if inputs_embeds is not None and past_key_values is None:
+ model_inputs = {"inputs_embeds": inputs_embeds}
+ else:
+ model_inputs = {"input_ids": input_ids}
+
+ model_inputs.update(
+ {
+ "position_ids": position_ids,
+ "past_key_values": past_key_values,
+ "use_cache": kwargs.get("use_cache"),
+ "attention_mask": attention_mask,
+ "token_idx": token_idx,
+ "trim_logits": kwargs.get("trim_logits"),
+ "attn_softmax_bf16": kwargs.get("attn_softmax_bf16"),
+ "reuse_cache": reuse_cache,
+ "use_flash_attention": kwargs.get("use_flash_attention"),
+ "flash_attention_recompute": kwargs.get("flash_attention_recompute"),
+ "flash_attention_causal_mask": kwargs.get("flash_attention_causal_mask"),
+ "flash_attention_fast_softmax": kwargs.get("flash_attention_fast_softmax"),
+ "cache_idx": kwargs.get("cache_idx"),
+ "lazy_mode": kwargs.get("lazy_mode"),
+ }
+ )
+ return model_inputs
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ token_idx: Optional[torch.Tensor] = None,
+ trim_logits: Optional[bool] = False,
+ attn_softmax_bf16: Optional[bool] = False,
+ reuse_cache: Optional[bool] = False,
+ use_flash_attention: Optional[bool] = False,
+ flash_attention_recompute: Optional[bool] = False,
+ flash_attention_causal_mask: Optional[bool] = False,
+ flash_attention_fast_softmax: Optional[bool] = False,
+ cache_idx: int = None,
+ lazy_mode: Optional[bool] = True,
+ ):
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.transformer(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ token_idx=token_idx,
+ attn_softmax_bf16=attn_softmax_bf16,
+ reuse_cache=reuse_cache,
+ use_flash_attention=use_flash_attention,
+ flash_attention_recompute=flash_attention_recompute,
+ flash_attention_causal_mask=flash_attention_causal_mask,
+ flash_attention_fast_softmax=flash_attention_fast_softmax,
+ cache_idx=cache_idx,
+ lazy_mode=lazy_mode,
+ )
+
+ hidden_states = outputs[0].transpose(0, 1).contiguous()
+ _, seq_len, _ = hidden_states.shape
+ if seq_len > 1 and trim_logits and not self.training:
+ if token_idx is not None:
+ hidden_states = hidden_states.index_select(1, token_idx - 1)
+ else:
+ hidden_states = hidden_states[:, -1, :]
+
+ lm_logits = self.transformer.output_layer(hidden_states).float()
+
+ loss = None
+ if labels is not None:
+ lm_logits = lm_logits
+
+ # Shift so that tokens < n predict n
+ shift_logits = lm_logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss(ignore_index=-100)
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
+
+ lm_logits = lm_logits.to(hidden_states.dtype)
+ loss = loss.to(hidden_states.dtype)
+
+ if not return_dict:
+ output = (lm_logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=lm_logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ @staticmethod
+ def _reorder_cache(
+ past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
+ ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
+ """
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
+ beam_idx at every generation step.
+
+ Output shares the same memory storage as `past`.
+ """
+ return tuple(
+ (
+ layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)),
+ layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)),
+ )
+ for layer_past in past
+ )
+
+ def process_response(self, output, history):
+ content = ""
+ history = copy.deepcopy(history)
+ for response in output.split("<|assistant|>"):
+ if "\n" in response:
+ metadata, content = response.split("\n", maxsplit=1)
+ else:
+ metadata, content = "", response
+ if not metadata.strip():
+ content = content.strip()
+ history.append({"role": "assistant", "metadata": metadata, "content": content})
+ content = content.replace("[[训¤~C¤~W¤¤~W¤]]", "2023年")
+ else:
+ history.append({"role": "assistant", "metadata": metadata, "content": content})
+ if history[0]["role"] == "system" and "tools" in history[0]:
+ parameters = json.loads(content)
+ content = {"name": metadata.strip(), "parameters": parameters}
+ else:
+ content = {"name": metadata.strip(), "content": content}
+ return content, history
+
+ def build_chat_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None):
+ # For chatglm2-6b, we need to use a different method to process the inputs.
+ if self.config.name_or_path == "THUDM/chatglm2-6b":
+ prompt = tokenizer.build_prompt(query, history=history)
+ inputs = tokenizer([prompt], return_tensors="pt")
+ else:
+ inputs = tokenizer.apply_chat_template(
+ history, add_generation_prompt=True, tokenize=True, return_tensors="pt", return_dict=True
+ )
+
+ inputs = inputs.to(self.device)
+ return inputs
+
+ def build_stream_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None):
+ if history:
+ prompt = "\n\n[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query)
+ input_ids = tokenizer.encode(prompt, add_special_tokens=False)
+ input_ids = input_ids[1:]
+ inputs = tokenizer.batch_encode_plus([(input_ids, None)], return_tensors="pt", add_special_tokens=False)
+ else:
+ prompt = "[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query)
+ inputs = tokenizer([prompt], return_tensors="pt")
+ inputs = inputs.to(self.device)
+ return inputs
+
+ @torch.inference_mode()
+ def chat(
+ self,
+ tokenizer,
+ query: str,
+ history: List[Dict] = None,
+ role: str = "user",
+ num_beams=1,
+ do_sample=False,
+ top_p=0.8,
+ temperature=0.8,
+ logits_processor=None,
+ **kwargs,
+ ):
+ if history is None:
+ history = []
+ if logits_processor is None:
+ logits_processor = LogitsProcessorList()
+ logits_processor.append(InvalidScoreLogitsProcessor())
+ gen_kwargs = {
+ "num_beams": num_beams,
+ "do_sample": do_sample,
+ "top_p": top_p,
+ "temperature": temperature,
+ "logits_processor": logits_processor,
+ **kwargs,
+ }
+ history.append({"role": role, "content": query})
+ inputs = self.build_chat_inputs(tokenizer, query, history=history)
+ eos_token_id = [
+ tokenizer.eos_token_id,
+ tokenizer.convert_tokens_to_ids("<|user|>"),
+ tokenizer.convert_tokens_to_ids("<|observation|>"),
+ ]
+ outputs = self.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id, ignore_eos=False)
+ outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) : -1]
+ response = tokenizer.decode(outputs, skip_special_tokens=True)
+ response, history = self.process_response(response, history)
+
+ return response, history
+
+ @torch.inference_mode()
+ def stream_chat(
+ self,
+ tokenizer,
+ query: str,
+ history: List[Tuple[str, str]] = None,
+ past_key_values=None,
+ max_length: int = 8192,
+ do_sample=True,
+ top_p=0.8,
+ temperature=0.8,
+ logits_processor=None,
+ return_past_key_values=False,
+ **kwargs,
+ ):
+ if history is None:
+ history = []
+ if logits_processor is None:
+ logits_processor = LogitsProcessorList()
+ logits_processor.append(InvalidScoreLogitsProcessor())
+ gen_kwargs = {
+ "max_length": max_length,
+ "do_sample": do_sample,
+ "top_p": top_p,
+ "temperature": temperature,
+ "logits_processor": logits_processor,
+ **kwargs,
+ }
+ if past_key_values is None and not return_past_key_values:
+ inputs = self.build_inputs(tokenizer, query, history=history)
+ else:
+ inputs = self.build_stream_inputs(tokenizer, query, history=history)
+ if past_key_values is not None:
+ past_length = past_key_values[0][0].shape[0]
+ if self.transformer.pre_seq_len is not None:
+ past_length -= self.transformer.pre_seq_len
+ inputs.position_ids += past_length
+ attention_mask = inputs.attention_mask
+ attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
+ inputs["attention_mask"] = attention_mask
+ for outputs in self.stream_generate(
+ **inputs, past_key_values=past_key_values, return_past_key_values=return_past_key_values, **gen_kwargs
+ ):
+ if return_past_key_values:
+ outputs, past_key_values = outputs
+ outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) :]
+ response = tokenizer.decode(outputs)
+ if response and response[-1] != "�":
+ response = self.process_response(response)
+ new_history = history + [(query, response)]
+ if return_past_key_values:
+ yield response, new_history, past_key_values
+ else:
+ yield response, new_history
+
+ @torch.inference_mode()
+ def stream_generate(
+ self,
+ input_ids,
+ generation_config: Optional[GenerationConfig] = None,
+ logits_processor: Optional[LogitsProcessorList] = None,
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
+ return_past_key_values=False,
+ **kwargs,
+ ):
+ input_ids_seq_length = input_ids.shape[-1]
+
+ if generation_config is None:
+ generation_config = self.generation_config
+ generation_config = copy.deepcopy(generation_config)
+ model_kwargs = generation_config.update(**kwargs)
+ model_kwargs["use_cache"] = generation_config.use_cache
+ eos_token_id = generation_config.eos_token_id
+
+ if isinstance(eos_token_id, int):
+ eos_token_id = [eos_token_id]
+
+ has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
+ if has_default_max_length and generation_config.max_new_tokens is None:
+ warnings.warn(
+ f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
+ "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
+ " recommend using `max_new_tokens` to control the maximum length of the generation.",
+ UserWarning,
+ )
+ elif generation_config.max_new_tokens is not None:
+ generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
+ if not has_default_max_length:
+ logger.warn(
+ f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
+ f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
+ "Please refer to the documentation for more information. "
+ "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
+ UserWarning,
+ )
+
+ if input_ids_seq_length >= generation_config.max_length:
+ input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
+ logger.warning(
+ f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
+ f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
+ " increasing `max_new_tokens`."
+ )
+
+ # 2. Set generation parameters if not already defined
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
+ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
+
+ logits_processor = self._get_logits_processor(
+ generation_config=generation_config,
+ input_ids_seq_length=input_ids_seq_length,
+ encoder_input_ids=input_ids,
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
+ logits_processor=logits_processor,
+ )
+
+ stopping_criteria = self._get_stopping_criteria(
+ generation_config=generation_config, stopping_criteria=stopping_criteria
+ )
+ logits_warper = self._get_logits_warper(generation_config)
+
+ unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
+ scores = None
+ while True:
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
+ # forward pass to get next token
+ outputs = self(
+ **model_inputs,
+ return_dict=True,
+ output_attentions=False,
+ output_hidden_states=False,
+ )
+
+ next_token_logits = outputs.logits[:, -1, :]
+
+ # pre-process distribution
+ next_token_scores = logits_processor(input_ids, next_token_logits)
+ next_token_scores = logits_warper(input_ids, next_token_scores)
+
+ # sample
+ probs = nn.functional.softmax(next_token_scores, dim=-1)
+ if generation_config.do_sample:
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
+ else:
+ next_tokens = torch.argmax(probs, dim=-1)
+
+ # update generated ids, model inputs, and length for next step
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
+ model_kwargs = self._update_model_kwargs_for_generation(
+ outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
+ )
+ unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long())
+ if return_past_key_values:
+ yield input_ids, outputs.past_key_values
+ else:
+ yield input_ids
+ # stop when each sentence is finished, or if we exceed the maximum length
+ if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
+ break
+
+
+class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
+ def __init__(self, config: ChatGLMConfig, empty_init=False, device=None):
+ super().__init__(config)
+
+ self.num_labels = config.num_labels
+ self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
+
+ self.classifier_head = nn.Linear(config.hidden_size, config.num_labels, bias=True, dtype=torch.bf16)
+ if config.classifier_dropout is not None:
+ self.dropout = nn.Dropout(config.classifier_dropout)
+ else:
+ self.dropout = None
+ self.config = config
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.LongTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ token_idx: Optional[torch.Tensor] = None,
+ attn_softmax_bf16: Optional[bool] = False,
+ reuse_cache: Optional[bool] = False,
+ use_flash_attention: Optional[bool] = False,
+ flash_attention_recompute: Optional[bool] = False,
+ flash_attention_causal_mask: Optional[bool] = False,
+ flash_attention_fast_softmax: Optional[bool] = False,
+ cache_idx: int = None,
+ lazy_mode: Optional[bool] = True,
+ ) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutputWithPast]:
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = self.transformer(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ token_idx=token_idx,
+ attn_softmax_bf16=attn_softmax_bf16,
+ reuse_cache=reuse_cache,
+ use_flash_attention=use_flash_attention,
+ flash_attention_recompute=flash_attention_recompute,
+ flash_attention_causal_mask=flash_attention_causal_mask,
+ flash_attention_fast_softmax=flash_attention_fast_softmax,
+ cache_idx=cache_idx,
+ lazy_mode=lazy_mode,
+ )
+
+ hidden_states = transformer_outputs[0].transpose(0, 1).contiguous()
+ pooled_hidden_states = hidden_states[-1]
+ if self.dropout is not None:
+ pooled_hidden_states = self.dropout(pooled_hidden_states)
+ logits = self.classifier_head(pooled_hidden_states)
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze().float(), labels.squeeze())
+ else:
+ loss = loss_fct(logits.float(), labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels).float(), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits.float(), labels.view(-1, self.num_labels))
+
+ if not return_dict:
+ output = (logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
diff --git a/optimum/habana/transformers/models/chatglm/tokenization_chatglm.py b/optimum/habana/transformers/models/chatglm/tokenization_chatglm.py
new file mode 100644
index 0000000000..b650893a1c
--- /dev/null
+++ b/optimum/habana/transformers/models/chatglm/tokenization_chatglm.py
@@ -0,0 +1,368 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+###############################################################################
+# Copyright (C) 2022-2024 Habana Labs, Ltd. an Intel Company
+###############################################################################
+
+"""
+Adapted from the following sources:
+https://huggingface.co/THUDM/chatglm2-6b/blob/main/tokenization_chatglm.py
+https://huggingface.co/THUDM/chatglm3-6b/blob/main/tokenization_chatglm.py
+"""
+
+import json
+import os
+import re
+from typing import Dict, List, Optional, Union
+
+from transformers import PreTrainedTokenizer
+from transformers.tokenization_utils_base import BatchEncoding, EncodedInput
+from transformers.utils import PaddingStrategy, logging
+from transformers.utils.import_utils import is_sentencepiece_available
+
+
+if is_sentencepiece_available():
+ from sentencepiece import SentencePieceProcessor
+
+
+logger = logging.get_logger(__name__)
+
+
+class SPTokenizer:
+ def __init__(self, model_path: str):
+ # reload tokenizer
+ assert os.path.isfile(model_path), model_path
+ self.sp_model = SentencePieceProcessor(model_file=model_path)
+
+ # BOS / EOS token IDs
+ self.n_words: int = self.sp_model.vocab_size()
+ self.bos_id: int = self.sp_model.bos_id()
+ self.eos_id: int = self.sp_model.eos_id()
+ self.pad_id: int = self.sp_model.unk_id()
+ assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
+
+ role_special_tokens = ["<|system|>", "<|user|>", "<|assistant|>", "<|observation|>"]
+ special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop"] + role_special_tokens
+ self.special_tokens = {}
+ self.index_special_tokens = {}
+ for token in special_tokens:
+ self.special_tokens[token] = self.n_words
+ self.index_special_tokens[self.n_words] = token
+ self.n_words += 1
+ self.role_special_token_expression = "|".join(
+ [re.escape(token) for token in special_tokens]
+ ) # for apply_chat_template
+
+ def tokenize(self, s: str, encode_special_tokens=False):
+ if encode_special_tokens:
+ last_index = 0
+ t = []
+ for match in re.finditer(self.role_special_token_expression, s):
+ if last_index < match.start():
+ t.extend(self.sp_model.EncodeAsPieces(s[last_index : match.start()]))
+ t.append(s[match.start() : match.end()])
+ last_index = match.end()
+ if last_index < len(s):
+ t.extend(self.sp_model.EncodeAsPieces(s[last_index:]))
+ return t
+ else:
+ return self.sp_model.EncodeAsPieces(s)
+
+ def encode(self, s: str, bos: bool = False, eos: bool = False) -> List[int]:
+ assert type(s) is str
+ t = self.sp_model.encode(s)
+ if bos:
+ t = [self.bos_id] + t
+ if eos:
+ t = t + [self.eos_id]
+ return t
+
+ def decode(self, t: List[int]) -> str:
+ text, buffer = "", []
+ for token in t:
+ if token in self.index_special_tokens:
+ if buffer:
+ text += self.sp_model.decode(buffer)
+ buffer = []
+ text += self.index_special_tokens[token]
+ else:
+ buffer.append(token)
+ if buffer:
+ text += self.sp_model.decode(buffer)
+ return text
+
+ def decode_tokens(self, tokens: List[str]) -> str:
+ text = self.sp_model.DecodePieces(tokens)
+ return text
+
+ def convert_token_to_id(self, token):
+ """Converts a token (str) in an id using the vocab."""
+ if token in self.special_tokens:
+ return self.special_tokens[token]
+ return self.sp_model.PieceToId(token)
+
+ def convert_id_to_token(self, index):
+ """Converts an index (integer) in a token (str) using the vocab."""
+ if (
+ index in self.index_special_tokens
+ or index in [self.eos_id, self.bos_id, self.pad_id]
+ or index < 0
+ or index > self.sp_model.vocab_size()
+ ):
+ return ""
+ return self.sp_model.IdToPiece(index)
+
+
+class ChatGLMTokenizer(PreTrainedTokenizer):
+ vocab_files_names = {"vocab_file": "tokenizer.model"}
+ model_input_names = ["input_ids", "attention_mask", "position_ids"]
+
+ def __init__(
+ self,
+ vocab_file,
+ padding_side="left",
+ clean_up_tokenization_spaces=False,
+ encode_special_tokens=False,
+ **kwargs,
+ ):
+ if not is_sentencepiece_available():
+ raise ModuleNotFoundError(
+ "Chatglm requires the Sentencepiece library to be installed. Please install it with: `pip install sentencepiece`"
+ )
+
+ self.__name__ = "ChatGLMTokenizer"
+ self.vocab_file = vocab_file
+
+ self.tokenizer = SPTokenizer(vocab_file)
+ self.special_tokens = {
+ "": self.tokenizer.bos_id if self.tokenizer.bos_id else None,
+ "": self.tokenizer.eos_id,
+ "": self.tokenizer.pad_id,
+ "": self.tokenizer.pad_id,
+ }
+
+ self.encode_special_tokens = encode_special_tokens
+
+ super().__init__(
+ padding_side=padding_side, clean_up_tokenization_spaces=clean_up_tokenization_spaces, **kwargs
+ )
+
+ def get_command(self, token):
+ if token in self.special_tokens:
+ return self.special_tokens[token]
+ assert token in self.tokenizer.special_tokens, f"{token} is not a special token for {self.name}"
+ return self.tokenizer.special_tokens[token]
+
+ @property
+ def unk_token(self) -> str:
+ return self.tokenizer.sp_model.IdToPiece(self.get_command(""))
+
+ @property
+ def pad_token(self) -> str:
+ return self.tokenizer.sp_model.IdToPiece(self.get_command(""))
+
+ @property
+ def eos_token(self) -> str:
+ return self.tokenizer.sp_model.IdToPiece(self.get_command(""))
+
+ @property
+ def unk_token_id(self) -> int:
+ return self.get_command("")
+
+ @property
+ def pad_token_id(self) -> int:
+ return self.get_command("")
+
+ @property
+ def eos_token_id(self):
+ return self.get_command("")
+
+ @unk_token.setter
+ def unk_token(self, value):
+ logger.warning("Setting unk_token is not supported, use the default one.")
+
+ @pad_token.setter
+ def pad_token(self, value):
+ logger.warning("Setting pad_token is not supported, use the default one.")
+
+ @eos_token.setter
+ def eos_token(self, value):
+ logger.warning("Setting eos_token is not supported, use the default one.")
+
+ @property
+ def vocab_size(self):
+ return self.tokenizer.n_words
+
+ def get_vocab(self):
+ """Returns vocab as a dict"""
+ vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}
+ vocab.update(self.added_tokens_encoder)
+ return vocab
+
+ def _tokenize(self, text, **kwargs):
+ return self.tokenizer.tokenize(text, encode_special_tokens=self.encode_special_tokens)
+
+ def _convert_token_to_id(self, token):
+ """Converts a token (str) in an id using the vocab."""
+ return self.tokenizer.convert_token_to_id(token)
+
+ def _convert_id_to_token(self, index):
+ """Converts an index (integer) in a token (str) using the vocab."""
+ return self.tokenizer.convert_id_to_token(index)
+
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
+ return self.tokenizer.decode_tokens(tokens)
+
+ def save_vocabulary(self, save_directory, filename_prefix=None):
+ """
+ Save the vocabulary and special tokens file to a directory.
+ Args:
+ save_directory (`str`):
+ The directory in which to save the vocabulary.
+ filename_prefix (`str`, *optional*):
+ An optional prefix to add to the named of the saved files.
+ Returns:
+ `Tuple(str)`: Paths to the files saved.
+ """
+ if os.path.isdir(save_directory):
+ vocab_file = os.path.join(save_directory, self.vocab_files_names["vocab_file"])
+ else:
+ vocab_file = save_directory
+
+ with open(self.vocab_file, "rb") as fin:
+ proto_str = fin.read()
+
+ with open(vocab_file, "wb") as writer:
+ writer.write(proto_str)
+
+ return (vocab_file,)
+
+ def get_prefix_tokens(self):
+ prefix_tokens = [self.get_command("[gMASK]"), self.get_command("sop")]
+ return prefix_tokens
+
+ def build_prompt(self, query, history=None):
+ if history is None:
+ history = []
+ prompt = ""
+ for i, (old_query, response) in enumerate(history):
+ prompt += "[Round {}]\n\n问:{}\n\n答:{}\n\n".format(i + 1, old_query, response)
+ prompt += "[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query)
+ return prompt
+
+ def build_single_message(self, role, metadata, message):
+ assert role in ["system", "user", "assistant", "observation"], role
+ role_tokens = [self.get_command(f"<|{role}|>")] + self.tokenizer.encode(f"{metadata}\n")
+ message_tokens = self.tokenizer.encode(message)
+ tokens = role_tokens + message_tokens
+ return tokens
+
+ def build_chat_input(self, query, history=None, role="user"):
+ if history is None:
+ history = []
+ input_ids = []
+ for item in history:
+ content = item["content"]
+ if item["role"] == "system" and "tools" in item:
+ content = content + "\n" + json.dumps(item["tools"], indent=4, ensure_ascii=False)
+ input_ids.extend(self.build_single_message(item["role"], item.get("metadata", ""), content))
+ input_ids.extend(self.build_single_message(role, "", query))
+ input_ids.extend([self.get_command("<|assistant|>")])
+ return self.batch_encode_plus([input_ids], return_tensors="pt", is_split_into_words=True)
+
+ def build_inputs_with_special_tokens(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
+ adding special tokens. A BERT sequence has the following format:
+ - single sequence: `[CLS] X [SEP]`
+ - pair of sequences: `[CLS] A [SEP] B [SEP]`
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs to which the special tokens will be added.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+ Returns:
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
+ """
+ prefix_tokens = self.get_prefix_tokens()
+ token_ids_0 = prefix_tokens + token_ids_0
+ if token_ids_1 is not None:
+ token_ids_0 = token_ids_0 + token_ids_1 + [self.get_command("")]
+ return token_ids_0
+
+ def _pad(
+ self,
+ encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
+ max_length: Optional[int] = None,
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+ pad_to_multiple_of: Optional[int] = None,
+ padding_side: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ ) -> dict:
+ """
+ Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
+ Args:
+ encoded_inputs:
+ Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
+ max_length: maximum length of the returned list and optionally padding length (see below).
+ Will truncate by taking into account the special tokens.
+ padding_strategy: PaddingStrategy to use for padding.
+ - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
+ - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
+ - PaddingStrategy.DO_NOT_PAD: Do not pad
+ The tokenizer padding sides are defined in self.padding_side:
+ - 'left': pads on the left of the sequences
+ - 'right': pads on the right of the sequences
+ pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
+ This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
+ `>= 7.5` (Volta).
+ return_attention_mask:
+ (optional) Set to False to avoid returning attention mask (default: set to model specifics)
+ """
+ # Load from model defaults
+ if padding_side is not None:
+ self.padding_side = padding_side
+ assert self.padding_side == "left"
+
+ required_input = encoded_inputs[self.model_input_names[0]]
+ seq_length = len(required_input)
+
+ if padding_strategy == PaddingStrategy.LONGEST:
+ max_length = len(required_input)
+
+ if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
+
+ needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
+
+ # Initialize attention mask if not present.
+ if "attention_mask" not in encoded_inputs:
+ encoded_inputs["attention_mask"] = [1] * seq_length
+
+ if "position_ids" not in encoded_inputs:
+ encoded_inputs["position_ids"] = list(range(seq_length))
+
+ if needs_to_be_padded:
+ difference = max_length - len(required_input)
+
+ if "attention_mask" in encoded_inputs:
+ encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
+ if "position_ids" in encoded_inputs:
+ encoded_inputs["position_ids"] = [0] * difference + encoded_inputs["position_ids"]
+ encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
+
+ return encoded_inputs
diff --git a/optimum/habana/transformers/trainer.py b/optimum/habana/transformers/trainer.py
index 08fb914ccc..22cb268871 100644
--- a/optimum/habana/transformers/trainer.py
+++ b/optimum/habana/transformers/trainer.py
@@ -973,9 +973,16 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args, use_reentrant: Optio
if step % args.gradient_accumulation_steps == 0:
self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
- # attn_softmax_bf16 and use_flash_attention is enabled only for llama, qwen2, starcoder2, gemma and baichuan
+ # attn_softmax_bf16 and use_flash_attention is enabled only for llama, qwen2, starcoder2, gemma, baichuan and chatglm
if hasattr(self.model, "generation_config") and self.model.generation_config is not None:
- if self.model.config.model_type in ["llama", "qwen2", "starcoder2", "gemma", "baichuan"]:
+ if self.model.config.model_type in [
+ "llama",
+ "qwen2",
+ "starcoder2",
+ "gemma",
+ "baichuan",
+ "chatglm",
+ ]:
if self.model.generation_config.attn_softmax_bf16:
inputs["attn_softmax_bf16"] = True
if self.model.generation_config.use_flash_attention:
@@ -1960,9 +1967,9 @@ def evaluation_loop(
if batch_size is None:
batch_size = observed_batch_size
- # attn_softmax_bf16 and use_flash_attention are enabled only for llama, qwen2, starcoder2, gemma and baichuan
+ # attn_softmax_bf16 and use_flash_attention are enabled only for llama, qwen2, starcoder2, gemma, baichuan and chatglm
if hasattr(self.model, "generation_config") and self.model.generation_config is not None:
- if self.model.config.model_type in ["llama", "qwen2", "starcoder2", "gemma", "baichuan"]:
+ if self.model.config.model_type in ["llama", "qwen2", "starcoder2", "gemma", "baichuan", "chatglm"]:
if self.model.generation_config.attn_softmax_bf16:
inputs["attn_softmax_bf16"] = True
if self.model.generation_config.use_flash_attention:
diff --git a/tests/baselines/chatglm3_6b.json b/tests/baselines/chatglm3_6b.json
new file mode 100644
index 0000000000..3a8c7a2feb
--- /dev/null
+++ b/tests/baselines/chatglm3_6b.json
@@ -0,0 +1,31 @@
+{
+ "gaudi2": {
+ "wikitext": {
+ "num_train_epochs": 3,
+ "eval_batch_size": 4,
+ "distribution": {
+ "deepspeed": {
+ "learning_rate": 5e-5,
+ "train_batch_size": 4,
+ "perplexity": 16.51629,
+ "train_runtime": 445,
+ "train_samples_per_second": 18.216,
+ "extra_arguments": [
+ "--dataset_name wikitext",
+ "--dataset_config_name wikitext-2-raw-v1",
+ "--block_size 1024",
+ "--use_cache False",
+ "--gradient_checkpointing",
+ "--bf16",
+ "--eval_strategy no",
+ "--save_strategy no",
+ "--throughput_warmup_steps 3",
+ "--logging_first_step True",
+ "--logging_steps 20",
+ "--deepspeed tests/configs/deepspeed_zero_3_gaudi1.json"
+ ]
+ }
+ }
+ }
+ }
+}
\ No newline at end of file
diff --git a/tests/test_examples.py b/tests/test_examples.py
index b6f07b0512..e8ba4ca12d 100644
--- a/tests/test_examples.py
+++ b/tests/test_examples.py
@@ -81,7 +81,7 @@ def _get_supported_models_for_script(
def is_valid_model_type(model_type: str) -> bool:
true_model_type = "llama" if model_type == "llama_guard" else model_type
- if model_type == "protst":
+ if model_type in ("protst", "chatglm"):
in_task_mapping = True
else:
# llama_guard is not a model type in Transformers so CONFIG_MAPPING wouldn't find it
@@ -241,6 +241,7 @@ def to_test(
"codellama/CodeLlama-13b-Instruct-hf",
"MIT/ast-finetuned-speech-commands-v2",
"meta-llama/LlamaGuard-7b",
+ "THUDM/chatglm3-6b",
]
case_only_in_gaudi2 = [
@@ -326,6 +327,8 @@ def to_test(
return True
elif "gemma" in model_name and IS_GAUDI2:
return True
+ elif "chatglm3" in model_name and IS_GAUDI2 and deepspeed:
+ return True
return False
@@ -365,6 +368,7 @@ def __new__(
attrs[f"test_{example_name}_{model_name.split('/')[-1]}_{distribution}"] = cls._create_test(
model_name, gaudi_config_name, multi_card, deepspeed, fsdp, torch_compile, fp8
)
+
attrs["EXAMPLE_NAME"] = example_name
return super().__new__(cls, name, bases, attrs)
diff --git a/tests/test_text_generation_example.py b/tests/test_text_generation_example.py
index 1fcadba9b0..6d5140b691 100644
--- a/tests/test_text_generation_example.py
+++ b/tests/test_text_generation_example.py
@@ -61,6 +61,7 @@
("baichuan-inc/Baichuan2-7B-Chat", 1, True, 108, False),
("baichuan-inc/Baichuan2-13B-Chat", 1, False, 66, False),
("deepseek-ai/DeepSeek-V2-Lite", 1, False, 35, False),
+ ("THUDM/chatglm3-6b", 1, True, 150, False),
],
"fp8": [
("tiiuae/falcon-180B", 4, 950, True, 128, 128, 2506.68),
diff --git a/tests/utils.py b/tests/utils.py
index cad1cc3821..849a3047de 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -64,6 +64,7 @@
"idefics2": [("HuggingFaceM4/idefics2-8b", "Habana/gpt2")],
"mllama": [("meta-llama/Llama-3.2-11B-Vision-Instruct", "Habana/gpt2")],
"gemma": [("google/gemma-2b-it", "Habana/gpt2")],
+ "chatglm": [("THUDM/chatglm3-6b", "Habana/gpt2")],
}
MODELS_TO_TEST_FOR_QUESTION_ANSWERING = [
@@ -82,7 +83,7 @@
# "distilbert",
]
-MODELS_TO_TEST_FOR_CAUSAL_LANGUAGE_MODELING = ["gpt2", "gpt_neox", "bloom", "code_llama", "gemma"]
+MODELS_TO_TEST_FOR_CAUSAL_LANGUAGE_MODELING = ["gpt2", "gpt_neox", "bloom", "code_llama", "gemma", "chatglm"]
MODELS_TO_TEST_FOR_SEQ2SEQ = ["t5"]