| 
							 | 
						import torch | 
					
					
						
						| 
							 | 
						import torch.nn.functional as F | 
					
					
						
						| 
							 | 
						import inspect | 
					
					
						
						| 
							 | 
						import numpy as np | 
					
					
						
						| 
							 | 
						from typing import Callable, List, Optional, Union | 
					
					
						
						| 
							 | 
						from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel, CLIPImageProcessor | 
					
					
						
						| 
							 | 
						from diffusers import AutoencoderKL, DiffusionPipeline | 
					
					
						
						| 
							 | 
						from diffusers.utils import ( | 
					
					
						
						| 
							 | 
						    deprecate, | 
					
					
						
						| 
							 | 
						    is_accelerate_available, | 
					
					
						
						| 
							 | 
						    is_accelerate_version, | 
					
					
						
						| 
							 | 
						    logging, | 
					
					
						
						| 
							 | 
						) | 
					
					
						
						| 
							 | 
						from diffusers.configuration_utils import FrozenDict | 
					
					
						
						| 
							 | 
						from diffusers.schedulers import DDIMScheduler | 
					
					
						
						| 
							 | 
						from diffusers.utils.torch_utils import randn_tensor | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						from mv_unet import MultiViewUNetModel, get_camera | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						logger = logging.get_logger(__name__)   | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class MVDreamPipeline(DiffusionPipeline): | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    _optional_components = ["feature_extractor", "image_encoder"] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def __init__( | 
					
					
						
						| 
							 | 
						        self, | 
					
					
						
						| 
							 | 
						        vae: AutoencoderKL, | 
					
					
						
						| 
							 | 
						        unet: MultiViewUNetModel, | 
					
					
						
						| 
							 | 
						        tokenizer: CLIPTokenizer, | 
					
					
						
						| 
							 | 
						        text_encoder: CLIPTextModel, | 
					
					
						
						| 
							 | 
						        scheduler: DDIMScheduler, | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        feature_extractor: CLIPImageProcessor, | 
					
					
						
						| 
							 | 
						        image_encoder: CLIPVisionModel, | 
					
					
						
						| 
							 | 
						        requires_safety_checker: bool = False, | 
					
					
						
						| 
							 | 
						    ): | 
					
					
						
						| 
							 | 
						        super().__init__() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:   | 
					
					
						
						| 
							 | 
						            deprecation_message = ( | 
					
					
						
						| 
							 | 
						                f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" | 
					
					
						
						| 
							 | 
						                f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "   | 
					
					
						
						| 
							 | 
						                "to update the config accordingly as leaving `steps_offset` might led to incorrect results" | 
					
					
						
						| 
							 | 
						                " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," | 
					
					
						
						| 
							 | 
						                " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" | 
					
					
						
						| 
							 | 
						                " file" | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						            deprecate( | 
					
					
						
						| 
							 | 
						                "steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						            new_config = dict(scheduler.config) | 
					
					
						
						| 
							 | 
						            new_config["steps_offset"] = 1 | 
					
					
						
						| 
							 | 
						            scheduler._internal_dict = FrozenDict(new_config) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:   | 
					
					
						
						| 
							 | 
						            deprecation_message = ( | 
					
					
						
						| 
							 | 
						                f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." | 
					
					
						
						| 
							 | 
						                " `clip_sample` should be set to False in the configuration file. Please make sure to update the" | 
					
					
						
						| 
							 | 
						                " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" | 
					
					
						
						| 
							 | 
						                " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" | 
					
					
						
						| 
							 | 
						                " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						            deprecate( | 
					
					
						
						| 
							 | 
						                "clip_sample not set", "1.0.0", deprecation_message, standard_warn=False | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						            new_config = dict(scheduler.config) | 
					
					
						
						| 
							 | 
						            new_config["clip_sample"] = False | 
					
					
						
						| 
							 | 
						            scheduler._internal_dict = FrozenDict(new_config) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.register_modules( | 
					
					
						
						| 
							 | 
						            vae=vae, | 
					
					
						
						| 
							 | 
						            unet=unet, | 
					
					
						
						| 
							 | 
						            scheduler=scheduler, | 
					
					
						
						| 
							 | 
						            tokenizer=tokenizer, | 
					
					
						
						| 
							 | 
						            text_encoder=text_encoder, | 
					
					
						
						| 
							 | 
						            feature_extractor=feature_extractor, | 
					
					
						
						| 
							 | 
						            image_encoder=image_encoder, | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) | 
					
					
						
						| 
							 | 
						        self.register_to_config(requires_safety_checker=requires_safety_checker) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def enable_vae_slicing(self): | 
					
					
						
						| 
							 | 
						        r""" | 
					
					
						
						| 
							 | 
						        Enable sliced VAE decoding. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						        When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several | 
					
					
						
						| 
							 | 
						        steps. This is useful to save some memory and allow larger batch sizes. | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        self.vae.enable_slicing() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def disable_vae_slicing(self): | 
					
					
						
						| 
							 | 
						        r""" | 
					
					
						
						| 
							 | 
						        Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to | 
					
					
						
						| 
							 | 
						        computing decoding in one step. | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        self.vae.disable_slicing() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def enable_vae_tiling(self): | 
					
					
						
						| 
							 | 
						        r""" | 
					
					
						
						| 
							 | 
						        Enable tiled VAE decoding. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						        When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in | 
					
					
						
						| 
							 | 
						        several steps. This is useful to save a large amount of memory and to allow the processing of larger images. | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        self.vae.enable_tiling() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def disable_vae_tiling(self): | 
					
					
						
						| 
							 | 
						        r""" | 
					
					
						
						| 
							 | 
						        Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to | 
					
					
						
						| 
							 | 
						        computing decoding in one step. | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        self.vae.disable_tiling() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def enable_sequential_cpu_offload(self, gpu_id=0): | 
					
					
						
						| 
							 | 
						        r""" | 
					
					
						
						| 
							 | 
						        Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, | 
					
					
						
						| 
							 | 
						        text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a | 
					
					
						
						| 
							 | 
						        `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. | 
					
					
						
						| 
							 | 
						        Note that offloading happens on a submodule basis. Memory savings are higher than with | 
					
					
						
						| 
							 | 
						        `enable_model_cpu_offload`, but performance is lower. | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): | 
					
					
						
						| 
							 | 
						            from accelerate import cpu_offload | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            raise ImportError( | 
					
					
						
						| 
							 | 
						                "`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher" | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        device = torch.device(f"cuda:{gpu_id}") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if self.device.type != "cpu": | 
					
					
						
						| 
							 | 
						            self.to("cpu", silence_dtype_warnings=True) | 
					
					
						
						| 
							 | 
						            torch.cuda.empty_cache()   | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: | 
					
					
						
						| 
							 | 
						            cpu_offload(cpu_offloaded_model, device) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def enable_model_cpu_offload(self, gpu_id=0): | 
					
					
						
						| 
							 | 
						        r""" | 
					
					
						
						| 
							 | 
						        Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared | 
					
					
						
						| 
							 | 
						        to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` | 
					
					
						
						| 
							 | 
						        method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with | 
					
					
						
						| 
							 | 
						        `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): | 
					
					
						
						| 
							 | 
						            from accelerate import cpu_offload_with_hook | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            raise ImportError( | 
					
					
						
						| 
							 | 
						                "`enable_model_offload` requires `accelerate v0.17.0` or higher." | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        device = torch.device(f"cuda:{gpu_id}") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if self.device.type != "cpu": | 
					
					
						
						| 
							 | 
						            self.to("cpu", silence_dtype_warnings=True) | 
					
					
						
						| 
							 | 
						            torch.cuda.empty_cache()   | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        hook = None | 
					
					
						
						| 
							 | 
						        for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: | 
					
					
						
						| 
							 | 
						            _, hook = cpu_offload_with_hook( | 
					
					
						
						| 
							 | 
						                cpu_offloaded_model, device, prev_module_hook=hook | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        self.final_offload_hook = hook | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    @property | 
					
					
						
						| 
							 | 
						    def _execution_device(self): | 
					
					
						
						| 
							 | 
						        r""" | 
					
					
						
						| 
							 | 
						        Returns the device on which the pipeline's models will be executed. After calling | 
					
					
						
						| 
							 | 
						        `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module | 
					
					
						
						| 
							 | 
						        hooks. | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        if not hasattr(self.unet, "_hf_hook"): | 
					
					
						
						| 
							 | 
						            return self.device | 
					
					
						
						| 
							 | 
						        for module in self.unet.modules(): | 
					
					
						
						| 
							 | 
						            if ( | 
					
					
						
						| 
							 | 
						                hasattr(module, "_hf_hook") | 
					
					
						
						| 
							 | 
						                and hasattr(module._hf_hook, "execution_device") | 
					
					
						
						| 
							 | 
						                and module._hf_hook.execution_device is not None | 
					
					
						
						| 
							 | 
						            ): | 
					
					
						
						| 
							 | 
						                return torch.device(module._hf_hook.execution_device) | 
					
					
						
						| 
							 | 
						        return self.device | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def _encode_prompt( | 
					
					
						
						| 
							 | 
						        self, | 
					
					
						
						| 
							 | 
						        prompt, | 
					
					
						
						| 
							 | 
						        device, | 
					
					
						
						| 
							 | 
						        num_images_per_prompt, | 
					
					
						
						| 
							 | 
						        do_classifier_free_guidance: bool, | 
					
					
						
						| 
							 | 
						        negative_prompt=None, | 
					
					
						
						| 
							 | 
						    ): | 
					
					
						
						| 
							 | 
						        r""" | 
					
					
						
						| 
							 | 
						        Encodes the prompt into text encoder hidden states. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						        Args: | 
					
					
						
						| 
							 | 
						             prompt (`str` or `List[str]`, *optional*): | 
					
					
						
						| 
							 | 
						                prompt to be encoded | 
					
					
						
						| 
							 | 
						            device: (`torch.device`): | 
					
					
						
						| 
							 | 
						                torch device | 
					
					
						
						| 
							 | 
						            num_images_per_prompt (`int`): | 
					
					
						
						| 
							 | 
						                number of images that should be generated per prompt | 
					
					
						
						| 
							 | 
						            do_classifier_free_guidance (`bool`): | 
					
					
						
						| 
							 | 
						                whether to use classifier free guidance or not | 
					
					
						
						| 
							 | 
						            negative_prompt (`str` or `List[str]`, *optional*): | 
					
					
						
						| 
							 | 
						                The prompt or prompts not to guide the image generation. If not defined, one has to pass | 
					
					
						
						| 
							 | 
						                `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. | 
					
					
						
						| 
							 | 
						                Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). | 
					
					
						
						| 
							 | 
						            prompt_embeds (`torch.FloatTensor`, *optional*): | 
					
					
						
						| 
							 | 
						                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not | 
					
					
						
						| 
							 | 
						                provided, text embeddings will be generated from `prompt` input argument. | 
					
					
						
						| 
							 | 
						            negative_prompt_embeds (`torch.FloatTensor`, *optional*): | 
					
					
						
						| 
							 | 
						                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt | 
					
					
						
						| 
							 | 
						                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input | 
					
					
						
						| 
							 | 
						                argument. | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        if prompt is not None and isinstance(prompt, str): | 
					
					
						
						| 
							 | 
						            batch_size = 1 | 
					
					
						
						| 
							 | 
						        elif prompt is not None and isinstance(prompt, list): | 
					
					
						
						| 
							 | 
						            batch_size = len(prompt) | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            raise ValueError( | 
					
					
						
						| 
							 | 
						                f"`prompt` should be either a string or a list of strings, but got {type(prompt)}." | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        text_inputs = self.tokenizer( | 
					
					
						
						| 
							 | 
						            prompt, | 
					
					
						
						| 
							 | 
						            padding="max_length", | 
					
					
						
						| 
							 | 
						            max_length=self.tokenizer.model_max_length, | 
					
					
						
						| 
							 | 
						            truncation=True, | 
					
					
						
						| 
							 | 
						            return_tensors="pt", | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        text_input_ids = text_inputs.input_ids | 
					
					
						
						| 
							 | 
						        untruncated_ids = self.tokenizer( | 
					
					
						
						| 
							 | 
						            prompt, padding="longest", return_tensors="pt" | 
					
					
						
						| 
							 | 
						        ).input_ids | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( | 
					
					
						
						| 
							 | 
						            text_input_ids, untruncated_ids | 
					
					
						
						| 
							 | 
						        ): | 
					
					
						
						| 
							 | 
						            removed_text = self.tokenizer.batch_decode( | 
					
					
						
						| 
							 | 
						                untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						            logger.warning( | 
					
					
						
						| 
							 | 
						                "The following part of your input was truncated because CLIP can only handle sequences up to" | 
					
					
						
						| 
							 | 
						                f" {self.tokenizer.model_max_length} tokens: {removed_text}" | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if ( | 
					
					
						
						| 
							 | 
						            hasattr(self.text_encoder.config, "use_attention_mask") | 
					
					
						
						| 
							 | 
						            and self.text_encoder.config.use_attention_mask | 
					
					
						
						| 
							 | 
						        ): | 
					
					
						
						| 
							 | 
						            attention_mask = text_inputs.attention_mask.to(device) | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            attention_mask = None | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        prompt_embeds = self.text_encoder( | 
					
					
						
						| 
							 | 
						            text_input_ids.to(device), | 
					
					
						
						| 
							 | 
						            attention_mask=attention_mask, | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        prompt_embeds = prompt_embeds[0] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        bs_embed, seq_len, _ = prompt_embeds.shape | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) | 
					
					
						
						| 
							 | 
						        prompt_embeds = prompt_embeds.view( | 
					
					
						
						| 
							 | 
						            bs_embed * num_images_per_prompt, seq_len, -1 | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if do_classifier_free_guidance: | 
					
					
						
						| 
							 | 
						            uncond_tokens: List[str] | 
					
					
						
						| 
							 | 
						            if negative_prompt is None: | 
					
					
						
						| 
							 | 
						                uncond_tokens = [""] * batch_size | 
					
					
						
						| 
							 | 
						            elif type(prompt) is not type(negative_prompt): | 
					
					
						
						| 
							 | 
						                raise TypeError( | 
					
					
						
						| 
							 | 
						                    f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" | 
					
					
						
						| 
							 | 
						                    f" {type(prompt)}." | 
					
					
						
						| 
							 | 
						                ) | 
					
					
						
						| 
							 | 
						            elif isinstance(negative_prompt, str): | 
					
					
						
						| 
							 | 
						                uncond_tokens = [negative_prompt] | 
					
					
						
						| 
							 | 
						            elif batch_size != len(negative_prompt): | 
					
					
						
						| 
							 | 
						                raise ValueError( | 
					
					
						
						| 
							 | 
						                    f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" | 
					
					
						
						| 
							 | 
						                    f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" | 
					
					
						
						| 
							 | 
						                    " the batch size of `prompt`." | 
					
					
						
						| 
							 | 
						                ) | 
					
					
						
						| 
							 | 
						            else: | 
					
					
						
						| 
							 | 
						                uncond_tokens = negative_prompt | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            max_length = prompt_embeds.shape[1] | 
					
					
						
						| 
							 | 
						            uncond_input = self.tokenizer( | 
					
					
						
						| 
							 | 
						                uncond_tokens, | 
					
					
						
						| 
							 | 
						                padding="max_length", | 
					
					
						
						| 
							 | 
						                max_length=max_length, | 
					
					
						
						| 
							 | 
						                truncation=True, | 
					
					
						
						| 
							 | 
						                return_tensors="pt", | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            if ( | 
					
					
						
						| 
							 | 
						                hasattr(self.text_encoder.config, "use_attention_mask") | 
					
					
						
						| 
							 | 
						                and self.text_encoder.config.use_attention_mask | 
					
					
						
						| 
							 | 
						            ): | 
					
					
						
						| 
							 | 
						                attention_mask = uncond_input.attention_mask.to(device) | 
					
					
						
						| 
							 | 
						            else: | 
					
					
						
						| 
							 | 
						                attention_mask = None | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            negative_prompt_embeds = self.text_encoder( | 
					
					
						
						| 
							 | 
						                uncond_input.input_ids.to(device), | 
					
					
						
						| 
							 | 
						                attention_mask=attention_mask, | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						            negative_prompt_embeds = negative_prompt_embeds[0] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            seq_len = negative_prompt_embeds.shape[1] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            negative_prompt_embeds = negative_prompt_embeds.to( | 
					
					
						
						| 
							 | 
						                dtype=self.text_encoder.dtype, device=device | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            negative_prompt_embeds = negative_prompt_embeds.repeat( | 
					
					
						
						| 
							 | 
						                1, num_images_per_prompt, 1 | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						            negative_prompt_embeds = negative_prompt_embeds.view( | 
					
					
						
						| 
							 | 
						                batch_size * num_images_per_prompt, seq_len, -1 | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        return prompt_embeds | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def decode_latents(self, latents): | 
					
					
						
						| 
							 | 
						        latents = 1 / self.vae.config.scaling_factor * latents | 
					
					
						
						| 
							 | 
						        image = self.vae.decode(latents).sample | 
					
					
						
						| 
							 | 
						        image = (image / 2 + 0.5).clamp(0, 1) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        image = image.cpu().permute(0, 2, 3, 1).float().numpy() | 
					
					
						
						| 
							 | 
						        return image | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def prepare_extra_step_kwargs(self, generator, eta): | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        accepts_eta = "eta" in set( | 
					
					
						
						| 
							 | 
						            inspect.signature(self.scheduler.step).parameters.keys() | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        extra_step_kwargs = {} | 
					
					
						
						| 
							 | 
						        if accepts_eta: | 
					
					
						
						| 
							 | 
						            extra_step_kwargs["eta"] = eta | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        accepts_generator = "generator" in set( | 
					
					
						
						| 
							 | 
						            inspect.signature(self.scheduler.step).parameters.keys() | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        if accepts_generator: | 
					
					
						
						| 
							 | 
						            extra_step_kwargs["generator"] = generator | 
					
					
						
						| 
							 | 
						        return extra_step_kwargs | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def prepare_latents( | 
					
					
						
						| 
							 | 
						        self, | 
					
					
						
						| 
							 | 
						        batch_size, | 
					
					
						
						| 
							 | 
						        num_channels_latents, | 
					
					
						
						| 
							 | 
						        height, | 
					
					
						
						| 
							 | 
						        width, | 
					
					
						
						| 
							 | 
						        dtype, | 
					
					
						
						| 
							 | 
						        device, | 
					
					
						
						| 
							 | 
						        generator, | 
					
					
						
						| 
							 | 
						        latents=None, | 
					
					
						
						| 
							 | 
						    ): | 
					
					
						
						| 
							 | 
						        shape = ( | 
					
					
						
						| 
							 | 
						            batch_size, | 
					
					
						
						| 
							 | 
						            num_channels_latents, | 
					
					
						
						| 
							 | 
						            height // self.vae_scale_factor, | 
					
					
						
						| 
							 | 
						            width // self.vae_scale_factor, | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        if isinstance(generator, list) and len(generator) != batch_size: | 
					
					
						
						| 
							 | 
						            raise ValueError( | 
					
					
						
						| 
							 | 
						                f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" | 
					
					
						
						| 
							 | 
						                f" size of {batch_size}. Make sure the batch size matches the length of the generators." | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if latents is None: | 
					
					
						
						| 
							 | 
						            latents = randn_tensor( | 
					
					
						
						| 
							 | 
						                shape, generator=generator, device=device, dtype=dtype | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            latents = latents.to(device) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        latents = latents * self.scheduler.init_noise_sigma | 
					
					
						
						| 
							 | 
						        return latents | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def encode_image(self, image, device, num_images_per_prompt): | 
					
					
						
						| 
							 | 
						        dtype = next(self.image_encoder.parameters()).dtype | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if image.dtype == np.float32: | 
					
					
						
						| 
							 | 
						            image = (image * 255).astype(np.uint8) | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						        image = self.feature_extractor(image, return_tensors="pt").pixel_values | 
					
					
						
						| 
							 | 
						        image = image.to(device=device, dtype=dtype) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        image_embeds = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] | 
					
					
						
						| 
							 | 
						        image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        return torch.zeros_like(image_embeds), image_embeds | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def encode_image_latents(self, image, device, num_images_per_prompt): | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        dtype = next(self.image_encoder.parameters()).dtype | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        image = torch.from_numpy(image).unsqueeze(0).permute(0, 3, 1, 2).to(device=device)  | 
					
					
						
						| 
							 | 
						        image = 2 * image - 1 | 
					
					
						
						| 
							 | 
						        image = F.interpolate(image, (256, 256), mode='bilinear', align_corners=False) | 
					
					
						
						| 
							 | 
						        image = image.to(dtype=dtype) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        posterior = self.vae.encode(image).latent_dist | 
					
					
						
						| 
							 | 
						        latents = posterior.sample() * self.vae.config.scaling_factor  | 
					
					
						
						| 
							 | 
						        latents = latents.repeat_interleave(num_images_per_prompt, dim=0) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        return torch.zeros_like(latents), latents | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    @torch.no_grad() | 
					
					
						
						| 
							 | 
						    def __call__( | 
					
					
						
						| 
							 | 
						        self, | 
					
					
						
						| 
							 | 
						        prompt: str = "", | 
					
					
						
						| 
							 | 
						        image: Optional[np.ndarray] = None, | 
					
					
						
						| 
							 | 
						        height: int = 256, | 
					
					
						
						| 
							 | 
						        width: int = 256, | 
					
					
						
						| 
							 | 
						        elevation: float = 0, | 
					
					
						
						| 
							 | 
						        num_inference_steps: int = 50, | 
					
					
						
						| 
							 | 
						        guidance_scale: float = 7.0, | 
					
					
						
						| 
							 | 
						        negative_prompt: str = "", | 
					
					
						
						| 
							 | 
						        num_images_per_prompt: int = 1, | 
					
					
						
						| 
							 | 
						        eta: float = 0.0, | 
					
					
						
						| 
							 | 
						        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | 
					
					
						
						| 
							 | 
						        output_type: Optional[str] = "numpy",  | 
					
					
						
						| 
							 | 
						        callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, | 
					
					
						
						| 
							 | 
						        callback_steps: int = 1, | 
					
					
						
						| 
							 | 
						        num_frames: int = 4, | 
					
					
						
						| 
							 | 
						        device=torch.device("cuda:0"), | 
					
					
						
						| 
							 | 
						    ): | 
					
					
						
						| 
							 | 
						        self.unet = self.unet.to(device=device) | 
					
					
						
						| 
							 | 
						        self.vae = self.vae.to(device=device) | 
					
					
						
						| 
							 | 
						        self.text_encoder = self.text_encoder.to(device=device) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        do_classifier_free_guidance = guidance_scale > 1.0 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        self.scheduler.set_timesteps(num_inference_steps, device=device) | 
					
					
						
						| 
							 | 
						        timesteps = self.scheduler.timesteps | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if image is not None: | 
					
					
						
						| 
							 | 
						            assert isinstance(image, np.ndarray) and image.dtype == np.float32 | 
					
					
						
						| 
							 | 
						            self.image_encoder = self.image_encoder.to(device=device) | 
					
					
						
						| 
							 | 
						            image_embeds_neg, image_embeds_pos = self.encode_image(image, device, num_images_per_prompt) | 
					
					
						
						| 
							 | 
						            image_latents_neg, image_latents_pos = self.encode_image_latents(image, device, num_images_per_prompt) | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						        _prompt_embeds = self._encode_prompt( | 
					
					
						
						| 
							 | 
						            prompt=prompt, | 
					
					
						
						| 
							 | 
						            device=device, | 
					
					
						
						| 
							 | 
						            num_images_per_prompt=num_images_per_prompt, | 
					
					
						
						| 
							 | 
						            do_classifier_free_guidance=do_classifier_free_guidance, | 
					
					
						
						| 
							 | 
						            negative_prompt=negative_prompt, | 
					
					
						
						| 
							 | 
						        )   | 
					
					
						
						| 
							 | 
						        prompt_embeds_neg, prompt_embeds_pos = _prompt_embeds.chunk(2) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        actual_num_frames = num_frames if image is None else num_frames + 1 | 
					
					
						
						| 
							 | 
						        latents: torch.Tensor = self.prepare_latents( | 
					
					
						
						| 
							 | 
						            actual_num_frames * num_images_per_prompt, | 
					
					
						
						| 
							 | 
						            4, | 
					
					
						
						| 
							 | 
						            height, | 
					
					
						
						| 
							 | 
						            width, | 
					
					
						
						| 
							 | 
						            prompt_embeds_pos.dtype, | 
					
					
						
						| 
							 | 
						            device, | 
					
					
						
						| 
							 | 
						            generator, | 
					
					
						
						| 
							 | 
						            None, | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        camera = get_camera(num_frames, elevation=elevation, extra_view=(image is not None)).to(dtype=latents.dtype, device=device) | 
					
					
						
						| 
							 | 
						        camera = camera.repeat_interleave(num_images_per_prompt, dim=0) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order | 
					
					
						
						| 
							 | 
						        with self.progress_bar(total=num_inference_steps) as progress_bar: | 
					
					
						
						| 
							 | 
						            for i, t in enumerate(timesteps): | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                multiplier = 2 if do_classifier_free_guidance else 1 | 
					
					
						
						| 
							 | 
						                latent_model_input = torch.cat([latents] * multiplier) | 
					
					
						
						| 
							 | 
						                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                unet_inputs = { | 
					
					
						
						| 
							 | 
						                    'x': latent_model_input, | 
					
					
						
						| 
							 | 
						                    'timesteps': torch.tensor([t] * actual_num_frames * multiplier, dtype=latent_model_input.dtype, device=device), | 
					
					
						
						| 
							 | 
						                    'context': torch.cat([prompt_embeds_neg] * actual_num_frames + [prompt_embeds_pos] * actual_num_frames), | 
					
					
						
						| 
							 | 
						                    'num_frames': actual_num_frames, | 
					
					
						
						| 
							 | 
						                    'camera': torch.cat([camera] * multiplier), | 
					
					
						
						| 
							 | 
						                } | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                if image is not None: | 
					
					
						
						| 
							 | 
						                    unet_inputs['ip'] = torch.cat([image_embeds_neg] * actual_num_frames + [image_embeds_pos] * actual_num_frames) | 
					
					
						
						| 
							 | 
						                    unet_inputs['ip_img'] = torch.cat([image_latents_neg] + [image_latents_pos])  | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                noise_pred = self.unet.forward(**unet_inputs) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                if do_classifier_free_guidance: | 
					
					
						
						| 
							 | 
						                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | 
					
					
						
						| 
							 | 
						                    noise_pred = noise_pred_uncond + guidance_scale * ( | 
					
					
						
						| 
							 | 
						                        noise_pred_text - noise_pred_uncond | 
					
					
						
						| 
							 | 
						                    ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                latents: torch.Tensor = self.scheduler.step( | 
					
					
						
						| 
							 | 
						                    noise_pred, t, latents, **extra_step_kwargs, return_dict=False | 
					
					
						
						| 
							 | 
						                )[0] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                if i == len(timesteps) - 1 or ( | 
					
					
						
						| 
							 | 
						                    (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 | 
					
					
						
						| 
							 | 
						                ): | 
					
					
						
						| 
							 | 
						                    progress_bar.update() | 
					
					
						
						| 
							 | 
						                    if callback is not None and i % callback_steps == 0: | 
					
					
						
						| 
							 | 
						                        callback(i, t, latents)   | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if output_type == "latent": | 
					
					
						
						| 
							 | 
						            image = latents | 
					
					
						
						| 
							 | 
						        elif output_type == "pil": | 
					
					
						
						| 
							 | 
						            image = self.decode_latents(latents) | 
					
					
						
						| 
							 | 
						            image = self.numpy_to_pil(image) | 
					
					
						
						| 
							 | 
						        else:  | 
					
					
						
						| 
							 | 
						            image = self.decode_latents(latents) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: | 
					
					
						
						| 
							 | 
						            self.final_offload_hook.offload() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        return image | 
					
					
						
						| 
							 | 
						
 |