中文新聞情感分類 Bert-Pytorch-transformers
中文新聞情感分類 Bert-Pytorch-transformers
使用pytorch框架以及transformers包,以及Bert的中文預訓練模型
檔案目錄
data
Train_DataSet.csv
Train_DataSet_Label.csv
main.py
NewsData.py
#main.py
from transformers import BertTokenizer
from transformers import BertForSequenceClassification
from transformers import BertConfig
from transformers import BertPreTrainedModel
import torch
import torch.nn as nn
from transformers import BertModel
import time
import argparse
from NewsData import NewsData
import os
def get_train_args():
parser=argparse.ArgumentParser()
parser.add_argument('--batch_size',type=int,default=10,help = '每批資料的數量')
parser.add_argument('--nepoch',type=int,default=3,help = '訓練的輪次')
parser.add_argument('--lr',type=float,default=0.001,help = '學習率')
parser.add_argument('--gpu',type=bool,default=True,help = '是否使用gpu')
parser.add_argument('--num_workers',type=int,default=2,help='dataloader使用的執行緒數量')
parser.add_argument('--num_labels',type=int,default=3,help='分類類數')
parser.add_argument('--data_path',type=str,default='./data',help='資料路徑')
opt=parser.parse_args()
print(opt)
return opt
def get_model(opt):
#類方法.from_pretrained()獲取預訓練模型,num_labels是分類的類數
model = BertForSequenceClassification.from_pretrained('bert-base-chinese',num_labels=opt.num_labels)
return model
def get_data(opt):
#NewsData繼承於pytorch的Dataset類
trainset = NewsData(opt.data_path,is_train = 1)
trainloader=torch.utils.data.DataLoader(trainset,batch_size=opt.batch_size,shuffle=True,num_workers=opt.num_workers)
testset = NewsData(opt.data_path,is_train = 0)
testloader=torch.utils.data.DataLoader(testset,batch_size=opt.batch_size,shuffle=False,num_workers=opt.num_workers)
return trainloader,testloader
def train(epoch,model,trainloader,testloader,optimizer,opt):
print('\ntrain-Epoch: %d' % (epoch+1))
model.train()
start_time = time.time()
print_step = int(len(trainloader)/10)
for batch_idx,(sue,label,posi) in enumerate(trainloader):
if opt.gpu:
sue = sue.cuda()
posi = posi.cuda()
label = label.unsqueeze(1).cuda()
optimizer.zero_grad()
#輸入引數為詞列表、位置列表、標籤
outputs = model(sue, position_ids=posi,labels = label)
loss, logits = outputs[0],outputs[1]
loss.backward()
optimizer.step()
if batch_idx % print_step == 0:
print("Epoch:%d [%d|%d] loss:%f" %(epoch+1,batch_idx,len(trainloader),loss.mean()))
print("time:%.3f" % (time.time() - start_time))
def test(epoch,model,trainloader,testloader,opt):
print('\ntest-Epoch: %d' % (epoch+1))
model.eval()
total=0
correct=0
with torch.no_grad():
for batch_idx,(sue,label,posi) in enumerate(testloader):
if opt.gpu:
sue = sue.cuda()
posi = posi.cuda()
labels = label.unsqueeze(1).cuda()
label = label.cuda()
else:
labels = label.unsqueeze(1)
outputs = model(sue, labels=labels)
loss, logits = outputs[:2]
_,predicted=torch.max(logits.data,1)
total+=sue.size(0)
correct+=predicted.data.eq(label.data).cpu().sum()
s = ("Acc:%.3f" %((1.0*correct.numpy())/total))
print(s)
if __name__=='__main__':
opt = get_train_args()
model = get_model(opt)
trainloader,testloader = get_data(opt)
if opt.gpu:
model.cuda()
optimizer=torch.optim.SGD(model.parameters(),lr=opt.lr,momentum=0.9)
if not os.path.exists('./model.pth'):
for epoch in range(opt.nepoch):
train(epoch,model,trainloader,testloader,optimizer,opt)
test(epoch,model,trainloader,testloader,opt)
torch.save(model.state_dict(),'./model.pth')
else: 鄭州治療婦科哪個醫院好
model.load_state_dict(torch.load('model.pth'))
print('模型存在,直接test')
test(0,model,trainloader,testloader,opt)
#NewsData.py
from transformers import BertTokenizer
from transformers import BertForSequenceClassification
from transformers import BertConfig
from transformers import BertPreTrainedModel
import torch
import torch.nn as nn
from transformers import BertModel
import time
import argparse
class NewsData(torch.utils.data.Dataset):
def __init__(self,root,is_train = 1):
self.tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
self.data_num = 7346
self.x_list = []
self.y_list = []
self.posi = []
with open(root + '/Train_DataSet.csv',encoding='UTF-8') as f:
for i in range(self.data_num+1):
line = f.readline()[:-1] + '這是一箇中性的資料'
data_one_str = line.split(',')[len(line.split(','))-2]
data_two_str = line.split(',')[len(line.split(','))-1]
if len(data_one_str) < 6:
z = len(data_one_str)
data_one_str = data_one_str + ',' + data_two_str[0:min(200,len(data_two_str))]
else:
data_one_str = data_one_str
if i==0:
continue
word_l = self.tokenizer.encode(data_one_str, add_special_tokens=False)
if len(word_l)<100:
while(len(word_l)!=100):
word_l.append(0)
else:
word_l = word_l[0:100]
word_l.append(102)
l = word_l
word_l = [101]
word_l.extend(l)
self.x_list.append(torch.tensor(word_l))
self.posi.append(torch.tensor([i for i in range(102)]))
with open(root + '/Train_DataSet_Label.csv',encoding='UTF-8') as f:
for i in range(self.data_num+1):
#print(i)
label_one = f.readline()[-2]
if i==0:
continue
label_one = int(label_one)
self.y_list.append(torch.tensor(label_one))
#訓練集或者是測試集
if is_train == 1:
self.x_list = self.x_list[0:6000]
self.y_list = self.y_list[0:6000]
self.posi = self.posi[0:6000]
else:
self.x_list = self.x_list[6000:]
self.y_list = self.y_list[6000:]
self.posi = self.posi[6000:]
self.len = len(self.x_list)
def __getitem__(self, index):
return self.x_list[index], self.y_list[index],self.posi[index]
def __len__(self):
return self.len
來自 “ ITPUB部落格 ” ,連結:http://blog.itpub.net/69945560/viewspace-2670147/,如需轉載,請註明出處,否則將追究法律責任。
相關文章
- 中文情感分類單標籤
- 數理統計——新聞分類
- pyhanlp文字分類與情感分析HanLP文字分類
- 使用貝葉斯進行新聞分類
- 利用LSTM做語言情感分類
- 如何用Python和機器學習訓練中文文字情感分類模型?Python機器學習模型
- 樸素貝葉斯--新浪新聞分類例項
- snownlp類庫(中文情感分析)原始碼註釋及使用原始碼
- NLP入門競賽,搜狗新聞文字分類!拿幾十萬獎金!文字分類
- 如何用50行程式碼構建情感分類器行程
- 深度學習之電影二分類的情感問題深度學習
- flutter實戰3:解析HTTP請求資料和製作新聞分類列表FlutterHTTP
- 深度學習之新聞多分類問題深度學習
- 文字自動摘要:基於TextRank的中文新聞摘要
- Spark機器學習實戰 (十一) - 文字情感分類專案實戰Spark機器學習
- 淺談NLP 文字分類/情感分析 任務中的文字預處理工作文字分類
- 小程式雲開發之新聞類專案分析
- 聊聊新的遊戲分類方式遊戲
- AI Challenger 2018:細粒度使用者評論情感分類冠軍思路總結AI
- IT新聞類軟文營銷的三大寫作技巧
- 人民新聞網
- 3 分鐘建立 Serverless Job 定時獲取新聞熱搜!Server
- 3 分鐘建立 Serverless Job 定時獲取新聞熱搜Server
- 利用transformer進行中文文字分類(資料集是復旦中文語料)ORM文字分類
- 利用TfidfVectorizer進行中文文字分類(資料集是復旦中文語料)文字分類
- 萬字總結Keras深度學習中文文字分類Keras深度學習文字分類
- 央視新聞《 五分快速三必中方法 》手機搜狐網
- 【牛腩新聞】——CSS(一)CSS
- 京東獲得jd商品分類API介面(父分類、根分類、子分類)API
- 綠色花草養殖新聞資訊類織夢dedecms網站模板網站
- 新聞類APP軟體開發有哪些特性?北京銳智互動APP
- HarmonyOS SDK 助力新浪新聞打造精緻易用的新聞應用
- 如何提升新聞營銷的效果?新聞稿釋出的技巧
- 基於機器學習和TFIDF的情感分類演算法,詳解自然語言處理機器學習演算法自然語言處理
- 新聞新體驗!3DCAT助力開啟紅網“元宇宙”新聞直播間3D元宇宙
- 靈玖軟體為你全方位介紹中文情感分析
- 新聞稿釋出渠道分析 如何選擇新聞稿釋出公司
- CSS1(新聞案例)CSS