Go 寫一個內網穿透工具

pibigstar發表於2019-12-10

系統架構

系統分為兩個部分,client 和 server,client執行在內網伺服器中,server執行在公網伺服器中,當我們想訪問內網中的服務,我們通過公網伺服器做一箇中繼。

下面是展示我靈魂畫手的時刻了

user傳送請求給 server,server和client建立連線,將請求發給client,client再將請求發給本地程式處理(內網中),然後本地程式將處理結果返回給client,client將結果返回給server,server再將結果返回給使用者,這樣使用者就訪問到了內網中的程式了。

程式碼流程

  1. server端監聽兩個埠,一個用來和user通訊,一個和client通訊
  2. client啟動時連線server端,並啟動一個埠監聽本地某程式
  3. 當User連線到server埠,將User請求內容發給client
  4. client將從server收到的請求發給本地程式
  5. client將從本地程式收到的內容發給server
  6. server將從client收到的內容發給User即可

  1. 當Server與client沒有訊息通訊,連線會斷開
  2. client斷開後,再啟動會連線不到Server
  3. Server端會因為client斷開而引發panic

為了解決這種坑點,加入了心跳包機制,通過5s傳送一次心跳包,保持client與server的連線,同時建立一個重連通道,監聽該通道,如果當Client被斷開後,則往重連通道放一個值,告訴Server端,等待新的Client連線,而避免引發Panic

程式碼

更詳細的我就不說了,直接看程式碼,程式碼裡面有詳細的註釋, 排版有問題,直接去github看吧。。。

程式碼倉庫地址: https://github.com/pibigstar/go-proxy

Server端

執行在具有公網IP地址的伺服器端

package main
import (
    "flag"
    "fmt"
    "io"
    "net"
    "runtime"
    "strings"
    "time"
)
var (
    localPort  int
    remotePort int
)
func init() {
    flag.IntVar(&localPort, "l", 5200, "the user link port")
    flag.IntVar(&remotePort, "r", 3333, "client listen port")
}
type client struct {
    conn net.Conn
    // 資料傳輸通道
    read  chan []byte
    write chan []byte
    // 異常退出通道
    exit chan error
    // 重連通道
    reConn chan bool
}
// 從Client端讀取資料
func (c *client) Read() {
    // 如果10秒鐘內沒有訊息傳輸,則Read函式會返回一個timeout的錯誤
    _ = c.conn.SetReadDeadline(time.Now().Add(time.Second * 10))
    for {
        data := make([]byte, 10240)
        n, err := c.conn.Read(data)
        if err != nil && err != io.EOF {
            if strings.Contains(err.Error(), "timeout") {
                // 設定讀取時間為3秒,3秒後若讀取不到, 則err會丟擲timeout,然後傳送心跳
                _ = c.conn.SetReadDeadline(time.Now().Add(time.Second * 3))
                c.conn.Write([]byte("pi"))
                continue
            }
            fmt.Println("讀取出現錯誤...")
            c.exit <- err
        }
        // 收到心跳包,則跳過
        if data[0] == 'p' && data[1] == 'i' {
            fmt.Println("server收到心跳包")
            continue
        }
        c.read <- data[:n]
    }
}
// 將資料寫入到Client端
func (c *client) Write() {
    for {
        select {
        case data := <-c.write:
            _, err := c.conn.Write(data)
            if err != nil && err != io.EOF {
                c.exit <- err
            }
        }
    }
}
type user struct {
    conn net.Conn
    // 資料傳輸通道
    read  chan []byte
    write chan []byte
    // 異常退出通道
    exit chan error
}
// 從User端讀取資料
func (u *user) Read() {
    _ = u.conn.SetReadDeadline(time.Now().Add(time.Second * 200))
    for {
        data := make([]byte, 10240)
        n, err := u.conn.Read(data)
        if err != nil && err != io.EOF {
            u.exit <- err
        }
        u.read <- data[:n]
    }
}
// 將資料寫給User端
func (u *user) Write() {
    for {
        select {
        case data := <-u.write:
            _, err := u.conn.Write(data)
            if err != nil && err != io.EOF {
                u.exit <- err
            }
        }
    }
}
func main() {
    flag.Parse()
    defer func() {
        err := recover()
        if err != nil {
            fmt.Println(err)
        }
    }()
    clientListener, err := net.Listen("tcp", fmt.Sprintf(":%d", remotePort))
    if err != nil {
        panic(err)
    }
    fmt.Printf("監聽:%d埠, 等待client連線... \n", remotePort)
    // 監聽User來連線
    userListener, err := net.Listen("tcp", fmt.Sprintf(":%d", localPort))
    if err != nil {
        panic(err)
    }
    fmt.Printf("監聽:%d埠, 等待user連線.... \n", localPort)
    for {
        // 有Client來連線了
        clientConn, err := clientListener.Accept()
        if err != nil {
            panic(err)
        }
        fmt.Printf("有Client連線: %s \n", clientConn.RemoteAddr())
        client := &client{
            conn:   clientConn,
            read:   make(chan []byte),
            write:  make(chan []byte),
            exit:   make(chan error),
            reConn: make(chan bool),
        }
        userConnChan := make(chan net.Conn)
        go AcceptUserConn(userListener, userConnChan)
        go HandleClient(client, userConnChan)
        <-client.reConn
        fmt.Println("重新等待新的client連線..")
    }
}
func HandleClient(client *client, userConnChan chan net.Conn) {
    go client.Read()
    go client.Write()
    for {
        select {
        case err := <-client.exit:
            fmt.Printf("client出現錯誤, 開始重試, err: %s \n", err.Error())
            client.reConn <- true
            runtime.Goexit()
        case userConn := <-userConnChan:
            user := &user{
                conn:  userConn,
                read:  make(chan []byte),
                write: make(chan []byte),
                exit:  make(chan error),
            }
            go user.Read()
            go user.Write()
            go handle(client, user)
        }
    }
}
// 將兩個Socket通道連結
// 1. 將從user收到的資訊發給client
// 2. 將從client收到資訊發給user
func handle(client *client, user *user) {
    for {
        select {
        case userRecv := <-user.read:
            // 收到從user發來的資訊
            client.write <- userRecv
        case clientRecv := <-client.read:
            // 收到從client發來的資訊
            user.write <- clientRecv
        case err := <-client.exit:
            fmt.Println("client出現錯誤, 關閉連線", err.Error())
            _ = client.conn.Close()
            _ = user.conn.Close()
            client.reConn <- true
            // 結束當前goroutine
            runtime.Goexit()
        case err := <-user.exit:
            fmt.Println("user出現錯誤,關閉連線", err.Error())
            _ = user.conn.Close()
        }
    }
}
// 等待user連線
func AcceptUserConn(userListener net.Listener, connChan chan net.Conn) {
    userConn, err := userListener.Accept()
    if err != nil {
        panic(err)
    }
    fmt.Printf("user connect: %s \n", userConn.RemoteAddr())
    connChan <- userConn
}

Client端

執行在需要內網穿透的客戶端中

package main
import (
    "flag"
    "fmt"
    "io"
    "net"
    "runtime"
    "strings"
    "time"
)
var (
    host       string
    localPort  int
    remotePort int
)
func init() {
    flag.StringVar(&host, "h", "127.0.0.1", "remote server ip")
    flag.IntVar(&localPort, "l", 8080, "the local port")
    flag.IntVar(&remotePort, "r", 3333, "remote server port")
}
type server struct {
    conn net.Conn
    // 資料傳輸通道
    read  chan []byte
    write chan []byte
    // 異常退出通道
    exit chan error
    // 重連通道
    reConn chan bool
}
// 從Server端讀取資料
func (s *server) Read() {
    // 如果10秒鐘內沒有訊息傳輸,則Read函式會返回一個timeout的錯誤
    _ = s.conn.SetReadDeadline(time.Now().Add(time.Second * 10))
    for {
        data := make([]byte, 10240)
        n, err := s.conn.Read(data)
        if err != nil && err != io.EOF {
            // 讀取超時,傳送一個心跳包過去
            if strings.Contains(err.Error(), "timeout") {
                // 3秒發一次心跳
                _ = s.conn.SetReadDeadline(time.Now().Add(time.Second * 3))
                s.conn.Write([]byte("pi"))
                continue
            }
            fmt.Println("從server讀取資料失敗, ", err.Error())
            s.exit <- err
            runtime.Goexit()
        }
        // 如果收到心跳包, 則跳過
        if data[0] == 'p' && data[1] == 'i' {
            fmt.Println("client收到心跳包")
            continue
        }
        s.read <- data[:n]
    }
}
// 將資料寫入到Server端
func (s *server) Write() {
    for {
        select {
        case data := <-s.write:
            _, err := s.conn.Write(data)
            if err != nil && err != io.EOF {
                s.exit <- err
            }
        }
    }
}
type local struct {
    conn net.Conn
    // 資料傳輸通道
    read  chan []byte
    write chan []byte
    // 有異常退出通道
    exit chan error
}
func (l *local) Read() {
    for {
        data := make([]byte, 10240)
        n, err := l.conn.Read(data)
        if err != nil {
            l.exit <- err
        }
        l.read <- data[:n]
    }
}
func (l *local) Write() {
    for {
        select {
        case data := <-l.write:
            _, err := l.conn.Write(data)
            if err != nil {
                l.exit <- err
            }
        }
    }
}
func main() {
    flag.Parse()
    target := net.JoinHostPort(host, fmt.Sprintf("%d", remotePort))
    for {
        serverConn, err := net.Dial("tcp", target)
        if err != nil {
            panic(err)
        }
        fmt.Printf("已連線server: %s \n", serverConn.RemoteAddr())
        server := &server{
            conn:   serverConn,
            read:   make(chan []byte),
            write:  make(chan []byte),
            exit:   make(chan error),
            reConn: make(chan bool),
        }
        go server.Read()
        go server.Write()
        go handle(server)
        <-server.reConn
        _ = server.conn.Close()
    }
}
func handle(server *server) {
    // 等待server端發來的資訊,也就是說user來請求server了
    data := <-server.read
    localConn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", localPort))
    if err != nil {
        panic(err)
    }
    local := &local{
        conn:  localConn,
        read:  make(chan []byte),
        write: make(chan []byte),
        exit:  make(chan error),
    }
    go local.Read()
    go local.Write()
    local.write <- data
    for {
        select {
        case data := <-server.read:
            local.write <- data
        case data := <-local.read:
            server.write <- data
        case err := <-server.exit:
            fmt.Printf("server have err: %s", err.Error())
            _ = server.conn.Close()
            _ = local.conn.Close()
            server.reConn <- true
        case err := <-local.exit:
            fmt.Printf("server have err: %s", err.Error())
            _ = local.conn.Close()
        }
    }
}

相關文章