線性迴歸 go 語言實現
先附上原文地址: https://github.com/liyue201/note/blob/master/linear-regression-go.md
線性迴歸 Go 語言實現
《動手學習深度學習》是一本非常經典的深度學習入門教程。 這個例子是其中第 3.2 節《線性迴歸從零開始》 的 go 語言實現。 目標是擬合函式
y = 2*x1 - 3.4*x2 + 4.2
這裡用到了機器學習框架gorgonia。 這個框架的思想類似 TensorFlow,通過構建圖來實現,每個計算單元是一個節點。
package main
import (
"fmt"
"gorgonia.org/gorgonia"
"gorgonia.org/tensor"
"log"
"math/rand"
"time"
)
// 這是一組線性函式的矩陣形式, y=x1*w1+x2*w2+...+wn*wn+b
// Tx 是mxn的特徵矩陣矩陣, m是樣本數,n是特徵數
// Tw 是nx1的權重矩陣
// b 是偏移量
// Ty 是一個mx1的標籤矩陣
func linear(Tx tensor.Tensor, Tw tensor.Tensor, b float64) tensor.Tensor {
temp, _ := tensor.MatMul(Tx, Tw)
Ty, _ := tensor.Add(temp, b, tensor.WithReuse(temp))
return Ty
}
// 生成特徵和標籤
// Tx:特徵
// Ty: 標籤
func genSamples(samplesNum int, trueW []float64, trueB float64) (Tx tensor.Tensor, Ty tensor.Tensor) {
rand.Seed(time.Now().Unix())
Tw := tensor.New(tensor.WithShape(len(trueW), 1), tensor.WithBacking(trueW[:]))
// 隨機生成特徵
Tx = tensor.New(tensor.Of(tensor.Float64), tensor.WithShape(samplesNum, len(trueW)), tensor.WithBacking(tensor.Random(tensor.Float64, samplesNum*2)))
// 生成標籤
Ty = linear(Tx, Tw, trueB)
// 給標籤新增噪聲
for i := 0; i < samplesNum; i++ {
v, _ := Ty.At(i, 0)
Ty.SetAt(v.(float64) + float64(rand.Int31n(10)+1)/100)
}
return
}
// 預測函式
func predictFun(nX, w, b *gorgonia.Node) *gorgonia.Node {
return gorgonia.Must(gorgonia.Add(gorgonia.Must(gorgonia.Mul(nX, w)), b))
}
// 損失函式
func lossFun(nodePred, nodeY *gorgonia.Node) *gorgonia.Node {
cost := gorgonia.Must(gorgonia.Square(gorgonia.Must(gorgonia.Sub(nodePred, nodeY))))
cost = gorgonia.Must(gorgonia.Mean(cost))
return cost
}
// 構建gorgonia的計算圖
func setup(X, Y tensor.Tensor) (nodeW, nodeB *gorgonia.Node, machine gorgonia.VM) {
g := gorgonia.NewGraph()
// 將資料轉換成圖中的節點
nodeX := gorgonia.NodeFromAny(g, X)
nodeY := gorgonia.NodeFromAny(g, Y)
nodeW = gorgonia.NewMatrix(g, gorgonia.Float64, gorgonia.WithShape(2, 1), gorgonia.WithInit(gorgonia.ValuesOf(float64(0))))
nodeB = gorgonia.NewScalar(g, gorgonia.Float64, gorgonia.WithValue(float64(0)))
// 預測函式
nodePred := predictFun(nodeX, nodeW, nodeB)
// 損失函式
cost := lossFun(nodePred, nodeY)
// 梯度
if _, err := gorgonia.Grad(cost, nodeW, nodeB); err != nil {
log.Fatalf("Failed to backpropagate: %v", err)
}
machine = gorgonia.NewTapeMachine(g, gorgonia.BindDualValues(nodeW, nodeB))
return nodeW, nodeB, machine
}
// 訓練模型
func train(batchSize int, w, b *gorgonia.Node, machine gorgonia.VM, iter int) {
model := []gorgonia.ValueGrad{w, b}
stepSize := 0.01
clip := 5.0
solver := gorgonia.NewVanillaSolver(gorgonia.WithLearnRate(stepSize), gorgonia.WithBatchSize(float64(batchSize)), gorgonia.WithClip(clip))
var err error
for i := 0; i < iter; i++ {
if err = machine.RunAll(); err != nil {
fmt.Printf("Error during iteration: %v: %v\n", i, err)
break
}
if err = solver.Step(model); err != nil {
log.Fatal(err)
}
machine.Reset()
if i%10 == 0 {
fmt.Printf("%v: %v %v\n", i, w.Value().Data(), b.Value())
}
}
}
func main() {
samplesNum := 10000
trueW := [2]float64{2, -3.4}
trueB := float64(4.2)
Tx, Ty := genSamples(samplesNum, trueW[:], trueB)
nodeW, nodeB, machine := setup(Tx, Ty)
defer machine.Close()
train(5, nodeW, nodeB, machine, 1000)
fmt.Printf("w = %v\nb = %v\n", nodeW.Value().Data(), nodeB.Value())
}
經過訓練後輸出的 w 和 b 如下,跟實際函式引數很接近了,繼續增加迭代次數可以更接近目標值。但是跟書中的 python 實現還是有些差距。可能是引數配置的差異,或者框架的問題,繼續研究中。
w = [1.9633147377650029 -3.3382087877519835]
b = 4.122876433412727
更多原創文章乾貨分享,請關注公眾號
- 加微信實戰群請加微信(註明:實戰群):gocnio
相關文章
- go語言使用切片實現線性表Go
- TensorFlow實現線性迴歸
- 【機器學習】線性迴歸sklearn實現機器學習
- pytorch實現線性迴歸PyTorch
- python實現線性迴歸之簡單迴歸Python
- 【機器學習】線性迴歸python實現機器學習Python
- 線性迴歸實戰
- 資料分析與挖掘 - R語言:多元線性迴歸R語言
- Pytorch 實現簡單線性迴歸PyTorch
- 機器學習--線性迴歸--梯度下降的實現機器學習梯度
- 利用TensorFlow實現多元線性迴歸
- 利用TensorFlow實現線性迴歸模型模型
- 線性迴歸
- 機器學習之線性迴歸(純python實現)機器學習Python
- 線性迴歸:最小二乘法實現
- 【pytorch_5】線性迴歸的實現PyTorch
- 機器學習實戰(一)—— 線性迴歸機器學習
- 【深度學習 01】線性迴歸+PyTorch實現深度學習PyTorch
- 機器學習-線性迴歸機器學習
- 1.3 - 線性迴歸
- 機器學習:線性迴歸機器學習
- 機器學習 | 線性迴歸與邏輯迴歸機器學習邏輯迴歸
- 機器學習之線性迴歸機器學習
- 多元線性迴歸模型模型
- 機器學習整理(線性迴歸)機器學習
- 線性迴歸總結
- 4-線性迴歸
- Java: 實現自迴歸分析/線性迴歸分析/基金各項指標計算等Java指標
- 詳解Go語言排程迴圈原始碼實現Go原始碼
- 梯度下降法實現最簡單線性迴歸問題python實現梯度Python
- 線性表-順序表C語言實現C語言
- 線性迴歸演算法演算法
- 【機器學習】線性迴歸預測機器學習
- 資料分析:線性迴歸
- PRML 迴歸的線性模型模型
- 機器學習5-線性迴歸機器學習
- 線性迴歸原理小結
- 線性迴歸-程式碼庫