Machine Learning (1) - Linear Regression

Rachel發表於2019-04-14

Pandas 是學習 Machine Learning 的利器,這裡假設你已經對 Pandas 基礎 有所瞭解。

這一節主要以預測一個地區的房價為例,學習 ML 的模型之一 Linear Regression

1. 引入需要用的包以及資料檔案

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn import linear_model

df = pd.read_csv('/Users/rachel/Downloads/py-master/ML/1_linear_reg/homeprices.csv')
df

輸出:

Machine Learning (1) - Linear Regression

2. 訓練 Linear Regression 模型

reg = linear_model.LinearRegression() // 初始化資料模型

// 訓練這個模型,第一個引數是已知的資料,第二個引數是未來要預測的值
reg.fit(df[['area']], df.price) 

現在就可以用模型來預測值

reg.predict([[5000]])

輸出:

array([859554.79452055])

這裡大概介紹一下 Linear Regression 的建模公式:

y = m * x + b

m 和 b 就是模型的係數, 通過提供大量的 x 和 y 的值,來求出最佳的 m 和 b 的值,這也就是訓練模型的過程。

m 被稱作 Coefficients
b 被稱作 Intercept
通過下面兩個命令就可以檢視 Linear Regression 模型的 m 和 b 的值:

reg.coef_  // m 的值
reg.intercept_  // b 的值

3. 輸出圖形資料:

%matplotlib inline
plt.xlabel('area(sqr ft)', fontsize=20) // x 軸
plt.ylabel('price(US$)', fontsize=20) // y 軸
plt.scatter(df.area, df.price, color='red', marker='+') // 以 “點” 輸出已知資料
plt.plot(df.area, reg.predict(df[['area']]), color='blue') // 以 “線” 輸出預測的資料,第二個引數是根據模型預測的值

Machine Learning (1) - Linear Regression

上面這條線,就是我們最終得到的 Linear Regression 模型,得到這條線,我們就可以輕鬆預測任何尺寸的房價,也就相當於模型訓練完成。

4. 應用模型

下面就用前面訓練好的模型來快速預測房價:

先建一個新的 csv 檔案,裡面填充一些房子的面積值:

df_new = pd.read_csv('/Users/rachel/Sites/py-master/ML/1_linear_reg/areas.csv')
df_new.head()

輸出(大家可以按照這個輸出格式,隨便建一個側表來做測試):

Machine Learning (1) - Linear Regression

用訓練好的模型 reg 做房價預測

p = reg.predict(df_new[['area']])
p

// 輸出
array([ 316404.10958904,  384297.94520548,  492928.08219178,
        661304.79452055,  740061.64383562,  799808.21917808,
        926090.75342466,  650441.78082192,  825607.87671233,
        492928.08219178, 1402705.47945205, 1348390.4109589 ,
       1144708.90410959])

// 用預測出來的房價資料完善原表     
df_new['price'] = p
df_new

// 輸出完善好的資料到 prediction.csv 檔案
//至於這個檔案生成在哪裡, 還是去終端看下, 你此時的 jupyter notebook 執行在哪裡
df_new.to_csv('prediction.csv', index = False)

輸出:

Machine Learning (1) - Linear Regression

今天開始第二次機器學習,對一些知識點有了更深入的瞭解,把之前的筆記完善一下。

本作品採用《CC 協議》,轉載必須註明作者和本文連結

相關文章