統計學習方法第二十章作業:潛在狄利克雷分配 LDA 吉布斯抽樣法演算法 程式碼實現

演算法只是工具發表於2021-01-01

潛在狄利克雷分配 LDA 吉布斯抽樣法演算法

import numpy as np
import jieba

class LDA:
    def __init__(self,text_list,k):
        self.k = k
        self.text_list = text_list
        self.text_num = len(text_list)
        self.get_X()
        self.NKV = np.zeros((self.k,self.word_num))
        self.NMK = np.zeros((self.text_num,self.k))
        self.nm = np.zeros(self.text_num)
        self.nk = np.zeros(self.k)
        self.zmn = [[] for i in range(self.text_num)]
        self.alpha = np.random.randint(1,self.k,size=k)
        self.beta = np.random.randint(1,self.word_num, size=self.word_num)


    def get_X(self):
        self.cuted_text = [jieba.lcut(text,cut_all=True) for text in self.text_list]
        self.word_all = []
        for i in self.cuted_text:
            self.word_all.extend(i)
        self.word_set = list(set(self.word_all))
        self.word_num = len(self.word_set)
        self.word_dict = {}
        for index,word in enumerate(self.word_set):
            self.word_dict[word] = index

    def initial_K(self):
        for doc_num in range(self.text_num):
            for word in self.cuted_text[doc_num]:
                k = np.random.choice(self.k, 1)[0]
                self.zmn[doc_num].append(k)
                v = self.word_dict[word]
                self.NMK[doc_num,k] += 1
                self.nm[doc_num] += 1
                self.NKV[k,v] += 1
                self.nk[k] += 1

    def iter_jbs(self):
        for doc_num in range(self.text_num):
            for word_index in range(len(self.cuted_text[doc_num])):
                v = self.word_dict[self.cuted_text[doc_num][word_index]]
                k = self.zmn[doc_num][word_index]
                self.NMK[doc_num,k] -= 1
                self.nm[doc_num] -= 1
                self.NKV[k,v] -= 1
                self.nk[k] -= 1
                p_klist = (self.NKV[:,v]+self.beta[v])/np.sum(self.NKV[:,v]+self.beta[v])*(self.NMK[doc_num]+self.alpha[k])/np.sum(self.NMK[doc_num]+self.alpha[k])
                p_klist = p_klist/np.sum(p_klist)
                k_choice = np.random.choice(self.k,p = p_klist)
                self.zmn[doc_num][word_index] = k_choice

                self.NMK[doc_num,k_choice] += 1
                self.nm[doc_num] += 1
                self.NKV[k_choice,v] += 1
                self.nk[k_choice] += 1

    def get_sita_y(self):
        self.sita_mk  = np.zeros((self.text_num,self.k))
        self.yta_kv = np.zeros((self.k,self.word_num))
        for i in range(self.text_num):
            self.sita_mk[i] = (self.NMK[i]+self.alpha)/np.sum(self.NMK[i])
        for j in range(self.k):
            self.yta_kv[j] = (self.NKV[j]+self.beta)/np.sum(self.NKV[j])

    def fit(self,max_iter = 100):
        self.initial_K()
        for iter in range(max_iter):
            print(iter)
            self.iter_jbs()
        self.get_sita_y()

def main():
    text_list = [
    '一個月前,足協盃十六進八的比賽,遼足費盡周折對調主客場,目的只是為了葫蘆島體育場的啟用儀式。那場球遼足50痛宰“主力休息”的天津泰達。幾天後中超聯賽遼足客場對天津,輪到遼足“全替補”,\
    13輸球,甘為天津泰達保級的祭品。那時,遼足以“聯賽保級問題不大,足協盃拼一拼”作為主力和外援聯賽全部缺陣的理由。',
    '被一腳踹進“忘恩負義”坑裡的孫楊,剛剛爬出來,又有手伸出來,要把孫楊再往坑裡推。即使是陪伴孫楊參加世錦賽的張亞東(微博)教練,\
    也沒敢大義凜然地伸出援手,“孫楊願意回去我不攔”,球又踢給了孫楊。張亞東教練怕什麼呢?',
    '孫楊成績的利益分配,以及榮譽的分享,圈裡人都知道,拿了世界冠軍和全運冠軍,運動員都會有相應的高額獎金,那麼主管教練也會得到與之對應的豐厚獎勵,\
    所以孫楊獲得的榮譽,也會惠及主管教練。']
    k = 2
    lda = LDA(text_list,k)
    lda.fit()
    print(lda.sita_mk)
    print(lda.yta_kv)

if __name__ == '__main__':
    main()

#result--------------------------

[[0.20689655 0.81034483]
 [0.7        0.32222222]
 [0.50666667 0.52      ]]
 
[[1.2295082  0.12295082 0.58196721 1.21311475 1.08196721 0.06557377
  1.12295082 0.18032787 0.98360656 0.16393443 0.78688525 1.01639344
  0.7704918  1.12295082 1.01639344 0.43442623 1.00819672 0.72131148
  0.70491803 0.21311475 0.78688525 0.14754098 0.6147541  0.53278689
  0.59836066 1.20491803 0.6557377  0.01639344 1.05737705 0.53278689
  1.22131148 0.71311475 1.29508197 1.23770492 0.59016393 1.20491803
  0.13114754 0.04918033 0.99180328 0.93442623 1.27868852 1.1557377
  0.90983607 0.66393443 1.08196721 1.07377049 0.57377049 0.08196721
  0.17213115 0.54098361 1.14754098 0.98360656 0.17213115 0.26229508
  0.6557377  1.12295082 0.80327869 0.77868852 1.10655738 0.81967213
  0.79508197 0.41803279 0.63934426 0.36065574 1.29508197 0.74590164
  0.99180328 1.14754098 0.67213115 0.33606557 0.40163934 0.73770492
  0.67213115 0.86885246 0.18852459 0.17213115 0.75409836 0.33606557
  0.07377049 1.13114754 0.40163934 0.63934426 0.36885246 1.27868852
  1.19672131 0.35245902 1.10655738 0.21311475 1.19672131 0.71311475
  0.29508197 0.67213115 1.02459016 0.87704918 0.81147541 1.04918033
  0.1147541  1.1147541  0.40163934 1.05737705 0.31147541 0.40983607
  0.31147541 0.59016393 0.74590164 1.18852459 1.32786885 0.74590164
  0.48360656 0.42622951 0.8442623  1.22131148 0.95901639 0.69672131
  0.09836066 1.26229508 1.1147541  0.63934426 1.1557377  0.14754098
  1.18032787 0.1557377  0.93442623 0.63114754 0.45901639 0.52459016
  1.28688525 1.13114754 0.91803279 1.27868852 0.82786885 0.31147541
  0.33606557 0.41803279 1.30327869 0.99180328 1.31147541 1.17213115
  0.97540984 1.19672131 0.24590164 0.90983607 0.59016393 0.49180328
  0.87704918 1.08196721 0.42622951 0.27868852 0.49180328 0.69672131
  0.08196721 0.48360656 0.5        0.7704918  0.95081967 1.
  0.52459016 0.16393443 1.1147541  0.18852459 0.82786885 1.09016393
  0.1147541  0.93442623]
 [0.99371069 0.10691824 0.44025157 0.93710692 0.83647799 0.05660377
  0.88679245 0.14465409 0.76100629 0.1509434  0.61006289 0.78616352
  0.59119497 0.87421384 0.77987421 0.35220126 0.77358491 0.55974843
  0.53459119 0.17610063 0.60377358 0.11949686 0.47798742 0.40880503
  0.45283019 0.93081761 0.50314465 0.00628931 0.81761006 0.41509434
  0.93081761 0.55345912 0.98742138 0.94339623 0.44654088 0.91823899
  0.10691824 0.03144654 0.7672956  0.72327044 0.98742138 0.88050314
  0.71069182 0.51572327 0.83647799 0.81761006 0.44654088 0.05660377
  0.12578616 0.40880503 0.87421384 0.74842767 0.12578616 0.19496855
  0.50943396 0.85534591 0.62264151 0.60377358 0.8427673  0.66037736
  0.6163522  0.32704403 0.48427673 0.27044025 1.01257862 0.56603774
  0.74213836 0.86792453 0.50943396 0.25157233 0.31446541 0.57232704
  0.52830189 0.65408805 0.1509434  0.12578616 0.58490566 0.26415094
  0.05660377 0.8490566  0.31446541 0.50314465 0.27044025 0.98742138
  0.9245283  0.27672956 0.8427673  0.16981132 0.9245283  0.54716981
  0.2327044  0.52201258 0.77987421 0.67924528 0.63522013 0.79874214
  0.08805031 0.8490566  0.29559748 0.81761006 0.24528302 0.32075472
  0.25157233 0.4591195  0.57861635 0.90566038 1.02515723 0.56603774
  0.36477987 0.32075472 0.65408805 0.93710692 0.72955975 0.52830189
  0.08176101 0.97484277 0.8490566  0.49685535 0.89308176 0.11949686
  0.89937107 0.11320755 0.70440252 0.49056604 0.33962264 0.40880503
  0.98113208 0.86163522 0.69811321 0.97484277 0.66037736 0.24528302
  0.25157233 0.32075472 0.99371069 0.75471698 1.01257862 0.89308176
  0.75471698 0.93081761 0.18238994 0.6918239  0.4591195  0.36477987
  0.66666667 0.80503145 0.32075472 0.22012579 0.3836478  0.52830189
  0.05660377 0.37735849 0.37735849 0.58490566 0.74213836 0.77358491
  0.40251572 0.13207547 0.8490566  0.1509434  0.64150943 0.83647799
  0.09433962 0.72327044]]

相關文章