003.01 梯度下降

Jason990420發表於2019-09-17

003.01 梯度下降

建檔日期: 2019/09/17
更新日期: None
語言: Python 3.7.4, numpy 1.16.4, matplotlib 3.1.0
系統: Win10 Ver. 10.0.17763

主題: 003.01 梯度下降

  • 如何找到y=f(x)曲線的最低點 ??

A. 地毯式搜尋的方法

  1. 格點搜尋

    給出一串固定距離的X值, 計算Y值, 找出在X0處, 有最小值的Y0, 得到最低點(X0, Y0)

  2. 隨機搜尋

    隨機產生一串X值, 計算Y值, 找出在X0處, 有最小值的Y0, 得到最低點(X0, Y0)


    搜尋式方法最大的缺點, 就是不能保證找到全域性的最低點, 甚至是區域性的最低點,

B. 解析的方法

基本的原則是斜率為0的點, 就是該區域性的最低點或最高點. 可以使用導數算出某點的斜率. d表示小變化, 導數dy/dx就是指x小變化所造成的y小變化, 這就是斜率的定義. 而dy/dx再次導數, 以d2y/dx2表示, 代表dy/dx的變化, 也就是斜率的變化. 如果斜率的變化為正, 表示該點是區域性最低點, 反之為區域性最高點.

C. 數值分析的方法

本方法採用沿著曲線, 一點一點地尋找到我們的目標點. 最常用的方法就是梯度下降. 大部份的作法都是採用該方法或其修改版. 其方法說明如下:

  1. 隨意取一個起點x
  2. 計算出y值
  3. 新的x點為x-(alpha*dy/dx)
    dy/dx代表曲線的變化, 變化越大, 離最低點越遠, 變化越小進最低點越近, 因此以dy/dx作為x變化的比例, 再給個常數alpha, 控制一下變化的比例, alpha就稱為學習速率, 用來控制步進的大小, 如果太小就必須花更久的時間找到最低點, 如果太大可能會錯過我們要找的最低點
  4. 重複步驟3, 直到y值收斂, 也就是說y值不會再因x的變化而變化, 因為dy/dx=0

記得, 我們的目標是要找出全域性的最低點, 這才是整個曲線的最低點. 所以光使用這個方法, 還不足達到我們追求的目標.

D. 數值分析的範例:

輸出:

003.01 梯度下降


import matplotlib.pyplot as plt

import numpy as np

def function(x):
    y = x*x + 5*x + 8
    return y

def slope(x):
    s = 2*x + 5
    return s

def gradient_descent(x, alpha, num_iterations):
    for i in range(num_iterations):
        y = function(x)
        y_history[i,:] = [x, y]
        update = slope(x)
        x = x - alpha * update

num_iteration =20
start_x = -7
start_y = function(start_x)
alpha = 0.3

y_history = np.zeros((num_iteration, 2))
result = gradient_descent(start_x, alpha, num_iteration)

y_min = np.min(y_history[:,1])
x_min = y_history[np.argmin(y_history[:,1]),0]

x_range = np.linspace(-10,5,100)

plt.plot(x_range, function(x_range), linestyle='-', color='blue', linewidth=2)
plt.axis([-10, 5, 0, 58])
plt.xlabel('x')
plt.ylabel('y=x^2+5x+8')
plt.annotate('Start Point ({:.2f},{:.2f})'.format(start_x, start_y), xy=(start_x, start_y), xytext=(-6, 50), arrowprops=dict(facecolor='black', shrink=0.05),)
plt.annotate('Local Min.({:.2f},{:.2f})'.format(x_min, y_min), xy=(x_min, y_min), xytext=(-2, 40), arrowprops=dict(facecolor='black', shrink=0.05),)
plt.scatter(y_history[:,0], y_history[:,1], s=100 ,color="red")
plt.show()

Jason Yang

相關文章