chatGPT幫我最佳化程式碼-2024.06.20

Elina-Chang發表於2024-06-20

改成物件導向

  該程式的主要任務是從指定的文字檔案中提取 ROI(感興趣區域)資訊,統計不同 ROI 標籤(如 56,63,69 ……)的出現次數,並繪製統計結果的條形圖。透過將功能模組化到 RoiAnalyzer 類中,程式碼變得更加結構化和可維護。
【文字檔案】:

  • 原始碼
    
    import re
    import numpy as np
    import seaborn as sns
    import matplotlib.pyplot as plt
    
    def ret_roi_value_dict(txt_path):
        output = []
        line_number = 0
        with open(txt_path, 'r') as file:
            for line in file:
                line_number += 1
                # 使用正規表示式提取case_name和pixel_value
                match = re.match(r'Case\s+(.*?)\s+has\s+\[(.*?)\]', line)
                if match:
                    case_name = match.group(1)
                    # 去除字串中的多餘空格,並進行分割
                    pixel_value = list(map(int, match.group(2).strip().split()))
                    output.append({'case_name': case_name, 'pixel_value': pixel_value})
                else:
                    print(f"Line {line_number} not matched: {line.strip()}")
        
        return output
    
    def show_roi_value(roi_value_dict, fig_save_path, save_dpi):
        # 提取所有 "pixel_value" 對應的值
        all_pixel_values = []
        for item in roi_value_dict:
            all_pixel_values.extend(item["pixel_value"])
        # 計算每個元素的出現次數
        unique_elements, counts = np.unique(np.array(all_pixel_values), return_counts=True)
        # 將 unique_elements 和 counts 轉換為 list,並根據 counts 進行排序
        sorted_indices = np.argsort(counts)[::-1]
        sorted_unique_elements = unique_elements[sorted_indices]
        sorted_counts = counts[sorted_indices]
        
        # 定義前五個顏色
        top_colors = ['red', 'blue', 'green', 'yellow', 'purple']
        bar_colors = ['grey'] * len(unique_elements)
    
        # 將排名第二到第六的元素的顏色改為指定顏色
        for i in range(1, min(6, len(sorted_unique_elements))):
            # 找到 sorted_unique_elements[i] 在 original unique_elements 中的位置
            original_index = np.where(unique_elements == sorted_unique_elements[i])[0][0]
            bar_colors[original_index] = top_colors[i-1]
        # 繪製條形圖
        plt.figure(figsize=(10, 6))
        bars = sns.barplot(x=unique_elements, y=counts, palette=bar_colors)
        plt.title('TB ROI Information')
        plt.xlabel('ROI label')
        plt.ylabel('Case count')
        plt.xticks(rotation=45)  # 旋轉 x 軸標籤以提高可讀性
        plt.tight_layout()  # 確保標籤不會被裁剪
        plt.grid(axis="y")
        # 在每個 bar 上方標註數量
        for i, bar in enumerate(bars.patches):
            height = bar.get_height()
            plt.text(bar.get_x() + bar.get_width() / 2, int(height), f'{int(height)}', ha='center', va='bottom', color='black')
        # 儲存影像
        plt.savefig(fig_save_path, dpi=save_dpi)
    
    def main():
        keyword_list = ["MaXin_50_Newlabel_final", "MaXin_100", "ZhangRui_50_Newlabel_final", "ZhangRui_54"]
        roi_value_dict = []
        for keyword in keyword_list:
            temp = ret_roi_value_dict(txt_path=f"/homes/xchang/Projects/Multi-center_Tuberculosis/segmentation/nnUNet_dl/{keyword}.txt")
            roi_value_dict.extend(temp) 
        show_roi_value(roi_value_dict=roi_value_dict, fig_save_path=f"/homes/xchang/Projects/Multi-center_Tuberculosis/roi_new_values_final.png", save_dpi=600)
    
    if __name__ == "__main__":
        main()
    
  • 最佳化後
    
    import re
    import numpy as np
    import seaborn as sns
    import matplotlib.pyplot as plt
    
    class RoiAnalyzer:
        def __init__(self, keywords, base_path):
            self.keywords = keywords
            self.base_path = base_path
            self.roi_value_dict = []
    
        def ret_roi_value_dict(self, txt_path):
            output = []
            line_number = 0
            with open(txt_path, 'r') as file:
                for line in file:
                    line_number += 1
                    # 使用正規表示式提取case_name和pixel_value
                    match = re.match(r'Case\s+(.*?)\s+has\s+\[(.*?)\]', line)
                    if match:
                        case_name = match.group(1)
                        # 去除字串中的多餘空格,並進行分割
                        pixel_value = list(map(int, match.group(2).strip().split()))
                        output.append({'case_name': case_name, 'pixel_value': pixel_value})
                    else:
                        print(f"Line {line_number} not matched: {line.strip()}")
            
            return output
    
        def collect_data(self):
            for keyword in self.keywords:
                txt_path = f"{self.base_path}/{keyword}.txt"
                temp = self.ret_roi_value_dict(txt_path)
                self.roi_value_dict.extend(temp)
    
        def show_roi_value(self, fig_save_path, save_dpi):
            # 提取所有 "pixel_value" 對應的值
            all_pixel_values = []
            for item in self.roi_value_dict:
                all_pixel_values.extend(item["pixel_value"])
            # 計算每個元素的出現次數
            unique_elements, counts = np.unique(np.array(all_pixel_values), return_counts=True)
            # 將 unique_elements 和 counts 轉換為 list,並根據 counts 進行排序
            sorted_indices = np.argsort(counts)[::-1]
            sorted_unique_elements = unique_elements[sorted_indices]
            sorted_counts = counts[sorted_indices]
            
            # 定義前五個顏色
            top_colors = ['red', 'blue', 'green', 'yellow', 'purple']
            bar_colors = ['grey'] * len(unique_elements)
    
            # 將排名第二到第六的元素的顏色改為指定顏色
            for i in range(1, min(6, len(sorted_unique_elements))):
                # 找到 sorted_unique_elements[i] 在 original unique_elements 中的位置
                original_index = np.where(unique_elements == sorted_unique_elements[i])[0][0]
                bar_colors[original_index] = top_colors[i-1]
            
            # 繪製條形圖
            plt.figure(figsize=(10, 6))
            bars = sns.barplot(x=unique_elements, y=counts, palette=bar_colors)
            plt.title('TB ROI Information')
            plt.xlabel('ROI label')
            plt.ylabel('Case count')
            plt.xticks(rotation=45)  # 旋轉 x 軸標籤以提高可讀性
            plt.tight_layout()  # 確保標籤不會被裁剪
            plt.grid(axis="y")
            
            # 在每個 bar 上方標註數量
            for i, bar in enumerate(bars.patches):
                height = bar.get_height()
                plt.text(bar.get_x() + bar.get_width() / 2, int(height), f'{int(height)}', ha='center', va='bottom', color='black')
            
            # 儲存影像
            plt.savefig(fig_save_path, dpi=save_dpi)
    
        def run(self, fig_save_path, save_dpi=600):
            self.collect_data()
            self.show_roi_value(fig_save_path, save_dpi)
    
    def main():
        keywords = ["MaXin_50_Newlabel_final", "MaXin_100", "ZhangRui_50_Newlabel_final", "ZhangRui_54"]
        base_path = "/homes/xchang/Projects/Multi-center_Tuberculosis/segmentation/nnUNet_dl"
        fig_save_path = "/homes/xchang/Projects/Multi-center_Tuberculosis/roi_new_values_final.png"
        
        analyzer = RoiAnalyzer(keywords, base_path)
        analyzer.run(fig_save_path)
    
    if __name__ == "__main__":
        main()
    

相關文章