RALLM 檢索增強LLM架構

高颜值的殺生丸發表於2024-05-27

import copy
import os
import sys

dir_path = os.path.dirname(os.path.realpath(__file__))
sys.path.insert(0, dir_path)
import contextlib
import torch.utils.checkpoint
from torch.nn import LayerNorm
from torch import nn
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from modeling_perceive_sampler import BertConfig, BertLMHeadModel
from transformers.utils import logging
logger = logging.get_logger(__name__)
import transformers
from transformers import PreTrainedModel, AutoTokenizer, AutoModelForMaskedLM,AutoModel,BertTokenizer,GPT2LMHeadModel,PretrainedConfig,GPT2Model,GPT2Tokenizer,LongformerTokenizer, LongformerModel
os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # 使用第一個GPU
import argparse
import math


class RALLM(nn.Module):

    def __init__(self,args):
        super(RALLM,self).__init__()

        self.is_compress = args.is_compress    
        self.use_lora = args.use_lora  
        print('Init LLM ... ')

        if args.LLM_model == "Baichuan2_13B":
            self.LLM_model_name = "Baichuan2-13B-Chat"
            self.LLM_hidden_size = 5120
        elif args.LLM_model == "Baichuan2_7B":
            self.LLM_model_name = "baichuan2_7B"
            self.LLM_hidden_size = 4096

        
        self.LLM_model = transformers.AutoModelForCausalLM.from_pretrained(
            self.LLM_model_name,
            device_map=f"cuda:{args.local_rank}",
            trust_remote_code=True,
            torch_dtype=torch.bfloat16,
            # cache_dir=training_args.cache_dir,
        )
        self.LLM_tokenizer = transformers.AutoTokenizer.from_pretrained(
            self.LLM_model_name,
            use_fast=False,
            trust_remote_code=True,
            model_max_length=4096,
            # cache_dir=training_args.cache_dir,
        )

        self.flag_context_start = nn.Parameter(torch.zeros([1, 1, self.LLM_hidden_size]))#.to(self.device)
        self.flag_context_end = nn.Parameter(torch.zeros([1, 1, self.LLM_hidden_size]))#.to(self.device)
        self.flag_context_start.requires_grad = False
        self.flag_context_end.requires_grad = False

        self.device = self.LLM_model.device
        self.user_token = self.LLM_tokenizer._convert_id_to_token(195)
        self.assisent_token = self.LLM_tokenizer._convert_id_to_token(196)
        self.eoa = self.LLM_tokenizer._convert_id_to_token(2)

        print("user_token:",self.user_token,"assisent_token:",self.assisent_token,"eoa:",self.eoa)
        print('Done')

        print('Init context encoder ... ')

        self.init_context_encoder(args)
 
        print('Done')

    def init_Qformer(self,num_query_token,num_features):
        self.Qformer  = self.init_qformer(num_query_token, num_features,cross_attention_freq=1)
        self.Qformer.bert.embeddings.word_embeddings = None
        self.Qformer.bert.embeddings.position_embeddings = None
        for layer in self.Qformer.bert.encoder.layer:
            layer.output = None
            layer.intermediate = None
        self.Qformer.cls = None


    @classmethod
    def init_qformer(cls,
                        num_query_token,
                        vision_width,
                        cross_attention_freq=2,
                        pretrain=True):
        encoder_config = BertConfig()
        encoder_config.num_hidden_layers = 2
        encoder_config.hidden_size = vision_width
        encoder_config.encoder_width = vision_width
        encoder_config.num_attention_heads = vision_width//64
        # insert cross-attention layer every other block
        encoder_config.add_cross_attention = True
        encoder_config.cross_attention_freq = cross_attention_freq
        encoder_config.query_length = num_query_token
        Qformer = BertLMHeadModel(config=encoder_config)

        return Qformer


    def init_context_encoder(self,args):
        
        num_query_token = args.query_tokens = 0

        if args.encoder == "bert_base":
            self.context_tokenizer = AutoTokenizer.from_pretrained("bert_base_chinese")
            self.context_encoder = AutoModelForMaskedLM.from_pretrained("bert_base_chinese",output_hidden_states=True)
            num_features = 768

        if args.encoder == "bert_large":
            self.context_tokenizer = AutoTokenizer.from_pretrained("bert_large_chinese",max_length=2000)
            self.context_encoder = AutoModelForMaskedLM.from_pretrained("bert_large_chinese",output_hidden_states=True)
            num_features = 1024

        if args.encoder == "gpt2_xlarge":
            self.context_tokenizer = BertTokenizer.from_pretrained("gpt2_chinese_xlarge")
            self.context_encoder = GPT2LMHeadModel.from_pretrained("gpt2_chinese_xlarge")
            num_features = 1600

        if args.encoder == "gpt2_large":
            self.context_tokenizer = BertTokenizer.from_pretrained("gpt2_chinese_large")
            self.context_encoder = GPT2LMHeadModel.from_pretrained("gpt2_chinese_large")
            num_features = 1280

        if args.encoder == "gpt2_large_en":
            self.context_tokenizer = GPT2Tokenizer.from_pretrained("/data2/xinyuuliu/Baichuan2_qformer_bert/gpt2-large-EN")
            self.context_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
            self.context_encoder = GPT2Model.from_pretrained("/data2/xinyuuliu/Baichuan2_qformer_bert/gpt2-large-EN")
            num_features = 1280

        if args.encoder == "longformer":
            self.context_tokenizer = LongformerTokenizer.from_pretrained('longformer')
            self.context_encoder = LongformerModel.from_pretrained('longformer')
            num_features = 768

        if args.encoder == "longformer_large":
            self.context_tokenizer = LongformerTokenizer.from_pretrained('longformer-large')
            self.context_encoder = LongformerModel.from_pretrained('longformer-large')
            num_features = 1024



        # bert_tokenizer = AutoTokenizer.from_pretrained("bert_base_chinese",max_length=2000)
        # bert_encoder = AutoModelForMaskedLM.from_pretrained("longformer_zh",output_hidden_states=True) #.to(device)
        self.context_encoder = self.context_encoder.to(self.device)

        self.context_score = torch.nn.ModuleList([
                torch.nn.Linear(num_features, 64),
                torch.nn.Tanh(),
                torch.nn.Linear(64, 1),
            ]) # 768是BERT的隱藏狀態維度,1是目標輸出維度

        self.context2llm_proj = torch.nn.Linear(num_features, self.LLM_hidden_size)  # 768是BERT的隱藏狀態維度,1是目標輸出維度
        self.llm_proj = torch.nn.Linear(num_features, self.LLM_hidden_size)
        # model.embed2qformer_proj = torch.nn.Linear(num_features, 768)

        self.ln_features = LayerNorm(num_features) 
        self.init_Qformer(num_query_token,num_features)

        # del model.internlm_proj
        # del model.Qformer
        # torch.cuda.empty_cache()  # 釋放視訊記憶體

        # if device:
        # model = self.model.to(self.device)


    def encode_text(self, text, add_special_tokens=False):

        input_ids = self.LLM_tokenizer.encode(text)
        input_ids = torch.LongTensor([input_ids]).to(self.device)

        if self.use_lora:
            text_embeds = self.LLM_model.base_model.model.model.embed_tokens(input_ids)
        else:
            text_embeds = self.LLM_model.model.embed_tokens(input_ids)
        return text_embeds


    def calculate_compressibility(self,x,k=0):
        return (x * k*(9 / 1000) + 1) * 111.111 / (x + 111.111)


    # 批次輸入句子
    def batch_input_sentences(self,sentences):
        input_ids_list = [self.context_tokenizer.encode(sentence,return_tensors="pt",padding='max_length', max_length=2500, truncation=True) for sentence in sentences]
        max_length = max(len(input_ids[0]) for input_ids in input_ids_list)
        input_ids_padded = [torch.cat([input_ids, torch.zeros(1, max_length - input_ids.size(1), dtype=torch.long)], dim=1) for input_ids in input_ids_list]
        input_ids_tensor = torch.cat(input_ids_padded, dim=0)
        return input_ids_tensor

    def encode_context(self, text_list):
        if text_list is None:
            return None

        inputs_LLMs = []
        input_atts = []
        # print(text_list)
        for text in text_list:
            # input_ids = self.context_tokenizer.encode(text, add_special_tokens=True, return_tensors="pt")
            # 對文字列表進行編碼並進行最長的填充
            # encoded_ids = self.context_tokenizer(text, padding=True, return_tensors="pt", truncation=True)
            input_ids = self.batch_input_sentences(text)
            input_ids =input_ids.to(self.device)
            # input_ids = encoded_ids.data["input_ids"].to(self.device)
            # attention_mask = encoded_ids.data["attention_mask"].to(self.device)
            outputs = self.context_encoder(input_ids,output_hidden_states=True)
            # 提取最後一層的隱藏狀態向量
            embedding,last_hidden_state = outputs.hidden_states[0],outputs.hidden_states[-1] #outputs.logits

            x = last_hidden_state
            for layer in self.context_score:
                x = layer(x)

            output = x
            # output = self.context_score(last_hidden_state)  # 進行線性變換
            
            batch,seq_len,ebd_dim = last_hidden_state.size()
            
            # compressibility = -0.0009 * seq_len+1 #壓縮率計算長度越低壓縮率越低,長度越長,壓縮率越高。線性壓縮不好 x*f(x) 不是單調遞減的
            # compressibility = 111.111/(seq_len+111.111) #重新設計非線性壓縮 10以下不壓縮,0-1000 x*f(x) 遞減 
            
            if self.is_compress:
                compressibility = self.calculate_compressibility(seq_len,0)
                K = math.ceil(seq_len*compressibility)
            else:
                K = seq_len

            # 使用 torch.topk 函式獲取 top k 的索引
            topk_indices = torch.topk(output, K,dim=1).indices
            # print(topk_indices)
            topk_indices, sorted_indices = torch.sort(topk_indices,dim=1)   #恢復原文順序
            # print(topk_indices)

            # 計算 top k 對應的 last_hidden_state
            topk_selected_last_hidden_state = torch.gather(last_hidden_state, 1, topk_indices.expand(-1, -1, ebd_dim))
            # print(last_hidden_state)
            # print(topk_selected_last_hidden_state)
            topk_selected_embedding = torch.gather(embedding, 1, topk_indices.expand(-1, -1, ebd_dim))
            # bert_text_atts = torch.gather(attention_mask, 1, torch.squeeze(topk_indices, dim=2))

            bert_text_embeds = self.ln_features(last_hidden_state)
            bert_text_atts = torch.ones(bert_text_embeds.size()[:-1],dtype=torch.long).to(self.device)
            # query_tokens = self.query_tokens.expand(bert_text_atts.shape[0], -1,-1)
            # query_tokens = topk_selected_embedding
            query_tokens = topk_selected_last_hidden_state

            query_output = self.Qformer.bert(
                query_embeds=query_tokens,
                encoder_hidden_states=bert_text_embeds,
                encoder_attention_mask=bert_text_atts,
                return_dict=True,
            )

            # topk_context_hidden_state = self.context2llm_proj(topk_selected_last_hidden_state)

            inputs_LLM = self.llm_proj(query_output.last_hidden_state)

            inputs_LLM = torch.cat([
                self.flag_context_start.expand(batch, -1, -1),
                # topk_context_hidden_state,
                inputs_LLM,
                self.flag_context_end.expand(batch, -1, -1)
            ],dim=1).view(-1, self.LLM_hidden_size)

            input_att = torch.cat([torch.ones((batch,1)).to(self.device),bert_text_atts,torch.ones((batch,1)).to(self.device)],dim=1).view(-1)
            # print(inputs_LLM.shape)
            inputs_LLMs.append(inputs_LLM)
            input_atts.append(input_att)
        # context_inputs = torch.stack(inputs_LLMs)
        return inputs_LLMs,input_atts



    def wrap_prompt(self,
                    text_embeds,
                    context_embeds=None,
                    history=None,
                    add_special=True):
        if add_special:
            if history is None:
                prompt_segs = [
                    self.user_token,
                    self.assisent_token
                ]
            else:
                prompt_segs = [self.user_token, self.assisent_token]
        else:
            prompt_segs = [self.user_token, self.assisent_token]  # used in wrap history
        prompt_seg_embeds = []
        for i, seg in enumerate(prompt_segs):
            if history is not None:
                add_special_tokens = False
            else:
                add_special_tokens = i == 0
            seg_embeds = self.encode_text(
                seg, add_special_tokens=add_special_tokens)
            prompt_seg_embeds.append(seg_embeds)
        if context_embeds is None:
            context_embeds = text_embeds.new_empty(text_embeds.size(0), 0,
                                                text_embeds.size(-1))
        else:
            # 在第一個維度(索引為0)新增一個維度
            context_embeds = context_embeds[0].unsqueeze(0)
        prompt_seg_embeds = [
            prompt_seg_embeds[0], text_embeds,context_embeds,  prompt_seg_embeds[1]
        ]
        prompt_embeds = torch.cat(prompt_seg_embeds, dim=1)
        if history is not None:
            prompt_embeds = torch.cat([*history, prompt_embeds], dim=1)
        return prompt_embeds


    def generate(self, text, context=None, **kwargs):
        text = text.replace("<context>","").replace(self.user_token,"").replace(self.assisent_token,"")
        text_embeds = self.encode_text(text)
        context_embeds,_ = self.encode_context(context)
        prompt_embeds = self.wrap_prompt(text_embeds, context_embeds)
        # out_embeds = self.LLM_model.generate(input_ids=None,
        #     inputs_embeds=prompt_embeds, **self.get_gen_args(**kwargs))
        # out_text = self.decode_text(out_embeds)
        outputs = self.LLM_model.generate(input_ids=None,inputs_embeds=prompt_embeds, generation_config=self.LLM_model.generation_config)
        response = self.LLM_tokenizer.decode(outputs[0], skip_special_tokens=True)
        return response

    def chat(self, text, context=None, history=None, **kwargs):
        text_embeds = self.encode_text(text)
        img_embeds = self.encode_context(context)
        prompt_embeds = self.wrap_prompt(text_embeds,
                                            img_embeds,
                                            history=history)
        out_embeds = self.internlm_model.generate(
            inputs_embeds=prompt_embeds, **self.get_gen_args(**kwargs))
        out_text = self.decode_text(out_embeds)

        # trunc at eoh and eoa
        clean_out_text_token_ids = self.tokenizer(
            out_text, return_tensors='pt').input_ids.to(self.device)
        clean_out_text_embeds = self.internlm_model.model.embed_tokens(
            clean_out_text_token_ids)
        clean_prompt_embeds = self.wrap_prompt(text_embeds,
                                                img_embeds,
                                                add_special=False)
        cur_history = torch.cat([clean_prompt_embeds, clean_out_text_embeds],
                                dim=1)
        if history is None:
            history = []
        history.append(cur_history)
        return out_text, history

    def align_text(self, samples, has_context=False):  ### add eos and eoa 返回<context>後的text

        text_new = []
        if has_context:  ### remove the first user to wrap image features
            text = [
                t.split("<context>")[-1] for t in samples["text_input"]
            ]
        else:
            text = [t for t in samples["text_input"]]

        text = [t + self.eoa  for t in text]
        for i in range(len(text)):
            temp = text[i]
            # temp = temp.replace('<|Bot|>', self.eoh + ' <|Bot|>')
            # if temp.find(self.eoh) > temp.find(self.eoa):
            #     temp = temp.replace(self.eoa, '', 1)
            text_new.append(temp)
        return text_new

    def prompt_wrap(self, context_embeds,context_atts, prompt_list):
        batch_size = len(context_embeds)
        p_before = [prompt.split('<context>')[0] for prompt in prompt_list]
        p_before_tokens = self.LLM_tokenizer(p_before,
                                        padding=True,
                                        truncation=True,
                                            return_tensors="pt",
                                            add_special_tokens=True).to(
                                                self.device)

        if self.use_lora:
            p_before_embeds = self.LLM_model.base_model.model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1)
        else:
            p_before_embeds = self.LLM_model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1)

        # wrapped_context_embeds = torch.cat([p_before_embeds, context_embeds], dim=1)
        # wrapped_context_embeds = torch.cat([p_before_embeds]+context_embeds, dim=1)
        wrapped_context_embeds = []
        wrapped_atts_context = []
        wrapped_target = []
        for i, (context_embed,context_att) in enumerate(zip(context_embeds,context_atts)):
            # 將p_before_embeds的每個序列與相應的張量在序列長度維度上拼接
            concatenated = torch.cat((p_before_embeds[i], context_embed), dim=0)
            wrapped_context_embeds.append(concatenated)
            # concatenated_att = torch.cat((torch.ones(p_before_embeds[i].size()[:-1],dtype=torch.long).to(self.device),context_att),dim=0)
            wrapped_atts_context.append(torch.ones(concatenated.size()[:-1],dtype=torch.long).to(self.device))
            # wrapped_atts_context.append(concatenated_att)
            target = torch.ones(concatenated.size()[:-1], dtype=torch.long) * -100
            target[0] = 2
            target = target.to(self.device)
            wrapped_target.append(target)

        # wrapped_atts_context = torch.ones(wrapped_context_embeds.size()[:-1],
        #                                 dtype=torch.long).to(self.device)

        # wrapped_target = torch.ones(
        #     batch_size, wrapped_context_embeds.shape[1], dtype=torch.long).to(
        #         self.device) * -100

        return wrapped_context_embeds, wrapped_atts_context, wrapped_target

    def text2emb(self, text):
        to_regress_tokens = self.LLM_tokenizer(text,
                                           return_tensors="pt",
                                           padding="longest",
                                           truncation=True,
                                           max_length=4096,
                                           add_special_tokens=False).to(
                                               self.device)

        targets = self.mask_human_targets(to_regress_tokens.input_ids)
        targets = targets.to(self.device)

        return to_regress_tokens, targets


    def mask_human_targets(self, input_ids, pure=False):
        target_batch = []
        for bs in range(input_ids.shape[0]):
            cur_idx = 0
            ids = input_ids[bs]
            targets = copy.deepcopy(ids)
            last_eoa = 0 
            last_eoh = 0
            for i, temp_id in enumerate(ids):
                if temp_id == 196:  #### end of human
                    targets[cur_idx:i+1] = -100

            target_batch.append(targets.unsqueeze(0))

        target_batch = torch.cat(target_batch, dim=0)
        target_batch[target_batch==0]=-100
        # print(input_ids)
        # print(target_batch)
        return target_batch

    def forward(self,
                input_ids=None,
                attention_mask=None,
                inputs_embeds=None,
                labels=None,
                output_attentions=None,
                output_hidden_states=None,
                return_dict=None,
                context = None,
                text_input = None,
                **kwargs):

        # samples = kwargs #.get('samples')
        # has_context = 'context' in samples.keys()
        if context:
            has_context = True
        else:
            has_context = False

        samples = {"text_input":text_input,"context":context}

        ### encode text
        text = self.align_text(samples=samples, has_context=has_context) #獲取<context> 後面的text
        to_regress_tokens, targets = self.text2emb(text) #返回token和target

        if self.use_lora:
            to_regress_embeds = self.LLM_model.base_model.model.model.embed_tokens(to_regress_tokens.input_ids)
        else:
            to_regress_embeds = self.LLM_model.model.embed_tokens(to_regress_tokens.input_ids)


        attention_mask = to_regress_tokens.attention_mask

        if has_context:
            prompt = samples["text_input"]

            ### encode context
            context = samples["context"]
            context_embeds,context_atts = self.encode_context(context)
            context_embeds, atts_context, wrapped_target = self.prompt_wrap(
                context_embeds,context_atts, prompt)
            ### combine text and image

            to_regress_embeds_ = []
            attention_mask_ = []
            targets_ = []
            for i, (tensor0,tensor1,tensor2) in enumerate(zip(to_regress_embeds,attention_mask,targets)):
                # 將p_before_embeds的每個序列與相應的張量在序列長度維度上拼接
                to_regress_embed = torch.cat((context_embeds[i], tensor0), dim=0)
                to_regress_embeds_.append(to_regress_embed)
                attention_m = torch.cat((atts_context[i], tensor1), dim=0)
                attention_mask_.append(attention_m)
                target = torch.cat((wrapped_target[i], tensor2), dim=0)
                targets_.append(target)


            # to_regress_embeds = torch.cat([context_embeds, to_regress_embeds],
            #                                 dim=1)
            # attention_mask = torch.cat([atts_context, attention_mask], dim=1)
            # targets = torch.cat([wrapped_target, targets], dim=1)

            # 確定最大長度
            max_len = max(t.size(0) for t in to_regress_embeds_)

            # 填充張量
            padded_to_regress_embeds_ = []
            padded_attention_mask_ = []
            padded_targets_ = []
            for (t,a,l) in zip(to_regress_embeds_,attention_mask_,targets_):
                if t.size(0) < max_len:
                    # 計算需要填充的長度
                    padding_size = max_len - t.size(0)
                    # 在序列維度上進行填充
                    padded_regress = torch.nn.functional.pad(t, (0, 0, 0, padding_size))
                    padded_attention = torch.nn.functional.pad(a, (0, padding_size), value=0)
                    padded_target = torch.nn.functional.pad(l, (0, padding_size), value=-100)

                    padded_to_regress_embeds_.append(padded_regress)
                    padded_attention_mask_.append(padded_attention)
                    padded_targets_.append(padded_target)
                else:
                    padded_to_regress_embeds_.append(t)
                    padded_attention_mask_.append(a)
                    padded_targets_.append(l)


            # 合併張量
            to_regress_embeds = torch.stack(padded_to_regress_embeds_)
            attention_mask = torch.stack(padded_attention_mask_)
            targets = torch.stack(padded_targets_)


        outputs = self.LLM_model(
            inputs_embeds=to_regress_embeds,
            attention_mask=attention_mask,
            return_dict=True,
            labels=targets,
        )
        return outputs
    






if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument("--output", default="output", type=str)
    parser.add_argument("--encoder", default="gpt2_large", type=str)
    parser.add_argument("--query_tokens", default=32, type=int)
    parser.add_argument("--load_path", default="/data2/xinyuuliu/InternLM-XComposer/output_rerank", type=str)
    parser.add_argument("--local_rank", default="0", type=str)


    args = parser.parse_args()


    model = RALLM(args)

    print(model)

    # model.encode_context("我愛北京天安門")
    # model.encode_text("我愛北京天安門")
    # #<ContextHere>
    # query = "Q:請重複內容:<cont_s><ContextHere><cont_e> \n A:"
    # context = ["電飯煲不知道怎麼選?想要吃一碗香噴噴的米飯,除了米要好之外,還需要一款效能優秀的電飯煲,所以大家在選購電飯煲的時候,一定要多花點心思看看攻略避免踩雷。我前前後後給親朋好友選購過不下5臺電飯煲,也算是積攢了不少選購經驗,今天特意總結了一下想分享給大家。1、容量選擇市面上電飯煲容量普遍在3L-5L之間,這個範圍的容量足夠滿足絕大部分家庭使用,3L一般可以滿足1-3人的家庭,4L一般可以滿足2-5人的家庭,5L一般可以滿足2-8人的家庭,如果人口超過8人建議直接選擇5L以上的容量,使用會更方便。"]

    # model.interleav_wrap(query,context)

modeling_perceive_sampler.py

"""
 * Copyright (c) 2023, salesforce.com, inc.
 * All rights reserved.
 * SPDX-License-Identifier: BSD-3-Clause
 * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
 * By Junnan Li
 * Based on huggingface code base
 * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
"""

import math
from typing import Tuple

import torch
import torch.utils.checkpoint
from torch import Tensor, device
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.activations import ACT2FN
from transformers.modeling_outputs import (
    BaseModelOutputWithPastAndCrossAttentions,
    BaseModelOutputWithPoolingAndCrossAttentions,
    CausalLMOutputWithCrossAttentions,
    MaskedLMOutput,
)
from transformers.modeling_utils import (
    PreTrainedModel,
    apply_chunking_to_forward,
    find_pruneable_heads_and_indices,
    prune_linear_layer,
)
from transformers.models.bert.configuration_bert import BertConfig
from transformers.utils import logging

logger = logging.get_logger(__name__)


class BertEmbeddings(nn.Module):
    """Construct the embeddings from word and position embeddings."""
    def __init__(self, config):
        super().__init__()
        self.word_embeddings = nn.Embedding(config.vocab_size,
                                            config.hidden_size,
                                            padding_idx=config.pad_token_id)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings,
                                                config.hidden_size)

        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
        # any TensorFlow checkpoint file
        self.LayerNorm = nn.LayerNorm(config.hidden_size,
                                      eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

        # position_ids (1, len position emb) is contiguous in memory and exported when serialized
        self.register_buffer(
            "position_ids",
            torch.arange(config.max_position_embeddings).expand((1, -1)))
        self.position_embedding_type = getattr(config,
                                               "position_embedding_type",
                                               "absolute")

        self.config = config

    def forward(
        self,
        input_ids=None,
        position_ids=None,
        query_embeds=None,
        past_key_values_length=0,
    ):
        if input_ids is not None:
            seq_length = input_ids.size()[1]
        else:
            seq_length = 0

        if position_ids is None:
            position_ids = self.position_ids[:, past_key_values_length:
                                             seq_length +
                                             past_key_values_length].clone()

        if input_ids is not None:
            embeddings = self.word_embeddings(input_ids)
            if self.position_embedding_type == "absolute":
                position_embeddings = self.position_embeddings(position_ids)
                embeddings = embeddings + position_embeddings

            if query_embeds is not None:
                embeddings = torch.cat((query_embeds, embeddings), dim=1)
        else:
            embeddings = query_embeds

        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings


class BertSelfAttention(nn.Module):
    def __init__(self, config, is_cross_attention):
        super().__init__()
        self.config = config
        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
                config, "embedding_size"):
            raise ValueError(
                "The hidden size (%d) is not a multiple of the number of attention "
                "heads (%d)" %
                (config.hidden_size, config.num_attention_heads))

        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size /
                                       config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        if is_cross_attention:
            self.key = nn.Linear(config.encoder_width, self.all_head_size)
            self.value = nn.Linear(config.encoder_width, self.all_head_size)
        else:
            self.key = nn.Linear(config.hidden_size, self.all_head_size)
            self.value = nn.Linear(config.hidden_size, self.all_head_size)

        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
        self.position_embedding_type = getattr(config,
                                               "position_embedding_type",
                                               "absolute")
        if (self.position_embedding_type == "relative_key"
                or self.position_embedding_type == "relative_key_query"):
            self.max_position_embeddings = config.max_position_embeddings
            self.distance_embedding = nn.Embedding(
                2 * config.max_position_embeddings - 1,
                self.attention_head_size)
        self.save_attention = False

    def save_attn_gradients(self, attn_gradients):
        self.attn_gradients = attn_gradients

    def get_attn_gradients(self):
        return self.attn_gradients

    def save_attention_map(self, attention_map):
        self.attention_map = attention_map

    def get_attention_map(self):
        return self.attention_map

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (
            self.num_attention_heads,
            self.attention_head_size,
        )
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_value=None,
        output_attentions=False,
    ):

        # If this is instantiated as a cross-attention module, the keys
        # and values come from an encoder; the attention mask needs to be
        # such that the encoder's padding tokens are not attended to.
        is_cross_attention = encoder_hidden_states is not None

        if is_cross_attention:
            key_layer = self.transpose_for_scores(
                self.key(encoder_hidden_states))
            value_layer = self.transpose_for_scores(
                self.value(encoder_hidden_states))
            attention_mask = encoder_attention_mask
        elif past_key_value is not None:
            key_layer = self.transpose_for_scores(self.key(hidden_states))
            value_layer = self.transpose_for_scores(self.value(hidden_states))
            key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
            value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
        else:
            key_layer = self.transpose_for_scores(self.key(hidden_states))
            value_layer = self.transpose_for_scores(self.value(hidden_states))

        mixed_query_layer = self.query(hidden_states)

        query_layer = self.transpose_for_scores(mixed_query_layer)

        past_key_value = (key_layer, value_layer)

        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = torch.matmul(query_layer,
                                        key_layer.transpose(-1, -2))

        if (self.position_embedding_type == "relative_key"
                or self.position_embedding_type == "relative_key_query"):
            seq_length = hidden_states.size()[1]
            position_ids_l = torch.arange(seq_length,
                                          dtype=torch.long,
                                          device=hidden_states.device).view(
                                              -1, 1)
            position_ids_r = torch.arange(seq_length,
                                          dtype=torch.long,
                                          device=hidden_states.device).view(
                                              1, -1)
            distance = position_ids_l - position_ids_r
            positional_embedding = self.distance_embedding(
                distance + self.max_position_embeddings - 1)
            positional_embedding = positional_embedding.to(
                dtype=query_layer.dtype)  # fp16 compatibility

            if self.position_embedding_type == "relative_key":
                relative_position_scores = torch.einsum(
                    "bhld,lrd->bhlr", query_layer, positional_embedding)
                attention_scores = attention_scores + relative_position_scores
            elif self.position_embedding_type == "relative_key_query":
                relative_position_scores_query = torch.einsum(
                    "bhld,lrd->bhlr", query_layer, positional_embedding)
                relative_position_scores_key = torch.einsum(
                    "bhrd,lrd->bhlr", key_layer, positional_embedding)
                attention_scores = (attention_scores +
                                    relative_position_scores_query +
                                    relative_position_scores_key)

        attention_scores = attention_scores / math.sqrt(
            self.attention_head_size)
        if attention_mask is not None:
            # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
            attention_scores = attention_scores + attention_mask

        # Normalize the attention scores to probabilities.
        attention_probs = nn.Softmax(dim=-1)(attention_scores)

        if is_cross_attention and self.save_attention:
            self.save_attention_map(attention_probs)
            attention_probs.register_hook(self.save_attn_gradients)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs_dropped = self.dropout(attention_probs)

        # Mask heads if we want to
        if head_mask is not None:
            attention_probs_dropped = attention_probs_dropped * head_mask

        context_layer = torch.matmul(attention_probs_dropped, value_layer)

        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (
            self.all_head_size, )
        context_layer = context_layer.view(*new_context_layer_shape)

        outputs = ((context_layer, attention_probs) if output_attentions else
                   (context_layer, ))

        outputs = outputs + (past_key_value, )
        return outputs


class BertSelfOutput(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.LayerNorm = nn.LayerNorm(config.hidden_size,
                                      eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states


class BertAttention(nn.Module):
    def __init__(self, config, is_cross_attention=False):
        super().__init__()
        self.self = BertSelfAttention(config, is_cross_attention)
        self.output = BertSelfOutput(config)
        self.pruned_heads = set()

    def prune_heads(self, heads):
        if len(heads) == 0:
            return
        heads, index = find_pruneable_heads_and_indices(
            heads,
            self.self.num_attention_heads,
            self.self.attention_head_size,
            self.pruned_heads,
        )

        # Prune linear layers
        self.self.query = prune_linear_layer(self.self.query, index)
        self.self.key = prune_linear_layer(self.self.key, index)
        self.self.value = prune_linear_layer(self.self.value, index)
        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)

        # Update hyper params and store pruned heads
        self.self.num_attention_heads = self.self.num_attention_heads - len(
            heads)
        self.self.all_head_size = (self.self.attention_head_size *
                                   self.self.num_attention_heads)
        self.pruned_heads = self.pruned_heads.union(heads)

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_value=None,
        output_attentions=False,
    ):
        self_outputs = self.self(
            hidden_states,
            attention_mask,
            head_mask,
            encoder_hidden_states,
            encoder_attention_mask,
            past_key_value,
            output_attentions,
        )
        attention_output = self.output(self_outputs[0], hidden_states)

        outputs = (attention_output,
                   ) + self_outputs[1:]  # add attentions if we output them
        return outputs


class BertIntermediate(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
        if isinstance(config.hidden_act, str):
            self.intermediate_act_fn = ACT2FN[config.hidden_act]
        else:
            self.intermediate_act_fn = config.hidden_act

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.intermediate_act_fn(hidden_states)
        return hidden_states


class BertOutput(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
        self.LayerNorm = nn.LayerNorm(config.hidden_size,
                                      eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states


class BertLayer(nn.Module):
    def __init__(self, config, layer_num):
        super().__init__()
        self.config = config
        self.chunk_size_feed_forward = config.chunk_size_feed_forward
        self.seq_len_dim = 1
        self.attention = BertAttention(config)
        self.layer_num = layer_num
        if (self.config.add_cross_attention
                and layer_num % self.config.cross_attention_freq == 0):
            self.crossattention = BertAttention(
                config, is_cross_attention=self.config.add_cross_attention)
            self.has_cross_attention = True
        else:
            self.has_cross_attention = False
        self.intermediate = BertIntermediate(config)
        self.output = BertOutput(config)

        self.intermediate_query = BertIntermediate(config)
        self.output_query = BertOutput(config)

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_value=None,
        output_attentions=False,
        query_length=0,
    ):
        # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
        self_attn_past_key_value = (past_key_value[:2]
                                    if past_key_value is not None else None)
        self_attention_outputs = self.attention(
            hidden_states,
            attention_mask,
            head_mask,
            output_attentions=output_attentions,
            past_key_value=self_attn_past_key_value,
        )
        attention_output = self_attention_outputs[0]
        outputs = self_attention_outputs[1:-1]

        present_key_value = self_attention_outputs[-1]

        if query_length > 0:
            query_attention_output = attention_output[:, :query_length, :]

            if self.has_cross_attention:
                assert (
                    encoder_hidden_states is not None
                ), "encoder_hidden_states must be given for cross-attention layers"
                cross_attention_outputs = self.crossattention(
                    query_attention_output,
                    attention_mask,
                    head_mask,
                    encoder_hidden_states,
                    encoder_attention_mask,
                    output_attentions=output_attentions,
                )
                query_attention_output = cross_attention_outputs[0]
                outputs = (
                    outputs + cross_attention_outputs[1:-1]
                )  # add cross attentions if we output attention weights

            layer_output = apply_chunking_to_forward(
                self.feed_forward_chunk_query,
                self.chunk_size_feed_forward,
                self.seq_len_dim,
                query_attention_output,
            )
            if attention_output.shape[1] > query_length:
                layer_output_text = apply_chunking_to_forward(
                    self.feed_forward_chunk,
                    self.chunk_size_feed_forward,
                    self.seq_len_dim,
                    attention_output[:, query_length:, :],
                )
                layer_output = torch.cat([layer_output, layer_output_text],
                                         dim=1)
        else:
            layer_output = apply_chunking_to_forward(
                self.feed_forward_chunk,
                self.chunk_size_feed_forward,
                self.seq_len_dim,
                attention_output,
            )
        outputs = (layer_output, ) + outputs

        outputs = outputs + (present_key_value, )

        return outputs

    def feed_forward_chunk(self, attention_output):
        intermediate_output = self.intermediate(attention_output)
        layer_output = self.output(intermediate_output, attention_output)
        return layer_output

    def feed_forward_chunk_query(self, attention_output):
        intermediate_output = self.intermediate_query(attention_output)
        layer_output = self.output_query(intermediate_output, attention_output)
        return layer_output


class BertEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.layer = nn.ModuleList(
            [BertLayer(config, i) for i in range(config.num_hidden_layers)])

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_values=None,
        use_cache=None,
        output_attentions=False,
        output_hidden_states=False,
        return_dict=True,
        query_length=0,
    ):
        all_hidden_states = () if output_hidden_states else None
        all_self_attentions = () if output_attentions else None
        all_cross_attentions = (() if output_attentions
                                and self.config.add_cross_attention else None)

        next_decoder_cache = () if use_cache else None

        for i in range(self.config.num_hidden_layers):
            layer_module = self.layer[i]
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states, )

            layer_head_mask = head_mask[i] if head_mask is not None else None
            past_key_value = past_key_values[
                i] if past_key_values is not None else None

            if getattr(self.config, "gradient_checkpointing",
                       False) and self.training:

                if use_cache:
                    logger.warn(
                        "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
                    )
                    use_cache = False

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        return module(*inputs, past_key_value,
                                      output_attentions, query_length)

                    return custom_forward

                layer_outputs = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(layer_module),
                    hidden_states,
                    attention_mask,
                    layer_head_mask,
                    encoder_hidden_states,
                    encoder_attention_mask,
                )
            else:
                layer_outputs = layer_module(
                    hidden_states,
                    attention_mask,
                    layer_head_mask,
                    encoder_hidden_states,
                    encoder_attention_mask,
                    past_key_value,
                    output_attentions,
                    query_length,
                )

            hidden_states = layer_outputs[0]
            if use_cache:
                next_decoder_cache += (layer_outputs[-1], )
            if output_attentions:
                all_self_attentions = all_self_attentions + (
                    layer_outputs[1], )
                all_cross_attentions = all_cross_attentions + (
                    layer_outputs[2], )

        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states, )

        if not return_dict:
            return tuple(v for v in [
                hidden_states,
                next_decoder_cache,
                all_hidden_states,
                all_self_attentions,
                all_cross_attentions,
            ] if v is not None)
        return BaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=hidden_states,
            past_key_values=next_decoder_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
            cross_attentions=all_cross_attentions,
        )


class BertPooler(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()

    def forward(self, hidden_states):
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output


class BertPredictionHeadTransform(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        if isinstance(config.hidden_act, str):
            self.transform_act_fn = ACT2FN[config.hidden_act]
        else:
            self.transform_act_fn = config.hidden_act
        self.LayerNorm = nn.LayerNorm(config.hidden_size,
                                      eps=config.layer_norm_eps)

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.transform_act_fn(hidden_states)
        hidden_states = self.LayerNorm(hidden_states)
        return hidden_states


class BertLMPredictionHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.transform = BertPredictionHeadTransform(config)

        # The output weights are the same as the input embeddings, but there is
        # an output-only bias for each token.
        self.decoder = nn.Linear(config.hidden_size,
                                 config.vocab_size,
                                 bias=False)

        self.bias = nn.Parameter(torch.zeros(config.vocab_size))

        # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
        self.decoder.bias = self.bias

    def forward(self, hidden_states):
        hidden_states = self.transform(hidden_states)
        hidden_states = self.decoder(hidden_states)
        return hidden_states


class BertOnlyMLMHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.predictions = BertLMPredictionHead(config)

    def forward(self, sequence_output):
        prediction_scores = self.predictions(sequence_output)
        return prediction_scores


class BertPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = BertConfig
    base_model_prefix = "bert"
    _keys_to_ignore_on_load_missing = [r"position_ids"]

    def _init_weights(self, module):
        """Initialize the weights"""
        if isinstance(module, (nn.Linear, nn.Embedding)):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0,
                                       std=self.config.initializer_range)
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()


class BertModel(BertPreTrainedModel):
    """
    The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
    cross-attention is added between the self-attention layers, following the architecture described in `Attention is
    all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
    Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
    argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
    input to the forward pass.
    """
    def __init__(self, config, add_pooling_layer=False):
        super().__init__(config)
        self.config = config

        self.embeddings = BertEmbeddings(config)

        self.encoder = BertEncoder(config)

        self.pooler = BertPooler(config) if add_pooling_layer else None

        self.init_weights()

    def get_input_embeddings(self):
        return self.embeddings.word_embeddings

    def set_input_embeddings(self, value):
        self.embeddings.word_embeddings = value

    def _prune_heads(self, heads_to_prune):
        """
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        """
        for layer, heads in heads_to_prune.items():
            self.encoder.layer[layer].attention.prune_heads(heads)

    def get_extended_attention_mask(
        self,
        attention_mask: Tensor,
        input_shape: Tuple[int],
        device: device,
        is_decoder: bool,
        has_query: bool = False,
    ) -> Tensor:
        """
        Makes broadcastable attention and causal masks so that future and masked tokens are ignored.

        Arguments:
            attention_mask (:obj:`torch.Tensor`):
                Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
            input_shape (:obj:`Tuple[int]`):
                The shape of the input to the model.
            device: (:obj:`torch.device`):
                The device of the input to the model.

        Returns:
            :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
        """
        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
        # ourselves in which case we just need to make it broadcastable to all heads.
        if attention_mask.dim() == 3:
            extended_attention_mask = attention_mask[:, None, :, :]
        elif attention_mask.dim() == 2:
            # Provided a padding mask of dimensions [batch_size, seq_length]
            # - if the model is a decoder, apply a causal mask in addition to the padding mask
            # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
            if is_decoder:
                batch_size, seq_length = input_shape

                seq_ids = torch.arange(seq_length, device=device)
                causal_mask = (seq_ids[None, None, :].repeat(
                    batch_size, seq_length, 1) <= seq_ids[None, :, None])

                # add a prefix ones mask to the causal mask
                # causal and attention masks must have same type with pytorch version < 1.3
                causal_mask = causal_mask.to(attention_mask.dtype)

                if causal_mask.shape[1] < attention_mask.shape[1]:
                    prefix_seq_len = attention_mask.shape[
                        1] - causal_mask.shape[1]
                    if has_query:  # UniLM style attention mask
                        causal_mask = torch.cat(
                            [
                                torch.zeros(
                                    (batch_size, prefix_seq_len, seq_length),
                                    device=device,
                                    dtype=causal_mask.dtype,
                                ),
                                causal_mask,
                            ],
                            axis=1,
                        )
                    causal_mask = torch.cat(
                        [
                            torch.ones(
                                (batch_size, causal_mask.shape[1],
                                 prefix_seq_len),
                                device=device,
                                dtype=causal_mask.dtype,
                            ),
                            causal_mask,
                        ],
                        axis=-1,
                    )
                extended_attention_mask = (causal_mask[:, None, :, :] *
                                           attention_mask[:, None, None, :])
            else:
                extended_attention_mask = attention_mask[:, None, None, :]
        else:
            raise ValueError(
                "Wrong shape for input_ids (shape {}) or attention_mask (shape {})"
                .format(input_shape, attention_mask.shape))

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
        extended_attention_mask = extended_attention_mask.to(
            dtype=self.dtype)  # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
        return extended_attention_mask

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        position_ids=None,
        head_mask=None,
        query_embeds=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_values=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        is_decoder=False,
    ):
        r"""
        encoder_hidden_states  (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
            Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
            the model is configured as a decoder.
        encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
            Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
            the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.
        past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
            Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
            If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
            (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
            instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
        use_cache (:obj:`bool`, `optional`):
            If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
            decoding (see :obj:`past_key_values`).
        """
        output_attentions = (output_attentions if output_attentions is not None
                             else self.config.output_attentions)
        output_hidden_states = (output_hidden_states
                                if output_hidden_states is not None else
                                self.config.output_hidden_states)
        return_dict = (return_dict if return_dict is not None else
                       self.config.use_return_dict)

        # use_cache = use_cache if use_cache is not None else self.config.use_cache

        if input_ids is None:
            assert (
                query_embeds is not None
            ), "You have to specify query_embeds when input_ids is None"

        # past_key_values_length
        past_key_values_length = (past_key_values[0][0].shape[2] -
                                  self.config.query_length
                                  if past_key_values is not None else 0)

        query_length = query_embeds.shape[1] if query_embeds is not None else 0

        embedding_output = self.embeddings(
            input_ids=input_ids,
            position_ids=position_ids,
            query_embeds=query_embeds,
            past_key_values_length=past_key_values_length,
        )

        input_shape = embedding_output.size()[:-1]
        batch_size, seq_length = input_shape
        device = embedding_output.device

        if attention_mask is None:
            attention_mask = torch.ones(
                ((batch_size, seq_length + past_key_values_length)),
                device=device)

        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
        # ourselves in which case we just need to make it broadcastable to all heads.
        if is_decoder:
            extended_attention_mask = self.get_extended_attention_mask(
                attention_mask,
                input_ids.shape,
                device,
                is_decoder,
                has_query=(query_embeds is not None),
            )
        else:
            extended_attention_mask = self.get_extended_attention_mask(
                attention_mask, input_shape, device, is_decoder)

        # If a 2D or 3D attention mask is provided for the cross-attention
        # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
        if encoder_hidden_states is not None:
            if type(encoder_hidden_states) == list:
                encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[
                    0].size()
            else:
                (
                    encoder_batch_size,
                    encoder_sequence_length,
                    _,
                ) = encoder_hidden_states.size()
            encoder_hidden_shape = (encoder_batch_size,
                                    encoder_sequence_length)

            if type(encoder_attention_mask) == list:
                encoder_extended_attention_mask = [
                    self.invert_attention_mask(mask)
                    for mask in encoder_attention_mask
                ]
            elif encoder_attention_mask is None:
                encoder_attention_mask = torch.ones(encoder_hidden_shape,
                                                    device=device)
                encoder_extended_attention_mask = self.invert_attention_mask(
                    encoder_attention_mask)
            else:
                encoder_extended_attention_mask = self.invert_attention_mask(
                    encoder_attention_mask)
        else:
            encoder_extended_attention_mask = None

        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
        head_mask = self.get_head_mask(head_mask,
                                       self.config.num_hidden_layers)

        encoder_outputs = self.encoder(
            embedding_output,
            attention_mask=extended_attention_mask,
            head_mask=head_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_extended_attention_mask,
            past_key_values=past_key_values,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            query_length=query_length,
        )
        sequence_output = encoder_outputs[0]
        pooled_output = (self.pooler(sequence_output)
                         if self.pooler is not None else None)

        if not return_dict:
            return (sequence_output, pooled_output) + encoder_outputs[1:]

        return BaseModelOutputWithPoolingAndCrossAttentions(
            last_hidden_state=sequence_output,
            pooler_output=pooled_output,
            past_key_values=encoder_outputs.past_key_values,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
            cross_attentions=encoder_outputs.cross_attentions,
        )


class BertLMHeadModel(BertPreTrainedModel):

    _keys_to_ignore_on_load_unexpected = [r"pooler"]
    _keys_to_ignore_on_load_missing = [
        r"position_ids", r"predictions.decoder.bias"
    ]

    def __init__(self, config):
        super().__init__(config)

        self.bert = BertModel(config, add_pooling_layer=False)
        self.cls = BertOnlyMLMHead(config)

        self.init_weights()

    def get_output_embeddings(self):
        return self.cls.predictions.decoder

    def set_output_embeddings(self, new_embeddings):
        self.cls.predictions.decoder = new_embeddings

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        position_ids=None,
        head_mask=None,
        query_embeds=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        labels=None,
        past_key_values=None,
        use_cache=True,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        return_logits=False,
        is_decoder=True,
        reduction="mean",
    ):
        r"""
        encoder_hidden_states  (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
            Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
            the model is configured as a decoder.
        encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
            Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
            the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
            Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
            ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
            ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
        past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
            Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
            If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
            (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
            instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
        use_cache (:obj:`bool`, `optional`):
            If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
            decoding (see :obj:`past_key_values`).
        Returns:
        Example::
            >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
            >>> import torch
            >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
            >>> config = BertConfig.from_pretrained("bert-base-cased")
            >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
            >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
            >>> outputs = model(**inputs)
            >>> prediction_logits = outputs.logits
        """
        return_dict = (return_dict if return_dict is not None else
                       self.config.use_return_dict)
        if labels is not None:
            use_cache = False
        if past_key_values is not None:
            query_embeds = None

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            head_mask=head_mask,
            query_embeds=query_embeds,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            past_key_values=past_key_values,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            is_decoder=is_decoder,
        )

        sequence_output = outputs[0]
        if query_embeds is not None:
            sequence_output = outputs[0][:, query_embeds.shape[1]:, :]

        prediction_scores = self.cls(sequence_output)

        if return_logits:
            return prediction_scores[:, :-1, :].contiguous()

        lm_loss = None
        if labels is not None:
            # we are doing next-token prediction; shift prediction scores and input ids by one
            shifted_prediction_scores = prediction_scores[:, :
                                                          -1, :].contiguous()
            labels = labels[:, 1:].contiguous()
            loss_fct = CrossEntropyLoss(reduction=reduction,
                                        label_smoothing=0.1)
            lm_loss = loss_fct(
                shifted_prediction_scores.view(-1, self.config.vocab_size),
                labels.view(-1),
            )
            if reduction == "none":
                lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1)

        if not return_dict:
            output = (prediction_scores, ) + outputs[2:]
            return ((lm_loss, ) + output) if lm_loss is not None else output

        return CausalLMOutputWithCrossAttentions(
            loss=lm_loss,
            logits=prediction_scores,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            cross_attentions=outputs.cross_attentions,
        )

    def prepare_inputs_for_generation(self,
                                      input_ids,
                                      query_embeds,
                                      past=None,
                                      attention_mask=None,
                                      **model_kwargs):
        # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
        if attention_mask is None:
            attention_mask = input_ids.new_ones(input_ids.shape)
        query_mask = input_ids.new_ones(query_embeds.shape[:-1])
        attention_mask = torch.cat([query_mask, attention_mask], dim=-1)

        # cut decoder_input_ids if past is used
        if past is not None:
            input_ids = input_ids[:, -1:]

        return {
            "input_ids":
            input_ids,
            "query_embeds":
            query_embeds,
            "attention_mask":
            attention_mask,
            "past_key_values":
            past,
            "encoder_hidden_states":
            model_kwargs.get("encoder_hidden_states", None),
            "encoder_attention_mask":
            model_kwargs.get("encoder_attention_mask", None),
            "is_decoder":
            True,
        }

    def _reorder_cache(self, past, beam_idx):
        reordered_past = ()
        for layer_past in past:
            reordered_past += (tuple(
                past_state.index_select(0, beam_idx)
                for past_state in layer_past), )
        return reordered_past


class BertForMaskedLM(BertPreTrainedModel):

    _keys_to_ignore_on_load_unexpected = [r"pooler"]
    _keys_to_ignore_on_load_missing = [
        r"position_ids", r"predictions.decoder.bias"
    ]

    def __init__(self, config):
        super().__init__(config)

        self.bert = BertModel(config, add_pooling_layer=False)
        self.cls = BertOnlyMLMHead(config)

        self.init_weights()

    def get_output_embeddings(self):
        return self.cls.predictions.decoder

    def set_output_embeddings(self, new_embeddings):
        self.cls.predictions.decoder = new_embeddings

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        position_ids=None,
        head_mask=None,
        query_embeds=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        return_logits=False,
        is_decoder=False,
    ):
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
            Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
            config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
            (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
        """

        return_dict = (return_dict if return_dict is not None else
                       self.config.use_return_dict)

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            head_mask=head_mask,
            query_embeds=query_embeds,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            is_decoder=is_decoder,
        )

        if query_embeds is not None:
            sequence_output = outputs[0][:, query_embeds.shape[1]:, :]
        prediction_scores = self.cls(sequence_output)

        if return_logits:
            return prediction_scores

        masked_lm_loss = None
        if labels is not None:
            loss_fct = CrossEntropyLoss()  # -100 index = padding token
            masked_lm_loss = loss_fct(
                prediction_scores.view(-1, self.config.vocab_size),
                labels.view(-1))

        if not return_dict:
            output = (prediction_scores, ) + outputs[2:]
            return (((masked_lm_loss, ) +
                     output) if masked_lm_loss is not None else output)

        return MaskedLMOutput(
            loss=masked_lm_loss,
            logits=prediction_scores,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

dataset_batch.py

# -*- coding: utf-8 -*-
from torch.utils.data import Dataset
import torch
import json
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader
import random


class QADataset(Dataset):
    def __init__(self, data_path,train) -> None:
        super().__init__()
 
        self.data = []

        
        data = pd.read_csv(data_path).dropna()
        print(data.columns)
        condition = (data['answer'].str.len() <= 1000) & (data['summary'].str.len() <= 500)

        filtered_data = data[condition]

        with open("data/corpus.tsv","r") as f_read:
            corpus = [i.split()[-1] for i in f_read.readlines()]

        retell_prompts = ["請複述這段被壓縮的內容",
                        "複述這段被壓縮的內容",
                        "請將被壓縮的內容複述出來",]

        summary_prompts = ["請總結被壓縮的資訊",
                    "還原被壓縮資訊的主要內容",
                    "請寫出被壓縮資訊的主要內容",
                    "請對之前壓縮的資訊進行概括",
                    "請提煉出之前被壓縮資訊的核心要點",
                    "請歸納一下之前被壓縮的內容的主旨"]

        if train:
            # 過濾出符合長度條件的文章
            # filtered_data1000 = list(filter(self.filter_by_length1000, data["answer"]))

            for idx in range(5000):
            
                # if not line or line == "" or len(line) < 50 or len(line) > 2000:
                #     continue
                
                # 隨機確定重複次數(1到5次)
                repeat_count = random.randint(1, 10)

                flag_context = "<context> "*repeat_count

                prompt = random.choice(retell_prompts)
                selected_articles = random.sample(corpus, repeat_count)
                selected_articles_ = "[SEP]".join(selected_articles)

                text = f'<reserved_106>{prompt}{flag_context}<reserved_107>{selected_articles_}'
                test_data = {"context":selected_articles,"text_input":text,"label":selected_articles_}

                self.data.append(
                    test_data
                )

            # for idx in range(5000):
            #     repeat_count = random.randint(1, 1)

            #     flag_context = "<context> "*repeat_count

            #     selected_articles = random.sample(filtered_data150, repeat_count)
            #     selected_articles_ = " ".join(selected_articles)

            #     text = f'<|User|>:請複述這段話{flag_context} <|Bot|>:{selected_articles_}'
            #     test_data = {"samples":{"context":selected_articles,"text_input":[text]}}

            #     self.data.append(
            #         test_data
            #     )    

            for idx,(answer,summary) in enumerate(zip(filtered_data["answer"],filtered_data["summary"])):
                
                answer = [answer[:1000]]
                flag_context = "<context> "

                prompt = random.choice(summary_prompts)

                # user_token: <reserved_106> assisent_token: <reserved_107> eoa: </s>
                text = f'<reserved_106>{prompt}{flag_context}<reserved_107>{summary}'
                test_data = {"context":answer,"text_input":text,"label":summary}

                self.data.append(
                    test_data
                )   

            # for idx in range(10000):
            #     repeat_count = random.randint(1, 1)

            #     flag_context = "<context> "*repeat_count

            #     selected_articles = random.sample(filtered_data1500, repeat_count)
            #     selected_articles_ = " ".join(selected_articles)

            #     text = f'<|User|>:請複述這段話{flag_context} <|Bot|>:{selected_articles_}'
            #     test_data = {"samples":{"context":selected_articles,"text_input":[text]}}

            #     self.data.append(
            #         test_data
            #     )


            print("data load , size:", len(self.data))

        else:
            for idx in range(100):
            
                # if not line or line == "" or len(line) < 50 or len(line) > 2000:
                #     continue
                
                # 隨機確定重複次數(1到5次)
                repeat_count = random.randint(3, 5)

                flag_context = "<context> "*repeat_count

                prompt = random.choice(retell_prompts)
                selected_articles = random.sample(corpus, repeat_count)
                selected_articles_ = "[SEP]".join(selected_articles)

                text = f'<reserved_106>{prompt}{flag_context}<reserved_107>'
                test_data = {"context":selected_articles,"text_input":text,"label":selected_articles_}

                self.data.append(
                    test_data
                )


    # 建立一個函式來過濾文章長度
    @staticmethod
    def filter_by_length150(article):
        return 180 <= len(article) <= 200

    @staticmethod
    def filter_by_length1000(article):
        return 50 <= len(article) <= 1000

    @staticmethod
    def filter_by_length1500(article):
        return 500 <= len(article) <= 1500

    def __getitem__(self, index):
        item_data = self.data[index]

        return item_data

    def __len__(self):
        return len(self.data)


if __name__ == "__main__":
    data_path = "QA_5000_summary.csv"
    dataset = QADataset(data_path,train=True)
    # print(dataset[0])
    val_params = {
        "batch_size": 2,
        "shuffle": False,
        "num_workers": 0,
    }

    def collate_fn(batch):
        """
        對batch資料進行處理
        :param batch: [一個getitem的結果,getitem的結果,getitem的結果]
        :return: 元組
        """
        # 初始化一個空字典來儲存合併後的結果
        merged_dict = {}

        # 遍歷列表中的每個字典
        for d in batch:
            # 遍歷每個字典中的鍵值對
            for key, value in d.items():
                # 如果鍵已經存在於merged_dict中,將值合併為一個字串,用逗號分隔
                if key in merged_dict:
                    merged_dict[key].append(value)
                else:
                    # 如果鍵不存在於merged_dict中,直接新增到merged_dict中
                    merged_dict[key] = [value]
        # 輸出合併後的結果
        # print(merged_dict)

        return merged_dict

    val_loader = DataLoader(dataset, **val_params,collate_fn=collate_fn)

    for i in val_loader:
        print(i)
        break
    

train_batch.py

# -*- coding: utf-8 -*-
import pandas as pd
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModel
# from dataset_batch_en import QADataset
# from dataset_rerank import QADataset
# from dataset_rerank_en_gpt import QADataset
from dataset_rerank_en import QADataset

from peft import LoraConfig, get_peft_model, TaskType
from tqdm import tqdm
import torch
import os, time, sys
import numpy as np
from modeling_RALLM import RALLM
import argparse
import deepspeed
from torch.nn.parallel import DataParallel


# 設定CUDA裝置可見性,例如僅使用第一個GPU
# os.environ["CUDA_VISIBLE_DEVICES"] = "3"

parser = argparse.ArgumentParser()
parser.add_argument("--is_compress", default=True, type=bool)
parser.add_argument("--compressibility_factor", default=0, type=float,dest="0-1")
parser.add_argument("--LLM_model", default="Baichuan2_7B", type=str)
parser.add_argument("--output", default="output_english_longformer_rerank100k_2", type=str)
parser.add_argument("--encoder", default="longformer", type=str)
parser.add_argument("--query_tokens", default=98, type=int)
parser.add_argument("--load_path", default="output_english_longformer_msmarco2019/checkpoint-87500", type=str)
parser.add_argument("--local_rank", default=0, type=int)
parser.add_argument("--num_train_epochs", default=10, type=int)
parser.add_argument("--learning_rate", default=5e-3, type=int)
parser.add_argument("--weight_decay", default=0.005, type=int)
parser.add_argument("--per_device_train_batch_size", default=6, type=int)
parser.add_argument("--max_length", default=4096, type=int)
parser.add_argument("--use_lora", default=True, type=bool)
parser.add_argument("--use_lora_gpt2", default=False, type=bool)
parser.add_argument("--train_dataset", default="data/qa_3w_summary.csv", type=str)
parser.add_argument("--epochs", default=1, type=int)
parser.add_argument("--batch_size", default=1, type=int)



args = parser.parse_args()


def train(epoch, model, loader, optimizer,scaler, gradient_accumulation_steps,model_output_dir):
    model.train()

    time1 = time.time()
    losses = []
    train_bar = tqdm(loader,total=len(loader))
    for index, data in enumerate(train_bar):
        optimizer.zero_grad()
        with torch.autocast(device_type="cuda",dtype=torch.float16):
            
            # print(data)
            outputs = model(model,**data)
            loss = outputs.loss
        # 反向傳播,計算當前梯度
            loss.requires_grad_(True)
            losses.append(loss.item())

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

        if (index+1) % 5000 == 0:
            model_output_dir_ = os.path.join(model_output_dir,f"epoch{epoch}")
            model_save_path = os.path.join(model_output_dir_,"index_{}".format(index))
            if os.path.exists(model_save_path):
                pass
            else:
                os.makedirs(model_save_path)

            torch.save(model.state_dict(), os.path.join(model_save_path,"LLM_model_{:.6f}.pth".format(np.mean(losses))))

        train_bar.set_description("epoch:{} idx:{} loss:{:.6f}".format(epoch,index,np.mean(losses)))



def validate( model,  loader):
    model.eval()

    predictions = []
    actuals = []
    
    with torch.no_grad():
        with torch.autocast(device_type="cuda",dtype=torch.float16):
            for _, data in enumerate(tqdm(loader, file=sys.stdout, desc="Validation Data")):

                text = data["text_input"]
                context = data["context"]
                label = data["label"]
                print(text)
                print("len context:",len(context))
                for text_,context_ in zip(text,context):
                    preds = model.generate(
                        text=text_,context = [context_]
                    )
                    print(preds)
                    print(label)
                    predictions.append(preds)
                actuals.extend(label)
    return predictions, actuals


def main():
    epochs = args.epochs
    batch_size = args.batch_size
    lr = 1e-5
    gradient_accumulation_steps = 16
    model_output_dir = args.output
    # train_path = "qa_3w_summary.csv"
    train_path = args.train_dataset
    val_path = args.train_dataset

    device = torch.device(f"cuda:{args.local_rank}")
    model = RALLM(args)
    model = model.to(device)

    if args.use_lora:
        print("使用lora訓練模型"+"*"*10)
        from peft import LoraConfig, TaskType, get_peft_model,PeftModel,AutoPeftModelForCausalLM

        peft_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            target_modules=["W_pack",],
            inference_mode=False,
            r=256,
            lora_alpha=512,
            lora_dropout=0.1,
        )
        model.LLM_model.enable_input_require_grads()
        model.LLM_model = get_peft_model(model.LLM_model, peft_config)

    if args.use_lora_gpt2:
        print("使用lora訓練模型"+"*"*10)
        from peft import LoraConfig, TaskType, get_peft_model,PeftModel,AutoPeftModelForCausalLM

        peft_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            target_modules=["wte","c_attn",],
            inference_mode=False,
            r=256,
            lora_alpha=512,
            lora_dropout=0.1,
        )
        # model.LLM_model.enable_input_require_grads()
        model.context_encoder = get_peft_model(model.context_encoder, peft_config)


    print(model)

    torch.cuda.empty_cache()  # 釋放視訊記憶體

    if args.load_path:
        base_load_path = args.load_path
        # 列出所有分塊模型引數檔案的檔名

        if base_load_path.endswith(".pth"):
            state_dict = torch.load(base_load_path,map_location=device)
        else:
            file_list = ['pytorch_model.bin']
            # 建立一個空的模型狀態字典
            state_dict = {}
            # 遍歷所有分塊檔案並載入它們
            for file_name in file_list:
                # 載入單個分塊檔案的模型引數
                part_state_dict = torch.load(os.path.join(base_load_path,file_name),map_location=device)
                
                # 將載入的模型引數合併到總的模型狀態字典中
                state_dict.update(part_state_dict)
        # 將合併後的模型狀態字典載入到模型中
        print("state_dict:")
        print(state_dict.keys())
    
        model.load_state_dict(state_dict,strict=False)


    for param in model.context_encoder.parameters():
        param.requires_grad = False
    # layers_to_modify = [30,31,32,33,34,35]
    # # Iterate over all named parameters in the model
    # for name, param in model.context_encoder.named_parameters():
    #     # Check if the parameter belongs to the specified layers
    #     if any(f"context_encoder.h.{layer}." in name for layer in layers_to_modify):
    #         # Set requires_grad to True or False depending on whether you want the parameter to be trainable or not
    #         param.requires_grad = True  # or False if you want to freeze the layer
    for param in model.ln_features.parameters():
        param.requires_grad = True
    for param in model.Qformer.parameters():
        param.requires_grad = True

    # 遍歷每一層並凍結引數
    # for param in model.LLM_model.parameters():
    #     param.requires_grad = False

    # 凍結除了lora_A和lora_B以外的所有層
    # trained = []
    # untrained = []
    # for name, param in model.LLM_model.named_parameters():
    #     # if 'v_proj.lora_A' in name or 'v_proj.lora_B' in name or 'q_proj.lora_B' in name or 'q_proj.lora_B' in name:
    #     if 'lora_A' in name or 'lora_B' in name:
    #         param.requires_grad = True
    #         trained.append(name)
    #     else:
    #         param.requires_grad = False
    #         untrained.append(name)


    # Print trainable and non-trainable parameters
    trainable_params = []
    non_trainable_params = []

    for name, param in model.named_parameters():
        if param.requires_grad:
            trainable_params.append(name)
        else:
            non_trainable_params.append(name)

    print("Trainable Parameters:")
    print("\n".join(trainable_params))

    print("\nNon-Trainable Parameters:")
    print("\n".join(non_trainable_params))

    # setup peft
    # peft_config = LoraConfig(
    #     task_type=TaskType.CAUSAL_LM,
    #     target_modules=["q_proj","v_proj"], #W_pack. query_key_value
    #     inference_mode=False,
    #     r=lora_rank,
    #     lora_alpha=lora_alpha,
    #     lora_dropout=0.1
    # )
    # model = get_peft_model(model, peft_config)


    # model.is_parallelizable = True
    # model.model_parallel = True
    # model.print_trainable_parameters()
    # 轉為半精度
    # model.LLM_model = model.LLM_model.half()
    # model.float()

    scaler = torch.cuda.amp.GradScaler()
    def collate_fn(batch):
        """
        對batch資料進行處理
        :param batch: [一個getitem的結果,getitem的結果,getitem的結果]
        :return: 元組
        """
        # 初始化一個空字典來儲存合併後的結果
        merged_dict = {}

        # 遍歷列表中的每個字典
        for d in batch:
            # 遍歷每個字典中的鍵值對
            for key, value in d.items():
                # 如果鍵已經存在於merged_dict中,將值合併為一個字串,用逗號分隔
                if key in merged_dict:
                    merged_dict[key].append(value)
                else:
                    # 如果鍵不存在於merged_dict中,直接新增到merged_dict中
                    merged_dict[key] = [value]
        # 輸出合併後的結果
        # print(merged_dict)

        return merged_dict

    print("Start Load Train Data...")
    train_params = {
        "batch_size": batch_size,
        "shuffle": True,
        "num_workers": 0,
    }
    training_set = QADataset(train_path,train=True)
    training_loader = DataLoader(training_set, **train_params,collate_fn=collate_fn)
    print("Start Load Validation Data...")
    val_params = {
        "batch_size": batch_size,
        "shuffle": False,
        "num_workers": 0,
    }
    val_set = QADataset(val_path,train=False)
    val_loader = DataLoader(val_set, **val_params,collate_fn=collate_fn)

    # optimizer = torch.optim.AdamW([{'params': model.bert_encoder.parameters(), 'lr': 1e-5},
    #                                 {'params': model.Qformer.parameters(), 'lr': 1e-3},
    #                                 {'params': model.ln_features.parameters(), 'lr': 1e-3},
    #                                 {'params': model.internlm_model.parameters(), 'lr': 1e-5},
    #                                 {'params': query_tokens_clone, 'lr': 1e-3}] #
    #                                 )

    optimizer = torch.optim.AdamW([{'params': model.parameters(), 'lr': lr}])

    # device_ids = [1,3,6,7]
    # model = DataParallel(model, device_ids=device_ids)


    print("Start Training...")
    for epoch in range(epochs):
        # train(epoch, model, training_loader, optimizer,scaler, gradient_accumulation_steps,model_output_dir)
        # print("Save Model To ",    加)
        # model.save_pretrained(model_output_dir)
        # 驗證
        # print("Start Validation...")
        with torch.no_grad():
            predictions, actuals = validate(model, val_loader)
            # 驗證結果儲存
            final_df = pd.DataFrame({"Generated Text": predictions, "Actual Text": actuals})
            val_data_path = os.path.join(model_output_dir, f"predictions_{epoch}.csv")
            final_df.to_csv(val_data_path)
            print("Validation Data To ", val_data_path)


if __name__ == '__main__':
    main()

test_chat.py

# -*- coding: utf-8 -*-
import pandas as pd
from torch.utils.data import DataLoader
from transformers import AutoModel, AutoTokenizer,AutoModelForMaskedLM,BertTokenizer, GPT2LMHeadModel, TextGenerationPipeline,AutoModelForCausalLM,AutoConfig
from transformers.generation.utils import GenerationConfig

from peft import LoraConfig, get_peft_model, TaskType
from tqdm import tqdm
import torch
import os, time, sys
import numpy as np
from modeling_RALLM import RALLM
import argparse
from torch import autocast

parser = argparse.ArgumentParser()
parser.add_argument("--is_compress", default=False, type=bool)
parser.add_argument("--compressibility_factor", default=0.1, type=float,dest="0-1")
parser.add_argument("--LLM_model", default="Baichuan2_7B", type=str)
parser.add_argument("--output", default="output_corpus_2", type=str)
parser.add_argument("--encoder", default="gpt2_large", type=str)
parser.add_argument("--query_tokens", default=98, type=int)
parser.add_argument("--load_path", default="output_corpus_lora/checkpoint-200004", type=str)
parser.add_argument("--local_rank", default=0, type=int)
parser.add_argument("--num_train_epochs", default=10, type=int)
parser.add_argument("--learning_rate", default=5e-5, type=int)
parser.add_argument("--weight_decay", default=0.005, type=int)
parser.add_argument("--per_device_train_batch_size", default=6, type=int)
parser.add_argument("--max_length", default=4096, type=int)
parser.add_argument("--use_lora", default=False, type=bool)
parser.add_argument("--use_lora_gpt2", default=True, type=bool)



args = parser.parse_args()

def chat( model):

    while True:
        context1 = input("輸入context:")
        context2 = input("輸入context2:")
    #     context = """NVIDIA的A6000顯示卡是一款面向專業領域的高效能顯示卡。關於它的雙精度(Double Precision)、單精度(Single Precision)和半精度(Half Precision)的算力,我們可以參考官方提供的規格引數。截至我最後更新的資訊(2023年4月),以下是A6000顯示卡的相關算力資料:雙精度(Double Precision): A6000顯示卡在雙精度計算方面的效能通常不如單精度和半精度,因為雙精度計算需要更多的計算資源和頻寬。具體數值因顯示卡的不同批次和製造工藝的微小差異可能有所不同。單精度(Single Precision): A6000在單精度計算方面的效能通常很高,適合於大多數圖形處理和一些科學計算任務。單精度計算是大多數顯示卡的主要優勢。半精度(Half Precision): 半精度計算主要用於某些機器學習和深度學習應用,能提供更高的吞吐量。A6000顯示卡在半精度計算方面的效能通常很高。
    #     """
        flag_context = "<context> "*2
        text = f'<reserved_106>請複述這段被壓縮的內容{flag_context} <reserved_107>'
        data = {"context":[[context1,context2]],"text_input":text}

        model.eval()
        with torch.no_grad():
            with autocast(device_type="cuda",dtype=torch.float16):
                text = data["text_input"]
                context = data["context"]
                preds = model.generate(
                    text=text,context = context
                )
        print("輸出:",preds) 


def main():
   
    model = RALLM(args)   # 釋放不再需要的模型
    device = torch.device(f"cuda:{args.local_rank}")
    model.to(device)


    if args.use_lora:
        from peft import LoraConfig, TaskType, get_peft_model,PeftModel,AutoPeftModelForCausalLM

        peft_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            target_modules=["W_pack",],
            inference_mode=False,
            r=256,
            lora_alpha=512,
            lora_dropout=0.1,
        )
        model.LLM_model.enable_input_require_grads()
        model.LLM_model = get_peft_model(model.LLM_model, peft_config)


    if args.use_lora_gpt2:
        print("使用lora訓練模型"+"*"*10)
        from peft import LoraConfig, TaskType, get_peft_model,PeftModel,AutoPeftModelForCausalLM

        peft_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            target_modules=["c_attn",],
            inference_mode=False,
            r=64,
            lora_alpha=256,
            lora_dropout=0.1,
        )
        # model.LLM_model.enable_input_require_grads()
        model.context_encoder = get_peft_model(model.context_encoder, peft_config)


    print(model)


    base_load_path = "output_qa3w_lora_gpt2_base_corpus"
    # 列出所有分塊模型引數檔案的檔名
    file_list = ['pytorch_model.bin']

    # 建立一個空的模型狀態字典
    state_dict = {}

    # 遍歷所有分塊檔案並載入它們
    for file_name in file_list:
        # 載入單個分塊檔案的模型引數
        part_state_dict = torch.load(os.path.join(base_load_path,file_name),map_location=f"cuda:{args.local_rank}")
        
        # 將載入的模型引數合併到總的模型狀態字典中
        state_dict.update(part_state_dict)

    # 將合併後的模型狀態字典載入到模型中
    model.load_state_dict(state_dict)
    model.LLM_model.generation_config = GenerationConfig.from_pretrained(base_load_path)



    # 載入模型的引數
    # load_path = '/data2/xinyuuliu/InternLM-XComposer/output12/epoch9/index_29999/LLM_model_0.109371.pth'
    # checkpoint = torch.load(load_path,map_location="cuda:0") #,map_location="cuda:3"
    # # 將引數載入到模型中
    # model.load_state_dict(checkpoint)
    # 轉為半精度
    # model.LLM_model = model.LLM_model.half()
    model = model.half()
    # model.float()

    chat(model)
 


if __name__ == '__main__':
    main()


ds_config.json

{
    "fp16": {
        "enabled": "auto",
        "loss_scale": 0,
        "loss_scale_window": 1000,
        "initial_scale_power": 16,
        "hysteresis": 2,
        "min_loss_scale": 1
    },
    "bf16": {
        "enabled": "auto"
    },

    "zero_optimization": {
        "stage": 2,
        "offload_optimizer": {
            "device": "none",
            "pin_memory": true
        },
        "allgather_partitions": true,
        "allgather_bucket_size": 2e8,
        "overlap_comm": true,
        "reduce_scatter": true,
        "reduce_bucket_size": 2e8,
        "contiguous_gradients": true
    },

    "gradient_accumulation_steps": "auto",
    "gradient_clipping": "auto",
    "steps_per_print": 10,
    "train_batch_size": "auto",
    "train_micro_batch_size_per_gpu": "auto",
    "wall_clock_breakdown": false

}

fine-tune.py

# This code is based on the revised code from fastchat based on tatsu-lab/stanford_alpaca.
from dataclasses import dataclass, field
import json
import math
import logging
import os
import random
from typing import Dict, Optional, List
import torch
from torch.utils.data import Dataset
from deepspeed import zero
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
import transformers
from transformers import Trainer, deepspeed
from transformers.trainer_pt_utils import LabelSmoother
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from accelerate.utils import DistributedType
from torchvision import transforms
from typing import Dict, Optional, Sequence, List
from modeling_RALLM import RALLM
# from dataset_batch import QADataset
from dataset_rerank_en import QADataset
# from dataset_rerank_en_gpt import QADataset
# from dataset_rerank import QADataset
import argparse
from transformers import AutoModel, AutoTokenizer,AutoModelForMaskedLM,BertTokenizer, GPT2LMHeadModel, TextGenerationPipeline,AutoModelForCausalLM,AutoConfig
from torch.optim import AdamW


IGNORE_TOKEN_ID = LabelSmoother.ignore_index

parser = argparse.ArgumentParser()

parser.add_argument("--is_compress", default=True, type=bool)
parser.add_argument("--compressibility_factor", default=0.1, type=float,dest="0-1")
parser.add_argument("--LLM_model", default="Baichuan2_7B", type=str)
parser.add_argument("--output", default="output_corpus_lora2", type=str)
parser.add_argument("--encoder", default="gpt2_large", type=str)
parser.add_argument("--query_tokens", default=98, type=int)
parser.add_argument("--load_path", default="output_english_longformer_rerank100k/checkpoint-112356", type=str)
parser.add_argument("--local_rank", default=-1, type=int)
parser.add_argument("--num_train_epochs", default=10, type=int)
parser.add_argument("--learning_rate", default=5e-5, type=float)
parser.add_argument("--weight_decay", default=0.01, type=float)
parser.add_argument("--per_device_train_batch_size", default=6, type=int)
parser.add_argument("--max_length", default=4096, type=int)
parser.add_argument("--use_lora", default=True, type=bool)
parser.add_argument("--use_lora_gpt2", default=False, type=bool)
parser.add_argument("--train_dataset", default="data/qa_3w_summary.csv", type=str)



args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.local_rank)


@dataclass
class TrainingArguments(transformers.TrainingArguments):
    # cache_dir: Optional[str] = field(default=None)
    optim: str = field(default="adamw_torch")
    local_rank: int = field(default=None)


@dataclass
class DataCollatorForSupervisedDataset(object):
    """Collate examples for supervised fine-tuning."""
    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        """
        對batch資料進行處理
        :param batch: [一個getitem的結果,getitem的結果,getitem的結果]
        :return: 元組
        """

        # 初始化一個空字典來儲存合併後的結果
        merged_dict = {}

        # 遍歷列表中的每個字典
        for d in instances:
            # 遍歷每個字典中的鍵值對
            for key, value in d.items():
                # 如果鍵已經存在於merged_dict中,將值合併為一個字串,用逗號分隔
                if key in merged_dict:
                    merged_dict[key].append(value)
                else:
                    # 如果鍵不存在於merged_dict中,直接新增到merged_dict中
                    merged_dict[key] = [value]
        # 輸出合併後的結果
        # print(merged_dict)

        return merged_dict



def train():
    global model

    train_path = args.train_dataset
    # train_path = "data/news_summary_30w.csv"
    # val_path = "QA_5000_summary.csv"
    
    device_map = None
    world_size = int(os.environ.get("WORLD_SIZE", 1))
    ddp = world_size != 1


    torch.cuda.device(0)
    torch.cuda.empty_cache()  # 釋放視訊記憶體


    # init model and tokenizer
    model = RALLM(args)   # 釋放不再需要的模型


    
    device = torch.device(f"cuda:{args.local_rank}")
    model.to(device)
    # torch.cuda.device(0)
    torch.cuda.empty_cache()  # 釋放視訊記憶體

    if args.use_lora:
        print("使用lora訓練模型"+"*"*10)
        from peft import LoraConfig, TaskType, get_peft_model,PeftModel,AutoPeftModelForCausalLM

        peft_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            target_modules=["W_pack",],
            inference_mode=False,
            r=256,
            lora_alpha=512,
            lora_dropout=0.1,
        )
        model.LLM_model.enable_input_require_grads()
        model.LLM_model = get_peft_model(model.LLM_model, peft_config)


    if args.use_lora_gpt2:
        print("使用lora訓練模型"+"*"*10)
        from peft import LoraConfig, TaskType, get_peft_model,PeftModel,AutoPeftModelForCausalLM

        peft_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            # target_modules=["wte","c_attn",],
            target_modules=["query","key","value","query_global","key_global","value_global"],
            inference_mode=False,
            r=128,
            lora_alpha=512,
            lora_dropout=0.1,
        )
        # model.LLM_model.enable_input_require_grads()
        model.context_encoder = get_peft_model(model.context_encoder, peft_config)



    print(model)

    for param in model.context_encoder.parameters():
        param.requires_grad = False
    # layers_to_modify = [27,28,29,30,31,32,33,34, 35]
    #     # Iterate over all named parameters in the model
    # for name, param in model.context_encoder.named_parameters():
    #     # Check if the parameter belongs to the specified layers
    #     if any(f"h.{layer}."  in name for layer in layers_to_modify):
    #         # Set requires_grad to True or False depending on whether you want the parameter to be trainable or not
    #         param.requires_grad = True  # or False if you want to freeze the layer
    #     # if f"ln_f"  in name:
    #     #     # Set requires_grad to True or False depending on whether you want the parameter to be trainable or not
    #     #     param.requires_grad = True  # or False if you want to freeze the layer
    #     if f"wte"  in name:
    #     #     # Set requires_grad to True or False depending on whether you want the parameter to be trainable or not
    #         param.requires_grad = True  # or False if you want to freeze the layer
    #     if f"wpe"  in name:
    #     #     # Set requires_grad to True or False depending on whether you want the parameter to be trainable or not
    #         param.requires_grad = True  # or False if you want to freeze the layer

    for param in model.ln_features.parameters():
        param.requires_grad = True


    # 遍歷每一層並凍結引數
    # for param in model.LLM_model.parameters():
    #     param.requires_grad = False

    # 凍結除了lora_A和lora_B以外的所有層
    trained = []
    untrained = []
    # for name, param in model.LLM_model.named_parameters():
    #     # if 'v_proj.lora_A' in name or 'v_proj.lora_B' in name or 'q_proj.lora_B' in name or 'q_proj.lora_B' in name:
    #     # if 'lora_A' in name or 'lora_B' in name or "layers.30" in name or "layers.31" in name or "embed_tokens" in name: 
    #     if 'lora_A' in name or 'lora_B' in name or "embed_tokens" in name: 

    #         param.requires_grad = True
    #         trained.append(name)
    #     else:
    #         param.requires_grad = False
    #         untrained.append(name)

    # print("可訓練的大模型層",trained)
    # print("不可訓練的大模型層",untrained)

    # Print trainable and non-trainable parameters
    trainable_params = []
    non_trainable_params = []

    for name, param in model.named_parameters():
        if param.requires_grad:
            trainable_params.append(name)
        else:
            non_trainable_params.append(name)

    print("Trainable Parameters:")
    print("\n".join(trainable_params))

    print("\nNon-Trainable Parameters:")
    print("\n".join(non_trainable_params))


    if args.load_path:
        base_load_path = args.load_path
    # 列出所有分塊模型引數檔案的檔名

        if base_load_path.endswith(".pth"):
            state_dict = torch.load(base_load_path,map_location=device)
        else:
            file_list = ['pytorch_model.bin']
            # 建立一個空的模型狀態字典
            state_dict = {}
            # 遍歷所有分塊檔案並載入它們
            for file_name in file_list:
                # 載入單個分塊檔案的模型引數
                part_state_dict = torch.load(os.path.join(base_load_path,file_name),map_location=device)
                
                # 將載入的模型引數合併到總的模型狀態字典中
                state_dict.update(part_state_dict)
        # 將合併後的模型狀態字典載入到模型中
        print("state_dict:")
        print(state_dict.keys())
        
        model.load_state_dict(state_dict,strict=False)

    # # 分離 model.Qformer 的引數和其他所有引數
    # qformer_params = set(model.Qformer.parameters())
    # other_params = [p for p in model.parameters() if p not in qformer_params]

    # # 建立引數組
    # param_groups = [
    #     {'params': list(qformer_params), 'lr': 1e-3},
    #     {'params': other_params, 'lr': 1e-5}
    # ]


    # 使用引數組建立 AdamW 最佳化器
    # optimizer = AdamW(param_groups)
    

    training_set = QADataset(train_path,train=True)
    # val_set = QADataset(val_path,train=False)

    print(training_set[0])

    # 設定訓練引數
    training_args = TrainingArguments(
        local_rank=args.local_rank,
        output_dir=args.output,          # 輸出目錄
        num_train_epochs=args.num_train_epochs,              # 訓練輪數
        per_device_train_batch_size=args.per_device_train_batch_size,  # 每個裝置的批大小
        warmup_steps=500,                # 預熱步驟
        weight_decay=0.01,               # 權重衰減
        logging_dir='./logs',            # 日誌目錄
        deepspeed = "ds_config.json",
        gradient_accumulation_steps = 1 ,
        save_strategy = "epoch" ,
        learning_rate = 5e-5 ,
        # lr_scheduler_type='linear',
        # logging_steps= 100,
    )

    data_collator = DataCollatorForSupervisedDataset()

    # Start trainner
    trainer = Trainer(
        model = model,
        tokenizer = model.LLM_tokenizer,
        train_dataset=training_set,
        # eval_dataset=val_set,
        data_collator=data_collator,
        args = training_args,
        # optimizers=(optimizer, None)  # 自定義最佳化器

    )

    trainer.train()
    trainer.save_state()
    trainer.save_model(output_dir=args.output)



if __name__ == "__main__":
    train()

#https://arxiv.org/pdf/2102.05951.pdf

fine-tune.sh

hostfile=""

# deepspeed --include localhost:1,2,3 --hostfile=$hostfile fine-tune.py  \
#     --report_to "none" \
#     --data_path "/data1/xinyuuliu/qa_data/professional_data/train_二階段.json" \
#     --model_name_or_path "/data1/xinyuuliu/Baichuan2-13B-Chat" \
#     --output_dir "output_lora3_1_2" \
#     --model_max_length  4000\
#     --num_train_epochs 10 \
#     --per_device_train_batch_size 4 \
#     --gradient_accumulation_steps 1 \
#     --save_strategy epoch \
#     --learning_rate 2e-4 \
#     --lr_scheduler_type constant \
#     --adam_beta1 0.9 \
#     --adam_beta2 0.98 \
#     --adam_epsilon 1e-8 \
#     --max_grad_norm 1.0 \
#     --weight_decay 1e-4 \
#     --warmup_ratio 0.0 \
#     --logging_steps 1 \
#     --gradient_checkpointing True \
#     --deepspeed ds_config.json \
#     --bf16 True \
#     --tf32 True \
#     --use_lora True \
#     --load_lora_path /data1/xinyuuliu/Baichuan2-main/fine-tune/output_lora3_1/checkpoint-8260
    # --use_NEFT True
    # --use_frozen True
# export CUDA_LAUNCH_BLOCKING=1

# CUDA_VISIBLE_DEVICES=“2,3,4,5,6,7” 
deepspeed --include localhost:0,1,2,3,4,5,6,7 --master_port 29501 --hostfile=$hostfile fine-tune.py \
    --encoder longformer \
    --query_tokens 32 \
    --output output_english_longformer_msmarco2019\
    --num_train_epochs 20 \
    --per_device_train_batch_size 1 \
    # --load_path /data2/xinyuuliu/Baichuan2_qformer_bert/output_30w/checkpoint-22488 \

相關文章