多分類Fisher線性判別演算法

li_huifei發表於2017-10-17

Fisher線性判別法也即FLD實在PCA降維的基礎上再進一步考慮樣本間的資訊。演算法目標是找到一個投影軸,使各分類的類內樣本在投影軸上的投影間距最小,同時樣本間的投影間距最大。原理不難,公式推導遍地都是,儘管看不太懂吧..但是掌握核心幾個公式以後就不妨礙我們用程式來實現它。

但是網上的例子多數是基於二分類的,那麼對於多類別的樣本如何使用FLD判別呢,這個問題沒有太多的論述。所以想出瞭如下的辦法去在多個分類的樣本中應用FLD:對樣本中的分類兩兩結合計算對應的w向量,測試資料進來後分別應用這些w向量對其進行FLD判斷,判斷成功的分類在最終結果中加一分,最後遍歷了所有w後,得分最高的那個分類即是測試資料的分類結果。辦法雖然笨,但是基本還算能完成功能吧..

import os  
import sys  
import numpy as np  
from numpy import *  
import operator  
import matplotlib  
import matplotlib.pyplot as plt

def class_mean(samples):#求樣本均值
    aver = np.mean(samples,axis = 1)
    return aver

def withclass_scatter(samples,mean):#求類內散度
    dim,num = samples.shape()
    samples_m = samples - mean
    s_with = 0
    for i in range(num):
        x = samples_mean[:,i]
        s_in += dot(x,x.T)
    return s_in

def get_w(s_in1,s_in2,mean1,mean2):#得到權向量
    sw = s_in1 + s_in2
    w = dot(sw.I,(mean1-mean2))
    return w

def classify(test,w,mean1,mean2):#分類演算法
    cen_1 = dot(w.T,mean1)
    cen_2 = dot(w.T,mean2)
    g = dot(w.T,sample)
    return abs(pos - cen_1)<abs(pos - cen2)


if __name__=='__main__':
    class_num = 分類數
    class_name={}
    test = 測試資料
    ws = {}
    result={}
    for i in range(k):
        class_name[i]=group[i]
    for i in range(k):
        result[i]=0
    for i in range(k):#第一次迴圈計算權向量的值
        for j in range(k):
            if i==j:
                break
            mean1 = class_mean(class_name[i])
            mean2 = class_mean(class_name[j])
            s_in1 = withclass_scatter(class_name[i],mean1)
            s_in2 = withclass_scatter(class_name[j],mean2)
            w[i,j] = get_w(s_in1,s_in2,mean1,mean2)
    for i in range(k):#第二次迴圈在測試資料上應用權向量
        for j in range(k):
            if i==j:
                break
            w = w[i,j]
            if classify(test,w,mean1,mean2):
                result[i]+=1
            else result[j]+=1
    final_result = filter(lambda x:max(result.values())==result[x],result)[0]#找出得分最高的分類作為最終結果
                 
            


相關文章