dreambooth程式碼閱讀

ltlreanweb發表於2024-08-13

網上dreambooth大部分只是對論文講解,但程式碼講解不是找不到就是收費,沒辦法,自己硬讀,記錄一下。
水平不高,學機器學習不久,可能有錯,歡迎指正,僅做參考。

Dreambooth 流程簡單來說是1,透過在現有的Diffusion模型增加一個你要的token,變成一個新的模型,比如你給特定一隻sys狗的照片訓練,你新生成的模型就有dog 的token 和 sys dog 的token。2,這時你就能用dog token 生成一隻普通的狗,用sys dog 生成sys狗。前一部分是微調模型,後一部分是生成圖片。
這裡介紹很好,比專案的readme更細一點https://huggingface.co/docs/diffusers/main/en/training/dreambooth
環境搭好後
accelerate config 設定訓練使用模式,特別是分散式訓練,我對這個不太懂,停留在設好能跑就行

然後會讓執行一個以下程式碼格式的指令碼,主要是執行train_dreambooth.py和設定引數,
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
export INSTANCE_DIR="./dog"
export OUTPUT_DIR="path_to_saved_model"

accelerate launch train_dreambooth.py
--pretrained_model_name_or_path=$MODEL_NAME
--instance_data_dir=$INSTANCE_DIR
--output_dir=$OUTPUT_DIR
--instance_prompt="a photo of sks dog"
--resolution=512
--train_batch_size=1
--gradient_accumulation_steps=1
--learning_rate=5e-6
--lr_scheduler="constant"
--lr_warmup_steps=0
--max_train_steps=400
--push_to_hub

重點在train_dreambooth.py的程式碼
1400行程式碼先不管翻到最下面
if name == "main":
args = parse_args()
main(args)
就兩行,先進行parse_args(),再main(),至少結構是清晰的
先看parse_args()函式
主要作用就是將指令碼或者命令列的傳參進行解析,儲存到args中

  def parse_args(input_args=None):
  ##主要是將上面指令碼的傳參進行解析,
    parser = argparse.ArgumentParser(description="Simple example of a training script.")
  #使用argparse 模組建立了一個解析命令列引數的物件,該物件名為 parser。
    parser.add_argument(#設定傳入引數格式,名字,格式,幫助等
        "--pretrained_model_name_or_path",
        type=str,
        default=None,
        required=True,
        help="Path to pretrained model or model identifier from huggingface.co/models.",
    )
    parser.add_argument(
        "--revision",
        type=str,
        default=None,
        required=False,
        help="Revision of pretrained model identifier from huggingface.co/models.",
    )

......後面都是這樣的,不細看用到再說
然後是一些環境配置和傳參的報錯

if input_args is not None:#不為空,使用提供的引數來解析命令列引數;
    args = parser.parse_args(input_args)
else:#為空,使用預設的方式來解析命令列引數。
    args = parser.parse_args()

env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
#從環境變數 LOCAL_RANK 中獲取一個值並將其轉換為整數型別。不存在為-1
#local_rank解釋建議參考https://blog.csdn.net/shenjianhua005/article/details/127318594
if env_local_rank != -1 and env_local_rank != args.local_rank:
    args.local_rank = env_local_rank
#賦值

if args.with_prior_preservation:
    #傳入引數有with_prior reservation loss 是否使用先置保全損失
    #沒理解錯的話是用來保證微調過程中dog 的token的語義不會偏移,就是依舊是普通的狗,不是特定的sys狗
    if args.class_data_dir is None:#報錯
        #--class_data_dir:包含生成的類樣本影像的資料夾的路徑
        raise ValueError("You must specify a data directory for class images.")
    if args.class_prompt is None:
        #--class_prompt:描述生成的樣本影像類別的文字提示
        raise ValueError("You must specify prompt for class images.")
else:
    #警告
    # logger is not available yet
    if args.class_data_dir is not None:
        warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
    if args.class_prompt is not None:
        warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
#  有必要提一嘴--instance_prompt="a photo of sks dog" \  包含示例影像的特殊詞的文字提示
#   --class_prompt="a photo of dog" \  描述生成的樣本影像類別的文字提示
#使用with_prior reservation loss 先置保全損失 需要class_prompt引數,不使用則不需要
#在我的理解,使用with_prior reservation 模型要生成普通的多樣的狗的圖片,便於保證dog token不會偏移
#class_data_dir是用來存放演算法過程中生成的普通的狗的圖片,往後看就知道了,是這樣的我不會回來改



if args.train_text_encoder and args.pre_compute_text_embeddings:
    raise ValueError("`--train_text_encoder` cannot be used with `--pre_compute_text_embeddings`")
#-train_text_encoder 是否也要訓練文字編碼器
# --pre_compute_text_embeddings  Whether or not to pre-compute text embeddings. 是否預先計算文字嵌入。不是很懂,先不管。
return args #返回處理完的引數 400行程式碼大部分是引數設定,鬆一口氣,後面也這麼簡單就好

後面就開始看main函式
先是如果要將模型上傳到庫的報錯和設定

  if args.report_to == "wandb" and args.hub_token is not None:
      #--report_to是 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
      #    ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
      #將結果和日誌報告整合到的平臺。支援的平臺有"tensorboard"(預設)、"wandb"和"comet_ml"。使用"all"來報告到所有整合平臺。
      #--hub_token 是"The token to use to push to the Model Hub.
      raise ValueError(
          "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
          " Please use `huggingface-cli login` to authenticate with the Hub."
      )
  #報錯,由於存在洩露令牌的安全風險,您不能同時使用 --report_to=wandb 和 --hub_token。請使用 huggingface-cli login 透過 Hub 進行身份驗證。
  #不太懂,我應該不需要上傳,應該不重要

  logging_dir = Path(args.output_dir, args.logging_dir)
  #連線output_dir和“TensorBoard日誌目錄,有預設值,所以引數不影響執行  --output_dir:儲存訓練好的模型,
  #logging_dir所以是生成output_dir上傳TensorBoard的地址

再是一些引數設定,有些我不太懂

accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
#配置專案的各種引數,以便在後續的程式碼中使用這些配置。
accelerator = Accelerator(
    gradient_accumulation_steps=args.gradient_accumulation_steps,#用於指定梯度累積的步數。
    mixed_precision=args.mixed_precision,#mixed_precision: 用於指定是否啟用混合精度訓練。
    log_with=args.report_to,#og_with: 用於指定日誌的輸出方式。
    project_config=accelerator_project_config,#project_config: 用於指定專案的配置資訊。
)# 例項化Accelerator類,有關accelerator的可以看這個https://blog.csdn.net/qq_56591814/article/details/134200839      


# Disable AMP for MPS.
if torch.backends.mps.is_available():
    accelerator.native_amp = False
#檢查了是否torch後端支援MPS(Multi-Process Service 不懂 跳過

if args.report_to == "wandb":#簡單的警告
    if not is_wandb_available():
        raise ImportError("Make sure to install wandb if you want to use it for logging during training.")

# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
# This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
# TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
 #目前,使用 accelerate.accumulate 在訓練兩個模型時不支援梯度累積。
  #這個功能很快會在 accelerate 中啟用。目前,當訓練兩個模型時,我們不允許梯度累積。
  #   TODO (patil-suraj): 當 accelerate 中允許在訓練兩個模型時進行梯度累積時,請移除這個檢查
if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
    raise ValueError(
        "Gradient accumulation is not supported when training the text encoder in distributed training. "
        "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
    )

日誌設定

# Make one log on every process with the configuration for debugging.
#在每個程序上使用配置進行除錯時進行一次日誌記錄。 具體不想看了gpt一搜就有
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
    transformers.utils.logging.set_verbosity_warning()
    diffusers.utils.logging.set_verbosity_info()
else:
    transformers.utils.logging.set_verbosity_error()
    diffusers.utils.logging.set_verbosity_error()

隨機數seed

# If passed along, set the training seed now.
if args.seed is not None:
    set_seed(args.seed)
#設定accelerate.utils 隨機數的種子,我覺得可以參考https://blog.csdn.net/qq_40206371/article/details/139466522

使用先驗損失的話生成普通dog圖片防止dog語義偏移

# Generate class images if prior preservation is enabled.
#使用了with_prior_preservation 生成生成類別影像。
if args.with_prior_preservation:
    class_images_dir = Path(args.class_data_dir)#dog
    if not class_images_dir.exists():#如果 class_images_dir 不存在,則建立該目錄。
        class_images_dir.mkdir(parents=True)
    cur_class_images = len(list(class_images_dir.iterdir()))#cur_class_images 變數儲存了 class_images_dir 目錄中檔案的數量。

    if cur_class_images < args.num_class_images:
        #載入模型
        torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
        if args.prior_generation_precision == "fp32":
            torch_dtype = torch.float32
        elif args.prior_generation_precision == "fp16":
            torch_dtype = torch.float16
        elif args.prior_generation_precision == "bf16":
            torch_dtype = torch.bfloat16
        pipeline = DiffusionPipeline.from_pretrained(
            args.pretrained_model_name_or_path,
            torch_dtype=torch_dtype,
            safety_checker=None,
            revision=args.revision,
            variant=args.variant,
        )
        pipeline.set_progress_bar_config(disable=True)
        # 用於禁用流水線的進度條

        num_new_images = args.num_class_images - cur_class_images#還需要生成的圖片數量
        logger.info(f"Number of class images to sample: {num_new_images}.")#日誌

        sample_dataset = PromptDataset(args.class_prompt, num_new_images)#生成資料集
        sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)#載入資料集

        sample_dataloader = accelerator.prepare(sample_dataloader)
        #使用了加速器(如GPU或TPU)來準備資料載入器sample_dataloader
        pipeline.to(accelerator.device)#模型或資料處理管道移動到加速裝置

        for example in tqdm(
            sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
        ):#迭代和進度條
            images = pipeline(example["prompt"]).images#生成圖片

            for i, image in enumerate(images):#進行hash命名儲存
                hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
                image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
                image.save(image_filename)

        del pipeline
        if torch.cuda.is_available():
            torch.cuda.empty_cache()#有gpu清空快取
#生成dog影像結束

上傳倉庫設定

# Handle the repository creation
#對上傳的倉庫的設定,不瞭解,沒細看
if accelerator.is_main_process:
    if args.output_dir is not None:
        os.makedirs(args.output_dir, exist_ok=True)

    if args.push_to_hub:
        repo_id = create_repo(
            repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
        ).repo_id

載入tokener

# Load the tokenizer
if args.tokenizer_name:#tokener名和model名不一樣的話
    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
elif args.pretrained_model_name_or_path:
    tokenizer = AutoTokenizer.from_pretrained(
        args.pretrained_model_name_or_path,
        subfolder="tokenizer",
        revision=args.revision,
        use_fast=False,
    )

載入scheduler and models,unet

# import correct text encoder class
text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
#作者自己寫的一個返回模型類別的函式,用於後面使用用class.xxx函式

# Load scheduler and models
#逆向擴散的Scheduler和encoder
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder = text_encoder_cls.from_pretrained(
    args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
)

if model_has_vae(args):#自定義函式判斷是否有vae 有則初始化自動編碼器
    vae = AutoencoderKL.from_pretrained(
        args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
    )
else:
    vae = None
#載入語義分割模型
unet = UNet2DConditionModel.from_pretrained(
    args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
)

hook的設定,說實話不太懂

def unwrap_model(model):
    model = accelerator.unwrap_model(model)#解包模型,具體還是看https://blog.csdn.net/qq_56591814/article/details/134200839
    model = model._orig_mod if is_compiled_module(model) else model#檢查模組是否是使用 torch.compile 進行編譯的
    return model

# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
#建立自定義儲存和載入hooks,以便使 accelerator.save_state(...) 方法以一種良好的格式序列化
def save_model_hook(models, weights, output_dir):
    if accelerator.is_main_process:
        for model in models:
            sub_dir = "unet" if isinstance(model, type(unwrap_model(unet))) else "text_encoder"#不同情況目錄名
            model.save_pretrained(os.path.join(output_dir, sub_dir))#模型儲存到目錄

            # make sure to pop weight so that corresponding model is not saved again
            #確保彈出權重,這樣相應的模型就不會再次儲存。
            weights.pop()

def load_model_hook(models, input_dir):
    while len(models) > 0:
        # pop models so that they are not loaded again
        model = models.pop()

        #使用isinstance檢查模型型別以確定應該使用哪種方式載入模型。如果模型型別是text_encoder,
        # 則使用transformers風格載入模型,並更新模型的配置(config);如果模型型別不是text_encoder,則使用UNet2DConditionModel風格載入模型,並註冊配置資訊。
        if isinstance(model, type(unwrap_model(text_encoder))):
            # load transformers style into model
            load_model = text_encoder_cls.from_pretrained(input_dir, subfolder="text_encoder")
            model.config = load_model.config
        else:
            # load diffusers style into model
            load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
            model.register_to_config(**load_model.config)

        model.load_state_dict(load_model.state_dict())
        #將已載入的load_model的狀態字典(包含模型的引數)應用到當前的model中。
        del load_model

accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
#設定hook來獲得執行時中間變數,裡面的函式也應該是配置,對這個不太懂,有興趣可以去accelerator庫詳細看看,先跳過

一些設定和報錯

if vae is not None:
    vae.requires_grad_(False)#是否跟蹤梯度

if not args.train_text_encoder:
    text_encoder.requires_grad_(False)

# xformers配置警告,md,上次就是這個b庫和pytorch 版本不適配,搞死我了
#還有這東西不能放前面嗎,每次跑那麼久你才告訴我有問題
if args.enable_xformers_memory_efficient_attention:
    if is_xformers_available():
        import xformers

        xformers_version = version.parse(xformers.__version__)
        if xformers_version == version.parse("0.0.16"):
            logger.warning(
                "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
            )
        unet.enable_xformers_memory_efficient_attention()
    else:
        raise ValueError("xformers is not available. Make sure it is installed correctly")

if args.gradient_checkpointing:#是否使用梯度檢查點來節省記憶體,但以減慢反向傳播的速度。
    unet.enable_gradient_checkpointing()
    if args.train_text_encoder:
        text_encoder.gradient_checkpointing_enable()

# Check that all trainable models are in full precision
low_precision_error_string = (
    "Please make sure to always have all model weights in full float32 precision when starting training - even if"
    " doing mixed precision training. copy of the weights should still be float32."
)
if args.gradient_checkpointing:#是否使用梯度檢查點來節省記憶體,但以減慢反向傳播的速度。
    unet.enable_gradient_checkpointing()
    if args.train_text_encoder:
        text_encoder.gradient_checkpointing_enable()

# Check that all trainable models are in full precision
low_precision_error_string = (
    "Please make sure to always have all model weights in full float32 precision when starting training - even if"
    " doing mixed precision training. copy of the weights should still be float32."
)

if unwrap_model(unet).dtype != torch.float32:
    raise ValueError(f"Unet loaded as datatype {unwrap_model(unet).dtype}. {low_precision_error_string}")

if args.train_text_encoder and unwrap_model(text_encoder).dtype != torch.float32:
    raise ValueError(
        f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}." f" {low_precision_error_string}"
    )
# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if args.allow_tf32:
    torch.backends.cuda.matmul.allow_tf32 = True

if args.scale_lr:#學習率
    args.learning_rate = (
        args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
    )

# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
if args.use_8bit_adam:
    try:
        import bitsandbytes as bnb
    except ImportError:
        raise ImportError(
            "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
        )

    optimizer_class = bnb.optim.AdamW8bit
else:
    optimizer_class = torch.optim.AdamW

建立最佳化器

# Optimizer creation
params_to_optimize = (
    itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters()
)
#根據條件args.train_text_encoder的值來選擇性地將神經網路模型的引數傳遞給最佳化器。如果args.train_text_encoder為真,
# 那麼params_to_optimize將使用itertools.chain(unet.parameters(), text_encoder.parameters())生成的引數組合。
# 否則,params_to_optimize將僅僅使用unet.parameters()。
optimizer = optimizer_class(
    params_to_optimize,
    lr=args.learning_rate,
    betas=(args.adam_beta1, args.adam_beta2),
    weight_decay=args.adam_weight_decay,
    eps=args.adam_epsilon,
)#最佳化器物件

是否預先計算文字嵌入

if args.pre_compute_text_embeddings:#是否預先計算文字嵌入。      
   #https://segmentfault.com/a/1190000044075300

    def compute_text_embeddings(prompt):
        with torch.no_grad():#不會追蹤梯度。
            text_inputs = tokenize_prompt(tokenizer, prompt, tokenizer_max_length=args.tokenizer_max_length)
            #將輸入的文字prompt編碼,
            prompt_embeds = encode_prompt(
                text_encoder,
                text_inputs.input_ids,
                text_inputs.attention_mask,
                text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
                #是否使用attention_mask進行text_encoder https://developer.baidu.com/article/details/3248780
            )#使用自定義函式,對輸入的文字進行編碼,得到編碼後向量

        return prompt_embeds

    pre_computed_encoder_hidden_states = compute_text_embeddings(args.instance_prompt)#呼叫剛剛函式返回sys dog編碼後向量
    validation_prompt_negative_prompt_embeds = compute_text_embeddings("")

    if args.validation_prompt is not None:#validation_prompt在驗證過程中用於確認模型是否在學習的prompt
        validation_prompt_encoder_hidden_states = compute_text_embeddings(args.validation_prompt)
    else:
        validation_prompt_encoder_hidden_states = None

    if args.class_prompt is not None:
        pre_computed_class_prompt_encoder_hidden_states = compute_text_embeddings(args.class_prompt)
    else:
        pre_computed_class_prompt_encoder_hidden_states = None

    text_encoder = None
    tokenizer = None

    gc.collect()#強制進行垃圾回收
    torch.cuda.empty_cache()
else:#不預先計算文字嵌入
    pre_computed_encoder_hidden_states = None
    validation_prompt_encoder_hidden_states = None
    validation_prompt_negative_prompt_embeds = None
    pre_computed_class_prompt_encoder_hidden_states = None

dataset和datasetloder

# Dataset and DataLoaders creation:
train_dataset = DreamBoothDataset(
    instance_data_root=args.instance_data_dir,
    instance_prompt=args.instance_prompt,
    class_data_root=args.class_data_dir if args.with_prior_preservation else None,
    class_prompt=args.class_prompt,
    class_num=args.num_class_images,
    tokenizer=tokenizer,
    size=args.resolution,
    center_crop=args.center_crop,
    encoder_hidden_states=pre_computed_encoder_hidden_states,
    class_prompt_encoder_hidden_states=pre_computed_class_prompt_encoder_hidden_states,
    tokenizer_max_length=args.tokenizer_max_length,
)

train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=args.train_batch_size,
    shuffle=True,
    collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
    num_workers=args.dataloader_num_workers,
)

其他配置

# Scheduler and math around the number of training steps.
#訓練step設定
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
    args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
    overrode_max_train_steps = True

lr_scheduler = get_scheduler(
    args.lr_scheduler,
    optimizer=optimizer,
    num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
    num_training_steps=args.max_train_steps * accelerator.num_processes,
    num_cycles=args.lr_num_cycles,
    power=args.lr_power,
)

# Prepare everything with our `accelerator`.
if args.train_text_encoder:
    unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
        unet, text_encoder, optimizer, train_dataloader, lr_scheduler
    )
else:
    unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
        unet, optimizer, train_dataloader, lr_scheduler
    )

# For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
#對於混合精度訓練,我們將所有不可訓練的權重(例如,VAE、非 LORA 文字編碼器和非 LORA UNet)轉換為半精度,因為這些權重僅用於推理,保持全精度權重是不必要的。
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
    weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
    weight_dtype = torch.bfloat16

# Move vae and text_encoder to device and cast to weight_dtype
if vae is not None:
    vae.to(accelerator.device, dtype=weight_dtype)

if not args.train_text_encoder and text_encoder is not None:
    text_encoder.to(accelerator.device, dtype=weight_dtype)

# We need to recalculate our total training steps as the size of the training dataloader may have changed.
#訓練step,不太懂,有空可以詳細看看
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
    args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)

# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
    tracker_config = vars(copy.deepcopy(args))
    tracker_config.pop("validation_images")
    accelerator.init_trackers("dreambooth", config=tracker_config)
# Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps

logger.info("***** Running training *****")
logger.info(f"  Num examples = {len(train_dataset)}")
logger.info(f"  Num batches each epoch = {len(train_dataloader)}")
logger.info(f"  Num Epochs = {args.num_train_epochs}")
logger.info(f"  Instantaneous batch size per device = {args.train_batch_size}")
logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f"  Total optimization steps = {args.max_train_steps}")
global_step = 0
first_epoch = 0

# Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint:#是否應該從以前的檢查點開始訓練。
    if args.resume_from_checkpoint != "latest":
        path = os.path.basename(args.resume_from_checkpoint)#不是latest提前檔名
    else:
        # Get the most recent checkpoint
        dirs = os.listdir(args.output_dir)#獲取指定目錄中的檔案和目錄列表
        dirs = [d for d in dirs if d.startswith("checkpoint")]#只保留以"checkpoint"開頭的名稱
        dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))#進行排序,排序的標準是按照"-"分割後的第二部分進行整數排序。
        path = dirs[-1] if len(dirs) > 0 else None#最後選擇最大的資料夾作為路徑

    if path is None:
        accelerator.print(
            f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
        )
        args.resume_from_checkpoint = None
        initial_global_step = 0
    else:
        accelerator.print(f"Resuming from checkpoint {path}")
        accelerator.load_state(os.path.join(args.output_dir, path))
        global_step = int(path.split("-")[1])
        #這行程式碼透過破折號("-")對字串變數 path 進行分割,並取分割後的第二部分(即索引為 1 的部分),然後將其轉換為整數型別
        #如果 path 是一個形如 'model-100.pth' 的檔案路徑,那麼這行程式碼將提取出全域性步驟數 100

        initial_global_step = global_step
        first_epoch = global_step // num_update_steps_per_epoch
else:#否應該從以前的檢查點開始訓練。
    initial_global_step = 0

progress_bar = tqdm(#建立了一個迭代器進度條。
    range(0, args.max_train_steps),
    initial=initial_global_step,
    desc="Steps",
    # Only show the progress bar once on each machine.
    disable=not accelerator.is_local_main_process,
)

訓練

# Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint:#是否應該從以前的檢查點開始訓練。
    if args.resume_from_checkpoint != "latest":
        path = os.path.basename(args.resume_from_checkpoint)#不是latest提前檔名
    else:
        # Get the most recent checkpoint
        dirs = os.listdir(args.output_dir)#獲取指定目錄中的檔案和目錄列表
        dirs = [d for d in dirs if d.startswith("checkpoint")]#只保留以"checkpoint"開頭的名稱
        dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))#進行排序,排序的標準是按照"-"分割後的第二部分進行整數排序。
        path = dirs[-1] if len(dirs) > 0 else None#最後選擇最大的資料夾作為路徑

    if path is None:
        accelerator.print(
            f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
        )
        args.resume_from_checkpoint = None
        initial_global_step = 0
    else:
        accelerator.print(f"Resuming from checkpoint {path}")
        accelerator.load_state(os.path.join(args.output_dir, path))
        global_step = int(path.split("-")[1])
        #這行程式碼透過破折號("-")對字串變數 path 進行分割,並取分割後的第二部分(即索引為 1 的部分),然後將其轉換為整數型別
        #如果 path 是一個形如 'model-100.pth' 的檔案路徑,那麼這行程式碼將提取出全域性步驟數 100

        initial_global_step = global_step
        first_epoch = global_step // num_update_steps_per_epoch
else:#否應該從以前的檢查點開始訓練。
    initial_global_step = 0

progress_bar = tqdm(#建立了一個迭代器進度條。
    range(0, args.max_train_steps),
    initial=initial_global_step,
    desc="Steps",
    # Only show the progress bar once on each machine.
    disable=not accelerator.is_local_main_process,
)

for epoch in range(first_epoch, args.num_train_epochs):#訓練輪次和開始輪次
    unet.train()
    if args.train_text_encoder:
        text_encoder.train()
    #使用 unet.train() 和 text_encoder.train() 將模型設定為訓練模式。
    for step, batch in enumerate(train_dataloader):
        #載入sys
        with accelerator.accumulate(unet):
            pixel_values = batch["pixel_values"].to(dtype=weight_dtype)

            if vae is not None:
                # Convert images to latent space
                #在影像上編碼並將其轉換為 latent space。
                model_input = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
                model_input = model_input * vae.config.scaling_factor
            else:
                model_input = pixel_values

            # Sample noise that we'll add to the model input
            #新增noise
            if args.offset_noise:
                noise = torch.randn_like(model_input) + 0.1 * torch.randn(
                    model_input.shape[0], model_input.shape[1], 1, 1, device=model_input.device
                )
            else:
                noise = torch.randn_like(model_input)
            bsz, channels, height, width = model_input.shape#圖片值

            # Sample a random timestep for each image
            #為每個影像隨機取樣一個時間步, 就是迴圈神經網路認為每個輸入資料與前多少個陸續輸入的資料有聯絡。
            timesteps = torch.randint(
                0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
            )
            timesteps = timesteps.long()
            #這段程式碼首先使用PyTorch的torch.randint函式生成一個大小為bsz的隨機整數張量timesteps,
            # 其中每個元素的取值範圍在0到noise_scheduler.config.num_train_timesteps之間。這個過程在model_input.device所指定的裝置上進行。
            #接著,將生成的隨機整數張量timesteps轉換為long型別。

            # Add noise to the model input according to the noise magnitude at each timestep
            #新增noise,是向前傳播過程
            # (this is the forward diffusion process)
            noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)

            # Get the text embedding for conditioning
            #為條件獲取文字嵌入,就是將文字變為向量
            if args.pre_compute_text_embeddings:
                encoder_hidden_states = batch["input_ids"]
            else:
                encoder_hidden_states = encode_prompt(
                    text_encoder,
                    batch["input_ids"],
                    batch["attention_mask"],
                    text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
                )

            if unwrap_model(unet).config.in_channels == channels * 2:#channels是兩倍,noise也要拼接成兩倍
                noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1)

            if args.class_labels_conditioning == "timesteps":
                class_labels = timesteps
            else:
                class_labels = None

            # Predict the noise residual預測噪音殘差
            model_pred = unet(
                noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels, return_dict=False
            )[0]

            if model_pred.shape[1] == 6:
                model_pred, _ = torch.chunk(model_pred, 2, dim=1)

            # Get the target for loss depending on the prediction type
            #根據預測型別獲取用於損失計算的目標值
            if noise_scheduler.config.prediction_type == "epsilon":
                target = noise
            elif noise_scheduler.config.prediction_type == "v_prediction":
                target = noise_scheduler.get_velocity(model_input, noise, timesteps)
            else:
                raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

            if args.with_prior_preservation:
                # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
                #將噪音和模型預測分成兩部分,並分別對每部分計算損失
                model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
                target, target_prior = torch.chunk(target, 2, dim=0)
                # Compute prior loss計算先驗損失
                prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")

            # Compute instance loss
            if args.snr_gamma is None:#“在重新平衡損失時使用的訊雜比加權 gamma。推薦值為 5.0。”
                loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
            else:
                # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
                # Since we predict the noise instead of x_0, the original formulation is slightly changed.
                # This is discussed in Section 4.2 of the same paper.
                #計算loss weight 作者進行了更改
                snr = compute_snr(noise_scheduler, timesteps)#計算了訊雜比(SNR)
                base_weight = (
                    torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
                )

                if noise_scheduler.config.prediction_type == "v_prediction":
                    # Velocity objective needs to be floored to an SNR weight of one.
                    mse_loss_weights = base_weight + 1
                else:
                    # Epsilon and sample both use the same loss weights.
                    mse_loss_weights = base_weight
                loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
                loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
                loss = loss.mean()

            if args.with_prior_preservation:
                # Add the prior loss to the instance loss.
                loss = loss + args.prior_loss_weight * prior_loss

            accelerator.backward(loss)#對loss反向傳播
            if accelerator.sync_gradients:
            #首先檢查加速器是否需要同步梯度。如果需要,它將選擇需要進行梯度裁剪的引數,
            # 將所有引數整合並進行梯度裁剪操作,確保梯度大小不會超過指定的max_grad_norm。
                params_to_clip = (
                    itertools.chain(unet.parameters(), text_encoder.parameters())
                    if args.train_text_encoder
                    else unet.parameters()
                )
                accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
            optimizer.step()
            #對模型引數進行一步最佳化器更新,根據計算得到的梯度大小調整引數
            lr_scheduler.step()
            #新最佳化器中的學習率。
            optimizer.zero_grad(set_to_none=args.set_grads_to_none)#梯度歸零

        # Checks if the accelerator has performed an optimization step behind the scenes
        #檢查加速器是否在後臺執行了最佳化步驟
        if accelerator.sync_gradients:#檢查加速器是否需要同步梯度。
            progress_bar.update(1)#透過進度條更新顯示進度。
            global_step += 1#全域性步數加一。

            if accelerator.is_main_process:#檢查當前程序是否為主程序,以確保以下的操作只由主程序執行,避免重複操作。
                if global_step % args.checkpointing_steps == 0:#判斷是否到了儲存檢查點的步數。如果是,就執行儲存檢查點的操作。
                    # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
                    if args.checkpoints_total_limit is not None:
                        checkpoints = os.listdir(args.output_dir)
                        checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
                        checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))

                        # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
                        if len(checkpoints) >= args.checkpoints_total_limit:#檢查是否超過了儲存檢查點的總限制。會刪除舊的檢查點,以確保不超過閾值個數。
                            num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
                            removing_checkpoints = checkpoints[0:num_to_remove]

                            logger.info(
                                f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
                            )
                            logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")

                            for removing_checkpoint in removing_checkpoints:
                                removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
                                shutil.rmtree(removing_checkpoint)

                    save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
                    accelerator.save_state(save_path)
                    logger.info(f"Saved state to {save_path}")
                    #儲存當前模型狀態為檢查點,並記錄儲存路徑

                images = []

                if args.validation_prompt is not None and global_step % args.validation_steps == 0:
                    #檢查是否設定了驗證提示,並判斷是否到了執行驗證的步數
                    images = log_validation(
                        unwrap_model(text_encoder) if text_encoder is not None else text_encoder,
                        tokenizer,
                        unwrap_model(unet),
                        vae,
                        args,
                        accelerator,
                        weight_dtype,
                        global_step,
                        validation_prompt_encoder_hidden_states,
                        validation_prompt_negative_prompt_embeds,
                    )##執行模型的驗證,並記錄驗證生成的影像。

        logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
        progress_bar.set_postfix(**logs)#將損失值和學習率作為進度條的附加資訊顯示出來,
        accelerator.log(logs, step=global_step)

        if global_step >= args.max_train_steps:#如果全域性步數(global_step)超過設定的最大訓練步數(args.max_train_steps),則跳出迴圈,即結束訓練過程。
            break

create the pipeline和上傳

# Create the pipeline using the trained modules and save it.
accelerator.wait_for_everyone()#確保所有程序都已經執行完畢,等待所有的執行緒進行同步。
if accelerator.is_main_process:
    #建立一個pipeline,並根據訓練得到的模型配置其引數,最後將pipeline儲存起來。
    pipeline_args = {}

    if text_encoder is not None:
        pipeline_args["text_encoder"] = unwrap_model(text_encoder)

    if args.skip_save_text_encoder:
        pipeline_args["text_encoder"] = None

    pipeline = DiffusionPipeline.from_pretrained(
        args.pretrained_model_name_or_path,
        unet=unwrap_model(unet),
        revision=args.revision,
        variant=args.variant,
        **pipeline_args,
    )

    # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
    scheduler_args = {}
   #建立一個空字典scheduler_args用於儲存排程器引數。
    if "variance_type" in pipeline.scheduler.config:
        variance_type = pipeline.scheduler.config.variance_type

        if variance_type in ["learned", "learned_range"]:
            variance_type = "fixed_small"

        scheduler_args["variance_type"] = variance_type

    pipeline.scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config, **scheduler_args)
    pipeline.save_pretrained(args.output_dir)
    #最後將訓練好的pipeline儲存到args.output_dir中

    if args.push_to_hub:
        save_model_card(
            repo_id,
            images=images,
            base_model=args.pretrained_model_name_or_path,
            train_text_encoder=args.train_text_encoder,
            prompt=args.instance_prompt,
            repo_folder=args.output_dir,
            pipeline=pipeline,
        )
        upload_folder(
            repo_id=repo_id,
            folder_path=args.output_dir,
            commit_message="End of training",
            ignore_patterns=["step_*", "epoch_*"],
        )

accelerator.end_training()

後面真的懶得詳細寫了
附上自定義函式

  def save_model_card(
      repo_id: str,
      images: list = None,
      base_model: str = None,
      train_text_encoder=False,
      prompt: str = None,
      repo_folder: str = None,
      pipeline: DiffusionPipeline = None,
  ):
      img_str = ""
      if images is not None:
          for i, image in enumerate(images):
              image.save(os.path.join(repo_folder, f"image_{i}.png"))
              img_str += f"![img_{i}](./image_{i}.png)\n"

      model_description = f"""
  # DreamBooth - {repo_id}

  This is a dreambooth model derived from {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/).
  You can find some example images in the following. \n
  {img_str}

  DreamBooth for the text encoder was enabled: {train_text_encoder}.
  """
      model_card = load_or_create_model_card(
          repo_id_or_path=repo_id,
          from_training=True,
          license="creativeml-openrail-m",
          base_model=base_model,
          prompt=prompt,
          model_description=model_description,
          inference=True,
      )

      tags = ["text-to-image", "dreambooth", "diffusers-training"]
      if isinstance(pipeline, StableDiffusionPipeline):
          tags.extend(["stable-diffusion", "stable-diffusion-diffusers"])
      else:
          tags.extend(["if", "if-diffusers"])
      model_card = populate_model_card(model_card, tags=tags)

      model_card.save(os.path.join(repo_folder, "README.md"))


  def log_validation(
      text_encoder,
      tokenizer,
      unet,
      vae,
      args,
      accelerator,
      weight_dtype,
      global_step,
      prompt_embeds,
      negative_prompt_embeds,
  ):#執行模型的驗證,並記錄驗證生成的影像。
      logger.info(
          f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
          f" {args.validation_prompt}."
      )

      pipeline_args = {}

      if vae is not None:
          pipeline_args["vae"] = vae

      # create pipeline (note: unet and vae are loaded again in float32)
      pipeline = DiffusionPipeline.from_pretrained(
          args.pretrained_model_name_or_path,
          tokenizer=tokenizer,
          text_encoder=text_encoder,
          unet=unet,
          revision=args.revision,
          variant=args.variant,
          torch_dtype=weight_dtype,
          **pipeline_args,
      )

      # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
      scheduler_args = {}

      if "variance_type" in pipeline.scheduler.config:
          variance_type = pipeline.scheduler.config.variance_type

          if variance_type in ["learned", "learned_range"]:
              variance_type = "fixed_small"

          scheduler_args["variance_type"] = variance_type

      module = importlib.import_module("diffusers")
      scheduler_class = getattr(module, args.validation_scheduler)
      pipeline.scheduler = scheduler_class.from_config(pipeline.scheduler.config, **scheduler_args)
      pipeline = pipeline.to(accelerator.device)
      pipeline.set_progress_bar_config(disable=True)

      if args.pre_compute_text_embeddings:
          pipeline_args = {
              "prompt_embeds": prompt_embeds,
              "negative_prompt_embeds": negative_prompt_embeds,
          }
      else:
          pipeline_args = {"prompt": args.validation_prompt}

      # run inference
      generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)
      images = []
      if args.validation_images is None:
          for _ in range(args.num_validation_images):
              with torch.autocast("cuda"):
                  image = pipeline(**pipeline_args, num_inference_steps=25, generator=generator).images[0]
              images.append(image)
      else:
          for image in args.validation_images:
              image = Image.open(image)
              image = pipeline(**pipeline_args, image=image, generator=generator).images[0]
              images.append(image)

      for tracker in accelerator.trackers:
          if tracker.name == "tensorboard":
              np_images = np.stack([np.asarray(img) for img in images])
              tracker.writer.add_images("validation", np_images, global_step, dataformats="NHWC")
          if tracker.name == "wandb":
              tracker.log(
                  {
                      "validation": [
                          wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
                      ]
                  }
              )

      del pipeline
      torch.cuda.empty_cache()

      return images


  def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):#返回模型類別
      text_encoder_config = PretrainedConfig.from_pretrained(
          pretrained_model_name_or_path,
          subfolder="text_encoder",
          revision=revision,
      )
      model_class = text_encoder_config.architectures[0]

      if model_class == "CLIPTextModel":
          from transformers import CLIPTextModel

          return CLIPTextModel
      elif model_class == "RobertaSeriesModelWithTransformation":
          from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation

          return RobertaSeriesModelWithTransformation
      elif model_class == "T5EncoderModel":
          from transformers import T5EncoderModel

          return T5EncoderModel
      else:
          raise ValueError(f"{model_class} is not supported.")

還有

class DreamBoothDataset(Dataset):

  #A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
  #It pre-processes the images and the tokenizes prompts.
  
  #一個資料集,用於為調整模型進行細化的例項和類別影像準備提示。它對影像進行預處理並對提示進行標記化處理
  def __init__(
      self,
      instance_data_root,
      instance_prompt,
      tokenizer,
      class_data_root=None,
      class_prompt=None,
      class_num=None,
      size=512,
      center_crop=False,
      encoder_hidden_states=None,
      class_prompt_encoder_hidden_states=None,
      tokenizer_max_length=None,
  ):
      self.size = size
      self.center_crop = center_crop
      self.tokenizer = tokenizer
      self.encoder_hidden_states = encoder_hidden_states
      self.class_prompt_encoder_hidden_states = class_prompt_encoder_hidden_states
      self.tokenizer_max_length = tokenizer_max_length

      self.instance_data_root = Path(instance_data_root)
      if not self.instance_data_root.exists():
          raise ValueError(f"Instance {self.instance_data_root} images root doesn't exists.")

      self.instance_images_path = list(Path(instance_data_root).iterdir())
      self.num_instance_images = len(self.instance_images_path)
      self.instance_prompt = instance_prompt
      self._length = self.num_instance_images

      if class_data_root is not None:
          self.class_data_root = Path(class_data_root)
          self.class_data_root.mkdir(parents=True, exist_ok=True)
          self.class_images_path = list(self.class_data_root.iterdir())
          if class_num is not None:
              self.num_class_images = min(len(self.class_images_path), class_num)
          else:
              self.num_class_images = len(self.class_images_path)
          self._length = max(self.num_class_images, self.num_instance_images)
          self.class_prompt = class_prompt
      else:
          self.class_data_root = None

      self.image_transforms = transforms.Compose(
          [
              transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
              transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
              transforms.ToTensor(),
              transforms.Normalize([0.5], [0.5]),
          ]
      )

  def __len__(self):
      return self._length

  def __getitem__(self, index):
      example = {}
      instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
      instance_image = exif_transpose(instance_image)

      if not instance_image.mode == "RGB":
          instance_image = instance_image.convert("RGB")
      example["instance_images"] = self.image_transforms(instance_image)

      if self.encoder_hidden_states is not None:
          example["instance_prompt_ids"] = self.encoder_hidden_states
      else:
          text_inputs = tokenize_prompt(
              self.tokenizer, self.instance_prompt, tokenizer_max_length=self.tokenizer_max_length
          )
          example["instance_prompt_ids"] = text_inputs.input_ids
          example["instance_attention_mask"] = text_inputs.attention_mask

      if self.class_data_root:
          class_image = Image.open(self.class_images_path[index % self.num_class_images])
          class_image = exif_transpose(class_image)

          if not class_image.mode == "RGB":
              class_image = class_image.convert("RGB")
          example["class_images"] = self.image_transforms(class_image)

          if self.class_prompt_encoder_hidden_states is not None:
              example["class_prompt_ids"] = self.class_prompt_encoder_hidden_states
          else:
              class_text_inputs = tokenize_prompt(
                  self.tokenizer, self.class_prompt, tokenizer_max_length=self.tokenizer_max_length
              )
              example["class_prompt_ids"] = class_text_inputs.input_ids
              example["class_attention_mask"] = class_text_inputs.attention_mask

      return example


def collate_fn(examples, with_prior_preservation=False):
    #返回一個批處理字典物件,其中包含輸入id、畫素值以及(如果存在的話)注意力掩碼。
    has_attention_mask = "instance_attention_mask" in examples[0]

    input_ids = [example["instance_prompt_ids"] for example in examples]
    pixel_values = [example["instance_images"] for example in examples]

    if has_attention_mask:
        attention_mask = [example["instance_attention_mask"] for example in examples]

    # Concat class and instance examples for prior preservation.
    # We do this to avoid doing two forward passes.
    if with_prior_preservation:
        input_ids += [example["class_prompt_ids"] for example in examples]
        pixel_values += [example["class_images"] for example in examples]

        if has_attention_mask:
            attention_mask += [example["class_attention_mask"] for example in examples]

    pixel_values = torch.stack(pixel_values)
    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()

    input_ids = torch.cat(input_ids, dim=0)

    batch = {
        "input_ids": input_ids,
        "pixel_values": pixel_values,
    }

    if has_attention_mask:
        attention_mask = torch.cat(attention_mask, dim=0)
        batch["attention_mask"] = attention_mask

    return batch


class PromptDataset(Dataset):
    """A simple dataset to prepare the prompts to generate class images on multiple GPUs."""

    def __init__(self, prompt, num_samples):
        self.prompt = prompt
        self.num_samples = num_samples

    def __len__(self):
        return self.num_samples

    def __getitem__(self, index):
        example = {}
        example["prompt"] = self.prompt
        example["index"] = index
        return example


def model_has_vae(args):
    #字面意思,是否有vae
    config_file_name = Path("vae", AutoencoderKL.config_name).as_posix()
    if os.path.isdir(args.pretrained_model_name_or_path):
        config_file_name = os.path.join(args.pretrained_model_name_or_path, config_file_name)
        return os.path.isfile(config_file_name)
    else:
        files_in_repo = model_info(args.pretrained_model_name_or_path, revision=args.revision).siblings
        return any(file.rfilename == config_file_name for file in files_in_repo)


def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None):
    #進行編碼可以看看這個https://www.cnblogs.com/carolsun/p/16903276.html
    if tokenizer_max_length is not None:
        max_length = tokenizer_max_length
    else:
        max_length = tokenizer.model_max_length

    text_inputs = tokenizer(
        prompt,
        truncation=True,
        padding="max_length",
        max_length=max_length,
        return_tensors="pt",
    )

    return text_inputs


def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=None):
    #對文字進行編碼
    text_input_ids = input_ids.to(text_encoder.device)

    if text_encoder_use_attention_mask:
        attention_mask = attention_mask.to(text_encoder.device)
    else:
        attention_mask = None

    prompt_embeds = text_encoder(
        text_input_ids,
        attention_mask=attention_mask,
        return_dict=False,
    )
    promp  t_embeds = prompt_embeds[0]
    #提取向量
    return prompt_embeds

相關文章