import io import logging import math import os os.environ["CUDA_VISIBLE_DEVICES"] = "1" import random import shutil import sys sys.path.append('./') from pathlib import Path import accelerate import datasets import numpy as np from PIL import Image import torch import torch.nn.functional as F import torch.utils.checkpoint import transformers from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.state import AcceleratorState from accelerate.utils import ProjectConfiguration, set_seed from datasets import load_dataset from huggingface_hub import create_repo, upload_folder from packaging import version from torchvision import transforms from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from transformers.utils import ContextManagers import diffusers from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, StableDiffusionXLPipeline, UNet2DConditionModel from models.unet_2d_condition import UNet2DLoRAConditionModel from models.lora import add_lora_to_model from diffusers.optimization import get_scheduler from diffusers.utils import check_min_version, deprecate, is_wandb_available, make_image_grid from diffusers.utils.import_utils import is_xformers_available from MMP_Diffusion_Lora_config import parse_args, import_model_class_from_model_name_or_path from peft.utils import get_peft_model_state_dict from diffusers.utils import convert_state_dict_to_diffusers from models.visual_prompts import EmotionEmbedding, EmotionEmbedding2 import copy if is_wandb_available(): import wandb ## SDXL import functools import gc from torchvision.transforms.functional import crop from transformers import AutoTokenizer # Will error if the minimal version of diffusers is not installed. Remove at your own risks. check_min_version("0.20.0") logger = get_logger(__name__, log_level="INFO") # Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt def encode_prompt_sdxl(batch, text_encoders, tokenizers, proportion_empty_prompts, caption_column, is_train=True): prompt_embeds_list = [] prompt_batch = batch[caption_column] captions = [] for caption in prompt_batch: if random.random() < proportion_empty_prompts: captions.append("") elif isinstance(caption, str): captions.append(caption) elif isinstance(caption, (list, np.ndarray)): # take a random caption if there are multiple captions.append(random.choice(caption) if is_train else caption[0]) for tokenizer, text_encoder in zip(tokenizers, text_encoders): text_input_ids = tokenizer( captions, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt", ).input_ids with torch.no_grad(): prompt_embeds = text_encoder( text_input_ids.to('cuda'), output_hidden_states=True, ) # We are only ALWAYS interested in the pooled output of the final text encoder # torch.Size([32, 1280]) this # odict_keys(['text_embeds', 'last_hidden_state', 'hidden_states']) if isinstance(text_encoder, CLIPTextModel): pass elif isinstance(text_encoder, CLIPTextModelWithProjection): pooled_prompt_embeds = prompt_embeds[0] # "2" because SDXL always indexes from the penultimate layer. # torch.Size([32, 77, 768/1280]) prompt_embeds = prompt_embeds.hidden_states[-2] bs_embed, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) prompt_embeds_list.append(prompt_embeds) # torch.Size([32, 77, 768+1280=2048]) prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) # torch.Size([32, 1280]) pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) return { "prompt_embeds": prompt_embeds, "pooled_prompt_embeds": pooled_prompt_embeds, } def init_emotion_prompts(visual_prompts_dir, is_sdxl=True, prompt_len=16): emotions = ["amusement", "anger", "awe", "contentment", "disgust", "excitement", "fear", "sadness"] if is_sdxl: output_dim = 2048 else: output_dim = 768 feature_names = ["clip", "vgg", "dinov2"] visual_prompts = EmotionEmbedding(emotions, visual_prompts_dir, feature_names, output_dim=output_dim, prompt_len=prompt_len) return visual_prompts def init_emotion_prompts2(is_sdxl=True): emotions = ["amusement", "anger", "awe", "contentment", "disgust", "excitement", "fear", "sadness"] if is_sdxl: output_dim = 2048 else: output_dim = 768 input_dim = 2048 visual_prompts = EmotionEmbedding2(emotions, input_dim, output_dim=output_dim) return visual_prompts def random_sample_emotions(anchor_emotions): emotions = ["amusement", "anger", "awe", "contentment", "disgust", "excitement", "fear", "sadness"] random_emotions = [] for anchor in anchor_emotions: available_emotions = [emotion for emotion in emotions if emotion != anchor] random_choice = random.choice(available_emotions) random_emotions.append(random_choice) return random_emotions def main(): args = parse_args() #### START ACCELERATOR BOILERPLATE ### logging_dir = os.path.join(args.output_dir, args.logging_dir) accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, log_with=args.report_to, project_config=accelerator_project_config, ) # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) logger.info(accelerator.state, main_process_only=False) if accelerator.is_local_main_process: datasets.utils.logging.set_verbosity_warning() transformers.utils.logging.set_verbosity_warning() diffusers.utils.logging.set_verbosity_info() else: datasets.utils.logging.set_verbosity_error() transformers.utils.logging.set_verbosity_error() diffusers.utils.logging.set_verbosity_error() # If passed along, set the training seed now. if args.seed is not None: set_seed(args.seed + accelerator.process_index) # added in + term, untested # Handle the repository creation if accelerator.is_main_process: if args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) ### END ACCELERATOR BOILERPLATE ### START DIFFUSION BOILERPLATE ### # Load scheduler, tokenizer and models. noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") # SDXL has two text encoders if args.sdxl: tokenizer_and_encoder_name = args.pretrained_model_name_or_path tokenizer_one = AutoTokenizer.from_pretrained(tokenizer_and_encoder_name, subfolder="tokenizer", revision=args.revision, use_fast=False) tokenizer_two = AutoTokenizer.from_pretrained(tokenizer_and_encoder_name, subfolder="tokenizer_2", revision=args.revision, use_fast=False) else: tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision) # Not sure if we're hitting this at all def deepspeed_zero_init_disabled_context_manager(): """ returns either a context list that includes one that will disable zero.Init or an empty context list """ deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None if deepspeed_plugin is None: return [] return [deepspeed_plugin.zero3_init_context_manager(enable=False)] with ContextManagers(deepspeed_zero_init_disabled_context_manager()): # SDXL has two text encoders if args.sdxl: # import correct text encoder classes text_encoder_cls_one = import_model_class_from_model_name_or_path(tokenizer_and_encoder_name, args.revision, subfolder="text_encoder") text_encoder_cls_two = import_model_class_from_model_name_or_path(tokenizer_and_encoder_name, args.revision, subfolder="text_encoder_2") text_encoder_one = text_encoder_cls_one.from_pretrained(tokenizer_and_encoder_name, revision=args.revision, subfolder="text_encoder") text_encoder_two = text_encoder_cls_two.from_pretrained(tokenizer_and_encoder_name, revision=args.revision, subfolder="text_encoder_2") text_encoders = [text_encoder_one, text_encoder_two] tokenizers = [tokenizer_one, tokenizer_two] else: text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision) # Can custom-select VAE (used in original SDXL tuning) vae_path = ( args.pretrained_model_name_or_path if args.pretrained_vae_model_name_or_path is None else args.pretrained_vae_model_name_or_path ) vae = AutoencoderKL.from_pretrained( vae_path, subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, revision=args.revision ) # clone of model ref_unet = UNet2DConditionModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision ) unet = UNet2DLoRAConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision) print("======== init_emotion_prompts ================") visual_prompts = init_emotion_prompts(args.visual_prompts_dir, is_sdxl=args.sdxl, prompt_len=args.prompt_len).to(accelerator.device) # visual_prompts = init_emotion_prompts2(is_sdxl=args.sdxl).to(accelerator.device) print("======== init_emotion_prompts done ================") # Freeze vae, text_encoder(s), reference unet vae.requires_grad_(False) if args.sdxl: text_encoder_one.requires_grad_(False) text_encoder_two.requires_grad_(False) else: text_encoder.requires_grad_(False) if args.train_method == 'dpo': ref_unet.requires_grad_(False) # if args.use_lora: # unet.requires_grad_(False) # args.lora_rank default 32 lora_p, negation = add_lora_to_model(unet, dropout=0.1, lora_rank=args.lora_rank, scale=1.0) # xformers efficient attention if is_xformers_available(): import xformers xformers_version = version.parse(xformers.__version__) if xformers_version == version.parse("0.0.16"): logger.warning( "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." ) unet.enable_xformers_memory_efficient_attention() else: raise ValueError("xformers is not available. Make sure it is installed correctly") # BRAM NOTE: We're using >=0.16.0. Below was a bit of a bug hive. I hacked around it, but ideally ref_unet wouldn't # be getting passed here # # `accelerate` 0.16.0 will have better support for customized saving if version.parse(accelerate.__version__) >= version.parse("0.16.0"): # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): print("save_model_hook") for i in range(len(models)): print(models[i].__class__.__name__) if len(models) > 1: assert args.train_method == 'dpo' # 2nd model is just ref_unet in DPO case if args.sdxl: # UNet2DLoRAConditionModel models[0].save_pretrained(os.path.join(output_dir, 'unet_with_lora')) weights.pop() # EmotionEmbedding torch.save(models[1].state_dict(), os.path.join(output_dir, "EmotionEmbedding.pth")) weights.pop() def load_model_hook(models, input_dir): print("load_model_hook") for i in range(len(models)): print(models[i].__class__.__name__) if len(models) > 1: assert args.train_method == 'dpo' # 2nd model is just ref_unet in DPO case if args.sdxl: # UNet2DLoRAConditionModel model = models.pop(0) from safetensors.torch import load_file # 加载两个safetensors文件 state_dict_1 = load_file(os.path.join(input_dir, 'unet_with_lora', 'diffusion_pytorch_model-00001-of-00002.safetensors')) state_dict_2 = load_file(os.path.join(input_dir, 'unet_with_lora', 'diffusion_pytorch_model-00002-of-00002.safetensors')) # 合并状态字典 state_dict = {**state_dict_1, **state_dict_2} model.load_state_dict(state_dict) # EmotionEmbedding model = models.pop(0) state_dict = torch.load(os.path.join(input_dir, "EmotionEmbedding.pth"), weights_only=True) model.load_state_dict(state_dict) accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) if args.gradient_checkpointing or args.sdxl: # (args.sdxl and ('turbo' not in args.pretrained_model_name_or_path) ): print("Enabling gradient checkpointing, either because you asked for this or because you're using SDXL") unet.enable_gradient_checkpointing() # Bram Note: haven't touched # Enable TF32 for faster training on Ampere GPUs, # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices if args.allow_tf32: torch.backends.cuda.matmul.allow_tf32 = True if args.scale_lr: args.learning_rate = ( args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) unet_params = [] lora_params = [] for name, param in unet.named_parameters(): if 'lora' in name.lower(): lora_params.append(param) else: if param.requires_grad: unet_params.append(param) # if args.use_adafactor or args.sdxl: print("Using Adafactor either because you asked for it or you're using SDXL") param_groups = [ { "params": unet_params, "lr": args.learning_rate_unet, }, { "params": lora_params, "lr": args.learning_rate_lora, }, { "params": visual_prompts.parameters(), "lr": args.learning_rate_prompts, } ] optimizer = transformers.Adafactor( param_groups, weight_decay=args.adam_weight_decay, clip_threshold=1.0, scale_parameter=False, relative_step=False ) # else: # optimizer = torch.optim.AdamW([ # {"params": unet_params, "lr": args.learning_rate, # "beta": (args.adam_beta1, args.adam_beta2), "weight_decay": args.adam_weight_decay, # "eps": args.adam_epsilon}, # {"params": lora_params, "lr": args.learning_rate*5, # "beta": (args.adam_beta1, args.adam_beta2), "weight_decay": args.adam_weight_decay, # "eps": args.adam_epsilon}, # {"params": visual_prompts.parameters(), "lr": args.learning_rate*5, # "beta": (args.adam_beta1, args.adam_beta2), "weight_decay": args.adam_weight_decay, # "eps": args.adam_epsilon} # ]) # In distributed training, the load_dataset function guarantees that only one local process can concurrently # download the dataset. dataset = load_dataset(path='parquet', data_files=args.dataset_path) caption_column = args.caption_column def tokenize_captions(examples, is_train=True): captions = [] for caption in examples[caption_column]: if random.random() < args.proportion_empty_prompts: captions.append("") elif isinstance(caption, str): captions.append(caption) elif isinstance(caption, (list, np.ndarray)): # take a random caption if there are multiple captions.append(random.choice(caption) if is_train else caption[0]) else: raise ValueError( f"Caption column `{caption_column}` should contain either strings or lists of strings." ) inputs = tokenizer( captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" ) return inputs.input_ids # Preprocessing the datasets. train_transforms = transforms.Compose( [ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), transforms.RandomCrop(args.resolution) if args.random_crop else transforms.CenterCrop(args.resolution), transforms.Lambda(lambda x: x) if args.no_hflip else transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] ) #### START PREPROCESSING/COLLATION #### if args.train_method == 'dpo': print("Ignoring image_column variable, reading from jpg_0 and jpg_1") def preprocess_train(examples): all_pixel_values = [] for col_name in ['jpg_0', 'jpg_1']: images = [Image.open(io.BytesIO(im_bytes)).convert("RGB") for im_bytes in examples[col_name]] pixel_values = [train_transforms(image) for image in images] all_pixel_values.append(pixel_values) # DOUBLE win images for visual prompts optimization # all_pixel_values # [[jpg_0,...],[jpg_1,...]] # => [[jpg_0,...],[jpg_1,...],[jpg_0,...]] all_pixel_values.append(copy.deepcopy(all_pixel_values[0])) # Triple on channel dim, jpg_y then jpg_w and jpg_y # im_tup_iterator = [(jpg_0,jpg_1,jpg_0),...] im_tup_iterator = zip(*all_pixel_values) combined_pixel_values = [] # item = (jpg_0,jpg_1,jpg_0), label for im_tup, label_0 in zip(im_tup_iterator, examples['label_0']): # print(len(im_tup), im_tup[0].shape) # 3 torch.Size([3, 512, 512]) if label_0==0 and (not args.choice_model): # don't want to flip things if using choice_model for AI feedback im_tup = im_tup[::-1] # [3+3+3, 512, 512] combined_im = torch.cat(im_tup, dim=0) # no batch dim combined_pixel_values.append(combined_im) # [[9, 512, 512],...] examples["pixel_values"] = combined_pixel_values # SDXL takes raw prompts if not args.sdxl: examples["input_ids"] = tokenize_captions(examples) return examples def collate_fn(examples): # [bs, 9, 512, 512] pixel_values = torch.stack([example["pixel_values"] for example in examples]) pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() return_d = {"pixel_values": pixel_values} return_d["emotions"] = [example["emotion"] for example in examples] # SDXL takes raw prompts if args.sdxl: return_d["caption"] = [example["caption"] for example in examples] else: return_d["input_ids"] = torch.stack([example["input_ids"] for example in examples]) if args.choice_model: # If using AIF then deliver image data for choice model to determine if should flip pixel values for k in ['jpg_0', 'jpg_1']: return_d[k] = [Image.open(io.BytesIO( example[k])).convert("RGB") for example in examples] return_d["caption"] = [example["caption"] for example in examples] return return_d ### DATASET ##### with accelerator.main_process_first(): if args.max_train_samples is not None: dataset[args.split] = dataset[args.split].shuffle(seed=args.seed).select(range(args.max_train_samples)) train_dataset = dataset[args.split].with_transform(preprocess_train) # DataLoaders creation: train_dataloader = torch.utils.data.DataLoader( train_dataset, shuffle=(args.split=='train'), collate_fn=collate_fn, batch_size=args.train_batch_size, num_workers=args.dataloader_num_workers, drop_last=True ) ##### END BIG OLD DATASET BLOCK ##### # Scheduler and math around the number of training steps. overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True lr_scheduler = get_scheduler( args.lr_scheduler, optimizer=optimizer, num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, num_training_steps=args.max_train_steps * accelerator.num_processes, ) #### START ACCELERATOR PREP #### unet, visual_prompts, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( unet, visual_prompts, optimizer, train_dataloader, lr_scheduler ) weight_dtype = torch.float32 # Move text_encode and vae to gpu and cast to weight_dtype vae.to(accelerator.device, dtype=weight_dtype) if args.sdxl: text_encoder_one.to(accelerator.device, dtype=weight_dtype) text_encoder_two.to(accelerator.device, dtype=weight_dtype) # print("offload vae (this actually stays as CPU)") # vae = accelerate.cpu_offload(vae) # print("Offloading text encoders to cpu") text_encoder_one = accelerate.cpu_offload(text_encoder_one) text_encoder_two = accelerate.cpu_offload(text_encoder_two) if args.train_method == 'dpo': ref_unet.to(accelerator.device, dtype=weight_dtype) # print("offload ref_unet") # ref_unet = accelerate.cpu_offload(ref_unet) else: text_encoder.to(accelerator.device, dtype=weight_dtype) if args.train_method == 'dpo': ref_unet.to(accelerator.device, dtype=weight_dtype) ### END ACCELERATOR PREP ### # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if overrode_max_train_steps: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch # Afterwards we recalculate our number of training epochs args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. if accelerator.is_main_process: tracker_config = dict(vars(args)) init_kwargs = { "wandb": { "name": args.tracker_run_name, } } accelerator.init_trackers(args.tracker_project_name, tracker_config, init_kwargs) # Training initialization total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num Epochs = {args.num_train_epochs}") logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {args.max_train_steps}") global_step = 0 first_epoch = 0 # Potentially load in the weights and states from a previous save if args.resume_from_checkpoint: if args.resume_from_checkpoint != "latest": path = os.path.basename(args.resume_from_checkpoint) else: # Get the most recent checkpoint dirs = os.listdir(args.output_dir) dirs = [d for d in dirs if d.startswith("checkpoint")] dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) path = dirs[-1] if len(dirs) > 0 else None if path is None: accelerator.print( f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." ) args.resume_from_checkpoint = None else: accelerator.print(f"Resuming from checkpoint {path}") accelerator.load_state(os.path.join(args.output_dir, path)) global_step = int(path.split("-")[1]) resume_global_step = global_step * args.gradient_accumulation_steps first_epoch = global_step // num_update_steps_per_epoch resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) # Bram Note: This was pretty janky to wrangle to look proper but works to my liking now progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) progress_bar.set_description("Steps") #### START MAIN TRAINING LOOP ##### for epoch in range(first_epoch, args.num_train_epochs): unet.train() train_loss = 0.0 implicit_acc_accumulated_d, implicit_acc_accumulated_c = 0.0, 0.0 for step, batch in enumerate(train_dataloader): # Skip steps until we reach the resumed step if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step and (not args.hard_skip_resume): if step % args.gradient_accumulation_steps == 0: print(f"Dummy processing step {step}, will start training at {resume_step}") continue with accelerator.accumulate(unet): # Convert images to latent space if args.train_method == 'dpo': # [bs, 6, 512, 512] => # [[bs, 3, 512, 512]*3] => # [bs*3, 3, 512, 512] feed_pixel_values = torch.cat(batch["pixel_values"].chunk(3, dim=1)) elif args.train_method == 'sft': feed_pixel_values = batch["pixel_values"] #### Diffusion Stuff #### # encode pixels --> latents with torch.no_grad(): latents = vae.encode(feed_pixel_values.to(weight_dtype)).latent_dist.sample() latents = latents * vae.config.scaling_factor # Sample noise that we'll add to the latents noise = torch.randn_like(latents) bsz = latents.shape[0] # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) timesteps = timesteps.long() if args.train_method == 'dpo': # make timesteps and noise same for pairs in DPO # [bs] => [1/3bs, 1/3bs, 1/3bs] => [1/3bs] => [bs] timesteps = timesteps.chunk(3)[0].repeat(3) noise = noise.chunk(3)[0].repeat(3, 1, 1, 1) # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) ### START PREP BATCH ### if args.sdxl: # Get the text embedding for conditioning with torch.no_grad(): # Need to compute "time_ids" https://github.com/huggingface/diffusers/blob/v0.20.0-release/examples/text_to_image/train_text_to_image_sdxl.py#L969 # for SDXL-base these are torch.tensor([args.resolution, args.resolution, *crop_coords_top_left, *target_size)) add_time_ids = torch.tensor([args.resolution, args.resolution, 0, 0, args.resolution, args.resolution], dtype=weight_dtype, device=accelerator.device)[None, :].repeat(timesteps.size(0), 1) prompt_batch = encode_prompt_sdxl(batch, text_encoders, tokenizers, args.proportion_empty_prompts, caption_column, is_train=True, ) if args.train_method == 'dpo': prompt_batch["prompt_embeds"] = prompt_batch["prompt_embeds"].repeat(3, 1, 1) prompt_batch["pooled_prompt_embeds"] = prompt_batch["pooled_prompt_embeds"].repeat(3, 1) unet_added_conditions = {"time_ids": add_time_ids, "text_embeds": prompt_batch["pooled_prompt_embeds"]} else: # sd1.5 # Get the text embedding for conditioning encoder_hidden_states = text_encoder(batch["input_ids"])[0] if args.train_method == 'dpo': encoder_hidden_states = encoder_hidden_states.repeat(2, 1, 1) emotion_visual_prompts = visual_prompts(batch['emotions']) if args.train_method == 'dpo': random_emotions = random_sample_emotions(batch['emotions']) random_emotion_visual_prompts = visual_prompts(random_emotions) emotion_visual_prompts = torch.cat([emotion_visual_prompts, emotion_visual_prompts, random_emotion_visual_prompts], dim=0) #### END PREP BATCH #### assert noise_scheduler.config.prediction_type == "epsilon" target = noise # Make the prediction from the model we're learning model_batch_args = ( noisy_latents, timesteps, prompt_batch["prompt_embeds"] if args.sdxl else encoder_hidden_states ) lora_model_batch_args = ( noisy_latents, timesteps, prompt_batch["prompt_embeds"] if args.sdxl else encoder_hidden_states, emotion_visual_prompts ) added_cond_kwargs = unet_added_conditions if args.sdxl else None model_pred = unet( *lora_model_batch_args, added_cond_kwargs = added_cond_kwargs ).sample #### START LOSS COMPUTATION #### if args.train_method == 'sft': # SFT, casting for F.mse_loss loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") elif args.train_method == 'dpo': # model_pred and ref_pred will be (2 * LBS) x 4 x latent_spatial_dim x latent_spatial_dim # losses are both 2 * LBS # 1st half of tensors is preferred (y_w), second half is unpreferred model_losses = (model_pred - target).pow(2).mean(dim=[1,2,3]) model_losses_w, model_losses_l_d, model_losses_l_c = model_losses.chunk(3) # below for logging purposes raw_model_loss = (model_losses_w.mean() + model_losses_l_d.mean() + model_losses_l_c.mean()) / 3 model_diff_d = model_losses_w - model_losses_l_d # These are both LBS (as is t) model_diff_c = model_losses_w - model_losses_l_c with torch.no_grad(): # Get the reference policy (unet) prediction ref_pred = ref_unet( *model_batch_args, added_cond_kwargs = added_cond_kwargs ).sample.detach() ref_losses = (ref_pred - target).pow(2).mean(dim=[1,2,3]) ref_losses_w, ref_losses_l_d, ref_losses_l_c = ref_losses.chunk(3) ref_diff = ref_losses_w - ref_losses_l_d raw_ref_loss = ref_losses.mean() scale_term = -0.5 * args.beta_dpo # beta_dpo = 5000 inside_term_d = scale_term * (model_diff_d - ref_diff) implicit_acc_d = (inside_term_d > 0).sum().float() / inside_term_d.size(0) # the scale_term may need to be adjust # inside_term_c = -1 * model_diff_c inside_term_c = scale_term * model_diff_c implicit_acc_c = (inside_term_c > 0).sum().float() / inside_term_c.size(0) loss = -1 * 0.5 * (F.logsigmoid(inside_term_d).mean() + F.logsigmoid(inside_term_c).mean()) #### END LOSS COMPUTATION ### # Gather the losses across all processes for logging avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() train_loss += avg_loss.item() / args.gradient_accumulation_steps # Also gather: # - model MSE vs reference MSE (useful to observe divergent behavior) # - Implicit accuracy if args.train_method == 'dpo': avg_model_mse = accelerator.gather(raw_model_loss.repeat(args.train_batch_size)).mean().item() avg_ref_mse = accelerator.gather(raw_ref_loss.repeat(args.train_batch_size)).mean().item() avg_acc_d = accelerator.gather(implicit_acc_d).mean().item() avg_acc_c = accelerator.gather(implicit_acc_c).mean().item() implicit_acc_accumulated_d += avg_acc_d / args.gradient_accumulation_steps implicit_acc_accumulated_c += avg_acc_c / args.gradient_accumulation_steps # Backpropagate accelerator.backward(loss) if accelerator.sync_gradients: if not args.use_adafactor: # Adafactor does itself, maybe could do here to cut down on code accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) # # 打印看看梯度 # for name, param in unet.named_parameters(): # # if "mid_block.attentions.0.transformer_blocks" in name and "lora" in name: # if param.grad is not None: # print(f"{name} has gradient ✅, grad mean: {param.grad.mean().item()}") # else: # print(f"{name} has NO gradient ❌") # for name, param in visual_prompts.named_parameters(): # if param.grad is not None: # print(f"{name} has gradient ✅, grad mean: {param.grad.mean().item()}") # else: # print(f"{name} has NO gradient ❌") optimizer.step() lr_scheduler.step() optimizer.zero_grad() # Checks if the accelerator has just performed an optimization step, if so do "end of batch" logging if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 accelerator.log({"train_loss": train_loss}, step=global_step) if args.train_method == 'dpo': accelerator.log({"model_mse_unaccumulated": avg_model_mse}, step=global_step) accelerator.log({"ref_mse_unaccumulated": avg_ref_mse}, step=global_step) accelerator.log({"avg_acc_d": implicit_acc_accumulated_d}, step=global_step) accelerator.log({"avg_acc_c": implicit_acc_accumulated_c}, step=global_step) train_loss = 0.0 implicit_acc_accumulated_d, implicit_acc_accumulated_c = 0.0, 0.0 if global_step % args.checkpointing_steps == 0: if accelerator.is_main_process: save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") logger.info("Pretty sure saving/loading is fixed but proceed cautiously") logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} if args.train_method == 'dpo': logs["implicit_acc_d"] = avg_acc_d logs["implicit_acc_c"] = avg_acc_c progress_bar.set_postfix(**logs) if global_step >= args.max_train_steps: break # Create the pipeline using the trained modules and save it. # This will save to top level of output_dir instead of a checkpoint directory accelerator.wait_for_everyone() if accelerator.is_main_process: unet = accelerator.unwrap_model(unet) if args.sdxl: # Serialize pipeline. if args.use_lora: unet_lora_state_dict = convert_state_dict_to_diffusers( get_peft_model_state_dict(unet) ) StableDiffusionXLPipeline.save_lora_weights( save_directory=os.path.join(args.output_dir, 'lora_weights_64'), unet_lora_layers=unet_lora_state_dict, safe_serialization=True, ) logger.info("Saved LoRA Model to {}".format(os.path.join(args.output_dir, 'lora_weights_64'))) else: vae = AutoencoderKL.from_pretrained( vae_path, subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, revision=args.revision, torch_dtype=weight_dtype, ) pipeline = StableDiffusionXLPipeline.from_pretrained( args.pretrained_model_name_or_path, unet=unet, vae=vae, revision=args.revision, torch_dtype=weight_dtype ) pipeline.save_pretrained(args.output_dir) logger.info("Saved Model to {}".format(args.output_dir)) else: pipeline = StableDiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, text_encoder=text_encoder, vae=vae, unet=unet, revision=args.revision, ) if not args.use_lora: pipeline.save_pretrained(args.output_dir) accelerator.end_training() if __name__ == "__main__": main()