一行命令為專案檔案新增開源協議頭

江湖十年發表於2024-11-04

公眾號首發地址:https://mp.weixin.qq.com/s/AmOq5yaDGbAerWGOiunMyQ

今天給大家介紹一款可以為專案檔案新增開源協議頭資訊的命令列工具 addlicense

如果一個現有的專案,想要開源,免不了要為專案中的檔案增加開源協議頭資訊。雖然很多 IDE 都可以為新建立的檔案自動增加頭資訊,但修改已有的檔案還是要麻煩些。好在我們有 addlicense 工具可以使用,一行命令就能搞定。並且 addlicense 是用 Go 語言開發的,本文不僅教你如何使用,還會對其原始碼進行分析講解。

安裝

使用如下命令安裝 addlicense

$ go install github.com/superproj/addlicense

使用 -h/--help 檢視幫助資訊:

$ addlicense -h
Usage: addlicense [flags] pattern [pattern ...]

The program ensures source code files have copyright license headers
by scanning directory patterns recursively.

It modifies all source files in place and avoids adding a license header
to any file that already has one.

The pattern argument can be provided multiple times, and may also refer
to single files.

Flags:

      --check                check only mode: verify presence of license headers and exit with non-zero code if missing
  -h, --help                 show this help message
  -c, --holder string        copyright holder (default "Google LLC")
  -l, --license string       license type: apache, bsd, mit, mpl (default "apache")
  -f, --licensef string      custom license file (no default)
      --skip-dirs strings    regexps of directories to skip
      --skip-files strings   regexps of files to skip
  -v, --verbose              verbose mode: print the name of the files that are modified
  -y, --year string          copyright year(s) (default "2024")

引數說明:

  • --check 只檢查檔案是否存在 License,執行後會列印所有不包含 License 版權頭資訊的檔名。
  • -h/--help 檢視 addlicense 使用幫助資訊,我們已經使用過了。
  • -c/--holder 指定 License 的版權所有者。
  • -l/--license 指定 License 的協議型別,目前內建支援了 Apache 2.0BSDMITMPL 2.0 協議。
  • -f/--licensef 指定自定義的 License 標頭檔案。
  • --skip-dirs 跳過指定的目錄。
  • --skip-files 跳過指定的檔案。
  • -v/--verbose 列印被更改的檔名。
  • -y/--year 指定 License 的版權起始年份。

使用

準備實驗的目錄如下:

$ tree data -a
data
├── a
│   ├── main.go
│   └── main_test.go
├── b
│   └── c
│       └── keep
├── c
│   └── main.py
├── d.go
└── d_test.go

5 directories, 6 files

使用內建 License

檢查 data 目錄下的所有檔案是否存在 License 頭資訊:

$ addlicense --check data
data/a/main_test.go
data/d_test.go
data/d.go
data/c/main.py
data/a/main.go

輸出了沒有 License 頭資訊的檔案。可以發現,這裡自動跳過了沒有字尾名的檔案 keep

NOTE:
因為 addlicense是併發操作多個目錄,所以每次執行列印結果順序可能不同。

為缺失 License 頭資訊的檔案新增 License 頭資訊:

$ addlicense -v -l mit -c 江湖十年 --skip-dirs=c data
data/a/main_test.go added license
data/a/main.go added license
data/d.go added license
data/d_test.go added license

輸出了所有本次命令增加了 License 頭資訊的檔案。

data/a/main.go 效果如下:

// Copyright (c) 2024 江湖十年
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of
// this software and associated documentation files (the "Software"), to deal in
// the Software without restriction, including without limitation the rights to
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
// the Software, and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

package main

import "fmt"

...

指定自定義 License

我們也可以指定自定義的 License 檔案 boilerplate.txt 內容如下:

Copyright 2024 jianghushinian <jianghushinian007@outlook.com>. All rights reserved.
Use of this source code is governed by a MIT style
license that can be found in the LICENSE file. The original repo for
this file is https://github.com/jianghushinian/blog-go-example.

為缺失 License 頭資訊的檔案新增 License 頭資訊:

$ addlicense -v -f ./boilerplate.txt --skip-dirs=^a$ --skip-files=d.go,d_test.go data
data/c/main.py added license
NOTE:
注意這次的命令使用了正則 --skip-dirs=^a$ 來跳過目錄 a,沒有直接使用 --skip-dirs=a 是因為如果這樣做會跳過整個 data 目錄,不再進一步遍歷子目錄。稍後閱讀完 addlicense 原始碼就知道為什麼會這樣了。

data/c/main.py 效果如下:

# Copyright 2024 jianghushinian <jianghushinian007@outlook.com>. All rights reserved.
# Use of this source code is governed by a MIT style
# license that can be found in the LICENSE file. The original repo for
# this file is https://github.com/jianghushinian/blog-go-example.

def main():
    print("Hello Python")
...

原始碼解讀

我們學會了 addlicense 命令列工具如何使用,接下來可以深入其原始碼,來看看它是如何實現的。這樣在使用過程中如果出現任何問題,也方便排查。

addlicense 專案很小,專案原始檔如下:

$ tree addlicense                        
addlicense
├── Makefile
├── README.md
├── boilerplate.txt
├── go.mod
├── go.sum
└── main.go

1 directory, 6 files

addlicense 的程式碼邏輯,其實只有一個 main.go 檔案,我們來對其程式碼進行逐行分析。

開啟 main.go 檔案,首先映入眼簾的就是 License 頭資訊:

// Copyright 2020 Lingfei Kong <colin404@foxmail.com>. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.

// This program ensures source code files have copyright license headers.
// See usage with "addlicense -h".
package main

import (
    "bufio"
    "bytes"
    "errors"
    "fmt"
    "html/template"
    "io/ioutil"
    "os"
    "path/filepath"
    "regexp"
    "strings"
    "time"
    "unicode"

    "github.com/spf13/pflag"
    "golang.org/x/sync/errgroup"
)

License 頭資訊下面就是正常的 Go 包宣告和匯入資訊。

接下來是幾個常量的定義:

const helpText = `Usage: addlicense [flags] pattern [pattern ...]

The program ensures source code files have copyright license headers
by scanning directory patterns recursively.

It modifies all source files in place and avoids adding a license header
to any file that already has one.

The pattern argument can be provided multiple times, and may also refer
to single files.

Flags:
`

const tmplApache = `Copyright {{.Year}} {{.Holder}}

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

     http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.`

const tmplBSD = `Copyright (c) {{.Year}} {{.Holder}} All rights reserved.
Use of this source code is governed by a BSD-style
license that can be found in the LICENSE file.`

const tmplMIT = `Copyright (c) {{.Year}} {{.Holder}}

Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
the Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
the Software, and to permit persons to whom the Software is furnished to do so,
subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.`

const tmplMPL = `This Source Code Form is subject to the terms of the Mozilla Public
License, v. 2.0. If a copy of the MPL was not distributed with this
file, You can obtain one at https://mozilla.org/MPL/2.0/.`

常量 helpText 就是使用 -h/--help 列印幫助資訊最上面的內容,回去看看是不是能對應上。

剩下的幾個常量就是內建支援的 License 頭資訊了,分別是 ApacheBSDMITMPL 協議。看到每個 License 頭資訊中的 { {.Year} } { {.Holder} } 就知道,這是 Go template 的模板語法。

然後,我們能看到定義的所有命令列標誌都在這裡了:

var (
    holder    = pflag.StringP("holder", "c", "Google LLC", "copyright holder")
    license   = pflag.StringP("license", "l", "apache", "license type: apache, bsd, mit, mpl")
    licensef  = pflag.StringP("licensef", "f", "", "custom license file (no default)")
    year      = pflag.StringP("year", "y", fmt.Sprint(time.Now().Year()), "copyright year(s)")
    verbose   = pflag.BoolP("verbose", "v", false, "verbose mode: print the name of the files that are modified")
    checkonly = pflag.BoolP(
        "check",
        "",
        false,
        "check only mode: verify presence of license headers and exit with non-zero code if missing",
    )
    skipDirs  = pflag.StringSliceP("skip-dirs", "", nil, "regexps of directories to skip")
    skipFiles = pflag.StringSliceP("skip-files", "", nil, "regexps of files to skip")
    help      = pflag.BoolP("help", "h", false, "show this help message")
)

這裡使用了 pflag 庫來定義所有命令列標誌,每個標誌的作用已經在前文講解過了,忘記的讀者可以翻上去回顧一下。

可以發現 --skip-dirs--skip-files 兩個標誌都是 slice 型別,傳入格式為 a,b,c

NOTE:
如果你不太熟悉 pflag 庫,可以參考我的另一篇文章《Go 命令列引數解析工具 pflag 使用》

下面就進入主邏輯 main 函式了:

func main() {
    pflag.Usage = usage
    pflag.Parse()

    if *help {
        pflag.Usage()
        os.Exit(1)
    }

    if pflag.NArg() == 0 {
        pflag.Usage()
        os.Exit(1)
    }

    if len(*skipDirs) != 0 {
        ps, err := getPatterns(*skipDirs)
        if err != nil {
            fmt.Println(err.Error())
            os.Exit(1)
        }

        patterns.dirs = ps
    }

    if len(*skipFiles) != 0 {
        ps, err := getPatterns(*skipFiles)
        if err != nil {
            fmt.Println(err.Error())
            os.Exit(1)
        }

        patterns.files = ps
    }

    data := &copyrightData{
        Year:   *year,
        Holder: *holder,
    }

    var t *template.Template
    if *licensef != "" {
        d, err := ioutil.ReadFile(*licensef)
        if err != nil {
            fmt.Printf("license file: %v\n", err)
            os.Exit(1)
        }
        t, err = template.New("").Parse(string(d))
        if err != nil {
            fmt.Printf("license file: %v\n", err)
            os.Exit(1)
        }
    } else {
        t = licenseTemplate[*license]
        if t == nil {
            fmt.Printf("unknown license: %s\n", *license)
            os.Exit(1)
        }
    }

    // process at most 1000 files in parallel
    ch := make(chan *file, 1000)
    done := make(chan struct{})
    go func() {
        var wg errgroup.Group
        for f := range ch {
            f := f // https://golang.org/doc/faq#closures_and_goroutines
            wg.Go(func() error {
                // nolint: nestif
                if *checkonly {
                    // Check if file extension is known
                    lic, err := licenseHeader(f.path, t, data)
                    if err != nil {
                        fmt.Printf("%s: %v\n", f.path, err)

                        return err
                    }
                    if lic == nil { // Unknown fileExtension
                        return nil
                    }
                    // Check if file has a license
                    isMissingLicenseHeader, err := fileHasLicense(f.path)
                    if err != nil {
                        fmt.Printf("%s: %v\n", f.path, err)

                        return err
                    }
                    if isMissingLicenseHeader {
                        fmt.Printf("%s\n", f.path)

                        return errors.New("missing license header")
                    }
                } else {
                    modified, err := addLicense(f.path, f.mode, t, data)
                    if err != nil {
                        fmt.Printf("%s: %v\n", f.path, err)

                        return err
                    }
                    if *verbose && modified {
                        fmt.Printf("%s added license\n", f.path)
                    }
                }

                return nil
            })
        }
        err := wg.Wait()
        close(done)
        if err != nil {
            os.Exit(1)
        }
    }()

    for _, d := range pflag.Args() {
        walk(ch, d)
    }
    close(ch)
    <-done
}

這裡邏輯很長,咱們一點點來拆解閱讀。

首先是對命令列標誌的處理:

pflag.Usage = usage
pflag.Parse()

if *help {
    pflag.Usage()
    os.Exit(1)
}

if pflag.NArg() == 0 {
    pflag.Usage()
    os.Exit(1)
}

pflag.Usage 欄位是一個函式,用來列印使用幫助資訊,變數 usage 定義如下:

var (
    ...
    usage           = func() {
        fmt.Println(helpText)
        pflag.PrintDefaults()
    }
)

if *help 就是對 -h/--help 標誌進行判斷,如果使用者輸入了此標誌,就列印幫助資訊,並直接退出程式。

pflag.NArg() 返回處理完標誌後剩餘的引數個數,用來指定需要處理的目錄。如果使用者沒傳,同樣列印幫助資訊並退出。

如果執行 addlicense -v -l mit -c 江湖十年 a b c 命令,pflag.NArg() 會返回 abc 三個目錄。我們至少要傳一個搜尋路徑,不然 addlicense 會不知道去找哪些檔案。

你可能想,這裡也可以設定為預設查詢當前目錄,即預設目錄為 .。但是我個人不推薦這種設計,因為 addlicense 會修改檔案,最好還是使用者明確傳了哪個目錄,再去操作。不然假如使用者不小心在家目錄下執行了這個命令,所有檔案都被改了。

顯然,在這個場景中,顯式勝於隱式。

接下來是對 --skip-dirs--skip-files 兩個命令列標誌的處理:

if len(*skipDirs) != 0 {
    ps, err := getPatterns(*skipDirs)
    if err != nil {
        fmt.Println(err.Error())
        os.Exit(1)
    }

    patterns.dirs = ps
}

if len(*skipFiles) != 0 {
    ps, err := getPatterns(*skipFiles)
    if err != nil {
        fmt.Println(err.Error())
        os.Exit(1)
    }

    patterns.files = ps
}

跳過的目錄和檔案都透過 getPatterns 函式來轉換成正規表示式,並賦值給 patterns 物件。

patternsgetPatterns 定義如下:

var patterns = struct {
    dirs  []*regexp.Regexp
    files []*regexp.Regexp
}{}

func getPatterns(patterns []string) ([]*regexp.Regexp, error) {
    patternsRe := make([]*regexp.Regexp, 0, len(patterns))
    for _, p := range patterns {
        patternRe, err := regexp.Compile(p)
        if err != nil {
            fmt.Printf("can't compile regexp %q\n", p)

            return nil, fmt.Errorf("compile regexp failed, %w", err)
        }
        patternsRe = append(patternsRe, patternRe)
    }

    return patternsRe, nil
}

接著又構建了一個 copyrightData 物件:

data := &copyrightData{
    Year:   *year,
    Holder: *holder,
}

其中 holder 是透過 -c/--holder 傳入的,year 是透過 -y--year 傳入的,year不傳預設值就是當前年份。

data 變數稍後將用於渲染模板。

而接下來就是構造模版邏輯:

var t *template.Template
if *licensef != "" {
    d, err := ioutil.ReadFile(*licensef)
    if err != nil {
        fmt.Printf("license file: %v\n", err)
        os.Exit(1)
    }
    t, err = template.New("").Parse(string(d))
    if err != nil {
        fmt.Printf("license file: %v\n", err)
        os.Exit(1)
    }
} else {
    t = licenseTemplate[*license]
    if t == nil {
        fmt.Printf("unknown license: %s\n", *license)
        os.Exit(1)
    }
}

if *licensef != "" 表示如果使用者使用-f/--licensef 指定了自定義的 License 標頭檔案,則進入此邏輯,讀取其中內容作為模板。

否則,使用預設支援的版權內容作為模板。licenseTemplate 是一個全域性變數,並在 init中被初始化:

var (
    licenseTemplate = make(map[string]*template.Template)
    ...
)

func init() {
    licenseTemplate["apache"] = template.Must(template.New("").Parse(tmplApache))
    licenseTemplate["mit"] = template.Must(template.New("").Parse(tmplMIT))
    licenseTemplate["bsd"] = template.Must(template.New("").Parse(tmplBSD))
    licenseTemplate["mpl"] = template.Must(template.New("").Parse(tmplMPL))
}

無論哪個分支,只要報錯,就會呼叫 os.Exit(1) 退出。

接下來就是程式的核心邏輯了:

// process at most 1000 files in parallel
ch := make(chan *file, 1000)
done := make(chan struct{})
go func() {
    var wg errgroup.Group
    for f := range ch {
        f := f // https://golang.org/doc/faq#closures_and_goroutines
        wg.Go(func() error {
            // nolint: nestif
            if *checkonly {
                // Check if file extension is known
                lic, err := licenseHeader(f.path, t, data)
                if err != nil {
                    fmt.Printf("%s: %v\n", f.path, err)

                    return err
                }
                if lic == nil { // Unknown fileExtension
                    return nil
                }
                // Check if file has a license
                isMissingLicenseHeader, err := fileHasLicense(f.path)
                if err != nil {
                    fmt.Printf("%s: %v\n", f.path, err)

                    return err
                }
                if isMissingLicenseHeader {
                    fmt.Printf("%s\n", f.path)

                    return errors.New("missing license header")
                }
            } else {
                modified, err := addLicense(f.path, f.mode, t, data)
                if err != nil {
                    fmt.Printf("%s: %v\n", f.path, err)

                    return err
                }
                if *verbose && modified {
                    fmt.Printf("%s added license\n", f.path)
                }
            }

            return nil
        })
    }
    err := wg.Wait()
    close(done)
    if err != nil {
        os.Exit(1)
    }
}()

for _, d := range pflag.Args() {
    walk(ch, d)
}
close(ch)
<-done

這段程式碼乍一看挺多,其實理清思路還是比較容易理解的。

我們先理清這段程式碼的整體脈絡:

// process at most 1000 files in parallel
ch := make(chan *file, 1000)
done := make(chan struct{})
go func() {
    var wg errgroup.Group
    for f := range ch {
        wg.Go(func() error {
            ...
            return nil
        })
    }
    err := wg.Wait()
    close(done)
    if err != nil {
        os.Exit(1)
    }
}()

for _, d := range pflag.Args() {
    walk(ch, d)
}
close(ch)
<-done

這段程式碼設計還是比較精妙的,主 goroutine 與子 goroutine 透過 chdone 進行協作。這也是典型的生產者消費者模型。

ch := make(chan *file, 1000) 建立了一個帶緩衝的通道,緩衝大小為 1000,即最大併發為 1000。它用於將遍歷到的檔案(透過 walk 函式找到的檔案)傳送到消費者 goroutine 中。

done := make(chan struct{}) 建立了一個無緩衝的通道,用於通知主 goroutine 所有併發任務(檢查或修改檔案)已經完成。

生產者 goroutine 遍歷 pflag.Args() 的返回值並呼叫 walk(ch, d) 來將生產的資料傳入 chpflag.Args() 呼叫會返回處理完標誌後剩餘的引數列表,型別為 []string,即傳進來的目錄或檔案。前面提到的 pflag.NArg() 返回幾,pflag.Args() 返回的切片中就有幾個值。

當生產者中的 walk 函式完成遍歷所有目錄併傳送所有檔案後,主 goroutine 會呼叫 close(ch) 關閉 ch 通道,通知接收方沒有更多的檔案需要處理。然後呼叫 <-done 阻塞,等待消費者 goroutine 傳送過來的完成訊號。

消費者 goroutine 中,for f := range ch { ... } 迴圈從 ch 通道接收檔案(*file 型別),併為每個檔案啟動一個新的 goroutine(透過 errgroupGo 方法管理併發任務)。如果你對 errgroup 不熟悉,可以參考後文附錄部分對 errgroup 的講解,瞭解其用法後再回過來接著分析程式碼。當 ch 通道被關閉,for 迴圈也就結束了。wg.Wait() 會等待所有消費 goroutine 處理完成並返回。然後呼叫 close(done) 關閉 done 通道。最後根據是否有 goroutine 返回 error 來決定是否呼叫 os.Exit(1) 進行異常退出。

當消費者 goroutine 關閉 done 通道時,生產者 <-done 會立即收到完成訊號,由於這是 main 函式的最後一行程式碼,<-done 返回也就意味著整個程式執行完成並退出。

兩個 goroutine 協同工作的主要邏輯已經解釋清楚,我們就來分別看下二者的具體邏輯實現。

生產者 goroutine 主要邏輯都在 walk 函式中:

func walk(ch chan<- *file, start string) {
    _ = filepath.Walk(start, func(path string, fi os.FileInfo, err error) error {
        if err != nil {
            fmt.Printf("%s error: %v\n", path, err)

            return nil
        }
        if fi.IsDir() {
            for _, pattern := range patterns.dirs {
                if pattern.MatchString(fi.Name()) {
                    return filepath.SkipDir
                }
            }

            return nil
        }

        for _, pattern := range patterns.files {
            if pattern.MatchString(fi.Name()) {
                return nil
            }
        }

        ch <- &file{path, fi.Mode()}

        return nil
    })
}

walk接收兩個引數 ch 通道以及遍歷的起始目錄 start

其中 ch 通道中的 file 型別定義如下:

type file struct {
    path string
    mode os.FileMode
}

path表示檔案路徑,mode 表示檔案操作模式。

walk函式內部使用 filepath.Walk 來從 start 開始遞迴的遍歷目錄,並對其進行處理。如果你對 filepath.Walk 不熟悉,可以參考後文附錄部分對 filepath.Walk 的講解,瞭解其用法後再回過來接著分析程式碼。

這裡處理邏輯也很簡單,就是透過正則匹配,來過濾使用者透過 --skip-dirs--skip-files 兩個標誌傳進來需要跳過的目錄和檔案。然後將需要處理的檔案傳遞給 ch 通道等待消費者去處理。

NOTE:

現在你知道為什麼前文示例中的命令使用了正則 --skip-dirs=^a$ 來跳過目錄 a,而沒有直接使用 --skip-dirs=a 了嗎?對字串 apattern.MatchString 會匹配到 data,所以程式才會跳過整個 data 目錄,不再進一步遍歷子目錄。

*file 物件被傳入 ch 通道,消費者就要開始工作了。

消費 goroutine 中主邏輯分兩種情況:

  1. 使用者執行命令時輸入了 --check 標誌,只檢查檔案是否存在 License。
  2. 需要新增 License 頭資訊的邏輯。

我們一個一個來看。

  1. 使用者執行命令時輸入了 --check 標誌,只檢查檔案是否存在 License,處理邏輯如下:
if *checkonly {
    // Check if file extension is known
    lic, err := licenseHeader(f.path, t, data)
    if err != nil {
        fmt.Printf("%s: %v\n", f.path, err)

        return err
    }
    if lic == nil { // Unknown fileExtension
        return nil
    }
    // Check if file has a license
    isMissingLicenseHeader, err := fileHasLicense(f.path)
    if err != nil {
        fmt.Printf("%s: %v\n", f.path, err)

        return err
    }
    if isMissingLicenseHeader {
        fmt.Printf("%s\n", f.path)

        return errors.New("missing license header")
    }
}

首先呼叫 licenseHeader 函式來檢查副檔名是否支援,它接收三個引數,分別是檔案路徑、License 模板、和 data,還記得 data 的內容嗎?包含 holderyear,用來渲染模板。

licenseHeader 函式實現如下:

func licenseHeader(path string, tmpl *template.Template, data *copyrightData) ([]byte, error) {
    var lic []byte
    var err error
    switch fileExtension(path) {
    default:
        return nil, nil
    case ".c", ".h":
        lic, err = prefix(tmpl, data, "/*", " * ", " */")
    case ".js", ".mjs", ".cjs", ".jsx", ".tsx", ".css", ".tf", ".ts":
        lic, err = prefix(tmpl, data, "/**", " * ", " */")
    case ".cc",
        ".cpp",
        ".cs",
        ".go",
        ".hh",
        ".hpp",
        ".java",
        ".m",
        ".mm",
        ".proto",
        ".rs",
        ".scala",
        ".swift",
        ".dart",
        ".groovy",
        ".kt",
        ".kts":
        lic, err = prefix(tmpl, data, "", "// ", "")
    case ".py", ".sh", ".yaml", ".yml", ".dockerfile", "dockerfile", ".rb", "gemfile":
        lic, err = prefix(tmpl, data, "", "# ", "")
    case ".el", ".lisp":
        lic, err = prefix(tmpl, data, "", ";; ", "")
    case ".erl":
        lic, err = prefix(tmpl, data, "", "% ", "")
    case ".hs", ".sql":
        lic, err = prefix(tmpl, data, "", "-- ", "")
    case ".html", ".xml", ".vue":
        lic, err = prefix(tmpl, data, "<!--", " ", "-->")
    case ".php":
        lic, err = prefix(tmpl, data, "", "// ", "")
    case ".ml", ".mli", ".mll", ".mly":
        lic, err = prefix(tmpl, data, "(**", "   ", "*)")
    }

    return lic, err
}

裡面邏輯看起來 case 比較多,不過主要是為了支援各種程式語言的檔案。

函式 fileExtension 用來獲取副檔名:

func fileExtension(name string) string {
    if v := filepath.Ext(name); v != "" {
        return strings.ToLower(v)
    }

    return strings.ToLower(filepath.Base(name))
}

然後根據不同的副檔名呼叫 prefix 函式渲染模板。

prefix 函式定義如下:

func prefix(t *template.Template, d *copyrightData, top, mid, bot string) ([]byte, error) {
    var buf bytes.Buffer
    if err := t.Execute(&buf, d); err != nil {
        return nil, fmt.Errorf("render template failed, err: %w", err)
    }
    var out bytes.Buffer
    if top != "" {
        fmt.Fprintln(&out, top)
    }
    s := bufio.NewScanner(&buf)
    for s.Scan() {
        fmt.Fprintln(&out, strings.TrimRightFunc(mid+s.Text(), unicode.IsSpace))
    }
    if bot != "" {
        fmt.Fprintln(&out, bot)
    }
    fmt.Fprintln(&out)

    return out.Bytes(), nil
}

prefix 函式會根據不同程式語言的註釋風格生成版權宣告頭資訊。它需要傳入 License 模板、版權資訊(年份、作者)、開頭、中間、結尾識別符號。

所以我們呼叫 lic, err := licenseHeader(f.path, t, data),最終得到的 lic 實際上內容根據檔案型別是渲染後的 License 資訊。

比如同一個 License 頭資訊,在不同程式語言檔案中都要寫成對應的註釋形式,所以要入鄉隨俗。

在 Go 檔案中 License 頭資訊長這樣:

// Copyright 2024 jianghushinian <jianghushinian007@outlook.com>. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file. The original repo for
// this file is https://github.com/jianghushinian/blog-go-example.

在 Python 檔案中 License 頭資訊則要長這樣:

# Copyright 2024 jianghushinian <jianghushinian007@outlook.com>. All rights reserved.
# Use of this source code is governed by a MIT style
# license that can be found in the LICENSE file. The original repo for
# this file is https://github.com/jianghushinian/blog-go-example.

接下來判斷如果沒拿到結果,說明是不支援的副檔名,直接返回不做進一步處理,邏輯如下:

if lic == nil { // Unknown fileExtension
    return nil
}

之後呼叫 fileHasLicense檢查檔案是否包含授權頭資訊。fileHasLicense 函式實現如下:

func fileHasLicense(path string) (bool, error) {
    b, err := ioutil.ReadFile(path)
    if err != nil {
        return false, err
    }

    if hasLicense(b) {
        return false, nil
    }

    return true, nil
}

func hasLicense(b []byte) bool {
    n := 1000
    if len(b) < 1000 {
        n = len(b)
    }

    return bytes.Contains(bytes.ToLower(b[:n]), []byte("copyright")) ||
        bytes.Contains(bytes.ToLower(b[:n]), []byte("mozilla public"))
}

這裡實現比較簡單,就是讀取檔案內容,然後判斷前 1000 個字元中是否包含 copyrightmozilla public 關鍵字。

fileHasLicense 函式返回後,如果其返回值為 true,則說明檔案中不包含 License 頭資訊,直接返回一個 error

if isMissingLicenseHeader {
    fmt.Printf("%s\n", f.path)

    return errors.New("missing license header")
}

這裡返回的 error 會被 err := wg.Wait() 拿到,最終呼叫 os.Exit(1) 異常退出。

  1. 處理需要新增 License 頭資訊的邏輯如下:
else {
    modified, err := addLicense(f.path, f.mode, t, data)
    if err != nil {
        fmt.Printf("%s: %v\n", f.path, err)

        return err
    }
    if *verbose && modified {
        fmt.Printf("%s added license\n", f.path)
    }
}

這裡呼叫 addLicense 函式為指定檔案插入 License 頭資訊。

addLicense 函式實現如下:

func addLicense(path string, fmode os.FileMode, tmpl *template.Template, data *copyrightData) (bool, error) {
    var lic []byte
    var err error
    lic, err = licenseHeader(path, tmpl, data)
    if err != nil || lic == nil {
        return false, err
    }

    b, err := ioutil.ReadFile(path)
    if err != nil {
        return false, err
    }
    if hasLicense(b) {
        return false, nil
    }

    line := hashBang(b)
    if len(line) > 0 {
        b = b[len(line):]
        if line[len(line)-1] != '\n' {
            line = append(line, '\n')
        }
        line = append(line, '\n')
        lic = append(line, lic...)
    }
    b = append(lic, b...)

    return true, ioutil.WriteFile(path, b, fmode)
}

首先這裡也呼叫了 licenseHeader 來判斷副檔名是否被支援,並渲染 License 模板。

然後呼叫 hasLicense 來判斷檔案是否已經存在 License 頭資訊。

如果檔案不存在 License 頭資訊,接下來的邏輯就是正式準備寫入 License 頭資訊了。

接下來這段邏輯分兩種情況,首先呼叫 hashBang 函式用來判斷檔案是否存在 Shebang 行,如果有 Shebang 行,則原始檔內容為 Shebang 行 + 程式碼,新內容為 Shebang 行 + License 頭資訊 + 程式碼。如果沒有 Shebang 行存在,則原始檔內容只包含程式碼,新內容為 License 頭資訊 + 程式碼。

hashBang 函式內容如下:

var head = []string{
    "#!",                       // shell script
    "<?xml",                    // XML declaratioon
    "<!doctype",                // HTML doctype
    "# encoding:",              // Ruby encoding
    "# frozen_string_literal:", // Ruby interpreter instruction
    "<?php",                    // PHP opening tag
}

func hashBang(b []byte) []byte {
    line := make([]byte, 0, len(b))
    for _, c := range b {
        line = append(line, c)
        if c == '\n' {
            break
        }
    }
    first := strings.ToLower(string(line))
    for _, h := range head {
        if strings.HasPrefix(first, h) {
            return line
        }
    }

    return nil
}

最後這段邏輯就簡單了:

if *verbose && modified {
    fmt.Printf("%s added license\n", f.path)
}

這裡用來處理 -v/--verbose 引數。

至此,addlicense 所有原始碼就都解讀完成了。

總結

本文介紹可一行命令為專案檔案新增開源協議頭的工具 addlicense,並且還對其原始碼進行了逐行解讀,讓你能夠知其然,也能知其所以然。

不過 addlicense 工具能力還比較有限,比如不支援跳過 a/b/c 這種巢狀目錄,再比如 hashBang 函式支援有限,不支援 Python3 的 # -*- coding:utf-8 -*- 等。

如果感興趣,你可以一起投入到專案建設中來,為這個工具提供更強大的能力,歡迎提交 PR

本文示例原始碼我都放在了 GitHub 中,歡迎點選檢視。

希望此文能對你有所啟發。

附錄

filepath.Walk

filepath.Walk 是 Go 標準庫中的一個函式,用來遞迴遍歷檔案系統中的目錄和檔案。它可以遍歷指定目錄下的所有檔案和子目錄,並對每個檔案或目錄執行使用者提供的回撥函式。

基本語法

Walk 函式簽名如下:

func Walk(root string, fn WalkFunc) error

Walk 接收兩個引數:

  • root:需要遞迴遍歷的起始目錄路徑。
  • fn:每次遍歷到一個檔案或目錄時呼叫的回撥函式。

Walk 遍歷以 root 為根的檔案樹,併為樹中的每個檔案或目錄(包括 root)呼叫 fn 函式。

WalkFunc 函式簽名如下:

type WalkFunc func(path string, info fs.FileInfo, err error) error

WalkFunc 接收三個引數:

  • path:當前檔案或目錄的路徑。
  • info:當前檔案或目錄的 fs.FileInfo,這裡包含了檔案的元資訊,如是否為目錄、檔案大小等。
  • err:錯誤資訊,如許可權問題。

該函式返回的錯誤結果會控制 Walk 是否繼續執行。如果函式返回特殊值 filepath.SkipDir,則 Walk 會跳過當前目錄(如果 path 是目錄跳過當前目錄,否則跳過 path 的父目錄)但繼續遍歷其他內容。如果函式返回特殊值 filepath.SkipAll,則 Walk 將跳過所有剩餘的檔案和目錄。否則,如果函式返回非 nil 錯誤,則 Walk 將完全停止並返回該錯誤。

使用示例

現在我們準備如下用來測試的目錄:

$ tree data -a
data
├── .git
├── a
│   ├── main.go
│   └── main_test.go
├── b
│   └── c
│       └── keep
├── d.go
└── d_test.go

5 directories, 5 files

我們來使用 Walk 遍歷 data 目錄,並且輸出每個檔案或目錄的路徑。此外,需要跳過名為 .git 的目錄和以 test.go 結尾的 Go 測試檔案。

示例程式碼如下:

package main

import (
    "fmt"
    "os"
    "path/filepath"
    "strings"
)

func main() {
    // 定義起始目錄
    root := "./data"

    // 呼叫 Walk 函式遍歷目錄
    err := filepath.Walk(root, func(path string, info os.FileInfo, err error) error {
        if err != nil {
            // 如果發生錯誤,則輸出錯誤並繼續遍歷
            fmt.Printf("Error accessing path %s: %v\n", path, err)
            return nil
        }

        // 跳過名為 `.git` 的目錄
        if info.IsDir() && info.Name() == ".git" {
            fmt.Printf("Skipping directory: %s\n", path)
            return filepath.SkipDir
        }

        // 跳過 Go 測試檔案
        if !info.IsDir() && strings.HasSuffix(info.Name(), "test.go") {
            fmt.Println("Skipping file:", path)
            return nil
        }

        // 輸出每個檔案或目錄的路徑
        fmt.Println(path)
        return nil
    })

    if err != nil {
        fmt.Printf("Error walking the path %v\n", err)
    }
}

透過 info.IsDir() 可以判斷是否為目錄,info.Name() 可以獲取檔案或目錄名。

使用 strings.HasSuffix() 函式可以過濾出 .go 檔案。

執行示例程式碼,得到輸出如下:

$ go run main.go 
./data
Skipping directory: data/.git
data/a
data/a/main.go
Skipping file: data/a/main_test.go
data/b
data/b/c
data/b/c/keep
data/d.go
Skipping file: data/d_test.go

errgroup

errgroupGo 官方庫 x 中提供的一個非常實用的工具,用於併發執行多個 goroutine,並且方便的處理錯誤。

使用場景
  1. 併發處理多個任務:當需要併發執行多個任務時,errgroup 有助於管理這些任務。
  2. 收集錯誤errgroup 會在任何一個 goroutine 出現錯誤時收集並返回這個錯誤,避免手動處理 goroutine 的錯誤。
  3. 等待所有 goroutine 完成errgroup 提供了一個簡便的方法等待所有併發的 goroutine 完成。
基本使用

errgroup 基本使用套路如下:

  1. 匯入 errgroup 包。
  2. 建立一個 errgroup.Group 例項。
  3. 使用 Group.Go 方法啟動多個併發任務。
  4. 使用 Group.Wait 方法等待所有 goroutine 完成或有一個返回錯誤。
使用示例

我們有 10 個併發任務用 errgroup 來管理,示例程式碼如下:

package main

import (
    "errors"
    "fmt"

    "golang.org/x/sync/errgroup"
)

func main() {
    var g errgroup.Group
    for i := 0; i < 10; i++ {
        i := i
        g.Go(func() error {
            if i == 3 {
                return errors.New("task 3 failed")
            }
            if i == 5 {
                return errors.New("task 5 failed")
            }

            // 其他任務繼續執行
            fmt.Printf("run task %d\n", i)

            return nil // 正常返回 nil 表示成功
        })
    }
    if err := g.Wait(); err != nil {
        fmt.Printf("Error: %v\n", err)
    }
}

程式碼解析:

  1. var g errgroup.Group: 建立了一個 errgroup.Group 物件,它用於管理多個 goroutine 並跟蹤它們的狀態。
  2. g.Go(func() error {...}): 每次呼叫 g.Go,都會啟動一個新的 goroutine,傳入的匿名函式是任務的執行內容。Go 方法會記錄這個任務的返回值(error 型別)。
  3. 併發執行任務:在 g.Go 內部執行的 func() error 都會併發執行。
  4. g.Wait(): g.Wait 會等待所有的 goroutine 執行完成。如果所有任務都執行成功,它會返回 nil,否則,無論有幾個 goroutine 執行失敗,它會返回第一個出現的錯誤。示例中第 3 個任務和第 5 個任務出錯,其他的 8 個任務不會受到影響,它們依然會繼續執行並完成

執行示例程式碼,得到輸出如下:

$ go run main.go 
run task 9
run task 4
run task 2
run task 6
run task 7
run task 1
run task 8
run task 0
Error: task 3 failed

由於任務是併發執行,所以多次執行輸出結果順序可能不同。

並且,輸出錯誤可能是 Error: task 3 failed,也有可能是 Error: task 5 failed

這裡還有一個更加真實的改編自 errgroup 官方文件的示例,用來併發請求多個 URL 並輸出響應狀態碼。

你可以再來感受下 errgroup 的用法,程式碼如下:

package main

import (
    "fmt"
    "net/http"
    "sync"

    "golang.org/x/sync/errgroup"
)

func main() {
    g := new(errgroup.Group)
    var urls = []string{
        "http://www.golang.org/",
        "http://www.google.com/",
        "http://www.somestupidname.com/", // 這是一個錯誤的 URL,會導致任務失敗
    }

    // 建立一個 map 來儲存結果
    var result sync.Map

    // 啟動多個 goroutine,併發處理多個 URL
    for _, url := range urls {
        // NOTE: 注意這裡的 url 需要傳遞給閉包函式,避免閉包共享變數問題,自 Go 1.22 開始無需考慮此問題
        url := url // https://golang.org/doc/faq#closures_and_goroutines

        // 啟動一個 goroutine 來獲取 URL
        g.Go(func() error {
            resp, err := http.Get(url)
            if err != nil {
                return err // 發生錯誤,返回該錯誤
            }
            defer resp.Body.Close()

            // 儲存每個 URL 的響應狀態碼
            result.Store(url, resp.Status)
            return nil
        })
    }

    // 等待所有 goroutine 完成
    if err := g.Wait(); err != nil {
        // 如果有任何一個 goroutine 返回了錯誤,這裡會得到該錯誤
        fmt.Println("Error: ", err)
    }

    // 所有 goroutine 都執行完成,遍歷並列印成功的結果
    result.Range(func(key, value any) bool {
        fmt.Printf("%s: %s\n", key, value)
        return true
    })
}

你可能注意到示例程式碼中有一句 url := url,這是由於在 Go 1.22 以前,由於 for 迴圈宣告的變數只會被建立一次,並在每次迭代時更新。所以為了避免多個 goroutine 中拿到相同的 url 值,而進行的複製操作。

在 Go 1.22 中,迴圈的每次迭代都會建立新的變數,以避免意外的共享錯誤。這在 Go 1.22 Release Notes 中有說明。

執行示例程式碼,得到輸出如下:

$ go run main.go
Error:  Get "http://www.somestupidname.com/": dial tcp: lookup www.somestupidname.com: no such host
http://www.google.com/: 200 OK
http://www.golang.org/: 200 OK

聯絡我

  • 公眾號:Go程式設計世界
  • 微信:jianghushinian
  • 郵箱:jianghushinian007@outlook.com
  • 部落格:https://jianghushinian.cn

相關文章