4.SIFT特徵提取與ransac演算法

真真夜夜發表於2024-06-27

不廢話了,直接上程式碼吧

import cv2
import matplotlib.pyplot as plt
import numpy as np
import random

plt.rcParams['figure.figsize'] = [15, 15]

# 讀取影像並轉換為灰度圖
def read_image(path):
    img = cv2.imread(path)
    img_gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)  # 轉換為灰度圖
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)    # 轉換為RGB格式
    return img_gray, img, img_rgb

# SIFT演算法提取關鍵點和描述符
def SIFT(img):
    siftDetector = cv2.SIFT_create()
    kp, des = siftDetector.detectAndCompute(img, None)
    return kp, des

# 繪製SIFT關鍵點
def plot_sift(gray, rgb, kp):
    tmp = rgb.copy()
    img = cv2.drawKeypoints(gray, kp, tmp, flags=cv2.DRAW_MATCHES_FLAGS_DRAW_RICH_KEYPOINTS)
    return img

# 特徵點匹配
def matcher(kp1, des1, img1, kp2, des2, img2, threshold):
    bf = cv2.BFMatcher()  # 使用預設引數的BFMatcher
    matches = bf.knnMatch(des1, des2, k=2)  # 使用knnMatch進行特徵點匹配

    good = []
    for m, n in matches:
        if m.distance < threshold * n.distance:
            good.append([m])  # 應用比值測試進行篩選

    matches = []
    for pair in good:
        matches.append(list(kp1[pair[0].queryIdx].pt + kp2[pair[0].trainIdx].pt))

    matches = np.array(matches)
    return matches

# 繪製匹配的特徵點
def plot_matches(matches, total_img, filename=None):
    match_img = total_img.copy()
    offset = total_img.shape[1] / 2

    plt.clf()
    fig, ax = plt.subplots()
    ax.set_aspect('equal')
    ax.imshow(np.array(match_img).astype('uint8'))

    ax.plot(matches[:, 0], matches[:, 1], 'xr')  # 標記左影像中的特徵點
    ax.plot(matches[:, 2] + offset, matches[:, 3], 'xr')  # 標記右影像中的特徵點

    ax.plot([matches[:, 0], matches[:, 2] + offset], [matches[:, 1], matches[:, 3]],
            'r', linewidth=0.5)  # 連線匹配的特徵點對

    plt.savefig(f'{filename}.png')

# 計算單應性矩陣H
def homography(pairs):
    rows = []
    for i in range(pairs.shape[0]):
        p1 = np.append(pairs[i][0:2], 1)
        p2 = np.append(pairs[i][2:4], 1)
        row1 = [0, 0, 0, p1[0], p1[1], p1[2], -p2[1] * p1[0], -p2[1] * p1[1], -p2[1] * p1[2]]
        row2 = [p1[0], p1[1], p1[2], 0, 0, 0, -p2[0] * p1[0], -p2[0] * p1[1], -p2[0] * p1[2]]
        rows.append(row1)
        rows.append(row2)
    rows = np.array(rows)
    U, s, V = np.linalg.svd(rows)
    H = V[-1].reshape(3, 3)
    H = H / H[2, 2]  # 標準化,使得H[2,2] = 1
    return H

# 從匹配集合中隨機選擇四個點
def random_point(matches, k=4):
    idx = random.sample(range(len(matches)), k)
    point = [matches[i] for i in idx]
    return np.array(point)

# 計算重投影誤差
def get_error(points, H):
    num_points = len(points)
    all_p1 = np.concatenate((points[:, 0:2], np.ones((num_points, 1))), axis=1)
    all_p2 = points[:, 2:4]
    estimate_p2 = np.zeros((num_points, 2))
    for i in range(num_points):
        temp = np.dot(H, all_p1[i])
        estimate_p2[i] = (temp / temp[2])[0:2]  # 歸一化座標
    errors = np.linalg.norm(all_p2 - estimate_p2, axis=1) ** 2  # 計算誤差
    return errors

# RANSAC演算法尋找最優單應性矩陣和內點集合
def ransac(matches, threshold, iters, min_error_prob=0.75):
    num_best_inliers = 0
    current_error_prob = 1.0

    for i in range(iters):
        points = random_point(matches)
        H = homography(points)

        if np.linalg.matrix_rank(H) < 3:  # 避免奇異矩陣
            continue

        errors = get_error(matches, H)
        idx = np.where(errors < threshold)[0]
        inliers = matches[idx]

        num_inliers = len(inliers)
        if num_inliers > num_best_inliers:
            best_inliers = inliers.copy()
            num_best_inliers = num_inliers
            best_H = H.copy()

        # 更新當前錯誤機率
        current_error_prob = 1.0 - (num_best_inliers / len(matches))

        # 判斷是否達到最小錯誤機率要求
        if current_error_prob < min_error_prob:
            break

    print("內點數/總匹配數: {}/{}".format(num_best_inliers, len(matches)))
    return best_inliers, best_H

if __name__ == '__main__':
    # 讀取影像
    left_gray, left_origin, left_rgb = read_image('image1.jpg')
    right_gray, right_origin, right_rgb = read_image('image2.jpg')

    # 確定目標尺寸,這裡以左影像的尺寸為準
    target_height, target_width = left_gray.shape[:2]

    # 調整右影像的大小以匹配左影像的尺寸
    right_gray = cv2.resize(right_gray, (target_width, target_height))
    right_rgb = cv2.resize(right_rgb, (target_width, target_height))

    # 使用灰度圖進行SIFT特徵提取
    kp_left, des_left = SIFT(left_gray)
    kp_right, des_right = SIFT(right_gray)

    # 繪製SIFT關鍵點影像
    kp_left_img = plot_sift(left_gray, left_rgb, kp_left)
    kp_right_img = plot_sift(right_gray, right_rgb, kp_right)
    total_kp = np.concatenate((kp_left_img, kp_right_img), axis=1)
    plt.imshow(total_kp)
    plt.savefig('keypoints.png')
    plt.clf()

    # 進行特徵點匹配
    matches = matcher(kp_left, des_left, left_rgb, kp_right, des_right, right_rgb, 0.5)

    # 將左右兩幅影像拼接起來並繪製匹配的特徵點
    total_img = np.concatenate((left_rgb, right_rgb), axis=1)
    plot_matches(matches, total_img, 'matches')

    # 使用RANSAC演算法找出最優的單應性矩陣和內點集合,並繪製內點匹配
    inliers, H = ransac(matches, 0.5, 2000)
    plot_matches(inliers, total_img, 'inliers')

相關文章