import copy import os import sys dir_path = os.path.dirname(os.path.realpath(__file__)) sys.path.insert(0, dir_path) import contextlib import torch.utils.checkpoint from torch.nn import LayerNorm from torch import nn from torchvision import transforms from torchvision.transforms.functional import InterpolationMode from modeling_perceive_sampler import BertConfig, BertLMHeadModel from transformers.utils import logging logger = logging.get_logger(__name__) import transformers from transformers import PreTrainedModel, AutoTokenizer, AutoModelForMaskedLM,AutoModel,BertTokenizer,GPT2LMHeadModel,PretrainedConfig,GPT2Model,GPT2Tokenizer,LongformerTokenizer, LongformerModel os.environ["CUDA_VISIBLE_DEVICES"] = "0" # 使用第一個GPU import argparse import math class RALLM(nn.Module): def __init__(self,args): super(RALLM,self).__init__() self.is_compress = args.is_compress self.use_lora = args.use_lora print('Init LLM ... ') if args.LLM_model == "Baichuan2_13B": self.LLM_model_name = "Baichuan2-13B-Chat" self.LLM_hidden_size = 5120 elif args.LLM_model == "Baichuan2_7B": self.LLM_model_name = "baichuan2_7B" self.LLM_hidden_size = 4096 self.LLM_model = transformers.AutoModelForCausalLM.from_pretrained( self.LLM_model_name, device_map=f"cuda:{args.local_rank}", trust_remote_code=True, torch_dtype=torch.bfloat16, # cache_dir=training_args.cache_dir, ) self.LLM_tokenizer = transformers.AutoTokenizer.from_pretrained( self.LLM_model_name, use_fast=False, trust_remote_code=True, model_max_length=4096, # cache_dir=training_args.cache_dir, ) self.flag_context_start = nn.Parameter(torch.zeros([1, 1, self.LLM_hidden_size]))#.to(self.device) self.flag_context_end = nn.Parameter(torch.zeros([1, 1, self.LLM_hidden_size]))#.to(self.device) self.flag_context_start.requires_grad = False self.flag_context_end.requires_grad = False self.device = self.LLM_model.device self.user_token = self.LLM_tokenizer._convert_id_to_token(195) self.assisent_token = self.LLM_tokenizer._convert_id_to_token(196) self.eoa = self.LLM_tokenizer._convert_id_to_token(2) print("user_token:",self.user_token,"assisent_token:",self.assisent_token,"eoa:",self.eoa) print('Done') print('Init context encoder ... ') self.init_context_encoder(args) print('Done') def init_Qformer(self,num_query_token,num_features): self.Qformer = self.init_qformer(num_query_token, num_features,cross_attention_freq=1) self.Qformer.bert.embeddings.word_embeddings = None self.Qformer.bert.embeddings.position_embeddings = None for layer in self.Qformer.bert.encoder.layer: layer.output = None layer.intermediate = None self.Qformer.cls = None @classmethod def init_qformer(cls, num_query_token, vision_width, cross_attention_freq=2, pretrain=True): encoder_config = BertConfig() encoder_config.num_hidden_layers = 2 encoder_config.hidden_size = vision_width encoder_config.encoder_width = vision_width encoder_config.num_attention_heads = vision_width//64 # insert cross-attention layer every other block encoder_config.add_cross_attention = True encoder_config.cross_attention_freq = cross_attention_freq encoder_config.query_length = num_query_token Qformer = BertLMHeadModel(config=encoder_config) return Qformer def init_context_encoder(self,args): num_query_token = args.query_tokens = 0 if args.encoder == "bert_base": self.context_tokenizer = AutoTokenizer.from_pretrained("bert_base_chinese") self.context_encoder = AutoModelForMaskedLM.from_pretrained("bert_base_chinese",output_hidden_states=True) num_features = 768 if args.encoder == "bert_large": self.context_tokenizer = AutoTokenizer.from_pretrained("bert_large_chinese",max_length=2000) self.context_encoder = AutoModelForMaskedLM.from_pretrained("bert_large_chinese",output_hidden_states=True) num_features = 1024 if args.encoder == "gpt2_xlarge": self.context_tokenizer = BertTokenizer.from_pretrained("gpt2_chinese_xlarge") self.context_encoder = GPT2LMHeadModel.from_pretrained("gpt2_chinese_xlarge") num_features = 1600 if args.encoder == "gpt2_large": self.context_tokenizer = BertTokenizer.from_pretrained("gpt2_chinese_large") self.context_encoder = GPT2LMHeadModel.from_pretrained("gpt2_chinese_large") num_features = 1280 if args.encoder == "gpt2_large_en": self.context_tokenizer = GPT2Tokenizer.from_pretrained("/data2/xinyuuliu/Baichuan2_qformer_bert/gpt2-large-EN") self.context_tokenizer.add_special_tokens({'pad_token': '[PAD]'}) self.context_encoder = GPT2Model.from_pretrained("/data2/xinyuuliu/Baichuan2_qformer_bert/gpt2-large-EN") num_features = 1280 if args.encoder == "longformer": self.context_tokenizer = LongformerTokenizer.from_pretrained('longformer') self.context_encoder = LongformerModel.from_pretrained('longformer') num_features = 768 if args.encoder == "longformer_large": self.context_tokenizer = LongformerTokenizer.from_pretrained('longformer-large') self.context_encoder = LongformerModel.from_pretrained('longformer-large') num_features = 1024 # bert_tokenizer = AutoTokenizer.from_pretrained("bert_base_chinese",max_length=2000) # bert_encoder = AutoModelForMaskedLM.from_pretrained("longformer_zh",output_hidden_states=True) #.to(device) self.context_encoder = self.context_encoder.to(self.device) self.context_score = torch.nn.ModuleList([ torch.nn.Linear(num_features, 64), torch.nn.Tanh(), torch.nn.Linear(64, 1), ]) # 768是BERT的隱藏狀態維度,1是目標輸出維度 self.context2llm_proj = torch.nn.Linear(num_features, self.LLM_hidden_size) # 768是BERT的隱藏狀態維度,1是目標輸出維度 self.llm_proj = torch.nn.Linear(num_features, self.LLM_hidden_size) # model.embed2qformer_proj = torch.nn.Linear(num_features, 768) self.ln_features = LayerNorm(num_features) self.init_Qformer(num_query_token,num_features) # del model.internlm_proj # del model.Qformer # torch.cuda.empty_cache() # 釋放視訊記憶體 # if device: # model = self.model.to(self.device) def encode_text(self, text, add_special_tokens=False): input_ids = self.LLM_tokenizer.encode(text) input_ids = torch.LongTensor([input_ids]).to(self.device) if self.use_lora: text_embeds = self.LLM_model.base_model.model.model.embed_tokens(input_ids) else: text_embeds = self.LLM_model.model.embed_tokens(input_ids) return text_embeds def calculate_compressibility(self,x,k=0): return (x * k*(9 / 1000) + 1) * 111.111 / (x + 111.111) # 批次輸入句子 def batch_input_sentences(self,sentences): input_ids_list = [self.context_tokenizer.encode(sentence,return_tensors="pt",padding='max_length', max_length=2500, truncation=True) for sentence in sentences] max_length = max(len(input_ids[0]) for input_ids in input_ids_list) input_ids_padded = [torch.cat([input_ids, torch.zeros(1, max_length - input_ids.size(1), dtype=torch.long)], dim=1) for input_ids in input_ids_list] input_ids_tensor = torch.cat(input_ids_padded, dim=0) return input_ids_tensor def encode_context(self, text_list): if text_list is None: return None inputs_LLMs = [] input_atts = [] # print(text_list) for text in text_list: # input_ids = self.context_tokenizer.encode(text, add_special_tokens=True, return_tensors="pt") # 對文字列表進行編碼並進行最長的填充 # encoded_ids = self.context_tokenizer(text, padding=True, return_tensors="pt", truncation=True) input_ids = self.batch_input_sentences(text) input_ids =input_ids.to(self.device) # input_ids = encoded_ids.data["input_ids"].to(self.device) # attention_mask = encoded_ids.data["attention_mask"].to(self.device) outputs = self.context_encoder(input_ids,output_hidden_states=True) # 提取最後一層的隱藏狀態向量 embedding,last_hidden_state = outputs.hidden_states[0],outputs.hidden_states[-1] #outputs.logits x = last_hidden_state for layer in self.context_score: x = layer(x) output = x # output = self.context_score(last_hidden_state) # 進行線性變換 batch,seq_len,ebd_dim = last_hidden_state.size() # compressibility = -0.0009 * seq_len+1 #壓縮率計算長度越低壓縮率越低,長度越長,壓縮率越高。線性壓縮不好 x*f(x) 不是單調遞減的 # compressibility = 111.111/(seq_len+111.111) #重新設計非線性壓縮 10以下不壓縮,0-1000 x*f(x) 遞減 if self.is_compress: compressibility = self.calculate_compressibility(seq_len,0) K = math.ceil(seq_len*compressibility) else: K = seq_len # 使用 torch.topk 函式獲取 top k 的索引 topk_indices = torch.topk(output, K,dim=1).indices # print(topk_indices) topk_indices, sorted_indices = torch.sort(topk_indices,dim=1) #恢復原文順序 # print(topk_indices) # 計算 top k 對應的 last_hidden_state topk_selected_last_hidden_state = torch.gather(last_hidden_state, 1, topk_indices.expand(-1, -1, ebd_dim)) # print(last_hidden_state) # print(topk_selected_last_hidden_state) topk_selected_embedding = torch.gather(embedding, 1, topk_indices.expand(-1, -1, ebd_dim)) # bert_text_atts = torch.gather(attention_mask, 1, torch.squeeze(topk_indices, dim=2)) bert_text_embeds = self.ln_features(last_hidden_state) bert_text_atts = torch.ones(bert_text_embeds.size()[:-1],dtype=torch.long).to(self.device) # query_tokens = self.query_tokens.expand(bert_text_atts.shape[0], -1,-1) # query_tokens = topk_selected_embedding query_tokens = topk_selected_last_hidden_state query_output = self.Qformer.bert( query_embeds=query_tokens, encoder_hidden_states=bert_text_embeds, encoder_attention_mask=bert_text_atts, return_dict=True, ) # topk_context_hidden_state = self.context2llm_proj(topk_selected_last_hidden_state) inputs_LLM = self.llm_proj(query_output.last_hidden_state) inputs_LLM = torch.cat([ self.flag_context_start.expand(batch, -1, -1), # topk_context_hidden_state, inputs_LLM, self.flag_context_end.expand(batch, -1, -1) ],dim=1).view(-1, self.LLM_hidden_size) input_att = torch.cat([torch.ones((batch,1)).to(self.device),bert_text_atts,torch.ones((batch,1)).to(self.device)],dim=1).view(-1) # print(inputs_LLM.shape) inputs_LLMs.append(inputs_LLM) input_atts.append(input_att) # context_inputs = torch.stack(inputs_LLMs) return inputs_LLMs,input_atts def wrap_prompt(self, text_embeds, context_embeds=None, history=None, add_special=True): if add_special: if history is None: prompt_segs = [ self.user_token, self.assisent_token ] else: prompt_segs = [self.user_token, self.assisent_token] else: prompt_segs = [self.user_token, self.assisent_token] # used in wrap history prompt_seg_embeds = [] for i, seg in enumerate(prompt_segs): if history is not None: add_special_tokens = False else: add_special_tokens = i == 0 seg_embeds = self.encode_text( seg, add_special_tokens=add_special_tokens) prompt_seg_embeds.append(seg_embeds) if context_embeds is None: context_embeds = text_embeds.new_empty(text_embeds.size(0), 0, text_embeds.size(-1)) else: # 在第一個維度(索引為0)新增一個維度 context_embeds = context_embeds[0].unsqueeze(0) prompt_seg_embeds = [ prompt_seg_embeds[0], text_embeds,context_embeds, prompt_seg_embeds[1] ] prompt_embeds = torch.cat(prompt_seg_embeds, dim=1) if history is not None: prompt_embeds = torch.cat([*history, prompt_embeds], dim=1) return prompt_embeds def generate(self, text, context=None, **kwargs): text = text.replace("<context>","").replace(self.user_token,"").replace(self.assisent_token,"") text_embeds = self.encode_text(text) context_embeds,_ = self.encode_context(context) prompt_embeds = self.wrap_prompt(text_embeds, context_embeds) # out_embeds = self.LLM_model.generate(input_ids=None, # inputs_embeds=prompt_embeds, **self.get_gen_args(**kwargs)) # out_text = self.decode_text(out_embeds) outputs = self.LLM_model.generate(input_ids=None,inputs_embeds=prompt_embeds, generation_config=self.LLM_model.generation_config) response = self.LLM_tokenizer.decode(outputs[0], skip_special_tokens=True) return response def chat(self, text, context=None, history=None, **kwargs): text_embeds = self.encode_text(text) img_embeds = self.encode_context(context) prompt_embeds = self.wrap_prompt(text_embeds, img_embeds, history=history) out_embeds = self.internlm_model.generate( inputs_embeds=prompt_embeds, **self.get_gen_args(**kwargs)) out_text = self.decode_text(out_embeds) # trunc at eoh and eoa clean_out_text_token_ids = self.tokenizer( out_text, return_tensors='pt').input_ids.to(self.device) clean_out_text_embeds = self.internlm_model.model.embed_tokens( clean_out_text_token_ids) clean_prompt_embeds = self.wrap_prompt(text_embeds, img_embeds, add_special=False) cur_history = torch.cat([clean_prompt_embeds, clean_out_text_embeds], dim=1) if history is None: history = [] history.append(cur_history) return out_text, history def align_text(self, samples, has_context=False): ### add eos and eoa 返回<context>後的text text_new = [] if has_context: ### remove the first user to wrap image features text = [ t.split("<context>")[-1] for t in samples["text_input"] ] else: text = [t for t in samples["text_input"]] text = [t + self.eoa for t in text] for i in range(len(text)): temp = text[i] # temp = temp.replace('<|Bot|>', self.eoh + ' <|Bot|>') # if temp.find(self.eoh) > temp.find(self.eoa): # temp = temp.replace(self.eoa, '', 1) text_new.append(temp) return text_new def prompt_wrap(self, context_embeds,context_atts, prompt_list): batch_size = len(context_embeds) p_before = [prompt.split('<context>')[0] for prompt in prompt_list] p_before_tokens = self.LLM_tokenizer(p_before, padding=True, truncation=True, return_tensors="pt", add_special_tokens=True).to( self.device) if self.use_lora: p_before_embeds = self.LLM_model.base_model.model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1) else: p_before_embeds = self.LLM_model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1) # wrapped_context_embeds = torch.cat([p_before_embeds, context_embeds], dim=1) # wrapped_context_embeds = torch.cat([p_before_embeds]+context_embeds, dim=1) wrapped_context_embeds = [] wrapped_atts_context = [] wrapped_target = [] for i, (context_embed,context_att) in enumerate(zip(context_embeds,context_atts)): # 將p_before_embeds的每個序列與相應的張量在序列長度維度上拼接 concatenated = torch.cat((p_before_embeds[i], context_embed), dim=0) wrapped_context_embeds.append(concatenated) # concatenated_att = torch.cat((torch.ones(p_before_embeds[i].size()[:-1],dtype=torch.long).to(self.device),context_att),dim=0) wrapped_atts_context.append(torch.ones(concatenated.size()[:-1],dtype=torch.long).to(self.device)) # wrapped_atts_context.append(concatenated_att) target = torch.ones(concatenated.size()[:-1], dtype=torch.long) * -100 target[0] = 2 target = target.to(self.device) wrapped_target.append(target) # wrapped_atts_context = torch.ones(wrapped_context_embeds.size()[:-1], # dtype=torch.long).to(self.device) # wrapped_target = torch.ones( # batch_size, wrapped_context_embeds.shape[1], dtype=torch.long).to( # self.device) * -100 return wrapped_context_embeds, wrapped_atts_context, wrapped_target def text2emb(self, text): to_regress_tokens = self.LLM_tokenizer(text, return_tensors="pt", padding="longest", truncation=True, max_length=4096, add_special_tokens=False).to( self.device) targets = self.mask_human_targets(to_regress_tokens.input_ids) targets = targets.to(self.device) return to_regress_tokens, targets def mask_human_targets(self, input_ids, pure=False): target_batch = [] for bs in range(input_ids.shape[0]): cur_idx = 0 ids = input_ids[bs] targets = copy.deepcopy(ids) last_eoa = 0 last_eoh = 0 for i, temp_id in enumerate(ids): if temp_id == 196: #### end of human targets[cur_idx:i+1] = -100 target_batch.append(targets.unsqueeze(0)) target_batch = torch.cat(target_batch, dim=0) target_batch[target_batch==0]=-100 # print(input_ids) # print(target_batch) return target_batch def forward(self, input_ids=None, attention_mask=None, inputs_embeds=None, labels=None, output_attentions=None, output_hidden_states=None, return_dict=None, context = None, text_input = None, **kwargs): # samples = kwargs #.get('samples') # has_context = 'context' in samples.keys() if context: has_context = True else: has_context = False samples = {"text_input":text_input,"context":context} ### encode text text = self.align_text(samples=samples, has_context=has_context) #獲取<context> 後面的text to_regress_tokens, targets = self.text2emb(text) #返回token和target if self.use_lora: to_regress_embeds = self.LLM_model.base_model.model.model.embed_tokens(to_regress_tokens.input_ids) else: to_regress_embeds = self.LLM_model.model.embed_tokens(to_regress_tokens.input_ids) attention_mask = to_regress_tokens.attention_mask if has_context: prompt = samples["text_input"] ### encode context context = samples["context"] context_embeds,context_atts = self.encode_context(context) context_embeds, atts_context, wrapped_target = self.prompt_wrap( context_embeds,context_atts, prompt) ### combine text and image to_regress_embeds_ = [] attention_mask_ = [] targets_ = [] for i, (tensor0,tensor1,tensor2) in enumerate(zip(to_regress_embeds,attention_mask,targets)): # 將p_before_embeds的每個序列與相應的張量在序列長度維度上拼接 to_regress_embed = torch.cat((context_embeds[i], tensor0), dim=0) to_regress_embeds_.append(to_regress_embed) attention_m = torch.cat((atts_context[i], tensor1), dim=0) attention_mask_.append(attention_m) target = torch.cat((wrapped_target[i], tensor2), dim=0) targets_.append(target) # to_regress_embeds = torch.cat([context_embeds, to_regress_embeds], # dim=1) # attention_mask = torch.cat([atts_context, attention_mask], dim=1) # targets = torch.cat([wrapped_target, targets], dim=1) # 確定最大長度 max_len = max(t.size(0) for t in to_regress_embeds_) # 填充張量 padded_to_regress_embeds_ = [] padded_attention_mask_ = [] padded_targets_ = [] for (t,a,l) in zip(to_regress_embeds_,attention_mask_,targets_): if t.size(0) < max_len: # 計算需要填充的長度 padding_size = max_len - t.size(0) # 在序列維度上進行填充 padded_regress = torch.nn.functional.pad(t, (0, 0, 0, padding_size)) padded_attention = torch.nn.functional.pad(a, (0, padding_size), value=0) padded_target = torch.nn.functional.pad(l, (0, padding_size), value=-100) padded_to_regress_embeds_.append(padded_regress) padded_attention_mask_.append(padded_attention) padded_targets_.append(padded_target) else: padded_to_regress_embeds_.append(t) padded_attention_mask_.append(a) padded_targets_.append(l) # 合併張量 to_regress_embeds = torch.stack(padded_to_regress_embeds_) attention_mask = torch.stack(padded_attention_mask_) targets = torch.stack(padded_targets_) outputs = self.LLM_model( inputs_embeds=to_regress_embeds, attention_mask=attention_mask, return_dict=True, labels=targets, ) return outputs if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--output", default="output", type=str) parser.add_argument("--encoder", default="gpt2_large", type=str) parser.add_argument("--query_tokens", default=32, type=int) parser.add_argument("--load_path", default="/data2/xinyuuliu/InternLM-XComposer/output_rerank", type=str) parser.add_argument("--local_rank", default="0", type=str) args = parser.parse_args() model = RALLM(args) print(model) # model.encode_context("我愛北京天安門") # model.encode_text("我愛北京天安門") # #<ContextHere> # query = "Q:請重複內容:<cont_s><ContextHere><cont_e> \n A:" # context = ["電飯煲不知道怎麼選?想要吃一碗香噴噴的米飯,除了米要好之外,還需要一款效能優秀的電飯煲,所以大家在選購電飯煲的時候,一定要多花點心思看看攻略避免踩雷。我前前後後給親朋好友選購過不下5臺電飯煲,也算是積攢了不少選購經驗,今天特意總結了一下想分享給大家。1、容量選擇市面上電飯煲容量普遍在3L-5L之間,這個範圍的容量足夠滿足絕大部分家庭使用,3L一般可以滿足1-3人的家庭,4L一般可以滿足2-5人的家庭,5L一般可以滿足2-8人的家庭,如果人口超過8人建議直接選擇5L以上的容量,使用會更方便。"] # model.interleav_wrap(query,context)
modeling_perceive_sampler.py
""" * Copyright (c) 2023, salesforce.com, inc. * All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause * By Junnan Li * Based on huggingface code base * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert """ import math from typing import Tuple import torch import torch.utils.checkpoint from torch import Tensor, device from torch import nn from torch.nn import CrossEntropyLoss from transformers.activations import ACT2FN from transformers.modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, CausalLMOutputWithCrossAttentions, MaskedLMOutput, ) from transformers.modeling_utils import ( PreTrainedModel, apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer, ) from transformers.models.bert.configuration_bert import BertConfig from transformers.utils import logging logger = logging.get_logger(__name__) class BertEmbeddings(nn.Module): """Construct the embeddings from word and position embeddings.""" def __init__(self, config): super().__init__() self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load # any TensorFlow checkpoint file self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized self.register_buffer( "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.config = config def forward( self, input_ids=None, position_ids=None, query_embeds=None, past_key_values_length=0, ): if input_ids is not None: seq_length = input_ids.size()[1] else: seq_length = 0 if position_ids is None: position_ids = self.position_ids[:, past_key_values_length: seq_length + past_key_values_length].clone() if input_ids is not None: embeddings = self.word_embeddings(input_ids) if self.position_embedding_type == "absolute": position_embeddings = self.position_embeddings(position_ids) embeddings = embeddings + position_embeddings if query_embeds is not None: embeddings = torch.cat((query_embeds, embeddings), dim=1) else: embeddings = query_embeds embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) return embeddings class BertSelfAttention(nn.Module): def __init__(self, config, is_cross_attention): super().__init__() self.config = config if config.hidden_size % config.num_attention_heads != 0 and not hasattr( config, "embedding_size"): raise ValueError( "The hidden size (%d) is not a multiple of the number of attention " "heads (%d)" % (config.hidden_size, config.num_attention_heads)) self.num_attention_heads = config.num_attention_heads self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size self.query = nn.Linear(config.hidden_size, self.all_head_size) if is_cross_attention: self.key = nn.Linear(config.encoder_width, self.all_head_size) self.value = nn.Linear(config.encoder_width, self.all_head_size) else: self.key = nn.Linear(config.hidden_size, self.all_head_size) self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") if (self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query"): self.max_position_embeddings = config.max_position_embeddings self.distance_embedding = nn.Embedding( 2 * config.max_position_embeddings - 1, self.attention_head_size) self.save_attention = False def save_attn_gradients(self, attn_gradients): self.attn_gradients = attn_gradients def get_attn_gradients(self): return self.attn_gradients def save_attention_map(self, attention_map): self.attention_map = attention_map def get_attention_map(self): return self.attention_map def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + ( self.num_attention_heads, self.attention_head_size, ) x = x.view(*new_x_shape) return x.permute(0, 2, 1, 3) def forward( self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, past_key_value=None, output_attentions=False, ): # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None if is_cross_attention: key_layer = self.transpose_for_scores( self.key(encoder_hidden_states)) value_layer = self.transpose_for_scores( self.value(encoder_hidden_states)) attention_mask = encoder_attention_mask elif past_key_value is not None: key_layer = self.transpose_for_scores(self.key(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states)) key_layer = torch.cat([past_key_value[0], key_layer], dim=2) value_layer = torch.cat([past_key_value[1], value_layer], dim=2) else: key_layer = self.transpose_for_scores(self.key(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states)) mixed_query_layer = self.query(hidden_states) query_layer = self.transpose_for_scores(mixed_query_layer) past_key_value = (key_layer, value_layer) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if (self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query"): seq_length = hidden_states.size()[1] position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view( -1, 1) position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view( 1, -1) distance = position_ids_l - position_ids_r positional_embedding = self.distance_embedding( distance + self.max_position_embeddings - 1) positional_embedding = positional_embedding.to( dtype=query_layer.dtype) # fp16 compatibility if self.position_embedding_type == "relative_key": relative_position_scores = torch.einsum( "bhld,lrd->bhlr", query_layer, positional_embedding) attention_scores = attention_scores + relative_position_scores elif self.position_embedding_type == "relative_key_query": relative_position_scores_query = torch.einsum( "bhld,lrd->bhlr", query_layer, positional_embedding) relative_position_scores_key = torch.einsum( "bhrd,lrd->bhlr", key_layer, positional_embedding) attention_scores = (attention_scores + relative_position_scores_query + relative_position_scores_key) attention_scores = attention_scores / math.sqrt( self.attention_head_size) if attention_mask is not None: # Apply the attention mask is (precomputed for all layers in BertModel forward() function) attention_scores = attention_scores + attention_mask # Normalize the attention scores to probabilities. attention_probs = nn.Softmax(dim=-1)(attention_scores) if is_cross_attention and self.save_attention: self.save_attention_map(attention_probs) attention_probs.register_hook(self.save_attn_gradients) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. attention_probs_dropped = self.dropout(attention_probs) # Mask heads if we want to if head_mask is not None: attention_probs_dropped = attention_probs_dropped * head_mask context_layer = torch.matmul(attention_probs_dropped, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + ( self.all_head_size, ) context_layer = context_layer.view(*new_context_layer_shape) outputs = ((context_layer, attention_probs) if output_attentions else (context_layer, )) outputs = outputs + (past_key_value, ) return outputs class BertSelfOutput(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.LayerNorm(hidden_states + input_tensor) return hidden_states class BertAttention(nn.Module): def __init__(self, config, is_cross_attention=False): super().__init__() self.self = BertSelfAttention(config, is_cross_attention) self.output = BertSelfOutput(config) self.pruned_heads = set() def prune_heads(self, heads): if len(heads) == 0: return heads, index = find_pruneable_heads_and_indices( heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads, ) # Prune linear layers self.self.query = prune_linear_layer(self.self.query, index) self.self.key = prune_linear_layer(self.self.key, index) self.self.value = prune_linear_layer(self.self.value, index) self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) # Update hyper params and store pruned heads self.self.num_attention_heads = self.self.num_attention_heads - len( heads) self.self.all_head_size = (self.self.attention_head_size * self.self.num_attention_heads) self.pruned_heads = self.pruned_heads.union(heads) def forward( self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, past_key_value=None, output_attentions=False, ): self_outputs = self.self( hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output, ) + self_outputs[1:] # add attentions if we output them return outputs class BertIntermediate(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.intermediate_size) if isinstance(config.hidden_act, str): self.intermediate_act_fn = ACT2FN[config.hidden_act] else: self.intermediate_act_fn = config.hidden_act def forward(self, hidden_states): hidden_states = self.dense(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states) return hidden_states class BertOutput(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.intermediate_size, config.hidden_size) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.LayerNorm(hidden_states + input_tensor) return hidden_states class BertLayer(nn.Module): def __init__(self, config, layer_num): super().__init__() self.config = config self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 self.attention = BertAttention(config) self.layer_num = layer_num if (self.config.add_cross_attention and layer_num % self.config.cross_attention_freq == 0): self.crossattention = BertAttention( config, is_cross_attention=self.config.add_cross_attention) self.has_cross_attention = True else: self.has_cross_attention = False self.intermediate = BertIntermediate(config) self.output = BertOutput(config) self.intermediate_query = BertIntermediate(config) self.output_query = BertOutput(config) def forward( self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, past_key_value=None, output_attentions=False, query_length=0, ): # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 self_attn_past_key_value = (past_key_value[:2] if past_key_value is not None else None) self_attention_outputs = self.attention( hidden_states, attention_mask, head_mask, output_attentions=output_attentions, past_key_value=self_attn_past_key_value, ) attention_output = self_attention_outputs[0] outputs = self_attention_outputs[1:-1] present_key_value = self_attention_outputs[-1] if query_length > 0: query_attention_output = attention_output[:, :query_length, :] if self.has_cross_attention: assert ( encoder_hidden_states is not None ), "encoder_hidden_states must be given for cross-attention layers" cross_attention_outputs = self.crossattention( query_attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, output_attentions=output_attentions, ) query_attention_output = cross_attention_outputs[0] outputs = ( outputs + cross_attention_outputs[1:-1] ) # add cross attentions if we output attention weights layer_output = apply_chunking_to_forward( self.feed_forward_chunk_query, self.chunk_size_feed_forward, self.seq_len_dim, query_attention_output, ) if attention_output.shape[1] > query_length: layer_output_text = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output[:, query_length:, :], ) layer_output = torch.cat([layer_output, layer_output_text], dim=1) else: layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output, ) outputs = (layer_output, ) + outputs outputs = outputs + (present_key_value, ) return outputs def feed_forward_chunk(self, attention_output): intermediate_output = self.intermediate(attention_output) layer_output = self.output(intermediate_output, attention_output) return layer_output def feed_forward_chunk_query(self, attention_output): intermediate_output = self.intermediate_query(attention_output) layer_output = self.output_query(intermediate_output, attention_output) return layer_output class BertEncoder(nn.Module): def __init__(self, config): super().__init__() self.config = config self.layer = nn.ModuleList( [BertLayer(config, i) for i in range(config.num_hidden_layers)]) def forward( self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, past_key_values=None, use_cache=None, output_attentions=False, output_hidden_states=False, return_dict=True, query_length=0, ): all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None all_cross_attentions = (() if output_attentions and self.config.add_cross_attention else None) next_decoder_cache = () if use_cache else None for i in range(self.config.num_hidden_layers): layer_module = self.layer[i] if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states, ) layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[ i] if past_key_values is not None else None if getattr(self.config, "gradient_checkpointing", False) and self.training: if use_cache: logger.warn( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs, past_key_value, output_attentions, query_length) return custom_forward layer_outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(layer_module), hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, ) else: layer_outputs = layer_module( hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions, query_length, ) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache += (layer_outputs[-1], ) if output_attentions: all_self_attentions = all_self_attentions + ( layer_outputs[1], ) all_cross_attentions = all_cross_attentions + ( layer_outputs[2], ) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states, ) if not return_dict: return tuple(v for v in [ hidden_states, next_decoder_cache, all_hidden_states, all_self_attentions, all_cross_attentions, ] if v is not None) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=next_decoder_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, ) class BertPooler(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.activation = nn.Tanh() def forward(self, hidden_states): # We "pool" the model by simply taking the hidden state corresponding # to the first token. first_token_tensor = hidden_states[:, 0] pooled_output = self.dense(first_token_tensor) pooled_output = self.activation(pooled_output) return pooled_output class BertPredictionHeadTransform(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) if isinstance(config.hidden_act, str): self.transform_act_fn = ACT2FN[config.hidden_act] else: self.transform_act_fn = config.hidden_act self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) def forward(self, hidden_states): hidden_states = self.dense(hidden_states) hidden_states = self.transform_act_fn(hidden_states) hidden_states = self.LayerNorm(hidden_states) return hidden_states class BertLMPredictionHead(nn.Module): def __init__(self, config): super().__init__() self.transform = BertPredictionHeadTransform(config) # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) return hidden_states class BertOnlyMLMHead(nn.Module): def __init__(self, config): super().__init__() self.predictions = BertLMPredictionHead(config) def forward(self, sequence_output): prediction_scores = self.predictions(sequence_output) return prediction_scores class BertPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config_class = BertConfig base_model_prefix = "bert" _keys_to_ignore_on_load_missing = [r"position_ids"] def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Embedding)): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() class BertModel(BertPreTrainedModel): """ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of cross-attention is added between the self-attention layers, following the architecture described in `Attention is all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an input to the forward pass. """ def __init__(self, config, add_pooling_layer=False): super().__init__(config) self.config = config self.embeddings = BertEmbeddings(config) self.encoder = BertEncoder(config) self.pooler = BertPooler(config) if add_pooling_layer else None self.init_weights() def get_input_embeddings(self): return self.embeddings.word_embeddings def set_input_embeddings(self, value): self.embeddings.word_embeddings = value def _prune_heads(self, heads_to_prune): """ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base class PreTrainedModel """ for layer, heads in heads_to_prune.items(): self.encoder.layer[layer].attention.prune_heads(heads) def get_extended_attention_mask( self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool, has_query: bool = False, ) -> Tensor: """ Makes broadcastable attention and causal masks so that future and masked tokens are ignored. Arguments: attention_mask (:obj:`torch.Tensor`): Mask with ones indicating tokens to attend to, zeros for tokens to ignore. input_shape (:obj:`Tuple[int]`): The shape of the input to the model. device: (:obj:`torch.device`): The device of the input to the model. Returns: :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. """ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. if attention_mask.dim() == 3: extended_attention_mask = attention_mask[:, None, :, :] elif attention_mask.dim() == 2: # Provided a padding mask of dimensions [batch_size, seq_length] # - if the model is a decoder, apply a causal mask in addition to the padding mask # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] if is_decoder: batch_size, seq_length = input_shape seq_ids = torch.arange(seq_length, device=device) causal_mask = (seq_ids[None, None, :].repeat( batch_size, seq_length, 1) <= seq_ids[None, :, None]) # add a prefix ones mask to the causal mask # causal and attention masks must have same type with pytorch version < 1.3 causal_mask = causal_mask.to(attention_mask.dtype) if causal_mask.shape[1] < attention_mask.shape[1]: prefix_seq_len = attention_mask.shape[ 1] - causal_mask.shape[1] if has_query: # UniLM style attention mask causal_mask = torch.cat( [ torch.zeros( (batch_size, prefix_seq_len, seq_length), device=device, dtype=causal_mask.dtype, ), causal_mask, ], axis=1, ) causal_mask = torch.cat( [ torch.ones( (batch_size, causal_mask.shape[1], prefix_seq_len), device=device, dtype=causal_mask.dtype, ), causal_mask, ], axis=-1, ) extended_attention_mask = (causal_mask[:, None, :, :] * attention_mask[:, None, None, :]) else: extended_attention_mask = attention_mask[:, None, None, :] else: raise ValueError( "Wrong shape for input_ids (shape {}) or attention_mask (shape {})" .format(input_shape, attention_mask.shape)) # Since attention_mask is 1.0 for positions we want to attend and 0.0 for # masked positions, this operation will create a tensor which is 0.0 for # positions we want to attend and -10000.0 for masked positions. # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. extended_attention_mask = extended_attention_mask.to( dtype=self.dtype) # fp16 compatibility extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 return extended_attention_mask def forward( self, input_ids=None, attention_mask=None, position_ids=None, head_mask=None, query_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None, past_key_values=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, is_decoder=False, ): r""" encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model is configured as a decoder. encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. use_cache (:obj:`bool`, `optional`): If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up decoding (see :obj:`past_key_values`). """ output_attentions = (output_attentions if output_attentions is not None else self.config.output_attentions) output_hidden_states = (output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states) return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) # use_cache = use_cache if use_cache is not None else self.config.use_cache if input_ids is None: assert ( query_embeds is not None ), "You have to specify query_embeds when input_ids is None" # past_key_values_length past_key_values_length = (past_key_values[0][0].shape[2] - self.config.query_length if past_key_values is not None else 0) query_length = query_embeds.shape[1] if query_embeds is not None else 0 embedding_output = self.embeddings( input_ids=input_ids, position_ids=position_ids, query_embeds=query_embeds, past_key_values_length=past_key_values_length, ) input_shape = embedding_output.size()[:-1] batch_size, seq_length = input_shape device = embedding_output.device if attention_mask is None: attention_mask = torch.ones( ((batch_size, seq_length + past_key_values_length)), device=device) # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. if is_decoder: extended_attention_mask = self.get_extended_attention_mask( attention_mask, input_ids.shape, device, is_decoder, has_query=(query_embeds is not None), ) else: extended_attention_mask = self.get_extended_attention_mask( attention_mask, input_shape, device, is_decoder) # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] if encoder_hidden_states is not None: if type(encoder_hidden_states) == list: encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[ 0].size() else: ( encoder_batch_size, encoder_sequence_length, _, ) = encoder_hidden_states.size() encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) if type(encoder_attention_mask) == list: encoder_extended_attention_mask = [ self.invert_attention_mask(mask) for mask in encoder_attention_mask ] elif encoder_attention_mask is None: encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) encoder_extended_attention_mask = self.invert_attention_mask( encoder_attention_mask) else: encoder_extended_attention_mask = self.invert_attention_mask( encoder_attention_mask) else: encoder_extended_attention_mask = None # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) encoder_outputs = self.encoder( embedding_output, attention_mask=extended_attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, query_length=query_length, ) sequence_output = encoder_outputs[0] pooled_output = (self.pooler(sequence_output) if self.pooler is not None else None) if not return_dict: return (sequence_output, pooled_output) + encoder_outputs[1:] return BaseModelOutputWithPoolingAndCrossAttentions( last_hidden_state=sequence_output, pooler_output=pooled_output, past_key_values=encoder_outputs.past_key_values, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, cross_attentions=encoder_outputs.cross_attentions, ) class BertLMHeadModel(BertPreTrainedModel): _keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_missing = [ r"position_ids", r"predictions.decoder.bias" ] def __init__(self, config): super().__init__(config) self.bert = BertModel(config, add_pooling_layer=False) self.cls = BertOnlyMLMHead(config) self.init_weights() def get_output_embeddings(self): return self.cls.predictions.decoder def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings def forward( self, input_ids=None, attention_mask=None, position_ids=None, head_mask=None, query_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None, labels=None, past_key_values=None, use_cache=True, output_attentions=None, output_hidden_states=None, return_dict=None, return_logits=False, is_decoder=True, reduction="mean", ): r""" encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model is configured as a decoder. encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]`` past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. use_cache (:obj:`bool`, `optional`): If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up decoding (see :obj:`past_key_values`). Returns: Example:: >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig >>> import torch >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') >>> config = BertConfig.from_pretrained("bert-base-cased") >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config) >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") >>> outputs = model(**inputs) >>> prediction_logits = outputs.logits """ return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) if labels is not None: use_cache = False if past_key_values is not None: query_embeds = None outputs = self.bert( input_ids, attention_mask=attention_mask, position_ids=position_ids, head_mask=head_mask, query_embeds=query_embeds, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, is_decoder=is_decoder, ) sequence_output = outputs[0] if query_embeds is not None: sequence_output = outputs[0][:, query_embeds.shape[1]:, :] prediction_scores = self.cls(sequence_output) if return_logits: return prediction_scores[:, :-1, :].contiguous() lm_loss = None if labels is not None: # we are doing next-token prediction; shift prediction scores and input ids by one shifted_prediction_scores = prediction_scores[:, : -1, :].contiguous() labels = labels[:, 1:].contiguous() loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1) lm_loss = loss_fct( shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1), ) if reduction == "none": lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1) if not return_dict: output = (prediction_scores, ) + outputs[2:] return ((lm_loss, ) + output) if lm_loss is not None else output return CausalLMOutputWithCrossAttentions( loss=lm_loss, logits=prediction_scores, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, cross_attentions=outputs.cross_attentions, ) def prepare_inputs_for_generation(self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs): # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly if attention_mask is None: attention_mask = input_ids.new_ones(input_ids.shape) query_mask = input_ids.new_ones(query_embeds.shape[:-1]) attention_mask = torch.cat([query_mask, attention_mask], dim=-1) # cut decoder_input_ids if past is used if past is not None: input_ids = input_ids[:, -1:] return { "input_ids": input_ids, "query_embeds": query_embeds, "attention_mask": attention_mask, "past_key_values": past, "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None), "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None), "is_decoder": True, } def _reorder_cache(self, past, beam_idx): reordered_past = () for layer_past in past: reordered_past += (tuple( past_state.index_select(0, beam_idx) for past_state in layer_past), ) return reordered_past class BertForMaskedLM(BertPreTrainedModel): _keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_missing = [ r"position_ids", r"predictions.decoder.bias" ] def __init__(self, config): super().__init__(config) self.bert = BertModel(config, add_pooling_layer=False) self.cls = BertOnlyMLMHead(config) self.init_weights() def get_output_embeddings(self): return self.cls.predictions.decoder def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings def forward( self, input_ids=None, attention_mask=None, position_ids=None, head_mask=None, query_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None, labels=None, output_attentions=None, output_hidden_states=None, return_dict=None, return_logits=False, is_decoder=False, ): r""" labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` """ return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) outputs = self.bert( input_ids, attention_mask=attention_mask, position_ids=position_ids, head_mask=head_mask, query_embeds=query_embeds, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, is_decoder=is_decoder, ) if query_embeds is not None: sequence_output = outputs[0][:, query_embeds.shape[1]:, :] prediction_scores = self.cls(sequence_output) if return_logits: return prediction_scores masked_lm_loss = None if labels is not None: loss_fct = CrossEntropyLoss() # -100 index = padding token masked_lm_loss = loss_fct( prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) if not return_dict: output = (prediction_scores, ) + outputs[2:] return (((masked_lm_loss, ) + output) if masked_lm_loss is not None else output) return MaskedLMOutput( loss=masked_lm_loss, logits=prediction_scores, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )
dataset_batch.py
# -*- coding: utf-8 -*- from torch.utils.data import Dataset import torch import json import numpy as np import pandas as pd from torch.utils.data import DataLoader import random class QADataset(Dataset): def __init__(self, data_path,train) -> None: super().__init__() self.data = [] data = pd.read_csv(data_path).dropna() print(data.columns) condition = (data['answer'].str.len() <= 1000) & (data['summary'].str.len() <= 500) filtered_data = data[condition] with open("data/corpus.tsv","r") as f_read: corpus = [i.split()[-1] for i in f_read.readlines()] retell_prompts = ["請複述這段被壓縮的內容", "複述這段被壓縮的內容", "請將被壓縮的內容複述出來",] summary_prompts = ["請總結被壓縮的資訊", "還原被壓縮資訊的主要內容", "請寫出被壓縮資訊的主要內容", "請對之前壓縮的資訊進行概括", "請提煉出之前被壓縮資訊的核心要點", "請歸納一下之前被壓縮的內容的主旨"] if train: # 過濾出符合長度條件的文章 # filtered_data1000 = list(filter(self.filter_by_length1000, data["answer"])) for idx in range(5000): # if not line or line == "" or len(line) < 50 or len(line) > 2000: # continue # 隨機確定重複次數(1到5次) repeat_count = random.randint(1, 10) flag_context = "<context> "*repeat_count prompt = random.choice(retell_prompts) selected_articles = random.sample(corpus, repeat_count) selected_articles_ = "[SEP]".join(selected_articles) text = f'<reserved_106>{prompt}{flag_context}<reserved_107>{selected_articles_}' test_data = {"context":selected_articles,"text_input":text,"label":selected_articles_} self.data.append( test_data ) # for idx in range(5000): # repeat_count = random.randint(1, 1) # flag_context = "<context> "*repeat_count # selected_articles = random.sample(filtered_data150, repeat_count) # selected_articles_ = " ".join(selected_articles) # text = f'<|User|>:請複述這段話{flag_context} <|Bot|>:{selected_articles_}' # test_data = {"samples":{"context":selected_articles,"text_input":[text]}} # self.data.append( # test_data # ) for idx,(answer,summary) in enumerate(zip(filtered_data["answer"],filtered_data["summary"])): answer = [answer[:1000]] flag_context = "<context> " prompt = random.choice(summary_prompts) # user_token: <reserved_106> assisent_token: <reserved_107> eoa: </s> text = f'<reserved_106>{prompt}{flag_context}<reserved_107>{summary}' test_data = {"context":answer,"text_input":text,"label":summary} self.data.append( test_data ) # for idx in range(10000): # repeat_count = random.randint(1, 1) # flag_context = "<context> "*repeat_count # selected_articles = random.sample(filtered_data1500, repeat_count) # selected_articles_ = " ".join(selected_articles) # text = f'<|User|>:請複述這段話{flag_context} <|Bot|>:{selected_articles_}' # test_data = {"samples":{"context":selected_articles,"text_input":[text]}} # self.data.append( # test_data # ) print("data load , size:", len(self.data)) else: for idx in range(100): # if not line or line == "" or len(line) < 50 or len(line) > 2000: # continue # 隨機確定重複次數(1到5次) repeat_count = random.randint(3, 5) flag_context = "<context> "*repeat_count prompt = random.choice(retell_prompts) selected_articles = random.sample(corpus, repeat_count) selected_articles_ = "[SEP]".join(selected_articles) text = f'<reserved_106>{prompt}{flag_context}<reserved_107>' test_data = {"context":selected_articles,"text_input":text,"label":selected_articles_} self.data.append( test_data ) # 建立一個函式來過濾文章長度 @staticmethod def filter_by_length150(article): return 180 <= len(article) <= 200 @staticmethod def filter_by_length1000(article): return 50 <= len(article) <= 1000 @staticmethod def filter_by_length1500(article): return 500 <= len(article) <= 1500 def __getitem__(self, index): item_data = self.data[index] return item_data def __len__(self): return len(self.data) if __name__ == "__main__": data_path = "QA_5000_summary.csv" dataset = QADataset(data_path,train=True) # print(dataset[0]) val_params = { "batch_size": 2, "shuffle": False, "num_workers": 0, } def collate_fn(batch): """ 對batch資料進行處理 :param batch: [一個getitem的結果,getitem的結果,getitem的結果] :return: 元組 """ # 初始化一個空字典來儲存合併後的結果 merged_dict = {} # 遍歷列表中的每個字典 for d in batch: # 遍歷每個字典中的鍵值對 for key, value in d.items(): # 如果鍵已經存在於merged_dict中,將值合併為一個字串,用逗號分隔 if key in merged_dict: merged_dict[key].append(value) else: # 如果鍵不存在於merged_dict中,直接新增到merged_dict中 merged_dict[key] = [value] # 輸出合併後的結果 # print(merged_dict) return merged_dict val_loader = DataLoader(dataset, **val_params,collate_fn=collate_fn) for i in val_loader: print(i) break
train_batch.py
# -*- coding: utf-8 -*- import pandas as pd from torch.utils.data import DataLoader from transformers import AutoTokenizer, AutoModel # from dataset_batch_en import QADataset # from dataset_rerank import QADataset # from dataset_rerank_en_gpt import QADataset from dataset_rerank_en import QADataset from peft import LoraConfig, get_peft_model, TaskType from tqdm import tqdm import torch import os, time, sys import numpy as np from modeling_RALLM import RALLM import argparse import deepspeed from torch.nn.parallel import DataParallel # 設定CUDA裝置可見性,例如僅使用第一個GPU # os.environ["CUDA_VISIBLE_DEVICES"] = "3" parser = argparse.ArgumentParser() parser.add_argument("--is_compress", default=True, type=bool) parser.add_argument("--compressibility_factor", default=0, type=float,dest="0-1") parser.add_argument("--LLM_model", default="Baichuan2_7B", type=str) parser.add_argument("--output", default="output_english_longformer_rerank100k_2", type=str) parser.add_argument("--encoder", default="longformer", type=str) parser.add_argument("--query_tokens", default=98, type=int) parser.add_argument("--load_path", default="output_english_longformer_msmarco2019/checkpoint-87500", type=str) parser.add_argument("--local_rank", default=0, type=int) parser.add_argument("--num_train_epochs", default=10, type=int) parser.add_argument("--learning_rate", default=5e-3, type=int) parser.add_argument("--weight_decay", default=0.005, type=int) parser.add_argument("--per_device_train_batch_size", default=6, type=int) parser.add_argument("--max_length", default=4096, type=int) parser.add_argument("--use_lora", default=True, type=bool) parser.add_argument("--use_lora_gpt2", default=False, type=bool) parser.add_argument("--train_dataset", default="data/qa_3w_summary.csv", type=str) parser.add_argument("--epochs", default=1, type=int) parser.add_argument("--batch_size", default=1, type=int) args = parser.parse_args() def train(epoch, model, loader, optimizer,scaler, gradient_accumulation_steps,model_output_dir): model.train() time1 = time.time() losses = [] train_bar = tqdm(loader,total=len(loader)) for index, data in enumerate(train_bar): optimizer.zero_grad() with torch.autocast(device_type="cuda",dtype=torch.float16): # print(data) outputs = model(model,**data) loss = outputs.loss # 反向傳播,計算當前梯度 loss.requires_grad_(True) losses.append(loss.item()) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() if (index+1) % 5000 == 0: model_output_dir_ = os.path.join(model_output_dir,f"epoch{epoch}") model_save_path = os.path.join(model_output_dir_,"index_{}".format(index)) if os.path.exists(model_save_path): pass else: os.makedirs(model_save_path) torch.save(model.state_dict(), os.path.join(model_save_path,"LLM_model_{:.6f}.pth".format(np.mean(losses)))) train_bar.set_description("epoch:{} idx:{} loss:{:.6f}".format(epoch,index,np.mean(losses))) def validate( model, loader): model.eval() predictions = [] actuals = [] with torch.no_grad(): with torch.autocast(device_type="cuda",dtype=torch.float16): for _, data in enumerate(tqdm(loader, file=sys.stdout, desc="Validation Data")): text = data["text_input"] context = data["context"] label = data["label"] print(text) print("len context:",len(context)) for text_,context_ in zip(text,context): preds = model.generate( text=text_,context = [context_] ) print(preds) print(label) predictions.append(preds) actuals.extend(label) return predictions, actuals def main(): epochs = args.epochs batch_size = args.batch_size lr = 1e-5 gradient_accumulation_steps = 16 model_output_dir = args.output # train_path = "qa_3w_summary.csv" train_path = args.train_dataset val_path = args.train_dataset device = torch.device(f"cuda:{args.local_rank}") model = RALLM(args) model = model.to(device) if args.use_lora: print("使用lora訓練模型"+"*"*10) from peft import LoraConfig, TaskType, get_peft_model,PeftModel,AutoPeftModelForCausalLM peft_config = LoraConfig( task_type=TaskType.CAUSAL_LM, target_modules=["W_pack",], inference_mode=False, r=256, lora_alpha=512, lora_dropout=0.1, ) model.LLM_model.enable_input_require_grads() model.LLM_model = get_peft_model(model.LLM_model, peft_config) if args.use_lora_gpt2: print("使用lora訓練模型"+"*"*10) from peft import LoraConfig, TaskType, get_peft_model,PeftModel,AutoPeftModelForCausalLM peft_config = LoraConfig( task_type=TaskType.CAUSAL_LM, target_modules=["wte","c_attn",], inference_mode=False, r=256, lora_alpha=512, lora_dropout=0.1, ) # model.LLM_model.enable_input_require_grads() model.context_encoder = get_peft_model(model.context_encoder, peft_config) print(model) torch.cuda.empty_cache() # 釋放視訊記憶體 if args.load_path: base_load_path = args.load_path # 列出所有分塊模型引數檔案的檔名 if base_load_path.endswith(".pth"): state_dict = torch.load(base_load_path,map_location=device) else: file_list = ['pytorch_model.bin'] # 建立一個空的模型狀態字典 state_dict = {} # 遍歷所有分塊檔案並載入它們 for file_name in file_list: # 載入單個分塊檔案的模型引數 part_state_dict = torch.load(os.path.join(base_load_path,file_name),map_location=device) # 將載入的模型引數合併到總的模型狀態字典中 state_dict.update(part_state_dict) # 將合併後的模型狀態字典載入到模型中 print("state_dict:") print(state_dict.keys()) model.load_state_dict(state_dict,strict=False) for param in model.context_encoder.parameters(): param.requires_grad = False # layers_to_modify = [30,31,32,33,34,35] # # Iterate over all named parameters in the model # for name, param in model.context_encoder.named_parameters(): # # Check if the parameter belongs to the specified layers # if any(f"context_encoder.h.{layer}." in name for layer in layers_to_modify): # # Set requires_grad to True or False depending on whether you want the parameter to be trainable or not # param.requires_grad = True # or False if you want to freeze the layer for param in model.ln_features.parameters(): param.requires_grad = True for param in model.Qformer.parameters(): param.requires_grad = True # 遍歷每一層並凍結引數 # for param in model.LLM_model.parameters(): # param.requires_grad = False # 凍結除了lora_A和lora_B以外的所有層 # trained = [] # untrained = [] # for name, param in model.LLM_model.named_parameters(): # # if 'v_proj.lora_A' in name or 'v_proj.lora_B' in name or 'q_proj.lora_B' in name or 'q_proj.lora_B' in name: # if 'lora_A' in name or 'lora_B' in name: # param.requires_grad = True # trained.append(name) # else: # param.requires_grad = False # untrained.append(name) # Print trainable and non-trainable parameters trainable_params = [] non_trainable_params = [] for name, param in model.named_parameters(): if param.requires_grad: trainable_params.append(name) else: non_trainable_params.append(name) print("Trainable Parameters:") print("\n".join(trainable_params)) print("\nNon-Trainable Parameters:") print("\n".join(non_trainable_params)) # setup peft # peft_config = LoraConfig( # task_type=TaskType.CAUSAL_LM, # target_modules=["q_proj","v_proj"], #W_pack. query_key_value # inference_mode=False, # r=lora_rank, # lora_alpha=lora_alpha, # lora_dropout=0.1 # ) # model = get_peft_model(model, peft_config) # model.is_parallelizable = True # model.model_parallel = True # model.print_trainable_parameters() # 轉為半精度 # model.LLM_model = model.LLM_model.half() # model.float() scaler = torch.cuda.amp.GradScaler() def collate_fn(batch): """ 對batch資料進行處理 :param batch: [一個getitem的結果,getitem的結果,getitem的結果] :return: 元組 """ # 初始化一個空字典來儲存合併後的結果 merged_dict = {} # 遍歷列表中的每個字典 for d in batch: # 遍歷每個字典中的鍵值對 for key, value in d.items(): # 如果鍵已經存在於merged_dict中,將值合併為一個字串,用逗號分隔 if key in merged_dict: merged_dict[key].append(value) else: # 如果鍵不存在於merged_dict中,直接新增到merged_dict中 merged_dict[key] = [value] # 輸出合併後的結果 # print(merged_dict) return merged_dict print("Start Load Train Data...") train_params = { "batch_size": batch_size, "shuffle": True, "num_workers": 0, } training_set = QADataset(train_path,train=True) training_loader = DataLoader(training_set, **train_params,collate_fn=collate_fn) print("Start Load Validation Data...") val_params = { "batch_size": batch_size, "shuffle": False, "num_workers": 0, } val_set = QADataset(val_path,train=False) val_loader = DataLoader(val_set, **val_params,collate_fn=collate_fn) # optimizer = torch.optim.AdamW([{'params': model.bert_encoder.parameters(), 'lr': 1e-5}, # {'params': model.Qformer.parameters(), 'lr': 1e-3}, # {'params': model.ln_features.parameters(), 'lr': 1e-3}, # {'params': model.internlm_model.parameters(), 'lr': 1e-5}, # {'params': query_tokens_clone, 'lr': 1e-3}] # # ) optimizer = torch.optim.AdamW([{'params': model.parameters(), 'lr': lr}]) # device_ids = [1,3,6,7] # model = DataParallel(model, device_ids=device_ids) print("Start Training...") for epoch in range(epochs): # train(epoch, model, training_loader, optimizer,scaler, gradient_accumulation_steps,model_output_dir) # print("Save Model To ", 加) # model.save_pretrained(model_output_dir) # 驗證 # print("Start Validation...") with torch.no_grad(): predictions, actuals = validate(model, val_loader) # 驗證結果儲存 final_df = pd.DataFrame({"Generated Text": predictions, "Actual Text": actuals}) val_data_path = os.path.join(model_output_dir, f"predictions_{epoch}.csv") final_df.to_csv(val_data_path) print("Validation Data To ", val_data_path) if __name__ == '__main__': main()
test_chat.py
# -*- coding: utf-8 -*- import pandas as pd from torch.utils.data import DataLoader from transformers import AutoModel, AutoTokenizer,AutoModelForMaskedLM,BertTokenizer, GPT2LMHeadModel, TextGenerationPipeline,AutoModelForCausalLM,AutoConfig from transformers.generation.utils import GenerationConfig from peft import LoraConfig, get_peft_model, TaskType from tqdm import tqdm import torch import os, time, sys import numpy as np from modeling_RALLM import RALLM import argparse from torch import autocast parser = argparse.ArgumentParser() parser.add_argument("--is_compress", default=False, type=bool) parser.add_argument("--compressibility_factor", default=0.1, type=float,dest="0-1") parser.add_argument("--LLM_model", default="Baichuan2_7B", type=str) parser.add_argument("--output", default="output_corpus_2", type=str) parser.add_argument("--encoder", default="gpt2_large", type=str) parser.add_argument("--query_tokens", default=98, type=int) parser.add_argument("--load_path", default="output_corpus_lora/checkpoint-200004", type=str) parser.add_argument("--local_rank", default=0, type=int) parser.add_argument("--num_train_epochs", default=10, type=int) parser.add_argument("--learning_rate", default=5e-5, type=int) parser.add_argument("--weight_decay", default=0.005, type=int) parser.add_argument("--per_device_train_batch_size", default=6, type=int) parser.add_argument("--max_length", default=4096, type=int) parser.add_argument("--use_lora", default=False, type=bool) parser.add_argument("--use_lora_gpt2", default=True, type=bool) args = parser.parse_args() def chat( model): while True: context1 = input("輸入context:") context2 = input("輸入context2:") # context = """NVIDIA的A6000顯示卡是一款面向專業領域的高效能顯示卡。關於它的雙精度(Double Precision)、單精度(Single Precision)和半精度(Half Precision)的算力,我們可以參考官方提供的規格引數。截至我最後更新的資訊(2023年4月),以下是A6000顯示卡的相關算力資料:雙精度(Double Precision): A6000顯示卡在雙精度計算方面的效能通常不如單精度和半精度,因為雙精度計算需要更多的計算資源和頻寬。具體數值因顯示卡的不同批次和製造工藝的微小差異可能有所不同。單精度(Single Precision): A6000在單精度計算方面的效能通常很高,適合於大多數圖形處理和一些科學計算任務。單精度計算是大多數顯示卡的主要優勢。半精度(Half Precision): 半精度計算主要用於某些機器學習和深度學習應用,能提供更高的吞吐量。A6000顯示卡在半精度計算方面的效能通常很高。 # """ flag_context = "<context> "*2 text = f'<reserved_106>請複述這段被壓縮的內容{flag_context} <reserved_107>' data = {"context":[[context1,context2]],"text_input":text} model.eval() with torch.no_grad(): with autocast(device_type="cuda",dtype=torch.float16): text = data["text_input"] context = data["context"] preds = model.generate( text=text,context = context ) print("輸出:",preds) def main(): model = RALLM(args) # 釋放不再需要的模型 device = torch.device(f"cuda:{args.local_rank}") model.to(device) if args.use_lora: from peft import LoraConfig, TaskType, get_peft_model,PeftModel,AutoPeftModelForCausalLM peft_config = LoraConfig( task_type=TaskType.CAUSAL_LM, target_modules=["W_pack",], inference_mode=False, r=256, lora_alpha=512, lora_dropout=0.1, ) model.LLM_model.enable_input_require_grads() model.LLM_model = get_peft_model(model.LLM_model, peft_config) if args.use_lora_gpt2: print("使用lora訓練模型"+"*"*10) from peft import LoraConfig, TaskType, get_peft_model,PeftModel,AutoPeftModelForCausalLM peft_config = LoraConfig( task_type=TaskType.CAUSAL_LM, target_modules=["c_attn",], inference_mode=False, r=64, lora_alpha=256, lora_dropout=0.1, ) # model.LLM_model.enable_input_require_grads() model.context_encoder = get_peft_model(model.context_encoder, peft_config) print(model) base_load_path = "output_qa3w_lora_gpt2_base_corpus" # 列出所有分塊模型引數檔案的檔名 file_list = ['pytorch_model.bin'] # 建立一個空的模型狀態字典 state_dict = {} # 遍歷所有分塊檔案並載入它們 for file_name in file_list: # 載入單個分塊檔案的模型引數 part_state_dict = torch.load(os.path.join(base_load_path,file_name),map_location=f"cuda:{args.local_rank}") # 將載入的模型引數合併到總的模型狀態字典中 state_dict.update(part_state_dict) # 將合併後的模型狀態字典載入到模型中 model.load_state_dict(state_dict) model.LLM_model.generation_config = GenerationConfig.from_pretrained(base_load_path) # 載入模型的引數 # load_path = '/data2/xinyuuliu/InternLM-XComposer/output12/epoch9/index_29999/LLM_model_0.109371.pth' # checkpoint = torch.load(load_path,map_location="cuda:0") #,map_location="cuda:3" # # 將引數載入到模型中 # model.load_state_dict(checkpoint) # 轉為半精度 # model.LLM_model = model.LLM_model.half() model = model.half() # model.float() chat(model) if __name__ == '__main__': main()
ds_config.json
{ "fp16": { "enabled": "auto", "loss_scale": 0, "loss_scale_window": 1000, "initial_scale_power": 16, "hysteresis": 2, "min_loss_scale": 1 }, "bf16": { "enabled": "auto" }, "zero_optimization": { "stage": 2, "offload_optimizer": { "device": "none", "pin_memory": true }, "allgather_partitions": true, "allgather_bucket_size": 2e8, "overlap_comm": true, "reduce_scatter": true, "reduce_bucket_size": 2e8, "contiguous_gradients": true }, "gradient_accumulation_steps": "auto", "gradient_clipping": "auto", "steps_per_print": 10, "train_batch_size": "auto", "train_micro_batch_size_per_gpu": "auto", "wall_clock_breakdown": false }
fine-tune.py
# This code is based on the revised code from fastchat based on tatsu-lab/stanford_alpaca. from dataclasses import dataclass, field import json import math import logging import os import random from typing import Dict, Optional, List import torch from torch.utils.data import Dataset from deepspeed import zero from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus import transformers from transformers import Trainer, deepspeed from transformers.trainer_pt_utils import LabelSmoother from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training from accelerate.utils import DistributedType from torchvision import transforms from typing import Dict, Optional, Sequence, List from modeling_RALLM import RALLM # from dataset_batch import QADataset from dataset_rerank_en import QADataset # from dataset_rerank_en_gpt import QADataset # from dataset_rerank import QADataset import argparse from transformers import AutoModel, AutoTokenizer,AutoModelForMaskedLM,BertTokenizer, GPT2LMHeadModel, TextGenerationPipeline,AutoModelForCausalLM,AutoConfig from torch.optim import AdamW IGNORE_TOKEN_ID = LabelSmoother.ignore_index parser = argparse.ArgumentParser() parser.add_argument("--is_compress", default=True, type=bool) parser.add_argument("--compressibility_factor", default=0.1, type=float,dest="0-1") parser.add_argument("--LLM_model", default="Baichuan2_7B", type=str) parser.add_argument("--output", default="output_corpus_lora2", type=str) parser.add_argument("--encoder", default="gpt2_large", type=str) parser.add_argument("--query_tokens", default=98, type=int) parser.add_argument("--load_path", default="output_english_longformer_rerank100k/checkpoint-112356", type=str) parser.add_argument("--local_rank", default=-1, type=int) parser.add_argument("--num_train_epochs", default=10, type=int) parser.add_argument("--learning_rate", default=5e-5, type=float) parser.add_argument("--weight_decay", default=0.01, type=float) parser.add_argument("--per_device_train_batch_size", default=6, type=int) parser.add_argument("--max_length", default=4096, type=int) parser.add_argument("--use_lora", default=True, type=bool) parser.add_argument("--use_lora_gpt2", default=False, type=bool) parser.add_argument("--train_dataset", default="data/qa_3w_summary.csv", type=str) args = parser.parse_args() os.environ["CUDA_VISIBLE_DEVICES"] = str(args.local_rank) @dataclass class TrainingArguments(transformers.TrainingArguments): # cache_dir: Optional[str] = field(default=None) optim: str = field(default="adamw_torch") local_rank: int = field(default=None) @dataclass class DataCollatorForSupervisedDataset(object): """Collate examples for supervised fine-tuning.""" def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: """ 對batch資料進行處理 :param batch: [一個getitem的結果,getitem的結果,getitem的結果] :return: 元組 """ # 初始化一個空字典來儲存合併後的結果 merged_dict = {} # 遍歷列表中的每個字典 for d in instances: # 遍歷每個字典中的鍵值對 for key, value in d.items(): # 如果鍵已經存在於merged_dict中,將值合併為一個字串,用逗號分隔 if key in merged_dict: merged_dict[key].append(value) else: # 如果鍵不存在於merged_dict中,直接新增到merged_dict中 merged_dict[key] = [value] # 輸出合併後的結果 # print(merged_dict) return merged_dict def train(): global model train_path = args.train_dataset # train_path = "data/news_summary_30w.csv" # val_path = "QA_5000_summary.csv" device_map = None world_size = int(os.environ.get("WORLD_SIZE", 1)) ddp = world_size != 1 torch.cuda.device(0) torch.cuda.empty_cache() # 釋放視訊記憶體 # init model and tokenizer model = RALLM(args) # 釋放不再需要的模型 device = torch.device(f"cuda:{args.local_rank}") model.to(device) # torch.cuda.device(0) torch.cuda.empty_cache() # 釋放視訊記憶體 if args.use_lora: print("使用lora訓練模型"+"*"*10) from peft import LoraConfig, TaskType, get_peft_model,PeftModel,AutoPeftModelForCausalLM peft_config = LoraConfig( task_type=TaskType.CAUSAL_LM, target_modules=["W_pack",], inference_mode=False, r=256, lora_alpha=512, lora_dropout=0.1, ) model.LLM_model.enable_input_require_grads() model.LLM_model = get_peft_model(model.LLM_model, peft_config) if args.use_lora_gpt2: print("使用lora訓練模型"+"*"*10) from peft import LoraConfig, TaskType, get_peft_model,PeftModel,AutoPeftModelForCausalLM peft_config = LoraConfig( task_type=TaskType.CAUSAL_LM, # target_modules=["wte","c_attn",], target_modules=["query","key","value","query_global","key_global","value_global"], inference_mode=False, r=128, lora_alpha=512, lora_dropout=0.1, ) # model.LLM_model.enable_input_require_grads() model.context_encoder = get_peft_model(model.context_encoder, peft_config) print(model) for param in model.context_encoder.parameters(): param.requires_grad = False # layers_to_modify = [27,28,29,30,31,32,33,34, 35] # # Iterate over all named parameters in the model # for name, param in model.context_encoder.named_parameters(): # # Check if the parameter belongs to the specified layers # if any(f"h.{layer}." in name for layer in layers_to_modify): # # Set requires_grad to True or False depending on whether you want the parameter to be trainable or not # param.requires_grad = True # or False if you want to freeze the layer # # if f"ln_f" in name: # # # Set requires_grad to True or False depending on whether you want the parameter to be trainable or not # # param.requires_grad = True # or False if you want to freeze the layer # if f"wte" in name: # # # Set requires_grad to True or False depending on whether you want the parameter to be trainable or not # param.requires_grad = True # or False if you want to freeze the layer # if f"wpe" in name: # # # Set requires_grad to True or False depending on whether you want the parameter to be trainable or not # param.requires_grad = True # or False if you want to freeze the layer for param in model.ln_features.parameters(): param.requires_grad = True # 遍歷每一層並凍結引數 # for param in model.LLM_model.parameters(): # param.requires_grad = False # 凍結除了lora_A和lora_B以外的所有層 trained = [] untrained = [] # for name, param in model.LLM_model.named_parameters(): # # if 'v_proj.lora_A' in name or 'v_proj.lora_B' in name or 'q_proj.lora_B' in name or 'q_proj.lora_B' in name: # # if 'lora_A' in name or 'lora_B' in name or "layers.30" in name or "layers.31" in name or "embed_tokens" in name: # if 'lora_A' in name or 'lora_B' in name or "embed_tokens" in name: # param.requires_grad = True # trained.append(name) # else: # param.requires_grad = False # untrained.append(name) # print("可訓練的大模型層",trained) # print("不可訓練的大模型層",untrained) # Print trainable and non-trainable parameters trainable_params = [] non_trainable_params = [] for name, param in model.named_parameters(): if param.requires_grad: trainable_params.append(name) else: non_trainable_params.append(name) print("Trainable Parameters:") print("\n".join(trainable_params)) print("\nNon-Trainable Parameters:") print("\n".join(non_trainable_params)) if args.load_path: base_load_path = args.load_path # 列出所有分塊模型引數檔案的檔名 if base_load_path.endswith(".pth"): state_dict = torch.load(base_load_path,map_location=device) else: file_list = ['pytorch_model.bin'] # 建立一個空的模型狀態字典 state_dict = {} # 遍歷所有分塊檔案並載入它們 for file_name in file_list: # 載入單個分塊檔案的模型引數 part_state_dict = torch.load(os.path.join(base_load_path,file_name),map_location=device) # 將載入的模型引數合併到總的模型狀態字典中 state_dict.update(part_state_dict) # 將合併後的模型狀態字典載入到模型中 print("state_dict:") print(state_dict.keys()) model.load_state_dict(state_dict,strict=False) # # 分離 model.Qformer 的引數和其他所有引數 # qformer_params = set(model.Qformer.parameters()) # other_params = [p for p in model.parameters() if p not in qformer_params] # # 建立引數組 # param_groups = [ # {'params': list(qformer_params), 'lr': 1e-3}, # {'params': other_params, 'lr': 1e-5} # ] # 使用引數組建立 AdamW 最佳化器 # optimizer = AdamW(param_groups) training_set = QADataset(train_path,train=True) # val_set = QADataset(val_path,train=False) print(training_set[0]) # 設定訓練引數 training_args = TrainingArguments( local_rank=args.local_rank, output_dir=args.output, # 輸出目錄 num_train_epochs=args.num_train_epochs, # 訓練輪數 per_device_train_batch_size=args.per_device_train_batch_size, # 每個裝置的批大小 warmup_steps=500, # 預熱步驟 weight_decay=0.01, # 權重衰減 logging_dir='./logs', # 日誌目錄 deepspeed = "ds_config.json", gradient_accumulation_steps = 1 , save_strategy = "epoch" , learning_rate = 5e-5 , # lr_scheduler_type='linear', # logging_steps= 100, ) data_collator = DataCollatorForSupervisedDataset() # Start trainner trainer = Trainer( model = model, tokenizer = model.LLM_tokenizer, train_dataset=training_set, # eval_dataset=val_set, data_collator=data_collator, args = training_args, # optimizers=(optimizer, None) # 自定義最佳化器 ) trainer.train() trainer.save_state() trainer.save_model(output_dir=args.output) if __name__ == "__main__": train() #https://arxiv.org/pdf/2102.05951.pdf
fine-tune.sh
hostfile="" # deepspeed --include localhost:1,2,3 --hostfile=$hostfile fine-tune.py \ # --report_to "none" \ # --data_path "/data1/xinyuuliu/qa_data/professional_data/train_二階段.json" \ # --model_name_or_path "/data1/xinyuuliu/Baichuan2-13B-Chat" \ # --output_dir "output_lora3_1_2" \ # --model_max_length 4000\ # --num_train_epochs 10 \ # --per_device_train_batch_size 4 \ # --gradient_accumulation_steps 1 \ # --save_strategy epoch \ # --learning_rate 2e-4 \ # --lr_scheduler_type constant \ # --adam_beta1 0.9 \ # --adam_beta2 0.98 \ # --adam_epsilon 1e-8 \ # --max_grad_norm 1.0 \ # --weight_decay 1e-4 \ # --warmup_ratio 0.0 \ # --logging_steps 1 \ # --gradient_checkpointing True \ # --deepspeed ds_config.json \ # --bf16 True \ # --tf32 True \ # --use_lora True \ # --load_lora_path /data1/xinyuuliu/Baichuan2-main/fine-tune/output_lora3_1/checkpoint-8260 # --use_NEFT True # --use_frozen True # export CUDA_LAUNCH_BLOCKING=1 # CUDA_VISIBLE_DEVICES=“2,3,4,5,6,7” deepspeed --include localhost:0,1,2,3,4,5,6,7 --master_port 29501 --hostfile=$hostfile fine-tune.py \ --encoder longformer \ --query_tokens 32 \ --output output_english_longformer_msmarco2019\ --num_train_epochs 20 \ --per_device_train_batch_size 1 \ # --load_path /data2/xinyuuliu/Baichuan2_qformer_bert/output_30w/checkpoint-22488 \