TCN 一維預測の筆記

選西瓜專業戶發表於2020-12-29

來源

這個例子Support of tensorflow.keras instead of keras
https://github.com/philipperemy/keras-tcn/tree/master/tasks

資料

month milk_production_pounds
1962-01 589
1962-02 561
1962-03 640
1962-04 656
1962-05 727
1962-06 697
1962-07 640
1962-08 599
。。。。。。。。。。
1975-09 817
1975-10 827
1975-11 797
1975-12 843

程式碼

載入包

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from tensorflow.keras import Input, Model
from tensorflow.keras.layers import Dense
from tcn import TCN

說明

這個例子非常簡單,
為了簡單起見一切都是訓練集
輸入輸入沒有normalization正則化。

讀取資料

milk = pd.read_csv('monthly-milk-production-pounds-p.csv', index_col=0, parse_dates=True)
print(milk.head())

lookback_window = 12  # 月
milk = milk.values  # 為了簡單起見,這裡保留np陣列
x, y = [], []
for i in range(lookback_window, len(milk)):
    x.append(milk[i - lookback_window:i])
    y.append(milk[i])
x = np.array(x)
y = np.array(y)
print(x.shape)
print(y.shape)

設定引數

i = Input(shape=(lookback_window, 1))
m = TCN()(i)
m = Dense(1, activation='linear')(m)

搭建模型

model = Model(inputs=[i], outputs=[m])
model.summary()

設定優化器

model.compile('adam', 'mae')

模型擬合

print('Train...')
model.fit(x, y, epochs=100, verbose=2)
p = model.predict(x)

視覺化

plt.plot(p)
plt.plot(y)
plt.title('Monthly Milk Production (in pounds)')
plt.legend(['predicted', 'actual'])
p```lt.show()

相關文章