finally, everything works locally
Browse files- README.md +42 -17
- convert_mvdream_to_diffusers.py +2 -2
- mvdream/adaptor.py +0 -28
- mvdream/attention.py +2 -4
- mvdream/models.py +6 -8
- mvdream/pipeline_mvdream.py +19 -22
- requirements.lock.txt +6 -0
- requirements.txt +6 -0
- run_imagedream.py +3 -4
- run_mvdream.py +3 -4
    	
        README.md
    CHANGED
    
    | @@ -1,15 +1,27 @@ | |
| 1 | 
            -
            # MVDream- | 
| 2 |  | 
| 3 | 
            -
             | 
| 4 |  | 
| 5 | 
            -
             | 
|  | |
|  | |
| 6 |  | 
| 7 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 8 | 
             
            ```bash
         | 
| 9 | 
             
            # dependency
         | 
| 10 | 
            -
            pip install - | 
|  | |
|  | |
|  | |
| 11 |  | 
| 12 | 
            -
             | 
|  | |
|  | |
| 13 | 
             
            cd models
         | 
| 14 | 
             
            wget https://huggingface.co/MVDream/MVDream/resolve/main/sd-v2.1-base-4view.pt
         | 
| 15 | 
             
            wget https://raw.githubusercontent.com/bytedance/MVDream/main/mvdream/configs/sd-v2-base.yaml
         | 
| @@ -21,18 +33,31 @@ python convert_mvdream_to_diffusers.py --checkpoint_path models/sd-v2.1-base-4vi | |
| 21 |  | 
| 22 | 
             
            ImageDream:
         | 
| 23 | 
             
            ```bash
         | 
| 24 | 
            -
            # download original ckpt
         | 
| 25 | 
            -
             | 
| 26 | 
            -
            wget https:// | 
|  | |
|  | |
| 27 |  | 
| 28 | 
             
            # convert
         | 
| 29 | 
            -
            python convert_mvdream_to_diffusers.py --checkpoint_path models/sd-v2.1-base-4view-ipmv | 
| 30 | 
             
            ```
         | 
| 31 |  | 
| 32 | 
            -
            ###  | 
| 33 | 
            -
             | 
| 34 | 
            -
             | 
| 35 | 
            -
            ``` | 
| 36 | 
            -
             | 
| 37 | 
            -
             | 
| 38 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # MVDream-diffusers
         | 
| 2 |  | 
| 3 | 
            +
            A **unified** diffusers implementation of [MVDream](https://github.com/bytedance/MVDream) and [ImageDream](https://github.com/bytedance/ImageDream).
         | 
| 4 |  | 
| 5 | 
            +
            We provide converted `fp16` weights on [huggingface](TODO).
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            ### Usage
         | 
| 8 |  | 
| 9 | 
            +
            ```bash
         | 
| 10 | 
            +
            python run_mvdream.py "a cute owl"
         | 
| 11 | 
            +
            python run_imagedream.py data/anya_rgba.png
         | 
| 12 | 
            +
            ```
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            ### Install
         | 
| 15 | 
             
            ```bash
         | 
| 16 | 
             
            # dependency
         | 
| 17 | 
            +
            pip install -r requirements.txt
         | 
| 18 | 
            +
            ```
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            ### Convert weights
         | 
| 21 |  | 
| 22 | 
            +
            MVDream:
         | 
| 23 | 
            +
            ```bash
         | 
| 24 | 
            +
            # download original ckpt (we only support the SD 2.1 version)
         | 
| 25 | 
             
            cd models
         | 
| 26 | 
             
            wget https://huggingface.co/MVDream/MVDream/resolve/main/sd-v2.1-base-4view.pt
         | 
| 27 | 
             
            wget https://raw.githubusercontent.com/bytedance/MVDream/main/mvdream/configs/sd-v2-base.yaml
         | 
|  | |
| 33 |  | 
| 34 | 
             
            ImageDream:
         | 
| 35 | 
             
            ```bash
         | 
| 36 | 
            +
            # download original ckpt (we only support the pixel-controller version)
         | 
| 37 | 
            +
            cd models
         | 
| 38 | 
            +
            wget https://huggingface.co/Peng-Wang/ImageDream/resolve/main/sd-v2.1-base-4view-ipmv.pt
         | 
| 39 | 
            +
            wget https://raw.githubusercontent.com/bytedance/ImageDream/main/extern/ImageDream/imagedream/configs/sd_v2_base_ipmv.yaml
         | 
| 40 | 
            +
            cd ..
         | 
| 41 |  | 
| 42 | 
             
            # convert
         | 
| 43 | 
            +
            python convert_mvdream_to_diffusers.py --checkpoint_path models/sd-v2.1-base-4view-ipmv.pt --dump_path ./weights_imagedream --original_config_file models/sd_v2_base_ipmv.yaml --half --to_safetensors --test
         | 
| 44 | 
             
            ```
         | 
| 45 |  | 
| 46 | 
            +
            ### Acknowledgement
         | 
| 47 | 
            +
             | 
| 48 | 
            +
            * The original papers:
         | 
| 49 | 
            +
                ```bibtex
         | 
| 50 | 
            +
                @article{shi2023MVDream,
         | 
| 51 | 
            +
                    author = {Shi, Yichun and Wang, Peng and Ye, Jianglong and Mai, Long and Li, Kejie and Yang, Xiao},
         | 
| 52 | 
            +
                    title = {MVDream: Multi-view Diffusion for 3D Generation},
         | 
| 53 | 
            +
                    journal = {arXiv:2308.16512},
         | 
| 54 | 
            +
                    year = {2023},
         | 
| 55 | 
            +
                }
         | 
| 56 | 
            +
                @article{wang2023imagedream,
         | 
| 57 | 
            +
                    title={ImageDream: Image-Prompt Multi-view Diffusion for 3D Generation},
         | 
| 58 | 
            +
                    author={Wang, Peng and Shi, Yichun},
         | 
| 59 | 
            +
                    journal={arXiv preprint arXiv:2312.02201},
         | 
| 60 | 
            +
                    year={2023}
         | 
| 61 | 
            +
                }
         | 
| 62 | 
            +
                ```
         | 
| 63 | 
            +
            * This codebase is modified from [mvdream-hf](https://github.com/KokeCacao/mvdream-hf).
         | 
    	
        convert_mvdream_to_diffusers.py
    CHANGED
    
    | @@ -568,7 +568,7 @@ if __name__ == "__main__": | |
| 568 | 
             
                            images = pipe(
         | 
| 569 | 
             
                                image=input_image,
         | 
| 570 | 
             
                                prompt="",
         | 
| 571 | 
            -
                                negative_prompt=" | 
| 572 | 
             
                                output_type="pil",
         | 
| 573 | 
             
                                guidance_scale=5.0,
         | 
| 574 | 
             
                                num_inference_steps=50,
         | 
| @@ -582,7 +582,7 @@ if __name__ == "__main__": | |
| 582 | 
             
                            images = loaded_pipe(
         | 
| 583 | 
             
                                image=input_image,
         | 
| 584 | 
             
                                prompt="",
         | 
| 585 | 
            -
                                negative_prompt=" | 
| 586 | 
             
                                output_type="pil",
         | 
| 587 | 
             
                                guidance_scale=5.0,
         | 
| 588 | 
             
                                num_inference_steps=50,
         | 
|  | |
| 568 | 
             
                            images = pipe(
         | 
| 569 | 
             
                                image=input_image,
         | 
| 570 | 
             
                                prompt="",
         | 
| 571 | 
            +
                                negative_prompt="",
         | 
| 572 | 
             
                                output_type="pil",
         | 
| 573 | 
             
                                guidance_scale=5.0,
         | 
| 574 | 
             
                                num_inference_steps=50,
         | 
|  | |
| 582 | 
             
                            images = loaded_pipe(
         | 
| 583 | 
             
                                image=input_image,
         | 
| 584 | 
             
                                prompt="",
         | 
| 585 | 
            +
                                negative_prompt="",
         | 
| 586 | 
             
                                output_type="pil",
         | 
| 587 | 
             
                                guidance_scale=5.0,
         | 
| 588 | 
             
                                num_inference_steps=50,
         | 
    	
        mvdream/adaptor.py
    CHANGED
    
    | @@ -73,34 +73,6 @@ class PerceiverAttention(nn.Module): | |
| 73 | 
             
                    return self.to_out(out)
         | 
| 74 |  | 
| 75 |  | 
| 76 | 
            -
            class ImageProjModel(torch.nn.Module):
         | 
| 77 | 
            -
                """Projection Model"""
         | 
| 78 | 
            -
             | 
| 79 | 
            -
                def __init__(
         | 
| 80 | 
            -
                    self,
         | 
| 81 | 
            -
                    cross_attention_dim=1024,
         | 
| 82 | 
            -
                    clip_embeddings_dim=1024,
         | 
| 83 | 
            -
                    clip_extra_context_tokens=4,
         | 
| 84 | 
            -
                ):
         | 
| 85 | 
            -
                    super().__init__()
         | 
| 86 | 
            -
                    self.cross_attention_dim = cross_attention_dim
         | 
| 87 | 
            -
                    self.clip_extra_context_tokens = clip_extra_context_tokens
         | 
| 88 | 
            -
             | 
| 89 | 
            -
                    # from 1024 -> 4 * 1024
         | 
| 90 | 
            -
                    self.proj = torch.nn.Linear(
         | 
| 91 | 
            -
                        clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim
         | 
| 92 | 
            -
                    )
         | 
| 93 | 
            -
                    self.norm = torch.nn.LayerNorm(cross_attention_dim)
         | 
| 94 | 
            -
             | 
| 95 | 
            -
                def forward(self, image_embeds):
         | 
| 96 | 
            -
                    embeds = image_embeds
         | 
| 97 | 
            -
                    clip_extra_context_tokens = self.proj(embeds).reshape(
         | 
| 98 | 
            -
                        -1, self.clip_extra_context_tokens, self.cross_attention_dim
         | 
| 99 | 
            -
                    )
         | 
| 100 | 
            -
                    clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
         | 
| 101 | 
            -
                    return clip_extra_context_tokens
         | 
| 102 | 
            -
             | 
| 103 | 
            -
             | 
| 104 | 
             
            class Resampler(nn.Module):
         | 
| 105 | 
             
                def __init__(
         | 
| 106 | 
             
                    self,
         | 
|  | |
| 73 | 
             
                    return self.to_out(out)
         | 
| 74 |  | 
| 75 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 76 | 
             
            class Resampler(nn.Module):
         | 
| 77 | 
             
                def __init__(
         | 
| 78 | 
             
                    self,
         | 
    	
        mvdream/attention.py
    CHANGED
    
    | @@ -88,7 +88,7 @@ class MemoryEfficientCrossAttention(nn.Module): | |
| 88 | 
             
                    context = default(context, x)
         | 
| 89 |  | 
| 90 | 
             
                    if self.ip_dim > 0:
         | 
| 91 | 
            -
                        # context  | 
| 92 | 
             
                        token_len = context.shape[1]
         | 
| 93 | 
             
                        context_ip = context[:, -self.ip_dim :, :]
         | 
| 94 | 
             
                        k_ip = self.to_k_ip(context_ip)
         | 
| @@ -212,9 +212,7 @@ class SpatialTransformer3D(nn.Module): | |
| 212 | 
             
                    self.in_channels = in_channels
         | 
| 213 |  | 
| 214 | 
             
                    inner_dim = n_heads * d_head
         | 
| 215 | 
            -
                    self.norm = nn.GroupNorm(
         | 
| 216 | 
            -
                        num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
         | 
| 217 | 
            -
                    )
         | 
| 218 | 
             
                    self.proj_in = nn.Linear(in_channels, inner_dim)
         | 
| 219 |  | 
| 220 | 
             
                    self.transformer_blocks = nn.ModuleList(
         | 
|  | |
| 88 | 
             
                    context = default(context, x)
         | 
| 89 |  | 
| 90 | 
             
                    if self.ip_dim > 0:
         | 
| 91 | 
            +
                        # context: [B, 77 + 16(ip), 1024]
         | 
| 92 | 
             
                        token_len = context.shape[1]
         | 
| 93 | 
             
                        context_ip = context[:, -self.ip_dim :, :]
         | 
| 94 | 
             
                        k_ip = self.to_k_ip(context_ip)
         | 
|  | |
| 212 | 
             
                    self.in_channels = in_channels
         | 
| 213 |  | 
| 214 | 
             
                    inner_dim = n_heads * d_head
         | 
| 215 | 
            +
                    self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
         | 
|  | |
|  | |
| 216 | 
             
                    self.proj_in = nn.Linear(in_channels, inner_dim)
         | 
| 217 |  | 
| 218 | 
             
                    self.transformer_blocks = nn.ModuleList(
         | 
    	
        mvdream/models.py
    CHANGED
    
    | @@ -14,7 +14,7 @@ from .util import ( | |
| 14 | 
             
                timestep_embedding,
         | 
| 15 | 
             
            )
         | 
| 16 | 
             
            from .attention import SpatialTransformer3D
         | 
| 17 | 
            -
            from .adaptor import Resampler | 
| 18 |  | 
| 19 | 
             
            import kiui
         | 
| 20 |  | 
| @@ -266,15 +266,13 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin): | |
| 266 | 
             
                    num_heads_upsample=-1,
         | 
| 267 | 
             
                    use_scale_shift_norm=False,
         | 
| 268 | 
             
                    resblock_updown=False,
         | 
| 269 | 
            -
                    transformer_depth=1, | 
| 270 | 
            -
                    context_dim=None, | 
| 271 | 
            -
                    n_embed=None, | 
| 272 | 
            -
                    disable_self_attentions=None,
         | 
| 273 | 
             
                    num_attention_blocks=None,
         | 
| 274 | 
            -
                    disable_middle_self_attn=False,
         | 
| 275 | 
             
                    adm_in_channels=None,
         | 
| 276 | 
             
                    camera_dim=None,
         | 
| 277 | 
            -
                    ip_dim=0,
         | 
| 278 | 
             
                    ip_weight=1.0,
         | 
| 279 | 
             
                    **kwargs,
         | 
| 280 | 
             
                ):
         | 
| @@ -604,7 +602,7 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin): | |
| 604 |  | 
| 605 | 
             
                    # imagedream variant
         | 
| 606 | 
             
                    if self.ip_dim > 0:
         | 
| 607 | 
            -
                        x[(num_frames - 1) :: num_frames, :, :, :] = ip_img
         | 
| 608 | 
             
                        ip_emb = self.image_embed(ip)
         | 
| 609 | 
             
                        context = torch.cat((context, ip_emb), 1)
         | 
| 610 |  | 
|  | |
| 14 | 
             
                timestep_embedding,
         | 
| 15 | 
             
            )
         | 
| 16 | 
             
            from .attention import SpatialTransformer3D
         | 
| 17 | 
            +
            from .adaptor import Resampler
         | 
| 18 |  | 
| 19 | 
             
            import kiui
         | 
| 20 |  | 
|  | |
| 266 | 
             
                    num_heads_upsample=-1,
         | 
| 267 | 
             
                    use_scale_shift_norm=False,
         | 
| 268 | 
             
                    resblock_updown=False,
         | 
| 269 | 
            +
                    transformer_depth=1,
         | 
| 270 | 
            +
                    context_dim=None,
         | 
| 271 | 
            +
                    n_embed=None,
         | 
|  | |
| 272 | 
             
                    num_attention_blocks=None,
         | 
|  | |
| 273 | 
             
                    adm_in_channels=None,
         | 
| 274 | 
             
                    camera_dim=None,
         | 
| 275 | 
            +
                    ip_dim=0, # imagedream uses ip_dim > 0
         | 
| 276 | 
             
                    ip_weight=1.0,
         | 
| 277 | 
             
                    **kwargs,
         | 
| 278 | 
             
                ):
         | 
|  | |
| 602 |  | 
| 603 | 
             
                    # imagedream variant
         | 
| 604 | 
             
                    if self.ip_dim > 0:
         | 
| 605 | 
            +
                        x[(num_frames - 1) :: num_frames, :, :, :] = ip_img # place at [4, 9]
         | 
| 606 | 
             
                        ip_emb = self.image_embed(ip)
         | 
| 607 | 
             
                        context = torch.cat((context, ip_emb), 1)
         | 
| 608 |  | 
    	
        mvdream/pipeline_mvdream.py
    CHANGED
    
    | @@ -405,29 +405,27 @@ class MVDreamPipeline(DiffusionPipeline): | |
| 405 | 
             
                def encode_image(self, image, device, num_images_per_prompt):
         | 
| 406 | 
             
                    dtype = next(self.image_encoder.parameters()).dtype
         | 
| 407 |  | 
| 408 | 
            -
                     | 
|  | |
|  | |
| 409 | 
             
                    image = self.feature_extractor(image, return_tensors="pt").pixel_values
         | 
| 410 | 
            -
                    
         | 
| 411 | 
             
                    image = image.to(device=device, dtype=dtype)
         | 
| 412 |  | 
| 413 | 
            -
                     | 
| 414 | 
            -
                     | 
| 415 |  | 
| 416 | 
            -
                     | 
| 417 | 
            -
                    uncond_image_enc_hidden_states = torch.zeros_like(image_enc_hidden_states)
         | 
| 418 | 
            -
             | 
| 419 | 
            -
                    return uncond_image_enc_hidden_states, image_enc_hidden_states
         | 
| 420 |  | 
| 421 | 
             
                def encode_image_latents(self, image, device, num_images_per_prompt):
         | 
| 422 |  | 
| 423 | 
            -
                    image = torch.from_numpy(image).unsqueeze(0).permute(0, 3, 1, 2) # [1, 3, H, W]
         | 
| 424 | 
            -
                    image = image.to(device=device)
         | 
| 425 | 
            -
                    image = F.interpolate(image, (256, 256), mode='bilinear', align_corners=False)
         | 
| 426 | 
             
                    dtype = next(self.image_encoder.parameters()).dtype
         | 
|  | |
|  | |
|  | |
|  | |
| 427 | 
             
                    image = image.to(dtype=dtype)
         | 
| 428 |  | 
| 429 | 
             
                    posterior = self.vae.encode(image).latent_dist
         | 
| 430 | 
            -
             | 
| 431 | 
             
                    latents = posterior.sample() * self.vae.config.scaling_factor # [B, C, H, W]
         | 
| 432 | 
             
                    latents = latents.repeat_interleave(num_images_per_prompt, dim=0)
         | 
| 433 |  | 
| @@ -436,13 +434,13 @@ class MVDreamPipeline(DiffusionPipeline): | |
| 436 | 
             
                @torch.no_grad()
         | 
| 437 | 
             
                def __call__(
         | 
| 438 | 
             
                    self,
         | 
| 439 | 
            -
                    prompt: str = " | 
| 440 | 
             
                    image: Optional[np.ndarray] = None,
         | 
| 441 | 
             
                    height: int = 256,
         | 
| 442 | 
             
                    width: int = 256,
         | 
| 443 | 
             
                    num_inference_steps: int = 50,
         | 
| 444 | 
             
                    guidance_scale: float = 7.0,
         | 
| 445 | 
            -
                    negative_prompt: str = " | 
| 446 | 
             
                    num_images_per_prompt: int = 1,
         | 
| 447 | 
             
                    eta: float = 0.0,
         | 
| 448 | 
             
                    generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
         | 
| @@ -454,7 +452,6 @@ class MVDreamPipeline(DiffusionPipeline): | |
| 454 | 
             
                ):
         | 
| 455 | 
             
                    self.unet = self.unet.to(device=device)
         | 
| 456 | 
             
                    self.vae = self.vae.to(device=device)
         | 
| 457 | 
            -
             | 
| 458 | 
             
                    self.text_encoder = self.text_encoder.to(device=device)
         | 
| 459 |  | 
| 460 | 
             
                    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
         | 
| @@ -466,10 +463,9 @@ class MVDreamPipeline(DiffusionPipeline): | |
| 466 | 
             
                    self.scheduler.set_timesteps(num_inference_steps, device=device)
         | 
| 467 | 
             
                    timesteps = self.scheduler.timesteps
         | 
| 468 |  | 
| 469 | 
            -
                    # imagedream variant | 
| 470 | 
             
                    if image is not None:
         | 
| 471 | 
             
                        assert isinstance(image, np.ndarray) and image.dtype == np.float32
         | 
| 472 | 
            -
             | 
| 473 | 
             
                        self.image_encoder = self.image_encoder.to(device=device)
         | 
| 474 | 
             
                        image_embeds_neg, image_embeds_pos = self.encode_image(image, device, num_images_per_prompt)
         | 
| 475 | 
             
                        image_latents_neg, image_latents_pos = self.encode_image_latents(image, device, num_images_per_prompt)
         | 
| @@ -496,7 +492,11 @@ class MVDreamPipeline(DiffusionPipeline): | |
| 496 | 
             
                        None,
         | 
| 497 | 
             
                    )
         | 
| 498 |  | 
| 499 | 
            -
                     | 
|  | |
|  | |
|  | |
|  | |
| 500 |  | 
| 501 | 
             
                    # Prepare extra step kwargs.
         | 
| 502 | 
             
                    extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
         | 
| @@ -508,10 +508,7 @@ class MVDreamPipeline(DiffusionPipeline): | |
| 508 | 
             
                            # expand the latents if we are doing classifier free guidance
         | 
| 509 | 
             
                            multiplier = 2 if do_classifier_free_guidance else 1
         | 
| 510 | 
             
                            latent_model_input = torch.cat([latents] * multiplier)
         | 
| 511 | 
            -
                            latent_model_input = self.scheduler.scale_model_input(
         | 
| 512 | 
            -
                                latent_model_input, t
         | 
| 513 | 
            -
                            )
         | 
| 514 | 
            -
             | 
| 515 |  | 
| 516 | 
             
                            unet_inputs = {
         | 
| 517 | 
             
                                'x': latent_model_input,
         | 
|  | |
| 405 | 
             
                def encode_image(self, image, device, num_images_per_prompt):
         | 
| 406 | 
             
                    dtype = next(self.image_encoder.parameters()).dtype
         | 
| 407 |  | 
| 408 | 
            +
                    if image.dtype == np.float32:
         | 
| 409 | 
            +
                        image = (image * 255).astype(np.uint8)
         | 
| 410 | 
            +
                        
         | 
| 411 | 
             
                    image = self.feature_extractor(image, return_tensors="pt").pixel_values
         | 
|  | |
| 412 | 
             
                    image = image.to(device=device, dtype=dtype)
         | 
| 413 |  | 
| 414 | 
            +
                    image_embeds = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
         | 
| 415 | 
            +
                    image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
         | 
| 416 |  | 
| 417 | 
            +
                    return torch.zeros_like(image_embeds), image_embeds
         | 
|  | |
|  | |
|  | |
| 418 |  | 
| 419 | 
             
                def encode_image_latents(self, image, device, num_images_per_prompt):
         | 
| 420 |  | 
|  | |
|  | |
|  | |
| 421 | 
             
                    dtype = next(self.image_encoder.parameters()).dtype
         | 
| 422 | 
            +
             | 
| 423 | 
            +
                    image = torch.from_numpy(image).unsqueeze(0).permute(0, 3, 1, 2).to(device=device) # [1, 3, H, W]
         | 
| 424 | 
            +
                    image = 2 * image - 1
         | 
| 425 | 
            +
                    image = F.interpolate(image, (256, 256), mode='bilinear', align_corners=False)
         | 
| 426 | 
             
                    image = image.to(dtype=dtype)
         | 
| 427 |  | 
| 428 | 
             
                    posterior = self.vae.encode(image).latent_dist
         | 
|  | |
| 429 | 
             
                    latents = posterior.sample() * self.vae.config.scaling_factor # [B, C, H, W]
         | 
| 430 | 
             
                    latents = latents.repeat_interleave(num_images_per_prompt, dim=0)
         | 
| 431 |  | 
|  | |
| 434 | 
             
                @torch.no_grad()
         | 
| 435 | 
             
                def __call__(
         | 
| 436 | 
             
                    self,
         | 
| 437 | 
            +
                    prompt: str = "",
         | 
| 438 | 
             
                    image: Optional[np.ndarray] = None,
         | 
| 439 | 
             
                    height: int = 256,
         | 
| 440 | 
             
                    width: int = 256,
         | 
| 441 | 
             
                    num_inference_steps: int = 50,
         | 
| 442 | 
             
                    guidance_scale: float = 7.0,
         | 
| 443 | 
            +
                    negative_prompt: str = "",
         | 
| 444 | 
             
                    num_images_per_prompt: int = 1,
         | 
| 445 | 
             
                    eta: float = 0.0,
         | 
| 446 | 
             
                    generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
         | 
|  | |
| 452 | 
             
                ):
         | 
| 453 | 
             
                    self.unet = self.unet.to(device=device)
         | 
| 454 | 
             
                    self.vae = self.vae.to(device=device)
         | 
|  | |
| 455 | 
             
                    self.text_encoder = self.text_encoder.to(device=device)
         | 
| 456 |  | 
| 457 | 
             
                    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
         | 
|  | |
| 463 | 
             
                    self.scheduler.set_timesteps(num_inference_steps, device=device)
         | 
| 464 | 
             
                    timesteps = self.scheduler.timesteps
         | 
| 465 |  | 
| 466 | 
            +
                    # imagedream variant
         | 
| 467 | 
             
                    if image is not None:
         | 
| 468 | 
             
                        assert isinstance(image, np.ndarray) and image.dtype == np.float32
         | 
|  | |
| 469 | 
             
                        self.image_encoder = self.image_encoder.to(device=device)
         | 
| 470 | 
             
                        image_embeds_neg, image_embeds_pos = self.encode_image(image, device, num_images_per_prompt)
         | 
| 471 | 
             
                        image_latents_neg, image_latents_pos = self.encode_image_latents(image, device, num_images_per_prompt)
         | 
|  | |
| 492 | 
             
                        None,
         | 
| 493 | 
             
                    )
         | 
| 494 |  | 
| 495 | 
            +
                    if image is not None:
         | 
| 496 | 
            +
                        camera = get_camera(num_frames, elevation=5, extra_view=True).to(dtype=latents.dtype, device=device)
         | 
| 497 | 
            +
                    else:
         | 
| 498 | 
            +
                        camera = get_camera(num_frames, elevation=15, extra_view=False).to(dtype=latents.dtype, device=device)
         | 
| 499 | 
            +
                    camera = camera.repeat_interleave(num_images_per_prompt, dim=0)
         | 
| 500 |  | 
| 501 | 
             
                    # Prepare extra step kwargs.
         | 
| 502 | 
             
                    extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
         | 
|  | |
| 508 | 
             
                            # expand the latents if we are doing classifier free guidance
         | 
| 509 | 
             
                            multiplier = 2 if do_classifier_free_guidance else 1
         | 
| 510 | 
             
                            latent_model_input = torch.cat([latents] * multiplier)
         | 
| 511 | 
            +
                            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
         | 
|  | |
|  | |
|  | |
| 512 |  | 
| 513 | 
             
                            unet_inputs = {
         | 
| 514 | 
             
                                'x': latent_model_input,
         | 
    	
        requirements.lock.txt
    ADDED
    
    | @@ -0,0 +1,6 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            omegaconf == 2.3.0
         | 
| 2 | 
            +
            diffusers == 0.23.1
         | 
| 3 | 
            +
            safetensors == 0.4.1
         | 
| 4 | 
            +
            huggingface_hub == 0.19.4
         | 
| 5 | 
            +
            transformers == 4.35.2
         | 
| 6 | 
            +
            accelerate == 0.25.0.dev0
         | 
    	
        requirements.txt
    ADDED
    
    | @@ -0,0 +1,6 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            omegaconf 
         | 
| 2 | 
            +
            diffusers 
         | 
| 3 | 
            +
            safetensors 
         | 
| 4 | 
            +
            huggingface_hub 
         | 
| 5 | 
            +
            transformers 
         | 
| 6 | 
            +
            accelerate
         | 
    	
        run_imagedream.py
    CHANGED
    
    | @@ -17,9 +17,9 @@ parser.add_argument("image", type=str, default='data/anya_rgba.png') | |
| 17 | 
             
            parser.add_argument("--prompt", type=str, default="")
         | 
| 18 | 
             
            args = parser.parse_args()
         | 
| 19 |  | 
| 20 | 
            -
             | 
| 21 | 
             
                input_image = kiui.read_image(args.image, mode='float')
         | 
| 22 | 
            -
                image = pipe(args.prompt, input_image)
         | 
| 23 | 
             
                grid = np.concatenate(
         | 
| 24 | 
             
                    [
         | 
| 25 | 
             
                        np.concatenate([image[0], image[2]], axis=0),
         | 
| @@ -28,5 +28,4 @@ while True: | |
| 28 | 
             
                    axis=1,
         | 
| 29 | 
             
                )
         | 
| 30 | 
             
                # kiui.vis.plot_image(grid)
         | 
| 31 | 
            -
                kiui.write_image(' | 
| 32 | 
            -
                break
         | 
|  | |
| 17 | 
             
            parser.add_argument("--prompt", type=str, default="")
         | 
| 18 | 
             
            args = parser.parse_args()
         | 
| 19 |  | 
| 20 | 
            +
            for i in range(5):
         | 
| 21 | 
             
                input_image = kiui.read_image(args.image, mode='float')
         | 
| 22 | 
            +
                image = pipe(args.prompt, input_image, guidance_scale=5)
         | 
| 23 | 
             
                grid = np.concatenate(
         | 
| 24 | 
             
                    [
         | 
| 25 | 
             
                        np.concatenate([image[0], image[2]], axis=0),
         | 
|  | |
| 28 | 
             
                    axis=1,
         | 
| 29 | 
             
                )
         | 
| 30 | 
             
                # kiui.vis.plot_image(grid)
         | 
| 31 | 
            +
                kiui.write_image(f'test_imagedream_{i}.jpg', grid)
         | 
|  | 
    	
        run_mvdream.py
    CHANGED
    
    | @@ -5,7 +5,7 @@ import argparse | |
| 5 | 
             
            from mvdream.pipeline_mvdream import MVDreamPipeline
         | 
| 6 |  | 
| 7 | 
             
            pipe = MVDreamPipeline.from_pretrained(
         | 
| 8 | 
            -
                "./ | 
| 9 | 
             
                # "ashawkey/mvdream-sd2.1-diffusers",
         | 
| 10 | 
             
                torch_dtype=torch.float16
         | 
| 11 | 
             
            )
         | 
| @@ -16,7 +16,7 @@ parser = argparse.ArgumentParser(description="MVDream") | |
| 16 | 
             
            parser.add_argument("prompt", type=str, default="a cute owl 3d model")
         | 
| 17 | 
             
            args = parser.parse_args()
         | 
| 18 |  | 
| 19 | 
            -
             | 
| 20 | 
             
                image = pipe(args.prompt)
         | 
| 21 | 
             
                grid = np.concatenate(
         | 
| 22 | 
             
                    [
         | 
| @@ -26,5 +26,4 @@ while True: | |
| 26 | 
             
                    axis=1,
         | 
| 27 | 
             
                )
         | 
| 28 | 
             
                # kiui.vis.plot_image(grid)
         | 
| 29 | 
            -
                kiui.write_image(' | 
| 30 | 
            -
                break
         | 
|  | |
| 5 | 
             
            from mvdream.pipeline_mvdream import MVDreamPipeline
         | 
| 6 |  | 
| 7 | 
             
            pipe = MVDreamPipeline.from_pretrained(
         | 
| 8 | 
            +
                "./weights_mvdream", # local weights
         | 
| 9 | 
             
                # "ashawkey/mvdream-sd2.1-diffusers",
         | 
| 10 | 
             
                torch_dtype=torch.float16
         | 
| 11 | 
             
            )
         | 
|  | |
| 16 | 
             
            parser.add_argument("prompt", type=str, default="a cute owl 3d model")
         | 
| 17 | 
             
            args = parser.parse_args()
         | 
| 18 |  | 
| 19 | 
            +
            for i in range(5):
         | 
| 20 | 
             
                image = pipe(args.prompt)
         | 
| 21 | 
             
                grid = np.concatenate(
         | 
| 22 | 
             
                    [
         | 
|  | |
| 26 | 
             
                    axis=1,
         | 
| 27 | 
             
                )
         | 
| 28 | 
             
                # kiui.vis.plot_image(grid)
         | 
| 29 | 
            +
                kiui.write_image(f'test_mvdream_{i}.jpg', grid)
         | 
|  | 

