文字識別(四)--大批量生成文字訓練集

Eason.wxd發表於2019-02-18

轉自:https://www.cnblogs.com/skyfsm/p/8436820.html

上次談到文字的切割,今天打算總結一下我們怎麼得到用於訓練的文字資料集。如果是想訓練一個手寫體識別的模型,用一些前人收集好的手寫文字集就好了,比如中科院的這些資料集。但是如果我們只是想要訓練一個專門用於識別印刷漢字的模型,那麼我們就需要各種印刷字型的訓練集,那怎麼獲取呢?藉助強大的影象庫,自己生成就行了!

先捋一捋思路,生成文字集需要什麼步驟:

  1. 確定你要生成多少字型,生成一個記錄著漢字與label的對應表。
  2. 確定和收集需要用到的字型檔案。
  3. 生成字型影象,儲存在規定的目錄下。
  4. 適當的資料增強。

第三步的生成字型影象最為重要,如果僅僅是生成很正規的文字,那麼用這個正規文字集去訓練模型,第一影象數目有點少,第二模型泛化能力比較差,所以我們需要對字型影象做大量的影象處理工作,以增大我們的印刷體文字資料集。

我總結了一下,我們可以做的一些影象增強工作有這些:

  1. 文字扭曲
  2. 背景噪聲(椒鹽)
  3. 文字位置(設定文字的中心點)
  4. 筆畫粘連(膨脹來模擬)
  5. 筆畫斷裂(腐蝕來模擬)
  6. 文字傾斜(文字旋轉)
  7. 多種字型

做完以上增強後,我們得到的資料集已經非常龐大了。

現在開始一步一步生成我們的3755個漢字的印刷體文字資料集。

一、生成漢字與label的對應表

這裡的漢字、label對映表的生成我使用了pickel模組,藉助它生成一個id:漢字的對映檔案儲存下來。
這裡舉個小例子說明怎麼生成這個“漢字:id”對映表。

首先在一個txt檔案裡寫入你想要的漢字,如果對漢字對應的ID沒有要求的話,我們不妨使用該漢字的排位作為其ID,比如“一二三四五”中,五的ID就是00005。如此類推,把漢字讀入記憶體,建立一個字典,把這個關係記錄下來,再使用pickle.dump存入檔案儲存。

二、收集字型檔案

字型檔案上網收集就好了,但是值得注意的是,不是每一種字型都支援漢字,所以我們需要篩選出真正適合漢字生成的字型檔案才可以。我一共使用了十三種漢字字型作為我們接下來漢字資料集用到的字型,具體如下圖:

當然,如果需要進一步擴大資料集來增強訓練得到的模型的泛化能力,可以花更多的時間去收集各類漢字字型,那麼模型在面對各種字型時也能從容應對,給出準確的預測。

三、文字影象生成

首先是定義好輸入引數,其中包括輸出目錄、字型目錄、測試集大小、影象尺寸、影象旋轉幅度等等。

def args_parse():
    #解析輸入引數
    parser = argparse.ArgumentParser(
        description=description, formatter_class=RawTextHelpFormatter)
    parser.add_argument('--out_dir', dest='out_dir',
                        default=None, required=True,
                        help='write a caffe dir')
    parser.add_argument('--font_dir', dest='font_dir',
                        default=None, required=True,
                        help='font dir to to produce images')
    parser.add_argument('--test_ratio', dest='test_ratio',
                        default=0.2, required=False,
                        help='test dataset size')
    parser.add_argument('--width', dest='width',
                        default=None, required=True,
                        help='width')
    parser.add_argument('--height', dest='height',
                        default=None, required=True,
                        help='height')
    parser.add_argument('--no_crop', dest='no_crop',
                        default=True, required=False,
                        help='', action='store_true')
    parser.add_argument('--margin', dest='margin',
                        default=0, required=False,
                        help='', )
    parser.add_argument('--rotate', dest='rotate',
                        default=0, required=False,
                        help='max rotate degree 0-45')
    parser.add_argument('--rotate_step', dest='rotate_step',
                        default=0, required=False,
                        help='rotate step for the rotate angle')
    parser.add_argument('--need_aug', dest='need_aug',
                        default=False, required=False,
                        help='need data augmentation', action='store_true')   
    args = vars(parser.parse_args()) 
    return args

接下來需要將我們第一步得到的對應表讀入記憶體,因為這個表示ID到漢字的對映,我們在做一下轉換,改成漢字到ID的對映,用於後面的字型生成。

#將漢字的label讀入,得到(ID:漢字)的對映表label_dict
label_dict = get_label_dict()

char_list=[]  # 漢字列表
value_list=[] # label列表
for (value,chars) in label_dict.items():
    print (value,chars)
    char_list.append(chars)
    value_list.append(value)

# 合併成新的對映關係表:(漢字:ID)
lang_chars = dict(zip(char_list,value_list)) 
font_check = FontCheck(lang_chars) 

我們對旋轉的角度儲存到列表中,旋轉角度的範圍是[-rotate,rotate].

if rotate < 0:
    roate = - rotate

if rotate > 0 and rotate <= 45:
    all_rotate_angles = []
    for i in range(0, rotate+1, rotate_step):  
        all_rotate_angles.append(i)
    for i in range(-rotate, 0, rotate_step):
        all_rotate_angles.append(i)
    #print(all_rotate_angles)

現在說一下字型影象是怎麼生成的,首先我們使用的工具是PIL。PIL裡面有很好用的漢字生成函式,我們用這個函式再結合我們提供的字型檔案,就可以生成我們想要的數字化的漢字了。我們先設定好我們生成的字型顏色為黑底白色,字型尺寸由輸入引數來動態設定。

# 生成字型影象
class Font2Image(object):

    def __init__(self,
                 width, height,
                 need_crop, margin):
        self.width = width
        self.height = height
        self.need_crop = need_crop
        self.margin = margin

    def do(self, font_path, char, rotate=0):
        find_image_bbox = FindImageBBox()
        # 黑色背景
        img = Image.new("RGB", (self.width, self.height), "black")
        draw = ImageDraw.Draw(img)
        font = ImageFont.truetype(font_path, int(self.width * 0.7),)
        # 白色字型
        draw.text((0, 0), char, (255, 255, 255),
                  font=font)
        if rotate != 0:
            img = img.rotate(rotate)
        data = list(img.getdata())
        sum_val = 0
        for i_data in data:
            sum_val += sum(i_data)
        if sum_val > 2:
            np_img = np.asarray(data, dtype='uint8')
            np_img = np_img[:, 0]
            np_img = np_img.reshape((self.height, self.width))
            cropped_box = find_image_bbox.do(np_img)
            left, upper, right, lower = cropped_box
            np_img = np_img[upper: lower + 1, left: right + 1]
            if not self.need_crop:
                preprocess_resize_keep_ratio_fill_bg = \
                    PreprocessResizeKeepRatioFillBG(self.width, self.height,
                                                    fill_bg=False,
                                                    margin=self.margin)
                np_img = preprocess_resize_keep_ratio_fill_bg.do(
                    np_img)
            # cv2.imwrite(path_img, np_img)
            return np_img
        else:
            print("img doesn't exist.")

我們寫兩個迴圈,外層迴圈是漢字列表,內層迴圈是字型列表,對於每個漢字會得到一個image_list列表,裡面儲存著這個漢字的所有影象。

for (char, value) in lang_chars.items():  # 外層迴圈是字
    image_list = []
    print (char,value)
    #char_dir = os.path.join(images_dir, "%0.5d" % value)
    for j, verified_font_path in enumerate(verified_font_paths):    # 內層迴圈是字型   
        if rotate == 0:
            image = font2image.do(verified_font_path, char)
            image_list.append(image)
        else:
            for k in all_rotate_angles: 
                image = font2image.do(verified_font_path, char, rotate=k)
                image_list.append(image)

我們將image_list中影象按照比例分為訓練集和測試集儲存。

        test_num = len(image_list) * test_ratio
        random.shuffle(image_list)  # 影象列表打亂
        count = 0
        for i in range(len(image_list)):
            img = image_list[i]
            #print(img.shape)
            if count < test_num :
                char_dir = os.path.join(test_images_dir, "%0.5d" % value)
            else:
                char_dir = os.path.join(train_images_dir, "%0.5d" % value)

            if not os.path.isdir(char_dir):
                os.makedirs(char_dir)

            path_image = os.path.join(char_dir,"%d.png" % count)
            cv2.imwrite(path_image,img)
            count += 1

寫好程式碼後,我們執行如下指令,開始生成印刷體文字漢字集。

 python gen_printed_char.py --out_dir ./dataset --font_dir ./chinese_fonts --width 30 --height 30 --margin 4 --rotate 30 --rotate_step 1

解析一下上述指令的附屬引數:

  1. --out_dir 表示生成的漢字影象的儲存目錄
  2. --font_dir 表示放置漢字字型檔案的路徑
  3. --width --height 表示生成影象的高度和寬度
  4. --margin 表示字型與邊緣的間隔
  5. --rotate 表示字型旋轉的範圍,[-rotate,rotate]
  6. --rotate_step 表示每次旋轉的間隔

生成這麼一個3755個漢字的資料集的所需的時間還是很久的,估計接近一個小時。其實這個生成過程可以用多執行緒、多程式並行加速,但是考慮到這種文字資料集只需生成一次就好,所以就沒做這方面的優化了。資料集生成完我們可以發現,在dataset資料夾下得到train和test兩個資料夾,train和test資料夾下都有3755個子資料夾,分別儲存著生成的3755個漢字對應的影象,每個子檔案的名字就是該漢字對應的id。隨便選擇一個train資料夾下的一個子資料夾開啟,可以看到所獲得的漢字影象,一共634個。

dataset下自動生成測試集和訓練集

測試集和訓練集下都有3755個子資料夾,用於儲存每個漢字的影象。

生成出來的漢字影象

額外的影象增強

第三步生成的漢字影象是最基本的資料集,它所做的影象處理僅有旋轉這麼一項,如果我們想在資料增強上再做多點東西,想必我們最終訓練出來的OCR模型的效能會更加優秀。我們使用opencv來完成我們定製的漢字影象增強任務。

因為生成的影象比較小,僅僅是30*30,如果對這麼小的影象加噪聲或者形態學處理,得到的字型影象會很糟糕,所以我們在做資料增強時,把圖片尺寸適當增加,比如設定為100×100,再進行相應的資料增強,效果會更好。

噪點增加

def add_noise(cls,img):
    for i in range(20): #新增點噪聲
        temp_x = np.random.randint(0,img.shape[0])
        temp_y = np.random.randint(0,img.shape[1])
        img[temp_x][temp_y] = 255
    return img

適當腐蝕

def add_erode(cls,img):
    kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(3, 3))    
    img = cv2.erode(img,kernel) 
    return img

適當膨脹

def add_dilate(cls,img):
    kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(3, 3))    
    img = cv2.dilate(img,kernel) 
    return img

然後做隨機擾動

def do(self,img_list=[]):
    aug_list= copy.deepcopy(img_list)
    for i in range(len(img_list)):
        im = img_list[i]
        if self.noise and random.random()<0.5:
            im = self.add_noise(im)
        if self.dilate and random.random()<0.25:
            im = self.add_dilate(im)
        if self.erode and random.random()<0.25:
            im = self.add_erode(im)    
        aug_list.append(im)
    return aug_list

輸入指令

python gen_printed_char.py --out_dir ./dataset2 --font_dir ./chinese_fonts --width 100 --height 100 --margin 10 --rotate 30 --rotate_step 1 --need_aug

使用這種生成的影象如下圖所示,第一資料集擴大了兩倍,第二影象的豐富性進一步提高,效果還是明顯的。當然,如果要獲得最好的效果,還需要調一下里面的引數,這裡就不再詳細說明了。

至此,我們所需的印刷體漢字資料集已經成功生成完畢,下一步要做的就是利用這些資料集設計一個卷積神經網路做文字識別了!

相關文章