1 import numpy as np 2 import pandas as pd 3 4 5 def load_csv(path): 6 csv_df = pd.read_csv(path) 7 labels = csv_df['label'].values 8 images = csv_df.drop('label', axis=1).values / 255.0 # 歸一化畫素值 9 v = int(labels.size * 0.8) 10 train_images = images[:v] 11 test_images = images[v:] 12 # 13 train_labels = labels[:v] 14 t_s = train_labels.size 15 train_zero = np.zeros((t_s, 10)) 16 train_zero[np.arange(t_s), train_labels] = 1 17 test_labels = labels[v:] 18 t_s = test_labels.size 19 test_zero = np.zeros((t_s, 10)) 20 test_zero[np.arange(t_s), test_labels] = 1 21 22 return train_images, train_zero, test_images, test_zero 23 24 25 # 載入資料 26 train_images, train_labels, test_images, test_labels = load_csv('train.csv') 27 28 # import torch.nn as nn 29 # nn.Sequential(nn.Linear(784, 256), nn.ReLU(), nn.Linear(256, 128), nn.ReLU(), nn.Linear(128, 10)) 30 class Model: 31 def __init__(self, input_size, hidden1_size, hidden2_size, output_size): 32 self.W1 = np.random.randn(input_size, hidden1_size) * 0.01 33 self.b1 = np.zeros((1, hidden1_size)) 34 self.W2 = np.random.randn(hidden1_size, hidden2_size) * 0.01 35 self.b2 = np.zeros((1, hidden2_size)) 36 self.W3 = np.random.randn(hidden2_size, output_size) * 0.01 37 self.b3 = np.zeros((1, output_size)) 38 39 # adam 40 self.mW1, self.mb1 = np.zeros_like(self.W1), np.zeros_like(self.b1) 41 self.vW1, self.vb1 = np.zeros_like(self.W1), np.zeros_like(self.b1) 42 self.mW2, self.mb2 = np.zeros_like(self.W2), np.zeros_like(self.b2) 43 self.vW2, self.vb2 = np.zeros_like(self.W2), np.zeros_like(self.b2) 44 self.mW3, self.mb3 = np.zeros_like(self.W3), np.zeros_like(self.b3) 45 self.vW3, self.vb3 = np.zeros_like(self.W3), np.zeros_like(self.b3) 46 47 def relu(self, z): 48 return np.maximum(0, z) 49 50 def relu_derivative(self, z): 51 return np.where(z > 0, 1, 0) 52 53 def softmax(self, z): 54 exp_scores = np.exp(z - np.max(z, axis=1, keepdims=True)) 55 return exp_scores / np.sum(exp_scores, axis=1, keepdims=True) 56 57 def forward(self, X): 58 self.Z1 = np.dot(X, self.W1) + self.b1 59 self.A1 = self.relu(self.Z1) 60 self.Z2 = np.dot(self.A1, self.W2) + self.b2 61 self.A2 = self.relu(self.Z2) 62 self.Z3 = np.dot(self.A2, self.W3) + self.b3 63 self.A3 = self.softmax(self.Z3) 64 return self.A3 65 66 def backward(self, X, y, output): 67 m = y.shape[0] 68 dZ3 = output - y 69 dW3 = np.dot(self.A2.T, dZ3) / m 70 db3 = np.sum(dZ3, axis=0, keepdims=True) / m 71 dZ2 = np.dot(dZ3, self.W3.T) * self.relu_derivative(self.Z2) 72 dW2 = np.dot(self.A1.T, dZ2) / m 73 db2 = np.sum(dZ2, axis=0, keepdims=True) / m 74 dZ1 = np.dot(dZ2, self.W2.T) * self.relu_derivative(self.Z1) 75 dW1 = np.dot(X.T, dZ1) / m 76 db1 = np.sum(dZ1, axis=0, keepdims=True) / m 77 78 return dW1, db1, dW2, db2, dW3, db3 79 80 def update_parameters_with_adam(self, grads, t, learning_rate, beta1, beta2, epsilon): 81 dW1, db1, dW2, db2, dW3, db3 = grads 82 83 # one 84 self.mW1 = beta1 * self.mW1 + (1 - beta1) * dW1 85 self.mb1 = beta1 * self.mb1 + (1 - beta1) * db1 86 self.mW2 = beta1 * self.mW2 + (1 - beta1) * dW2 87 self.mb2 = beta1 * self.mb2 + (1 - beta1) * db2 88 self.mW3 = beta1 * self.mW3 + (1 - beta1) * dW3 89 self.mb3 = beta1 * self.mb3 + (1 - beta1) * db3 90 91 # two 92 self.vW1 = beta2 * self.vW1 + (1 - beta2) * (dW1 ** 2) 93 self.vb1 = beta2 * self.vb1 + (1 - beta2) * (db1 ** 2) 94 self.vW2 = beta2 * self.vW2 + (1 - beta2) * (dW2 ** 2) 95 self.vb2 = beta2 * self.vb2 + (1 - beta2) * (db2 ** 2) 96 self.vW3 = beta2 * self.vW3 + (1 - beta2) * (dW3 ** 2) 97 self.vb3 = beta2 * self.vb3 + (1 - beta2) * (db3 ** 2) 98 99 # 2層 100 mW1_corrected = self.mW1 / (1 - beta1 ** t) 101 mb1_corrected = self.mb1 / (1 - beta1 ** t) 102 mW2_corrected = self.mW2 / (1 - beta1 ** t) 103 mb2_corrected = self.mb2 / (1 - beta1 ** t) 104 mW3_corrected = self.mW3 / (1 - beta1 ** t) 105 mb3_corrected = self.mb3 / (1 - beta1 ** t) 106 107 vW1_corrected = self.vW1 / (1 - beta2 ** t) 108 vb1_corrected = self.vb1 / (1 - beta2 ** t) 109 vW2_corrected = self.vW2 / (1 - beta2 ** t) 110 vb2_corrected = self.vb2 / (1 - beta2 ** t) 111 vW3_corrected = self.vW3 / (1 - beta2 ** t) 112 vb3_corrected = self.vb3 / (1 - beta2 ** t) 113 114 # update 115 self.W1 -= learning_rate * mW1_corrected / (np.sqrt(vW1_corrected) + epsilon) 116 self.b1 -= learning_rate * mb1_corrected / (np.sqrt(vb1_corrected) + epsilon) 117 self.W2 -= learning_rate * mW2_corrected / (np.sqrt(vW2_corrected) + epsilon) 118 self.b2 -= learning_rate * mb2_corrected / (np.sqrt(vb2_corrected) + epsilon) 119 self.W3 -= learning_rate * mW3_corrected / (np.sqrt(vW3_corrected) + epsilon) 120 self.b3 -= learning_rate * mb3_corrected / (np.sqrt(vb3_corrected) + epsilon) 121 122 def train(self, epochs, learning_rate, beta1=0.9, beta2=0.999, epsilon=1e-8): 123 for epoch in range(epochs): 124 # train 125 output = self.forward(train_images) 126 grads = self.backward(train_images, train_labels, output) 127 self.update_parameters_with_adam(grads, epoch + 1, learning_rate, beta1, beta2, epsilon) 128 if (epoch + 1) % 10 == 0: 129 loss = -np.sum(train_labels * np.log(output + 1e-8)) / train_labels.shape[0] # 交叉熵損失 130 predictions = np.argmax(output, axis=1) 131 accuracy = np.mean(np.argmax(train_labels, axis=1) == predictions) 132 print(f'train_epoch {epoch + 1}, loss: {loss:.4f}, acc: {accuracy:.4f}') 133 # test 134 output = self.predict(test_images) 135 accuracy = np.mean(np.argmax(test_labels, axis=1) == output) 136 #loss可寫可不寫 137 # loss = -np.sum(test_labels * np.log(output + 1e-8)) / test_labels.shape[0] 138 print(f'test_epoch {epoch + 1}, acc: {accuracy:.4f}') 139 140 141 142 143 def predict(self, X): 144 output = self.forward(X) 145 return np.argmax(output, axis=1) 146 147 148 nn = Model(28*28, 256, 128, 10) # cls 149 nn.train(30, 0.01) # 0.001即可 150 151 # 測試 152 test_images = test_images * 255 153 import cv2 154 while True: 155 # 測試 156 idx = np.random.randint(0, test_images.shape[0]) 157 image = test_images[idx].reshape(28,28) 158 image = cv2.resize(image, (128, 128)) 159 cv2.imshow('test', image) 160 print('當前識別數字為: ',nn.predict(test_images[idx])) 161 cv2.waitKey(0) 162 cv2.destroyAllWindows()
30輪可還行 資料集
連結:https://pan.baidu.com/s/1inAI-tVnCLETZ_18Y8QV0w?pwd=6666