golang 防SQL隱碼攻擊 基於反射、TAG標記實現的不定引數檢查器

morning_sun發表於2018-07-09

  收到一個任務,所有http的handler要對入參檢查,防止SQL隱碼攻擊。剛開始笨笨的,打算為所有的結構體寫一個方法,後來統計了下,要寫幾十上百,隨著業務增加,以後還會重複這個無腦力的機械勞作。想想就low。

  直接做一個不定引數的自動檢測函式不就ok了麼?

  磨刀不誤砍柴工,用了一個下午的時間,調教出一個演算法:把不定結構體物件扔進去,這個函式自動檢查。

  普通場景還好,不比電信級業務,比如FRR快切,要求50ms以內重新整理百萬路由。

  先說說我的想法,然後把程式碼貼後面。

 

  這裡猶豫,要不要做併發?就要看需求了。
  需求0:呼叫者傳入一個結構體物件,要檢查這個物件有沒有變數注入指令碼,沒必要併發;
  需求1:呼叫者傳入多個結構體物件,要檢查這些物件有沒有變數注入指令碼,不能明確是哪個物件的變數有誤,需要併發。這種情況的話,感覺簡單的SQL檢查用不上,牛刀殺雞了;
  擴充0:如果想對TAG的長度做限制,沒必要在const裡面多定義幾個限制長度的變數,直接把TAG改成TAG.len就ok了,在迭代器裡面把小數點後面的長度提出來,扔到具體的檢查函式裡去。
  擴充1:這個函式可以改裝成自動對映器:工具自動生成對映程式碼,插入檢查器中對basic type的switch裡,根據tag自動對映。省去程式設計師的機械編碼,對映部分全自動。
  擴充2:只要一個結構體中,對變數進行tag自定義,就可以對這個結構體的所有變數進行任意處理。生產工具發展生產力!

 

  函式缺陷:高併發場景下,可能會有效能瓶頸,畢竟用了遞迴,棧空間吃緊,並且影響程式的可理解性,對程式的測試也有一定影響。

 

  個人的想象力總歸是有限的,讀者如果有什麼更燒腦,異想天開的想法,可以留言,一起分析,一起進步。

   

  好了,貼程式碼吧:

定義了一個三層的結構體。

type ZZZStu struct {
	ggg int    `sql:"int"`
	hhh string `sql:"email"`
}

type YYYStu struct {
	ddd int    `sql:"int"`
	eee string `sql:"alphaandnum"`
	zzz ZZZStu
}

type XXXStu struct {
	aaa int      `sql:"int"`
	bbb []string `sql:"num"`
	yyy YYYStu
}

 

然後在main裡面定義了這個結構體物件例項,在裡面隨意新增一些非法字元,除錯測試使用。

func main() {
	var temp XXXStu
	temp.aaa = 1
	//temp.bbb[0] = "123"
	bbb_tmp := "1"
	temp.bbb = append(temp.bbb, bbb_tmp)
	bbb_tmp = "2"
	temp.bbb = append(temp.bbb, bbb_tmp)
	bbb_tmp = "3"
	temp.bbb = append(temp.bbb, bbb_tmp)
	temp.yyy.ddd = 3
	temp.yyy.eee = "123qwe"
	temp.yyy.zzz.ggg = 5
	temp.yyy.zzz.hhh = `123456789@xxxxxxx.com`
	addMsg, ret := CheckSqlInject(temp)
	fmt.Println("main:", addMsg, ret)
	return
}

 

 下面是這個防SQL檢查器的最外層封裝。

/*****************************************************************************
* author          pxx
* date             2018/07/05
* rief		防sql注入檢查遞迴迭代器
* param[in]	不定引數
* 
eturn		給定結構體物件內的變數,有非法字元
* ingroup
* 
emarks
******************************************************************************/
func CheckSqlInject(args ...interface{}) (addMsg string, ret int) {

	for _, arg := range args {
		name := reflect.TypeOf(arg).Name()
		fmt.Printf("Recursioner %s (%T):
", name, arg)
		addMsg, ret = Recursioner(reflect.ValueOf(arg), name, name)
	}

	return
}

  

  

下面的Recursioner就是整個遞迴檢查器的核心部分了。可以看出來,我是從struct型別起始的,因為reflect包裡面只有structfield有TAG。

如果想擴充,就得在這個本包裡面實現,或者在公司內部的庫包裡面做。這裡不用三方庫,有開原始碼安全性問題的考量。

基本把所有型別涵蓋了:指標,介面,channel,陣列,切片,結構體,map,baisc type

func Recursioner(FieldValue reflect.Value, Path, FieldName string) (addMsg string, ret int) {
	switch FieldValue.Kind() {
	case reflect.Invalid:
		fmt.Printf("%s = invalid
", Path)

	//struct為起點(只有StructField有TAG),暫時滿足需求。如果以其他型別起始,需要對底層庫函式擴充,有時間再搞
	case reflect.Struct:
		for i := 0; i < FieldValue.NumField(); i++ {
			fieldInfo := FieldValue.Type().Field(i)
			tag := fieldInfo.Tag //  reflect.StructTag(string)
			name := tag.Get("sql")
			fieldPath := fmt.Sprintf("%s.%s (%s)", Path, FieldValue.Type().Field(i).Name, FieldValue.Type().Field(i).Type)
			addMsg, ret = Recursioner(FieldValue.Field(i), fieldPath, name)
			if ret != 0 {
				return
			}
		}

	case reflect.Slice, reflect.Array:
		for i := 0; i < FieldValue.Len(); i++ {
			addMsg, ret = Recursioner(FieldValue.Index(i), fmt.Sprintf("%s[%d]", Path, i), FieldName)
			if ret != 0 {
				return
			}
		}

	case reflect.Map:
		for _, key := range FieldValue.MapKeys() {
			addMsg, ret = Recursioner(FieldValue.MapIndex(key), fmt.Sprintf("%s[%s]", Path,
				formatAtom(key)), FieldName)
			if ret != 0 {
				return
			}
		}

	case reflect.Ptr:
		if FieldValue.IsNil() {
			fmt.Printf("%s = nil
", Path)
		} else {
			addMsg, ret = Recursioner(FieldValue.Elem(), fmt.Sprintf("(*%s)", Path), FieldName)
			if ret != 0 {
				return
			}
		}

	case reflect.Interface:
		if FieldValue.IsNil() {
			fmt.Printf("%s = nil
", Path)
		} else {
			fmt.Printf("%s.type = %s
", Path, FieldValue.Elem().Type())
			addMsg, ret = Recursioner(FieldValue.Elem(), Path+".value", FieldName)
			if ret != 0 {
				return
			}
		}

	default: // basic types, channels, funcs
		fmt.Printf("%s = %s
", Path, formatAtom(FieldValue))

		field_name := FieldValue.Type().Name()
		if field_name == "string" {

			//獲取該屬性的tag
			fmt.Println("tag_value=", FieldName)
			switch FieldName {
			case "alphaandnum":
				addMsg, ret = CheckAlphaAndNum(CheckAlphaAndNumLen, formatAtom(FieldValue))
				if ret != 0 {
					return
				}

			case "email":
				addMsg, ret = CheckEmail(CheckEmailLen, formatAtom(FieldValue))
				if ret != 0 {
					return addMsg, ret
				}

			case "num":
				addMsg, ret = CheckNum(CheckNumLen, formatAtom(FieldValue))
				if ret != 0 {
					return
				}
			}
			fmt.Println()
		}
	}
	return
}

  

 

  

下面這個函式,格式化資料。

func formatAtom(FieldValue reflect.Value) string {
	switch FieldValue.Kind() {
	case reflect.Invalid:
		return "invalid"

	case reflect.String:
		return FieldValue.String()

	case reflect.Int, reflect.Int8, reflect.Int16,
		reflect.Int32, reflect.Int64:
		return strconv.FormatInt(FieldValue.Int(), 10)

	case reflect.Uint, reflect.Uint8, reflect.Uint16,
		reflect.Uint32, reflect.Uint64, reflect.Uintptr:
		return strconv.FormatUint(FieldValue.Uint(), 10)

	// ...floating-point and complex cases omitted for brevity...
	case reflect.Bool:
		return strconv.FormatBool(FieldValue.Bool())

	case reflect.Chan, reflect.Func, reflect.Ptr, reflect.Slice, reflect.Map:
		return FieldValue.Type().String() + " 0x" +
			strconv.FormatUint(uint64(FieldValue.Pointer()), 16)

	default: // reflect.Array, reflect.Struct, reflect.Interface
		return FieldValue.Type().String() + " value"
	}
}

  

具體的欄位檢查函式,我貼一個就好,意思到了就行。原理也很簡單,用golang自帶的map(java裡是hashmap,python裡的dict,之前做的晶片驅動層表項管理,也是hash)。

func CheckAlphaAndNum(lenLimit int, str string) (addMsg string, ret int) {
	ret = 0
	var lenStr int = len(str)
	if lenStr > lenLimit {
		ret = -1
		return
	}
	for i := 0; i < lenStr; i++ {
		r := str[i]
		if _, ok := CheckAlphaAndNumMap[r]; !ok {
			ret = -1
			addMsg = "字母數字組合型別字串包含非法字元,請檢查!"
			return
		}
	}
	return
}

  

相關文章