拆分PPOCRLabel標註的資料集並生成識別資料集
說明
- 首次發表日期:2024-10-31
- 參考:
- https://github.com/PFCCLab/PPOCRLabel/blob/main/README_ch.md
關於PPOCRLabel以及本文緣起
PPOCRLabel是OCR領域的標註工具,其本身自帶匯出識別資料和拆分資料集的功能。其中:
PPOCRLabel本身自帶匯出識別資料的功能,但是儲存檢測框圖片時會自動旋轉圖片,具體見其saveRecResult
函式實現程式碼: https://github.com/PFCCLab/PPOCRLabel/blob/81a9c550b7b625bd003a16681fcc7d782184d1f4/PPOCRLabel.py#L3371
def saveRecResult(self):
if {} in [self.PPlabelpath, self.PPlabel, self.fileStatedict]:
QMessageBox.information(self, "Information", "Check the image first")
return
base_dir = os.path.dirname(self.PPlabelpath)
rec_gt_dir = base_dir + "/rec_gt.txt"
crop_img_dir = base_dir + "/crop_img/"
ques_img = []
if not os.path.exists(crop_img_dir):
os.mkdir(crop_img_dir)
with open(rec_gt_dir, "w", encoding="utf-8") as f:
for key in self.fileStatedict:
idx = self.getImglabelidx(key)
try:
img_path = os.path.dirname(base_dir) + "/" + key
img = cv2.imdecode(np.fromfile(img_path, dtype=np.uint8), -1)
for i, label in enumerate(self.PPlabel[idx]):
if label["difficult"]:
continue
img_crop = get_rotate_crop_image(
img, np.array(label["points"], np.float32)
)
img_name = (
os.path.splitext(os.path.basename(idx))[0]
+ "_crop_"
+ str(i)
+ ".jpg"
)
cv2.imencode(".jpg", img_crop)[1].tofile(
crop_img_dir + img_name
)
f.write("crop_img/" + img_name + "\t")
f.write(label["transcription"] + "\n")
except KeyError as e:
pass
except Exception as e:
ques_img.append(key)
traceback.print_exc()
if ques_img:
QMessageBox.information(
self,
"Information",
"The following images can not be saved, please check the image path and labels.\n"
+ "".join(str(i) + "\n" for i in ques_img),
)
QMessageBox.information(
self,
"Information",
"Cropped images have been saved in " + str(crop_img_dir),
)
其中get_rotate_crop_image
函式定義: https://github.com/PFCCLab/PPOCRLabel/blob/81a9c550b7b625bd003a16681fcc7d782184d1f4/libs/utils.py#L137
def get_rotate_crop_image(img, points):
# Use Green's theory to judge clockwise or counterclockwise
# author: biyanhua
d = 0.0
for index in range(-1, 3):
d += (
-0.5
* (points[index + 1][1] + points[index][1])
* (points[index + 1][0] - points[index][0])
)
if d < 0: # counterclockwise
tmp = np.array(points)
points[1], points[3] = tmp[3], tmp[1]
try:
img_crop_width = int(
max(
np.linalg.norm(points[0] - points[1]),
np.linalg.norm(points[2] - points[3]),
)
)
img_crop_height = int(
max(
np.linalg.norm(points[0] - points[3]),
np.linalg.norm(points[1] - points[2]),
)
)
pts_std = np.float32(
[
[0, 0],
[img_crop_width, 0],
[img_crop_width, img_crop_height],
[0, img_crop_height],
]
)
M = cv2.getPerspectiveTransform(points, pts_std)
dst_img = cv2.warpPerspective(
img,
M,
(img_crop_width, img_crop_height),
borderMode=cv2.BORDER_REPLICATE,
flags=cv2.INTER_CUBIC,
)
dst_img_height, dst_img_width = dst_img.shape[0:2]
if dst_img_height * 1.0 / dst_img_width >= 1.5:
dst_img = np.rot90(dst_img)
return dst_img
except Exception as e:
print(e)
但是,有的場景是不需要在將裁剪的檢測框旋轉後再儲存的。
另外,PPOCRLabel官方自帶指令碼可以用於拆分資料集:
python gen_ocr_train_val_test.py --trainValTestRatio 9:1:0 --datasetRootPath dataset/handwritten_digits/images --detRootPath ./train_data/det --recRootPath ./train_data/rec
拆分資料集並生成識別資料集
標註檔案格式
假設我們有資料集及其標註檔案:
data_dir = "data/"
label_file = 'data/Label_det.txt'
PPOCRLabel的標註檔案是 PaddleOCR 文字檢測資料格式。
PaddleOCR 中的文字檢測演算法支援的標註檔案格式如下,中間用"\t"分隔:
" 影像檔名 json.dumps編碼的影像標註資訊"
ch4_test_images/img_61.jpg [{"transcription": "MASA", "points": [[310, 104], [416, 141], [418, 216], [312, 179]]}, {...}]
這兒假設
data_dir
加上標註檔案中的影像檔名
將構成圖片的路徑。
讀取標註檔案
讀取標註檔案內容,並檢查下標註檔案中的圖片是否都存在:
def get_label_lines(data_dir: str, label_file:str):
with open(label_file, 'r') as f:
label_lines = f.readlines()
for line in label_lines:
img_path, img_label = line.split("\t")
img_rel_path = os.path.join(data_dir, img_path)
if not os.path.exists(img_rel_path):
print(f'{img_rel_path} not exists!')
return label_lines
label_lines = get_label_lines(parent_dir, label_file)
拆分標註資料並儲存
拆分label_lines:
from sklearn.model_selection import train_test_split
train_set_label_lines, test_set_label_lines = train_test_split(label_lines, test_size = 0.2, random_state = 42)
儲存為具體的資料集(圖片和標註檔案):
def save_split_data(
split_label_lines,
data_dir,
dest_dir = "dataset",
split_name = "train",
):
new_label_lines = []
first_img_path = split_label_lines[0].split("\t")[0]
parent_dir_name = os.path.split(os.path.dirname(os.path.join(dest_dir, first_img_path)))[-1]
rel_dest_img_path = "_".join([parent_dir_name, split_name])
dest_dir = os.path.join(dest_dir, rel_dest_img_path)
os.makedirs(dest_dir, exist_ok=True)
for line in split_label_lines:
img_path, label_text = line.split("\t")
label_text = label_text.replace("\n", "")
assert parent_dir_name == os.path.split(os.path.dirname(os.path.join(dest_dir, img_path)))[-1]
new_label_lines.append("\t".join([os.path.join(rel_dest_img_path, os.path.basename(img_path)), label_text]))
shutil.copy2(os.path.join(data_dir, img_path), os.path.join(dest_dir, os.path.basename(img_path)))
label_file_path = os.path.join(dest_dir, "_".join(["Label", parent_dir_name, split_name]) + ".txt")
with open(label_file_path, "w") as f:
f.write("\n".join(new_label_lines))
return dest_dir, label_file_path
train_img_dir, train_det_label_file = save_split_data(
train_set_label_lines,
data_dir = data_dir,
dest_dir = "dataset",
split_name = "train",
)
test_img_dir, test_det_label_file = save_split_data(
test_set_label_lines,
data_dir = data_dir,
dest_dir = "dataset",
split_name = "test",
)
生成識別圖片和標籤
def generate_rec_img_label(label_file_path, parent_dir, do_crop = True):
with open(label_file_path, 'r') as f:
label_lines = f.readlines()
rec_label_lines = []
for line in label_lines:
img_path, label_text = line.split("\t")
label_text = label_text.replace("\n", "")
label_list = json.loads(label_text)
img = cv2.imread(os.path.join(parent_dir, img_path))
parent_img_dir = os.path.split(os.path.dirname(img_path))[-1]
dest_img_dir = dest_img_path = os.path.join(parent_dir, parent_img_dir, "crop_img")
os.makedirs(dest_img_dir, exist_ok=True)
for idx, label in enumerate(label_list):
crop_img_name = os.path.splitext(os.path.basename(img_path))[0] + "_crop_" + str(idx) + ".jpg"
rec_label_lines.append("\t".join([os.path.join(parent_img_dir, "crop_img", crop_img_name), label["transcription"]]))
dest_img_path = os.path.join(dest_img_dir, crop_img_name)
if do_crop:
pt0, pt1, pt2, pt3 = label["points"]
crop_img = img[pt0[1]:pt2[1], pt0[0]:pt2[0]]
cv2.imwrite(dest_img_path, crop_img)
else:
shutil.copy2(os.path.join(parent_dir, img_path), dest_img_path)
with open(os.path.join(parent_dir, "_".join([os.path.splitext(os.path.basename(label_file_path))[0], "rec"]) + ".txt"), 'w') as f:
f.write("\n".join(rec_label_lines))
generate_rec_img_label(train_det_label_file, "dataset")
generate_rec_img_label(test_det_label_file, "dataset")