智慧電網的電能預估及價值分析
一、實驗目的與要求
1、掌握使用pandas庫處理資料的基本方法。
2、掌握對時間序列類資料預處理的基本方法。
3、掌握使用matplotlib結合pandas庫對資料分析視覺化處理的基本方法。
二、實驗內容
1、利用python中pandas等庫讀取資料,並完成資料的預處理。
2、利用matplotlib等庫完成對資料的視覺化。
3、使用Sklearn庫的相關係數建立決策樹模型,對模型進行訓練,使用測試集測試後對模型的效果進行評價。
三、實驗步驟
1.資料預處理。讀取所提供的資料檔案,檢查檔案中時間序列是否完整,有無缺失值,重複值。
(1)匯入所需要使用的包
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.dates as mdate
from datetime import datetime, timedelta
from matplotlib.dates import DateFormatter, WeekdayLocator, DayLocator, MONDAY,YEARLY
import dateutil.relativedelta
import time
from matplotlib.pyplot import MultipleLocator
from pandas.core.common import SettingWithCopyWarning
from sklearn import tree#決策樹模型
from sklearn.model_selection import train_test_split#劃分測試集合與訓練集合
from sklearn.model_selection import GridSearchCV#用於找到最優模型
from scipy.stats import pearsonr
from sklearn.tree import DecisionTreeRegressor
#設定字型
plt.rcParams['font.sans-serif']=['SimHei']
(2)讀取檔案
file_path='/data/bigfiles/data3.csv'
data=pd.read_csv(file_path)
(3)檢視資料的基本統計資訊
data.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 87648 entries, 0 to 87647
Data columns (total 8 columns):
日期 87648 non-null object
小時 87648 non-null float64
乾球溫度 87648 non-null float64
露點溫度 87648 non-null float64
溼球溫度 87648 non-null float64
溼度 87648 non-null float64
電價 87648 non-null float64
電力負荷 87648 non-null float64
dtypes: float64(7), object(1)
memory usage: 5.3+ MB
(4)檢查資料是否完整
# 檢視資料長度
std_rng=pd.date_range(start='2006/1/1',end='2011/1/1',freq='D')
len(std_rng)
1827
(5)求日平均資料並新增日型別
data['溼度']=data['溼度'].map(lambda x :float(x))
#溼度轉換成 浮點型別
data['日最高溼度']=data["溼度"]
data['日最低溼度']=data["溼度"]
data['日均電力負荷']=data['電力負荷']
data['日最高電力負荷']=data['電力負荷']
data['日最低電力負荷']=data['電力負荷']
# 將日期設為行索引
data=data.set_index(['日期'])
data
小時 | 乾球溫度 | 露點溫度 | 溼球溫度 | 溼度 | 電價 | 電力負荷 | 日最高溼度 | 日最低溼度 | 日均電力負荷 | 日最高電力負荷 | 日最低電力負荷 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|
日期 | ||||||||||||
2006/1/1 | 0.5 | 23.90 | 21.65 | 22.40 | 87.5 | 19.67 | 8013.27833 | 87.5 | 87.5 | 8013.27833 | 8013.27833 | 8013.27833 |
2006/1/1 | 1.0 | 23.90 | 21.70 | 22.40 | 88.0 | 18.56 | 7726.89167 | 88.0 | 88.0 | 7726.89167 | 7726.89167 | 7726.89167 |
2006/1/1 | 1.5 | 23.80 | 21.65 | 22.35 | 88.0 | 19.09 | 7372.85833 | 88.0 | 88.0 | 7372.85833 | 7372.85833 | 7372.85833 |
2006/1/1 | 2.0 | 23.70 | 21.60 | 22.30 | 88.0 | 17.40 | 7071.83333 | 88.0 | 88.0 | 7071.83333 | 7071.83333 | 7071.83333 |
2006/1/1 | 2.5 | 23.70 | 21.60 | 22.30 | 88.0 | 17.00 | 6865.44000 | 88.0 | 88.0 | 6865.44000 | 6865.44000 | 6865.44000 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
2010/12/31 | 22.0 | 22.60 | 19.10 | 20.40 | 81.0 | 23.86 | 8449.54000 | 81.0 | 81.0 | 8449.54000 | 8449.54000 | 8449.54000 |
2010/12/31 | 22.5 | 22.45 | 19.05 | 20.30 | 81.5 | 26.49 | 8508.16000 | 81.5 | 81.5 | 8508.16000 | 8508.16000 | 8508.16000 |
2010/12/31 | 23.0 | 22.30 | 19.00 | 20.20 | 82.0 | 25.18 | 8413.14000 | 82.0 | 82.0 | 8413.14000 | 8413.14000 | 8413.14000 |
2010/12/31 | 23.5 | 22.05 | 19.05 | 20.15 | 83.5 | 26.19 | 8173.79000 | 83.5 | 83.5 | 8173.79000 | 8173.79000 | 8173.79000 |
2011/1/1 | 0.0 | 21.80 | 19.10 | 20.10 | 85.0 | 24.62 | 8063.36000 | 85.0 | 85.0 | 8063.36000 | 8063.36000 | 8063.36000 |
87648 rows × 12 columns
# 將日期型別轉換成datatime
data.index=data.index.map(lambda x :datetime.strptime(x,'%Y/%m/%d'))
data['日']=data.index
# 定義中文包
week_cn = [1,2,3,4,5,6,7]
data['日']=data['日'].map(lambda x: week_cn[pd.to_datetime(x).weekday()])
datas =data.groupby('日期').agg({"乾球溫度":"mean","露點溫度":"mean","溼球溫度":"mean","溼度":"mean","日最高溼度":"max","日最低溼度":"min","電價":"mean","日最高電力負荷":"max","日最低電力負荷":"min","日均電力負荷":"mean","日":"mean"})
# 對缺失的值進行填補
datas.head(30)
乾球溫度 | 露點溫度 | 溼球溫度 | 溼度 | 日最高溼度 | 日最低溼度 | 電價 | 日最高電力負荷 | 日最低電力負荷 | 日均電力負荷 | 日 | |
---|---|---|---|---|---|---|---|---|---|---|---|
日期 | |||||||||||
2006-01-01 | 32.790426 | 16.987234 | 22.910638 | 50.244681 | 90.0 | 15.0 | 43.721702 | 11112.76333 | 6388.27833 | 9188.743865 | 7 |
2006-01-02 | 20.855208 | 17.447917 | 18.757292 | 81.072917 | 92.0 | 73.0 | 16.273542 | 8456.72833 | 6364.80333 | 7686.274860 | 1 |
2006-01-03 | 23.122917 | 16.632292 | 19.172917 | 69.770833 | 92.0 | 33.0 | 43.863958 | 10674.23500 | 6035.49833 | 8664.448403 | 2 |
2006-01-04 | 20.312500 | 15.073958 | 17.218750 | 71.864583 | 80.5 | 67.0 | 16.958333 | 9115.16167 | 6232.96667 | 8060.226216 | 3 |
2006-01-05 | 19.586458 | 18.015625 | 18.608333 | 90.718750 | 96.0 | 76.0 | 16.535625 | 9145.73833 | 6062.90500 | 8081.884756 | 4 |
2006-01-06 | 22.277083 | 18.770833 | 20.056250 | 81.177083 | 98.0 | 64.0 | 16.927500 | 9673.71000 | 6100.96833 | 8292.417466 | 5 |
2006-01-07 | 22.555208 | 17.836458 | 19.612500 | 75.677083 | 93.0 | 59.0 | 15.888958 | 8541.36500 | 5979.96667 | 7620.753992 | 6 |
2006-01-08 | 23.533333 | 17.887500 | 20.000000 | 71.250000 | 85.0 | 58.0 | 16.264167 | 8504.99667 | 5841.69333 | 7552.765243 | 7 |
2006-01-09 | 24.022917 | 20.504167 | 21.719792 | 81.364583 | 95.0 | 64.0 | 19.757292 | 10643.02833 | 6045.60667 | 8844.571910 | 1 |
2006-01-10 | 25.386458 | 21.983333 | 23.097917 | 81.875000 | 93.0 | 67.0 | 32.939375 | 12174.29000 | 6493.04667 | 9754.662396 | 2 |
2006-01-11 | 25.367708 | 21.119792 | 22.548958 | 78.750000 | 95.0 | 47.0 | 31.764167 | 12221.79500 | 6850.93167 | 9746.585139 | 3 |
2006-01-12 | 23.495833 | 20.695833 | 21.661458 | 84.604167 | 95.0 | 69.0 | 25.126042 | 10858.03167 | 6574.69333 | 9164.433056 | 4 |
2006-01-13 | 23.608333 | 20.312500 | 21.480208 | 82.031250 | 90.0 | 72.0 | 19.868750 | 10301.34000 | 6581.40500 | 8850.319861 | 5 |
2006-01-14 | 24.253125 | 20.995833 | 22.080208 | 82.739583 | 94.0 | 67.0 | 21.000417 | 10652.51833 | 6256.08833 | 8635.826492 | 6 |
2006-01-15 | 22.350000 | 17.861458 | 19.575000 | 76.479167 | 96.0 | 61.0 | 15.725417 | 8478.99333 | 6044.62167 | 7662.240695 | 7 |
2006-01-16 | 22.698958 | 21.400000 | 21.826042 | 92.614583 | 98.0 | 80.0 | 17.917500 | 10843.35667 | 6218.35667 | 9027.497256 | 1 |
2006-01-17 | 23.360417 | 20.847917 | 21.706250 | 86.572917 | 98.0 | 63.0 | 18.510833 | 10589.74500 | 6515.38500 | 9099.179757 | 2 |
2006-01-18 | 20.894792 | 17.382292 | 18.737500 | 80.614583 | 93.0 | 69.0 | 16.334583 | 9977.47167 | 6450.04000 | 8585.024723 | 3 |
2006-01-19 | 21.169792 | 15.673958 | 17.893750 | 71.177083 | 83.0 | 58.0 | 20.896667 | 9719.08833 | 6203.51333 | 8484.019445 | 4 |
2006-01-20 | 23.046875 | 16.936458 | 19.292708 | 68.937500 | 82.0 | 54.0 | 22.438750 | 9970.94667 | 6319.94000 | 8672.169306 | 5 |
2006-01-21 | 24.168750 | 19.107292 | 20.911458 | 73.968750 | 84.5 | 62.0 | 46.051458 | 10521.93167 | 6199.52333 | 8693.989653 | 6 |
2006-01-22 | 24.412500 | 19.813542 | 21.429167 | 76.468750 | 94.0 | 59.0 | 36.657500 | 10510.95667 | 6259.67333 | 8707.004375 | 7 |
2006-01-23 | 24.558333 | 20.515625 | 21.896875 | 78.666667 | 90.0 | 66.0 | 106.587292 | 12674.64333 | 6590.63167 | 10124.172847 | 1 |
2006-01-24 | 21.555208 | 18.565625 | 19.681250 | 83.406250 | 93.0 | 69.0 | 22.543125 | 10597.25833 | 6839.21667 | 9087.451111 | 2 |
2006-01-25 | 20.404167 | 17.433333 | 18.590625 | 83.593750 | 97.0 | 68.0 | 22.963333 | 9902.47667 | 6404.41167 | 8647.444028 | 3 |
2006-01-26 | 22.982292 | 19.335417 | 20.641667 | 80.083333 | 90.0 | 70.0 | 23.195625 | 9404.52500 | 6215.56000 | 8193.167396 | 4 |
2006-01-27 | 24.482292 | 20.520833 | 21.882292 | 79.229167 | 94.0 | 66.0 | 39.246042 | 11952.90000 | 6412.55167 | 9484.763229 | 5 |
2006-01-28 | 24.717708 | 18.366667 | 20.701042 | 69.104167 | 87.0 | 50.0 | 19.693750 | 10059.30333 | 6389.71833 | 8544.133229 | 6 |
2006-01-29 | 24.348958 | 18.191667 | 20.469792 | 69.854167 | 90.0 | 47.0 | 21.477917 | 9810.24500 | 6177.64000 | 8231.146563 | 7 |
2006-01-30 | 24.290625 | 19.226042 | 21.026042 | 74.427083 | 91.0 | 58.0 | 25.206458 | 12201.74667 | 6403.69833 | 9689.812222 | 1 |
(6)儲存預處理後的檔案
data.to_csv('/data/bigfiles/預處理後檔案.csv')
2、資料視覺化。
(1)讀取預處理後的檔案
data=pd.read_csv('/data/bigfiles/預處理後檔案.csv')
(2)繪製各氣象資訊的時間序列曲線
charts=datas["2006/01/01":"2010/02/01"]
charts.index=charts.index.map(lambda x:str(x)[:10])
# 建立一個畫布
fig = plt.figure(figsize=(15.6,7.2))
# 在畫布上新增一個子檢視
ax = plt.subplot(111)
# 將x軸的刻度進行格式化
x_major_locator=MultipleLocator(300) # 把x軸的刻度間隔設定為原來的2倍
ax.xaxis.set_major_locator(x_major_locator) # 把x軸的主刻度設定為1的倍數
# 畫折線
line1,=ax.plot(charts.index,charts['溼度'],'r-',label='溼度')
plt.legend()
<matplotlib.legend.Legend at 0x7f056d061f98>
# 建立一個畫布
fig = plt.figure(figsize=(15.6,7.2))
# 在畫布上新增一個子檢視
ax = plt.subplot(111)
# 將x軸的刻度進行格式化
x_major_locator=MultipleLocator(300) # 把x軸的刻度間隔設定為原來的2倍
ax.xaxis.set_major_locator(x_major_locator) # 把x軸的主刻度設定為1的倍數
# 畫折線
line2,=ax.plot(charts.index,charts['日最高溼度'],'b--',label='日最高溼度')
plt.legend()
<matplotlib.legend.Legend at 0x7f06514f3908>
# 建立一個畫布
fig = plt.figure(figsize=(15.6,7.2))
# 在畫布上新增一個子檢視
ax = plt.subplot(111)
# 將x軸的刻度進行格式化
x_major_locator=MultipleLocator(300) # 把x軸的刻度間隔設定為原來的2倍
ax.xaxis.set_major_locator(x_major_locator) # 把x軸的主刻度設定為1的倍數
# 畫折線
line3,=ax.plot(charts.index,charts['日最低溼度'],'g-',label='日最低溼度')
plt.legend()
<matplotlib.legend.Legend at 0x7f056cf01a58>
# 建立一個畫布
fig = plt.figure(figsize=(15.6,7.2))
# 在畫布上新增一個子檢視
ax = plt.subplot(111)
# 將x軸的刻度進行格式化
x_major_locator=MultipleLocator(300) # 把x軸的刻度間隔設定為原來的2倍
ax.xaxis.set_major_locator(x_major_locator) # 把x軸的主刻度設定為1的倍數
# 畫折線
line4,=ax.plot(charts.index,charts['露點溫度'],'c--',label='露點溫度')
plt.legend()
<matplotlib.legend.Legend at 0x7f056ce78c88>
# 建立一個畫布
fig = plt.figure(figsize=(15.6,7.2))
# 在畫布上新增一個子檢視
ax = plt.subplot(111)
# 將x軸的刻度進行格式化
x_major_locator=MultipleLocator(300) # 把x軸的刻度間隔設定為原來的2倍
ax.xaxis.set_major_locator(x_major_locator) # 把x軸的主刻度設定為1的倍數
# 畫折線
line5,=ax.plot(charts.index,charts['溼球溫度'],'y-',label='溼球溫度')
plt.legend()
<matplotlib.legend.Legend at 0x7f056cdfb518>
(3)繪製電價和電力負荷的時間序列曲線
chart=datas["2006/01/01":"2010/02/01"]
chart.index=chart.index.map(lambda x:str(x)[:10])
# 建立一個畫布
fig = plt.figure(figsize=(15.6,7.2))
# 在畫布上新增一個子檢視
ax = plt.subplot(111)
#將x軸的刻度進行格式化
# ax.xaxis.set_major_formatter(mdate.DateFormatter('%Y/%m/%d'))
x_major_locator=MultipleLocator(300) # 把x軸的刻度間隔設定為原來的2倍
ax.xaxis.set_major_locator(x_major_locator) # 把x軸的主刻度設定為1的倍數
# 畫折線
line1,=ax.plot(chart.index,chart['日均電力負荷'],'r-',label='日均電力負荷')
plt.legend()
<matplotlib.legend.Legend at 0x7f056cd7e160>
# 建立一個畫布
fig = plt.figure(figsize=(15.6,7.2))
# 在畫布上新增一個子檢視
ax = plt.subplot(111)
#將x軸的刻度進行格式化
# ax.xaxis.set_major_formatter(mdate.DateFormatter('%Y/%m/%d'))
x_major_locator=MultipleLocator(300) # 把x軸的刻度間隔設定為原來的2倍
ax.xaxis.set_major_locator(x_major_locator) # 把x軸的主刻度設定為1的倍數
# 畫折線
line2,=ax.plot(chart.index,chart['日最高電力負荷'],'b--',label='日最高電力負荷')
plt.legend()
<matplotlib.legend.Legend at 0x7f056cc59a58>
# 建立一個畫布
fig = plt.figure(figsize=(15.6,7.2))
# 在畫布上新增一個子檢視
ax = plt.subplot(111)
#將x軸的刻度進行格式化
# ax.xaxis.set_major_formatter(mdate.DateFormatter('%Y/%m/%d'))
x_major_locator=MultipleLocator(300) # 把x軸的刻度間隔設定為原來的2倍
ax.xaxis.set_major_locator(x_major_locator) # 把x軸的主刻度設定為1的倍數
# 畫折線
line3,=ax.plot(chart.index,chart['日最低電力負荷'],'g-',label='日最低電力負荷')
plt.legend()
<matplotlib.legend.Legend at 0x7f056436af60>
chart=datas["2006/01/01":"2010/02/01"]
chart.index=chart.index.map(lambda x:str(x)[:10])
# 建立一個畫布
fig = plt.figure(figsize=(15.6,7.2))
# 在畫布上新增一個子檢視
ax = plt.subplot(111)
# 將x軸的刻度進行格式化
# ax.xaxis.set_major_formatter(mdate.DateFormatter('%Y/%m/%d'))
x_major_locator=MultipleLocator(300) # 把x軸的刻度間隔設定為原來的2倍
ax.xaxis.set_major_locator(x_major_locator) # 把x軸的主刻度設定為1的倍數
# 畫折線
line=ax.plot(chart.index,chart['電價'],'y-',label='電價')
plt.legend()
<matplotlib.legend.Legend at 0x7f05642b8d68>
(4)編寫中值濾波函式消除明顯的噪音
def median_filter(data,sizes):
filter_data=np.zeros_like(data)
halfwindow=sizes
for i in range (halfwindow,len(data)-halfwindow):
window=data[i - halfwindow:i + halfwindow + 1]
filter_data[i]=np.median(window)
return filter_data
pd.options.mode.chained_assignment = None # 預設模式
chart['乾球溫度']=median_filter(chart['乾球溫度'],5)
chart['日均電力負荷']=median_filter(chart['日均電力負荷'],5)
chart['日最高電力負荷']=median_filter(chart['日最高電力負荷'],5)
chart['日最低電力負荷']=median_filter(chart['日最低電力負荷'],5)
chart['露點溫度']=median_filter(chart['露點溫度'],5)
(5)繪製濾波後的時序曲線
# 建立一個畫布
fig = plt.figure(figsize=(15.6,7.2))
# 在畫布上新增一個子檢視
ax = plt.subplot(111)
# 這裡很重要 需要 將 x軸的刻度 進行格式化
# ax.xaxis.set_major_formatter(mdate.DateFormatter('%Y/%m/%d'))
x_major_locator=MultipleLocator(300) # 把x軸的刻度間隔設定為原來的2倍
ax.xaxis.set_major_locator(x_major_locator) # 把x軸的主刻度設定為1的倍數
# 畫折線
line1,=ax.plot(chart.index,chart['日均電力負荷'],'r-',label='日均電力負荷')
plt.legend()
<matplotlib.legend.Legend at 0x7f056402a7f0>
# 建立一個畫布
fig = plt.figure(figsize=(15.6,7.2))
# 在畫布上新增一個子檢視
ax = plt.subplot(111)
# 這裡很重要 需要 將 x軸的刻度 進行格式化
# ax.xaxis.set_major_formatter(mdate.DateFormatter('%Y/%m/%d'))
x_major_locator=MultipleLocator(300) # 把x軸的刻度間隔設定為原來的2倍
ax.xaxis.set_major_locator(x_major_locator) # 把x軸的主刻度設定為1的倍數
# 畫折線
line2,=ax.plot(chart.index,chart['日最高電力負荷'],'b--',label='日最高電力負荷')
plt.legend()
<matplotlib.legend.Legend at 0x7f056cdd9550>
# 建立一個畫布
fig = plt.figure(figsize=(15.6,7.2))
# 在畫布上新增一個子檢視
ax = plt.subplot(111)
# 這裡很重要 需要 將 x軸的刻度 進行格式化
# ax.xaxis.set_major_formatter(mdate.DateFormatter('%Y/%m/%d'))
x_major_locator=MultipleLocator(300) # 把x軸的刻度間隔設定為原來的2倍
ax.xaxis.set_major_locator(x_major_locator) # 把x軸的主刻度設定為1的倍數
# 畫折線
line3,=ax.plot(chart.index,chart['日最低電力負荷'],'g-',label='日最低電力負荷')
plt.legend()
<matplotlib.legend.Legend at 0x7f056cde05c0>
3.相關係數。求出各量與電力負荷之間的皮爾遜相關係數,選擇相關係數絕對值前3高的屬性作為特徵屬性,用於下一步進行模型訓練。
#通常情況下透過以下取值範圍判斷變數的相關強度: 相關係數 0.8-1.0 極強相關
#0.6-0.8 強相關 0.4-0.6 中等程度相關 0.2-0.4 弱相關 0.0-0.2 極弱相關或無相關
x=np.array([1,3,5])
y=np.array([1,3,4])
pc = pearsonr(x,y)
print("相關係數:",pc[0])
print("顯著性水平:",pc[1])
相關係數: 0.9819805060619655
顯著性水平: 0.1210377183236774
pccs = pearsonr(chart['溼度'],chart['日均電力負荷'])
print('溼度')
print("相關係數:",pccs[0])
print("顯著性水平:",pccs[1])
溼度
相關係數: 0.008366098656828478
顯著性水平: 0.7466985868452606
pccs = pearsonr(chart['乾球溫度'],chart['日均電力負荷'])
print('乾球溫度')
print("相關係數:",pccs[0])
print("顯著性水平:",pccs[1])
乾球溫度
相關係數: 0.13553654698519382
顯著性水平: 1.464772478253176e-07
pccs = pearsonr(chart['溼球溫度'],chart['日均電力負荷'])
print('溼球溫度')
print("相關係數:",pccs[0])
print("顯著性水平:",pccs[1])
溼球溫度
相關係數: -0.2234887241034933
顯著性水平: 2.3520368633274703e-18
pccs = pearsonr(chart['露點溫度'],chart['日均電力負荷'])
print('露點溫度')
print("相關係數:",pccs[0])
print("顯著性水平:",pccs[1])
露點溫度
相關係數: 0.11306119537985258
顯著性水平: 1.192249332196542e-05
pccs = pearsonr(chart['電價'],chart['日均電力負荷'])
print('電價')
print("相關係數:",pccs[0])
print("顯著性水平:",pccs[1])
電價
相關係數: 0.12163330223784043
顯著性水平: 2.436096002707336e-06
pccs = pearsonr(chart['日最高電力負荷'],chart['日均電力負荷'])
print('日最高電力負荷')
print("相關係數:",pccs[0])
print("顯著性水平:",pccs[1])
日最高電力負荷
相關係數: 0.9766023264436509
顯著性水平: 0.0
pccs = pearsonr(chart['日最低電力負荷'],chart['日均電力負荷'])
print('日最低電力負荷')
print("相關係數:",pccs[0])
print("顯著性水平:",pccs[1])
日最低電力負荷
相關係數: 0.9390036936914716
顯著性水平: 0.0
4.資料分析。使用上一步選擇的3個特徵屬性作為輸入屬性,電力負荷作為輸出屬性,合理劃分訓練集與測試集比例,選擇適合的引數,使用Sklearn建立決策樹模型,並對模型進行測試。
(1)建立決策樹模型
X=pd.concat([chart['溼球溫度'],chart['乾球溫度'],chart['電價']],axis=1)
Y=chart['日均電力負荷']
# 劃分測試與訓練集
Xtrain,Xtest,Ytrain,Ytest=train_test_split(X,Y,test_size=0.1,random_state=420)
# 選擇最優引數
tree_param={'criterion':['mse','friedman_mse','mae'],'max_depth':list(range(10))}
# GridSearchCV網格搜尋,搜尋的是引數,即在指定的引數範圍內,按步長依次調整引數,利用調整的引數訓練學習器,從所有的引數中找到在驗證集上精度最高的引數,這其實是一個訓練和比較的過程。k折交叉驗證將所有資料集分成k份,不重複地每次取其中一份做測試集,
# 用其餘k-1份做訓練集訓練模型,之後計算該模型在測試集上的得分,將k次的得分取平均得到最後的得分。
#例項化物件
grid=GridSearchCV(tree.DecisionTreeRegressor(),param_grid=tree_param,cv=3)
regressor = DecisionTreeRegressor(max_depth=5, min_samples_split=10)
grid = GridSearchCV(regressor, param_grid={'max_depth': [3, 5, 7], 'min_samples_split': [2, 5, 10]})
grid.fit(Xtrain, Ytrain)
#最優引數,最優分數
grid.best_params_,grid.best_score_
#建立迴歸樹
dtr=tree.DecisionTreeRegressor(criterion='mae',max_depth =5)
#訓練決策樹
#預測訓練結果
dtr.fit(Xtrain,Ytrain)
pred=dtr.predict(Xtest)
(2)繪製預測結果
fig=plt.figure(figsize=(15.6,7.2))
ax=fig.add_subplot(111)
s1=ax.scatter(range(len(pred)),pred,facecolors="red",label='預測')
s2=ax.scatter(range(len(Ytest)),Ytest,facecolors="blue",label='實際')
plt.legend()
<matplotlib.legend.Legend at 0x7f055ef94080>
實驗總結:
自己總結一哈!