C++基於armadillo im2col的實現

c艹使用者發表於2021-05-23

col2im的實現,這是im2col的逆過程
最近學習CNN,需要用到im2col這個函式,無奈網上沒有多少使用armadillo的例子,而且armadillo庫中似乎也沒有這個函式,因此自己寫了。
im2col的原理網上一大把,我懶得寫了。

1. field<某類>

field<class oT> 是armadillo庫中的類,類似於矩陣, 不過這個“矩陣”的每一個元素都是向量或者矩陣。因此用field可以作為四維輸入資料使用。

2. 矩陣展開

這個其實還挺簡單,使用reshape函式將矩陣變形。不過,armadillo中變形是按照豎向變形的。比如:

1 2 3
4 5 6
7 8 9

這樣的矩陣變形成1×9的向量的話:

1 4 7 2 5 8 3 6 9

會成這樣?。。。
但是也不影響,濾波器也是這麼變得,相對位置沒變唄。。

3. 排列組合

鄙人才疏學淺,只會用一堆for迴圈來排列組合。。。貌似沒找到更好的辦法。

4. 其他細節

像是步數、填充什麼的,多注意一下就行了。

5. 實現程式碼

mat im2col(field<mat> input_data, int filter_h, int filter_w, int stride, int pad)
{
	int N, C, H, W;
	N = input_data.n_rows;
	C = input_data.n_cols;
	H = input_data(0, 0).n_rows;
	W = input_data(0, 0).n_cols;
	int out_h = (H + 2 * pad - filter_h) / stride + 1;
	int out_w = (W + 2 * pad - filter_w) / stride + 1;
	field<mat> img = input_data;
	img.for_each([H, W, pad](mat& X) {X.insert_rows(0, pad); X.insert_rows(H + pad, pad); X.insert_cols(0, pad); X.insert_cols(W + pad, pad); });
	mat col(out_h * out_w * N, C * filter_h * filter_w, fill::zeros);
	for (int n = 0, z = 0; n < N; n++)
	{
		for (int i = 0; i < out_h; i++)
		{
			for (int j = 0; j < out_w; j++, z++)
			{
				for (int k = 0; k < C; k++)
				{
					mat filter(filter_h, filter_w, fill::zeros);
					filter = img(n, k)(span(i * stride, i * stride + filter_h - 1), span(j * stride, j * stride + filter_w - 1));
					filter.reshape(1, filter_h * filter_w);
					int x = z;
					int y0 = filter_h * filter_w * k;
					int y1 = filter_h * filter_w * k + filter_h * filter_w - 1;
					col(span(x, x), span(y0, y1)) = filter;
				}
			}
		}
	}
	return col;
}

標頭檔案就是宣告和引用。

相關文章