中文新聞情感分類 Bert-Pytorch-transformers

ckxllf發表於2019-12-24

  中文新聞情感分類 Bert-Pytorch-transformers

  使用pytorch框架以及transformers包,以及Bert的中文預訓練模型

  檔案目錄

  data

  Train_DataSet.csv

  Train_DataSet_Label.csv

  main.py

  NewsData.py

  #main.py

  from transformers import BertTokenizer

  from transformers import BertForSequenceClassification

  from transformers import BertConfig

  from transformers import BertPreTrainedModel

  import torch

  import torch.nn as nn

  from transformers import BertModel

  import time

  import argparse

  from NewsData import NewsData

  import os

  def get_train_args():

  parser=argparse.ArgumentParser()

  parser.add_argument('--batch_size',type=int,default=10,help = '每批資料的數量')

  parser.add_argument('--nepoch',type=int,default=3,help = '訓練的輪次')

  parser.add_argument('--lr',type=float,default=0.001,help = '學習率')

  parser.add_argument('--gpu',type=bool,default=True,help = '是否使用gpu')

  parser.add_argument('--num_workers',type=int,default=2,help='dataloader使用的執行緒數量')

  parser.add_argument('--num_labels',type=int,default=3,help='分類類數')

  parser.add_argument('--data_path',type=str,default='./data',help='資料路徑')

  opt=parser.parse_args()

  print(opt)

  return opt

  def get_model(opt):

  #類方法.from_pretrained()獲取預訓練模型,num_labels是分類的類數

  model = BertForSequenceClassification.from_pretrained('bert-base-chinese',num_labels=opt.num_labels)

  return model

  def get_data(opt):

  #NewsData繼承於pytorch的Dataset類

  trainset = NewsData(opt.data_path,is_train = 1)

  trainloader=torch.utils.data.DataLoader(trainset,batch_size=opt.batch_size,shuffle=True,num_workers=opt.num_workers)

  testset = NewsData(opt.data_path,is_train = 0)

  testloader=torch.utils.data.DataLoader(testset,batch_size=opt.batch_size,shuffle=False,num_workers=opt.num_workers)

  return trainloader,testloader

  def train(epoch,model,trainloader,testloader,optimizer,opt):

  print('\ntrain-Epoch: %d' % (epoch+1))

  model.train()

  start_time = time.time()

  print_step = int(len(trainloader)/10)

  for batch_idx,(sue,label,posi) in enumerate(trainloader):

  if opt.gpu:

  sue = sue.cuda()

  posi = posi.cuda()

  label = label.unsqueeze(1).cuda()

  optimizer.zero_grad()

  #輸入引數為詞列表、位置列表、標籤

  outputs = model(sue, position_ids=posi,labels = label)

  loss, logits = outputs[0],outputs[1]

  loss.backward()

  optimizer.step()

  if batch_idx % print_step == 0:

  print("Epoch:%d [%d|%d] loss:%f" %(epoch+1,batch_idx,len(trainloader),loss.mean()))

  print("time:%.3f" % (time.time() - start_time))

  def test(epoch,model,trainloader,testloader,opt):

  print('\ntest-Epoch: %d' % (epoch+1))

  model.eval()

  total=0

  correct=0

  with torch.no_grad():

  for batch_idx,(sue,label,posi) in enumerate(testloader):

  if opt.gpu:

  sue = sue.cuda()

  posi = posi.cuda()

  labels = label.unsqueeze(1).cuda()

  label = label.cuda()

  else:

  labels = label.unsqueeze(1)

  outputs = model(sue, labels=labels)

  loss, logits = outputs[:2]

  _,predicted=torch.max(logits.data,1)

  total+=sue.size(0)

  correct+=predicted.data.eq(label.data).cpu().sum()

  s = ("Acc:%.3f" %((1.0*correct.numpy())/total))

  print(s)

  if __name__=='__main__':

  opt = get_train_args()

  model = get_model(opt)

  trainloader,testloader = get_data(opt)

  if opt.gpu:

  model.cuda()

  optimizer=torch.optim.SGD(model.parameters(),lr=opt.lr,momentum=0.9)

  if not os.path.exists('./model.pth'):

  for epoch in range(opt.nepoch):

  train(epoch,model,trainloader,testloader,optimizer,opt)

  test(epoch,model,trainloader,testloader,opt)

  torch.save(model.state_dict(),'./model.pth')

  else: 鄭州治療婦科哪個醫院好

  model.load_state_dict(torch.load('model.pth'))

  print('模型存在,直接test')

  test(0,model,trainloader,testloader,opt)

  #NewsData.py

  from transformers import BertTokenizer

  from transformers import BertForSequenceClassification

  from transformers import BertConfig

  from transformers import BertPreTrainedModel

  import torch

  import torch.nn as nn

  from transformers import BertModel

  import time

  import argparse

  class NewsData(torch.utils.data.Dataset):

  def __init__(self,root,is_train = 1):

  self.tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')

  self.data_num = 7346

  self.x_list = []

  self.y_list = []

  self.posi = []

  with open(root + '/Train_DataSet.csv',encoding='UTF-8') as f:

  for i in range(self.data_num+1):

  line = f.readline()[:-1] + '這是一箇中性的資料'

  data_one_str = line.split(',')[len(line.split(','))-2]

  data_two_str = line.split(',')[len(line.split(','))-1]

  if len(data_one_str) < 6:

  z = len(data_one_str)

  data_one_str = data_one_str + ',' + data_two_str[0:min(200,len(data_two_str))]

  else:

  data_one_str = data_one_str

  if i==0:

  continue

  word_l = self.tokenizer.encode(data_one_str, add_special_tokens=False)

  if len(word_l)<100:

  while(len(word_l)!=100):

  word_l.append(0)

  else:

  word_l = word_l[0:100]

  word_l.append(102)

  l = word_l

  word_l = [101]

  word_l.extend(l)

  self.x_list.append(torch.tensor(word_l))

  self.posi.append(torch.tensor([i for i in range(102)]))

  with open(root + '/Train_DataSet_Label.csv',encoding='UTF-8') as f:

  for i in range(self.data_num+1):

  #print(i)

  label_one = f.readline()[-2]

  if i==0:

  continue

  label_one = int(label_one)

  self.y_list.append(torch.tensor(label_one))

  #訓練集或者是測試集

  if is_train == 1:

  self.x_list = self.x_list[0:6000]

  self.y_list = self.y_list[0:6000]

  self.posi = self.posi[0:6000]

  else:

  self.x_list = self.x_list[6000:]

  self.y_list = self.y_list[6000:]

  self.posi = self.posi[6000:]

  self.len = len(self.x_list)

  def __getitem__(self, index):

  return self.x_list[index], self.y_list[index],self.posi[index]

  def __len__(self):

  return self.len


來自 “ ITPUB部落格 ” ,連結:http://blog.itpub.net/69945560/viewspace-2670147/,如需轉載,請註明出處,否則將追究法律責任。

相關文章