col2im的實現,這是im2col的逆過程
最近學習CNN,需要用到im2col這個函式,無奈網上沒有多少使用armadillo的例子,而且armadillo庫中似乎也沒有這個函式,因此自己寫了。
im2col的原理網上一大把,我懶得寫了。
1. field<某類>
field<class oT> 是armadillo庫中的類,類似於矩陣, 不過這個“矩陣”的每一個元素都是向量或者矩陣。因此用field可以作為四維輸入資料使用。
2. 矩陣展開
這個其實還挺簡單,使用reshape函式將矩陣變形。不過,armadillo中變形是按照豎向變形的。比如:
1 2 3
4 5 6
7 8 9
這樣的矩陣變形成1×9的向量的話:
1 4 7 2 5 8 3 6 9
會成這樣?。。。
但是也不影響,濾波器也是這麼變得,相對位置沒變唄。。
3. 排列組合
鄙人才疏學淺,只會用一堆for迴圈來排列組合。。。貌似沒找到更好的辦法。
4. 其他細節
像是步數、填充什麼的,多注意一下就行了。
5. 實現程式碼
mat im2col(field<mat> input_data, int filter_h, int filter_w, int stride, int pad)
{
int N, C, H, W;
N = input_data.n_rows;
C = input_data.n_cols;
H = input_data(0, 0).n_rows;
W = input_data(0, 0).n_cols;
int out_h = (H + 2 * pad - filter_h) / stride + 1;
int out_w = (W + 2 * pad - filter_w) / stride + 1;
field<mat> img = input_data;
img.for_each([H, W, pad](mat& X) {X.insert_rows(0, pad); X.insert_rows(H + pad, pad); X.insert_cols(0, pad); X.insert_cols(W + pad, pad); });
mat col(out_h * out_w * N, C * filter_h * filter_w, fill::zeros);
for (int n = 0, z = 0; n < N; n++)
{
for (int i = 0; i < out_h; i++)
{
for (int j = 0; j < out_w; j++, z++)
{
for (int k = 0; k < C; k++)
{
mat filter(filter_h, filter_w, fill::zeros);
filter = img(n, k)(span(i * stride, i * stride + filter_h - 1), span(j * stride, j * stride + filter_w - 1));
filter.reshape(1, filter_h * filter_w);
int x = z;
int y0 = filter_h * filter_w * k;
int y1 = filter_h * filter_w * k + filter_h * filter_w - 1;
col(span(x, x), span(y0, y1)) = filter;
}
}
}
}
return col;
}
標頭檔案就是宣告和引用。