【2】Kaggle 醫學影像資料讀取

SXWisON發表於2024-08-21

賽題名稱:RSNA 2024 Lumbar Spine Degenerative Classification
中文:腰椎退行性病變分類

kaggle官網賽題連結:https://www.kaggle.com/competitions/rsna-2024-lumbar-spine-degenerative-classification/overview

文章安排


①、如何用python讀取dcm/dicom檔案
②、基於matplotlib視覺化
③、繪製頻率分佈直方圖
④、程式碼彙總

檔案依賴


# requirements.txt
# Python version 3.11.8
torch==2.3.1
torchvision==0.18.1
matplotlib==3.8.4
pydicom==2.4.4
numpy==1.26.4
pip install -r requirements.txt

讀取dicom影像並做預處理

概述

本文中採取pydicom包讀取dicom檔案,其關鍵程式碼格式為:

dcm_tensor = pydicom.dcmread(dcm_file)

注意資料集的路徑,其在train_images檔案下存放了每一患者的資料,對於每一患者包含三張MRI影像,每張MRI影像存放為一個資料夾。
需要注意的是,MRI影像為三維影像(dicom格式),一般習慣性將其每個切片分別儲存為一個dcm檔案,因此一張dicom影像將被存為一個資料夾,如下圖

我們可以採用如下路徑訪問該dicom檔案:

"./train_images/4003253/702807833"

讀取路徑


為了讀取dicom影像,我們需要寫程式碼讀取資料夾中的所有dcm檔案

# dicom檔案路徑
dicom_dir = "./train_images/4003253/702807833"
# 儲存所有dcm檔案的路徑
dicom_files = [os.path.join(dicom_dir, f) for f in os.listdir(dicom_dir) if f.endswith('.dcm')] 
  • os.listdir:返回dicom_dir路徑下的所有檔案
  • f.endswith('.dcm') :篩選所有dcm格式的檔案
  • os.path.join: 將dcm檔名新增到dicom_dir之後
    示意:"./hello"+“1.dcm”->"./hello/1.dcm"

路徑排序


這次的kaggle賽題所給的資料集中,檔名的迭代方式為:

1.dcm、2.dcm、...、9.dcm、10.dcm、11.dcm、...

這給我們帶來了一定的麻煩,因為在os的檔名排序規則中,首先檢索高位字母的ASCII碼大小做排序,也就是說10.dcm將被認為是2.dcm前面的檔案。
對此,本文采用正規表示式的方式,實現了依據檔名中數字大小排序。

def extract_number(filepath):
    # 獲取檔名(包括副檔名)
    filename = os.path.basename(filepath)
    # 提取檔名中的數字部分,假設檔名以數字結尾,如 '1.dcm'
    match = re.search(r'(\d+)\.dcm$', filename)
    return int(match.group(1)) if match else float('inf')

# 基於數字控制代碼排序
dicom_files.sort(key=extract_number)

該程式碼效果如下:

讀取影像


為讀取dicom影像,我們需要依次讀取每一個dcm檔案,並將其最終打包為3D tensor,下述程式碼實現了該功能:

# 建立空列表儲存所有dcm檔案
dcm_list= []

# 迭代每一個檔案
for dcm_file in dicom_files:
    # 讀取檔案
    dcm = pydicom.dcmread(dcm_file)
    # 將其轉為numpy格式
    image_data = dcm.pixel_array.astype(np.float32)
    # 加入檔案列表 
    dcm_list.append(image_data)

# 將圖片堆疊為3D張量
tensor_dcm = torch.stack([torch.tensor(image_data) for image_data in dcm_list])

資料預處理


常見的預處理方式有兩種,歸一化(Normalization)量化(Quantization)

  • 歸一化:將資料縮放到某個標準範圍內的過程。常見的歸一化方法包括最小-最大歸一化(Min-Max Normalization)和Z-score標準化(Z-score Normalization),前者將資料歸一化至[0,1]範圍,後者將資料轉化為標準正態分佈。本例中採用Min-Max方案。

  • 量化:量化是將資料的值域退化到離散值的過程。常用於減少儲存和計算成本,尤其在神經網路模型中。量化通常將浮點數值轉換為整數值。量化前一般先進行歸一化。

歸一化的實現如下:

def norm_tensor(tensor_dicom):
    # 查詢影像的最大值和最小值
    vmin, vmax = tensor_dicom.min(), tensor_dicom.max()
    # 歸一化
    tensor_dicom= (tensor_dicom- vmax ) / (max_val - vmin)
    
    return tensor_dicom

實現基於method控制代碼選擇預處理方式:

if method == "norm":
    # 歸一化
    tensor_dcm = norm_tensor(tensor_dcm)
elif method == "uint8":
    # 歸一化
    tensor_dcm = norm_tensor(tensor_dcm)
    # 量化
    tensor_dcm = (tensor_dcm * 255).clamp(0, 255).to(torch.uint8)

繪圖


由於dicom影像為三維資料,視覺化時我們一般將其在z軸上分為多個切片依次視覺化,本文采用的方式是,採用5*5網格視覺化至多25個切片。

def show_dciom(tensor_dicom):
    # 查詢影像的最大最小值
    vmin, vmax = tensor_dicom.min(), tensor_dicom.max()
    
    # 建立一個圖形視窗
    fig, axes = plt.subplots(5, 5, figsize=(15, 15))  # 5x5 網格佈局

    count = 0
    length = tensor_dicom.size()[0]
    for i in range(25):
        if count < length:
            count += 1
        else:
            return
        # 獲取當前影像的座標
        ax = axes[i // 5, i % 5]
        # 顯示圖片
        ax.imshow(tensor_dicom[i], cmap='gray') # , vmin=vmin, vmax=vmax
        ax.axis('off')  # 關閉座標軸
    
    plt.tight_layout() # 避免重疊
    plt.title(f"Layer {i}")
    plt.show()

這裡有一點需要比較注意,在ax.imshow()函式中,我們指定了vmin和vmax引數;這是因為當該引數未被指定時,imshow函式將會自動調整點的亮度,使值最大的點對應255亮度,值最小的點對應0亮度。鑑於相鄰切片最大、最小畫素值可能存在較大差異,這將使得相鄰切片的影像亮度較異常,如下圖:

這兩張圖的左上角區域實際上亮度相近,但從視覺化影像來看,存在較大差異,這將對觀察帶來誤解。

視覺化頻率分佈直方圖


視覺化MRI影像的頻率分佈直方圖在醫學影像處理中有重要意義,主要包括以下幾個方面:

  • 影像對比度分析:頻率分佈直方圖可以顯示MRI影像中不同灰度級別(或畫素強度)的分佈情況。透過分析直方圖的形狀和範圍,可以瞭解影像的對比度。例如,直方圖的分佈範圍較廣表示影像對比度較高,能夠更好地區分不同組織或結構。

  • 影像均衡化:透過直方圖均衡化,可以改善影像的對比度,使得低對比度的區域更加清晰。均衡化過程透過重新分配影像中的畫素值,使得直方圖的分佈更加均勻,從而增強影像的視覺效果。

  • 組織分割:頻率分佈直方圖可以幫助確定適當的閾值,以進行影像分割。透過分析直方圖,可以選擇合適的閾值將不同組織或病變從背景中分離出來。

  • 影像質量評估:直方圖分析可以揭示影像的質量問題,例如過暗或過亮的影像,或者影像噪聲的影響。透過直方圖的形態,可以評估影像是否需要進一步的處理或最佳化。

在繪製頻率分佈直方圖前,需要先將三維向量展平,本文采用plt.hist函式繪製

def show_hist(tensor_dicom):
    # 將所有圖片的畫素值展平為一個一維陣列
    pixel_values = tensor_dicom.numpy().flatten()

    # 繪製直方圖
    plt.figure(figsize=(10, 6))
    plt.hist(pixel_values, bins=50, color='gray', edgecolor='black')
    plt.title('Histogram of All Pixel Values')
    plt.xlabel('Pixel Value')
    plt.ylabel('Frequency')
    plt.grid(True)
    plt.show()

直方圖呈現如下分步,在val=0附近有一高峰,這是因為MRI影像中大部分割槽域並不存在人體組織,為空值0。
倘若除零以外的點過分集中在較小值(<100),那麼很可能是因為MRI影像中出現了一個亮度極大的噪點,使得以該噪點亮度為最值歸一化質量較差,對於這種情形,可以用99%分位數代替最大值,並將99%分位數歸一化至亮度為200. (比起歸一化至255,這將允許亮度最大1%的畫素點亮度值有區分)。
本例中影像質量均較高,故不需要做特殊處理。

程式碼彙總


程式碼架構

主函式

# main.py
# Import custom utility functions
from utils import read_one_dicom, show_dciom, show_hist

# Define the directory containing the DICOM images
dicom_dir = "./train_images/4003253/1054713880"

# Read the DICOM image into a tensor with uint8 data type
tensor_dicom = read_one_dicom(dicom_dir, method="uint8")

# Display the DICOM image slices in a 5x5 grid layout
show_dciom(tensor_dicom)

# Plot the histogram of pixel values from the DICOM image slices
show_hist(tensor_dicom)

# Convert the tensor to a NumPy array for further processing or inspection
np_img = tensor_dicom.numpy()

包檔案

from .preprocess import read_one_dicom

from .show import show_dciom
from .show import show_hist

讀取&預處理

# preprocess.py
import numpy as np
import torch
import os
import re
import pydicom
from tqdm import tqdm

def norm_tensor(tensor_dicom):
    """
    Normalize the image tensor to the range [0, 1].

    Args:
        tensor_dicom (torch.Tensor): Tensor containing image data.

    Returns:
        torch.Tensor: Normalized image tensor.
    """
    # Calculate the maximum and minimum values of the image tensor
    vmin, vmax = tensor_dicom.min(), tensor_dicom.max()

    # Normalize the image tensor to the range [0, 1]
    tensor_dicom = (tensor_dicom - vmin) / (vmax - vmin)
    
    return tensor_dicom

def extract_number(filepath):
    """
    Extract the numeric part from the DICOM filename.

    Args:
        filepath (str): Path to the DICOM file.

    Returns:
        int: Extracted number from the filename. Returns float('inf') if not found.
    """
    # Get the filename (including extension)
    filename = os.path.basename(filepath)
    # Extract numeric part from filename, assuming filenames end with digits, e.g., '1.dcm'
    match = re.search(r'(\d+)\.dcm$', filename)
    return int(match.group(1)) if match else float('inf')

def read_one_dicom(dicom_dir, method = "", bar_title = ""):
    """
    Reads DICOM files from a directory and converts them into a PyTorch tensor.

    Args:
        dicom_dir (str): Directory containing DICOM files.
        method (str): Optional method to process the tensor ('norm' for normalization, 'uint8' for normalization and conversion to uint8).
        bar_title (str): Optional title for the progress bar.

    Returns:
        torch.Tensor: PyTorch tensor containing image data from DICOM files.
    """
    # Get all DICOM files and sort them based on numeric part of the filename
    dicom_files = [os.path.join(dicom_dir, f) for f in os.listdir(dicom_dir) if f.endswith('.dcm')]    
    dicom_files.sort(key=extract_number)

    # Create an empty list to store image data
    dcm_list = []

    # Initialize tqdm progress bar
    with tqdm(total=len(dicom_files), desc='Processing DICOM files', unit='dcm', unit_scale=True, unit_divisor=1000000) as pbar:
        # Iterate over each DICOM file and read image data
        for count, dcm_file in enumerate(dicom_files, start=1):
            # Read the DICOM file
            dcm = pydicom.dcmread(dcm_file)

            # Extract and convert image data to a NumPy array
            image_data = dcm.pixel_array.astype(np.float32)

            # Add the image data to the list
            dcm_list.append(image_data)

            # Update progress bar description
            pbar.set_description(bar_title + 'Reading')

            # Update progress bar
            pbar.update(1)

    # Convert the list of image data to a PyTorch tensor and stack into a 3D tensor
    tensor_dicom = torch.stack([torch.tensor(image_data) for image_data in dcm_list])

    if method == "norm":
        # Normalize the image tensor
        tensor_dicom = norm_tensor(tensor_dicom)
    elif method == "uint8":
        # Normalize the image tensor
        tensor_dicom = norm_tensor(tensor_dicom)
        # Scale the tensor values to the range [0, 255] and convert to uint8 type
        tensor_dicom = (tensor_dicom * 255).clamp(0, 255).to(torch.uint8)

    return tensor_dicom

視覺化、繪製直方圖

# show.py
import numpy as np
import torch
import matplotlib.pyplot as plt

def show_dciom(tensor_dicom):
    """
    Display MRI image slices in a 5x5 grid layout.

    Parameters:
    tensor_dicom (torch.Tensor): Tensor containing MRI image slices, expected shape is (N, H, W),
                                 where N is the number of slices, and H and W are the height and width of the images.
    """
    # Calculate the minimum and maximum pixel values in the tensor
    vmin, vmax = tensor_dicom.min(), tensor_dicom.max()
    
    # Create a figure with a 5x5 grid layout
    fig, axes = plt.subplots(5, 5, figsize=(15, 15))  # 5x5 grid layout

    count = 0
    length = tensor_dicom.size(0)
    for i in range(25):
        if count < length:
            count += 1
        else:
            return
        # Get the current subplot's axis
        ax = axes[i // 5, i % 5]
        # Display the image
        ax.imshow(tensor_dicom[count - 1], cmap='gray', vmin=vmin, vmax=vmax)
        ax.axis('off')  # Hide the axis
    
    plt.tight_layout()  # Adjust layout to prevent overlap
    plt.title(f"Layer {i + 1}")  # Title indicating the last displayed slice
    plt.show()

def show_hist(tensor_dicom):
    """
    Plot the histogram of pixel values for all MRI image slices.

    Parameters:
    tensor_dicom (torch.Tensor): Tensor containing MRI image slices, expected shape is (N, H, W).
    """
    # Flatten all image pixel values into a single 1D array
    pixel_values = tensor_dicom.numpy().flatten()

    # Plot the histogram
    plt.figure(figsize=(10, 6))
    plt.hist(pixel_values, bins=50, color='gray', edgecolor='black')
    plt.title('Histogram of All Pixel Values')
    plt.xlabel('Pixel Value')
    plt.ylabel('Frequency')
    plt.grid(True)
    plt.show()

下篇預告


討論本題的解題方法

製作不易,請幫我點一個免費的贊,謝謝!

相關文章