【火爐煉AI】機器學習026-股票資料聚類分析-近鄰傳播演算法
(本文所使用的Python庫和版本號: Python 3.6, Numpy 1.14, scikit-learn 0.19, matplotlib 2.2, tushare 1.2)
有一位朋友很擅長炒股,聽說其資產已經達到了兩百多萬,我聽後對其敬佩得五體投地,遂虛心向其請教炒股之祕訣,他聽後,點了一根菸,深深地吸了一口,然後慢悠悠地告訴我,祕訣其實很簡單,你先準備一千萬,炒著炒著就能炒到兩百萬。。。我聽後狂噴鼻血。。。
雖然沒有取到真經,但我仍不死心,仍然覺得人工智慧應該可以用於炒股,AI的能力都能夠輕鬆解決圍棋這一世界性難題,難道還不能打敗股票市場嗎?
下面我們用機器學習的方法來研究一下股票資料,由於股票資料之間沒有任何標記,故而這是一類比較典型的無監督學習問題。但在我們著手股票研究之前,需要了解一下什麼是近鄰傳播演算法。
1. 近鄰傳播演算法簡介
近鄰傳播聚類演算法(Affinity Propagation, AP)是2007年在Science雜誌上提出的一種新的聚類演算法,它根據N個資料點之間的相似度進行聚類,這些相似度可以是對稱的,即兩個資料點相互之間的相似度一樣(如歐式距離),也可以是不對稱的,即兩個資料點相互之間的相似度不等,這些相似度組成N*N的相似度矩陣S。
這種演算法會同時考慮所有資料點都是潛在的代表點,通過結點之間的資訊傳遞,最後得到高質量的聚類,這個資訊的傳遞,是基於sum-product或者max-product的更新原則,在任意一個時刻,這個資訊幅度都代表著近鄰的程度,也就是一個資料點選擇另一個資料點作為代表點有多靠譜,這也是近鄰傳播名字的由來。
關於演算法原理和公式推導,有很多很好地文章,比如python 實現 AP近鄰傳播聚類演算法和Affinity propagation 近鄰傳播演算法,讀者可以閱讀這些文章進行深入研究。
2. 準備股票資料
2.1 從網上獲取股票資料
股票資料雖然可以從網上獲取,但是要想輕易地得到結構化的資料,還是要花費一番功夫的,幸好,我找到了一個很好地財經類python介面模組--tushare,這個模組可以快速的從網站上爬取股票資料,並可以輕鬆儲存和做進一步的資料分析,用起來非常方便。
下面我先自定義了三個工具函式,用於輔助我從網上下載股票資料和對股票資料進行整理。如下程式碼:
# 準備資料集,使用tushare來獲取股票資料
# 準備幾個函式,用來獲取資料
import tushare as ts
def get_K_dataframe(code,start,end):
'''get day-K data of code, from start date to end date
params:
code: stock code eg: 600123, 002743
start: start date, eg: 2016-10-01
end: end date, eg: 2016-10-31
return:
dataframe with columns [date, open, close, high, low]
'''
df=ts.get_k_data(code,start=start,end=end)
df.drop(['volume'],axis=1, inplace=True)
return df
複製程式碼
這個函式獲取單隻股票的估計資料,其時間跨度為start到end,返回獲取到的股票資料DataFrame。當然,一次獲取一隻股票的資料太慢了,下面這個函式我們可以一次獲取多隻股票資料。
def get_batch_K_df(codes_list,start,end):
'''get batch stock K data'''
df=pd.DataFrame()
print('fetching data. pls wait...')
for code in codes_list:
# print('fetching K data of {}...'.format(code))
df=df.append(get_K_dataframe(code,start,end))
return df
複製程式碼
此處我選擇上證50指數的成分股作為研究物件
2.2 對股票資料進行規整
由於tushare模組已經將股票資料進行了基本的規整,此處我們只需要將資料處理成我們專案所需要的樣子即可。
此處對股票資料的規整包括有幾個方面:
1,計算需要聚類的資料,此處我用收盤價減去開盤價作分析,即一天的漲跌幅度。或許用一天的漲幅%形式可能更合適。
2,由於上面的get_batch_k_df()函式獲取的批量股票資料都是將多個股票資料在縱向上合併而來,故而此處我們要將各種不同股票的漲跌幅度放在DataFrame的列上,以股票程式碼為列名。
3,在pd.merge()過程中,由於有的股票在某些交易日停牌,所以沒有資料,這幾個交易日就被刪掉(因為後面的聚類演算法中不允許存在NaN),所以相當於要選擇所有股票都有交易資料的日期,這個選擇相當於取股票資料的交集,最終得到很少一部分資料,資料量太少時,得到的聚類結果也沒有太多說服力。故而我的解決方法是,刪除一些交易日明顯很少的股票,不對其進行pd.merge(),最終得到603個交易日的有效資料,選取了41只股票,捨棄了9只停牌日太多的股票。
這三部分的規整過程我都整合到一個函式中實現,如下是這個函式的程式碼:
# 資料規整函式,用於對獲取的df進行資料處理
def preprocess_data(stock_df,min_K_num=1000):
'''preprocess the stock data.
Notice: min_K_num: the minimum stock K number.
because some stocks was halt trading in this time period,
the some K data was missing.
if the K data number is less than min_K_num, the stock is discarded.'''
df=stock_df.copy()
df['diff']=df.close-df.open # 此處用收盤價與開盤價的差值做分析
df.drop(['open','close','high','low'],axis=1,inplace=True)
result_df=None
#下面一部分是將不同的股票diff資料整合為不同的列,列名為股票程式碼
for name, group in df[['date','diff']].groupby(df.code):
if len(group.index)<min_K_num: continue
if result_df is None:
result_df=group.rename(columns={'diff':name})
else:
result_df=pd.merge(result_df,
group.rename(columns={'diff':name}),
on='date',how='inner') # 一定要inner,要不然會有很多日期由於股票停牌沒資料
result_df.drop(['date'],axis=1,inplace=True)
# 然後將股票資料DataFrame轉變為np.ndarray
stock_dataset=np.array(result_df).astype(np.float64)
# 資料歸一化,此處使用相關性而不是協方差的原因是在結構恢復時更高效
stock_dataset/=np.std(stock_dataset,axis=0)
return stock_dataset,result_df.columns.tolist()
複製程式碼
函式定義好了之後,我們就可以獲取股票資料,並對其進行規整分析,如下所示:
# 上面準備了各種函式,下面開始準備資料集
# 我們此處分析上證50指數的成分股,看看這些股票有哪些特性
sz50_df=ts.get_sz50s()
stock_list=sz50_df.code.tolist()
# print(stock_list) # 沒有問題
batch_K_data=get_batch_K_df(stock_list,start='2013-09-01',end='2018-09-01') # 檢視最近五年的資料
print(batch_K_data.info())
複製程式碼
------------------------輸---------出--------------------------------
fetching data. pls wait...
<class 'pandas.core.frame.DataFrame'>
Int64Index: 56246 entries, 158 to 1356
Data columns (total 6 columns):
date 56246 non-null object
open 56246 non-null float64
close 56246 non-null float64
high 56246 non-null float64
low 56246 non-null float64
code 56246 non-null object
dtypes: float64(4), object(2)
memory usage: 3.0+ MB
None
--------------------------------完-------------------------------------
stock_dataset,selected_stocks=preprocess_data(batch_K_data,min_K_num=1100)
print(stock_dataset.shape) # (603, 41) 由此可以看出得到了603個交易日的資料,其中有41只股票被選出。
# 其他的9只股票因為不滿足最小交易日的要求而被刪除。這603個交易日是所有41只股票都在交易,都沒有停牌的資料。
print(selected_stocks) # 這是實際使用的股票列表
複製程式碼
-------------------------------------輸---------出---------------
(603, 41) ['600000', '600016', '600019', '600028', '600029', '600030', '600036', '600048', '600050', '600104', '600111', '600276', '600340', '600519', '600547', '600585', '600690', '600703', '600887', '600999', '601006', '601088', '601166', '601169', '601186', '601288', '601318', '601328', '601336', '601390', '601398', '601601', '601628', '601668', '601688', '601766', '601800', '601818', '601857', '601988', '603993']
---------------------------------------完------------------------
到此為止,股票資料也從網上下載下來了,我們也對其進行了資料處理,可以滿足後面聚類演算法的要求了。
########################小**********結###############################
1,tushare一個非常好用的獲取股票資料,基金資料,區塊連資料等各種財經資料的模組,強烈推薦。
2,此處我自定義了幾個函式,get_K_dataframe(), get_batch_K_df()和preprocess_data()都是具有一定通用性的,以後要獲取股票資料或者處理股票資料,可以直接搬用或在此基礎上稍微修改即可。
3,作為演示,此處我只獲取了上證50只股票的最近五年資料,並且刪除掉一些停牌太多的股票,得到了41只股票的共603個有效交易日資料。
#################################################################
3. 用近鄰傳播演算法聚類股票資料
首先我們構建了協方差圖模型,從相關性中學習其圖結構
# 從相關性中學習其圖形結構
from sklearn.covariance import GraphLassoCV
edge_model=GraphLassoCV()
edge_model.fit(stock_dataset)
複製程式碼
然後再構建近鄰傳播演算法結構模型,並訓練LassoCV graph中的相關性資料
# 使用近鄰傳播演算法構建模型,並訓練LassoCV graph
from sklearn.cluster import affinity_propagation
_,labels=affinity_propagation(edge_model.covariance_)
複製程式碼
此處已經構建並訓練了該聚類演算法模型,但是怎麼看結果了?
如下程式碼:
n_labels=max(labels)
# 對這41只股票進行了聚類,labels裡面是每隻股票對應的類別標號
print('Stock Clusters: {}'.format(n_labels+1)) # 10,即得到10個類別
sz50_df2=sz50_df.set_index('code')
# print(sz50_df2)
for i in range(n_labels+1):
# print('Cluster: {}----> stocks: {}'.format(i,','.join(np.array(selected_stocks)[labels==i]))) # 這個只有股票程式碼而不是股票名稱
# 下面列印出股票名稱,便於觀察
stocks=np.array(selected_stocks)[labels==i].tolist()
names=sz50_df2.loc[stocks,:].name.tolist()
print('Cluster: {}----> stocks: {}'.format(i,','.join(names)))
複製程式碼
------------------------輸---------出--------------------------------
Stock Clusters: 10
Cluster: 0----> stocks: 寶鋼股份,南方航空,華夏幸福,海螺水泥,中國神華
Cluster: 1----> stocks: 中信證券,保利地產,招商證券,華泰證券
Cluster: 2----> stocks: 北方稀土,洛陽鉬業
Cluster: 3----> stocks: 恆瑞醫藥,三安光電
Cluster: 4----> stocks: 山東黃金
Cluster: 5----> stocks: 貴州茅臺,青島海爾,伊利股份
Cluster: 6----> stocks: 中國聯通,大秦鐵路,中國鐵建,中國中鐵,中國建築,中國中車,中國交建
Cluster: 7----> stocks: 中國平安,新華保險,中國太保,中國人壽
Cluster: 8----> stocks: 浦發銀行,民生銀行,招商銀行,上汽集團,興業銀行,北京銀行,農業銀行,交通銀行,工商銀行,光大銀行,中國銀行
Cluster: 9----> stocks: 中國石化,中國石油
-----------------------------完-------------------------------------
從結果中可以看出,這41只股票已經被劃分為10個簇群,從這個聚類結果中,我們也可以看到,比較類似的股票都被劃分到同一個簇群中,比如Cluster1中大部分都是證券公司,而Cluster6中都是“鐵公基”類股票,而Cluster8中都是銀行類的股票。這和我們普遍認為的概念分類的股票相吻合。
雖然此處我們進行了合理分類,但是我還將這種分類結果繪製到圖中,便於直觀感受他們的簇群距離,所以我此處自定義了一個函式visual_stock_relationship()專門來視覺化聚類演算法的結果。這個函式的程式碼太長,此處我就不貼程式碼了,可以參考我的github上程式碼,得到的聚類結果為:
這個圖看起來一團糟,但是每一部分都代表不同的含義:
1,這個圖形結合了本專案的股票資料,GraphLassoCV圖結構模型,近鄰傳播演算法的分類結果,故而可以說是整個專案的結晶。
2,圖中每一個節點代表一隻股票,旁邊有股票名稱,節點的顏色表示該股票所屬類別的種類,用節點顏色來區分股票所屬簇群。
3,GraphLassoCV圖結構模型中的稀疏逆協方差資訊用節點之間的線條來表示,線條越粗,表示股票之間的關聯性越強。
4,股票在圖形中的位置是由2D巢狀演算法來決定的,距離越遠,表示其相關性越弱,簇間距離越遠。
這個圖得來不易,花了我整整一天時間來做這個專案,汗,裡面的各種股票資料處理,太讓我頭疼了,所以再來具體研究一下這張圖。
########################小**********結###############################
1,本專案僅僅使用股票的收盤價與開盤價的差值,就聚類得到了股票所屬類別的資訊,看來聚類的確可以用於股票內在結構的分類。
2,本專案先用GraphLassoCV得到股票原始資料之間的相關性圖,然後再用近鄰傳播演算法對GraphLassoCV的相關性進行聚類,這種方式和以前我們直接用資料集來訓練聚類演算法不一樣。
3,使用其他股票資料,比如漲幅,成交量,或換手率,也許可以挖掘出更多有用的股票結構資訊,從而為我們的股票投資帶來幫助。
4,股票市場風險太大,沒事還是好好工作,好好研究AI,奉勸各位一句:珍愛生命,遠離股市,切記,切記!!!
#################################################################
注:本部分程式碼已經全部上傳到(我的github)上,歡迎下載。
參考資料:
1, Python機器學習經典例項,Prateek Joshi著,陶俊傑,陳小莉譯