如何增強grpc的攔截器

yuandian發表於2020-09-27

grpc的攔截器作為aop程式設計的利器,相信大家在使用grpc時,一定都體會過。這裡呢,主要說一下平時使用時的不足:

只能用於全域性,不能靈活的對每個或者一類方法進行攔截處理。如果只做一些業務無關的操作(記錄請求日誌,發起非業務錯誤重試),藉助grpc-multi-interceptor還是能很好實現。但跟業務相關的一些處理,則顯得力不從心了。


剛好最近在寫grpc server端的時候,需要做方法呼叫身份驗證,呼叫者許可權驗證。但是這兩類驗證並不是對所有方法都攔截,而是有好些方法是完全放開的。所以呢只能把兩者的驗證寫在一個攔截器裡面的,然後定義一個白名單切片,在白名單中就跳過。但是寫下來感覺不是很好,畢竟兩者功能不一樣,只是有共同的白名單。還有一些方法需要做一些統一的提前處理,這些處理也可能不一樣,有些少一點,有些多一點,頓時感覺grpc的攔截器有些不靈活。沒有其它框架裡面類似中介軟體,攔截器來的強大。

於是自己理了下思路,決定擴充套件下grpc的攔截器,主要想要實現:

  1. 支援設定多個攔截器

  2. 全域性攔截器上實現分組,可以設定白名單,同一分組的攔截器適用該組的白名單

  3. 對單一方法新增攔截器

在談實現之前,先簡單說一下grpc的四類攔截器吧

server端

unary interceptor 只要實現grpc.UnaryServerInterceptor:

type UnaryServerInterceptor func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (resp interface{}, err error)

例:

func UnaryServerInterceptorDemo(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
    log.Printf("before handling. Info: %+v", info)
    resp, err := handler(ctx, req)
    log.Printf("after handling. resp: %+v", resp)
    return resp, err
}

stream interceptor 實現grpc.StreamServerInterceptor

type StreamServerInterceptor func(srv interface{}, ss ServerStream, info *StreamServerInfo, handler StreamHandler) error

例:

func StreamServerInterceptorDemo(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
    log.Printf("before handling. Info: %+v", info)
    err := handler(srv, ss)
    log.Printf("after handling. err: %v", err)
    return err
}

然後在服務初始化時:

srv := grpc.NewServer(
    grpc.UnaryInterceptor(UnaryServerInterceptorDemo),
    grpc.StreamInterceptor(StreamServerInterceptorDemo),
)

user.RegisterUserServiceServer(srv, &UserService{})

srv.Server(listen)

client端

unary interceptor 實現 grpc.UnaryClientInterceptor:

type UnaryClientInterceptor fuc(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error

例:

func UnaryClientInterceptorDemo(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
    log.Printf("before invoker. method: %+v, request:%+v", method, req)
    err := invoker(ctx, method, req, reply, cc, opts...)
    log.Printf("after invoker. reply: %+v", reply)
    return err
}

stream interceptor 實現grpc.StreamClientInterceptor:

type StreamClientInterceptor func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, streamer Streamer, opts ...CallOption) (ClientStream, error)

例:

func StreamServerInterceptorDemo(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
    log.Printf("before handling. Info: %+v", info)
    err := handler(srv, ss)
    log.Printf("after handling. err: %v", err)
    return err
}

在初始化客戶端時:

grpc.DialContext(ctx, target, 
    grpc.WithUnaryInterceptor(UnaryClientInterceptorDemo),
    grpc.WithStreamInterceptor(StreamServerInterceptorDemo),
    ...
)

預設的攔截器是unary 跟 stream 每種只能配置一個,配置多個,需要藉助grpc-multi-interceptor](GitHub - kazegusuri/grpc-multi-interceptor)

首先擴充套件UnaryServerInterceptor, 定義如下結構體:

// 用於存放一組攔截器
type unaryServerInterceptorGroup struct {
    handlers []grpc.UnaryServerInterceptor  // 包含的攔截器
    skip map[string]struct{}                // 該組的白名單,存放不被攔截的方法名
}

// 用於配置所有的UnaryServerInterceptor
type UnaryServerInterceptors struct {
    global []*unaryServerInterceptorGroup  // 全域性攔截器組
    part map[string][]grpc.UnaryServerInterceptor // 區域性攔截器
}

定義好結構以後,來實現新增攔截器方法,需要能夠新增全域性攔截器,新增單一方法的攔截器(區域性攔截器)。

  1. 新增全域性攔截器,約定呼叫一次該方法,則新增一組攔截器,方法引數需要傳入攔截器切片。如果需要新增單個全域性攔截器,則攔截器切片只放一個攔截器,即一個攔截器就是一組。 攔截器的執行順序按 新增方法的呼叫順序。

  2. 新增單一方法的攔截器,需要傳入攔截器針對的方法名,以及該方法的攔截器。區域性攔截器的執行順序為 新增方法的呼叫順序,然後同一個新增方法傳入的攔截器順序。

  3. 無論新增全域性攔截器跟區域性攔截器的順序怎麼樣,都是先執行全域性攔截器再執行區域性攔截器。

// 新增全域性攔截器
// @param interceptors 新增的攔截器切片
// @param skipMethods  改組攔截器需要忽略的方法名
func (usi *UnaryServerInterceptors) UseGlobal(interceptors []grpc.UnaryServerInterceptor, skipMethods ...string) {
    skip := make(map[string]struct{}, len(skipMethods))
    // 將白名單切片轉換為map
    for _, method := range skipMethods {
        skip[method] = struct{}{}
    }

    // 構造攔截器組放置到攔截器組切片末尾
    usi.global = append(usi.global, &unaryServerInterceptorGroup{
        handlers: interceptors,
        skip:     skip,
    })
}

// 新增區域性攔截器
// @param method 針對的方法名
// @param interceptors 新增的攔截器
func (usi *UnaryServerInterceptors) UseMethod(method string, interceptors ...grpc.UnaryServerInterceptor) {
    // 區域性攔截器用map存放,key為方法全名,判斷是否初始化
    if usi.part == nil {
        usi.part = make(map[string][]grpc.UnaryServerInterceptor)
        usi.part[method] = interceptors
        return
    }

    // 已經初始化,判斷該方法名的攔截器是否新增過,沒有直接賦值
    if _, ok := usi.part[method]; !ok {
        usi.part[method] = interceptors
        return
    }

    // 已經存在,將新增的加至末尾
    usi.part[method] = append(usi.part[method], interceptors...)
}

上面的skipMethods跟method引數,如果不知道grpc方法全名的命名規則,可以直接檢視生成的protoc.pb.go檔案

比如:

func (c *userServiceClient) GetUserInfo(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*UserInfo, error) {
    out := new(UserInfo)
    err := c.cc.Invoke(ctx, "/user.UserService/GetUserInfo", in, out, opts...)
    if err != nil {
        return nil, err
    }
    return out, nil
}


func _UserService_GetUserInfo_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
    in := new(Empty)
    if err := dec(in); err != nil {
        return nil, err
    }
    if interceptor == nil {
        return srv.(UserServiceServer).GetUserInfo(ctx, in)
    }
    info := &grpc.UnaryServerInfo{
        Server:     srv,
        FullMethod: "/user.UserService/GetUserInfo",
    }
    handler := func(ctx context.Context, req interface{}) (interface{}, error) {
        return srv.(UserServiceServer).GetUserInfo(ctx, req.(*Empty))
    }
    return interceptor(ctx, in, info, handler)
}

tips~ 上面是grpc生成的客戶端跟服務端程式碼。其中"/user.UserService/GetUserInfo"就是方法名,以/開頭,然後是protoc檔案定義的包名,定義的服務名,包名跟服務名用.隔開,再跟/和定義的方法名。

新增實現了,接下來就是實現匯出grpc.UnaryServerInterceptor方法了

func (usi *UnaryServerInterceptors) UnaryServerInterceptor() grpc.UnaryServerInterceptor {
    // 返回grpc.UnaryServerInterceptor型別的匿名函式
    return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo,
    handler grpc.UnaryHandler)(interface{}, error) {

        // 在該匿名函式中,需要依次呼叫設定的攔截器

        // 獲取全域性攔截器組的數量
        ggCount := len(usi.global)
        // 獲取區域性攔截器的數量
        mCount := len(usi.part[info.FullMethod])

        // 如果沒有設定任何攔截器,直接呼叫後續處理函式
        if ggCount + mCount == 0 {
            return handler(ctx, req)
        }

        // 開始呼叫攔截器
        var (
            groupI, handlerI int  // 全域性攔截器組執行計數,攔截器執行計數
            chainHandler grpc.UnaryHandler // 呼叫函式名,遞迴呼叫,需要函式名
        )
        chainHandler = func(ctx context.Context, req interface{}) (i interface{}, err error) {
            // 攔截器組執行計數小於全域性攔截器組的數量,說明當前呼叫仍然執行全域性攔截器
            if groupI < ggCount {
                group := usi.global[groupI]
                // 判斷當前組白名單是否包含該方法
                if _, ok := group.skip[info.FullMethod]; !ok {
                    // 不包含,得到當前攔截器組執行到哪個攔截器的索引,計數加1
                    index := handlerI
                    handlerI++
                    if index < len(group.handlers) {
                        // 執行當前攔截器
                        return group.handlers[index](ctx, req, info, chainHandler)
                    }
                    // 上步得到的索引大於該組數量,則該組已經執行完成,需要跳到下一組
                    // 先將攔截器計數歸0
                    handlerI = 0
                }
                // 攔截器組執行完或者方法在攔截器組白名單,都跳到下一組攔截器
                // 攔截器組計數加1
                groupI++
                return chainHandler(ctx, req)
            }

            // 全域性攔截器組執行完以後,執行鍼對該方法的區域性攔截器
            // 攔截器計數在執行完全域性攔截器組以後被歸0,複用來計數區域性攔截器
            if handlerI < mCount {
                special := usi.part[info.FullMethod]
                index := handlerI
                handlerI++
                return special[index](ctx, req, info, chainHandler)
            }

            // 區域性攔截器執行完以後,執行後續處理函式,也是遞迴跳出點
            return handler(ctx, req)
        }

        // 再次匯出
        return chainHandler(ctx, req)
    }
}

這樣,增強版UnaryServerInterceptor攔截器就實現了

但上面的程式碼其實還是有點小問題,就是在執行下一個全域性攔截器組時以及跳到區域性攔截器時,都是使用return chainHandler(ctx, req),本次執行其實沒有意義的,加深了無用的函式呼叫棧。所以可以採用goto或者for迴圈忽略無意義的呼叫

我們採用for迴圈來優化:

func (usi *UnaryServerInterceptors) UnaryServerInterceptor() grpc.UnaryServerInterceptor {
    // 返回grpc.UnaryServerInterceptor型別的匿名函式
    return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo,
    handler grpc.UnaryHandler)(interface{}, error) {
        ...

        chainHandler = func(ctx context.Context, req interface{}) (i interface{}, err error) {
            // 攔截器組執行計數小於全域性攔截器組的數量,說明當前呼叫仍然執行全域性攔截器
            if groupI < ggCount {
                for {
                    group := usi.global[groupI]
                    // 判斷當前組白名單是否包含該方法
                    if _, ok := group.skip[info.FullMethod]; !ok {
                        // 不包含,得到當前攔截器組執行到哪個攔截器的索引,計數加1
                        index := handlerI
                        handlerI++
                        if index < len(group.handlers) {
                            // 執行當前攔截器
                            return group.handlers[index](ctx, req, info, chainHandler)
                        }
                        // 上步得到的索引大於該組數量,則該組已經執行完成,需要跳到下一組
                        // 先將攔截器計數歸0
                        handlerI = 0
                    }
                    // 攔截器組執行完或者方法在攔截器組白名單,都跳到下一組攔截器
                    // 攔截器組計數加1
                    groupI++
                    if groupI >= ggCount { // 攔截器組執行完跳出迴圈
                        break
                    }
                }
            }

            ...
        }

        ...
    }
}

到這,我們可以在程式碼中使用了

func LoadUnaryInterceptors() grpc.UnaryServerInterceptor  {
    md := &UnaryServerInterceptors{}

    // 載入全域性攔截器
    md.UseGlobal([]grpc.UnaryServerInterceptor{AuthGuard},
        // 完全放開的api
        "/user.UserService/RegisterRule",
        "/user.UserService/SendRegisterCode",
        ...
        )

    ...

    // 載入區域性攔截器
    md.UseMethod("/user.UserService/SendRegisterCode", DisableGuard, MsgCodeRateLimitGuard)


    return md.UnaryServerInterceptor()
}


srv := grpc.NewServer(
    grpc.UnaryInterceptor(LoadUnaryInterceptors()),
)

其它型別的攔截器擴充套件也跟此差不多,就不講解了。可以參考程式碼:github.com/welllog/grpc_intercepto...

本作品採用《CC 協議》,轉載必須註明作者和本文連結
~by orinfy

相關文章