PrefixEncoder
# 根據字首 ID 獲取字首嵌入
# 字首嵌入將連線到分頭之後的 K 和 V 上
class PrefixEncoder(torch.nn.Module):
"""
The torch.nn model to encode the prefix
Input shape: (batch-size, prefix-length)
Output shape: (batch-size, prefix-length, 2*layers*hidden)
"""
def __init__(self, config: ChatGLMConfig):
super().__init__()
# 控制是否開啟字首投影,即用兩層 MLP 處理字首嵌入
self.prefix_projection = config.prefix_projection
if self.prefix_projection:
# KVSize = NLayer * 2 * NGroup * HeadSize
kv_size = config.num_layers * config.kv_channels * config.multi_query_group_num * 2
# 將 ID 變為嵌入的嵌入層,[PreSeqLen, KVSize]
self.embedding = torch.nn.Embedding(config.pre_seq_len, kv_size)
# 處理嵌入的 MLP
# 對映到 HidSize, 計算 tanh,在對映到 KVSize
self.trans = torch.nn.Sequential(
torch.nn.Linear(kv_size, config.hidden_size),
torch.nn.Tanh(),
torch.nn.Linear(config.hidden_size, kv_size)
)
else:
# 將 ID 變為嵌入的嵌入層
self.embedding = torch.nn.Embedding(config.pre_seq_len,
config.num_layers * config.kv_channels * config.multi_query_group_num * 2)
def forward(self, prefix: torch.Tensor):
# 字首 ID 尺寸為 [BatchSize, PreSeqLen]
# 根據字首 ID 獲取嵌入,尺寸為 [BatchSize, PreSeqLen, KVSize]
# 如果設定了需要投影,就用兩層 MLP 處理嵌入
if self.prefix_projection:
prefix_tokens = self.embedding(prefix)
past_key_values = self.trans(prefix_tokens)
else:
past_key_values = self.embedding(prefix)
return past_key_values
ChatGLMPreTrainedModel
class ChatGLMPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models.
"""
is_parallelizable = False
supports_gradient_checkpointing = True
config_class = ChatGLMConfig
base_model_prefix = "transformer"
_no_split_modules = ["GLMBlock"]
def _init_weights(self, module: nn.Module):
"""Initialize the weights."""
return
# 從輸入單詞 ID,KVCache生成預設的(上三角)掩碼矩陣
def get_masks(self, input_ids, past_key_values, padding_mask=None):
# 單詞 ID 尺寸為 [BatchSize, SeqLen]
batch_size, seq_length = input_ids.shape
# 掩碼矩陣初始化為全 1,形狀為 [BatchSize, SeqLen, SeqLen],每個輸入序列一個
full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device)
# 保留其下三角元素,其餘設為 9
full_attention_mask.tril_()
# CacheLen:KVCache 中序列長度
# 如果沒有提供則設為 0,如果提供了,從中獲取長度
past_length = 0
if past_key_values:
past_length = past_key_values[0][0].shape[0]
# 如果提供了 KVCache,在每個掩碼矩陣的上方填充 1,形狀為 [BatchSize, SeqLen, CacheSeqLen]
if past_length:
full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length,
device=input_ids.device), full_attention_mask), dim=-1)
# 如果提供了掩碼陣列([BatchSize, (Cache)SeqLen])
# 將其變形為 [BatchSize, 1, (Cache)SeqLen]
# 然後與掩碼矩陣相乘
# 將掩碼陣列為0的列設為0
if padding_mask is not None:
full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
# 如果提供了掩碼陣列,並且沒有提供 KVCache
# 將其變形為 [BatchSize, SeqLen, 1]
# 然後將掩碼陣列為 0 的行設為 1
if not past_length and padding_mask is not None:
full_attention_mask -= padding_mask.unsqueeze(-1) - 1
# 小於 0.5 變成 true,大於 0.5 變成 false,相當於將其翻轉,上三角不為 0
full_attention_mask = (full_attention_mask < 0.5).bool()
# 分頭,變形為 [BatchSize, 1, SeqLen, SeqLen]
full_attention_mask.unsqueeze_(1)
return full_attention_mask
# 從輸入單詞 ID 生成預設的(從零開始的)序列 ID
def get_position_ids(self, input_ids, device):
# 單詞 ID 尺寸為 [BatchSize, SeqLen]
batch_size, seq_length = input_ids.shape
# 序列 ID 建立為 0~(SeqLen-1)的一維陣列
# 變形為 [1, SeqLen],之後重複第一維 BatchSize 次,得到 [BatchSize, SeqLen]
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
return position_ids
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, GLMTransformer):
module.gradient_checkpointing = value
ChatGLMForConditionalGeneration.stream_generate()
@torch.inference_mode()
def stream_generate(
self,
input_ids,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
return_past_key_values=False,
**kwargs,
):
batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
if generation_config is None:
generation_config = self.generation_config
generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs)
model_kwargs["use_cache"] = generation_config.use_cache
bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
if has_default_max_length and generation_config.max_new_tokens is None:
warnings.warn(
f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
"This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
" recommend using `max_new_tokens` to control the maximum length of the generation.",
UserWarning,
)
elif generation_config.max_new_tokens is not None:
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
if not has_default_max_length:
logger.warn(
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
"Please refer to the documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
UserWarning,
)
# 如果 SeqLen 大於等於配置裡設定的 MaxSeqLen,發出警告
if input_ids_seq_length >= generation_config.max_length:
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
logger.warning(
f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
" increasing `max_new_tokens`."
)
# 如果沒有提供 logits 處理器,初始化為空列表
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
# 沒有提供停止標準,初始化為空列表
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
# 根據生成配置等物件獲取 logits 處理器
logits_processor = self._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length,
encoder_input_ids=input_ids,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
logits_processor=logits_processor,
)
# 根據生成配置等物件獲取停止標準
stopping_criteria = self._get_stopping_criteria(
generation_config=generation_config, stopping_criteria=stopping_criteria
)
# 根據生成配置獲取 logits 包裝器
logits_warper = self._get_logits_warper(generation_config)
# 未完成標誌,表示每個序列是否生成完畢的陣列
# 初始化為 [BatchSize] 尺寸的全 1 陣列
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
scores = None
while True:
# 根據傳入引數組裝成字典,請見該方法定義
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# 將單詞 ID 傳入模型,得到(所有字首)下一個單詞的 logits
# [BatchSize, SeqLen, VocabSize]
outputs = self(
**model_inputs,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)
# 擷取 SeqLen 維度的最後一維,得到整句話下一個單詞的 logits
# [BatchSize, VocabSize]
next_token_logits = outputs.logits[:, -1, :]
# 傳入 logits 處理器和包裝器,修正 logits
next_token_scores = logits_processor(input_ids, next_token_logits)
next_token_scores = logits_warper(input_ids, next_token_scores)
# 計算 softmax 得到機率值
probs = nn.functional.softmax(next_token_scores, dim=-1)
# 如果設定了需要取樣,對其進行多項式取樣,樣本容量為 1
# 否則直接取最大的
# 得到下個單詞 ID,尺寸為 [BatchSize]
if generation_config.do_sample:
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
else:
next_tokens = torch.argmax(probs, dim=-1)
# 下個單詞 ID 變形為 [BatchSize, 1],然後和輸入單詞 ID 拼接
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
# 根據當前輸出更新KVCache、注意力掩碼和位置ID
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
# `next_tokens` 變形為 [1, BatchSize],再將第一維重複 NEOS 次,[NEOS, BatchSize]
# `eos_token_id_tensor` 變形為 [NEOS, 1],將廣播第二維變成 [NEOS, BatchSize]
# 之後二者逐元素比較是否不相等,形成一個比較結果,尺寸為 [NEOS, BatchSize]
# 之後按照 BatchSize 維度計算乘積,得到未完成標誌,[BatchSize]
# 如果某個序列等於終止符集合裡面的任意一個,那麼比較結果就會出現一個 0,未完成標誌將會是 0。
unfinished_sequences = unfinished_sequences.mul(
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
)
# 如果指定了返回 KVCache
# 產生輸入ID和已生成的輸出ID
# 和 KVCache
# 否則只產生第一個
if return_past_key_values:
yield input_ids, outputs.past_key_values
else:
yield input_ids
# 如果未完成標誌全為零(表示序列都已生成完畢),或者達到了停止標準,就停止生成
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
break