grpc的攔截器作為aop程式設計的利器,相信大家在使用grpc時,一定都體會過。這裡呢,主要說一下平時使用時的不足:
只能用於全域性,不能靈活的對每個或者一類方法進行攔截處理。如果只做一些業務無關的操作(記錄請求日誌,發起非業務錯誤重試),藉助grpc-multi-interceptor還是能很好實現。但跟業務相關的一些處理,則顯得力不從心了。
剛好最近在寫grpc server端的時候,需要做方法呼叫身份驗證,呼叫者許可權驗證。但是這兩類驗證並不是對所有方法都攔截,而是有好些方法是完全放開的。所以呢只能把兩者的驗證寫在一個攔截器裡面的,然後定義一個白名單切片,在白名單中就跳過。但是寫下來感覺不是很好,畢竟兩者功能不一樣,只是有共同的白名單。還有一些方法需要做一些統一的提前處理,這些處理也可能不一樣,有些少一點,有些多一點,頓時感覺grpc的攔截器有些不靈活。沒有其它框架裡面類似中介軟體,攔截器來的強大。
於是自己理了下思路,決定擴充套件下grpc的攔截器,主要想要實現:
支援設定多個攔截器
全域性攔截器上實現分組,可以設定白名單,同一分組的攔截器適用該組的白名單
對單一方法新增攔截器
在談實現之前,先簡單說一下grpc的四類攔截器吧
server端
unary interceptor 只要實現grpc.UnaryServerInterceptor:
type UnaryServerInterceptor func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (resp interface{}, err error)
例:
func UnaryServerInterceptorDemo(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
log.Printf("before handling. Info: %+v", info)
resp, err := handler(ctx, req)
log.Printf("after handling. resp: %+v", resp)
return resp, err
}
stream interceptor 實現grpc.StreamServerInterceptor
type StreamServerInterceptor func(srv interface{}, ss ServerStream, info *StreamServerInfo, handler StreamHandler) error
例:
func StreamServerInterceptorDemo(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
log.Printf("before handling. Info: %+v", info)
err := handler(srv, ss)
log.Printf("after handling. err: %v", err)
return err
}
然後在服務初始化時:
srv := grpc.NewServer(
grpc.UnaryInterceptor(UnaryServerInterceptorDemo),
grpc.StreamInterceptor(StreamServerInterceptorDemo),
)
user.RegisterUserServiceServer(srv, &UserService{})
srv.Server(listen)
client端
unary interceptor 實現 grpc.UnaryClientInterceptor:
type UnaryClientInterceptor fuc(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error
例:
func UnaryClientInterceptorDemo(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
log.Printf("before invoker. method: %+v, request:%+v", method, req)
err := invoker(ctx, method, req, reply, cc, opts...)
log.Printf("after invoker. reply: %+v", reply)
return err
}
stream interceptor 實現grpc.StreamClientInterceptor:
type StreamClientInterceptor func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, streamer Streamer, opts ...CallOption) (ClientStream, error)
例:
func StreamServerInterceptorDemo(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
log.Printf("before handling. Info: %+v", info)
err := handler(srv, ss)
log.Printf("after handling. err: %v", err)
return err
}
在初始化客戶端時:
grpc.DialContext(ctx, target,
grpc.WithUnaryInterceptor(UnaryClientInterceptorDemo),
grpc.WithStreamInterceptor(StreamServerInterceptorDemo),
...
)
預設的攔截器是unary 跟 stream 每種只能配置一個,配置多個,需要藉助grpc-multi-interceptor](GitHub - kazegusuri/grpc-multi-interceptor)
首先擴充套件UnaryServerInterceptor, 定義如下結構體:
// 用於存放一組攔截器
type unaryServerInterceptorGroup struct {
handlers []grpc.UnaryServerInterceptor // 包含的攔截器
skip map[string]struct{} // 該組的白名單,存放不被攔截的方法名
}
// 用於配置所有的UnaryServerInterceptor
type UnaryServerInterceptors struct {
global []*unaryServerInterceptorGroup // 全域性攔截器組
part map[string][]grpc.UnaryServerInterceptor // 區域性攔截器
}
定義好結構以後,來實現新增攔截器方法,需要能夠新增全域性攔截器,新增單一方法的攔截器(區域性攔截器)。
新增全域性攔截器,約定呼叫一次該方法,則新增一組攔截器,方法引數需要傳入攔截器切片。如果需要新增單個全域性攔截器,則攔截器切片只放一個攔截器,即一個攔截器就是一組。 攔截器的執行順序按 新增方法的呼叫順序。
新增單一方法的攔截器,需要傳入攔截器針對的方法名,以及該方法的攔截器。區域性攔截器的執行順序為 新增方法的呼叫順序,然後同一個新增方法傳入的攔截器順序。
無論新增全域性攔截器跟區域性攔截器的順序怎麼樣,都是先執行全域性攔截器再執行區域性攔截器。
// 新增全域性攔截器
// @param interceptors 新增的攔截器切片
// @param skipMethods 改組攔截器需要忽略的方法名
func (usi *UnaryServerInterceptors) UseGlobal(interceptors []grpc.UnaryServerInterceptor, skipMethods ...string) {
skip := make(map[string]struct{}, len(skipMethods))
// 將白名單切片轉換為map
for _, method := range skipMethods {
skip[method] = struct{}{}
}
// 構造攔截器組放置到攔截器組切片末尾
usi.global = append(usi.global, &unaryServerInterceptorGroup{
handlers: interceptors,
skip: skip,
})
}
// 新增區域性攔截器
// @param method 針對的方法名
// @param interceptors 新增的攔截器
func (usi *UnaryServerInterceptors) UseMethod(method string, interceptors ...grpc.UnaryServerInterceptor) {
// 區域性攔截器用map存放,key為方法全名,判斷是否初始化
if usi.part == nil {
usi.part = make(map[string][]grpc.UnaryServerInterceptor)
usi.part[method] = interceptors
return
}
// 已經初始化,判斷該方法名的攔截器是否新增過,沒有直接賦值
if _, ok := usi.part[method]; !ok {
usi.part[method] = interceptors
return
}
// 已經存在,將新增的加至末尾
usi.part[method] = append(usi.part[method], interceptors...)
}
上面的skipMethods跟method引數,如果不知道grpc方法全名的命名規則,可以直接檢視生成的protoc.pb.go檔案
比如:
func (c *userServiceClient) GetUserInfo(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*UserInfo, error) {
out := new(UserInfo)
err := c.cc.Invoke(ctx, "/user.UserService/GetUserInfo", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func _UserService_GetUserInfo_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(Empty)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(UserServiceServer).GetUserInfo(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/user.UserService/GetUserInfo",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(UserServiceServer).GetUserInfo(ctx, req.(*Empty))
}
return interceptor(ctx, in, info, handler)
}
tips~ 上面是grpc生成的客戶端跟服務端程式碼。其中"/user.UserService/GetUserInfo"
就是方法名,以/
開頭,然後是protoc檔案定義的包名,定義的服務名,包名跟服務名用.
隔開,再跟/
和定義的方法名。
新增實現了,接下來就是實現匯出grpc.UnaryServerInterceptor方法了
func (usi *UnaryServerInterceptors) UnaryServerInterceptor() grpc.UnaryServerInterceptor {
// 返回grpc.UnaryServerInterceptor型別的匿名函式
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler)(interface{}, error) {
// 在該匿名函式中,需要依次呼叫設定的攔截器
// 獲取全域性攔截器組的數量
ggCount := len(usi.global)
// 獲取區域性攔截器的數量
mCount := len(usi.part[info.FullMethod])
// 如果沒有設定任何攔截器,直接呼叫後續處理函式
if ggCount + mCount == 0 {
return handler(ctx, req)
}
// 開始呼叫攔截器
var (
groupI, handlerI int // 全域性攔截器組執行計數,攔截器執行計數
chainHandler grpc.UnaryHandler // 呼叫函式名,遞迴呼叫,需要函式名
)
chainHandler = func(ctx context.Context, req interface{}) (i interface{}, err error) {
// 攔截器組執行計數小於全域性攔截器組的數量,說明當前呼叫仍然執行全域性攔截器
if groupI < ggCount {
group := usi.global[groupI]
// 判斷當前組白名單是否包含該方法
if _, ok := group.skip[info.FullMethod]; !ok {
// 不包含,得到當前攔截器組執行到哪個攔截器的索引,計數加1
index := handlerI
handlerI++
if index < len(group.handlers) {
// 執行當前攔截器
return group.handlers[index](ctx, req, info, chainHandler)
}
// 上步得到的索引大於該組數量,則該組已經執行完成,需要跳到下一組
// 先將攔截器計數歸0
handlerI = 0
}
// 攔截器組執行完或者方法在攔截器組白名單,都跳到下一組攔截器
// 攔截器組計數加1
groupI++
return chainHandler(ctx, req)
}
// 全域性攔截器組執行完以後,執行鍼對該方法的區域性攔截器
// 攔截器計數在執行完全域性攔截器組以後被歸0,複用來計數區域性攔截器
if handlerI < mCount {
special := usi.part[info.FullMethod]
index := handlerI
handlerI++
return special[index](ctx, req, info, chainHandler)
}
// 區域性攔截器執行完以後,執行後續處理函式,也是遞迴跳出點
return handler(ctx, req)
}
// 再次匯出
return chainHandler(ctx, req)
}
}
這樣,增強版UnaryServerInterceptor攔截器就實現了
但上面的程式碼其實還是有點小問題,就是在執行下一個全域性攔截器組時以及跳到區域性攔截器時,都是使用return chainHandler(ctx, req)
,本次執行其實沒有意義的,加深了無用的函式呼叫棧。所以可以採用goto或者for迴圈忽略無意義的呼叫
我們採用for迴圈來優化:
func (usi *UnaryServerInterceptors) UnaryServerInterceptor() grpc.UnaryServerInterceptor {
// 返回grpc.UnaryServerInterceptor型別的匿名函式
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler)(interface{}, error) {
...
chainHandler = func(ctx context.Context, req interface{}) (i interface{}, err error) {
// 攔截器組執行計數小於全域性攔截器組的數量,說明當前呼叫仍然執行全域性攔截器
if groupI < ggCount {
for {
group := usi.global[groupI]
// 判斷當前組白名單是否包含該方法
if _, ok := group.skip[info.FullMethod]; !ok {
// 不包含,得到當前攔截器組執行到哪個攔截器的索引,計數加1
index := handlerI
handlerI++
if index < len(group.handlers) {
// 執行當前攔截器
return group.handlers[index](ctx, req, info, chainHandler)
}
// 上步得到的索引大於該組數量,則該組已經執行完成,需要跳到下一組
// 先將攔截器計數歸0
handlerI = 0
}
// 攔截器組執行完或者方法在攔截器組白名單,都跳到下一組攔截器
// 攔截器組計數加1
groupI++
if groupI >= ggCount { // 攔截器組執行完跳出迴圈
break
}
}
}
...
}
...
}
}
到這,我們可以在程式碼中使用了
func LoadUnaryInterceptors() grpc.UnaryServerInterceptor {
md := &UnaryServerInterceptors{}
// 載入全域性攔截器
md.UseGlobal([]grpc.UnaryServerInterceptor{AuthGuard},
// 完全放開的api
"/user.UserService/RegisterRule",
"/user.UserService/SendRegisterCode",
...
)
...
// 載入區域性攔截器
md.UseMethod("/user.UserService/SendRegisterCode", DisableGuard, MsgCodeRateLimitGuard)
return md.UnaryServerInterceptor()
}
srv := grpc.NewServer(
grpc.UnaryInterceptor(LoadUnaryInterceptors()),
)
其它型別的攔截器擴充套件也跟此差不多,就不講解了。可以參考程式碼:github.com/welllog/grpc_intercepto...
本作品採用《CC 協議》,轉載必須註明作者和本文連結