Chao Xu
		
	commited on
		
		
					Commit 
							
							·
						
						0e93edd
	
1
								Parent(s):
							
							c534da1
								
test rm taming
Browse files- taming-transformers/.gitignore +0 -2
- taming-transformers/License.txt +0 -19
- taming-transformers/README.md +0 -410
- taming-transformers/configs/coco_cond_stage.yaml +0 -49
- taming-transformers/configs/coco_scene_images_transformer.yaml +0 -80
- taming-transformers/configs/custom_vqgan.yaml +0 -43
- taming-transformers/configs/drin_transformer.yaml +0 -77
- taming-transformers/configs/faceshq_transformer.yaml +0 -61
- taming-transformers/configs/faceshq_vqgan.yaml +0 -42
- taming-transformers/configs/imagenet_vqgan.yaml +0 -42
- taming-transformers/configs/imagenetdepth_vqgan.yaml +0 -41
- taming-transformers/configs/open_images_scene_images_transformer.yaml +0 -86
- taming-transformers/configs/sflckr_cond_stage.yaml +0 -43
- taming-transformers/environment.yaml +0 -25
- taming-transformers/main.py +0 -585
- taming-transformers/scripts/extract_depth.py +0 -112
- taming-transformers/scripts/extract_segmentation.py +0 -130
- taming-transformers/scripts/extract_submodel.py +0 -17
- taming-transformers/scripts/make_samples.py +0 -292
- taming-transformers/scripts/make_scene_samples.py +0 -198
- taming-transformers/scripts/sample_conditional.py +0 -355
- taming-transformers/scripts/sample_fast.py +0 -260
- taming-transformers/setup.py +0 -13
- taming-transformers/taming/lr_scheduler.py +0 -34
- taming-transformers/taming/models/cond_transformer.py +0 -352
- taming-transformers/taming/models/dummy_cond_stage.py +0 -22
- taming-transformers/taming/models/vqgan.py +0 -404
- taming-transformers/taming/modules/diffusionmodules/model.py +0 -776
- taming-transformers/taming/modules/discriminator/model.py +0 -67
- taming-transformers/taming/modules/losses/__init__.py +0 -2
- taming-transformers/taming/modules/losses/lpips.py +0 -123
- taming-transformers/taming/modules/losses/segmentation.py +0 -22
- taming-transformers/taming/modules/losses/vqperceptual.py +0 -136
- taming-transformers/taming/modules/misc/coord.py +0 -31
- taming-transformers/taming/modules/transformer/mingpt.py +0 -415
- taming-transformers/taming/modules/transformer/permuter.py +0 -248
- taming-transformers/taming/modules/util.py +0 -130
- taming-transformers/taming/modules/vqvae/quantize.py +0 -445
- taming-transformers/taming/util.py +0 -157
- taming-transformers/taming_transformers.egg-info/PKG-INFO +0 -10
- taming-transformers/taming_transformers.egg-info/SOURCES.txt +0 -7
- taming-transformers/taming_transformers.egg-info/dependency_links.txt +0 -1
- taming-transformers/taming_transformers.egg-info/requires.txt +0 -3
- taming-transformers/taming_transformers.egg-info/top_level.txt +0 -1
    	
        taming-transformers/.gitignore
    DELETED
    
    | @@ -1,2 +0,0 @@ | |
| 1 | 
            -
            assets/
         | 
| 2 | 
            -
            data/
         | 
|  | |
|  | |
|  | 
    	
        taming-transformers/License.txt
    DELETED
    
    | @@ -1,19 +0,0 @@ | |
| 1 | 
            -
            Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer
         | 
| 2 | 
            -
             | 
| 3 | 
            -
            Permission is hereby granted, free of charge, to any person obtaining a copy
         | 
| 4 | 
            -
            of this software and associated documentation files (the "Software"), to deal
         | 
| 5 | 
            -
            in the Software without restriction, including without limitation the rights
         | 
| 6 | 
            -
            to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
         | 
| 7 | 
            -
            copies of the Software, and to permit persons to whom the Software is
         | 
| 8 | 
            -
            furnished to do so, subject to the following conditions:
         | 
| 9 | 
            -
             | 
| 10 | 
            -
            The above copyright notice and this permission notice shall be included in all
         | 
| 11 | 
            -
            copies or substantial portions of the Software.
         | 
| 12 | 
            -
             | 
| 13 | 
            -
            THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
         | 
| 14 | 
            -
            EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
         | 
| 15 | 
            -
            MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
         | 
| 16 | 
            -
            IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
         | 
| 17 | 
            -
            DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
         | 
| 18 | 
            -
            OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
         | 
| 19 | 
            -
            OR OTHER DEALINGS IN THE SOFTWARE./
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        taming-transformers/README.md
    DELETED
    
    | @@ -1,410 +0,0 @@ | |
| 1 | 
            -
            # Taming Transformers for High-Resolution Image Synthesis
         | 
| 2 | 
            -
            ##### CVPR 2021 (Oral)
         | 
| 3 | 
            -
            
         | 
| 4 | 
            -
             | 
| 5 | 
            -
            [**Taming Transformers for High-Resolution Image Synthesis**](https://compvis.github.io/taming-transformers/)<br/>
         | 
| 6 | 
            -
            [Patrick Esser](https://github.com/pesser)\*,
         | 
| 7 | 
            -
            [Robin Rombach](https://github.com/rromb)\*,
         | 
| 8 | 
            -
            [Björn Ommer](https://hci.iwr.uni-heidelberg.de/Staff/bommer)<br/>
         | 
| 9 | 
            -
            \* equal contribution
         | 
| 10 | 
            -
             | 
| 11 | 
            -
            **tl;dr** We combine the efficiancy of convolutional approaches with the expressivity of transformers by introducing a convolutional VQGAN, which learns a codebook of context-rich visual parts, whose composition is modeled with an autoregressive transformer.
         | 
| 12 | 
            -
             | 
| 13 | 
            -
            
         | 
| 14 | 
            -
            [arXiv](https://arxiv.org/abs/2012.09841) | [BibTeX](#bibtex) | [Project Page](https://compvis.github.io/taming-transformers/)
         | 
| 15 | 
            -
             | 
| 16 | 
            -
             | 
| 17 | 
            -
            ### News
         | 
| 18 | 
            -
            #### 2022
         | 
| 19 | 
            -
            - More pretrained VQGANs (e.g. a f8-model with only 256 codebook entries) are available in our new work on [Latent Diffusion Models](https://github.com/CompVis/latent-diffusion).
         | 
| 20 | 
            -
            - Added scene synthesis models as proposed in the paper [High-Resolution Complex Scene Synthesis with Transformers](https://arxiv.org/abs/2105.06458), see [this section](#scene-image-synthesis).
         | 
| 21 | 
            -
            #### 2021
         | 
| 22 | 
            -
            - Thanks to [rom1504](https://github.com/rom1504) it is now easy to [train a VQGAN on your own datasets](#training-on-custom-data).
         | 
| 23 | 
            -
            - Included a bugfix for the quantizer. For backward compatibility it is
         | 
| 24 | 
            -
              disabled by default (which corresponds to always training with `beta=1.0`).
         | 
| 25 | 
            -
              Use `legacy=False` in the quantizer config to enable it.
         | 
| 26 | 
            -
              Thanks [richcmwang](https://github.com/richcmwang) and [wcshin-git](https://github.com/wcshin-git)!
         | 
| 27 | 
            -
            - Our paper received an update: See https://arxiv.org/abs/2012.09841v3 and the corresponding changelog.
         | 
| 28 | 
            -
            - Added a pretrained, [1.4B transformer model](https://k00.fr/s511rwcv) trained for class-conditional ImageNet synthesis, which obtains state-of-the-art FID scores among autoregressive approaches and outperforms BigGAN.
         | 
| 29 | 
            -
            - Added pretrained, unconditional models on [FFHQ](https://k00.fr/yndvfu95) and [CelebA-HQ](https://k00.fr/2xkmielf).
         | 
| 30 | 
            -
            - Added accelerated sampling via caching of keys/values in the self-attention operation, used in `scripts/sample_fast.py`.
         | 
| 31 | 
            -
            - Added a checkpoint of a [VQGAN](https://heibox.uni-heidelberg.de/d/2e5662443a6b4307b470/) trained with f8 compression and Gumbel-Quantization. 
         | 
| 32 | 
            -
              See also our updated [reconstruction notebook](https://colab.research.google.com/github/CompVis/taming-transformers/blob/master/scripts/reconstruction_usage.ipynb). 
         | 
| 33 | 
            -
            - We added a [colab notebook](https://colab.research.google.com/github/CompVis/taming-transformers/blob/master/scripts/reconstruction_usage.ipynb) which compares two VQGANs and OpenAI's [DALL-E](https://github.com/openai/DALL-E). See also [this section](#more-resources).
         | 
| 34 | 
            -
            - We now include an overview of pretrained models in [Tab.1](#overview-of-pretrained-models). We added models for [COCO](#coco) and [ADE20k](#ade20k).
         | 
| 35 | 
            -
            - The streamlit demo now supports image completions.
         | 
| 36 | 
            -
            - We now include a couple of examples from the D-RIN dataset so you can run the
         | 
| 37 | 
            -
              [D-RIN demo](#d-rin) without preparing the dataset first.
         | 
| 38 | 
            -
            - You can now jump right into sampling with our [Colab quickstart notebook](https://colab.research.google.com/github/CompVis/taming-transformers/blob/master/scripts/taming-transformers.ipynb).
         | 
| 39 | 
            -
             | 
| 40 | 
            -
            ## Requirements
         | 
| 41 | 
            -
            A suitable [conda](https://conda.io/) environment named `taming` can be created
         | 
| 42 | 
            -
            and activated with:
         | 
| 43 | 
            -
             | 
| 44 | 
            -
            ```
         | 
| 45 | 
            -
            conda env create -f environment.yaml
         | 
| 46 | 
            -
            conda activate taming
         | 
| 47 | 
            -
            ```
         | 
| 48 | 
            -
            ## Overview of pretrained models
         | 
| 49 | 
            -
            The following table provides an overview of all models that are currently available. 
         | 
| 50 | 
            -
            FID scores were evaluated using [torch-fidelity](https://github.com/toshas/torch-fidelity).
         | 
| 51 | 
            -
            For reference, we also include a link to the recently released autoencoder of the [DALL-E](https://github.com/openai/DALL-E) model. 
         | 
| 52 | 
            -
            See the corresponding [colab
         | 
| 53 | 
            -
            notebook](https://colab.research.google.com/github/CompVis/taming-transformers/blob/master/scripts/reconstruction_usage.ipynb)
         | 
| 54 | 
            -
            for a comparison and discussion of reconstruction capabilities.
         | 
| 55 | 
            -
             | 
| 56 | 
            -
            | Dataset  | FID vs train | FID vs val | Link |  Samples (256x256) | Comments
         | 
| 57 | 
            -
            | ------------- | ------------- | ------------- |-------------  | -------------  |-------------  |
         | 
| 58 | 
            -
            | FFHQ (f=16) | 9.6 | -- | [ffhq_transformer](https://k00.fr/yndvfu95) |  [ffhq_samples](https://k00.fr/j626x093) |
         | 
| 59 | 
            -
            | CelebA-HQ (f=16) | 10.2 | -- | [celebahq_transformer](https://k00.fr/2xkmielf) | [celebahq_samples](https://k00.fr/j626x093) |
         | 
| 60 | 
            -
            | ADE20K (f=16) | -- | 35.5 | [ade20k_transformer](https://k00.fr/ot46cksa) | [ade20k_samples.zip](https://heibox.uni-heidelberg.de/f/70bb78cbaf844501b8fb/) [2k] | evaluated on val split (2k images)
         | 
| 61 | 
            -
            | COCO-Stuff (f=16) | -- | 20.4  | [coco_transformer](https://k00.fr/2zz6i2ce) | [coco_samples.zip](https://heibox.uni-heidelberg.de/f/a395a9be612f4a7a8054/) [5k] | evaluated on val split (5k images)
         | 
| 62 | 
            -
            | ImageNet (cIN) (f=16) | 15.98/15.78/6.59/5.88/5.20 | -- | [cin_transformer](https://k00.fr/s511rwcv) | [cin_samples](https://k00.fr/j626x093) | different decoding hyperparameters |  
         | 
| 63 | 
            -
            | |  | | || |
         | 
| 64 | 
            -
            | FacesHQ (f=16) | -- |  -- | [faceshq_transformer](https://k00.fr/qqfl2do8)
         | 
| 65 | 
            -
            | S-FLCKR (f=16) | -- | -- | [sflckr](https://heibox.uni-heidelberg.de/d/73487ab6e5314cb5adba/) 
         | 
| 66 | 
            -
            | D-RIN (f=16) | -- | -- | [drin_transformer](https://k00.fr/39jcugc5)
         | 
| 67 | 
            -
            | | |  | | || |
         | 
| 68 | 
            -
            | VQGAN ImageNet (f=16), 1024 |  10.54 | 7.94 | [vqgan_imagenet_f16_1024](https://heibox.uni-heidelberg.de/d/8088892a516d4e3baf92/) | [reconstructions](https://k00.fr/j626x093) | Reconstruction-FIDs.
         | 
| 69 | 
            -
            | VQGAN ImageNet (f=16), 16384 | 7.41 | 4.98 |[vqgan_imagenet_f16_16384](https://heibox.uni-heidelberg.de/d/a7530b09fed84f80a887/)  |  [reconstructions](https://k00.fr/j626x093) | Reconstruction-FIDs.
         | 
| 70 | 
            -
            | VQGAN OpenImages (f=8), 256 | -- | 1.49 |https://ommer-lab.com/files/latent-diffusion/vq-f8-n256.zip |  ---  | Reconstruction-FIDs. Available via [latent diffusion](https://github.com/CompVis/latent-diffusion).
         | 
| 71 | 
            -
            | VQGAN OpenImages (f=8), 16384 | -- | 1.14 |https://ommer-lab.com/files/latent-diffusion/vq-f8.zip  |  ---  | Reconstruction-FIDs. Available via [latent diffusion](https://github.com/CompVis/latent-diffusion)
         | 
| 72 | 
            -
            | VQGAN OpenImages (f=8), 8192, GumbelQuantization | 3.24 | 1.49 |[vqgan_gumbel_f8](https://heibox.uni-heidelberg.de/d/2e5662443a6b4307b470/)  |  ---  | Reconstruction-FIDs.
         | 
| 73 | 
            -
            | | |  | | || |
         | 
| 74 | 
            -
            | DALL-E dVAE (f=8), 8192, GumbelQuantization | 33.88 | 32.01 | https://github.com/openai/DALL-E | [reconstructions](https://k00.fr/j626x093) | Reconstruction-FIDs.
         | 
| 75 | 
            -
             | 
| 76 | 
            -
             | 
| 77 | 
            -
            ## Running pretrained models
         | 
| 78 | 
            -
             | 
| 79 | 
            -
            The commands below will start a streamlit demo which supports sampling at
         | 
| 80 | 
            -
            different resolutions and image completions. To run a non-interactive version
         | 
| 81 | 
            -
            of the sampling process, replace `streamlit run scripts/sample_conditional.py --`
         | 
| 82 | 
            -
            by `python scripts/make_samples.py --outdir <path_to_write_samples_to>` and
         | 
| 83 | 
            -
            keep the remaining command line arguments. 
         | 
| 84 | 
            -
             | 
| 85 | 
            -
            To sample from unconditional or class-conditional models, 
         | 
| 86 | 
            -
            run `python scripts/sample_fast.py -r <path/to/config_and_checkpoint>`.
         | 
| 87 | 
            -
            We describe below how to use this script to sample from the ImageNet, FFHQ, and CelebA-HQ models, 
         | 
| 88 | 
            -
            respectively.
         | 
| 89 | 
            -
             | 
| 90 | 
            -
            ### S-FLCKR
         | 
| 91 | 
            -
            
         | 
| 92 | 
            -
             | 
| 93 | 
            -
            You can also [run this model in a Colab
         | 
| 94 | 
            -
            notebook](https://colab.research.google.com/github/CompVis/taming-transformers/blob/master/scripts/taming-transformers.ipynb),
         | 
| 95 | 
            -
            which includes all necessary steps to start sampling.
         | 
| 96 | 
            -
             | 
| 97 | 
            -
            Download the
         | 
| 98 | 
            -
            [2020-11-09T13-31-51_sflckr](https://heibox.uni-heidelberg.de/d/73487ab6e5314cb5adba/)
         | 
| 99 | 
            -
            folder and place it into `logs`. Then, run
         | 
| 100 | 
            -
            ```
         | 
| 101 | 
            -
            streamlit run scripts/sample_conditional.py -- -r logs/2020-11-09T13-31-51_sflckr/
         | 
| 102 | 
            -
            ```
         | 
| 103 | 
            -
             | 
| 104 | 
            -
            ### ImageNet
         | 
| 105 | 
            -
            
         | 
| 106 | 
            -
             | 
| 107 | 
            -
            Download the [2021-04-03T19-39-50_cin_transformer](https://k00.fr/s511rwcv)
         | 
| 108 | 
            -
            folder and place it into logs.  Sampling from the class-conditional ImageNet
         | 
| 109 | 
            -
            model does not require any data preparation. To produce 50 samples for each of
         | 
| 110 | 
            -
            the 1000 classes of ImageNet, with k=600 for top-k sampling, p=0.92 for nucleus
         | 
| 111 | 
            -
            sampling and temperature t=1.0, run
         | 
| 112 | 
            -
             | 
| 113 | 
            -
            ```
         | 
| 114 | 
            -
            python scripts/sample_fast.py -r logs/2021-04-03T19-39-50_cin_transformer/ -n 50 -k 600 -t 1.0 -p 0.92 --batch_size 25   
         | 
| 115 | 
            -
            ```
         | 
| 116 | 
            -
             | 
| 117 | 
            -
            To restrict the model to certain classes, provide them via the `--classes` argument, separated by 
         | 
| 118 | 
            -
            commas. For example, to sample 50 *ostriches*, *border collies* and *whiskey jugs*, run
         | 
| 119 | 
            -
             | 
| 120 | 
            -
            ```
         | 
| 121 | 
            -
            python scripts/sample_fast.py -r logs/2021-04-03T19-39-50_cin_transformer/ -n 50 -k 600 -t 1.0 -p 0.92 --batch_size 25 --classes 9,232,901   
         | 
| 122 | 
            -
            ```
         | 
| 123 | 
            -
            We recommended to experiment with the autoregressive decoding parameters (top-k, top-p and temperature) for best results.  
         | 
| 124 | 
            -
             | 
| 125 | 
            -
            ### FFHQ/CelebA-HQ
         | 
| 126 | 
            -
             | 
| 127 | 
            -
            Download the [2021-04-23T18-19-01_ffhq_transformer](https://k00.fr/yndvfu95) and 
         | 
| 128 | 
            -
            [2021-04-23T18-11-19_celebahq_transformer](https://k00.fr/2xkmielf) 
         | 
| 129 | 
            -
            folders and place them into logs. 
         | 
| 130 | 
            -
            Again, sampling from these unconditional models does not require any data preparation.
         | 
| 131 | 
            -
            To produce 50000 samples, with k=250 for top-k sampling,
         | 
| 132 | 
            -
            p=1.0 for nucleus sampling and temperature t=1.0, run
         | 
| 133 | 
            -
             | 
| 134 | 
            -
            ```
         | 
| 135 | 
            -
            python scripts/sample_fast.py -r logs/2021-04-23T18-19-01_ffhq_transformer/   
         | 
| 136 | 
            -
            ```
         | 
| 137 | 
            -
            for FFHQ and  
         | 
| 138 | 
            -
             | 
| 139 | 
            -
            ```
         | 
| 140 | 
            -
            python scripts/sample_fast.py -r logs/2021-04-23T18-11-19_celebahq_transformer/   
         | 
| 141 | 
            -
            ```
         | 
| 142 | 
            -
            to sample from the CelebA-HQ model.
         | 
| 143 | 
            -
            For both models it can be advantageous to vary the top-k/top-p parameters for sampling.
         | 
| 144 | 
            -
             | 
| 145 | 
            -
            ### FacesHQ
         | 
| 146 | 
            -
            
         | 
| 147 | 
            -
             | 
| 148 | 
            -
            Download [2020-11-13T21-41-45_faceshq_transformer](https://k00.fr/qqfl2do8) and
         | 
| 149 | 
            -
            place it into `logs`. Follow the data preparation steps for
         | 
| 150 | 
            -
            [CelebA-HQ](#celeba-hq) and [FFHQ](#ffhq). Run
         | 
| 151 | 
            -
            ```
         | 
| 152 | 
            -
            streamlit run scripts/sample_conditional.py -- -r logs/2020-11-13T21-41-45_faceshq_transformer/
         | 
| 153 | 
            -
            ```
         | 
| 154 | 
            -
             | 
| 155 | 
            -
            ### D-RIN
         | 
| 156 | 
            -
            
         | 
| 157 | 
            -
             | 
| 158 | 
            -
            Download [2020-11-20T12-54-32_drin_transformer](https://k00.fr/39jcugc5) and
         | 
| 159 | 
            -
            place it into `logs`. To run the demo on a couple of example depth maps
         | 
| 160 | 
            -
            included in the repository, run
         | 
| 161 | 
            -
             | 
| 162 | 
            -
            ```
         | 
| 163 | 
            -
            streamlit run scripts/sample_conditional.py -- -r logs/2020-11-20T12-54-32_drin_transformer/ --ignore_base_data data="{target: main.DataModuleFromConfig, params: {batch_size: 1, validation: {target: taming.data.imagenet.DRINExamples}}}"
         | 
| 164 | 
            -
            ```
         | 
| 165 | 
            -
             | 
| 166 | 
            -
            To run the demo on the complete validation set, first follow the data preparation steps for
         | 
| 167 | 
            -
            [ImageNet](#imagenet) and then run
         | 
| 168 | 
            -
            ```
         | 
| 169 | 
            -
            streamlit run scripts/sample_conditional.py -- -r logs/2020-11-20T12-54-32_drin_transformer/
         | 
| 170 | 
            -
            ```
         | 
| 171 | 
            -
             | 
| 172 | 
            -
            ### COCO
         | 
| 173 | 
            -
            Download [2021-01-20T16-04-20_coco_transformer](https://k00.fr/2zz6i2ce) and
         | 
| 174 | 
            -
            place it into `logs`. To run the demo on a couple of example segmentation maps
         | 
| 175 | 
            -
            included in the repository, run
         | 
| 176 | 
            -
             | 
| 177 | 
            -
            ```
         | 
| 178 | 
            -
            streamlit run scripts/sample_conditional.py -- -r logs/2021-01-20T16-04-20_coco_transformer/ --ignore_base_data data="{target: main.DataModuleFromConfig, params: {batch_size: 1, validation: {target: taming.data.coco.Examples}}}"
         | 
| 179 | 
            -
            ```
         | 
| 180 | 
            -
             | 
| 181 | 
            -
            ### ADE20k
         | 
| 182 | 
            -
            Download [2020-11-20T21-45-44_ade20k_transformer](https://k00.fr/ot46cksa) and
         | 
| 183 | 
            -
            place it into `logs`. To run the demo on a couple of example segmentation maps
         | 
| 184 | 
            -
            included in the repository, run
         | 
| 185 | 
            -
             | 
| 186 | 
            -
            ```
         | 
| 187 | 
            -
            streamlit run scripts/sample_conditional.py -- -r logs/2020-11-20T21-45-44_ade20k_transformer/ --ignore_base_data data="{target: main.DataModuleFromConfig, params: {batch_size: 1, validation: {target: taming.data.ade20k.Examples}}}"
         | 
| 188 | 
            -
            ```
         | 
| 189 | 
            -
             | 
| 190 | 
            -
            ## Scene Image Synthesis
         | 
| 191 | 
            -
            
         | 
| 192 | 
            -
            Scene image generation based on bounding box conditionals as done in our CVPR2021 AI4CC workshop paper [High-Resolution Complex Scene Synthesis with Transformers](https://arxiv.org/abs/2105.06458) (see talk on [workshop page](https://visual.cs.brown.edu/workshops/aicc2021/#awards)). Supporting the datasets COCO and Open Images.
         | 
| 193 | 
            -
             | 
| 194 | 
            -
            ### Training
         | 
| 195 | 
            -
            Download first-stage models [COCO-8k-VQGAN](https://heibox.uni-heidelberg.de/f/78dea9589974474c97c1/) for COCO or [COCO/Open-Images-8k-VQGAN](https://heibox.uni-heidelberg.de/f/461d9a9f4fcf48ab84f4/) for Open Images.
         | 
| 196 | 
            -
            Change `ckpt_path` in `data/coco_scene_images_transformer.yaml` and `data/open_images_scene_images_transformer.yaml` to point to the downloaded first-stage models.
         | 
| 197 | 
            -
            Download the full COCO/OI datasets and adapt `data_path` in the same files, unless working with the 100 files provided for training and validation suits your needs already.
         | 
| 198 | 
            -
             | 
| 199 | 
            -
            Code can be run with
         | 
| 200 | 
            -
            `python main.py --base configs/coco_scene_images_transformer.yaml -t True --gpus 0,`
         | 
| 201 | 
            -
            or
         | 
| 202 | 
            -
            `python main.py --base configs/open_images_scene_images_transformer.yaml -t True --gpus 0,`
         | 
| 203 | 
            -
             | 
| 204 | 
            -
            ### Sampling 
         | 
| 205 | 
            -
            Train a model as described above or download a pre-trained model:
         | 
| 206 | 
            -
             - [Open Images 1 billion parameter model](https://drive.google.com/file/d/1FEK-Z7hyWJBvFWQF50pzSK9y1W_CJEig/view?usp=sharing) available that trained 100 epochs. On 256x256 pixels, FID 41.48±0.21, SceneFID 14.60±0.15, Inception Score 18.47±0.27. The model was trained with 2d crops of images and is thus well-prepared for the task of generating high-resolution images, e.g. 512x512.
         | 
| 207 | 
            -
             - [Open Images distilled version of the above model with 125 million parameters](https://drive.google.com/file/d/1xf89g0mc78J3d8Bx5YhbK4tNRNlOoYaO) allows for sampling on smaller GPUs (4 GB is enough for sampling 256x256 px images). Model was trained for 60 epochs with 10% soft loss, 90% hard loss. On 256x256 pixels, FID 43.07±0.40, SceneFID 15.93±0.19, Inception Score 17.23±0.11.
         | 
| 208 | 
            -
             - [COCO 30 epochs](https://heibox.uni-heidelberg.de/f/0d0b2594e9074c7e9a33/)
         | 
| 209 | 
            -
             - [COCO 60 epochs](https://drive.google.com/file/d/1bInd49g2YulTJBjU32Awyt5qnzxxG5U9/) (find model statistics for both COCO versions in `assets/coco_scene_images_training.svg`)
         | 
| 210 | 
            -
             | 
| 211 | 
            -
            When downloading a pre-trained model, remember to change `ckpt_path` in `configs/*project.yaml` to point to your downloaded first-stage model (see ->Training).
         | 
| 212 | 
            -
             | 
| 213 | 
            -
            Scene image generation can be run with
         | 
| 214 | 
            -
            `python scripts/make_scene_samples.py --outdir=/some/outdir -r /path/to/pretrained/model --resolution=512,512`
         | 
| 215 | 
            -
             | 
| 216 | 
            -
             | 
| 217 | 
            -
            ## Training on custom data
         | 
| 218 | 
            -
             | 
| 219 | 
            -
            Training on your own dataset can be beneficial to get better tokens and hence better images for your domain.
         | 
| 220 | 
            -
            Those are the steps to follow to make this work:
         | 
| 221 | 
            -
            1. install the repo with `conda env create -f environment.yaml`, `conda activate taming` and `pip install -e .`
         | 
| 222 | 
            -
            1. put your .jpg files in a folder `your_folder`
         | 
| 223 | 
            -
            2. create 2 text files a `xx_train.txt` and `xx_test.txt` that point to the files in your training and test set respectively (for example `find $(pwd)/your_folder -name "*.jpg" > train.txt`)
         | 
| 224 | 
            -
            3. adapt `configs/custom_vqgan.yaml` to point to these 2 files
         | 
| 225 | 
            -
            4. run `python main.py --base configs/custom_vqgan.yaml -t True --gpus 0,1` to
         | 
| 226 | 
            -
               train on two GPUs. Use `--gpus 0,` (with a trailing comma) to train on a single GPU.
         | 
| 227 | 
            -
             | 
| 228 | 
            -
            ## Data Preparation
         | 
| 229 | 
            -
             | 
| 230 | 
            -
            ### ImageNet
         | 
| 231 | 
            -
            The code will try to download (through [Academic
         | 
| 232 | 
            -
            Torrents](http://academictorrents.com/)) and prepare ImageNet the first time it
         | 
| 233 | 
            -
            is used. However, since ImageNet is quite large, this requires a lot of disk
         | 
| 234 | 
            -
            space and time. If you already have ImageNet on your disk, you can speed things
         | 
| 235 | 
            -
            up by putting the data into
         | 
| 236 | 
            -
            `${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/data/` (which defaults to
         | 
| 237 | 
            -
            `~/.cache/autoencoders/data/ILSVRC2012_{split}/data/`), where `{split}` is one
         | 
| 238 | 
            -
            of `train`/`validation`. It should have the following structure:
         | 
| 239 | 
            -
             | 
| 240 | 
            -
            ```
         | 
| 241 | 
            -
            ${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/data/
         | 
| 242 | 
            -
            ├── n01440764
         | 
| 243 | 
            -
            │   ├── n01440764_10026.JPEG
         | 
| 244 | 
            -
            │   ├── n01440764_10027.JPEG
         | 
| 245 | 
            -
            │   ├── ...
         | 
| 246 | 
            -
            ├── n01443537
         | 
| 247 | 
            -
            │   ├── n01443537_10007.JPEG
         | 
| 248 | 
            -
            │   ├── n01443537_10014.JPEG
         | 
| 249 | 
            -
            │   ├── ...
         | 
| 250 | 
            -
            ├── ...
         | 
| 251 | 
            -
            ```
         | 
| 252 | 
            -
             | 
| 253 | 
            -
            If you haven't extracted the data, you can also place
         | 
| 254 | 
            -
            `ILSVRC2012_img_train.tar`/`ILSVRC2012_img_val.tar` (or symlinks to them) into
         | 
| 255 | 
            -
            `${XDG_CACHE}/autoencoders/data/ILSVRC2012_train/` /
         | 
| 256 | 
            -
            `${XDG_CACHE}/autoencoders/data/ILSVRC2012_validation/`, which will then be
         | 
| 257 | 
            -
            extracted into above structure without downloading it again.  Note that this
         | 
| 258 | 
            -
            will only happen if neither a folder
         | 
| 259 | 
            -
            `${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/data/` nor a file
         | 
| 260 | 
            -
            `${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/.ready` exist. Remove them
         | 
| 261 | 
            -
            if you want to force running the dataset preparation again.
         | 
| 262 | 
            -
             | 
| 263 | 
            -
            You will then need to prepare the depth data using
         | 
| 264 | 
            -
            [MiDaS](https://github.com/intel-isl/MiDaS). Create a symlink
         | 
| 265 | 
            -
            `data/imagenet_depth` pointing to a folder with two subfolders `train` and
         | 
| 266 | 
            -
            `val`, each mirroring the structure of the corresponding ImageNet folder
         | 
| 267 | 
            -
            described above and containing a `png` file for each of ImageNet's `JPEG`
         | 
| 268 | 
            -
            files. The `png` encodes `float32` depth values obtained from MiDaS as RGBA
         | 
| 269 | 
            -
            images. We provide the script `scripts/extract_depth.py` to generate this data.
         | 
| 270 | 
            -
            **Please note** that this script uses [MiDaS via PyTorch
         | 
| 271 | 
            -
            Hub](https://pytorch.org/hub/intelisl_midas_v2/). When we prepared the data,
         | 
| 272 | 
            -
            the hub provided the [MiDaS
         | 
| 273 | 
            -
            v2.0](https://github.com/intel-isl/MiDaS/releases/tag/v2) version, but now it
         | 
| 274 | 
            -
            provides a v2.1 version. We haven't tested our models with depth maps obtained
         | 
| 275 | 
            -
            via v2.1 and if you want to make sure that things work as expected, you must
         | 
| 276 | 
            -
            adjust the script to make sure it explicitly uses
         | 
| 277 | 
            -
            [v2.0](https://github.com/intel-isl/MiDaS/releases/tag/v2)!
         | 
| 278 | 
            -
             | 
| 279 | 
            -
            ### CelebA-HQ
         | 
| 280 | 
            -
            Create a symlink `data/celebahq` pointing to a folder containing the `.npy`
         | 
| 281 | 
            -
            files of CelebA-HQ (instructions to obtain them can be found in the [PGGAN
         | 
| 282 | 
            -
            repository](https://github.com/tkarras/progressive_growing_of_gans)).
         | 
| 283 | 
            -
             | 
| 284 | 
            -
            ### FFHQ
         | 
| 285 | 
            -
            Create a symlink `data/ffhq` pointing to the `images1024x1024` folder obtained
         | 
| 286 | 
            -
            from the [FFHQ repository](https://github.com/NVlabs/ffhq-dataset).
         | 
| 287 | 
            -
             | 
| 288 | 
            -
            ### S-FLCKR
         | 
| 289 | 
            -
            Unfortunately, we are not allowed to distribute the images we collected for the
         | 
| 290 | 
            -
            S-FLCKR dataset and can therefore only give a description how it was produced.
         | 
| 291 | 
            -
            There are many resources on [collecting images from the
         | 
| 292 | 
            -
            web](https://github.com/adrianmrit/flickrdatasets) to get started.
         | 
| 293 | 
            -
            We collected sufficiently large images from [flickr](https://www.flickr.com)
         | 
| 294 | 
            -
            (see `data/flickr_tags.txt` for a full list of tags used to find images)
         | 
| 295 | 
            -
            and various [subreddits](https://www.reddit.com/r/sfwpornnetwork/wiki/network)
         | 
| 296 | 
            -
            (see `data/subreddits.txt` for all subreddits that were used).
         | 
| 297 | 
            -
            Overall, we collected 107625 images, and split them randomly into 96861
         | 
| 298 | 
            -
            training images and 10764 validation images. We then obtained segmentation
         | 
| 299 | 
            -
            masks for each image using [DeepLab v2](https://arxiv.org/abs/1606.00915)
         | 
| 300 | 
            -
            trained on [COCO-Stuff](https://arxiv.org/abs/1612.03716). We used a [PyTorch
         | 
| 301 | 
            -
            reimplementation](https://github.com/kazuto1011/deeplab-pytorch) and include an
         | 
| 302 | 
            -
            example script for this process in `scripts/extract_segmentation.py`.
         | 
| 303 | 
            -
             | 
| 304 | 
            -
            ### COCO
         | 
| 305 | 
            -
            Create a symlink `data/coco` containing the images from the 2017 split in
         | 
| 306 | 
            -
            `train2017` and `val2017`, and their annotations in `annotations`. Files can be
         | 
| 307 | 
            -
            obtained from the [COCO webpage](https://cocodataset.org/). In addition, we use
         | 
| 308 | 
            -
            the [Stuff+thing PNG-style annotations on COCO 2017
         | 
| 309 | 
            -
            trainval](http://calvin.inf.ed.ac.uk/wp-content/uploads/data/cocostuffdataset/stuffthingmaps_trainval2017.zip)
         | 
| 310 | 
            -
            annotations from [COCO-Stuff](https://github.com/nightrome/cocostuff), which
         | 
| 311 | 
            -
            should be placed under `data/cocostuffthings`.
         | 
| 312 | 
            -
             | 
| 313 | 
            -
            ### ADE20k
         | 
| 314 | 
            -
            Create a symlink `data/ade20k_root` containing the contents of
         | 
| 315 | 
            -
            [ADEChallengeData2016.zip](http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip)
         | 
| 316 | 
            -
            from the [MIT Scene Parsing Benchmark](http://sceneparsing.csail.mit.edu/).
         | 
| 317 | 
            -
             | 
| 318 | 
            -
            ## Training models
         | 
| 319 | 
            -
             | 
| 320 | 
            -
            ### FacesHQ
         | 
| 321 | 
            -
             | 
| 322 | 
            -
            Train a VQGAN with
         | 
| 323 | 
            -
            ```
         | 
| 324 | 
            -
            python main.py --base configs/faceshq_vqgan.yaml -t True --gpus 0,
         | 
| 325 | 
            -
            ```
         | 
| 326 | 
            -
             | 
| 327 | 
            -
            Then, adjust the checkpoint path of the config key
         | 
| 328 | 
            -
            `model.params.first_stage_config.params.ckpt_path` in
         | 
| 329 | 
            -
            `configs/faceshq_transformer.yaml` (or download
         | 
| 330 | 
            -
            [2020-11-09T13-33-36_faceshq_vqgan](https://k00.fr/uxy5usa9) and place into `logs`, which
         | 
| 331 | 
            -
            corresponds to the preconfigured checkpoint path), then run
         | 
| 332 | 
            -
            ```
         | 
| 333 | 
            -
            python main.py --base configs/faceshq_transformer.yaml -t True --gpus 0,
         | 
| 334 | 
            -
            ```
         | 
| 335 | 
            -
             | 
| 336 | 
            -
            ### D-RIN
         | 
| 337 | 
            -
             | 
| 338 | 
            -
            Train a VQGAN on ImageNet with
         | 
| 339 | 
            -
            ```
         | 
| 340 | 
            -
            python main.py --base configs/imagenet_vqgan.yaml -t True --gpus 0,
         | 
| 341 | 
            -
            ```
         | 
| 342 | 
            -
             | 
| 343 | 
            -
            or download a pretrained one from [2020-09-23T17-56-33_imagenet_vqgan](https://k00.fr/u0j2dtac)
         | 
| 344 | 
            -
            and place under `logs`. If you trained your own, adjust the path in the config
         | 
| 345 | 
            -
            key `model.params.first_stage_config.params.ckpt_path` of
         | 
| 346 | 
            -
            `configs/drin_transformer.yaml`.
         | 
| 347 | 
            -
             | 
| 348 | 
            -
            Train a VQGAN on Depth Maps of ImageNet with
         | 
| 349 | 
            -
            ```
         | 
| 350 | 
            -
            python main.py --base configs/imagenetdepth_vqgan.yaml -t True --gpus 0,
         | 
| 351 | 
            -
            ```
         | 
| 352 | 
            -
             | 
| 353 | 
            -
            or download a pretrained one from [2020-11-03T15-34-24_imagenetdepth_vqgan](https://k00.fr/55rlxs6i)
         | 
| 354 | 
            -
            and place under `logs`. If you trained your own, adjust the path in the config
         | 
| 355 | 
            -
            key `model.params.cond_stage_config.params.ckpt_path` of
         | 
| 356 | 
            -
            `configs/drin_transformer.yaml`.
         | 
| 357 | 
            -
             | 
| 358 | 
            -
            To train the transformer, run
         | 
| 359 | 
            -
            ```
         | 
| 360 | 
            -
            python main.py --base configs/drin_transformer.yaml -t True --gpus 0,
         | 
| 361 | 
            -
            ```
         | 
| 362 | 
            -
             | 
| 363 | 
            -
            ## More Resources
         | 
| 364 | 
            -
            ### Comparing Different First Stage Models
         | 
| 365 | 
            -
            The reconstruction and compression capabilities of different fist stage models can be analyzed in this [colab notebook](https://colab.research.google.com/github/CompVis/taming-transformers/blob/master/scripts/reconstruction_usage.ipynb). 
         | 
| 366 | 
            -
            In particular, the notebook compares two VQGANs with a downsampling factor of f=16 for each and codebook dimensionality of 1024 and 16384, 
         | 
| 367 | 
            -
            a VQGAN with f=8 and 8192 codebook entries and the discrete autoencoder of OpenAI's [DALL-E](https://github.com/openai/DALL-E) (which has f=8 and 8192 
         | 
| 368 | 
            -
            codebook entries).
         | 
| 369 | 
            -
            
         | 
| 370 | 
            -
            
         | 
| 371 | 
            -
             | 
| 372 | 
            -
            ### Other
         | 
| 373 | 
            -
            - A [video summary](https://www.youtube.com/watch?v=o7dqGcLDf0A&feature=emb_imp_woyt) by [Two Minute Papers](https://www.youtube.com/channel/UCbfYPyITQ-7l4upoX8nvctg).
         | 
| 374 | 
            -
            - A [video summary](https://www.youtube.com/watch?v=-wDSDtIAyWQ) by [Gradient Dude](https://www.youtube.com/c/GradientDude/about).
         | 
| 375 | 
            -
            - A [weights and biases report summarizing the paper](https://wandb.ai/ayush-thakur/taming-transformer/reports/-Overview-Taming-Transformers-for-High-Resolution-Image-Synthesis---Vmlldzo0NjEyMTY)
         | 
| 376 | 
            -
            by [ayulockin](https://github.com/ayulockin).
         | 
| 377 | 
            -
            - A [video summary](https://www.youtube.com/watch?v=JfUTd8fjtX8&feature=emb_imp_woyt) by [What's AI](https://www.youtube.com/channel/UCUzGQrN-lyyc0BWTYoJM_Sg).
         | 
| 378 | 
            -
            - Take a look at [ak9250's notebook](https://github.com/ak9250/taming-transformers/blob/master/tamingtransformerscolab.ipynb) if you want to run the streamlit demos on Colab.
         | 
| 379 | 
            -
             | 
| 380 | 
            -
            ### Text-to-Image Optimization via CLIP
         | 
| 381 | 
            -
            VQGAN has been successfully used as an image generator guided by the [CLIP](https://github.com/openai/CLIP) model, both for pure image generation
         | 
| 382 | 
            -
            from scratch and image-to-image translation. We recommend the following notebooks/videos/resources:
         | 
| 383 | 
            -
             | 
| 384 | 
            -
             - [Advadnouns](https://twitter.com/advadnoun/status/1389316507134357506) Patreon and corresponding LatentVision notebooks: https://www.patreon.com/patronizeme
         | 
| 385 | 
            -
             - The [notebook]( https://colab.research.google.com/drive/1L8oL-vLJXVcRzCFbPwOoMkPKJ8-aYdPN) of [Rivers Have Wings](https://twitter.com/RiversHaveWings).
         | 
| 386 | 
            -
             - A [video](https://www.youtube.com/watch?v=90QDe6DQXF4&t=12s) explanation by [Dot CSV](https://www.youtube.com/channel/UCy5znSnfMsDwaLlROnZ7Qbg) (in Spanish, but English subtitles are available)
         | 
| 387 | 
            -
             | 
| 388 | 
            -
            
         | 
| 389 | 
            -
             | 
| 390 | 
            -
            Text prompt: *'A bird drawn by a child'*
         | 
| 391 | 
            -
             | 
| 392 | 
            -
            ## Shout-outs
         | 
| 393 | 
            -
            Thanks to everyone who makes their code and models available. In particular,
         | 
| 394 | 
            -
             | 
| 395 | 
            -
            - The architecture of our VQGAN is inspired by [Denoising Diffusion Probabilistic Models](https://github.com/hojonathanho/diffusion)
         | 
| 396 | 
            -
            - The very hackable transformer implementation [minGPT](https://github.com/karpathy/minGPT)
         | 
| 397 | 
            -
            - The good ol' [PatchGAN](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix) and [Learned Perceptual Similarity (LPIPS)](https://github.com/richzhang/PerceptualSimilarity)
         | 
| 398 | 
            -
             | 
| 399 | 
            -
            ## BibTeX
         | 
| 400 | 
            -
             | 
| 401 | 
            -
            ```
         | 
| 402 | 
            -
            @misc{esser2020taming,
         | 
| 403 | 
            -
                  title={Taming Transformers for High-Resolution Image Synthesis}, 
         | 
| 404 | 
            -
                  author={Patrick Esser and Robin Rombach and Björn Ommer},
         | 
| 405 | 
            -
                  year={2020},
         | 
| 406 | 
            -
                  eprint={2012.09841},
         | 
| 407 | 
            -
                  archivePrefix={arXiv},
         | 
| 408 | 
            -
                  primaryClass={cs.CV}
         | 
| 409 | 
            -
            }
         | 
| 410 | 
            -
            ```
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        taming-transformers/configs/coco_cond_stage.yaml
    DELETED
    
    | @@ -1,49 +0,0 @@ | |
| 1 | 
            -
            model:
         | 
| 2 | 
            -
              base_learning_rate: 4.5e-06
         | 
| 3 | 
            -
              target: taming.models.vqgan.VQSegmentationModel
         | 
| 4 | 
            -
              params:
         | 
| 5 | 
            -
                embed_dim: 256
         | 
| 6 | 
            -
                n_embed: 1024
         | 
| 7 | 
            -
                image_key: "segmentation"
         | 
| 8 | 
            -
                n_labels: 183
         | 
| 9 | 
            -
                ddconfig:
         | 
| 10 | 
            -
                  double_z: false
         | 
| 11 | 
            -
                  z_channels: 256
         | 
| 12 | 
            -
                  resolution: 256
         | 
| 13 | 
            -
                  in_channels: 183
         | 
| 14 | 
            -
                  out_ch: 183
         | 
| 15 | 
            -
                  ch: 128
         | 
| 16 | 
            -
                  ch_mult:
         | 
| 17 | 
            -
                  - 1
         | 
| 18 | 
            -
                  - 1
         | 
| 19 | 
            -
                  - 2
         | 
| 20 | 
            -
                  - 2
         | 
| 21 | 
            -
                  - 4
         | 
| 22 | 
            -
                  num_res_blocks: 2
         | 
| 23 | 
            -
                  attn_resolutions:
         | 
| 24 | 
            -
                  - 16
         | 
| 25 | 
            -
                  dropout: 0.0
         | 
| 26 | 
            -
             | 
| 27 | 
            -
                lossconfig:
         | 
| 28 | 
            -
                  target: taming.modules.losses.segmentation.BCELossWithQuant
         | 
| 29 | 
            -
                  params:
         | 
| 30 | 
            -
                    codebook_weight: 1.0
         | 
| 31 | 
            -
             | 
| 32 | 
            -
            data:
         | 
| 33 | 
            -
              target: main.DataModuleFromConfig
         | 
| 34 | 
            -
              params:
         | 
| 35 | 
            -
                batch_size: 12
         | 
| 36 | 
            -
                train:
         | 
| 37 | 
            -
                  target: taming.data.coco.CocoImagesAndCaptionsTrain
         | 
| 38 | 
            -
                  params:
         | 
| 39 | 
            -
                    size: 296
         | 
| 40 | 
            -
                    crop_size: 256
         | 
| 41 | 
            -
                    onehot_segmentation: true
         | 
| 42 | 
            -
                    use_stuffthing: true
         | 
| 43 | 
            -
                validation:
         | 
| 44 | 
            -
                  target: taming.data.coco.CocoImagesAndCaptionsValidation
         | 
| 45 | 
            -
                  params:
         | 
| 46 | 
            -
                    size: 256
         | 
| 47 | 
            -
                    crop_size: 256
         | 
| 48 | 
            -
                    onehot_segmentation: true
         | 
| 49 | 
            -
                    use_stuffthing: true
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        taming-transformers/configs/coco_scene_images_transformer.yaml
    DELETED
    
    | @@ -1,80 +0,0 @@ | |
| 1 | 
            -
            model:
         | 
| 2 | 
            -
              base_learning_rate: 4.5e-06
         | 
| 3 | 
            -
              target: taming.models.cond_transformer.Net2NetTransformer
         | 
| 4 | 
            -
              params:
         | 
| 5 | 
            -
                cond_stage_key: objects_bbox
         | 
| 6 | 
            -
                transformer_config:
         | 
| 7 | 
            -
                  target: taming.modules.transformer.mingpt.GPT
         | 
| 8 | 
            -
                  params:
         | 
| 9 | 
            -
                    vocab_size: 8192
         | 
| 10 | 
            -
                    block_size: 348  # = 256 + 92 = dim(vqgan_latent_space,16x16) + dim(conditional_builder.embedding_dim)
         | 
| 11 | 
            -
                    n_layer: 40
         | 
| 12 | 
            -
                    n_head: 16
         | 
| 13 | 
            -
                    n_embd: 1408
         | 
| 14 | 
            -
                    embd_pdrop: 0.1
         | 
| 15 | 
            -
                    resid_pdrop: 0.1
         | 
| 16 | 
            -
                    attn_pdrop: 0.1
         | 
| 17 | 
            -
                first_stage_config:
         | 
| 18 | 
            -
                  target: taming.models.vqgan.VQModel
         | 
| 19 | 
            -
                  params:
         | 
| 20 | 
            -
                    ckpt_path: /path/to/coco_epoch117.ckpt  # https://heibox.uni-heidelberg.de/f/78dea9589974474c97c1/
         | 
| 21 | 
            -
                    embed_dim: 256
         | 
| 22 | 
            -
                    n_embed: 8192
         | 
| 23 | 
            -
                    ddconfig:
         | 
| 24 | 
            -
                      double_z: false
         | 
| 25 | 
            -
                      z_channels: 256
         | 
| 26 | 
            -
                      resolution: 256
         | 
| 27 | 
            -
                      in_channels: 3
         | 
| 28 | 
            -
                      out_ch: 3
         | 
| 29 | 
            -
                      ch: 128
         | 
| 30 | 
            -
                      ch_mult:
         | 
| 31 | 
            -
                      - 1
         | 
| 32 | 
            -
                      - 1
         | 
| 33 | 
            -
                      - 2
         | 
| 34 | 
            -
                      - 2
         | 
| 35 | 
            -
                      - 4
         | 
| 36 | 
            -
                      num_res_blocks: 2
         | 
| 37 | 
            -
                      attn_resolutions:
         | 
| 38 | 
            -
                      - 16
         | 
| 39 | 
            -
                      dropout: 0.0
         | 
| 40 | 
            -
                    lossconfig:
         | 
| 41 | 
            -
                      target: taming.modules.losses.DummyLoss
         | 
| 42 | 
            -
                cond_stage_config:
         | 
| 43 | 
            -
                  target: taming.models.dummy_cond_stage.DummyCondStage
         | 
| 44 | 
            -
                  params:
         | 
| 45 | 
            -
                    conditional_key: objects_bbox
         | 
| 46 | 
            -
             | 
| 47 | 
            -
            data:
         | 
| 48 | 
            -
              target: main.DataModuleFromConfig
         | 
| 49 | 
            -
              params:
         | 
| 50 | 
            -
                batch_size: 6
         | 
| 51 | 
            -
                train:
         | 
| 52 | 
            -
                  target: taming.data.annotated_objects_coco.AnnotatedObjectsCoco
         | 
| 53 | 
            -
                  params:
         | 
| 54 | 
            -
                    data_path: data/coco_annotations_100  # substitute with path to full dataset
         | 
| 55 | 
            -
                    split: train
         | 
| 56 | 
            -
                    keys: [image, objects_bbox, file_name, annotations]
         | 
| 57 | 
            -
                    no_tokens: 8192
         | 
| 58 | 
            -
                    target_image_size: 256
         | 
| 59 | 
            -
                    min_object_area: 0.00001
         | 
| 60 | 
            -
                    min_objects_per_image: 2
         | 
| 61 | 
            -
                    max_objects_per_image: 30
         | 
| 62 | 
            -
                    crop_method: random-1d
         | 
| 63 | 
            -
                    random_flip: true
         | 
| 64 | 
            -
                    use_group_parameter: true
         | 
| 65 | 
            -
                    encode_crop: true
         | 
| 66 | 
            -
                validation:
         | 
| 67 | 
            -
                  target: taming.data.annotated_objects_coco.AnnotatedObjectsCoco
         | 
| 68 | 
            -
                  params:
         | 
| 69 | 
            -
                    data_path: data/coco_annotations_100  # substitute with path to full dataset
         | 
| 70 | 
            -
                    split: validation
         | 
| 71 | 
            -
                    keys: [image, objects_bbox, file_name, annotations]
         | 
| 72 | 
            -
                    no_tokens: 8192
         | 
| 73 | 
            -
                    target_image_size: 256
         | 
| 74 | 
            -
                    min_object_area: 0.00001
         | 
| 75 | 
            -
                    min_objects_per_image: 2
         | 
| 76 | 
            -
                    max_objects_per_image: 30
         | 
| 77 | 
            -
                    crop_method: center
         | 
| 78 | 
            -
                    random_flip: false
         | 
| 79 | 
            -
                    use_group_parameter: true
         | 
| 80 | 
            -
                    encode_crop: true
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        taming-transformers/configs/custom_vqgan.yaml
    DELETED
    
    | @@ -1,43 +0,0 @@ | |
| 1 | 
            -
            model:
         | 
| 2 | 
            -
              base_learning_rate: 4.5e-6
         | 
| 3 | 
            -
              target: taming.models.vqgan.VQModel
         | 
| 4 | 
            -
              params:
         | 
| 5 | 
            -
                embed_dim: 256
         | 
| 6 | 
            -
                n_embed: 1024
         | 
| 7 | 
            -
                ddconfig:
         | 
| 8 | 
            -
                  double_z: False
         | 
| 9 | 
            -
                  z_channels: 256
         | 
| 10 | 
            -
                  resolution: 256
         | 
| 11 | 
            -
                  in_channels: 3
         | 
| 12 | 
            -
                  out_ch: 3
         | 
| 13 | 
            -
                  ch: 128
         | 
| 14 | 
            -
                  ch_mult: [ 1,1,2,2,4]  # num_down = len(ch_mult)-1
         | 
| 15 | 
            -
                  num_res_blocks: 2
         | 
| 16 | 
            -
                  attn_resolutions: [16]
         | 
| 17 | 
            -
                  dropout: 0.0
         | 
| 18 | 
            -
             | 
| 19 | 
            -
                lossconfig:
         | 
| 20 | 
            -
                  target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
         | 
| 21 | 
            -
                  params:
         | 
| 22 | 
            -
                    disc_conditional: False
         | 
| 23 | 
            -
                    disc_in_channels: 3
         | 
| 24 | 
            -
                    disc_start: 10000
         | 
| 25 | 
            -
                    disc_weight: 0.8
         | 
| 26 | 
            -
                    codebook_weight: 1.0
         | 
| 27 | 
            -
             | 
| 28 | 
            -
            data:
         | 
| 29 | 
            -
              target: main.DataModuleFromConfig
         | 
| 30 | 
            -
              params:
         | 
| 31 | 
            -
                batch_size: 5
         | 
| 32 | 
            -
                num_workers: 8
         | 
| 33 | 
            -
                train:
         | 
| 34 | 
            -
                  target: taming.data.custom.CustomTrain
         | 
| 35 | 
            -
                  params:
         | 
| 36 | 
            -
                    training_images_list_file: some/training.txt
         | 
| 37 | 
            -
                    size: 256
         | 
| 38 | 
            -
                validation:
         | 
| 39 | 
            -
                  target: taming.data.custom.CustomTest
         | 
| 40 | 
            -
                  params:
         | 
| 41 | 
            -
                    test_images_list_file: some/test.txt
         | 
| 42 | 
            -
                    size: 256
         | 
| 43 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        taming-transformers/configs/drin_transformer.yaml
    DELETED
    
    | @@ -1,77 +0,0 @@ | |
| 1 | 
            -
            model:
         | 
| 2 | 
            -
              base_learning_rate: 4.5e-06
         | 
| 3 | 
            -
              target: taming.models.cond_transformer.Net2NetTransformer
         | 
| 4 | 
            -
              params:
         | 
| 5 | 
            -
                cond_stage_key: depth
         | 
| 6 | 
            -
                transformer_config:
         | 
| 7 | 
            -
                  target: taming.modules.transformer.mingpt.GPT
         | 
| 8 | 
            -
                  params:
         | 
| 9 | 
            -
                    vocab_size: 1024
         | 
| 10 | 
            -
                    block_size: 512
         | 
| 11 | 
            -
                    n_layer: 24
         | 
| 12 | 
            -
                    n_head: 16
         | 
| 13 | 
            -
                    n_embd: 1024
         | 
| 14 | 
            -
                first_stage_config:
         | 
| 15 | 
            -
                  target: taming.models.vqgan.VQModel
         | 
| 16 | 
            -
                  params:
         | 
| 17 | 
            -
                    ckpt_path: logs/2020-09-23T17-56-33_imagenet_vqgan/checkpoints/last.ckpt
         | 
| 18 | 
            -
                    embed_dim: 256
         | 
| 19 | 
            -
                    n_embed: 1024
         | 
| 20 | 
            -
                    ddconfig:
         | 
| 21 | 
            -
                      double_z: false
         | 
| 22 | 
            -
                      z_channels: 256
         | 
| 23 | 
            -
                      resolution: 256
         | 
| 24 | 
            -
                      in_channels: 3
         | 
| 25 | 
            -
                      out_ch: 3
         | 
| 26 | 
            -
                      ch: 128
         | 
| 27 | 
            -
                      ch_mult:
         | 
| 28 | 
            -
                      - 1
         | 
| 29 | 
            -
                      - 1
         | 
| 30 | 
            -
                      - 2
         | 
| 31 | 
            -
                      - 2
         | 
| 32 | 
            -
                      - 4
         | 
| 33 | 
            -
                      num_res_blocks: 2
         | 
| 34 | 
            -
                      attn_resolutions:
         | 
| 35 | 
            -
                      - 16
         | 
| 36 | 
            -
                      dropout: 0.0
         | 
| 37 | 
            -
                    lossconfig:
         | 
| 38 | 
            -
                      target: taming.modules.losses.DummyLoss
         | 
| 39 | 
            -
                cond_stage_config:
         | 
| 40 | 
            -
                  target: taming.models.vqgan.VQModel
         | 
| 41 | 
            -
                  params:
         | 
| 42 | 
            -
                    ckpt_path: logs/2020-11-03T15-34-24_imagenetdepth_vqgan/checkpoints/last.ckpt
         | 
| 43 | 
            -
                    embed_dim: 256
         | 
| 44 | 
            -
                    n_embed: 1024
         | 
| 45 | 
            -
                    ddconfig:
         | 
| 46 | 
            -
                      double_z: false
         | 
| 47 | 
            -
                      z_channels: 256
         | 
| 48 | 
            -
                      resolution: 256
         | 
| 49 | 
            -
                      in_channels: 1
         | 
| 50 | 
            -
                      out_ch: 1
         | 
| 51 | 
            -
                      ch: 128
         | 
| 52 | 
            -
                      ch_mult:
         | 
| 53 | 
            -
                      - 1
         | 
| 54 | 
            -
                      - 1
         | 
| 55 | 
            -
                      - 2
         | 
| 56 | 
            -
                      - 2
         | 
| 57 | 
            -
                      - 4
         | 
| 58 | 
            -
                      num_res_blocks: 2
         | 
| 59 | 
            -
                      attn_resolutions:
         | 
| 60 | 
            -
                      - 16
         | 
| 61 | 
            -
                      dropout: 0.0
         | 
| 62 | 
            -
                    lossconfig:
         | 
| 63 | 
            -
                      target: taming.modules.losses.DummyLoss
         | 
| 64 | 
            -
             | 
| 65 | 
            -
            data:
         | 
| 66 | 
            -
              target: main.DataModuleFromConfig
         | 
| 67 | 
            -
              params:
         | 
| 68 | 
            -
                batch_size: 2
         | 
| 69 | 
            -
                num_workers: 8
         | 
| 70 | 
            -
                train:
         | 
| 71 | 
            -
                  target: taming.data.imagenet.RINTrainWithDepth
         | 
| 72 | 
            -
                  params:
         | 
| 73 | 
            -
                    size: 256
         | 
| 74 | 
            -
                validation:
         | 
| 75 | 
            -
                  target: taming.data.imagenet.RINValidationWithDepth
         | 
| 76 | 
            -
                  params:
         | 
| 77 | 
            -
                    size: 256
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        taming-transformers/configs/faceshq_transformer.yaml
    DELETED
    
    | @@ -1,61 +0,0 @@ | |
| 1 | 
            -
            model:
         | 
| 2 | 
            -
              base_learning_rate: 4.5e-06
         | 
| 3 | 
            -
              target: taming.models.cond_transformer.Net2NetTransformer
         | 
| 4 | 
            -
              params:
         | 
| 5 | 
            -
                cond_stage_key: coord
         | 
| 6 | 
            -
                transformer_config:
         | 
| 7 | 
            -
                  target: taming.modules.transformer.mingpt.GPT
         | 
| 8 | 
            -
                  params:
         | 
| 9 | 
            -
                    vocab_size: 1024
         | 
| 10 | 
            -
                    block_size: 512
         | 
| 11 | 
            -
                    n_layer: 24
         | 
| 12 | 
            -
                    n_head: 16
         | 
| 13 | 
            -
                    n_embd: 1024
         | 
| 14 | 
            -
                first_stage_config:
         | 
| 15 | 
            -
                  target: taming.models.vqgan.VQModel
         | 
| 16 | 
            -
                  params:
         | 
| 17 | 
            -
                    ckpt_path: logs/2020-11-09T13-33-36_faceshq_vqgan/checkpoints/last.ckpt
         | 
| 18 | 
            -
                    embed_dim: 256
         | 
| 19 | 
            -
                    n_embed: 1024
         | 
| 20 | 
            -
                    ddconfig:
         | 
| 21 | 
            -
                      double_z: false
         | 
| 22 | 
            -
                      z_channels: 256
         | 
| 23 | 
            -
                      resolution: 256
         | 
| 24 | 
            -
                      in_channels: 3
         | 
| 25 | 
            -
                      out_ch: 3
         | 
| 26 | 
            -
                      ch: 128
         | 
| 27 | 
            -
                      ch_mult:
         | 
| 28 | 
            -
                      - 1
         | 
| 29 | 
            -
                      - 1
         | 
| 30 | 
            -
                      - 2
         | 
| 31 | 
            -
                      - 2
         | 
| 32 | 
            -
                      - 4
         | 
| 33 | 
            -
                      num_res_blocks: 2
         | 
| 34 | 
            -
                      attn_resolutions:
         | 
| 35 | 
            -
                      - 16
         | 
| 36 | 
            -
                      dropout: 0.0
         | 
| 37 | 
            -
                    lossconfig:
         | 
| 38 | 
            -
                      target: taming.modules.losses.DummyLoss
         | 
| 39 | 
            -
                cond_stage_config:
         | 
| 40 | 
            -
                  target: taming.modules.misc.coord.CoordStage
         | 
| 41 | 
            -
                  params:
         | 
| 42 | 
            -
                    n_embed: 1024
         | 
| 43 | 
            -
                    down_factor: 16
         | 
| 44 | 
            -
             | 
| 45 | 
            -
            data:
         | 
| 46 | 
            -
              target: main.DataModuleFromConfig
         | 
| 47 | 
            -
              params:
         | 
| 48 | 
            -
                batch_size: 2
         | 
| 49 | 
            -
                num_workers: 8
         | 
| 50 | 
            -
                train:
         | 
| 51 | 
            -
                  target: taming.data.faceshq.FacesHQTrain
         | 
| 52 | 
            -
                  params:
         | 
| 53 | 
            -
                    size: 256
         | 
| 54 | 
            -
                    crop_size: 256
         | 
| 55 | 
            -
                    coord: True
         | 
| 56 | 
            -
                validation:
         | 
| 57 | 
            -
                  target: taming.data.faceshq.FacesHQValidation
         | 
| 58 | 
            -
                  params:
         | 
| 59 | 
            -
                    size: 256
         | 
| 60 | 
            -
                    crop_size: 256
         | 
| 61 | 
            -
                    coord: True
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        taming-transformers/configs/faceshq_vqgan.yaml
    DELETED
    
    | @@ -1,42 +0,0 @@ | |
| 1 | 
            -
            model:
         | 
| 2 | 
            -
              base_learning_rate: 4.5e-6
         | 
| 3 | 
            -
              target: taming.models.vqgan.VQModel
         | 
| 4 | 
            -
              params:
         | 
| 5 | 
            -
                embed_dim: 256
         | 
| 6 | 
            -
                n_embed: 1024
         | 
| 7 | 
            -
                ddconfig:
         | 
| 8 | 
            -
                  double_z: False
         | 
| 9 | 
            -
                  z_channels: 256
         | 
| 10 | 
            -
                  resolution: 256
         | 
| 11 | 
            -
                  in_channels: 3
         | 
| 12 | 
            -
                  out_ch: 3
         | 
| 13 | 
            -
                  ch: 128
         | 
| 14 | 
            -
                  ch_mult: [ 1,1,2,2,4]  # num_down = len(ch_mult)-1
         | 
| 15 | 
            -
                  num_res_blocks: 2
         | 
| 16 | 
            -
                  attn_resolutions: [16]
         | 
| 17 | 
            -
                  dropout: 0.0
         | 
| 18 | 
            -
             | 
| 19 | 
            -
                lossconfig:
         | 
| 20 | 
            -
                  target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
         | 
| 21 | 
            -
                  params:
         | 
| 22 | 
            -
                    disc_conditional: False
         | 
| 23 | 
            -
                    disc_in_channels: 3
         | 
| 24 | 
            -
                    disc_start: 30001
         | 
| 25 | 
            -
                    disc_weight: 0.8
         | 
| 26 | 
            -
                    codebook_weight: 1.0
         | 
| 27 | 
            -
             | 
| 28 | 
            -
            data:
         | 
| 29 | 
            -
              target: main.DataModuleFromConfig
         | 
| 30 | 
            -
              params:
         | 
| 31 | 
            -
                batch_size: 3
         | 
| 32 | 
            -
                num_workers: 8
         | 
| 33 | 
            -
                train:
         | 
| 34 | 
            -
                  target: taming.data.faceshq.FacesHQTrain
         | 
| 35 | 
            -
                  params:
         | 
| 36 | 
            -
                    size: 256
         | 
| 37 | 
            -
                    crop_size: 256
         | 
| 38 | 
            -
                validation:
         | 
| 39 | 
            -
                  target: taming.data.faceshq.FacesHQValidation
         | 
| 40 | 
            -
                  params:
         | 
| 41 | 
            -
                    size: 256
         | 
| 42 | 
            -
                    crop_size: 256
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        taming-transformers/configs/imagenet_vqgan.yaml
    DELETED
    
    | @@ -1,42 +0,0 @@ | |
| 1 | 
            -
            model:
         | 
| 2 | 
            -
              base_learning_rate: 4.5e-6
         | 
| 3 | 
            -
              target: taming.models.vqgan.VQModel
         | 
| 4 | 
            -
              params:
         | 
| 5 | 
            -
                embed_dim: 256
         | 
| 6 | 
            -
                n_embed: 1024
         | 
| 7 | 
            -
                ddconfig:
         | 
| 8 | 
            -
                  double_z: False
         | 
| 9 | 
            -
                  z_channels: 256
         | 
| 10 | 
            -
                  resolution: 256
         | 
| 11 | 
            -
                  in_channels: 3
         | 
| 12 | 
            -
                  out_ch: 3
         | 
| 13 | 
            -
                  ch: 128
         | 
| 14 | 
            -
                  ch_mult: [ 1,1,2,2,4]  # num_down = len(ch_mult)-1
         | 
| 15 | 
            -
                  num_res_blocks: 2
         | 
| 16 | 
            -
                  attn_resolutions: [16]
         | 
| 17 | 
            -
                  dropout: 0.0
         | 
| 18 | 
            -
             | 
| 19 | 
            -
                lossconfig:
         | 
| 20 | 
            -
                  target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
         | 
| 21 | 
            -
                  params:
         | 
| 22 | 
            -
                    disc_conditional: False
         | 
| 23 | 
            -
                    disc_in_channels: 3
         | 
| 24 | 
            -
                    disc_start: 250001
         | 
| 25 | 
            -
                    disc_weight: 0.8
         | 
| 26 | 
            -
                    codebook_weight: 1.0
         | 
| 27 | 
            -
             | 
| 28 | 
            -
            data:
         | 
| 29 | 
            -
              target: main.DataModuleFromConfig
         | 
| 30 | 
            -
              params:
         | 
| 31 | 
            -
                batch_size: 12
         | 
| 32 | 
            -
                num_workers: 24
         | 
| 33 | 
            -
                train:
         | 
| 34 | 
            -
                  target: taming.data.imagenet.ImageNetTrain
         | 
| 35 | 
            -
                  params:
         | 
| 36 | 
            -
                    config:
         | 
| 37 | 
            -
                      size: 256
         | 
| 38 | 
            -
                validation:
         | 
| 39 | 
            -
                  target: taming.data.imagenet.ImageNetValidation
         | 
| 40 | 
            -
                  params:
         | 
| 41 | 
            -
                    config:
         | 
| 42 | 
            -
                      size: 256
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        taming-transformers/configs/imagenetdepth_vqgan.yaml
    DELETED
    
    | @@ -1,41 +0,0 @@ | |
| 1 | 
            -
            model:
         | 
| 2 | 
            -
              base_learning_rate: 4.5e-6
         | 
| 3 | 
            -
              target: taming.models.vqgan.VQModel
         | 
| 4 | 
            -
              params:
         | 
| 5 | 
            -
                embed_dim: 256
         | 
| 6 | 
            -
                n_embed: 1024
         | 
| 7 | 
            -
                image_key: depth
         | 
| 8 | 
            -
                ddconfig:
         | 
| 9 | 
            -
                  double_z: False
         | 
| 10 | 
            -
                  z_channels: 256
         | 
| 11 | 
            -
                  resolution: 256
         | 
| 12 | 
            -
                  in_channels: 1
         | 
| 13 | 
            -
                  out_ch: 1
         | 
| 14 | 
            -
                  ch: 128
         | 
| 15 | 
            -
                  ch_mult: [ 1,1,2,2,4]  # num_down = len(ch_mult)-1
         | 
| 16 | 
            -
                  num_res_blocks: 2
         | 
| 17 | 
            -
                  attn_resolutions: [16]
         | 
| 18 | 
            -
                  dropout: 0.0
         | 
| 19 | 
            -
             | 
| 20 | 
            -
                lossconfig:
         | 
| 21 | 
            -
                  target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
         | 
| 22 | 
            -
                  params:
         | 
| 23 | 
            -
                    disc_conditional: False
         | 
| 24 | 
            -
                    disc_in_channels: 1
         | 
| 25 | 
            -
                    disc_start: 50001
         | 
| 26 | 
            -
                    disc_weight: 0.75
         | 
| 27 | 
            -
                    codebook_weight: 1.0
         | 
| 28 | 
            -
             | 
| 29 | 
            -
            data:
         | 
| 30 | 
            -
              target: main.DataModuleFromConfig
         | 
| 31 | 
            -
              params:
         | 
| 32 | 
            -
                batch_size: 3
         | 
| 33 | 
            -
                num_workers: 8
         | 
| 34 | 
            -
                train:
         | 
| 35 | 
            -
                  target: taming.data.imagenet.ImageNetTrainWithDepth
         | 
| 36 | 
            -
                  params:
         | 
| 37 | 
            -
                    size: 256
         | 
| 38 | 
            -
                validation:
         | 
| 39 | 
            -
                  target: taming.data.imagenet.ImageNetValidationWithDepth
         | 
| 40 | 
            -
                  params:
         | 
| 41 | 
            -
                    size: 256
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        taming-transformers/configs/open_images_scene_images_transformer.yaml
    DELETED
    
    | @@ -1,86 +0,0 @@ | |
| 1 | 
            -
            model:
         | 
| 2 | 
            -
              base_learning_rate: 4.5e-06
         | 
| 3 | 
            -
              target: taming.models.cond_transformer.Net2NetTransformer
         | 
| 4 | 
            -
              params:
         | 
| 5 | 
            -
                cond_stage_key: objects_bbox
         | 
| 6 | 
            -
                transformer_config:
         | 
| 7 | 
            -
                  target: taming.modules.transformer.mingpt.GPT
         | 
| 8 | 
            -
                  params:
         | 
| 9 | 
            -
                    vocab_size: 8192
         | 
| 10 | 
            -
                    block_size: 348  # = 256 + 92 = dim(vqgan_latent_space,16x16) + dim(conditional_builder.embedding_dim)
         | 
| 11 | 
            -
                    n_layer: 36
         | 
| 12 | 
            -
                    n_head: 16
         | 
| 13 | 
            -
                    n_embd: 1536
         | 
| 14 | 
            -
                    embd_pdrop: 0.1
         | 
| 15 | 
            -
                    resid_pdrop: 0.1
         | 
| 16 | 
            -
                    attn_pdrop: 0.1
         | 
| 17 | 
            -
                first_stage_config:
         | 
| 18 | 
            -
                  target: taming.models.vqgan.VQModel
         | 
| 19 | 
            -
                  params:
         | 
| 20 | 
            -
                    ckpt_path: /path/to/coco_oi_epoch12.ckpt  # https://heibox.uni-heidelberg.de/f/461d9a9f4fcf48ab84f4/
         | 
| 21 | 
            -
                    embed_dim: 256
         | 
| 22 | 
            -
                    n_embed: 8192
         | 
| 23 | 
            -
                    ddconfig:
         | 
| 24 | 
            -
                      double_z: false
         | 
| 25 | 
            -
                      z_channels: 256
         | 
| 26 | 
            -
                      resolution: 256
         | 
| 27 | 
            -
                      in_channels: 3
         | 
| 28 | 
            -
                      out_ch: 3
         | 
| 29 | 
            -
                      ch: 128
         | 
| 30 | 
            -
                      ch_mult:
         | 
| 31 | 
            -
                      - 1
         | 
| 32 | 
            -
                      - 1
         | 
| 33 | 
            -
                      - 2
         | 
| 34 | 
            -
                      - 2
         | 
| 35 | 
            -
                      - 4
         | 
| 36 | 
            -
                      num_res_blocks: 2
         | 
| 37 | 
            -
                      attn_resolutions:
         | 
| 38 | 
            -
                      - 16
         | 
| 39 | 
            -
                      dropout: 0.0
         | 
| 40 | 
            -
                    lossconfig:
         | 
| 41 | 
            -
                      target: taming.modules.losses.DummyLoss
         | 
| 42 | 
            -
                cond_stage_config:
         | 
| 43 | 
            -
                  target: taming.models.dummy_cond_stage.DummyCondStage
         | 
| 44 | 
            -
                  params:
         | 
| 45 | 
            -
                    conditional_key: objects_bbox
         | 
| 46 | 
            -
             | 
| 47 | 
            -
            data:
         | 
| 48 | 
            -
              target: main.DataModuleFromConfig
         | 
| 49 | 
            -
              params:
         | 
| 50 | 
            -
                batch_size: 6
         | 
| 51 | 
            -
                train:
         | 
| 52 | 
            -
                  target: taming.data.annotated_objects_open_images.AnnotatedObjectsOpenImages
         | 
| 53 | 
            -
                  params:
         | 
| 54 | 
            -
                    data_path: data/open_images_annotations_100  # substitute with path to full dataset
         | 
| 55 | 
            -
                    split: train
         | 
| 56 | 
            -
                    keys: [image, objects_bbox, file_name, annotations]
         | 
| 57 | 
            -
                    no_tokens: 8192
         | 
| 58 | 
            -
                    target_image_size: 256
         | 
| 59 | 
            -
                    category_allow_list_target: taming.data.open_images_helper.top_300_classes_plus_coco_compatibility
         | 
| 60 | 
            -
                    category_mapping_target: taming.data.open_images_helper.open_images_unify_categories_for_coco
         | 
| 61 | 
            -
                    min_object_area: 0.0001
         | 
| 62 | 
            -
                    min_objects_per_image: 2
         | 
| 63 | 
            -
                    max_objects_per_image: 30
         | 
| 64 | 
            -
                    crop_method: random-2d
         | 
| 65 | 
            -
                    random_flip: true
         | 
| 66 | 
            -
                    use_group_parameter: true
         | 
| 67 | 
            -
                    use_additional_parameters: true
         | 
| 68 | 
            -
                    encode_crop: true
         | 
| 69 | 
            -
                validation:
         | 
| 70 | 
            -
                  target: taming.data.annotated_objects_open_images.AnnotatedObjectsOpenImages
         | 
| 71 | 
            -
                  params:
         | 
| 72 | 
            -
                    data_path: data/open_images_annotations_100  # substitute with path to full dataset
         | 
| 73 | 
            -
                    split: validation
         | 
| 74 | 
            -
                    keys: [image, objects_bbox, file_name, annotations]
         | 
| 75 | 
            -
                    no_tokens: 8192
         | 
| 76 | 
            -
                    target_image_size: 256
         | 
| 77 | 
            -
                    category_allow_list_target: taming.data.open_images_helper.top_300_classes_plus_coco_compatibility
         | 
| 78 | 
            -
                    category_mapping_target: taming.data.open_images_helper.open_images_unify_categories_for_coco
         | 
| 79 | 
            -
                    min_object_area: 0.0001
         | 
| 80 | 
            -
                    min_objects_per_image: 2
         | 
| 81 | 
            -
                    max_objects_per_image: 30
         | 
| 82 | 
            -
                    crop_method: center
         | 
| 83 | 
            -
                    random_flip: false
         | 
| 84 | 
            -
                    use_group_parameter: true
         | 
| 85 | 
            -
                    use_additional_parameters: true
         | 
| 86 | 
            -
                    encode_crop: true
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        taming-transformers/configs/sflckr_cond_stage.yaml
    DELETED
    
    | @@ -1,43 +0,0 @@ | |
| 1 | 
            -
            model:
         | 
| 2 | 
            -
              base_learning_rate: 4.5e-06
         | 
| 3 | 
            -
              target: taming.models.vqgan.VQSegmentationModel
         | 
| 4 | 
            -
              params:
         | 
| 5 | 
            -
                embed_dim: 256
         | 
| 6 | 
            -
                n_embed: 1024
         | 
| 7 | 
            -
                image_key: "segmentation"
         | 
| 8 | 
            -
                n_labels: 182
         | 
| 9 | 
            -
                ddconfig:
         | 
| 10 | 
            -
                  double_z: false
         | 
| 11 | 
            -
                  z_channels: 256
         | 
| 12 | 
            -
                  resolution: 256
         | 
| 13 | 
            -
                  in_channels: 182
         | 
| 14 | 
            -
                  out_ch: 182
         | 
| 15 | 
            -
                  ch: 128
         | 
| 16 | 
            -
                  ch_mult:
         | 
| 17 | 
            -
                  - 1
         | 
| 18 | 
            -
                  - 1
         | 
| 19 | 
            -
                  - 2
         | 
| 20 | 
            -
                  - 2
         | 
| 21 | 
            -
                  - 4
         | 
| 22 | 
            -
                  num_res_blocks: 2
         | 
| 23 | 
            -
                  attn_resolutions:
         | 
| 24 | 
            -
                  - 16
         | 
| 25 | 
            -
                  dropout: 0.0
         | 
| 26 | 
            -
             | 
| 27 | 
            -
                lossconfig:
         | 
| 28 | 
            -
                  target: taming.modules.losses.segmentation.BCELossWithQuant
         | 
| 29 | 
            -
                  params:
         | 
| 30 | 
            -
                    codebook_weight: 1.0
         | 
| 31 | 
            -
             | 
| 32 | 
            -
            data:
         | 
| 33 | 
            -
              target: cutlit.DataModuleFromConfig
         | 
| 34 | 
            -
              params:
         | 
| 35 | 
            -
                batch_size: 12
         | 
| 36 | 
            -
                train:
         | 
| 37 | 
            -
                  target: taming.data.sflckr.Examples # adjust
         | 
| 38 | 
            -
                  params:
         | 
| 39 | 
            -
                    size: 256
         | 
| 40 | 
            -
                validation:
         | 
| 41 | 
            -
                  target: taming.data.sflckr.Examples # adjust
         | 
| 42 | 
            -
                  params:
         | 
| 43 | 
            -
                    size: 256
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        taming-transformers/environment.yaml
    DELETED
    
    | @@ -1,25 +0,0 @@ | |
| 1 | 
            -
            name: taming
         | 
| 2 | 
            -
            channels:
         | 
| 3 | 
            -
              - pytorch
         | 
| 4 | 
            -
              - defaults
         | 
| 5 | 
            -
            dependencies:
         | 
| 6 | 
            -
              - python=3.8.5
         | 
| 7 | 
            -
              - pip=20.3
         | 
| 8 | 
            -
              - cudatoolkit=10.2
         | 
| 9 | 
            -
              - pytorch=1.7.0
         | 
| 10 | 
            -
              - torchvision=0.8.1
         | 
| 11 | 
            -
              - numpy=1.19.2
         | 
| 12 | 
            -
              - pip:
         | 
| 13 | 
            -
                - albumentations==0.4.3
         | 
| 14 | 
            -
                - opencv-python==4.1.2.30
         | 
| 15 | 
            -
                - pudb==2019.2
         | 
| 16 | 
            -
                - imageio==2.9.0
         | 
| 17 | 
            -
                - imageio-ffmpeg==0.4.2
         | 
| 18 | 
            -
                - pytorch-lightning==1.0.8
         | 
| 19 | 
            -
                - omegaconf==2.0.0
         | 
| 20 | 
            -
                - test-tube>=0.7.5
         | 
| 21 | 
            -
                - streamlit>=0.73.1
         | 
| 22 | 
            -
                - einops==0.3.0
         | 
| 23 | 
            -
                - more-itertools>=8.0.0
         | 
| 24 | 
            -
                - transformers==4.3.1
         | 
| 25 | 
            -
                - -e .
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        taming-transformers/main.py
    DELETED
    
    | @@ -1,585 +0,0 @@ | |
| 1 | 
            -
            import argparse, os, sys, datetime, glob, importlib
         | 
| 2 | 
            -
            from omegaconf import OmegaConf
         | 
| 3 | 
            -
            import numpy as np
         | 
| 4 | 
            -
            from PIL import Image
         | 
| 5 | 
            -
            import torch
         | 
| 6 | 
            -
            import torchvision
         | 
| 7 | 
            -
            from torch.utils.data import random_split, DataLoader, Dataset
         | 
| 8 | 
            -
            import pytorch_lightning as pl
         | 
| 9 | 
            -
            from pytorch_lightning import seed_everything
         | 
| 10 | 
            -
            from pytorch_lightning.trainer import Trainer
         | 
| 11 | 
            -
            from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor
         | 
| 12 | 
            -
            from pytorch_lightning.utilities import rank_zero_only
         | 
| 13 | 
            -
             | 
| 14 | 
            -
            from taming.data.utils import custom_collate
         | 
| 15 | 
            -
             | 
| 16 | 
            -
             | 
| 17 | 
            -
            def get_obj_from_str(string, reload=False):
         | 
| 18 | 
            -
                module, cls = string.rsplit(".", 1)
         | 
| 19 | 
            -
                if reload:
         | 
| 20 | 
            -
                    module_imp = importlib.import_module(module)
         | 
| 21 | 
            -
                    importlib.reload(module_imp)
         | 
| 22 | 
            -
                return getattr(importlib.import_module(module, package=None), cls)
         | 
| 23 | 
            -
             | 
| 24 | 
            -
             | 
| 25 | 
            -
            def get_parser(**parser_kwargs):
         | 
| 26 | 
            -
                def str2bool(v):
         | 
| 27 | 
            -
                    if isinstance(v, bool):
         | 
| 28 | 
            -
                        return v
         | 
| 29 | 
            -
                    if v.lower() in ("yes", "true", "t", "y", "1"):
         | 
| 30 | 
            -
                        return True
         | 
| 31 | 
            -
                    elif v.lower() in ("no", "false", "f", "n", "0"):
         | 
| 32 | 
            -
                        return False
         | 
| 33 | 
            -
                    else:
         | 
| 34 | 
            -
                        raise argparse.ArgumentTypeError("Boolean value expected.")
         | 
| 35 | 
            -
             | 
| 36 | 
            -
                parser = argparse.ArgumentParser(**parser_kwargs)
         | 
| 37 | 
            -
                parser.add_argument(
         | 
| 38 | 
            -
                    "-n",
         | 
| 39 | 
            -
                    "--name",
         | 
| 40 | 
            -
                    type=str,
         | 
| 41 | 
            -
                    const=True,
         | 
| 42 | 
            -
                    default="",
         | 
| 43 | 
            -
                    nargs="?",
         | 
| 44 | 
            -
                    help="postfix for logdir",
         | 
| 45 | 
            -
                )
         | 
| 46 | 
            -
                parser.add_argument(
         | 
| 47 | 
            -
                    "-r",
         | 
| 48 | 
            -
                    "--resume",
         | 
| 49 | 
            -
                    type=str,
         | 
| 50 | 
            -
                    const=True,
         | 
| 51 | 
            -
                    default="",
         | 
| 52 | 
            -
                    nargs="?",
         | 
| 53 | 
            -
                    help="resume from logdir or checkpoint in logdir",
         | 
| 54 | 
            -
                )
         | 
| 55 | 
            -
                parser.add_argument(
         | 
| 56 | 
            -
                    "-b",
         | 
| 57 | 
            -
                    "--base",
         | 
| 58 | 
            -
                    nargs="*",
         | 
| 59 | 
            -
                    metavar="base_config.yaml",
         | 
| 60 | 
            -
                    help="paths to base configs. Loaded from left-to-right. "
         | 
| 61 | 
            -
                    "Parameters can be overwritten or added with command-line options of the form `--key value`.",
         | 
| 62 | 
            -
                    default=list(),
         | 
| 63 | 
            -
                )
         | 
| 64 | 
            -
                parser.add_argument(
         | 
| 65 | 
            -
                    "-t",
         | 
| 66 | 
            -
                    "--train",
         | 
| 67 | 
            -
                    type=str2bool,
         | 
| 68 | 
            -
                    const=True,
         | 
| 69 | 
            -
                    default=False,
         | 
| 70 | 
            -
                    nargs="?",
         | 
| 71 | 
            -
                    help="train",
         | 
| 72 | 
            -
                )
         | 
| 73 | 
            -
                parser.add_argument(
         | 
| 74 | 
            -
                    "--no-test",
         | 
| 75 | 
            -
                    type=str2bool,
         | 
| 76 | 
            -
                    const=True,
         | 
| 77 | 
            -
                    default=False,
         | 
| 78 | 
            -
                    nargs="?",
         | 
| 79 | 
            -
                    help="disable test",
         | 
| 80 | 
            -
                )
         | 
| 81 | 
            -
                parser.add_argument("-p", "--project", help="name of new or path to existing project")
         | 
| 82 | 
            -
                parser.add_argument(
         | 
| 83 | 
            -
                    "-d",
         | 
| 84 | 
            -
                    "--debug",
         | 
| 85 | 
            -
                    type=str2bool,
         | 
| 86 | 
            -
                    nargs="?",
         | 
| 87 | 
            -
                    const=True,
         | 
| 88 | 
            -
                    default=False,
         | 
| 89 | 
            -
                    help="enable post-mortem debugging",
         | 
| 90 | 
            -
                )
         | 
| 91 | 
            -
                parser.add_argument(
         | 
| 92 | 
            -
                    "-s",
         | 
| 93 | 
            -
                    "--seed",
         | 
| 94 | 
            -
                    type=int,
         | 
| 95 | 
            -
                    default=23,
         | 
| 96 | 
            -
                    help="seed for seed_everything",
         | 
| 97 | 
            -
                )
         | 
| 98 | 
            -
                parser.add_argument(
         | 
| 99 | 
            -
                    "-f",
         | 
| 100 | 
            -
                    "--postfix",
         | 
| 101 | 
            -
                    type=str,
         | 
| 102 | 
            -
                    default="",
         | 
| 103 | 
            -
                    help="post-postfix for default name",
         | 
| 104 | 
            -
                )
         | 
| 105 | 
            -
             | 
| 106 | 
            -
                return parser
         | 
| 107 | 
            -
             | 
| 108 | 
            -
             | 
| 109 | 
            -
            def nondefault_trainer_args(opt):
         | 
| 110 | 
            -
                parser = argparse.ArgumentParser()
         | 
| 111 | 
            -
                parser = Trainer.add_argparse_args(parser)
         | 
| 112 | 
            -
                args = parser.parse_args([])
         | 
| 113 | 
            -
                return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k))
         | 
| 114 | 
            -
             | 
| 115 | 
            -
             | 
| 116 | 
            -
            def instantiate_from_config(config):
         | 
| 117 | 
            -
                if not "target" in config:
         | 
| 118 | 
            -
                    raise KeyError("Expected key `target` to instantiate.")
         | 
| 119 | 
            -
                return get_obj_from_str(config["target"])(**config.get("params", dict()))
         | 
| 120 | 
            -
             | 
| 121 | 
            -
             | 
| 122 | 
            -
            class WrappedDataset(Dataset):
         | 
| 123 | 
            -
                """Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset"""
         | 
| 124 | 
            -
                def __init__(self, dataset):
         | 
| 125 | 
            -
                    self.data = dataset
         | 
| 126 | 
            -
             | 
| 127 | 
            -
                def __len__(self):
         | 
| 128 | 
            -
                    return len(self.data)
         | 
| 129 | 
            -
             | 
| 130 | 
            -
                def __getitem__(self, idx):
         | 
| 131 | 
            -
                    return self.data[idx]
         | 
| 132 | 
            -
             | 
| 133 | 
            -
             | 
| 134 | 
            -
            class DataModuleFromConfig(pl.LightningDataModule):
         | 
| 135 | 
            -
                def __init__(self, batch_size, train=None, validation=None, test=None,
         | 
| 136 | 
            -
                             wrap=False, num_workers=None):
         | 
| 137 | 
            -
                    super().__init__()
         | 
| 138 | 
            -
                    self.batch_size = batch_size
         | 
| 139 | 
            -
                    self.dataset_configs = dict()
         | 
| 140 | 
            -
                    self.num_workers = num_workers if num_workers is not None else batch_size*2
         | 
| 141 | 
            -
                    if train is not None:
         | 
| 142 | 
            -
                        self.dataset_configs["train"] = train
         | 
| 143 | 
            -
                        self.train_dataloader = self._train_dataloader
         | 
| 144 | 
            -
                    if validation is not None:
         | 
| 145 | 
            -
                        self.dataset_configs["validation"] = validation
         | 
| 146 | 
            -
                        self.val_dataloader = self._val_dataloader
         | 
| 147 | 
            -
                    if test is not None:
         | 
| 148 | 
            -
                        self.dataset_configs["test"] = test
         | 
| 149 | 
            -
                        self.test_dataloader = self._test_dataloader
         | 
| 150 | 
            -
                    self.wrap = wrap
         | 
| 151 | 
            -
             | 
| 152 | 
            -
                def prepare_data(self):
         | 
| 153 | 
            -
                    for data_cfg in self.dataset_configs.values():
         | 
| 154 | 
            -
                        instantiate_from_config(data_cfg)
         | 
| 155 | 
            -
             | 
| 156 | 
            -
                def setup(self, stage=None):
         | 
| 157 | 
            -
                    self.datasets = dict(
         | 
| 158 | 
            -
                        (k, instantiate_from_config(self.dataset_configs[k]))
         | 
| 159 | 
            -
                        for k in self.dataset_configs)
         | 
| 160 | 
            -
                    if self.wrap:
         | 
| 161 | 
            -
                        for k in self.datasets:
         | 
| 162 | 
            -
                            self.datasets[k] = WrappedDataset(self.datasets[k])
         | 
| 163 | 
            -
             | 
| 164 | 
            -
                def _train_dataloader(self):
         | 
| 165 | 
            -
                    return DataLoader(self.datasets["train"], batch_size=self.batch_size,
         | 
| 166 | 
            -
                                      num_workers=self.num_workers, shuffle=True, collate_fn=custom_collate)
         | 
| 167 | 
            -
             | 
| 168 | 
            -
                def _val_dataloader(self):
         | 
| 169 | 
            -
                    return DataLoader(self.datasets["validation"],
         | 
| 170 | 
            -
                                      batch_size=self.batch_size,
         | 
| 171 | 
            -
                                      num_workers=self.num_workers, collate_fn=custom_collate)
         | 
| 172 | 
            -
             | 
| 173 | 
            -
                def _test_dataloader(self):
         | 
| 174 | 
            -
                    return DataLoader(self.datasets["test"], batch_size=self.batch_size,
         | 
| 175 | 
            -
                                      num_workers=self.num_workers, collate_fn=custom_collate)
         | 
| 176 | 
            -
             | 
| 177 | 
            -
             | 
| 178 | 
            -
            class SetupCallback(Callback):
         | 
| 179 | 
            -
                def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config):
         | 
| 180 | 
            -
                    super().__init__()
         | 
| 181 | 
            -
                    self.resume = resume
         | 
| 182 | 
            -
                    self.now = now
         | 
| 183 | 
            -
                    self.logdir = logdir
         | 
| 184 | 
            -
                    self.ckptdir = ckptdir
         | 
| 185 | 
            -
                    self.cfgdir = cfgdir
         | 
| 186 | 
            -
                    self.config = config
         | 
| 187 | 
            -
                    self.lightning_config = lightning_config
         | 
| 188 | 
            -
             | 
| 189 | 
            -
                def on_pretrain_routine_start(self, trainer, pl_module):
         | 
| 190 | 
            -
                    if trainer.global_rank == 0:
         | 
| 191 | 
            -
                        # Create logdirs and save configs
         | 
| 192 | 
            -
                        os.makedirs(self.logdir, exist_ok=True)
         | 
| 193 | 
            -
                        os.makedirs(self.ckptdir, exist_ok=True)
         | 
| 194 | 
            -
                        os.makedirs(self.cfgdir, exist_ok=True)
         | 
| 195 | 
            -
             | 
| 196 | 
            -
                        print("Project config")
         | 
| 197 | 
            -
                        print(self.config.pretty())
         | 
| 198 | 
            -
                        OmegaConf.save(self.config,
         | 
| 199 | 
            -
                                       os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)))
         | 
| 200 | 
            -
             | 
| 201 | 
            -
                        print("Lightning config")
         | 
| 202 | 
            -
                        print(self.lightning_config.pretty())
         | 
| 203 | 
            -
                        OmegaConf.save(OmegaConf.create({"lightning": self.lightning_config}),
         | 
| 204 | 
            -
                                       os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)))
         | 
| 205 | 
            -
             | 
| 206 | 
            -
                    else:
         | 
| 207 | 
            -
                        # ModelCheckpoint callback created log directory --- remove it
         | 
| 208 | 
            -
                        if not self.resume and os.path.exists(self.logdir):
         | 
| 209 | 
            -
                            dst, name = os.path.split(self.logdir)
         | 
| 210 | 
            -
                            dst = os.path.join(dst, "child_runs", name)
         | 
| 211 | 
            -
                            os.makedirs(os.path.split(dst)[0], exist_ok=True)
         | 
| 212 | 
            -
                            try:
         | 
| 213 | 
            -
                                os.rename(self.logdir, dst)
         | 
| 214 | 
            -
                            except FileNotFoundError:
         | 
| 215 | 
            -
                                pass
         | 
| 216 | 
            -
             | 
| 217 | 
            -
             | 
| 218 | 
            -
            class ImageLogger(Callback):
         | 
| 219 | 
            -
                def __init__(self, batch_frequency, max_images, clamp=True, increase_log_steps=True):
         | 
| 220 | 
            -
                    super().__init__()
         | 
| 221 | 
            -
                    self.batch_freq = batch_frequency
         | 
| 222 | 
            -
                    self.max_images = max_images
         | 
| 223 | 
            -
                    self.logger_log_images = {
         | 
| 224 | 
            -
                        pl.loggers.WandbLogger: self._wandb,
         | 
| 225 | 
            -
                        pl.loggers.TestTubeLogger: self._testtube,
         | 
| 226 | 
            -
                    }
         | 
| 227 | 
            -
                    self.log_steps = [2 ** n for n in range(int(np.log2(self.batch_freq)) + 1)]
         | 
| 228 | 
            -
                    if not increase_log_steps:
         | 
| 229 | 
            -
                        self.log_steps = [self.batch_freq]
         | 
| 230 | 
            -
                    self.clamp = clamp
         | 
| 231 | 
            -
             | 
| 232 | 
            -
                @rank_zero_only
         | 
| 233 | 
            -
                def _wandb(self, pl_module, images, batch_idx, split):
         | 
| 234 | 
            -
                    raise ValueError("No way wandb")
         | 
| 235 | 
            -
                    grids = dict()
         | 
| 236 | 
            -
                    for k in images:
         | 
| 237 | 
            -
                        grid = torchvision.utils.make_grid(images[k])
         | 
| 238 | 
            -
                        grids[f"{split}/{k}"] = wandb.Image(grid)
         | 
| 239 | 
            -
                    pl_module.logger.experiment.log(grids)
         | 
| 240 | 
            -
             | 
| 241 | 
            -
                @rank_zero_only
         | 
| 242 | 
            -
                def _testtube(self, pl_module, images, batch_idx, split):
         | 
| 243 | 
            -
                    for k in images:
         | 
| 244 | 
            -
                        grid = torchvision.utils.make_grid(images[k])
         | 
| 245 | 
            -
                        grid = (grid+1.0)/2.0 # -1,1 -> 0,1; c,h,w
         | 
| 246 | 
            -
             | 
| 247 | 
            -
                        tag = f"{split}/{k}"
         | 
| 248 | 
            -
                        pl_module.logger.experiment.add_image(
         | 
| 249 | 
            -
                            tag, grid,
         | 
| 250 | 
            -
                            global_step=pl_module.global_step)
         | 
| 251 | 
            -
             | 
| 252 | 
            -
                @rank_zero_only
         | 
| 253 | 
            -
                def log_local(self, save_dir, split, images,
         | 
| 254 | 
            -
                              global_step, current_epoch, batch_idx):
         | 
| 255 | 
            -
                    root = os.path.join(save_dir, "images", split)
         | 
| 256 | 
            -
                    for k in images:
         | 
| 257 | 
            -
                        grid = torchvision.utils.make_grid(images[k], nrow=4)
         | 
| 258 | 
            -
             | 
| 259 | 
            -
                        grid = (grid+1.0)/2.0 # -1,1 -> 0,1; c,h,w
         | 
| 260 | 
            -
                        grid = grid.transpose(0,1).transpose(1,2).squeeze(-1)
         | 
| 261 | 
            -
                        grid = grid.numpy()
         | 
| 262 | 
            -
                        grid = (grid*255).astype(np.uint8)
         | 
| 263 | 
            -
                        filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(
         | 
| 264 | 
            -
                            k,
         | 
| 265 | 
            -
                            global_step,
         | 
| 266 | 
            -
                            current_epoch,
         | 
| 267 | 
            -
                            batch_idx)
         | 
| 268 | 
            -
                        path = os.path.join(root, filename)
         | 
| 269 | 
            -
                        os.makedirs(os.path.split(path)[0], exist_ok=True)
         | 
| 270 | 
            -
                        Image.fromarray(grid).save(path)
         | 
| 271 | 
            -
             | 
| 272 | 
            -
                def log_img(self, pl_module, batch, batch_idx, split="train"):
         | 
| 273 | 
            -
                    if (self.check_frequency(batch_idx) and  # batch_idx % self.batch_freq == 0
         | 
| 274 | 
            -
                            hasattr(pl_module, "log_images") and
         | 
| 275 | 
            -
                            callable(pl_module.log_images) and
         | 
| 276 | 
            -
                            self.max_images > 0):
         | 
| 277 | 
            -
                        logger = type(pl_module.logger)
         | 
| 278 | 
            -
             | 
| 279 | 
            -
                        is_train = pl_module.training
         | 
| 280 | 
            -
                        if is_train:
         | 
| 281 | 
            -
                            pl_module.eval()
         | 
| 282 | 
            -
             | 
| 283 | 
            -
                        with torch.no_grad():
         | 
| 284 | 
            -
                            images = pl_module.log_images(batch, split=split, pl_module=pl_module)
         | 
| 285 | 
            -
             | 
| 286 | 
            -
                        for k in images:
         | 
| 287 | 
            -
                            N = min(images[k].shape[0], self.max_images)
         | 
| 288 | 
            -
                            images[k] = images[k][:N]
         | 
| 289 | 
            -
                            if isinstance(images[k], torch.Tensor):
         | 
| 290 | 
            -
                                images[k] = images[k].detach().cpu()
         | 
| 291 | 
            -
                                if self.clamp:
         | 
| 292 | 
            -
                                    images[k] = torch.clamp(images[k], -1., 1.)
         | 
| 293 | 
            -
             | 
| 294 | 
            -
                        self.log_local(pl_module.logger.save_dir, split, images,
         | 
| 295 | 
            -
                                       pl_module.global_step, pl_module.current_epoch, batch_idx)
         | 
| 296 | 
            -
             | 
| 297 | 
            -
                        logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None)
         | 
| 298 | 
            -
                        logger_log_images(pl_module, images, pl_module.global_step, split)
         | 
| 299 | 
            -
             | 
| 300 | 
            -
                        if is_train:
         | 
| 301 | 
            -
                            pl_module.train()
         | 
| 302 | 
            -
             | 
| 303 | 
            -
                def check_frequency(self, batch_idx):
         | 
| 304 | 
            -
                    if (batch_idx % self.batch_freq) == 0 or (batch_idx in self.log_steps):
         | 
| 305 | 
            -
                        try:
         | 
| 306 | 
            -
                            self.log_steps.pop(0)
         | 
| 307 | 
            -
                        except IndexError:
         | 
| 308 | 
            -
                            pass
         | 
| 309 | 
            -
                        return True
         | 
| 310 | 
            -
                    return False
         | 
| 311 | 
            -
             | 
| 312 | 
            -
                def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
         | 
| 313 | 
            -
                    self.log_img(pl_module, batch, batch_idx, split="train")
         | 
| 314 | 
            -
             | 
| 315 | 
            -
                def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
         | 
| 316 | 
            -
                    self.log_img(pl_module, batch, batch_idx, split="val")
         | 
| 317 | 
            -
             | 
| 318 | 
            -
             | 
| 319 | 
            -
             | 
| 320 | 
            -
            if __name__ == "__main__":
         | 
| 321 | 
            -
                # custom parser to specify config files, train, test and debug mode,
         | 
| 322 | 
            -
                # postfix, resume.
         | 
| 323 | 
            -
                # `--key value` arguments are interpreted as arguments to the trainer.
         | 
| 324 | 
            -
                # `nested.key=value` arguments are interpreted as config parameters.
         | 
| 325 | 
            -
                # configs are merged from left-to-right followed by command line parameters.
         | 
| 326 | 
            -
             | 
| 327 | 
            -
                # model:
         | 
| 328 | 
            -
                #   base_learning_rate: float
         | 
| 329 | 
            -
                #   target: path to lightning module
         | 
| 330 | 
            -
                #   params:
         | 
| 331 | 
            -
                #       key: value
         | 
| 332 | 
            -
                # data:
         | 
| 333 | 
            -
                #   target: main.DataModuleFromConfig
         | 
| 334 | 
            -
                #   params:
         | 
| 335 | 
            -
                #      batch_size: int
         | 
| 336 | 
            -
                #      wrap: bool
         | 
| 337 | 
            -
                #      train:
         | 
| 338 | 
            -
                #          target: path to train dataset
         | 
| 339 | 
            -
                #          params:
         | 
| 340 | 
            -
                #              key: value
         | 
| 341 | 
            -
                #      validation:
         | 
| 342 | 
            -
                #          target: path to validation dataset
         | 
| 343 | 
            -
                #          params:
         | 
| 344 | 
            -
                #              key: value
         | 
| 345 | 
            -
                #      test:
         | 
| 346 | 
            -
                #          target: path to test dataset
         | 
| 347 | 
            -
                #          params:
         | 
| 348 | 
            -
                #              key: value
         | 
| 349 | 
            -
                # lightning: (optional, has sane defaults and can be specified on cmdline)
         | 
| 350 | 
            -
                #   trainer:
         | 
| 351 | 
            -
                #       additional arguments to trainer
         | 
| 352 | 
            -
                #   logger:
         | 
| 353 | 
            -
                #       logger to instantiate
         | 
| 354 | 
            -
                #   modelcheckpoint:
         | 
| 355 | 
            -
                #       modelcheckpoint to instantiate
         | 
| 356 | 
            -
                #   callbacks:
         | 
| 357 | 
            -
                #       callback1:
         | 
| 358 | 
            -
                #           target: importpath
         | 
| 359 | 
            -
                #           params:
         | 
| 360 | 
            -
                #               key: value
         | 
| 361 | 
            -
             | 
| 362 | 
            -
                now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
         | 
| 363 | 
            -
             | 
| 364 | 
            -
                # add cwd for convenience and to make classes in this file available when
         | 
| 365 | 
            -
                # running as `python main.py`
         | 
| 366 | 
            -
                # (in particular `main.DataModuleFromConfig`)
         | 
| 367 | 
            -
                sys.path.append(os.getcwd())
         | 
| 368 | 
            -
             | 
| 369 | 
            -
                parser = get_parser()
         | 
| 370 | 
            -
                parser = Trainer.add_argparse_args(parser)
         | 
| 371 | 
            -
             | 
| 372 | 
            -
                opt, unknown = parser.parse_known_args()
         | 
| 373 | 
            -
                if opt.name and opt.resume:
         | 
| 374 | 
            -
                    raise ValueError(
         | 
| 375 | 
            -
                        "-n/--name and -r/--resume cannot be specified both."
         | 
| 376 | 
            -
                        "If you want to resume training in a new log folder, "
         | 
| 377 | 
            -
                        "use -n/--name in combination with --resume_from_checkpoint"
         | 
| 378 | 
            -
                    )
         | 
| 379 | 
            -
                if opt.resume:
         | 
| 380 | 
            -
                    if not os.path.exists(opt.resume):
         | 
| 381 | 
            -
                        raise ValueError("Cannot find {}".format(opt.resume))
         | 
| 382 | 
            -
                    if os.path.isfile(opt.resume):
         | 
| 383 | 
            -
                        paths = opt.resume.split("/")
         | 
| 384 | 
            -
                        idx = len(paths)-paths[::-1].index("logs")+1
         | 
| 385 | 
            -
                        logdir = "/".join(paths[:idx])
         | 
| 386 | 
            -
                        ckpt = opt.resume
         | 
| 387 | 
            -
                    else:
         | 
| 388 | 
            -
                        assert os.path.isdir(opt.resume), opt.resume
         | 
| 389 | 
            -
                        logdir = opt.resume.rstrip("/")
         | 
| 390 | 
            -
                        ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
         | 
| 391 | 
            -
             | 
| 392 | 
            -
                    opt.resume_from_checkpoint = ckpt
         | 
| 393 | 
            -
                    base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))
         | 
| 394 | 
            -
                    opt.base = base_configs+opt.base
         | 
| 395 | 
            -
                    _tmp = logdir.split("/")
         | 
| 396 | 
            -
                    nowname = _tmp[_tmp.index("logs")+1]
         | 
| 397 | 
            -
                else:
         | 
| 398 | 
            -
                    if opt.name:
         | 
| 399 | 
            -
                        name = "_"+opt.name
         | 
| 400 | 
            -
                    elif opt.base:
         | 
| 401 | 
            -
                        cfg_fname = os.path.split(opt.base[0])[-1]
         | 
| 402 | 
            -
                        cfg_name = os.path.splitext(cfg_fname)[0]
         | 
| 403 | 
            -
                        name = "_"+cfg_name
         | 
| 404 | 
            -
                    else:
         | 
| 405 | 
            -
                        name = ""
         | 
| 406 | 
            -
                    nowname = now+name+opt.postfix
         | 
| 407 | 
            -
                    logdir = os.path.join("logs", nowname)
         | 
| 408 | 
            -
             | 
| 409 | 
            -
                ckptdir = os.path.join(logdir, "checkpoints")
         | 
| 410 | 
            -
                cfgdir = os.path.join(logdir, "configs")
         | 
| 411 | 
            -
                seed_everything(opt.seed)
         | 
| 412 | 
            -
             | 
| 413 | 
            -
                try:
         | 
| 414 | 
            -
                    # init and save configs
         | 
| 415 | 
            -
                    configs = [OmegaConf.load(cfg) for cfg in opt.base]
         | 
| 416 | 
            -
                    cli = OmegaConf.from_dotlist(unknown)
         | 
| 417 | 
            -
                    config = OmegaConf.merge(*configs, cli)
         | 
| 418 | 
            -
                    lightning_config = config.pop("lightning", OmegaConf.create())
         | 
| 419 | 
            -
                    # merge trainer cli with config
         | 
| 420 | 
            -
                    trainer_config = lightning_config.get("trainer", OmegaConf.create())
         | 
| 421 | 
            -
                    # default to ddp
         | 
| 422 | 
            -
                    trainer_config["distributed_backend"] = "ddp"
         | 
| 423 | 
            -
                    for k in nondefault_trainer_args(opt):
         | 
| 424 | 
            -
                        trainer_config[k] = getattr(opt, k)
         | 
| 425 | 
            -
                    if not "gpus" in trainer_config:
         | 
| 426 | 
            -
                        del trainer_config["distributed_backend"]
         | 
| 427 | 
            -
                        cpu = True
         | 
| 428 | 
            -
                    else:
         | 
| 429 | 
            -
                        gpuinfo = trainer_config["gpus"]
         | 
| 430 | 
            -
                        print(f"Running on GPUs {gpuinfo}")
         | 
| 431 | 
            -
                        cpu = False
         | 
| 432 | 
            -
                    trainer_opt = argparse.Namespace(**trainer_config)
         | 
| 433 | 
            -
                    lightning_config.trainer = trainer_config
         | 
| 434 | 
            -
             | 
| 435 | 
            -
                    # model
         | 
| 436 | 
            -
                    model = instantiate_from_config(config.model)
         | 
| 437 | 
            -
             | 
| 438 | 
            -
                    # trainer and callbacks
         | 
| 439 | 
            -
                    trainer_kwargs = dict()
         | 
| 440 | 
            -
             | 
| 441 | 
            -
                    # default logger configs
         | 
| 442 | 
            -
                    # NOTE wandb < 0.10.0 interferes with shutdown
         | 
| 443 | 
            -
                    # wandb >= 0.10.0 seems to fix it but still interferes with pudb
         | 
| 444 | 
            -
                    # debugging (wrongly sized pudb ui)
         | 
| 445 | 
            -
                    # thus prefer testtube for now
         | 
| 446 | 
            -
                    default_logger_cfgs = {
         | 
| 447 | 
            -
                        "wandb": {
         | 
| 448 | 
            -
                            "target": "pytorch_lightning.loggers.WandbLogger",
         | 
| 449 | 
            -
                            "params": {
         | 
| 450 | 
            -
                                "name": nowname,
         | 
| 451 | 
            -
                                "save_dir": logdir,
         | 
| 452 | 
            -
                                "offline": opt.debug,
         | 
| 453 | 
            -
                                "id": nowname,
         | 
| 454 | 
            -
                            }
         | 
| 455 | 
            -
                        },
         | 
| 456 | 
            -
                        "testtube": {
         | 
| 457 | 
            -
                            "target": "pytorch_lightning.loggers.TestTubeLogger",
         | 
| 458 | 
            -
                            "params": {
         | 
| 459 | 
            -
                                "name": "testtube",
         | 
| 460 | 
            -
                                "save_dir": logdir,
         | 
| 461 | 
            -
                            }
         | 
| 462 | 
            -
                        },
         | 
| 463 | 
            -
                    }
         | 
| 464 | 
            -
                    default_logger_cfg = default_logger_cfgs["testtube"]
         | 
| 465 | 
            -
                    logger_cfg = lightning_config.logger or OmegaConf.create()
         | 
| 466 | 
            -
                    logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
         | 
| 467 | 
            -
                    trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)
         | 
| 468 | 
            -
             | 
| 469 | 
            -
                    # modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
         | 
| 470 | 
            -
                    # specify which metric is used to determine best models
         | 
| 471 | 
            -
                    default_modelckpt_cfg = {
         | 
| 472 | 
            -
                        "target": "pytorch_lightning.callbacks.ModelCheckpoint",
         | 
| 473 | 
            -
                        "params": {
         | 
| 474 | 
            -
                            "dirpath": ckptdir,
         | 
| 475 | 
            -
                            "filename": "{epoch:06}",
         | 
| 476 | 
            -
                            "verbose": True,
         | 
| 477 | 
            -
                            "save_last": True,
         | 
| 478 | 
            -
                        }
         | 
| 479 | 
            -
                    }
         | 
| 480 | 
            -
                    if hasattr(model, "monitor"):
         | 
| 481 | 
            -
                        print(f"Monitoring {model.monitor} as checkpoint metric.")
         | 
| 482 | 
            -
                        default_modelckpt_cfg["params"]["monitor"] = model.monitor
         | 
| 483 | 
            -
                        default_modelckpt_cfg["params"]["save_top_k"] = 3
         | 
| 484 | 
            -
             | 
| 485 | 
            -
                    modelckpt_cfg = lightning_config.modelcheckpoint or OmegaConf.create()
         | 
| 486 | 
            -
                    modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
         | 
| 487 | 
            -
                    trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg)
         | 
| 488 | 
            -
             | 
| 489 | 
            -
                    # add callback which sets up log directory
         | 
| 490 | 
            -
                    default_callbacks_cfg = {
         | 
| 491 | 
            -
                        "setup_callback": {
         | 
| 492 | 
            -
                            "target": "main.SetupCallback",
         | 
| 493 | 
            -
                            "params": {
         | 
| 494 | 
            -
                                "resume": opt.resume,
         | 
| 495 | 
            -
                                "now": now,
         | 
| 496 | 
            -
                                "logdir": logdir,
         | 
| 497 | 
            -
                                "ckptdir": ckptdir,
         | 
| 498 | 
            -
                                "cfgdir": cfgdir,
         | 
| 499 | 
            -
                                "config": config,
         | 
| 500 | 
            -
                                "lightning_config": lightning_config,
         | 
| 501 | 
            -
                            }
         | 
| 502 | 
            -
                        },
         | 
| 503 | 
            -
                        "image_logger": {
         | 
| 504 | 
            -
                            "target": "main.ImageLogger",
         | 
| 505 | 
            -
                            "params": {
         | 
| 506 | 
            -
                                "batch_frequency": 750,
         | 
| 507 | 
            -
                                "max_images": 4,
         | 
| 508 | 
            -
                                "clamp": True
         | 
| 509 | 
            -
                            }
         | 
| 510 | 
            -
                        },
         | 
| 511 | 
            -
                        "learning_rate_logger": {
         | 
| 512 | 
            -
                            "target": "main.LearningRateMonitor",
         | 
| 513 | 
            -
                            "params": {
         | 
| 514 | 
            -
                                "logging_interval": "step",
         | 
| 515 | 
            -
                                #"log_momentum": True
         | 
| 516 | 
            -
                            }
         | 
| 517 | 
            -
                        },
         | 
| 518 | 
            -
                    }
         | 
| 519 | 
            -
                    callbacks_cfg = lightning_config.callbacks or OmegaConf.create()
         | 
| 520 | 
            -
                    callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
         | 
| 521 | 
            -
                    trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
         | 
| 522 | 
            -
             | 
| 523 | 
            -
                    trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
         | 
| 524 | 
            -
             | 
| 525 | 
            -
                    # data
         | 
| 526 | 
            -
                    data = instantiate_from_config(config.data)
         | 
| 527 | 
            -
                    # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
         | 
| 528 | 
            -
                    # calling these ourselves should not be necessary but it is.
         | 
| 529 | 
            -
                    # lightning still takes care of proper multiprocessing though
         | 
| 530 | 
            -
                    data.prepare_data()
         | 
| 531 | 
            -
                    data.setup()
         | 
| 532 | 
            -
             | 
| 533 | 
            -
                    # configure learning rate
         | 
| 534 | 
            -
                    bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate
         | 
| 535 | 
            -
                    if not cpu:
         | 
| 536 | 
            -
                        ngpu = len(lightning_config.trainer.gpus.strip(",").split(','))
         | 
| 537 | 
            -
                    else:
         | 
| 538 | 
            -
                        ngpu = 1
         | 
| 539 | 
            -
                    accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches or 1
         | 
| 540 | 
            -
                    print(f"accumulate_grad_batches = {accumulate_grad_batches}")
         | 
| 541 | 
            -
                    lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
         | 
| 542 | 
            -
                    model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr
         | 
| 543 | 
            -
                    print("Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format(
         | 
| 544 | 
            -
                        model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr))
         | 
| 545 | 
            -
             | 
| 546 | 
            -
                    # allow checkpointing via USR1
         | 
| 547 | 
            -
                    def melk(*args, **kwargs):
         | 
| 548 | 
            -
                        # run all checkpoint hooks
         | 
| 549 | 
            -
                        if trainer.global_rank == 0:
         | 
| 550 | 
            -
                            print("Summoning checkpoint.")
         | 
| 551 | 
            -
                            ckpt_path = os.path.join(ckptdir, "last.ckpt")
         | 
| 552 | 
            -
                            trainer.save_checkpoint(ckpt_path)
         | 
| 553 | 
            -
             | 
| 554 | 
            -
                    def divein(*args, **kwargs):
         | 
| 555 | 
            -
                        if trainer.global_rank == 0:
         | 
| 556 | 
            -
                            import pudb; pudb.set_trace()
         | 
| 557 | 
            -
             | 
| 558 | 
            -
                    import signal
         | 
| 559 | 
            -
                    signal.signal(signal.SIGUSR1, melk)
         | 
| 560 | 
            -
                    signal.signal(signal.SIGUSR2, divein)
         | 
| 561 | 
            -
             | 
| 562 | 
            -
                    # run
         | 
| 563 | 
            -
                    if opt.train:
         | 
| 564 | 
            -
                        try:
         | 
| 565 | 
            -
                            trainer.fit(model, data)
         | 
| 566 | 
            -
                        except Exception:
         | 
| 567 | 
            -
                            melk()
         | 
| 568 | 
            -
                            raise
         | 
| 569 | 
            -
                    if not opt.no_test and not trainer.interrupted:
         | 
| 570 | 
            -
                        trainer.test(model, data)
         | 
| 571 | 
            -
                except Exception:
         | 
| 572 | 
            -
                    if opt.debug and trainer.global_rank==0:
         | 
| 573 | 
            -
                        try:
         | 
| 574 | 
            -
                            import pudb as debugger
         | 
| 575 | 
            -
                        except ImportError:
         | 
| 576 | 
            -
                            import pdb as debugger
         | 
| 577 | 
            -
                        debugger.post_mortem()
         | 
| 578 | 
            -
                    raise
         | 
| 579 | 
            -
                finally:
         | 
| 580 | 
            -
                    # move newly created debug project to debug_runs
         | 
| 581 | 
            -
                    if opt.debug and not opt.resume and trainer.global_rank==0:
         | 
| 582 | 
            -
                        dst, name = os.path.split(logdir)
         | 
| 583 | 
            -
                        dst = os.path.join(dst, "debug_runs", name)
         | 
| 584 | 
            -
                        os.makedirs(os.path.split(dst)[0], exist_ok=True)
         | 
| 585 | 
            -
                        os.rename(logdir, dst)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        taming-transformers/scripts/extract_depth.py
    DELETED
    
    | @@ -1,112 +0,0 @@ | |
| 1 | 
            -
            import os
         | 
| 2 | 
            -
            import torch
         | 
| 3 | 
            -
            import numpy as np
         | 
| 4 | 
            -
            from tqdm import trange
         | 
| 5 | 
            -
            from PIL import Image
         | 
| 6 | 
            -
             | 
| 7 | 
            -
             | 
| 8 | 
            -
            def get_state(gpu):
         | 
| 9 | 
            -
                import torch
         | 
| 10 | 
            -
                midas = torch.hub.load("intel-isl/MiDaS", "MiDaS")
         | 
| 11 | 
            -
                if gpu:
         | 
| 12 | 
            -
                    midas.cuda()
         | 
| 13 | 
            -
                midas.eval()
         | 
| 14 | 
            -
             | 
| 15 | 
            -
                midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
         | 
| 16 | 
            -
                transform = midas_transforms.default_transform
         | 
| 17 | 
            -
             | 
| 18 | 
            -
                state = {"model": midas,
         | 
| 19 | 
            -
                         "transform": transform}
         | 
| 20 | 
            -
                return state
         | 
| 21 | 
            -
             | 
| 22 | 
            -
             | 
| 23 | 
            -
            def depth_to_rgba(x):
         | 
| 24 | 
            -
                assert x.dtype == np.float32
         | 
| 25 | 
            -
                assert len(x.shape) == 2
         | 
| 26 | 
            -
                y = x.copy()
         | 
| 27 | 
            -
                y.dtype = np.uint8
         | 
| 28 | 
            -
                y = y.reshape(x.shape+(4,))
         | 
| 29 | 
            -
                return np.ascontiguousarray(y)
         | 
| 30 | 
            -
             | 
| 31 | 
            -
             | 
| 32 | 
            -
            def rgba_to_depth(x):
         | 
| 33 | 
            -
                assert x.dtype == np.uint8
         | 
| 34 | 
            -
                assert len(x.shape) == 3 and x.shape[2] == 4
         | 
| 35 | 
            -
                y = x.copy()
         | 
| 36 | 
            -
                y.dtype = np.float32
         | 
| 37 | 
            -
                y = y.reshape(x.shape[:2])
         | 
| 38 | 
            -
                return np.ascontiguousarray(y)
         | 
| 39 | 
            -
             | 
| 40 | 
            -
             | 
| 41 | 
            -
            def run(x, state):
         | 
| 42 | 
            -
                model = state["model"]
         | 
| 43 | 
            -
                transform = state["transform"]
         | 
| 44 | 
            -
                hw = x.shape[:2]
         | 
| 45 | 
            -
                with torch.no_grad():
         | 
| 46 | 
            -
                    prediction = model(transform((x + 1.0) * 127.5).cuda())
         | 
| 47 | 
            -
                    prediction = torch.nn.functional.interpolate(
         | 
| 48 | 
            -
                        prediction.unsqueeze(1),
         | 
| 49 | 
            -
                        size=hw,
         | 
| 50 | 
            -
                        mode="bicubic",
         | 
| 51 | 
            -
                        align_corners=False,
         | 
| 52 | 
            -
                    ).squeeze()
         | 
| 53 | 
            -
                    output = prediction.cpu().numpy()
         | 
| 54 | 
            -
                return output
         | 
| 55 | 
            -
             | 
| 56 | 
            -
             | 
| 57 | 
            -
            def get_filename(relpath, level=-2):
         | 
| 58 | 
            -
                # save class folder structure and filename:
         | 
| 59 | 
            -
                fn = relpath.split(os.sep)[level:]
         | 
| 60 | 
            -
                folder = fn[-2]
         | 
| 61 | 
            -
                file   = fn[-1].split('.')[0]
         | 
| 62 | 
            -
                return folder, file
         | 
| 63 | 
            -
             | 
| 64 | 
            -
             | 
| 65 | 
            -
            def save_depth(dataset, path, debug=False):
         | 
| 66 | 
            -
                os.makedirs(path)
         | 
| 67 | 
            -
                N = len(dset)
         | 
| 68 | 
            -
                if debug:
         | 
| 69 | 
            -
                    N = 10
         | 
| 70 | 
            -
                state = get_state(gpu=True)
         | 
| 71 | 
            -
                for idx in trange(N, desc="Data"):
         | 
| 72 | 
            -
                    ex = dataset[idx]
         | 
| 73 | 
            -
                    image, relpath = ex["image"], ex["relpath"]
         | 
| 74 | 
            -
                    folder, filename = get_filename(relpath)
         | 
| 75 | 
            -
                    # prepare
         | 
| 76 | 
            -
                    folderabspath = os.path.join(path, folder)
         | 
| 77 | 
            -
                    os.makedirs(folderabspath, exist_ok=True)
         | 
| 78 | 
            -
                    savepath = os.path.join(folderabspath, filename)
         | 
| 79 | 
            -
                    # run model
         | 
| 80 | 
            -
                    xout = run(image, state)
         | 
| 81 | 
            -
                    I = depth_to_rgba(xout)
         | 
| 82 | 
            -
                    Image.fromarray(I).save("{}.png".format(savepath))
         | 
| 83 | 
            -
             | 
| 84 | 
            -
             | 
| 85 | 
            -
            if __name__ == "__main__":
         | 
| 86 | 
            -
                from taming.data.imagenet import ImageNetTrain, ImageNetValidation
         | 
| 87 | 
            -
                out = "data/imagenet_depth"
         | 
| 88 | 
            -
                if not os.path.exists(out):
         | 
| 89 | 
            -
                    print("Please create a folder or symlink '{}' to extract depth data ".format(out) +
         | 
| 90 | 
            -
                          "(be prepared that the output size will be larger than ImageNet itself).")
         | 
| 91 | 
            -
                    exit(1)
         | 
| 92 | 
            -
             | 
| 93 | 
            -
                # go
         | 
| 94 | 
            -
                dset = ImageNetValidation()
         | 
| 95 | 
            -
                abspath = os.path.join(out, "val")
         | 
| 96 | 
            -
                if os.path.exists(abspath):
         | 
| 97 | 
            -
                    print("{} exists - not doing anything.".format(abspath))
         | 
| 98 | 
            -
                else:
         | 
| 99 | 
            -
                    print("preparing {}".format(abspath))
         | 
| 100 | 
            -
                    save_depth(dset, abspath)
         | 
| 101 | 
            -
                    print("done with validation split")
         | 
| 102 | 
            -
             | 
| 103 | 
            -
                dset = ImageNetTrain()
         | 
| 104 | 
            -
                abspath = os.path.join(out, "train")
         | 
| 105 | 
            -
                if os.path.exists(abspath):
         | 
| 106 | 
            -
                    print("{} exists - not doing anything.".format(abspath))
         | 
| 107 | 
            -
                else:
         | 
| 108 | 
            -
                    print("preparing {}".format(abspath))
         | 
| 109 | 
            -
                    save_depth(dset, abspath)
         | 
| 110 | 
            -
                    print("done with train split")
         | 
| 111 | 
            -
             | 
| 112 | 
            -
                print("done done.")
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        taming-transformers/scripts/extract_segmentation.py
    DELETED
    
    | @@ -1,130 +0,0 @@ | |
| 1 | 
            -
            import sys, os
         | 
| 2 | 
            -
            import numpy as np
         | 
| 3 | 
            -
            import scipy
         | 
| 4 | 
            -
            import torch
         | 
| 5 | 
            -
            import torch.nn as nn
         | 
| 6 | 
            -
            from scipy import ndimage
         | 
| 7 | 
            -
            from tqdm import tqdm, trange
         | 
| 8 | 
            -
            from PIL import Image
         | 
| 9 | 
            -
            import torch.hub
         | 
| 10 | 
            -
            import torchvision
         | 
| 11 | 
            -
            import torch.nn.functional as F
         | 
| 12 | 
            -
             | 
| 13 | 
            -
            # download deeplabv2_resnet101_msc-cocostuff164k-100000.pth from
         | 
| 14 | 
            -
            # https://github.com/kazuto1011/deeplab-pytorch/releases/download/v1.0/deeplabv2_resnet101_msc-cocostuff164k-100000.pth
         | 
| 15 | 
            -
            # and put the path here
         | 
| 16 | 
            -
            CKPT_PATH = "TODO"
         | 
| 17 | 
            -
             | 
| 18 | 
            -
            rescale = lambda x: (x + 1.) / 2.
         | 
| 19 | 
            -
             | 
| 20 | 
            -
            def rescale_bgr(x):
         | 
| 21 | 
            -
                x = (x+1)*127.5
         | 
| 22 | 
            -
                x = torch.flip(x, dims=[0])
         | 
| 23 | 
            -
                return x
         | 
| 24 | 
            -
             | 
| 25 | 
            -
             | 
| 26 | 
            -
            class COCOStuffSegmenter(nn.Module):
         | 
| 27 | 
            -
                def __init__(self, config):
         | 
| 28 | 
            -
                    super().__init__()
         | 
| 29 | 
            -
                    self.config = config
         | 
| 30 | 
            -
                    self.n_labels = 182
         | 
| 31 | 
            -
                    model = torch.hub.load("kazuto1011/deeplab-pytorch", "deeplabv2_resnet101", n_classes=self.n_labels)
         | 
| 32 | 
            -
                    ckpt_path = CKPT_PATH
         | 
| 33 | 
            -
                    model.load_state_dict(torch.load(ckpt_path))
         | 
| 34 | 
            -
                    self.model = model
         | 
| 35 | 
            -
             | 
| 36 | 
            -
                    normalize = torchvision.transforms.Normalize(mean=self.mean, std=self.std)
         | 
| 37 | 
            -
                    self.image_transform = torchvision.transforms.Compose([
         | 
| 38 | 
            -
                        torchvision.transforms.Lambda(lambda image: torch.stack(
         | 
| 39 | 
            -
                            [normalize(rescale_bgr(x)) for x in image]))
         | 
| 40 | 
            -
                    ])
         | 
| 41 | 
            -
             | 
| 42 | 
            -
                def forward(self, x, upsample=None):
         | 
| 43 | 
            -
                    x = self._pre_process(x)
         | 
| 44 | 
            -
                    x = self.model(x)
         | 
| 45 | 
            -
                    if upsample is not None:
         | 
| 46 | 
            -
                        x = torch.nn.functional.upsample_bilinear(x, size=upsample)
         | 
| 47 | 
            -
                    return x
         | 
| 48 | 
            -
             | 
| 49 | 
            -
                def _pre_process(self, x):
         | 
| 50 | 
            -
                    x = self.image_transform(x)
         | 
| 51 | 
            -
                    return x
         | 
| 52 | 
            -
             | 
| 53 | 
            -
                @property
         | 
| 54 | 
            -
                def mean(self):
         | 
| 55 | 
            -
                    # bgr
         | 
| 56 | 
            -
                    return [104.008, 116.669, 122.675]
         | 
| 57 | 
            -
             | 
| 58 | 
            -
                @property
         | 
| 59 | 
            -
                def std(self):
         | 
| 60 | 
            -
                    return [1.0, 1.0, 1.0]
         | 
| 61 | 
            -
             | 
| 62 | 
            -
                @property
         | 
| 63 | 
            -
                def input_size(self):
         | 
| 64 | 
            -
                    return [3, 224, 224]
         | 
| 65 | 
            -
             | 
| 66 | 
            -
             | 
| 67 | 
            -
            def run_model(img, model):
         | 
| 68 | 
            -
                model = model.eval()
         | 
| 69 | 
            -
                with torch.no_grad():
         | 
| 70 | 
            -
                    segmentation = model(img, upsample=(img.shape[2], img.shape[3]))
         | 
| 71 | 
            -
                    segmentation = torch.argmax(segmentation, dim=1, keepdim=True)
         | 
| 72 | 
            -
                return segmentation.detach().cpu()
         | 
| 73 | 
            -
             | 
| 74 | 
            -
             | 
| 75 | 
            -
            def get_input(batch, k):
         | 
| 76 | 
            -
                x = batch[k]
         | 
| 77 | 
            -
                if len(x.shape) == 3:
         | 
| 78 | 
            -
                    x = x[..., None]
         | 
| 79 | 
            -
                x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
         | 
| 80 | 
            -
                return x.float()
         | 
| 81 | 
            -
             | 
| 82 | 
            -
             | 
| 83 | 
            -
            def save_segmentation(segmentation, path):
         | 
| 84 | 
            -
                # --> class label to uint8, save as png
         | 
| 85 | 
            -
                os.makedirs(os.path.dirname(path), exist_ok=True)
         | 
| 86 | 
            -
                assert len(segmentation.shape)==4
         | 
| 87 | 
            -
                assert segmentation.shape[0]==1
         | 
| 88 | 
            -
                for seg in segmentation:
         | 
| 89 | 
            -
                    seg = seg.permute(1,2,0).numpy().squeeze().astype(np.uint8)
         | 
| 90 | 
            -
                    seg = Image.fromarray(seg)
         | 
| 91 | 
            -
                    seg.save(path)
         | 
| 92 | 
            -
             | 
| 93 | 
            -
             | 
| 94 | 
            -
            def iterate_dataset(dataloader, destpath, model):
         | 
| 95 | 
            -
                os.makedirs(destpath, exist_ok=True)
         | 
| 96 | 
            -
                num_processed = 0
         | 
| 97 | 
            -
                for i, batch in tqdm(enumerate(dataloader), desc="Data"):
         | 
| 98 | 
            -
                    try:
         | 
| 99 | 
            -
                        img = get_input(batch, "image")
         | 
| 100 | 
            -
                        img = img.cuda()
         | 
| 101 | 
            -
                        seg = run_model(img, model)
         | 
| 102 | 
            -
             | 
| 103 | 
            -
                        path = batch["relative_file_path_"][0]
         | 
| 104 | 
            -
                        path = os.path.splitext(path)[0]
         | 
| 105 | 
            -
             | 
| 106 | 
            -
                        path = os.path.join(destpath, path + ".png")
         | 
| 107 | 
            -
                        save_segmentation(seg, path)
         | 
| 108 | 
            -
                        num_processed += 1
         | 
| 109 | 
            -
                    except Exception as e:
         | 
| 110 | 
            -
                        print(e)
         | 
| 111 | 
            -
                        print("but anyhow..")
         | 
| 112 | 
            -
             | 
| 113 | 
            -
                print("Processed {} files. Bye.".format(num_processed))
         | 
| 114 | 
            -
             | 
| 115 | 
            -
             | 
| 116 | 
            -
            from taming.data.sflckr import Examples
         | 
| 117 | 
            -
            from torch.utils.data import DataLoader
         | 
| 118 | 
            -
             | 
| 119 | 
            -
            if __name__ == "__main__":
         | 
| 120 | 
            -
                dest = sys.argv[1]
         | 
| 121 | 
            -
                batchsize = 1
         | 
| 122 | 
            -
                print("Running with batch-size {}, saving to {}...".format(batchsize, dest))
         | 
| 123 | 
            -
             | 
| 124 | 
            -
                model = COCOStuffSegmenter({}).cuda()
         | 
| 125 | 
            -
                print("Instantiated model.")
         | 
| 126 | 
            -
             | 
| 127 | 
            -
                dataset = Examples()
         | 
| 128 | 
            -
                dloader = DataLoader(dataset, batch_size=batchsize)
         | 
| 129 | 
            -
                iterate_dataset(dataloader=dloader, destpath=dest, model=model)
         | 
| 130 | 
            -
                print("done.")
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        taming-transformers/scripts/extract_submodel.py
    DELETED
    
    | @@ -1,17 +0,0 @@ | |
| 1 | 
            -
            import torch
         | 
| 2 | 
            -
            import sys
         | 
| 3 | 
            -
             | 
| 4 | 
            -
            if __name__ == "__main__":
         | 
| 5 | 
            -
                inpath = sys.argv[1]
         | 
| 6 | 
            -
                outpath = sys.argv[2]
         | 
| 7 | 
            -
                submodel = "cond_stage_model"
         | 
| 8 | 
            -
                if len(sys.argv) > 3:
         | 
| 9 | 
            -
                    submodel = sys.argv[3]
         | 
| 10 | 
            -
             | 
| 11 | 
            -
                print("Extracting {} from {} to {}.".format(submodel, inpath, outpath))
         | 
| 12 | 
            -
             | 
| 13 | 
            -
                sd = torch.load(inpath, map_location="cpu")
         | 
| 14 | 
            -
                new_sd = {"state_dict": dict((k.split(".", 1)[-1],v)
         | 
| 15 | 
            -
                                             for k,v in sd["state_dict"].items()
         | 
| 16 | 
            -
                                             if k.startswith("cond_stage_model"))}
         | 
| 17 | 
            -
                torch.save(new_sd, outpath)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        taming-transformers/scripts/make_samples.py
    DELETED
    
    | @@ -1,292 +0,0 @@ | |
| 1 | 
            -
            import argparse, os, sys, glob, math, time
         | 
| 2 | 
            -
            import torch
         | 
| 3 | 
            -
            import numpy as np
         | 
| 4 | 
            -
            from omegaconf import OmegaConf
         | 
| 5 | 
            -
            from PIL import Image
         | 
| 6 | 
            -
            from main import instantiate_from_config, DataModuleFromConfig
         | 
| 7 | 
            -
            from torch.utils.data import DataLoader
         | 
| 8 | 
            -
            from torch.utils.data.dataloader import default_collate
         | 
| 9 | 
            -
            from tqdm import trange
         | 
| 10 | 
            -
             | 
| 11 | 
            -
             | 
| 12 | 
            -
            def save_image(x, path):
         | 
| 13 | 
            -
                c,h,w = x.shape
         | 
| 14 | 
            -
                assert c==3
         | 
| 15 | 
            -
                x = ((x.detach().cpu().numpy().transpose(1,2,0)+1.0)*127.5).clip(0,255).astype(np.uint8)
         | 
| 16 | 
            -
                Image.fromarray(x).save(path)
         | 
| 17 | 
            -
             | 
| 18 | 
            -
             | 
| 19 | 
            -
            @torch.no_grad()
         | 
| 20 | 
            -
            def run_conditional(model, dsets, outdir, top_k, temperature, batch_size=1):
         | 
| 21 | 
            -
                if len(dsets.datasets) > 1:
         | 
| 22 | 
            -
                    split = sorted(dsets.datasets.keys())[0]
         | 
| 23 | 
            -
                    dset = dsets.datasets[split]
         | 
| 24 | 
            -
                else:
         | 
| 25 | 
            -
                    dset = next(iter(dsets.datasets.values()))
         | 
| 26 | 
            -
                print("Dataset: ", dset.__class__.__name__)
         | 
| 27 | 
            -
                for start_idx in trange(0,len(dset)-batch_size+1,batch_size):
         | 
| 28 | 
            -
                    indices = list(range(start_idx, start_idx+batch_size))
         | 
| 29 | 
            -
                    example = default_collate([dset[i] for i in indices])
         | 
| 30 | 
            -
             | 
| 31 | 
            -
                    x = model.get_input("image", example).to(model.device)
         | 
| 32 | 
            -
                    for i in range(x.shape[0]):
         | 
| 33 | 
            -
                        save_image(x[i], os.path.join(outdir, "originals",
         | 
| 34 | 
            -
                                                      "{:06}.png".format(indices[i])))
         | 
| 35 | 
            -
             | 
| 36 | 
            -
                    cond_key = model.cond_stage_key
         | 
| 37 | 
            -
                    c = model.get_input(cond_key, example).to(model.device)
         | 
| 38 | 
            -
             | 
| 39 | 
            -
                    scale_factor = 1.0
         | 
| 40 | 
            -
                    quant_z, z_indices = model.encode_to_z(x)
         | 
| 41 | 
            -
                    quant_c, c_indices = model.encode_to_c(c)
         | 
| 42 | 
            -
             | 
| 43 | 
            -
                    cshape = quant_z.shape
         | 
| 44 | 
            -
             | 
| 45 | 
            -
                    xrec = model.first_stage_model.decode(quant_z)
         | 
| 46 | 
            -
                    for i in range(xrec.shape[0]):
         | 
| 47 | 
            -
                        save_image(xrec[i], os.path.join(outdir, "reconstructions",
         | 
| 48 | 
            -
                                                         "{:06}.png".format(indices[i])))
         | 
| 49 | 
            -
             | 
| 50 | 
            -
                    if cond_key == "segmentation":
         | 
| 51 | 
            -
                        # get image from segmentation mask
         | 
| 52 | 
            -
                        num_classes = c.shape[1]
         | 
| 53 | 
            -
                        c = torch.argmax(c, dim=1, keepdim=True)
         | 
| 54 | 
            -
                        c = torch.nn.functional.one_hot(c, num_classes=num_classes)
         | 
| 55 | 
            -
                        c = c.squeeze(1).permute(0, 3, 1, 2).float()
         | 
| 56 | 
            -
                        c = model.cond_stage_model.to_rgb(c)
         | 
| 57 | 
            -
             | 
| 58 | 
            -
                    idx = z_indices
         | 
| 59 | 
            -
             | 
| 60 | 
            -
                    half_sample = False
         | 
| 61 | 
            -
                    if half_sample:
         | 
| 62 | 
            -
                        start = idx.shape[1]//2
         | 
| 63 | 
            -
                    else:
         | 
| 64 | 
            -
                        start = 0
         | 
| 65 | 
            -
             | 
| 66 | 
            -
                    idx[:,start:] = 0
         | 
| 67 | 
            -
                    idx = idx.reshape(cshape[0],cshape[2],cshape[3])
         | 
| 68 | 
            -
                    start_i = start//cshape[3]
         | 
| 69 | 
            -
                    start_j = start %cshape[3]
         | 
| 70 | 
            -
             | 
| 71 | 
            -
                    cidx = c_indices
         | 
| 72 | 
            -
                    cidx = cidx.reshape(quant_c.shape[0],quant_c.shape[2],quant_c.shape[3])
         | 
| 73 | 
            -
             | 
| 74 | 
            -
                    sample = True
         | 
| 75 | 
            -
             | 
| 76 | 
            -
                    for i in range(start_i,cshape[2]-0):
         | 
| 77 | 
            -
                        if i <= 8:
         | 
| 78 | 
            -
                            local_i = i
         | 
| 79 | 
            -
                        elif cshape[2]-i < 8:
         | 
| 80 | 
            -
                            local_i = 16-(cshape[2]-i)
         | 
| 81 | 
            -
                        else:
         | 
| 82 | 
            -
                            local_i = 8
         | 
| 83 | 
            -
                        for j in range(start_j,cshape[3]-0):
         | 
| 84 | 
            -
                            if j <= 8:
         | 
| 85 | 
            -
                                local_j = j
         | 
| 86 | 
            -
                            elif cshape[3]-j < 8:
         | 
| 87 | 
            -
                                local_j = 16-(cshape[3]-j)
         | 
| 88 | 
            -
                            else:
         | 
| 89 | 
            -
                                local_j = 8
         | 
| 90 | 
            -
             | 
| 91 | 
            -
                            i_start = i-local_i
         | 
| 92 | 
            -
                            i_end = i_start+16
         | 
| 93 | 
            -
                            j_start = j-local_j
         | 
| 94 | 
            -
                            j_end = j_start+16
         | 
| 95 | 
            -
                            patch = idx[:,i_start:i_end,j_start:j_end]
         | 
| 96 | 
            -
                            patch = patch.reshape(patch.shape[0],-1)
         | 
| 97 | 
            -
                            cpatch = cidx[:, i_start:i_end, j_start:j_end]
         | 
| 98 | 
            -
                            cpatch = cpatch.reshape(cpatch.shape[0], -1)
         | 
| 99 | 
            -
                            patch = torch.cat((cpatch, patch), dim=1)
         | 
| 100 | 
            -
                            logits,_ = model.transformer(patch[:,:-1])
         | 
| 101 | 
            -
                            logits = logits[:, -256:, :]
         | 
| 102 | 
            -
                            logits = logits.reshape(cshape[0],16,16,-1)
         | 
| 103 | 
            -
                            logits = logits[:,local_i,local_j,:]
         | 
| 104 | 
            -
             | 
| 105 | 
            -
                            logits = logits/temperature
         | 
| 106 | 
            -
             | 
| 107 | 
            -
                            if top_k is not None:
         | 
| 108 | 
            -
                                logits = model.top_k_logits(logits, top_k)
         | 
| 109 | 
            -
                            # apply softmax to convert to probabilities
         | 
| 110 | 
            -
                            probs = torch.nn.functional.softmax(logits, dim=-1)
         | 
| 111 | 
            -
                            # sample from the distribution or take the most likely
         | 
| 112 | 
            -
                            if sample:
         | 
| 113 | 
            -
                                ix = torch.multinomial(probs, num_samples=1)
         | 
| 114 | 
            -
                            else:
         | 
| 115 | 
            -
                                _, ix = torch.topk(probs, k=1, dim=-1)
         | 
| 116 | 
            -
                            idx[:,i,j] = ix
         | 
| 117 | 
            -
             | 
| 118 | 
            -
                    xsample = model.decode_to_img(idx[:,:cshape[2],:cshape[3]], cshape)
         | 
| 119 | 
            -
                    for i in range(xsample.shape[0]):
         | 
| 120 | 
            -
                        save_image(xsample[i], os.path.join(outdir, "samples",
         | 
| 121 | 
            -
                                                            "{:06}.png".format(indices[i])))
         | 
| 122 | 
            -
             | 
| 123 | 
            -
             | 
| 124 | 
            -
            def get_parser():
         | 
| 125 | 
            -
                parser = argparse.ArgumentParser()
         | 
| 126 | 
            -
                parser.add_argument(
         | 
| 127 | 
            -
                    "-r",
         | 
| 128 | 
            -
                    "--resume",
         | 
| 129 | 
            -
                    type=str,
         | 
| 130 | 
            -
                    nargs="?",
         | 
| 131 | 
            -
                    help="load from logdir or checkpoint in logdir",
         | 
| 132 | 
            -
                )
         | 
| 133 | 
            -
                parser.add_argument(
         | 
| 134 | 
            -
                    "-b",
         | 
| 135 | 
            -
                    "--base",
         | 
| 136 | 
            -
                    nargs="*",
         | 
| 137 | 
            -
                    metavar="base_config.yaml",
         | 
| 138 | 
            -
                    help="paths to base configs. Loaded from left-to-right. "
         | 
| 139 | 
            -
                    "Parameters can be overwritten or added with command-line options of the form `--key value`.",
         | 
| 140 | 
            -
                    default=list(),
         | 
| 141 | 
            -
                )
         | 
| 142 | 
            -
                parser.add_argument(
         | 
| 143 | 
            -
                    "-c",
         | 
| 144 | 
            -
                    "--config",
         | 
| 145 | 
            -
                    nargs="?",
         | 
| 146 | 
            -
                    metavar="single_config.yaml",
         | 
| 147 | 
            -
                    help="path to single config. If specified, base configs will be ignored "
         | 
| 148 | 
            -
                    "(except for the last one if left unspecified).",
         | 
| 149 | 
            -
                    const=True,
         | 
| 150 | 
            -
                    default="",
         | 
| 151 | 
            -
                )
         | 
| 152 | 
            -
                parser.add_argument(
         | 
| 153 | 
            -
                    "--ignore_base_data",
         | 
| 154 | 
            -
                    action="store_true",
         | 
| 155 | 
            -
                    help="Ignore data specification from base configs. Useful if you want "
         | 
| 156 | 
            -
                    "to specify a custom datasets on the command line.",
         | 
| 157 | 
            -
                )
         | 
| 158 | 
            -
                parser.add_argument(
         | 
| 159 | 
            -
                    "--outdir",
         | 
| 160 | 
            -
                    required=True,
         | 
| 161 | 
            -
                    type=str,
         | 
| 162 | 
            -
                    help="Where to write outputs to.",
         | 
| 163 | 
            -
                )
         | 
| 164 | 
            -
                parser.add_argument(
         | 
| 165 | 
            -
                    "--top_k",
         | 
| 166 | 
            -
                    type=int,
         | 
| 167 | 
            -
                    default=100,
         | 
| 168 | 
            -
                    help="Sample from among top-k predictions.",
         | 
| 169 | 
            -
                )
         | 
| 170 | 
            -
                parser.add_argument(
         | 
| 171 | 
            -
                    "--temperature",
         | 
| 172 | 
            -
                    type=float,
         | 
| 173 | 
            -
                    default=1.0,
         | 
| 174 | 
            -
                    help="Sampling temperature.",
         | 
| 175 | 
            -
                )
         | 
| 176 | 
            -
                return parser
         | 
| 177 | 
            -
             | 
| 178 | 
            -
             | 
| 179 | 
            -
            def load_model_from_config(config, sd, gpu=True, eval_mode=True):
         | 
| 180 | 
            -
                if "ckpt_path" in config.params:
         | 
| 181 | 
            -
                    print("Deleting the restore-ckpt path from the config...")
         | 
| 182 | 
            -
                    config.params.ckpt_path = None
         | 
| 183 | 
            -
                if "downsample_cond_size" in config.params:
         | 
| 184 | 
            -
                    print("Deleting downsample-cond-size from the config and setting factor=0.5 instead...")
         | 
| 185 | 
            -
                    config.params.downsample_cond_size = -1
         | 
| 186 | 
            -
                    config.params["downsample_cond_factor"] = 0.5
         | 
| 187 | 
            -
                try:
         | 
| 188 | 
            -
                    if "ckpt_path" in config.params.first_stage_config.params:
         | 
| 189 | 
            -
                        config.params.first_stage_config.params.ckpt_path = None
         | 
| 190 | 
            -
                        print("Deleting the first-stage restore-ckpt path from the config...")
         | 
| 191 | 
            -
                    if "ckpt_path" in config.params.cond_stage_config.params:
         | 
| 192 | 
            -
                        config.params.cond_stage_config.params.ckpt_path = None
         | 
| 193 | 
            -
                        print("Deleting the cond-stage restore-ckpt path from the config...")
         | 
| 194 | 
            -
                except:
         | 
| 195 | 
            -
                    pass
         | 
| 196 | 
            -
             | 
| 197 | 
            -
                model = instantiate_from_config(config)
         | 
| 198 | 
            -
                if sd is not None:
         | 
| 199 | 
            -
                    missing, unexpected = model.load_state_dict(sd, strict=False)
         | 
| 200 | 
            -
                    print(f"Missing Keys in State Dict: {missing}")
         | 
| 201 | 
            -
                    print(f"Unexpected Keys in State Dict: {unexpected}")
         | 
| 202 | 
            -
                if gpu:
         | 
| 203 | 
            -
                    model.cuda()
         | 
| 204 | 
            -
                if eval_mode:
         | 
| 205 | 
            -
                    model.eval()
         | 
| 206 | 
            -
                return {"model": model}
         | 
| 207 | 
            -
             | 
| 208 | 
            -
             | 
| 209 | 
            -
            def get_data(config):
         | 
| 210 | 
            -
                # get data
         | 
| 211 | 
            -
                data = instantiate_from_config(config.data)
         | 
| 212 | 
            -
                data.prepare_data()
         | 
| 213 | 
            -
                data.setup()
         | 
| 214 | 
            -
                return data
         | 
| 215 | 
            -
             | 
| 216 | 
            -
             | 
| 217 | 
            -
            def load_model_and_dset(config, ckpt, gpu, eval_mode):
         | 
| 218 | 
            -
                # get data
         | 
| 219 | 
            -
                dsets = get_data(config)   # calls data.config ...
         | 
| 220 | 
            -
             | 
| 221 | 
            -
                # now load the specified checkpoint
         | 
| 222 | 
            -
                if ckpt:
         | 
| 223 | 
            -
                    pl_sd = torch.load(ckpt, map_location="cpu")
         | 
| 224 | 
            -
                    global_step = pl_sd["global_step"]
         | 
| 225 | 
            -
                else:
         | 
| 226 | 
            -
                    pl_sd = {"state_dict": None}
         | 
| 227 | 
            -
                    global_step = None
         | 
| 228 | 
            -
                model = load_model_from_config(config.model,
         | 
| 229 | 
            -
                                               pl_sd["state_dict"],
         | 
| 230 | 
            -
                                               gpu=gpu,
         | 
| 231 | 
            -
                                               eval_mode=eval_mode)["model"]
         | 
| 232 | 
            -
                return dsets, model, global_step
         | 
| 233 | 
            -
             | 
| 234 | 
            -
             | 
| 235 | 
            -
            if __name__ == "__main__":
         | 
| 236 | 
            -
                sys.path.append(os.getcwd())
         | 
| 237 | 
            -
             | 
| 238 | 
            -
                parser = get_parser()
         | 
| 239 | 
            -
             | 
| 240 | 
            -
                opt, unknown = parser.parse_known_args()
         | 
| 241 | 
            -
             | 
| 242 | 
            -
                ckpt = None
         | 
| 243 | 
            -
                if opt.resume:
         | 
| 244 | 
            -
                    if not os.path.exists(opt.resume):
         | 
| 245 | 
            -
                        raise ValueError("Cannot find {}".format(opt.resume))
         | 
| 246 | 
            -
                    if os.path.isfile(opt.resume):
         | 
| 247 | 
            -
                        paths = opt.resume.split("/")
         | 
| 248 | 
            -
                        try:
         | 
| 249 | 
            -
                            idx = len(paths)-paths[::-1].index("logs")+1
         | 
| 250 | 
            -
                        except ValueError:
         | 
| 251 | 
            -
                            idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
         | 
| 252 | 
            -
                        logdir = "/".join(paths[:idx])
         | 
| 253 | 
            -
                        ckpt = opt.resume
         | 
| 254 | 
            -
                    else:
         | 
| 255 | 
            -
                        assert os.path.isdir(opt.resume), opt.resume
         | 
| 256 | 
            -
                        logdir = opt.resume.rstrip("/")
         | 
| 257 | 
            -
                        ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
         | 
| 258 | 
            -
                    print(f"logdir:{logdir}")
         | 
| 259 | 
            -
                    base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml")))
         | 
| 260 | 
            -
                    opt.base = base_configs+opt.base
         | 
| 261 | 
            -
             | 
| 262 | 
            -
                if opt.config:
         | 
| 263 | 
            -
                    if type(opt.config) == str:
         | 
| 264 | 
            -
                        opt.base = [opt.config]
         | 
| 265 | 
            -
                    else:
         | 
| 266 | 
            -
                        opt.base = [opt.base[-1]]
         | 
| 267 | 
            -
             | 
| 268 | 
            -
                configs = [OmegaConf.load(cfg) for cfg in opt.base]
         | 
| 269 | 
            -
                cli = OmegaConf.from_dotlist(unknown)
         | 
| 270 | 
            -
                if opt.ignore_base_data:
         | 
| 271 | 
            -
                    for config in configs:
         | 
| 272 | 
            -
                        if hasattr(config, "data"): del config["data"]
         | 
| 273 | 
            -
                config = OmegaConf.merge(*configs, cli)
         | 
| 274 | 
            -
             | 
| 275 | 
            -
                print(ckpt)
         | 
| 276 | 
            -
                gpu = True
         | 
| 277 | 
            -
                eval_mode = True
         | 
| 278 | 
            -
                show_config = False
         | 
| 279 | 
            -
                if show_config:
         | 
| 280 | 
            -
                    print(OmegaConf.to_container(config))
         | 
| 281 | 
            -
             | 
| 282 | 
            -
                dsets, model, global_step = load_model_and_dset(config, ckpt, gpu, eval_mode)
         | 
| 283 | 
            -
                print(f"Global step: {global_step}")
         | 
| 284 | 
            -
             | 
| 285 | 
            -
                outdir = os.path.join(opt.outdir, "{:06}_{}_{}".format(global_step,
         | 
| 286 | 
            -
                                                                       opt.top_k,
         | 
| 287 | 
            -
                                                                       opt.temperature))
         | 
| 288 | 
            -
                os.makedirs(outdir, exist_ok=True)
         | 
| 289 | 
            -
                print("Writing samples to ", outdir)
         | 
| 290 | 
            -
                for k in ["originals", "reconstructions", "samples"]:
         | 
| 291 | 
            -
                    os.makedirs(os.path.join(outdir, k), exist_ok=True)
         | 
| 292 | 
            -
                run_conditional(model, dsets, outdir, opt.top_k, opt.temperature)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        taming-transformers/scripts/make_scene_samples.py
    DELETED
    
    | @@ -1,198 +0,0 @@ | |
| 1 | 
            -
            import glob
         | 
| 2 | 
            -
            import os
         | 
| 3 | 
            -
            import sys
         | 
| 4 | 
            -
            from itertools import product
         | 
| 5 | 
            -
            from pathlib import Path
         | 
| 6 | 
            -
            from typing import Literal, List, Optional, Tuple
         | 
| 7 | 
            -
             | 
| 8 | 
            -
            import numpy as np
         | 
| 9 | 
            -
            import torch
         | 
| 10 | 
            -
            from omegaconf import OmegaConf
         | 
| 11 | 
            -
            from pytorch_lightning import seed_everything
         | 
| 12 | 
            -
            from torch import Tensor
         | 
| 13 | 
            -
            from torchvision.utils import save_image
         | 
| 14 | 
            -
            from tqdm import tqdm
         | 
| 15 | 
            -
             | 
| 16 | 
            -
            from scripts.make_samples import get_parser, load_model_and_dset
         | 
| 17 | 
            -
            from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder
         | 
| 18 | 
            -
            from taming.data.helper_types import BoundingBox, Annotation
         | 
| 19 | 
            -
            from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset
         | 
| 20 | 
            -
            from taming.models.cond_transformer import Net2NetTransformer
         | 
| 21 | 
            -
             | 
| 22 | 
            -
            seed_everything(42424242)
         | 
| 23 | 
            -
            device: Literal['cuda', 'cpu'] = 'cuda'
         | 
| 24 | 
            -
            first_stage_factor = 16
         | 
| 25 | 
            -
            trained_on_res = 256
         | 
| 26 | 
            -
             | 
| 27 | 
            -
             | 
| 28 | 
            -
            def _helper(coord: int, coord_max: int, coord_window: int) -> (int, int):
         | 
| 29 | 
            -
                assert 0 <= coord < coord_max
         | 
| 30 | 
            -
                coord_desired_center = (coord_window - 1) // 2
         | 
| 31 | 
            -
                return np.clip(coord - coord_desired_center, 0, coord_max - coord_window)
         | 
| 32 | 
            -
             | 
| 33 | 
            -
             | 
| 34 | 
            -
            def get_crop_coordinates(x: int, y: int) -> BoundingBox:
         | 
| 35 | 
            -
                WIDTH, HEIGHT = desired_z_shape[1], desired_z_shape[0]
         | 
| 36 | 
            -
                x0 = _helper(x, WIDTH, first_stage_factor) / WIDTH
         | 
| 37 | 
            -
                y0 = _helper(y, HEIGHT, first_stage_factor) / HEIGHT
         | 
| 38 | 
            -
                w = first_stage_factor / WIDTH
         | 
| 39 | 
            -
                h = first_stage_factor / HEIGHT
         | 
| 40 | 
            -
                return x0, y0, w, h
         | 
| 41 | 
            -
             | 
| 42 | 
            -
             | 
| 43 | 
            -
            def get_z_indices_crop_out(z_indices: Tensor, predict_x: int, predict_y: int) -> Tensor:
         | 
| 44 | 
            -
                WIDTH, HEIGHT = desired_z_shape[1], desired_z_shape[0]
         | 
| 45 | 
            -
                x0 = _helper(predict_x, WIDTH, first_stage_factor)
         | 
| 46 | 
            -
                y0 = _helper(predict_y, HEIGHT, first_stage_factor)
         | 
| 47 | 
            -
                no_images = z_indices.shape[0]
         | 
| 48 | 
            -
                cut_out_1 = z_indices[:, y0:predict_y, x0:x0+first_stage_factor].reshape((no_images, -1))
         | 
| 49 | 
            -
                cut_out_2 = z_indices[:, predict_y, x0:predict_x]
         | 
| 50 | 
            -
                return torch.cat((cut_out_1, cut_out_2), dim=1)
         | 
| 51 | 
            -
             | 
| 52 | 
            -
             | 
| 53 | 
            -
            @torch.no_grad()
         | 
| 54 | 
            -
            def sample(model: Net2NetTransformer, annotations: List[Annotation], dataset: AnnotatedObjectsDataset,
         | 
| 55 | 
            -
                       conditional_builder: ObjectsCenterPointsConditionalBuilder, no_samples: int,
         | 
| 56 | 
            -
                       temperature: float, top_k: int) -> Tensor:
         | 
| 57 | 
            -
                x_max, y_max = desired_z_shape[1], desired_z_shape[0]
         | 
| 58 | 
            -
             | 
| 59 | 
            -
                annotations = [a._replace(category_no=dataset.get_category_number(a.category_id)) for a in annotations]
         | 
| 60 | 
            -
             | 
| 61 | 
            -
                recompute_conditional = any((desired_resolution[0] > trained_on_res, desired_resolution[1] > trained_on_res))
         | 
| 62 | 
            -
                if not recompute_conditional:
         | 
| 63 | 
            -
                    crop_coordinates = get_crop_coordinates(0, 0)
         | 
| 64 | 
            -
                    conditional_indices = conditional_builder.build(annotations, crop_coordinates)
         | 
| 65 | 
            -
                    c_indices = conditional_indices.to(device).repeat(no_samples, 1)
         | 
| 66 | 
            -
                    z_indices = torch.zeros((no_samples, 0), device=device).long()
         | 
| 67 | 
            -
                    output_indices = model.sample(z_indices, c_indices, steps=x_max*y_max, temperature=temperature,
         | 
| 68 | 
            -
                                                  sample=True, top_k=top_k)
         | 
| 69 | 
            -
                else:
         | 
| 70 | 
            -
                    output_indices = torch.zeros((no_samples, y_max, x_max), device=device).long()
         | 
| 71 | 
            -
                    for predict_y, predict_x in tqdm(product(range(y_max), range(x_max)), desc='sampling_image', total=x_max*y_max):
         | 
| 72 | 
            -
                        crop_coordinates = get_crop_coordinates(predict_x, predict_y)
         | 
| 73 | 
            -
                        z_indices = get_z_indices_crop_out(output_indices, predict_x, predict_y)
         | 
| 74 | 
            -
                        conditional_indices = conditional_builder.build(annotations, crop_coordinates)
         | 
| 75 | 
            -
                        c_indices = conditional_indices.to(device).repeat(no_samples, 1)
         | 
| 76 | 
            -
                        new_index = model.sample(z_indices, c_indices, steps=1, temperature=temperature, sample=True, top_k=top_k)
         | 
| 77 | 
            -
                        output_indices[:, predict_y, predict_x] = new_index[:, -1]
         | 
| 78 | 
            -
                z_shape = (
         | 
| 79 | 
            -
                    no_samples,
         | 
| 80 | 
            -
                    model.first_stage_model.quantize.e_dim,  # codebook embed_dim
         | 
| 81 | 
            -
                    desired_z_shape[0],  # z_height
         | 
| 82 | 
            -
                    desired_z_shape[1]  # z_width
         | 
| 83 | 
            -
                )
         | 
| 84 | 
            -
                x_sample = model.decode_to_img(output_indices, z_shape) * 0.5 + 0.5
         | 
| 85 | 
            -
                x_sample = x_sample.to('cpu')
         | 
| 86 | 
            -
             | 
| 87 | 
            -
                plotter = conditional_builder.plot
         | 
| 88 | 
            -
                figure_size = (x_sample.shape[2], x_sample.shape[3])
         | 
| 89 | 
            -
                scene_graph = conditional_builder.build(annotations, (0., 0., 1., 1.))
         | 
| 90 | 
            -
                plot = plotter(scene_graph, dataset.get_textual_label_for_category_no, figure_size)
         | 
| 91 | 
            -
                return torch.cat((x_sample, plot.unsqueeze(0)))
         | 
| 92 | 
            -
             | 
| 93 | 
            -
             | 
| 94 | 
            -
            def get_resolution(resolution_str: str) -> (Tuple[int, int], Tuple[int, int]):
         | 
| 95 | 
            -
                if not resolution_str.count(',') == 1:
         | 
| 96 | 
            -
                    raise ValueError("Give resolution as in 'height,width'")
         | 
| 97 | 
            -
                res_h, res_w = resolution_str.split(',')
         | 
| 98 | 
            -
                res_h = max(int(res_h), trained_on_res)
         | 
| 99 | 
            -
                res_w = max(int(res_w), trained_on_res)
         | 
| 100 | 
            -
                z_h = int(round(res_h/first_stage_factor))
         | 
| 101 | 
            -
                z_w = int(round(res_w/first_stage_factor))
         | 
| 102 | 
            -
                return (z_h, z_w), (z_h*first_stage_factor, z_w*first_stage_factor)
         | 
| 103 | 
            -
             | 
| 104 | 
            -
             | 
| 105 | 
            -
            def add_arg_to_parser(parser):
         | 
| 106 | 
            -
                parser.add_argument(
         | 
| 107 | 
            -
                    "-R",
         | 
| 108 | 
            -
                    "--resolution",
         | 
| 109 | 
            -
                    type=str,
         | 
| 110 | 
            -
                    default='256,256',
         | 
| 111 | 
            -
                    help=f"give resolution in multiples of {first_stage_factor}, default is '256,256'",
         | 
| 112 | 
            -
                )
         | 
| 113 | 
            -
                parser.add_argument(
         | 
| 114 | 
            -
                    "-C",
         | 
| 115 | 
            -
                    "--conditional",
         | 
| 116 | 
            -
                    type=str,
         | 
| 117 | 
            -
                    default='objects_bbox',
         | 
| 118 | 
            -
                    help=f"objects_bbox or objects_center_points",
         | 
| 119 | 
            -
                )
         | 
| 120 | 
            -
                parser.add_argument(
         | 
| 121 | 
            -
                    "-N",
         | 
| 122 | 
            -
                    "--n_samples_per_layout",
         | 
| 123 | 
            -
                    type=int,
         | 
| 124 | 
            -
                    default=4,
         | 
| 125 | 
            -
                    help=f"how many samples to generate per layout",
         | 
| 126 | 
            -
                )
         | 
| 127 | 
            -
                return parser
         | 
| 128 | 
            -
             | 
| 129 | 
            -
             | 
| 130 | 
            -
            if __name__ == "__main__":
         | 
| 131 | 
            -
                sys.path.append(os.getcwd())
         | 
| 132 | 
            -
             | 
| 133 | 
            -
                parser = get_parser()
         | 
| 134 | 
            -
                parser = add_arg_to_parser(parser)
         | 
| 135 | 
            -
             | 
| 136 | 
            -
                opt, unknown = parser.parse_known_args()
         | 
| 137 | 
            -
             | 
| 138 | 
            -
                ckpt = None
         | 
| 139 | 
            -
                if opt.resume:
         | 
| 140 | 
            -
                    if not os.path.exists(opt.resume):
         | 
| 141 | 
            -
                        raise ValueError("Cannot find {}".format(opt.resume))
         | 
| 142 | 
            -
                    if os.path.isfile(opt.resume):
         | 
| 143 | 
            -
                        paths = opt.resume.split("/")
         | 
| 144 | 
            -
                        try:
         | 
| 145 | 
            -
                            idx = len(paths)-paths[::-1].index("logs")+1
         | 
| 146 | 
            -
                        except ValueError:
         | 
| 147 | 
            -
                            idx = -2  # take a guess: path/to/logdir/checkpoints/model.ckpt
         | 
| 148 | 
            -
                        logdir = "/".join(paths[:idx])
         | 
| 149 | 
            -
                        ckpt = opt.resume
         | 
| 150 | 
            -
                    else:
         | 
| 151 | 
            -
                        assert os.path.isdir(opt.resume), opt.resume
         | 
| 152 | 
            -
                        logdir = opt.resume.rstrip("/")
         | 
| 153 | 
            -
                        ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
         | 
| 154 | 
            -
                    print(f"logdir:{logdir}")
         | 
| 155 | 
            -
                    base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml")))
         | 
| 156 | 
            -
                    opt.base = base_configs+opt.base
         | 
| 157 | 
            -
             | 
| 158 | 
            -
                if opt.config:
         | 
| 159 | 
            -
                    if type(opt.config) == str:
         | 
| 160 | 
            -
                        opt.base = [opt.config]
         | 
| 161 | 
            -
                    else:
         | 
| 162 | 
            -
                        opt.base = [opt.base[-1]]
         | 
| 163 | 
            -
             | 
| 164 | 
            -
                configs = [OmegaConf.load(cfg) for cfg in opt.base]
         | 
| 165 | 
            -
                cli = OmegaConf.from_dotlist(unknown)
         | 
| 166 | 
            -
                if opt.ignore_base_data:
         | 
| 167 | 
            -
                    for config in configs:
         | 
| 168 | 
            -
                        if hasattr(config, "data"):
         | 
| 169 | 
            -
                            del config["data"]
         | 
| 170 | 
            -
                config = OmegaConf.merge(*configs, cli)
         | 
| 171 | 
            -
                desired_z_shape, desired_resolution = get_resolution(opt.resolution)
         | 
| 172 | 
            -
                conditional = opt.conditional
         | 
| 173 | 
            -
             | 
| 174 | 
            -
                print(ckpt)
         | 
| 175 | 
            -
                gpu = True
         | 
| 176 | 
            -
                eval_mode = True
         | 
| 177 | 
            -
                show_config = False
         | 
| 178 | 
            -
                if show_config:
         | 
| 179 | 
            -
                    print(OmegaConf.to_container(config))
         | 
| 180 | 
            -
             | 
| 181 | 
            -
                dsets, model, global_step = load_model_and_dset(config, ckpt, gpu, eval_mode)
         | 
| 182 | 
            -
                print(f"Global step: {global_step}")
         | 
| 183 | 
            -
             | 
| 184 | 
            -
                data_loader = dsets.val_dataloader()
         | 
| 185 | 
            -
                print(dsets.datasets["validation"].conditional_builders)
         | 
| 186 | 
            -
                conditional_builder = dsets.datasets["validation"].conditional_builders[conditional]
         | 
| 187 | 
            -
             | 
| 188 | 
            -
                outdir = Path(opt.outdir).joinpath(f"{global_step:06}_{opt.top_k}_{opt.temperature}")
         | 
| 189 | 
            -
                outdir.mkdir(exist_ok=True, parents=True)
         | 
| 190 | 
            -
                print("Writing samples to ", outdir)
         | 
| 191 | 
            -
             | 
| 192 | 
            -
                p_bar_1 = tqdm(enumerate(iter(data_loader)), desc='batch', total=len(data_loader))
         | 
| 193 | 
            -
                for batch_no, batch in p_bar_1:
         | 
| 194 | 
            -
                    save_img: Optional[Tensor] = None
         | 
| 195 | 
            -
                    for i, annotations in tqdm(enumerate(batch['annotations']), desc='within_batch', total=data_loader.batch_size):
         | 
| 196 | 
            -
                        imgs = sample(model, annotations, dsets.datasets["validation"], conditional_builder,
         | 
| 197 | 
            -
                                      opt.n_samples_per_layout, opt.temperature, opt.top_k)
         | 
| 198 | 
            -
                        save_image(imgs, outdir.joinpath(f'{batch_no:04}_{i:02}.png'), n_row=opt.n_samples_per_layout+1)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        taming-transformers/scripts/sample_conditional.py
    DELETED
    
    | @@ -1,355 +0,0 @@ | |
| 1 | 
            -
            import argparse, os, sys, glob, math, time
         | 
| 2 | 
            -
            import torch
         | 
| 3 | 
            -
            import numpy as np
         | 
| 4 | 
            -
            from omegaconf import OmegaConf
         | 
| 5 | 
            -
            import streamlit as st
         | 
| 6 | 
            -
            from streamlit import caching
         | 
| 7 | 
            -
            from PIL import Image
         | 
| 8 | 
            -
            from main import instantiate_from_config, DataModuleFromConfig
         | 
| 9 | 
            -
            from torch.utils.data import DataLoader
         | 
| 10 | 
            -
            from torch.utils.data.dataloader import default_collate
         | 
| 11 | 
            -
             | 
| 12 | 
            -
             | 
| 13 | 
            -
            rescale = lambda x: (x + 1.) / 2.
         | 
| 14 | 
            -
             | 
| 15 | 
            -
             | 
| 16 | 
            -
            def bchw_to_st(x):
         | 
| 17 | 
            -
                return rescale(x.detach().cpu().numpy().transpose(0,2,3,1))
         | 
| 18 | 
            -
             | 
| 19 | 
            -
            def save_img(xstart, fname):
         | 
| 20 | 
            -
                I = (xstart.clip(0,1)[0]*255).astype(np.uint8)
         | 
| 21 | 
            -
                Image.fromarray(I).save(fname)
         | 
| 22 | 
            -
             | 
| 23 | 
            -
             | 
| 24 | 
            -
             | 
| 25 | 
            -
            def get_interactive_image(resize=False):
         | 
| 26 | 
            -
                image = st.file_uploader("Input", type=["jpg", "JPEG", "png"])
         | 
| 27 | 
            -
                if image is not None:
         | 
| 28 | 
            -
                    image = Image.open(image)
         | 
| 29 | 
            -
                    if not image.mode == "RGB":
         | 
| 30 | 
            -
                        image = image.convert("RGB")
         | 
| 31 | 
            -
                    image = np.array(image).astype(np.uint8)
         | 
| 32 | 
            -
                    print("upload image shape: {}".format(image.shape))
         | 
| 33 | 
            -
                    img = Image.fromarray(image)
         | 
| 34 | 
            -
                    if resize:
         | 
| 35 | 
            -
                        img = img.resize((256, 256))
         | 
| 36 | 
            -
                    image = np.array(img)
         | 
| 37 | 
            -
                    return image
         | 
| 38 | 
            -
             | 
| 39 | 
            -
             | 
| 40 | 
            -
            def single_image_to_torch(x, permute=True):
         | 
| 41 | 
            -
                assert x is not None, "Please provide an image through the upload function"
         | 
| 42 | 
            -
                x = np.array(x)
         | 
| 43 | 
            -
                x = torch.FloatTensor(x/255.*2. - 1.)[None,...]
         | 
| 44 | 
            -
                if permute:
         | 
| 45 | 
            -
                    x = x.permute(0, 3, 1, 2)
         | 
| 46 | 
            -
                return x
         | 
| 47 | 
            -
             | 
| 48 | 
            -
             | 
| 49 | 
            -
            def pad_to_M(x, M):
         | 
| 50 | 
            -
                hp = math.ceil(x.shape[2]/M)*M-x.shape[2]
         | 
| 51 | 
            -
                wp = math.ceil(x.shape[3]/M)*M-x.shape[3]
         | 
| 52 | 
            -
                x = torch.nn.functional.pad(x, (0,wp,0,hp,0,0,0,0))
         | 
| 53 | 
            -
                return x
         | 
| 54 | 
            -
             | 
| 55 | 
            -
            @torch.no_grad()
         | 
| 56 | 
            -
            def run_conditional(model, dsets):
         | 
| 57 | 
            -
                if len(dsets.datasets) > 1:
         | 
| 58 | 
            -
                    split = st.sidebar.radio("Split", sorted(dsets.datasets.keys()))
         | 
| 59 | 
            -
                    dset = dsets.datasets[split]
         | 
| 60 | 
            -
                else:
         | 
| 61 | 
            -
                    dset = next(iter(dsets.datasets.values()))
         | 
| 62 | 
            -
                batch_size = 1
         | 
| 63 | 
            -
                start_index = st.sidebar.number_input("Example Index (Size: {})".format(len(dset)), value=0,
         | 
| 64 | 
            -
                                                      min_value=0,
         | 
| 65 | 
            -
                                                      max_value=len(dset)-batch_size)
         | 
| 66 | 
            -
                indices = list(range(start_index, start_index+batch_size))
         | 
| 67 | 
            -
             | 
| 68 | 
            -
                example = default_collate([dset[i] for i in indices])
         | 
| 69 | 
            -
             | 
| 70 | 
            -
                x = model.get_input("image", example).to(model.device)
         | 
| 71 | 
            -
             | 
| 72 | 
            -
                cond_key = model.cond_stage_key
         | 
| 73 | 
            -
                c = model.get_input(cond_key, example).to(model.device)
         | 
| 74 | 
            -
             | 
| 75 | 
            -
                scale_factor = st.sidebar.slider("Scale Factor", min_value=0.5, max_value=4.0, step=0.25, value=1.00)
         | 
| 76 | 
            -
                if scale_factor != 1.0:
         | 
| 77 | 
            -
                    x = torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode="bicubic")
         | 
| 78 | 
            -
                    c = torch.nn.functional.interpolate(c, scale_factor=scale_factor, mode="bicubic")
         | 
| 79 | 
            -
             | 
| 80 | 
            -
                quant_z, z_indices = model.encode_to_z(x)
         | 
| 81 | 
            -
                quant_c, c_indices = model.encode_to_c(c)
         | 
| 82 | 
            -
             | 
| 83 | 
            -
                cshape = quant_z.shape
         | 
| 84 | 
            -
             | 
| 85 | 
            -
                xrec = model.first_stage_model.decode(quant_z)
         | 
| 86 | 
            -
                st.write("image: {}".format(x.shape))
         | 
| 87 | 
            -
                st.image(bchw_to_st(x), clamp=True, output_format="PNG")
         | 
| 88 | 
            -
                st.write("image reconstruction: {}".format(xrec.shape))
         | 
| 89 | 
            -
                st.image(bchw_to_st(xrec), clamp=True, output_format="PNG")
         | 
| 90 | 
            -
             | 
| 91 | 
            -
                if cond_key == "segmentation":
         | 
| 92 | 
            -
                    # get image from segmentation mask
         | 
| 93 | 
            -
                    num_classes = c.shape[1]
         | 
| 94 | 
            -
                    c = torch.argmax(c, dim=1, keepdim=True)
         | 
| 95 | 
            -
                    c = torch.nn.functional.one_hot(c, num_classes=num_classes)
         | 
| 96 | 
            -
                    c = c.squeeze(1).permute(0, 3, 1, 2).float()
         | 
| 97 | 
            -
                    c = model.cond_stage_model.to_rgb(c)
         | 
| 98 | 
            -
             | 
| 99 | 
            -
                st.write(f"{cond_key}: {tuple(c.shape)}")
         | 
| 100 | 
            -
                st.image(bchw_to_st(c), clamp=True, output_format="PNG")
         | 
| 101 | 
            -
             | 
| 102 | 
            -
                idx = z_indices
         | 
| 103 | 
            -
             | 
| 104 | 
            -
                half_sample = st.sidebar.checkbox("Image Completion", value=False)
         | 
| 105 | 
            -
                if half_sample:
         | 
| 106 | 
            -
                    start = idx.shape[1]//2
         | 
| 107 | 
            -
                else:
         | 
| 108 | 
            -
                    start = 0
         | 
| 109 | 
            -
             | 
| 110 | 
            -
                idx[:,start:] = 0
         | 
| 111 | 
            -
                idx = idx.reshape(cshape[0],cshape[2],cshape[3])
         | 
| 112 | 
            -
                start_i = start//cshape[3]
         | 
| 113 | 
            -
                start_j = start %cshape[3]
         | 
| 114 | 
            -
             | 
| 115 | 
            -
                if not half_sample and quant_z.shape == quant_c.shape:
         | 
| 116 | 
            -
                    st.info("Setting idx to c_indices")
         | 
| 117 | 
            -
                    idx = c_indices.clone().reshape(cshape[0],cshape[2],cshape[3])
         | 
| 118 | 
            -
             | 
| 119 | 
            -
                cidx = c_indices
         | 
| 120 | 
            -
                cidx = cidx.reshape(quant_c.shape[0],quant_c.shape[2],quant_c.shape[3])
         | 
| 121 | 
            -
             | 
| 122 | 
            -
                xstart = model.decode_to_img(idx[:,:cshape[2],:cshape[3]], cshape)
         | 
| 123 | 
            -
                st.image(bchw_to_st(xstart), clamp=True, output_format="PNG")
         | 
| 124 | 
            -
             | 
| 125 | 
            -
                temperature = st.number_input("Temperature", value=1.0)
         | 
| 126 | 
            -
                top_k = st.number_input("Top k", value=100)
         | 
| 127 | 
            -
                sample = st.checkbox("Sample", value=True)
         | 
| 128 | 
            -
                update_every = st.number_input("Update every", value=75)
         | 
| 129 | 
            -
             | 
| 130 | 
            -
                st.text(f"Sampling shape ({cshape[2]},{cshape[3]})")
         | 
| 131 | 
            -
             | 
| 132 | 
            -
                animate = st.checkbox("animate")
         | 
| 133 | 
            -
                if animate:
         | 
| 134 | 
            -
                    import imageio
         | 
| 135 | 
            -
                    outvid = "sampling.mp4"
         | 
| 136 | 
            -
                    writer = imageio.get_writer(outvid, fps=25)
         | 
| 137 | 
            -
                elapsed_t = st.empty()
         | 
| 138 | 
            -
                info = st.empty()
         | 
| 139 | 
            -
                st.text("Sampled")
         | 
| 140 | 
            -
                if st.button("Sample"):
         | 
| 141 | 
            -
                    output = st.empty()
         | 
| 142 | 
            -
                    start_t = time.time()
         | 
| 143 | 
            -
                    for i in range(start_i,cshape[2]-0):
         | 
| 144 | 
            -
                        if i <= 8:
         | 
| 145 | 
            -
                            local_i = i
         | 
| 146 | 
            -
                        elif cshape[2]-i < 8:
         | 
| 147 | 
            -
                            local_i = 16-(cshape[2]-i)
         | 
| 148 | 
            -
                        else:
         | 
| 149 | 
            -
                            local_i = 8
         | 
| 150 | 
            -
                        for j in range(start_j,cshape[3]-0):
         | 
| 151 | 
            -
                            if j <= 8:
         | 
| 152 | 
            -
                                local_j = j
         | 
| 153 | 
            -
                            elif cshape[3]-j < 8:
         | 
| 154 | 
            -
                                local_j = 16-(cshape[3]-j)
         | 
| 155 | 
            -
                            else:
         | 
| 156 | 
            -
                                local_j = 8
         | 
| 157 | 
            -
             | 
| 158 | 
            -
                            i_start = i-local_i
         | 
| 159 | 
            -
                            i_end = i_start+16
         | 
| 160 | 
            -
                            j_start = j-local_j
         | 
| 161 | 
            -
                            j_end = j_start+16
         | 
| 162 | 
            -
                            elapsed_t.text(f"Time: {time.time() - start_t} seconds")
         | 
| 163 | 
            -
                            info.text(f"Step: ({i},{j}) | Local: ({local_i},{local_j}) | Crop: ({i_start}:{i_end},{j_start}:{j_end})")
         | 
| 164 | 
            -
                            patch = idx[:,i_start:i_end,j_start:j_end]
         | 
| 165 | 
            -
                            patch = patch.reshape(patch.shape[0],-1)
         | 
| 166 | 
            -
                            cpatch = cidx[:, i_start:i_end, j_start:j_end]
         | 
| 167 | 
            -
                            cpatch = cpatch.reshape(cpatch.shape[0], -1)
         | 
| 168 | 
            -
                            patch = torch.cat((cpatch, patch), dim=1)
         | 
| 169 | 
            -
                            logits,_ = model.transformer(patch[:,:-1])
         | 
| 170 | 
            -
                            logits = logits[:, -256:, :]
         | 
| 171 | 
            -
                            logits = logits.reshape(cshape[0],16,16,-1)
         | 
| 172 | 
            -
                            logits = logits[:,local_i,local_j,:]
         | 
| 173 | 
            -
             | 
| 174 | 
            -
                            logits = logits/temperature
         | 
| 175 | 
            -
             | 
| 176 | 
            -
                            if top_k is not None:
         | 
| 177 | 
            -
                                logits = model.top_k_logits(logits, top_k)
         | 
| 178 | 
            -
                            # apply softmax to convert to probabilities
         | 
| 179 | 
            -
                            probs = torch.nn.functional.softmax(logits, dim=-1)
         | 
| 180 | 
            -
                            # sample from the distribution or take the most likely
         | 
| 181 | 
            -
                            if sample:
         | 
| 182 | 
            -
                                ix = torch.multinomial(probs, num_samples=1)
         | 
| 183 | 
            -
                            else:
         | 
| 184 | 
            -
                                _, ix = torch.topk(probs, k=1, dim=-1)
         | 
| 185 | 
            -
                            idx[:,i,j] = ix
         | 
| 186 | 
            -
             | 
| 187 | 
            -
                            if (i*cshape[3]+j)%update_every==0:
         | 
| 188 | 
            -
                                xstart = model.decode_to_img(idx[:, :cshape[2], :cshape[3]], cshape,)
         | 
| 189 | 
            -
             | 
| 190 | 
            -
                                xstart = bchw_to_st(xstart)
         | 
| 191 | 
            -
                                output.image(xstart, clamp=True, output_format="PNG")
         | 
| 192 | 
            -
             | 
| 193 | 
            -
                                if animate:
         | 
| 194 | 
            -
                                    writer.append_data((xstart[0]*255).clip(0, 255).astype(np.uint8))
         | 
| 195 | 
            -
             | 
| 196 | 
            -
                    xstart = model.decode_to_img(idx[:,:cshape[2],:cshape[3]], cshape)
         | 
| 197 | 
            -
                    xstart = bchw_to_st(xstart)
         | 
| 198 | 
            -
                    output.image(xstart, clamp=True, output_format="PNG")
         | 
| 199 | 
            -
                    #save_img(xstart, "full_res_sample.png")
         | 
| 200 | 
            -
                    if animate:
         | 
| 201 | 
            -
                        writer.close()
         | 
| 202 | 
            -
                        st.video(outvid)
         | 
| 203 | 
            -
             | 
| 204 | 
            -
             | 
| 205 | 
            -
            def get_parser():
         | 
| 206 | 
            -
                parser = argparse.ArgumentParser()
         | 
| 207 | 
            -
                parser.add_argument(
         | 
| 208 | 
            -
                    "-r",
         | 
| 209 | 
            -
                    "--resume",
         | 
| 210 | 
            -
                    type=str,
         | 
| 211 | 
            -
                    nargs="?",
         | 
| 212 | 
            -
                    help="load from logdir or checkpoint in logdir",
         | 
| 213 | 
            -
                )
         | 
| 214 | 
            -
                parser.add_argument(
         | 
| 215 | 
            -
                    "-b",
         | 
| 216 | 
            -
                    "--base",
         | 
| 217 | 
            -
                    nargs="*",
         | 
| 218 | 
            -
                    metavar="base_config.yaml",
         | 
| 219 | 
            -
                    help="paths to base configs. Loaded from left-to-right. "
         | 
| 220 | 
            -
                    "Parameters can be overwritten or added with command-line options of the form `--key value`.",
         | 
| 221 | 
            -
                    default=list(),
         | 
| 222 | 
            -
                )
         | 
| 223 | 
            -
                parser.add_argument(
         | 
| 224 | 
            -
                    "-c",
         | 
| 225 | 
            -
                    "--config",
         | 
| 226 | 
            -
                    nargs="?",
         | 
| 227 | 
            -
                    metavar="single_config.yaml",
         | 
| 228 | 
            -
                    help="path to single config. If specified, base configs will be ignored "
         | 
| 229 | 
            -
                    "(except for the last one if left unspecified).",
         | 
| 230 | 
            -
                    const=True,
         | 
| 231 | 
            -
                    default="",
         | 
| 232 | 
            -
                )
         | 
| 233 | 
            -
                parser.add_argument(
         | 
| 234 | 
            -
                    "--ignore_base_data",
         | 
| 235 | 
            -
                    action="store_true",
         | 
| 236 | 
            -
                    help="Ignore data specification from base configs. Useful if you want "
         | 
| 237 | 
            -
                    "to specify a custom datasets on the command line.",
         | 
| 238 | 
            -
                )
         | 
| 239 | 
            -
                return parser
         | 
| 240 | 
            -
             | 
| 241 | 
            -
             | 
| 242 | 
            -
            def load_model_from_config(config, sd, gpu=True, eval_mode=True):
         | 
| 243 | 
            -
                if "ckpt_path" in config.params:
         | 
| 244 | 
            -
                    st.warning("Deleting the restore-ckpt path from the config...")
         | 
| 245 | 
            -
                    config.params.ckpt_path = None
         | 
| 246 | 
            -
                if "downsample_cond_size" in config.params:
         | 
| 247 | 
            -
                    st.warning("Deleting downsample-cond-size from the config and setting factor=0.5 instead...")
         | 
| 248 | 
            -
                    config.params.downsample_cond_size = -1
         | 
| 249 | 
            -
                    config.params["downsample_cond_factor"] = 0.5
         | 
| 250 | 
            -
                try:
         | 
| 251 | 
            -
                    if "ckpt_path" in config.params.first_stage_config.params:
         | 
| 252 | 
            -
                        config.params.first_stage_config.params.ckpt_path = None
         | 
| 253 | 
            -
                        st.warning("Deleting the first-stage restore-ckpt path from the config...")
         | 
| 254 | 
            -
                    if "ckpt_path" in config.params.cond_stage_config.params:
         | 
| 255 | 
            -
                        config.params.cond_stage_config.params.ckpt_path = None
         | 
| 256 | 
            -
                        st.warning("Deleting the cond-stage restore-ckpt path from the config...")
         | 
| 257 | 
            -
                except:
         | 
| 258 | 
            -
                    pass
         | 
| 259 | 
            -
             | 
| 260 | 
            -
                model = instantiate_from_config(config)
         | 
| 261 | 
            -
                if sd is not None:
         | 
| 262 | 
            -
                    missing, unexpected = model.load_state_dict(sd, strict=False)
         | 
| 263 | 
            -
                    st.info(f"Missing Keys in State Dict: {missing}")
         | 
| 264 | 
            -
                    st.info(f"Unexpected Keys in State Dict: {unexpected}")
         | 
| 265 | 
            -
                if gpu:
         | 
| 266 | 
            -
                    model.cuda()
         | 
| 267 | 
            -
                if eval_mode:
         | 
| 268 | 
            -
                    model.eval()
         | 
| 269 | 
            -
                return {"model": model}
         | 
| 270 | 
            -
             | 
| 271 | 
            -
             | 
| 272 | 
            -
            def get_data(config):
         | 
| 273 | 
            -
                # get data
         | 
| 274 | 
            -
                data = instantiate_from_config(config.data)
         | 
| 275 | 
            -
                data.prepare_data()
         | 
| 276 | 
            -
                data.setup()
         | 
| 277 | 
            -
                return data
         | 
| 278 | 
            -
             | 
| 279 | 
            -
             | 
| 280 | 
            -
            @st.cache(allow_output_mutation=True, suppress_st_warning=True)
         | 
| 281 | 
            -
            def load_model_and_dset(config, ckpt, gpu, eval_mode):
         | 
| 282 | 
            -
                # get data
         | 
| 283 | 
            -
                dsets = get_data(config)   # calls data.config ...
         | 
| 284 | 
            -
             | 
| 285 | 
            -
                # now load the specified checkpoint
         | 
| 286 | 
            -
                if ckpt:
         | 
| 287 | 
            -
                    pl_sd = torch.load(ckpt, map_location="cpu")
         | 
| 288 | 
            -
                    global_step = pl_sd["global_step"]
         | 
| 289 | 
            -
                else:
         | 
| 290 | 
            -
                    pl_sd = {"state_dict": None}
         | 
| 291 | 
            -
                    global_step = None
         | 
| 292 | 
            -
                model = load_model_from_config(config.model,
         | 
| 293 | 
            -
                                               pl_sd["state_dict"],
         | 
| 294 | 
            -
                                               gpu=gpu,
         | 
| 295 | 
            -
                                               eval_mode=eval_mode)["model"]
         | 
| 296 | 
            -
                return dsets, model, global_step
         | 
| 297 | 
            -
             | 
| 298 | 
            -
             | 
| 299 | 
            -
            if __name__ == "__main__":
         | 
| 300 | 
            -
                sys.path.append(os.getcwd())
         | 
| 301 | 
            -
             | 
| 302 | 
            -
                parser = get_parser()
         | 
| 303 | 
            -
             | 
| 304 | 
            -
                opt, unknown = parser.parse_known_args()
         | 
| 305 | 
            -
             | 
| 306 | 
            -
                ckpt = None
         | 
| 307 | 
            -
                if opt.resume:
         | 
| 308 | 
            -
                    if not os.path.exists(opt.resume):
         | 
| 309 | 
            -
                        raise ValueError("Cannot find {}".format(opt.resume))
         | 
| 310 | 
            -
                    if os.path.isfile(opt.resume):
         | 
| 311 | 
            -
                        paths = opt.resume.split("/")
         | 
| 312 | 
            -
                        try:
         | 
| 313 | 
            -
                            idx = len(paths)-paths[::-1].index("logs")+1
         | 
| 314 | 
            -
                        except ValueError:
         | 
| 315 | 
            -
                            idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
         | 
| 316 | 
            -
                        logdir = "/".join(paths[:idx])
         | 
| 317 | 
            -
                        ckpt = opt.resume
         | 
| 318 | 
            -
                    else:
         | 
| 319 | 
            -
                        assert os.path.isdir(opt.resume), opt.resume
         | 
| 320 | 
            -
                        logdir = opt.resume.rstrip("/")
         | 
| 321 | 
            -
                        ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
         | 
| 322 | 
            -
                    print(f"logdir:{logdir}")
         | 
| 323 | 
            -
                    base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml")))
         | 
| 324 | 
            -
                    opt.base = base_configs+opt.base
         | 
| 325 | 
            -
             | 
| 326 | 
            -
                if opt.config:
         | 
| 327 | 
            -
                    if type(opt.config) == str:
         | 
| 328 | 
            -
                        opt.base = [opt.config]
         | 
| 329 | 
            -
                    else:
         | 
| 330 | 
            -
                        opt.base = [opt.base[-1]]
         | 
| 331 | 
            -
             | 
| 332 | 
            -
                configs = [OmegaConf.load(cfg) for cfg in opt.base]
         | 
| 333 | 
            -
                cli = OmegaConf.from_dotlist(unknown)
         | 
| 334 | 
            -
                if opt.ignore_base_data:
         | 
| 335 | 
            -
                    for config in configs:
         | 
| 336 | 
            -
                        if hasattr(config, "data"): del config["data"]
         | 
| 337 | 
            -
                config = OmegaConf.merge(*configs, cli)
         | 
| 338 | 
            -
             | 
| 339 | 
            -
                st.sidebar.text(ckpt)
         | 
| 340 | 
            -
                gs = st.sidebar.empty()
         | 
| 341 | 
            -
                gs.text(f"Global step: ?")
         | 
| 342 | 
            -
                st.sidebar.text("Options")
         | 
| 343 | 
            -
                #gpu = st.sidebar.checkbox("GPU", value=True)
         | 
| 344 | 
            -
                gpu = True
         | 
| 345 | 
            -
                #eval_mode = st.sidebar.checkbox("Eval Mode", value=True)
         | 
| 346 | 
            -
                eval_mode = True
         | 
| 347 | 
            -
                #show_config = st.sidebar.checkbox("Show Config", value=False)
         | 
| 348 | 
            -
                show_config = False
         | 
| 349 | 
            -
                if show_config:
         | 
| 350 | 
            -
                    st.info("Checkpoint: {}".format(ckpt))
         | 
| 351 | 
            -
                    st.json(OmegaConf.to_container(config))
         | 
| 352 | 
            -
             | 
| 353 | 
            -
                dsets, model, global_step = load_model_and_dset(config, ckpt, gpu, eval_mode)
         | 
| 354 | 
            -
                gs.text(f"Global step: {global_step}")
         | 
| 355 | 
            -
                run_conditional(model, dsets)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        taming-transformers/scripts/sample_fast.py
    DELETED
    
    | @@ -1,260 +0,0 @@ | |
| 1 | 
            -
            import argparse, os, sys, glob
         | 
| 2 | 
            -
            import torch
         | 
| 3 | 
            -
            import time
         | 
| 4 | 
            -
            import numpy as np
         | 
| 5 | 
            -
            from omegaconf import OmegaConf
         | 
| 6 | 
            -
            from PIL import Image
         | 
| 7 | 
            -
            from tqdm import tqdm, trange
         | 
| 8 | 
            -
            from einops import repeat
         | 
| 9 | 
            -
             | 
| 10 | 
            -
            from main import instantiate_from_config
         | 
| 11 | 
            -
            from taming.modules.transformer.mingpt import sample_with_past
         | 
| 12 | 
            -
             | 
| 13 | 
            -
             | 
| 14 | 
            -
            rescale = lambda x: (x + 1.) / 2.
         | 
| 15 | 
            -
             | 
| 16 | 
            -
             | 
| 17 | 
            -
            def chw_to_pillow(x):
         | 
| 18 | 
            -
                return Image.fromarray((255*rescale(x.detach().cpu().numpy().transpose(1,2,0))).clip(0,255).astype(np.uint8))
         | 
| 19 | 
            -
             | 
| 20 | 
            -
             | 
| 21 | 
            -
            @torch.no_grad()
         | 
| 22 | 
            -
            def sample_classconditional(model, batch_size, class_label, steps=256, temperature=None, top_k=None, callback=None,
         | 
| 23 | 
            -
                                        dim_z=256, h=16, w=16, verbose_time=False, top_p=None):
         | 
| 24 | 
            -
                log = dict()
         | 
| 25 | 
            -
                assert type(class_label) == int, f'expecting type int but type is {type(class_label)}'
         | 
| 26 | 
            -
                qzshape = [batch_size, dim_z, h, w]
         | 
| 27 | 
            -
                assert not model.be_unconditional, 'Expecting a class-conditional Net2NetTransformer.'
         | 
| 28 | 
            -
                c_indices = repeat(torch.tensor([class_label]), '1 -> b 1', b=batch_size).to(model.device)  # class token
         | 
| 29 | 
            -
                t1 = time.time()
         | 
| 30 | 
            -
                index_sample = sample_with_past(c_indices, model.transformer, steps=steps,
         | 
| 31 | 
            -
                                                sample_logits=True, top_k=top_k, callback=callback,
         | 
| 32 | 
            -
                                                temperature=temperature, top_p=top_p)
         | 
| 33 | 
            -
                if verbose_time:
         | 
| 34 | 
            -
                    sampling_time = time.time() - t1
         | 
| 35 | 
            -
                    print(f"Full sampling takes about {sampling_time:.2f} seconds.")
         | 
| 36 | 
            -
                x_sample = model.decode_to_img(index_sample, qzshape)
         | 
| 37 | 
            -
                log["samples"] = x_sample
         | 
| 38 | 
            -
                log["class_label"] = c_indices
         | 
| 39 | 
            -
                return log
         | 
| 40 | 
            -
             | 
| 41 | 
            -
             | 
| 42 | 
            -
            @torch.no_grad()
         | 
| 43 | 
            -
            def sample_unconditional(model, batch_size, steps=256, temperature=None, top_k=None, top_p=None, callback=None,
         | 
| 44 | 
            -
                                     dim_z=256, h=16, w=16, verbose_time=False):
         | 
| 45 | 
            -
                log = dict()
         | 
| 46 | 
            -
                qzshape = [batch_size, dim_z, h, w]
         | 
| 47 | 
            -
                assert model.be_unconditional, 'Expecting an unconditional model.'
         | 
| 48 | 
            -
                c_indices = repeat(torch.tensor([model.sos_token]), '1 -> b 1', b=batch_size).to(model.device)  # sos token
         | 
| 49 | 
            -
                t1 = time.time()
         | 
| 50 | 
            -
                index_sample = sample_with_past(c_indices, model.transformer, steps=steps,
         | 
| 51 | 
            -
                                                sample_logits=True, top_k=top_k, callback=callback,
         | 
| 52 | 
            -
                                                temperature=temperature, top_p=top_p)
         | 
| 53 | 
            -
                if verbose_time:
         | 
| 54 | 
            -
                    sampling_time = time.time() - t1
         | 
| 55 | 
            -
                    print(f"Full sampling takes about {sampling_time:.2f} seconds.")
         | 
| 56 | 
            -
                x_sample = model.decode_to_img(index_sample, qzshape)
         | 
| 57 | 
            -
                log["samples"] = x_sample
         | 
| 58 | 
            -
                return log
         | 
| 59 | 
            -
             | 
| 60 | 
            -
             | 
| 61 | 
            -
            @torch.no_grad()
         | 
| 62 | 
            -
            def run(logdir, model, batch_size, temperature, top_k, unconditional=True, num_samples=50000,
         | 
| 63 | 
            -
                    given_classes=None, top_p=None):
         | 
| 64 | 
            -
                batches = [batch_size for _ in range(num_samples//batch_size)] + [num_samples % batch_size]
         | 
| 65 | 
            -
                if not unconditional:
         | 
| 66 | 
            -
                    assert given_classes is not None
         | 
| 67 | 
            -
                    print("Running in pure class-conditional sampling mode. I will produce "
         | 
| 68 | 
            -
                          f"{num_samples} samples for each of the {len(given_classes)} classes, "
         | 
| 69 | 
            -
                          f"i.e. {num_samples*len(given_classes)} in total.")
         | 
| 70 | 
            -
                    for class_label in tqdm(given_classes, desc="Classes"):
         | 
| 71 | 
            -
                        for n, bs in tqdm(enumerate(batches), desc="Sampling Class"):
         | 
| 72 | 
            -
                            if bs == 0: break
         | 
| 73 | 
            -
                            logs = sample_classconditional(model, batch_size=bs, class_label=class_label,
         | 
| 74 | 
            -
                                                           temperature=temperature, top_k=top_k, top_p=top_p)
         | 
| 75 | 
            -
                            save_from_logs(logs, logdir, base_count=n * batch_size, cond_key=logs["class_label"])
         | 
| 76 | 
            -
                else:
         | 
| 77 | 
            -
                    print(f"Running in unconditional sampling mode, producing {num_samples} samples.")
         | 
| 78 | 
            -
                    for n, bs in tqdm(enumerate(batches), desc="Sampling"):
         | 
| 79 | 
            -
                        if bs == 0: break
         | 
| 80 | 
            -
                        logs = sample_unconditional(model, batch_size=bs, temperature=temperature, top_k=top_k, top_p=top_p)
         | 
| 81 | 
            -
                        save_from_logs(logs, logdir, base_count=n * batch_size)
         | 
| 82 | 
            -
             | 
| 83 | 
            -
             | 
| 84 | 
            -
            def save_from_logs(logs, logdir, base_count, key="samples", cond_key=None):
         | 
| 85 | 
            -
                xx = logs[key]
         | 
| 86 | 
            -
                for i, x in enumerate(xx):
         | 
| 87 | 
            -
                    x = chw_to_pillow(x)
         | 
| 88 | 
            -
                    count = base_count + i
         | 
| 89 | 
            -
                    if cond_key is None:
         | 
| 90 | 
            -
                        x.save(os.path.join(logdir, f"{count:06}.png"))
         | 
| 91 | 
            -
                    else:
         | 
| 92 | 
            -
                        condlabel = cond_key[i]
         | 
| 93 | 
            -
                        if type(condlabel) == torch.Tensor: condlabel = condlabel.item()
         | 
| 94 | 
            -
                        os.makedirs(os.path.join(logdir, str(condlabel)), exist_ok=True)
         | 
| 95 | 
            -
                        x.save(os.path.join(logdir, str(condlabel), f"{count:06}.png"))
         | 
| 96 | 
            -
             | 
| 97 | 
            -
             | 
| 98 | 
            -
            def get_parser():
         | 
| 99 | 
            -
                def str2bool(v):
         | 
| 100 | 
            -
                    if isinstance(v, bool):
         | 
| 101 | 
            -
                        return v
         | 
| 102 | 
            -
                    if v.lower() in ("yes", "true", "t", "y", "1"):
         | 
| 103 | 
            -
                        return True
         | 
| 104 | 
            -
                    elif v.lower() in ("no", "false", "f", "n", "0"):
         | 
| 105 | 
            -
                        return False
         | 
| 106 | 
            -
                    else:
         | 
| 107 | 
            -
                        raise argparse.ArgumentTypeError("Boolean value expected.")
         | 
| 108 | 
            -
             | 
| 109 | 
            -
                parser = argparse.ArgumentParser()
         | 
| 110 | 
            -
                parser.add_argument(
         | 
| 111 | 
            -
                    "-r",
         | 
| 112 | 
            -
                    "--resume",
         | 
| 113 | 
            -
                    type=str,
         | 
| 114 | 
            -
                    nargs="?",
         | 
| 115 | 
            -
                    help="load from logdir or checkpoint in logdir",
         | 
| 116 | 
            -
                )
         | 
| 117 | 
            -
                parser.add_argument(
         | 
| 118 | 
            -
                    "-o",
         | 
| 119 | 
            -
                    "--outdir",
         | 
| 120 | 
            -
                    type=str,
         | 
| 121 | 
            -
                    nargs="?",
         | 
| 122 | 
            -
                    help="path where the samples will be logged to.",
         | 
| 123 | 
            -
                    default=""
         | 
| 124 | 
            -
                )
         | 
| 125 | 
            -
                parser.add_argument(
         | 
| 126 | 
            -
                    "-b",
         | 
| 127 | 
            -
                    "--base",
         | 
| 128 | 
            -
                    nargs="*",
         | 
| 129 | 
            -
                    metavar="base_config.yaml",
         | 
| 130 | 
            -
                    help="paths to base configs. Loaded from left-to-right. "
         | 
| 131 | 
            -
                    "Parameters can be overwritten or added with command-line options of the form `--key value`.",
         | 
| 132 | 
            -
                    default=list(),
         | 
| 133 | 
            -
                )
         | 
| 134 | 
            -
                parser.add_argument(
         | 
| 135 | 
            -
                    "-n",
         | 
| 136 | 
            -
                    "--num_samples",
         | 
| 137 | 
            -
                    type=int,
         | 
| 138 | 
            -
                    nargs="?",
         | 
| 139 | 
            -
                    help="num_samples to draw",
         | 
| 140 | 
            -
                    default=50000
         | 
| 141 | 
            -
                )
         | 
| 142 | 
            -
                parser.add_argument(
         | 
| 143 | 
            -
                    "--batch_size",
         | 
| 144 | 
            -
                    type=int,
         | 
| 145 | 
            -
                    nargs="?",
         | 
| 146 | 
            -
                    help="the batch size",
         | 
| 147 | 
            -
                    default=25
         | 
| 148 | 
            -
                )
         | 
| 149 | 
            -
                parser.add_argument(
         | 
| 150 | 
            -
                    "-k",
         | 
| 151 | 
            -
                    "--top_k",
         | 
| 152 | 
            -
                    type=int,
         | 
| 153 | 
            -
                    nargs="?",
         | 
| 154 | 
            -
                    help="top-k value to sample with",
         | 
| 155 | 
            -
                    default=250,
         | 
| 156 | 
            -
                )
         | 
| 157 | 
            -
                parser.add_argument(
         | 
| 158 | 
            -
                    "-t",
         | 
| 159 | 
            -
                    "--temperature",
         | 
| 160 | 
            -
                    type=float,
         | 
| 161 | 
            -
                    nargs="?",
         | 
| 162 | 
            -
                    help="temperature value to sample with",
         | 
| 163 | 
            -
                    default=1.0
         | 
| 164 | 
            -
                )
         | 
| 165 | 
            -
                parser.add_argument(
         | 
| 166 | 
            -
                    "-p",
         | 
| 167 | 
            -
                    "--top_p",
         | 
| 168 | 
            -
                    type=float,
         | 
| 169 | 
            -
                    nargs="?",
         | 
| 170 | 
            -
                    help="top-p value to sample with",
         | 
| 171 | 
            -
                    default=1.0
         | 
| 172 | 
            -
                )
         | 
| 173 | 
            -
                parser.add_argument(
         | 
| 174 | 
            -
                    "--classes",
         | 
| 175 | 
            -
                    type=str,
         | 
| 176 | 
            -
                    nargs="?",
         | 
| 177 | 
            -
                    help="specify comma-separated classes to sample from. Uses 1000 classes per default.",
         | 
| 178 | 
            -
                    default="imagenet"
         | 
| 179 | 
            -
                )
         | 
| 180 | 
            -
                return parser
         | 
| 181 | 
            -
             | 
| 182 | 
            -
             | 
| 183 | 
            -
            def load_model_from_config(config, sd, gpu=True, eval_mode=True):
         | 
| 184 | 
            -
                model = instantiate_from_config(config)
         | 
| 185 | 
            -
                if sd is not None:
         | 
| 186 | 
            -
                    model.load_state_dict(sd)
         | 
| 187 | 
            -
                if gpu:
         | 
| 188 | 
            -
                    model.cuda()
         | 
| 189 | 
            -
                if eval_mode:
         | 
| 190 | 
            -
                    model.eval()
         | 
| 191 | 
            -
                return {"model": model}
         | 
| 192 | 
            -
             | 
| 193 | 
            -
             | 
| 194 | 
            -
            def load_model(config, ckpt, gpu, eval_mode):
         | 
| 195 | 
            -
                # load the specified checkpoint
         | 
| 196 | 
            -
                if ckpt:
         | 
| 197 | 
            -
                    pl_sd = torch.load(ckpt, map_location="cpu")
         | 
| 198 | 
            -
                    global_step = pl_sd["global_step"]
         | 
| 199 | 
            -
                    print(f"loaded model from global step {global_step}.")
         | 
| 200 | 
            -
                else:
         | 
| 201 | 
            -
                    pl_sd = {"state_dict": None}
         | 
| 202 | 
            -
                    global_step = None
         | 
| 203 | 
            -
                model = load_model_from_config(config.model, pl_sd["state_dict"], gpu=gpu, eval_mode=eval_mode)["model"]
         | 
| 204 | 
            -
                return model, global_step
         | 
| 205 | 
            -
             | 
| 206 | 
            -
             | 
| 207 | 
            -
            if __name__ == "__main__":
         | 
| 208 | 
            -
                sys.path.append(os.getcwd())
         | 
| 209 | 
            -
                parser = get_parser()
         | 
| 210 | 
            -
             | 
| 211 | 
            -
                opt, unknown = parser.parse_known_args()
         | 
| 212 | 
            -
                assert opt.resume
         | 
| 213 | 
            -
             | 
| 214 | 
            -
                ckpt = None
         | 
| 215 | 
            -
             | 
| 216 | 
            -
                if not os.path.exists(opt.resume):
         | 
| 217 | 
            -
                    raise ValueError("Cannot find {}".format(opt.resume))
         | 
| 218 | 
            -
                if os.path.isfile(opt.resume):
         | 
| 219 | 
            -
                    paths = opt.resume.split("/")
         | 
| 220 | 
            -
                    try:
         | 
| 221 | 
            -
                        idx = len(paths)-paths[::-1].index("logs")+1
         | 
| 222 | 
            -
                    except ValueError:
         | 
| 223 | 
            -
                        idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
         | 
| 224 | 
            -
                    logdir = "/".join(paths[:idx])
         | 
| 225 | 
            -
                    ckpt = opt.resume
         | 
| 226 | 
            -
                else:
         | 
| 227 | 
            -
                    assert os.path.isdir(opt.resume), opt.resume
         | 
| 228 | 
            -
                    logdir = opt.resume.rstrip("/")
         | 
| 229 | 
            -
                    ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
         | 
| 230 | 
            -
             | 
| 231 | 
            -
                base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml")))
         | 
| 232 | 
            -
                opt.base = base_configs+opt.base
         | 
| 233 | 
            -
             | 
| 234 | 
            -
                configs = [OmegaConf.load(cfg) for cfg in opt.base]
         | 
| 235 | 
            -
                cli = OmegaConf.from_dotlist(unknown)
         | 
| 236 | 
            -
                config = OmegaConf.merge(*configs, cli)
         | 
| 237 | 
            -
             | 
| 238 | 
            -
                model, global_step = load_model(config, ckpt, gpu=True, eval_mode=True)
         | 
| 239 | 
            -
             | 
| 240 | 
            -
                if opt.outdir:
         | 
| 241 | 
            -
                    print(f"Switching logdir from '{logdir}' to '{opt.outdir}'")
         | 
| 242 | 
            -
                    logdir = opt.outdir
         | 
| 243 | 
            -
             | 
| 244 | 
            -
                if opt.classes == "imagenet":
         | 
| 245 | 
            -
                    given_classes = [i for i in range(1000)]
         | 
| 246 | 
            -
                else:
         | 
| 247 | 
            -
                    cls_str = opt.classes
         | 
| 248 | 
            -
                    assert not cls_str.endswith(","), 'class string should not end with a ","'
         | 
| 249 | 
            -
                    given_classes = [int(c) for c in cls_str.split(",")]
         | 
| 250 | 
            -
             | 
| 251 | 
            -
                logdir = os.path.join(logdir, "samples", f"top_k_{opt.top_k}_temp_{opt.temperature:.2f}_top_p_{opt.top_p}",
         | 
| 252 | 
            -
                                      f"{global_step}")
         | 
| 253 | 
            -
             | 
| 254 | 
            -
                print(f"Logging to {logdir}")
         | 
| 255 | 
            -
                os.makedirs(logdir, exist_ok=True)
         | 
| 256 | 
            -
             | 
| 257 | 
            -
                run(logdir, model, opt.batch_size, opt.temperature, opt.top_k, unconditional=model.be_unconditional,
         | 
| 258 | 
            -
                    given_classes=given_classes, num_samples=opt.num_samples, top_p=opt.top_p)
         | 
| 259 | 
            -
             | 
| 260 | 
            -
                print("done.")
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        taming-transformers/setup.py
    DELETED
    
    | @@ -1,13 +0,0 @@ | |
| 1 | 
            -
            from setuptools import setup, find_packages
         | 
| 2 | 
            -
             | 
| 3 | 
            -
            setup(
         | 
| 4 | 
            -
                name='taming-transformers',
         | 
| 5 | 
            -
                version='0.0.1',
         | 
| 6 | 
            -
                description='Taming Transformers for High-Resolution Image Synthesis',
         | 
| 7 | 
            -
                packages=find_packages(),
         | 
| 8 | 
            -
                install_requires=[
         | 
| 9 | 
            -
                    'torch',
         | 
| 10 | 
            -
                    'numpy',
         | 
| 11 | 
            -
                    'tqdm',
         | 
| 12 | 
            -
                ],
         | 
| 13 | 
            -
            )
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        taming-transformers/taming/lr_scheduler.py
    DELETED
    
    | @@ -1,34 +0,0 @@ | |
| 1 | 
            -
            import numpy as np
         | 
| 2 | 
            -
             | 
| 3 | 
            -
             | 
| 4 | 
            -
            class LambdaWarmUpCosineScheduler:
         | 
| 5 | 
            -
                """
         | 
| 6 | 
            -
                note: use with a base_lr of 1.0
         | 
| 7 | 
            -
                """
         | 
| 8 | 
            -
                def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
         | 
| 9 | 
            -
                    self.lr_warm_up_steps = warm_up_steps
         | 
| 10 | 
            -
                    self.lr_start = lr_start
         | 
| 11 | 
            -
                    self.lr_min = lr_min
         | 
| 12 | 
            -
                    self.lr_max = lr_max
         | 
| 13 | 
            -
                    self.lr_max_decay_steps = max_decay_steps
         | 
| 14 | 
            -
                    self.last_lr = 0.
         | 
| 15 | 
            -
                    self.verbosity_interval = verbosity_interval
         | 
| 16 | 
            -
             | 
| 17 | 
            -
                def schedule(self, n):
         | 
| 18 | 
            -
                    if self.verbosity_interval > 0:
         | 
| 19 | 
            -
                        if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
         | 
| 20 | 
            -
                    if n < self.lr_warm_up_steps:
         | 
| 21 | 
            -
                        lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
         | 
| 22 | 
            -
                        self.last_lr = lr
         | 
| 23 | 
            -
                        return lr
         | 
| 24 | 
            -
                    else:
         | 
| 25 | 
            -
                        t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
         | 
| 26 | 
            -
                        t = min(t, 1.0)
         | 
| 27 | 
            -
                        lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
         | 
| 28 | 
            -
                                1 + np.cos(t * np.pi))
         | 
| 29 | 
            -
                        self.last_lr = lr
         | 
| 30 | 
            -
                        return lr
         | 
| 31 | 
            -
             | 
| 32 | 
            -
                def __call__(self, n):
         | 
| 33 | 
            -
                    return self.schedule(n)
         | 
| 34 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        taming-transformers/taming/models/cond_transformer.py
    DELETED
    
    | @@ -1,352 +0,0 @@ | |
| 1 | 
            -
            import os, math
         | 
| 2 | 
            -
            import torch
         | 
| 3 | 
            -
            import torch.nn.functional as F
         | 
| 4 | 
            -
            import pytorch_lightning as pl
         | 
| 5 | 
            -
             | 
| 6 | 
            -
            from main import instantiate_from_config
         | 
| 7 | 
            -
            from taming.modules.util import SOSProvider
         | 
| 8 | 
            -
             | 
| 9 | 
            -
             | 
| 10 | 
            -
            def disabled_train(self, mode=True):
         | 
| 11 | 
            -
                """Overwrite model.train with this function to make sure train/eval mode
         | 
| 12 | 
            -
                does not change anymore."""
         | 
| 13 | 
            -
                return self
         | 
| 14 | 
            -
             | 
| 15 | 
            -
             | 
| 16 | 
            -
            class Net2NetTransformer(pl.LightningModule):
         | 
| 17 | 
            -
                def __init__(self,
         | 
| 18 | 
            -
                             transformer_config,
         | 
| 19 | 
            -
                             first_stage_config,
         | 
| 20 | 
            -
                             cond_stage_config,
         | 
| 21 | 
            -
                             permuter_config=None,
         | 
| 22 | 
            -
                             ckpt_path=None,
         | 
| 23 | 
            -
                             ignore_keys=[],
         | 
| 24 | 
            -
                             first_stage_key="image",
         | 
| 25 | 
            -
                             cond_stage_key="depth",
         | 
| 26 | 
            -
                             downsample_cond_size=-1,
         | 
| 27 | 
            -
                             pkeep=1.0,
         | 
| 28 | 
            -
                             sos_token=0,
         | 
| 29 | 
            -
                             unconditional=False,
         | 
| 30 | 
            -
                             ):
         | 
| 31 | 
            -
                    super().__init__()
         | 
| 32 | 
            -
                    self.be_unconditional = unconditional
         | 
| 33 | 
            -
                    self.sos_token = sos_token
         | 
| 34 | 
            -
                    self.first_stage_key = first_stage_key
         | 
| 35 | 
            -
                    self.cond_stage_key = cond_stage_key
         | 
| 36 | 
            -
                    self.init_first_stage_from_ckpt(first_stage_config)
         | 
| 37 | 
            -
                    self.init_cond_stage_from_ckpt(cond_stage_config)
         | 
| 38 | 
            -
                    if permuter_config is None:
         | 
| 39 | 
            -
                        permuter_config = {"target": "taming.modules.transformer.permuter.Identity"}
         | 
| 40 | 
            -
                    self.permuter = instantiate_from_config(config=permuter_config)
         | 
| 41 | 
            -
                    self.transformer = instantiate_from_config(config=transformer_config)
         | 
| 42 | 
            -
             | 
| 43 | 
            -
                    if ckpt_path is not None:
         | 
| 44 | 
            -
                        self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
         | 
| 45 | 
            -
                    self.downsample_cond_size = downsample_cond_size
         | 
| 46 | 
            -
                    self.pkeep = pkeep
         | 
| 47 | 
            -
             | 
| 48 | 
            -
                def init_from_ckpt(self, path, ignore_keys=list()):
         | 
| 49 | 
            -
                    sd = torch.load(path, map_location="cpu")["state_dict"]
         | 
| 50 | 
            -
                    for k in sd.keys():
         | 
| 51 | 
            -
                        for ik in ignore_keys:
         | 
| 52 | 
            -
                            if k.startswith(ik):
         | 
| 53 | 
            -
                                self.print("Deleting key {} from state_dict.".format(k))
         | 
| 54 | 
            -
                                del sd[k]
         | 
| 55 | 
            -
                    self.load_state_dict(sd, strict=False)
         | 
| 56 | 
            -
                    print(f"Restored from {path}")
         | 
| 57 | 
            -
             | 
| 58 | 
            -
                def init_first_stage_from_ckpt(self, config):
         | 
| 59 | 
            -
                    model = instantiate_from_config(config)
         | 
| 60 | 
            -
                    model = model.eval()
         | 
| 61 | 
            -
                    model.train = disabled_train
         | 
| 62 | 
            -
                    self.first_stage_model = model
         | 
| 63 | 
            -
             | 
| 64 | 
            -
                def init_cond_stage_from_ckpt(self, config):
         | 
| 65 | 
            -
                    if config == "__is_first_stage__":
         | 
| 66 | 
            -
                        print("Using first stage also as cond stage.")
         | 
| 67 | 
            -
                        self.cond_stage_model = self.first_stage_model
         | 
| 68 | 
            -
                    elif config == "__is_unconditional__" or self.be_unconditional:
         | 
| 69 | 
            -
                        print(f"Using no cond stage. Assuming the training is intended to be unconditional. "
         | 
| 70 | 
            -
                              f"Prepending {self.sos_token} as a sos token.")
         | 
| 71 | 
            -
                        self.be_unconditional = True
         | 
| 72 | 
            -
                        self.cond_stage_key = self.first_stage_key
         | 
| 73 | 
            -
                        self.cond_stage_model = SOSProvider(self.sos_token)
         | 
| 74 | 
            -
                    else:
         | 
| 75 | 
            -
                        model = instantiate_from_config(config)
         | 
| 76 | 
            -
                        model = model.eval()
         | 
| 77 | 
            -
                        model.train = disabled_train
         | 
| 78 | 
            -
                        self.cond_stage_model = model
         | 
| 79 | 
            -
             | 
| 80 | 
            -
                def forward(self, x, c):
         | 
| 81 | 
            -
                    # one step to produce the logits
         | 
| 82 | 
            -
                    _, z_indices = self.encode_to_z(x)
         | 
| 83 | 
            -
                    _, c_indices = self.encode_to_c(c)
         | 
| 84 | 
            -
             | 
| 85 | 
            -
                    if self.training and self.pkeep < 1.0:
         | 
| 86 | 
            -
                        mask = torch.bernoulli(self.pkeep*torch.ones(z_indices.shape,
         | 
| 87 | 
            -
                                                                     device=z_indices.device))
         | 
| 88 | 
            -
                        mask = mask.round().to(dtype=torch.int64)
         | 
| 89 | 
            -
                        r_indices = torch.randint_like(z_indices, self.transformer.config.vocab_size)
         | 
| 90 | 
            -
                        a_indices = mask*z_indices+(1-mask)*r_indices
         | 
| 91 | 
            -
                    else:
         | 
| 92 | 
            -
                        a_indices = z_indices
         | 
| 93 | 
            -
             | 
| 94 | 
            -
                    cz_indices = torch.cat((c_indices, a_indices), dim=1)
         | 
| 95 | 
            -
             | 
| 96 | 
            -
                    # target includes all sequence elements (no need to handle first one
         | 
| 97 | 
            -
                    # differently because we are conditioning)
         | 
| 98 | 
            -
                    target = z_indices
         | 
| 99 | 
            -
                    # make the prediction
         | 
| 100 | 
            -
                    logits, _ = self.transformer(cz_indices[:, :-1])
         | 
| 101 | 
            -
                    # cut off conditioning outputs - output i corresponds to p(z_i | z_{<i}, c)
         | 
| 102 | 
            -
                    logits = logits[:, c_indices.shape[1]-1:]
         | 
| 103 | 
            -
             | 
| 104 | 
            -
                    return logits, target
         | 
| 105 | 
            -
             | 
| 106 | 
            -
                def top_k_logits(self, logits, k):
         | 
| 107 | 
            -
                    v, ix = torch.topk(logits, k)
         | 
| 108 | 
            -
                    out = logits.clone()
         | 
| 109 | 
            -
                    out[out < v[..., [-1]]] = -float('Inf')
         | 
| 110 | 
            -
                    return out
         | 
| 111 | 
            -
             | 
| 112 | 
            -
                @torch.no_grad()
         | 
| 113 | 
            -
                def sample(self, x, c, steps, temperature=1.0, sample=False, top_k=None,
         | 
| 114 | 
            -
                           callback=lambda k: None):
         | 
| 115 | 
            -
                    x = torch.cat((c,x),dim=1)
         | 
| 116 | 
            -
                    block_size = self.transformer.get_block_size()
         | 
| 117 | 
            -
                    assert not self.transformer.training
         | 
| 118 | 
            -
                    if self.pkeep <= 0.0:
         | 
| 119 | 
            -
                        # one pass suffices since input is pure noise anyway
         | 
| 120 | 
            -
                        assert len(x.shape)==2
         | 
| 121 | 
            -
                        noise_shape = (x.shape[0], steps-1)
         | 
| 122 | 
            -
                        #noise = torch.randint(self.transformer.config.vocab_size, noise_shape).to(x)
         | 
| 123 | 
            -
                        noise = c.clone()[:,x.shape[1]-c.shape[1]:-1]
         | 
| 124 | 
            -
                        x = torch.cat((x,noise),dim=1)
         | 
| 125 | 
            -
                        logits, _ = self.transformer(x)
         | 
| 126 | 
            -
                        # take all logits for now and scale by temp
         | 
| 127 | 
            -
                        logits = logits / temperature
         | 
| 128 | 
            -
                        # optionally crop probabilities to only the top k options
         | 
| 129 | 
            -
                        if top_k is not None:
         | 
| 130 | 
            -
                            logits = self.top_k_logits(logits, top_k)
         | 
| 131 | 
            -
                        # apply softmax to convert to probabilities
         | 
| 132 | 
            -
                        probs = F.softmax(logits, dim=-1)
         | 
| 133 | 
            -
                        # sample from the distribution or take the most likely
         | 
| 134 | 
            -
                        if sample:
         | 
| 135 | 
            -
                            shape = probs.shape
         | 
| 136 | 
            -
                            probs = probs.reshape(shape[0]*shape[1],shape[2])
         | 
| 137 | 
            -
                            ix = torch.multinomial(probs, num_samples=1)
         | 
| 138 | 
            -
                            probs = probs.reshape(shape[0],shape[1],shape[2])
         | 
| 139 | 
            -
                            ix = ix.reshape(shape[0],shape[1])
         | 
| 140 | 
            -
                        else:
         | 
| 141 | 
            -
                            _, ix = torch.topk(probs, k=1, dim=-1)
         | 
| 142 | 
            -
                        # cut off conditioning
         | 
| 143 | 
            -
                        x = ix[:, c.shape[1]-1:]
         | 
| 144 | 
            -
                    else:
         | 
| 145 | 
            -
                        for k in range(steps):
         | 
| 146 | 
            -
                            callback(k)
         | 
| 147 | 
            -
                            assert x.size(1) <= block_size # make sure model can see conditioning
         | 
| 148 | 
            -
                            x_cond = x if x.size(1) <= block_size else x[:, -block_size:]  # crop context if needed
         | 
| 149 | 
            -
                            logits, _ = self.transformer(x_cond)
         | 
| 150 | 
            -
                            # pluck the logits at the final step and scale by temperature
         | 
| 151 | 
            -
                            logits = logits[:, -1, :] / temperature
         | 
| 152 | 
            -
                            # optionally crop probabilities to only the top k options
         | 
| 153 | 
            -
                            if top_k is not None:
         | 
| 154 | 
            -
                                logits = self.top_k_logits(logits, top_k)
         | 
| 155 | 
            -
                            # apply softmax to convert to probabilities
         | 
| 156 | 
            -
                            probs = F.softmax(logits, dim=-1)
         | 
| 157 | 
            -
                            # sample from the distribution or take the most likely
         | 
| 158 | 
            -
                            if sample:
         | 
| 159 | 
            -
                                ix = torch.multinomial(probs, num_samples=1)
         | 
| 160 | 
            -
                            else:
         | 
| 161 | 
            -
                                _, ix = torch.topk(probs, k=1, dim=-1)
         | 
| 162 | 
            -
                            # append to the sequence and continue
         | 
| 163 | 
            -
                            x = torch.cat((x, ix), dim=1)
         | 
| 164 | 
            -
                        # cut off conditioning
         | 
| 165 | 
            -
                        x = x[:, c.shape[1]:]
         | 
| 166 | 
            -
                    return x
         | 
| 167 | 
            -
             | 
| 168 | 
            -
                @torch.no_grad()
         | 
| 169 | 
            -
                def encode_to_z(self, x):
         | 
| 170 | 
            -
                    quant_z, _, info = self.first_stage_model.encode(x)
         | 
| 171 | 
            -
                    indices = info[2].view(quant_z.shape[0], -1)
         | 
| 172 | 
            -
                    indices = self.permuter(indices)
         | 
| 173 | 
            -
                    return quant_z, indices
         | 
| 174 | 
            -
             | 
| 175 | 
            -
                @torch.no_grad()
         | 
| 176 | 
            -
                def encode_to_c(self, c):
         | 
| 177 | 
            -
                    if self.downsample_cond_size > -1:
         | 
| 178 | 
            -
                        c = F.interpolate(c, size=(self.downsample_cond_size, self.downsample_cond_size))
         | 
| 179 | 
            -
                    quant_c, _, [_,_,indices] = self.cond_stage_model.encode(c)
         | 
| 180 | 
            -
                    if len(indices.shape) > 2:
         | 
| 181 | 
            -
                        indices = indices.view(c.shape[0], -1)
         | 
| 182 | 
            -
                    return quant_c, indices
         | 
| 183 | 
            -
             | 
| 184 | 
            -
                @torch.no_grad()
         | 
| 185 | 
            -
                def decode_to_img(self, index, zshape):
         | 
| 186 | 
            -
                    index = self.permuter(index, reverse=True)
         | 
| 187 | 
            -
                    bhwc = (zshape[0],zshape[2],zshape[3],zshape[1])
         | 
| 188 | 
            -
                    quant_z = self.first_stage_model.quantize.get_codebook_entry(
         | 
| 189 | 
            -
                        index.reshape(-1), shape=bhwc)
         | 
| 190 | 
            -
                    x = self.first_stage_model.decode(quant_z)
         | 
| 191 | 
            -
                    return x
         | 
| 192 | 
            -
             | 
| 193 | 
            -
                @torch.no_grad()
         | 
| 194 | 
            -
                def log_images(self, batch, temperature=None, top_k=None, callback=None, lr_interface=False, **kwargs):
         | 
| 195 | 
            -
                    log = dict()
         | 
| 196 | 
            -
             | 
| 197 | 
            -
                    N = 4
         | 
| 198 | 
            -
                    if lr_interface:
         | 
| 199 | 
            -
                        x, c = self.get_xc(batch, N, diffuse=False, upsample_factor=8)
         | 
| 200 | 
            -
                    else:
         | 
| 201 | 
            -
                        x, c = self.get_xc(batch, N)
         | 
| 202 | 
            -
                    x = x.to(device=self.device)
         | 
| 203 | 
            -
                    c = c.to(device=self.device)
         | 
| 204 | 
            -
             | 
| 205 | 
            -
                    quant_z, z_indices = self.encode_to_z(x)
         | 
| 206 | 
            -
                    quant_c, c_indices = self.encode_to_c(c)
         | 
| 207 | 
            -
             | 
| 208 | 
            -
                    # create a "half"" sample
         | 
| 209 | 
            -
                    z_start_indices = z_indices[:,:z_indices.shape[1]//2]
         | 
| 210 | 
            -
                    index_sample = self.sample(z_start_indices, c_indices,
         | 
| 211 | 
            -
                                               steps=z_indices.shape[1]-z_start_indices.shape[1],
         | 
| 212 | 
            -
                                               temperature=temperature if temperature is not None else 1.0,
         | 
| 213 | 
            -
                                               sample=True,
         | 
| 214 | 
            -
                                               top_k=top_k if top_k is not None else 100,
         | 
| 215 | 
            -
                                               callback=callback if callback is not None else lambda k: None)
         | 
| 216 | 
            -
                    x_sample = self.decode_to_img(index_sample, quant_z.shape)
         | 
| 217 | 
            -
             | 
| 218 | 
            -
                    # sample
         | 
| 219 | 
            -
                    z_start_indices = z_indices[:, :0]
         | 
| 220 | 
            -
                    index_sample = self.sample(z_start_indices, c_indices,
         | 
| 221 | 
            -
                                               steps=z_indices.shape[1],
         | 
| 222 | 
            -
                                               temperature=temperature if temperature is not None else 1.0,
         | 
| 223 | 
            -
                                               sample=True,
         | 
| 224 | 
            -
                                               top_k=top_k if top_k is not None else 100,
         | 
| 225 | 
            -
                                               callback=callback if callback is not None else lambda k: None)
         | 
| 226 | 
            -
                    x_sample_nopix = self.decode_to_img(index_sample, quant_z.shape)
         | 
| 227 | 
            -
             | 
| 228 | 
            -
                    # det sample
         | 
| 229 | 
            -
                    z_start_indices = z_indices[:, :0]
         | 
| 230 | 
            -
                    index_sample = self.sample(z_start_indices, c_indices,
         | 
| 231 | 
            -
                                               steps=z_indices.shape[1],
         | 
| 232 | 
            -
                                               sample=False,
         | 
| 233 | 
            -
                                               callback=callback if callback is not None else lambda k: None)
         | 
| 234 | 
            -
                    x_sample_det = self.decode_to_img(index_sample, quant_z.shape)
         | 
| 235 | 
            -
             | 
| 236 | 
            -
                    # reconstruction
         | 
| 237 | 
            -
                    x_rec = self.decode_to_img(z_indices, quant_z.shape)
         | 
| 238 | 
            -
             | 
| 239 | 
            -
                    log["inputs"] = x
         | 
| 240 | 
            -
                    log["reconstructions"] = x_rec
         | 
| 241 | 
            -
             | 
| 242 | 
            -
                    if self.cond_stage_key in ["objects_bbox", "objects_center_points"]:
         | 
| 243 | 
            -
                        figure_size = (x_rec.shape[2], x_rec.shape[3])
         | 
| 244 | 
            -
                        dataset = kwargs["pl_module"].trainer.datamodule.datasets["validation"]
         | 
| 245 | 
            -
                        label_for_category_no = dataset.get_textual_label_for_category_no
         | 
| 246 | 
            -
                        plotter = dataset.conditional_builders[self.cond_stage_key].plot
         | 
| 247 | 
            -
                        log["conditioning"] = torch.zeros_like(log["reconstructions"])
         | 
| 248 | 
            -
                        for i in range(quant_c.shape[0]):
         | 
| 249 | 
            -
                            log["conditioning"][i] = plotter(quant_c[i], label_for_category_no, figure_size)
         | 
| 250 | 
            -
                        log["conditioning_rec"] = log["conditioning"]
         | 
| 251 | 
            -
                    elif self.cond_stage_key != "image":
         | 
| 252 | 
            -
                        cond_rec = self.cond_stage_model.decode(quant_c)
         | 
| 253 | 
            -
                        if self.cond_stage_key == "segmentation":
         | 
| 254 | 
            -
                            # get image from segmentation mask
         | 
| 255 | 
            -
                            num_classes = cond_rec.shape[1]
         | 
| 256 | 
            -
             | 
| 257 | 
            -
                            c = torch.argmax(c, dim=1, keepdim=True)
         | 
| 258 | 
            -
                            c = F.one_hot(c, num_classes=num_classes)
         | 
| 259 | 
            -
                            c = c.squeeze(1).permute(0, 3, 1, 2).float()
         | 
| 260 | 
            -
                            c = self.cond_stage_model.to_rgb(c)
         | 
| 261 | 
            -
             | 
| 262 | 
            -
                            cond_rec = torch.argmax(cond_rec, dim=1, keepdim=True)
         | 
| 263 | 
            -
                            cond_rec = F.one_hot(cond_rec, num_classes=num_classes)
         | 
| 264 | 
            -
                            cond_rec = cond_rec.squeeze(1).permute(0, 3, 1, 2).float()
         | 
| 265 | 
            -
                            cond_rec = self.cond_stage_model.to_rgb(cond_rec)
         | 
| 266 | 
            -
                        log["conditioning_rec"] = cond_rec
         | 
| 267 | 
            -
                        log["conditioning"] = c
         | 
| 268 | 
            -
             | 
| 269 | 
            -
                    log["samples_half"] = x_sample
         | 
| 270 | 
            -
                    log["samples_nopix"] = x_sample_nopix
         | 
| 271 | 
            -
                    log["samples_det"] = x_sample_det
         | 
| 272 | 
            -
                    return log
         | 
| 273 | 
            -
             | 
| 274 | 
            -
                def get_input(self, key, batch):
         | 
| 275 | 
            -
                    x = batch[key]
         | 
| 276 | 
            -
                    if len(x.shape) == 3:
         | 
| 277 | 
            -
                        x = x[..., None]
         | 
| 278 | 
            -
                    if len(x.shape) == 4:
         | 
| 279 | 
            -
                        x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
         | 
| 280 | 
            -
                    if x.dtype == torch.double:
         | 
| 281 | 
            -
                        x = x.float()
         | 
| 282 | 
            -
                    return x
         | 
| 283 | 
            -
             | 
| 284 | 
            -
                def get_xc(self, batch, N=None):
         | 
| 285 | 
            -
                    x = self.get_input(self.first_stage_key, batch)
         | 
| 286 | 
            -
                    c = self.get_input(self.cond_stage_key, batch)
         | 
| 287 | 
            -
                    if N is not None:
         | 
| 288 | 
            -
                        x = x[:N]
         | 
| 289 | 
            -
                        c = c[:N]
         | 
| 290 | 
            -
                    return x, c
         | 
| 291 | 
            -
             | 
| 292 | 
            -
                def shared_step(self, batch, batch_idx):
         | 
| 293 | 
            -
                    x, c = self.get_xc(batch)
         | 
| 294 | 
            -
                    logits, target = self(x, c)
         | 
| 295 | 
            -
                    loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), target.reshape(-1))
         | 
| 296 | 
            -
                    return loss
         | 
| 297 | 
            -
             | 
| 298 | 
            -
                def training_step(self, batch, batch_idx):
         | 
| 299 | 
            -
                    loss = self.shared_step(batch, batch_idx)
         | 
| 300 | 
            -
                    self.log("train/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
         | 
| 301 | 
            -
                    return loss
         | 
| 302 | 
            -
             | 
| 303 | 
            -
                def validation_step(self, batch, batch_idx):
         | 
| 304 | 
            -
                    loss = self.shared_step(batch, batch_idx)
         | 
| 305 | 
            -
                    self.log("val/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
         | 
| 306 | 
            -
                    return loss
         | 
| 307 | 
            -
             | 
| 308 | 
            -
                def configure_optimizers(self):
         | 
| 309 | 
            -
                    """
         | 
| 310 | 
            -
                    Following minGPT:
         | 
| 311 | 
            -
                    This long function is unfortunately doing something very simple and is being very defensive:
         | 
| 312 | 
            -
                    We are separating out all parameters of the model into two buckets: those that will experience
         | 
| 313 | 
            -
                    weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
         | 
| 314 | 
            -
                    We are then returning the PyTorch optimizer object.
         | 
| 315 | 
            -
                    """
         | 
| 316 | 
            -
                    # separate out all parameters to those that will and won't experience regularizing weight decay
         | 
| 317 | 
            -
                    decay = set()
         | 
| 318 | 
            -
                    no_decay = set()
         | 
| 319 | 
            -
                    whitelist_weight_modules = (torch.nn.Linear, )
         | 
| 320 | 
            -
                    blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
         | 
| 321 | 
            -
                    for mn, m in self.transformer.named_modules():
         | 
| 322 | 
            -
                        for pn, p in m.named_parameters():
         | 
| 323 | 
            -
                            fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
         | 
| 324 | 
            -
             | 
| 325 | 
            -
                            if pn.endswith('bias'):
         | 
| 326 | 
            -
                                # all biases will not be decayed
         | 
| 327 | 
            -
                                no_decay.add(fpn)
         | 
| 328 | 
            -
                            elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
         | 
| 329 | 
            -
                                # weights of whitelist modules will be weight decayed
         | 
| 330 | 
            -
                                decay.add(fpn)
         | 
| 331 | 
            -
                            elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
         | 
| 332 | 
            -
                                # weights of blacklist modules will NOT be weight decayed
         | 
| 333 | 
            -
                                no_decay.add(fpn)
         | 
| 334 | 
            -
             | 
| 335 | 
            -
                    # special case the position embedding parameter in the root GPT module as not decayed
         | 
| 336 | 
            -
                    no_decay.add('pos_emb')
         | 
| 337 | 
            -
             | 
| 338 | 
            -
                    # validate that we considered every parameter
         | 
| 339 | 
            -
                    param_dict = {pn: p for pn, p in self.transformer.named_parameters()}
         | 
| 340 | 
            -
                    inter_params = decay & no_decay
         | 
| 341 | 
            -
                    union_params = decay | no_decay
         | 
| 342 | 
            -
                    assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
         | 
| 343 | 
            -
                    assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
         | 
| 344 | 
            -
                                                                % (str(param_dict.keys() - union_params), )
         | 
| 345 | 
            -
             | 
| 346 | 
            -
                    # create the pytorch optimizer object
         | 
| 347 | 
            -
                    optim_groups = [
         | 
| 348 | 
            -
                        {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.01},
         | 
| 349 | 
            -
                        {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
         | 
| 350 | 
            -
                    ]
         | 
| 351 | 
            -
                    optimizer = torch.optim.AdamW(optim_groups, lr=self.learning_rate, betas=(0.9, 0.95))
         | 
| 352 | 
            -
                    return optimizer
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        taming-transformers/taming/models/dummy_cond_stage.py
    DELETED
    
    | @@ -1,22 +0,0 @@ | |
| 1 | 
            -
            from torch import Tensor
         | 
| 2 | 
            -
             | 
| 3 | 
            -
             | 
| 4 | 
            -
            class DummyCondStage:
         | 
| 5 | 
            -
                def __init__(self, conditional_key):
         | 
| 6 | 
            -
                    self.conditional_key = conditional_key
         | 
| 7 | 
            -
                    self.train = None
         | 
| 8 | 
            -
             | 
| 9 | 
            -
                def eval(self):
         | 
| 10 | 
            -
                    return self
         | 
| 11 | 
            -
             | 
| 12 | 
            -
                @staticmethod
         | 
| 13 | 
            -
                def encode(c: Tensor):
         | 
| 14 | 
            -
                    return c, None, (None, None, c)
         | 
| 15 | 
            -
             | 
| 16 | 
            -
                @staticmethod
         | 
| 17 | 
            -
                def decode(c: Tensor):
         | 
| 18 | 
            -
                    return c
         | 
| 19 | 
            -
             | 
| 20 | 
            -
                @staticmethod
         | 
| 21 | 
            -
                def to_rgb(c: Tensor):
         | 
| 22 | 
            -
                    return c
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        taming-transformers/taming/models/vqgan.py
    DELETED
    
    | @@ -1,404 +0,0 @@ | |
| 1 | 
            -
            import torch
         | 
| 2 | 
            -
            import torch.nn.functional as F
         | 
| 3 | 
            -
            import pytorch_lightning as pl
         | 
| 4 | 
            -
             | 
| 5 | 
            -
            from main import instantiate_from_config
         | 
| 6 | 
            -
             | 
| 7 | 
            -
            from taming.modules.diffusionmodules.model import Encoder, Decoder
         | 
| 8 | 
            -
            from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
         | 
| 9 | 
            -
            from taming.modules.vqvae.quantize import GumbelQuantize
         | 
| 10 | 
            -
            from taming.modules.vqvae.quantize import EMAVectorQuantizer
         | 
| 11 | 
            -
             | 
| 12 | 
            -
            class VQModel(pl.LightningModule):
         | 
| 13 | 
            -
                def __init__(self,
         | 
| 14 | 
            -
                             ddconfig,
         | 
| 15 | 
            -
                             lossconfig,
         | 
| 16 | 
            -
                             n_embed,
         | 
| 17 | 
            -
                             embed_dim,
         | 
| 18 | 
            -
                             ckpt_path=None,
         | 
| 19 | 
            -
                             ignore_keys=[],
         | 
| 20 | 
            -
                             image_key="image",
         | 
| 21 | 
            -
                             colorize_nlabels=None,
         | 
| 22 | 
            -
                             monitor=None,
         | 
| 23 | 
            -
                             remap=None,
         | 
| 24 | 
            -
                             sane_index_shape=False,  # tell vector quantizer to return indices as bhw
         | 
| 25 | 
            -
                             ):
         | 
| 26 | 
            -
                    super().__init__()
         | 
| 27 | 
            -
                    self.image_key = image_key
         | 
| 28 | 
            -
                    self.encoder = Encoder(**ddconfig)
         | 
| 29 | 
            -
                    self.decoder = Decoder(**ddconfig)
         | 
| 30 | 
            -
                    self.loss = instantiate_from_config(lossconfig)
         | 
| 31 | 
            -
                    self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
         | 
| 32 | 
            -
                                                    remap=remap, sane_index_shape=sane_index_shape)
         | 
| 33 | 
            -
                    self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
         | 
| 34 | 
            -
                    self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
         | 
| 35 | 
            -
                    if ckpt_path is not None:
         | 
| 36 | 
            -
                        self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
         | 
| 37 | 
            -
                    self.image_key = image_key
         | 
| 38 | 
            -
                    if colorize_nlabels is not None:
         | 
| 39 | 
            -
                        assert type(colorize_nlabels)==int
         | 
| 40 | 
            -
                        self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
         | 
| 41 | 
            -
                    if monitor is not None:
         | 
| 42 | 
            -
                        self.monitor = monitor
         | 
| 43 | 
            -
             | 
| 44 | 
            -
                def init_from_ckpt(self, path, ignore_keys=list()):
         | 
| 45 | 
            -
                    sd = torch.load(path, map_location="cpu")["state_dict"]
         | 
| 46 | 
            -
                    keys = list(sd.keys())
         | 
| 47 | 
            -
                    for k in keys:
         | 
| 48 | 
            -
                        for ik in ignore_keys:
         | 
| 49 | 
            -
                            if k.startswith(ik):
         | 
| 50 | 
            -
                                print("Deleting key {} from state_dict.".format(k))
         | 
| 51 | 
            -
                                del sd[k]
         | 
| 52 | 
            -
                    self.load_state_dict(sd, strict=False)
         | 
| 53 | 
            -
                    print(f"Restored from {path}")
         | 
| 54 | 
            -
             | 
| 55 | 
            -
                def encode(self, x):
         | 
| 56 | 
            -
                    h = self.encoder(x)
         | 
| 57 | 
            -
                    h = self.quant_conv(h)
         | 
| 58 | 
            -
                    quant, emb_loss, info = self.quantize(h)
         | 
| 59 | 
            -
                    return quant, emb_loss, info
         | 
| 60 | 
            -
             | 
| 61 | 
            -
                def decode(self, quant):
         | 
| 62 | 
            -
                    quant = self.post_quant_conv(quant)
         | 
| 63 | 
            -
                    dec = self.decoder(quant)
         | 
| 64 | 
            -
                    return dec
         | 
| 65 | 
            -
             | 
| 66 | 
            -
                def decode_code(self, code_b):
         | 
| 67 | 
            -
                    quant_b = self.quantize.embed_code(code_b)
         | 
| 68 | 
            -
                    dec = self.decode(quant_b)
         | 
| 69 | 
            -
                    return dec
         | 
| 70 | 
            -
             | 
| 71 | 
            -
                def forward(self, input):
         | 
| 72 | 
            -
                    quant, diff, _ = self.encode(input)
         | 
| 73 | 
            -
                    dec = self.decode(quant)
         | 
| 74 | 
            -
                    return dec, diff
         | 
| 75 | 
            -
             | 
| 76 | 
            -
                def get_input(self, batch, k):
         | 
| 77 | 
            -
                    x = batch[k]
         | 
| 78 | 
            -
                    if len(x.shape) == 3:
         | 
| 79 | 
            -
                        x = x[..., None]
         | 
| 80 | 
            -
                    x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
         | 
| 81 | 
            -
                    return x.float()
         | 
| 82 | 
            -
             | 
| 83 | 
            -
                def training_step(self, batch, batch_idx, optimizer_idx):
         | 
| 84 | 
            -
                    x = self.get_input(batch, self.image_key)
         | 
| 85 | 
            -
                    xrec, qloss = self(x)
         | 
| 86 | 
            -
             | 
| 87 | 
            -
                    if optimizer_idx == 0:
         | 
| 88 | 
            -
                        # autoencode
         | 
| 89 | 
            -
                        aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
         | 
| 90 | 
            -
                                                        last_layer=self.get_last_layer(), split="train")
         | 
| 91 | 
            -
             | 
| 92 | 
            -
                        self.log("train/aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
         | 
| 93 | 
            -
                        self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
         | 
| 94 | 
            -
                        return aeloss
         | 
| 95 | 
            -
             | 
| 96 | 
            -
                    if optimizer_idx == 1:
         | 
| 97 | 
            -
                        # discriminator
         | 
| 98 | 
            -
                        discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
         | 
| 99 | 
            -
                                                        last_layer=self.get_last_layer(), split="train")
         | 
| 100 | 
            -
                        self.log("train/discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
         | 
| 101 | 
            -
                        self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
         | 
| 102 | 
            -
                        return discloss
         | 
| 103 | 
            -
             | 
| 104 | 
            -
                def validation_step(self, batch, batch_idx):
         | 
| 105 | 
            -
                    x = self.get_input(batch, self.image_key)
         | 
| 106 | 
            -
                    xrec, qloss = self(x)
         | 
| 107 | 
            -
                    aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step,
         | 
| 108 | 
            -
                                                        last_layer=self.get_last_layer(), split="val")
         | 
| 109 | 
            -
             | 
| 110 | 
            -
                    discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step,
         | 
| 111 | 
            -
                                                        last_layer=self.get_last_layer(), split="val")
         | 
| 112 | 
            -
                    rec_loss = log_dict_ae["val/rec_loss"]
         | 
| 113 | 
            -
                    self.log("val/rec_loss", rec_loss,
         | 
| 114 | 
            -
                               prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
         | 
| 115 | 
            -
                    self.log("val/aeloss", aeloss,
         | 
| 116 | 
            -
                               prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
         | 
| 117 | 
            -
                    self.log_dict(log_dict_ae)
         | 
| 118 | 
            -
                    self.log_dict(log_dict_disc)
         | 
| 119 | 
            -
                    return self.log_dict
         | 
| 120 | 
            -
             | 
| 121 | 
            -
                def configure_optimizers(self):
         | 
| 122 | 
            -
                    lr = self.learning_rate
         | 
| 123 | 
            -
                    opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
         | 
| 124 | 
            -
                                              list(self.decoder.parameters())+
         | 
| 125 | 
            -
                                              list(self.quantize.parameters())+
         | 
| 126 | 
            -
                                              list(self.quant_conv.parameters())+
         | 
| 127 | 
            -
                                              list(self.post_quant_conv.parameters()),
         | 
| 128 | 
            -
                                              lr=lr, betas=(0.5, 0.9))
         | 
| 129 | 
            -
                    opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
         | 
| 130 | 
            -
                                                lr=lr, betas=(0.5, 0.9))
         | 
| 131 | 
            -
                    return [opt_ae, opt_disc], []
         | 
| 132 | 
            -
             | 
| 133 | 
            -
                def get_last_layer(self):
         | 
| 134 | 
            -
                    return self.decoder.conv_out.weight
         | 
| 135 | 
            -
             | 
| 136 | 
            -
                def log_images(self, batch, **kwargs):
         | 
| 137 | 
            -
                    log = dict()
         | 
| 138 | 
            -
                    x = self.get_input(batch, self.image_key)
         | 
| 139 | 
            -
                    x = x.to(self.device)
         | 
| 140 | 
            -
                    xrec, _ = self(x)
         | 
| 141 | 
            -
                    if x.shape[1] > 3:
         | 
| 142 | 
            -
                        # colorize with random projection
         | 
| 143 | 
            -
                        assert xrec.shape[1] > 3
         | 
| 144 | 
            -
                        x = self.to_rgb(x)
         | 
| 145 | 
            -
                        xrec = self.to_rgb(xrec)
         | 
| 146 | 
            -
                    log["inputs"] = x
         | 
| 147 | 
            -
                    log["reconstructions"] = xrec
         | 
| 148 | 
            -
                    return log
         | 
| 149 | 
            -
             | 
| 150 | 
            -
                def to_rgb(self, x):
         | 
| 151 | 
            -
                    assert self.image_key == "segmentation"
         | 
| 152 | 
            -
                    if not hasattr(self, "colorize"):
         | 
| 153 | 
            -
                        self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
         | 
| 154 | 
            -
                    x = F.conv2d(x, weight=self.colorize)
         | 
| 155 | 
            -
                    x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
         | 
| 156 | 
            -
                    return x
         | 
| 157 | 
            -
             | 
| 158 | 
            -
             | 
| 159 | 
            -
            class VQSegmentationModel(VQModel):
         | 
| 160 | 
            -
                def __init__(self, n_labels, *args, **kwargs):
         | 
| 161 | 
            -
                    super().__init__(*args, **kwargs)
         | 
| 162 | 
            -
                    self.register_buffer("colorize", torch.randn(3, n_labels, 1, 1))
         | 
| 163 | 
            -
             | 
| 164 | 
            -
                def configure_optimizers(self):
         | 
| 165 | 
            -
                    lr = self.learning_rate
         | 
| 166 | 
            -
                    opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
         | 
| 167 | 
            -
                                              list(self.decoder.parameters())+
         | 
| 168 | 
            -
                                              list(self.quantize.parameters())+
         | 
| 169 | 
            -
                                              list(self.quant_conv.parameters())+
         | 
| 170 | 
            -
                                              list(self.post_quant_conv.parameters()),
         | 
| 171 | 
            -
                                              lr=lr, betas=(0.5, 0.9))
         | 
| 172 | 
            -
                    return opt_ae
         | 
| 173 | 
            -
             | 
| 174 | 
            -
                def training_step(self, batch, batch_idx):
         | 
| 175 | 
            -
                    x = self.get_input(batch, self.image_key)
         | 
| 176 | 
            -
                    xrec, qloss = self(x)
         | 
| 177 | 
            -
                    aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="train")
         | 
| 178 | 
            -
                    self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
         | 
| 179 | 
            -
                    return aeloss
         | 
| 180 | 
            -
             | 
| 181 | 
            -
                def validation_step(self, batch, batch_idx):
         | 
| 182 | 
            -
                    x = self.get_input(batch, self.image_key)
         | 
| 183 | 
            -
                    xrec, qloss = self(x)
         | 
| 184 | 
            -
                    aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="val")
         | 
| 185 | 
            -
                    self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
         | 
| 186 | 
            -
                    total_loss = log_dict_ae["val/total_loss"]
         | 
| 187 | 
            -
                    self.log("val/total_loss", total_loss,
         | 
| 188 | 
            -
                             prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
         | 
| 189 | 
            -
                    return aeloss
         | 
| 190 | 
            -
             | 
| 191 | 
            -
                @torch.no_grad()
         | 
| 192 | 
            -
                def log_images(self, batch, **kwargs):
         | 
| 193 | 
            -
                    log = dict()
         | 
| 194 | 
            -
                    x = self.get_input(batch, self.image_key)
         | 
| 195 | 
            -
                    x = x.to(self.device)
         | 
| 196 | 
            -
                    xrec, _ = self(x)
         | 
| 197 | 
            -
                    if x.shape[1] > 3:
         | 
| 198 | 
            -
                        # colorize with random projection
         | 
| 199 | 
            -
                        assert xrec.shape[1] > 3
         | 
| 200 | 
            -
                        # convert logits to indices
         | 
| 201 | 
            -
                        xrec = torch.argmax(xrec, dim=1, keepdim=True)
         | 
| 202 | 
            -
                        xrec = F.one_hot(xrec, num_classes=x.shape[1])
         | 
| 203 | 
            -
                        xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float()
         | 
| 204 | 
            -
                        x = self.to_rgb(x)
         | 
| 205 | 
            -
                        xrec = self.to_rgb(xrec)
         | 
| 206 | 
            -
                    log["inputs"] = x
         | 
| 207 | 
            -
                    log["reconstructions"] = xrec
         | 
| 208 | 
            -
                    return log
         | 
| 209 | 
            -
             | 
| 210 | 
            -
             | 
| 211 | 
            -
            class VQNoDiscModel(VQModel):
         | 
| 212 | 
            -
                def __init__(self,
         | 
| 213 | 
            -
                             ddconfig,
         | 
| 214 | 
            -
                             lossconfig,
         | 
| 215 | 
            -
                             n_embed,
         | 
| 216 | 
            -
                             embed_dim,
         | 
| 217 | 
            -
                             ckpt_path=None,
         | 
| 218 | 
            -
                             ignore_keys=[],
         | 
| 219 | 
            -
                             image_key="image",
         | 
| 220 | 
            -
                             colorize_nlabels=None
         | 
| 221 | 
            -
                             ):
         | 
| 222 | 
            -
                    super().__init__(ddconfig=ddconfig, lossconfig=lossconfig, n_embed=n_embed, embed_dim=embed_dim,
         | 
| 223 | 
            -
                                     ckpt_path=ckpt_path, ignore_keys=ignore_keys, image_key=image_key,
         | 
| 224 | 
            -
                                     colorize_nlabels=colorize_nlabels)
         | 
| 225 | 
            -
             | 
| 226 | 
            -
                def training_step(self, batch, batch_idx):
         | 
| 227 | 
            -
                    x = self.get_input(batch, self.image_key)
         | 
| 228 | 
            -
                    xrec, qloss = self(x)
         | 
| 229 | 
            -
                    # autoencode
         | 
| 230 | 
            -
                    aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="train")
         | 
| 231 | 
            -
                    output = pl.TrainResult(minimize=aeloss)
         | 
| 232 | 
            -
                    output.log("train/aeloss", aeloss,
         | 
| 233 | 
            -
                               prog_bar=True, logger=True, on_step=True, on_epoch=True)
         | 
| 234 | 
            -
                    output.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
         | 
| 235 | 
            -
                    return output
         | 
| 236 | 
            -
             | 
| 237 | 
            -
                def validation_step(self, batch, batch_idx):
         | 
| 238 | 
            -
                    x = self.get_input(batch, self.image_key)
         | 
| 239 | 
            -
                    xrec, qloss = self(x)
         | 
| 240 | 
            -
                    aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="val")
         | 
| 241 | 
            -
                    rec_loss = log_dict_ae["val/rec_loss"]
         | 
| 242 | 
            -
                    output = pl.EvalResult(checkpoint_on=rec_loss)
         | 
| 243 | 
            -
                    output.log("val/rec_loss", rec_loss,
         | 
| 244 | 
            -
                               prog_bar=True, logger=True, on_step=True, on_epoch=True)
         | 
| 245 | 
            -
                    output.log("val/aeloss", aeloss,
         | 
| 246 | 
            -
                               prog_bar=True, logger=True, on_step=True, on_epoch=True)
         | 
| 247 | 
            -
                    output.log_dict(log_dict_ae)
         | 
| 248 | 
            -
             | 
| 249 | 
            -
                    return output
         | 
| 250 | 
            -
             | 
| 251 | 
            -
                def configure_optimizers(self):
         | 
| 252 | 
            -
                    optimizer = torch.optim.Adam(list(self.encoder.parameters())+
         | 
| 253 | 
            -
                                              list(self.decoder.parameters())+
         | 
| 254 | 
            -
                                              list(self.quantize.parameters())+
         | 
| 255 | 
            -
                                              list(self.quant_conv.parameters())+
         | 
| 256 | 
            -
                                              list(self.post_quant_conv.parameters()),
         | 
| 257 | 
            -
                                              lr=self.learning_rate, betas=(0.5, 0.9))
         | 
| 258 | 
            -
                    return optimizer
         | 
| 259 | 
            -
             | 
| 260 | 
            -
             | 
| 261 | 
            -
            class GumbelVQ(VQModel):
         | 
| 262 | 
            -
                def __init__(self,
         | 
| 263 | 
            -
                             ddconfig,
         | 
| 264 | 
            -
                             lossconfig,
         | 
| 265 | 
            -
                             n_embed,
         | 
| 266 | 
            -
                             embed_dim,
         | 
| 267 | 
            -
                             temperature_scheduler_config,
         | 
| 268 | 
            -
                             ckpt_path=None,
         | 
| 269 | 
            -
                             ignore_keys=[],
         | 
| 270 | 
            -
                             image_key="image",
         | 
| 271 | 
            -
                             colorize_nlabels=None,
         | 
| 272 | 
            -
                             monitor=None,
         | 
| 273 | 
            -
                             kl_weight=1e-8,
         | 
| 274 | 
            -
                             remap=None,
         | 
| 275 | 
            -
                             ):
         | 
| 276 | 
            -
             | 
| 277 | 
            -
                    z_channels = ddconfig["z_channels"]
         | 
| 278 | 
            -
                    super().__init__(ddconfig,
         | 
| 279 | 
            -
                                     lossconfig,
         | 
| 280 | 
            -
                                     n_embed,
         | 
| 281 | 
            -
                                     embed_dim,
         | 
| 282 | 
            -
                                     ckpt_path=None,
         | 
| 283 | 
            -
                                     ignore_keys=ignore_keys,
         | 
| 284 | 
            -
                                     image_key=image_key,
         | 
| 285 | 
            -
                                     colorize_nlabels=colorize_nlabels,
         | 
| 286 | 
            -
                                     monitor=monitor,
         | 
| 287 | 
            -
                                     )
         | 
| 288 | 
            -
             | 
| 289 | 
            -
                    self.loss.n_classes = n_embed
         | 
| 290 | 
            -
                    self.vocab_size = n_embed
         | 
| 291 | 
            -
             | 
| 292 | 
            -
                    self.quantize = GumbelQuantize(z_channels, embed_dim,
         | 
| 293 | 
            -
                                                   n_embed=n_embed,
         | 
| 294 | 
            -
                                                   kl_weight=kl_weight, temp_init=1.0,
         | 
| 295 | 
            -
                                                   remap=remap)
         | 
| 296 | 
            -
             | 
| 297 | 
            -
                    self.temperature_scheduler = instantiate_from_config(temperature_scheduler_config)   # annealing of temp
         | 
| 298 | 
            -
             | 
| 299 | 
            -
                    if ckpt_path is not None:
         | 
| 300 | 
            -
                        self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
         | 
| 301 | 
            -
             | 
| 302 | 
            -
                def temperature_scheduling(self):
         | 
| 303 | 
            -
                    self.quantize.temperature = self.temperature_scheduler(self.global_step)
         | 
| 304 | 
            -
             | 
| 305 | 
            -
                def encode_to_prequant(self, x):
         | 
| 306 | 
            -
                    h = self.encoder(x)
         | 
| 307 | 
            -
                    h = self.quant_conv(h)
         | 
| 308 | 
            -
                    return h
         | 
| 309 | 
            -
             | 
| 310 | 
            -
                def decode_code(self, code_b):
         | 
| 311 | 
            -
                    raise NotImplementedError
         | 
| 312 | 
            -
             | 
| 313 | 
            -
                def training_step(self, batch, batch_idx, optimizer_idx):
         | 
| 314 | 
            -
                    self.temperature_scheduling()
         | 
| 315 | 
            -
                    x = self.get_input(batch, self.image_key)
         | 
| 316 | 
            -
                    xrec, qloss = self(x)
         | 
| 317 | 
            -
             | 
| 318 | 
            -
                    if optimizer_idx == 0:
         | 
| 319 | 
            -
                        # autoencode
         | 
| 320 | 
            -
                        aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
         | 
| 321 | 
            -
                                                        last_layer=self.get_last_layer(), split="train")
         | 
| 322 | 
            -
             | 
| 323 | 
            -
                        self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
         | 
| 324 | 
            -
                        self.log("temperature", self.quantize.temperature, prog_bar=False, logger=True, on_step=True, on_epoch=True)
         | 
| 325 | 
            -
                        return aeloss
         | 
| 326 | 
            -
             | 
| 327 | 
            -
                    if optimizer_idx == 1:
         | 
| 328 | 
            -
                        # discriminator
         | 
| 329 | 
            -
                        discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
         | 
| 330 | 
            -
                                                        last_layer=self.get_last_layer(), split="train")
         | 
| 331 | 
            -
                        self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
         | 
| 332 | 
            -
                        return discloss
         | 
| 333 | 
            -
             | 
| 334 | 
            -
                def validation_step(self, batch, batch_idx):
         | 
| 335 | 
            -
                    x = self.get_input(batch, self.image_key)
         | 
| 336 | 
            -
                    xrec, qloss = self(x, return_pred_indices=True)
         | 
| 337 | 
            -
                    aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step,
         | 
| 338 | 
            -
                                                    last_layer=self.get_last_layer(), split="val")
         | 
| 339 | 
            -
             | 
| 340 | 
            -
                    discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step,
         | 
| 341 | 
            -
                                                        last_layer=self.get_last_layer(), split="val")
         | 
| 342 | 
            -
                    rec_loss = log_dict_ae["val/rec_loss"]
         | 
| 343 | 
            -
                    self.log("val/rec_loss", rec_loss,
         | 
| 344 | 
            -
                             prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
         | 
| 345 | 
            -
                    self.log("val/aeloss", aeloss,
         | 
| 346 | 
            -
                             prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
         | 
| 347 | 
            -
                    self.log_dict(log_dict_ae)
         | 
| 348 | 
            -
                    self.log_dict(log_dict_disc)
         | 
| 349 | 
            -
                    return self.log_dict
         | 
| 350 | 
            -
             | 
| 351 | 
            -
                def log_images(self, batch, **kwargs):
         | 
| 352 | 
            -
                    log = dict()
         | 
| 353 | 
            -
                    x = self.get_input(batch, self.image_key)
         | 
| 354 | 
            -
                    x = x.to(self.device)
         | 
| 355 | 
            -
                    # encode
         | 
| 356 | 
            -
                    h = self.encoder(x)
         | 
| 357 | 
            -
                    h = self.quant_conv(h)
         | 
| 358 | 
            -
                    quant, _, _ = self.quantize(h)
         | 
| 359 | 
            -
                    # decode
         | 
| 360 | 
            -
                    x_rec = self.decode(quant)
         | 
| 361 | 
            -
                    log["inputs"] = x
         | 
| 362 | 
            -
                    log["reconstructions"] = x_rec
         | 
| 363 | 
            -
                    return log
         | 
| 364 | 
            -
             | 
| 365 | 
            -
             | 
| 366 | 
            -
            class EMAVQ(VQModel):
         | 
| 367 | 
            -
                def __init__(self,
         | 
| 368 | 
            -
                             ddconfig,
         | 
| 369 | 
            -
                             lossconfig,
         | 
| 370 | 
            -
                             n_embed,
         | 
| 371 | 
            -
                             embed_dim,
         | 
| 372 | 
            -
                             ckpt_path=None,
         | 
| 373 | 
            -
                             ignore_keys=[],
         | 
| 374 | 
            -
                             image_key="image",
         | 
| 375 | 
            -
                             colorize_nlabels=None,
         | 
| 376 | 
            -
                             monitor=None,
         | 
| 377 | 
            -
                             remap=None,
         | 
| 378 | 
            -
                             sane_index_shape=False,  # tell vector quantizer to return indices as bhw
         | 
| 379 | 
            -
                             ):
         | 
| 380 | 
            -
                    super().__init__(ddconfig,
         | 
| 381 | 
            -
                                     lossconfig,
         | 
| 382 | 
            -
                                     n_embed,
         | 
| 383 | 
            -
                                     embed_dim,
         | 
| 384 | 
            -
                                     ckpt_path=None,
         | 
| 385 | 
            -
                                     ignore_keys=ignore_keys,
         | 
| 386 | 
            -
                                     image_key=image_key,
         | 
| 387 | 
            -
                                     colorize_nlabels=colorize_nlabels,
         | 
| 388 | 
            -
                                     monitor=monitor,
         | 
| 389 | 
            -
                                     )
         | 
| 390 | 
            -
                    self.quantize = EMAVectorQuantizer(n_embed=n_embed,
         | 
| 391 | 
            -
                                                       embedding_dim=embed_dim,
         | 
| 392 | 
            -
                                                       beta=0.25,
         | 
| 393 | 
            -
                                                       remap=remap)
         | 
| 394 | 
            -
                def configure_optimizers(self):
         | 
| 395 | 
            -
                    lr = self.learning_rate
         | 
| 396 | 
            -
                    #Remove self.quantize from parameter list since it is updated via EMA
         | 
| 397 | 
            -
                    opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
         | 
| 398 | 
            -
                                              list(self.decoder.parameters())+
         | 
| 399 | 
            -
                                              list(self.quant_conv.parameters())+
         | 
| 400 | 
            -
                                              list(self.post_quant_conv.parameters()),
         | 
| 401 | 
            -
                                              lr=lr, betas=(0.5, 0.9))
         | 
| 402 | 
            -
                    opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
         | 
| 403 | 
            -
                                                lr=lr, betas=(0.5, 0.9))
         | 
| 404 | 
            -
                    return [opt_ae, opt_disc], []                                           
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        taming-transformers/taming/modules/diffusionmodules/model.py
    DELETED
    
    | @@ -1,776 +0,0 @@ | |
| 1 | 
            -
            # pytorch_diffusion + derived encoder decoder
         | 
| 2 | 
            -
            import math
         | 
| 3 | 
            -
            import torch
         | 
| 4 | 
            -
            import torch.nn as nn
         | 
| 5 | 
            -
            import numpy as np
         | 
| 6 | 
            -
             | 
| 7 | 
            -
             | 
| 8 | 
            -
            def get_timestep_embedding(timesteps, embedding_dim):
         | 
| 9 | 
            -
                """
         | 
| 10 | 
            -
                This matches the implementation in Denoising Diffusion Probabilistic Models:
         | 
| 11 | 
            -
                From Fairseq.
         | 
| 12 | 
            -
                Build sinusoidal embeddings.
         | 
| 13 | 
            -
                This matches the implementation in tensor2tensor, but differs slightly
         | 
| 14 | 
            -
                from the description in Section 3.5 of "Attention Is All You Need".
         | 
| 15 | 
            -
                """
         | 
| 16 | 
            -
                assert len(timesteps.shape) == 1
         | 
| 17 | 
            -
             | 
| 18 | 
            -
                half_dim = embedding_dim // 2
         | 
| 19 | 
            -
                emb = math.log(10000) / (half_dim - 1)
         | 
| 20 | 
            -
                emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
         | 
| 21 | 
            -
                emb = emb.to(device=timesteps.device)
         | 
| 22 | 
            -
                emb = timesteps.float()[:, None] * emb[None, :]
         | 
| 23 | 
            -
                emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
         | 
| 24 | 
            -
                if embedding_dim % 2 == 1:  # zero pad
         | 
| 25 | 
            -
                    emb = torch.nn.functional.pad(emb, (0,1,0,0))
         | 
| 26 | 
            -
                return emb
         | 
| 27 | 
            -
             | 
| 28 | 
            -
             | 
| 29 | 
            -
            def nonlinearity(x):
         | 
| 30 | 
            -
                # swish
         | 
| 31 | 
            -
                return x*torch.sigmoid(x)
         | 
| 32 | 
            -
             | 
| 33 | 
            -
             | 
| 34 | 
            -
            def Normalize(in_channels):
         | 
| 35 | 
            -
                return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
         | 
| 36 | 
            -
             | 
| 37 | 
            -
             | 
| 38 | 
            -
            class Upsample(nn.Module):
         | 
| 39 | 
            -
                def __init__(self, in_channels, with_conv):
         | 
| 40 | 
            -
                    super().__init__()
         | 
| 41 | 
            -
                    self.with_conv = with_conv
         | 
| 42 | 
            -
                    if self.with_conv:
         | 
| 43 | 
            -
                        self.conv = torch.nn.Conv2d(in_channels,
         | 
| 44 | 
            -
                                                    in_channels,
         | 
| 45 | 
            -
                                                    kernel_size=3,
         | 
| 46 | 
            -
                                                    stride=1,
         | 
| 47 | 
            -
                                                    padding=1)
         | 
| 48 | 
            -
             | 
| 49 | 
            -
                def forward(self, x):
         | 
| 50 | 
            -
                    x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
         | 
| 51 | 
            -
                    if self.with_conv:
         | 
| 52 | 
            -
                        x = self.conv(x)
         | 
| 53 | 
            -
                    return x
         | 
| 54 | 
            -
             | 
| 55 | 
            -
             | 
| 56 | 
            -
            class Downsample(nn.Module):
         | 
| 57 | 
            -
                def __init__(self, in_channels, with_conv):
         | 
| 58 | 
            -
                    super().__init__()
         | 
| 59 | 
            -
                    self.with_conv = with_conv
         | 
| 60 | 
            -
                    if self.with_conv:
         | 
| 61 | 
            -
                        # no asymmetric padding in torch conv, must do it ourselves
         | 
| 62 | 
            -
                        self.conv = torch.nn.Conv2d(in_channels,
         | 
| 63 | 
            -
                                                    in_channels,
         | 
| 64 | 
            -
                                                    kernel_size=3,
         | 
| 65 | 
            -
                                                    stride=2,
         | 
| 66 | 
            -
                                                    padding=0)
         | 
| 67 | 
            -
             | 
| 68 | 
            -
                def forward(self, x):
         | 
| 69 | 
            -
                    if self.with_conv:
         | 
| 70 | 
            -
                        pad = (0,1,0,1)
         | 
| 71 | 
            -
                        x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
         | 
| 72 | 
            -
                        x = self.conv(x)
         | 
| 73 | 
            -
                    else:
         | 
| 74 | 
            -
                        x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
         | 
| 75 | 
            -
                    return x
         | 
| 76 | 
            -
             | 
| 77 | 
            -
             | 
| 78 | 
            -
            class ResnetBlock(nn.Module):
         | 
| 79 | 
            -
                def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
         | 
| 80 | 
            -
                             dropout, temb_channels=512):
         | 
| 81 | 
            -
                    super().__init__()
         | 
| 82 | 
            -
                    self.in_channels = in_channels
         | 
| 83 | 
            -
                    out_channels = in_channels if out_channels is None else out_channels
         | 
| 84 | 
            -
                    self.out_channels = out_channels
         | 
| 85 | 
            -
                    self.use_conv_shortcut = conv_shortcut
         | 
| 86 | 
            -
             | 
| 87 | 
            -
                    self.norm1 = Normalize(in_channels)
         | 
| 88 | 
            -
                    self.conv1 = torch.nn.Conv2d(in_channels,
         | 
| 89 | 
            -
                                                 out_channels,
         | 
| 90 | 
            -
                                                 kernel_size=3,
         | 
| 91 | 
            -
                                                 stride=1,
         | 
| 92 | 
            -
                                                 padding=1)
         | 
| 93 | 
            -
                    if temb_channels > 0:
         | 
| 94 | 
            -
                        self.temb_proj = torch.nn.Linear(temb_channels,
         | 
| 95 | 
            -
                                                         out_channels)
         | 
| 96 | 
            -
                    self.norm2 = Normalize(out_channels)
         | 
| 97 | 
            -
                    self.dropout = torch.nn.Dropout(dropout)
         | 
| 98 | 
            -
                    self.conv2 = torch.nn.Conv2d(out_channels,
         | 
| 99 | 
            -
                                                 out_channels,
         | 
| 100 | 
            -
                                                 kernel_size=3,
         | 
| 101 | 
            -
                                                 stride=1,
         | 
| 102 | 
            -
                                                 padding=1)
         | 
| 103 | 
            -
                    if self.in_channels != self.out_channels:
         | 
| 104 | 
            -
                        if self.use_conv_shortcut:
         | 
| 105 | 
            -
                            self.conv_shortcut = torch.nn.Conv2d(in_channels,
         | 
| 106 | 
            -
                                                                 out_channels,
         | 
| 107 | 
            -
                                                                 kernel_size=3,
         | 
| 108 | 
            -
                                                                 stride=1,
         | 
| 109 | 
            -
                                                                 padding=1)
         | 
| 110 | 
            -
                        else:
         | 
| 111 | 
            -
                            self.nin_shortcut = torch.nn.Conv2d(in_channels,
         | 
| 112 | 
            -
                                                                out_channels,
         | 
| 113 | 
            -
                                                                kernel_size=1,
         | 
| 114 | 
            -
                                                                stride=1,
         | 
| 115 | 
            -
                                                                padding=0)
         | 
| 116 | 
            -
             | 
| 117 | 
            -
                def forward(self, x, temb):
         | 
| 118 | 
            -
                    h = x
         | 
| 119 | 
            -
                    h = self.norm1(h)
         | 
| 120 | 
            -
                    h = nonlinearity(h)
         | 
| 121 | 
            -
                    h = self.conv1(h)
         | 
| 122 | 
            -
             | 
| 123 | 
            -
                    if temb is not None:
         | 
| 124 | 
            -
                        h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
         | 
| 125 | 
            -
             | 
| 126 | 
            -
                    h = self.norm2(h)
         | 
| 127 | 
            -
                    h = nonlinearity(h)
         | 
| 128 | 
            -
                    h = self.dropout(h)
         | 
| 129 | 
            -
                    h = self.conv2(h)
         | 
| 130 | 
            -
             | 
| 131 | 
            -
                    if self.in_channels != self.out_channels:
         | 
| 132 | 
            -
                        if self.use_conv_shortcut:
         | 
| 133 | 
            -
                            x = self.conv_shortcut(x)
         | 
| 134 | 
            -
                        else:
         | 
| 135 | 
            -
                            x = self.nin_shortcut(x)
         | 
| 136 | 
            -
             | 
| 137 | 
            -
                    return x+h
         | 
| 138 | 
            -
             | 
| 139 | 
            -
             | 
| 140 | 
            -
            class AttnBlock(nn.Module):
         | 
| 141 | 
            -
                def __init__(self, in_channels):
         | 
| 142 | 
            -
                    super().__init__()
         | 
| 143 | 
            -
                    self.in_channels = in_channels
         | 
| 144 | 
            -
             | 
| 145 | 
            -
                    self.norm = Normalize(in_channels)
         | 
| 146 | 
            -
                    self.q = torch.nn.Conv2d(in_channels,
         | 
| 147 | 
            -
                                             in_channels,
         | 
| 148 | 
            -
                                             kernel_size=1,
         | 
| 149 | 
            -
                                             stride=1,
         | 
| 150 | 
            -
                                             padding=0)
         | 
| 151 | 
            -
                    self.k = torch.nn.Conv2d(in_channels,
         | 
| 152 | 
            -
                                             in_channels,
         | 
| 153 | 
            -
                                             kernel_size=1,
         | 
| 154 | 
            -
                                             stride=1,
         | 
| 155 | 
            -
                                             padding=0)
         | 
| 156 | 
            -
                    self.v = torch.nn.Conv2d(in_channels,
         | 
| 157 | 
            -
                                             in_channels,
         | 
| 158 | 
            -
                                             kernel_size=1,
         | 
| 159 | 
            -
                                             stride=1,
         | 
| 160 | 
            -
                                             padding=0)
         | 
| 161 | 
            -
                    self.proj_out = torch.nn.Conv2d(in_channels,
         | 
| 162 | 
            -
                                                    in_channels,
         | 
| 163 | 
            -
                                                    kernel_size=1,
         | 
| 164 | 
            -
                                                    stride=1,
         | 
| 165 | 
            -
                                                    padding=0)
         | 
| 166 | 
            -
             | 
| 167 | 
            -
             | 
| 168 | 
            -
                def forward(self, x):
         | 
| 169 | 
            -
                    h_ = x
         | 
| 170 | 
            -
                    h_ = self.norm(h_)
         | 
| 171 | 
            -
                    q = self.q(h_)
         | 
| 172 | 
            -
                    k = self.k(h_)
         | 
| 173 | 
            -
                    v = self.v(h_)
         | 
| 174 | 
            -
             | 
| 175 | 
            -
                    # compute attention
         | 
| 176 | 
            -
                    b,c,h,w = q.shape
         | 
| 177 | 
            -
                    q = q.reshape(b,c,h*w)
         | 
| 178 | 
            -
                    q = q.permute(0,2,1)   # b,hw,c
         | 
| 179 | 
            -
                    k = k.reshape(b,c,h*w) # b,c,hw
         | 
| 180 | 
            -
                    w_ = torch.bmm(q,k)     # b,hw,hw    w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
         | 
| 181 | 
            -
                    w_ = w_ * (int(c)**(-0.5))
         | 
| 182 | 
            -
                    w_ = torch.nn.functional.softmax(w_, dim=2)
         | 
| 183 | 
            -
             | 
| 184 | 
            -
                    # attend to values
         | 
| 185 | 
            -
                    v = v.reshape(b,c,h*w)
         | 
| 186 | 
            -
                    w_ = w_.permute(0,2,1)   # b,hw,hw (first hw of k, second of q)
         | 
| 187 | 
            -
                    h_ = torch.bmm(v,w_)     # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
         | 
| 188 | 
            -
                    h_ = h_.reshape(b,c,h,w)
         | 
| 189 | 
            -
             | 
| 190 | 
            -
                    h_ = self.proj_out(h_)
         | 
| 191 | 
            -
             | 
| 192 | 
            -
                    return x+h_
         | 
| 193 | 
            -
             | 
| 194 | 
            -
             | 
| 195 | 
            -
            class Model(nn.Module):
         | 
| 196 | 
            -
                def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
         | 
| 197 | 
            -
                             attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
         | 
| 198 | 
            -
                             resolution, use_timestep=True):
         | 
| 199 | 
            -
                    super().__init__()
         | 
| 200 | 
            -
                    self.ch = ch
         | 
| 201 | 
            -
                    self.temb_ch = self.ch*4
         | 
| 202 | 
            -
                    self.num_resolutions = len(ch_mult)
         | 
| 203 | 
            -
                    self.num_res_blocks = num_res_blocks
         | 
| 204 | 
            -
                    self.resolution = resolution
         | 
| 205 | 
            -
                    self.in_channels = in_channels
         | 
| 206 | 
            -
             | 
| 207 | 
            -
                    self.use_timestep = use_timestep
         | 
| 208 | 
            -
                    if self.use_timestep:
         | 
| 209 | 
            -
                        # timestep embedding
         | 
| 210 | 
            -
                        self.temb = nn.Module()
         | 
| 211 | 
            -
                        self.temb.dense = nn.ModuleList([
         | 
| 212 | 
            -
                            torch.nn.Linear(self.ch,
         | 
| 213 | 
            -
                                            self.temb_ch),
         | 
| 214 | 
            -
                            torch.nn.Linear(self.temb_ch,
         | 
| 215 | 
            -
                                            self.temb_ch),
         | 
| 216 | 
            -
                        ])
         | 
| 217 | 
            -
             | 
| 218 | 
            -
                    # downsampling
         | 
| 219 | 
            -
                    self.conv_in = torch.nn.Conv2d(in_channels,
         | 
| 220 | 
            -
                                                   self.ch,
         | 
| 221 | 
            -
                                                   kernel_size=3,
         | 
| 222 | 
            -
                                                   stride=1,
         | 
| 223 | 
            -
                                                   padding=1)
         | 
| 224 | 
            -
             | 
| 225 | 
            -
                    curr_res = resolution
         | 
| 226 | 
            -
                    in_ch_mult = (1,)+tuple(ch_mult)
         | 
| 227 | 
            -
                    self.down = nn.ModuleList()
         | 
| 228 | 
            -
                    for i_level in range(self.num_resolutions):
         | 
| 229 | 
            -
                        block = nn.ModuleList()
         | 
| 230 | 
            -
                        attn = nn.ModuleList()
         | 
| 231 | 
            -
                        block_in = ch*in_ch_mult[i_level]
         | 
| 232 | 
            -
                        block_out = ch*ch_mult[i_level]
         | 
| 233 | 
            -
                        for i_block in range(self.num_res_blocks):
         | 
| 234 | 
            -
                            block.append(ResnetBlock(in_channels=block_in,
         | 
| 235 | 
            -
                                                     out_channels=block_out,
         | 
| 236 | 
            -
                                                     temb_channels=self.temb_ch,
         | 
| 237 | 
            -
                                                     dropout=dropout))
         | 
| 238 | 
            -
                            block_in = block_out
         | 
| 239 | 
            -
                            if curr_res in attn_resolutions:
         | 
| 240 | 
            -
                                attn.append(AttnBlock(block_in))
         | 
| 241 | 
            -
                        down = nn.Module()
         | 
| 242 | 
            -
                        down.block = block
         | 
| 243 | 
            -
                        down.attn = attn
         | 
| 244 | 
            -
                        if i_level != self.num_resolutions-1:
         | 
| 245 | 
            -
                            down.downsample = Downsample(block_in, resamp_with_conv)
         | 
| 246 | 
            -
                            curr_res = curr_res // 2
         | 
| 247 | 
            -
                        self.down.append(down)
         | 
| 248 | 
            -
             | 
| 249 | 
            -
                    # middle
         | 
| 250 | 
            -
                    self.mid = nn.Module()
         | 
| 251 | 
            -
                    self.mid.block_1 = ResnetBlock(in_channels=block_in,
         | 
| 252 | 
            -
                                                   out_channels=block_in,
         | 
| 253 | 
            -
                                                   temb_channels=self.temb_ch,
         | 
| 254 | 
            -
                                                   dropout=dropout)
         | 
| 255 | 
            -
                    self.mid.attn_1 = AttnBlock(block_in)
         | 
| 256 | 
            -
                    self.mid.block_2 = ResnetBlock(in_channels=block_in,
         | 
| 257 | 
            -
                                                   out_channels=block_in,
         | 
| 258 | 
            -
                                                   temb_channels=self.temb_ch,
         | 
| 259 | 
            -
                                                   dropout=dropout)
         | 
| 260 | 
            -
             | 
| 261 | 
            -
                    # upsampling
         | 
| 262 | 
            -
                    self.up = nn.ModuleList()
         | 
| 263 | 
            -
                    for i_level in reversed(range(self.num_resolutions)):
         | 
| 264 | 
            -
                        block = nn.ModuleList()
         | 
| 265 | 
            -
                        attn = nn.ModuleList()
         | 
| 266 | 
            -
                        block_out = ch*ch_mult[i_level]
         | 
| 267 | 
            -
                        skip_in = ch*ch_mult[i_level]
         | 
| 268 | 
            -
                        for i_block in range(self.num_res_blocks+1):
         | 
| 269 | 
            -
                            if i_block == self.num_res_blocks:
         | 
| 270 | 
            -
                                skip_in = ch*in_ch_mult[i_level]
         | 
| 271 | 
            -
                            block.append(ResnetBlock(in_channels=block_in+skip_in,
         | 
| 272 | 
            -
                                                     out_channels=block_out,
         | 
| 273 | 
            -
                                                     temb_channels=self.temb_ch,
         | 
| 274 | 
            -
                                                     dropout=dropout))
         | 
| 275 | 
            -
                            block_in = block_out
         | 
| 276 | 
            -
                            if curr_res in attn_resolutions:
         | 
| 277 | 
            -
                                attn.append(AttnBlock(block_in))
         | 
| 278 | 
            -
                        up = nn.Module()
         | 
| 279 | 
            -
                        up.block = block
         | 
| 280 | 
            -
                        up.attn = attn
         | 
| 281 | 
            -
                        if i_level != 0:
         | 
| 282 | 
            -
                            up.upsample = Upsample(block_in, resamp_with_conv)
         | 
| 283 | 
            -
                            curr_res = curr_res * 2
         | 
| 284 | 
            -
                        self.up.insert(0, up) # prepend to get consistent order
         | 
| 285 | 
            -
             | 
| 286 | 
            -
                    # end
         | 
| 287 | 
            -
                    self.norm_out = Normalize(block_in)
         | 
| 288 | 
            -
                    self.conv_out = torch.nn.Conv2d(block_in,
         | 
| 289 | 
            -
                                                    out_ch,
         | 
| 290 | 
            -
                                                    kernel_size=3,
         | 
| 291 | 
            -
                                                    stride=1,
         | 
| 292 | 
            -
                                                    padding=1)
         | 
| 293 | 
            -
             | 
| 294 | 
            -
             | 
| 295 | 
            -
                def forward(self, x, t=None):
         | 
| 296 | 
            -
                    #assert x.shape[2] == x.shape[3] == self.resolution
         | 
| 297 | 
            -
             | 
| 298 | 
            -
                    if self.use_timestep:
         | 
| 299 | 
            -
                        # timestep embedding
         | 
| 300 | 
            -
                        assert t is not None
         | 
| 301 | 
            -
                        temb = get_timestep_embedding(t, self.ch)
         | 
| 302 | 
            -
                        temb = self.temb.dense[0](temb)
         | 
| 303 | 
            -
                        temb = nonlinearity(temb)
         | 
| 304 | 
            -
                        temb = self.temb.dense[1](temb)
         | 
| 305 | 
            -
                    else:
         | 
| 306 | 
            -
                        temb = None
         | 
| 307 | 
            -
             | 
| 308 | 
            -
                    # downsampling
         | 
| 309 | 
            -
                    hs = [self.conv_in(x)]
         | 
| 310 | 
            -
                    for i_level in range(self.num_resolutions):
         | 
| 311 | 
            -
                        for i_block in range(self.num_res_blocks):
         | 
| 312 | 
            -
                            h = self.down[i_level].block[i_block](hs[-1], temb)
         | 
| 313 | 
            -
                            if len(self.down[i_level].attn) > 0:
         | 
| 314 | 
            -
                                h = self.down[i_level].attn[i_block](h)
         | 
| 315 | 
            -
                            hs.append(h)
         | 
| 316 | 
            -
                        if i_level != self.num_resolutions-1:
         | 
| 317 | 
            -
                            hs.append(self.down[i_level].downsample(hs[-1]))
         | 
| 318 | 
            -
             | 
| 319 | 
            -
                    # middle
         | 
| 320 | 
            -
                    h = hs[-1]
         | 
| 321 | 
            -
                    h = self.mid.block_1(h, temb)
         | 
| 322 | 
            -
                    h = self.mid.attn_1(h)
         | 
| 323 | 
            -
                    h = self.mid.block_2(h, temb)
         | 
| 324 | 
            -
             | 
| 325 | 
            -
                    # upsampling
         | 
| 326 | 
            -
                    for i_level in reversed(range(self.num_resolutions)):
         | 
| 327 | 
            -
                        for i_block in range(self.num_res_blocks+1):
         | 
| 328 | 
            -
                            h = self.up[i_level].block[i_block](
         | 
| 329 | 
            -
                                torch.cat([h, hs.pop()], dim=1), temb)
         | 
| 330 | 
            -
                            if len(self.up[i_level].attn) > 0:
         | 
| 331 | 
            -
                                h = self.up[i_level].attn[i_block](h)
         | 
| 332 | 
            -
                        if i_level != 0:
         | 
| 333 | 
            -
                            h = self.up[i_level].upsample(h)
         | 
| 334 | 
            -
             | 
| 335 | 
            -
                    # end
         | 
| 336 | 
            -
                    h = self.norm_out(h)
         | 
| 337 | 
            -
                    h = nonlinearity(h)
         | 
| 338 | 
            -
                    h = self.conv_out(h)
         | 
| 339 | 
            -
                    return h
         | 
| 340 | 
            -
             | 
| 341 | 
            -
             | 
| 342 | 
            -
            class Encoder(nn.Module):
         | 
| 343 | 
            -
                def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
         | 
| 344 | 
            -
                             attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
         | 
| 345 | 
            -
                             resolution, z_channels, double_z=True, **ignore_kwargs):
         | 
| 346 | 
            -
                    super().__init__()
         | 
| 347 | 
            -
                    self.ch = ch
         | 
| 348 | 
            -
                    self.temb_ch = 0
         | 
| 349 | 
            -
                    self.num_resolutions = len(ch_mult)
         | 
| 350 | 
            -
                    self.num_res_blocks = num_res_blocks
         | 
| 351 | 
            -
                    self.resolution = resolution
         | 
| 352 | 
            -
                    self.in_channels = in_channels
         | 
| 353 | 
            -
             | 
| 354 | 
            -
                    # downsampling
         | 
| 355 | 
            -
                    self.conv_in = torch.nn.Conv2d(in_channels,
         | 
| 356 | 
            -
                                                   self.ch,
         | 
| 357 | 
            -
                                                   kernel_size=3,
         | 
| 358 | 
            -
                                                   stride=1,
         | 
| 359 | 
            -
                                                   padding=1)
         | 
| 360 | 
            -
             | 
| 361 | 
            -
                    curr_res = resolution
         | 
| 362 | 
            -
                    in_ch_mult = (1,)+tuple(ch_mult)
         | 
| 363 | 
            -
                    self.down = nn.ModuleList()
         | 
| 364 | 
            -
                    for i_level in range(self.num_resolutions):
         | 
| 365 | 
            -
                        block = nn.ModuleList()
         | 
| 366 | 
            -
                        attn = nn.ModuleList()
         | 
| 367 | 
            -
                        block_in = ch*in_ch_mult[i_level]
         | 
| 368 | 
            -
                        block_out = ch*ch_mult[i_level]
         | 
| 369 | 
            -
                        for i_block in range(self.num_res_blocks):
         | 
| 370 | 
            -
                            block.append(ResnetBlock(in_channels=block_in,
         | 
| 371 | 
            -
                                                     out_channels=block_out,
         | 
| 372 | 
            -
                                                     temb_channels=self.temb_ch,
         | 
| 373 | 
            -
                                                     dropout=dropout))
         | 
| 374 | 
            -
                            block_in = block_out
         | 
| 375 | 
            -
                            if curr_res in attn_resolutions:
         | 
| 376 | 
            -
                                attn.append(AttnBlock(block_in))
         | 
| 377 | 
            -
                        down = nn.Module()
         | 
| 378 | 
            -
                        down.block = block
         | 
| 379 | 
            -
                        down.attn = attn
         | 
| 380 | 
            -
                        if i_level != self.num_resolutions-1:
         | 
| 381 | 
            -
                            down.downsample = Downsample(block_in, resamp_with_conv)
         | 
| 382 | 
            -
                            curr_res = curr_res // 2
         | 
| 383 | 
            -
                        self.down.append(down)
         | 
| 384 | 
            -
             | 
| 385 | 
            -
                    # middle
         | 
| 386 | 
            -
                    self.mid = nn.Module()
         | 
| 387 | 
            -
                    self.mid.block_1 = ResnetBlock(in_channels=block_in,
         | 
| 388 | 
            -
                                                   out_channels=block_in,
         | 
| 389 | 
            -
                                                   temb_channels=self.temb_ch,
         | 
| 390 | 
            -
                                                   dropout=dropout)
         | 
| 391 | 
            -
                    self.mid.attn_1 = AttnBlock(block_in)
         | 
| 392 | 
            -
                    self.mid.block_2 = ResnetBlock(in_channels=block_in,
         | 
| 393 | 
            -
                                                   out_channels=block_in,
         | 
| 394 | 
            -
                                                   temb_channels=self.temb_ch,
         | 
| 395 | 
            -
                                                   dropout=dropout)
         | 
| 396 | 
            -
             | 
| 397 | 
            -
                    # end
         | 
| 398 | 
            -
                    self.norm_out = Normalize(block_in)
         | 
| 399 | 
            -
                    self.conv_out = torch.nn.Conv2d(block_in,
         | 
| 400 | 
            -
                                                    2*z_channels if double_z else z_channels,
         | 
| 401 | 
            -
                                                    kernel_size=3,
         | 
| 402 | 
            -
                                                    stride=1,
         | 
| 403 | 
            -
                                                    padding=1)
         | 
| 404 | 
            -
             | 
| 405 | 
            -
             | 
| 406 | 
            -
                def forward(self, x):
         | 
| 407 | 
            -
                    #assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
         | 
| 408 | 
            -
             | 
| 409 | 
            -
                    # timestep embedding
         | 
| 410 | 
            -
                    temb = None
         | 
| 411 | 
            -
             | 
| 412 | 
            -
                    # downsampling
         | 
| 413 | 
            -
                    hs = [self.conv_in(x)]
         | 
| 414 | 
            -
                    for i_level in range(self.num_resolutions):
         | 
| 415 | 
            -
                        for i_block in range(self.num_res_blocks):
         | 
| 416 | 
            -
                            h = self.down[i_level].block[i_block](hs[-1], temb)
         | 
| 417 | 
            -
                            if len(self.down[i_level].attn) > 0:
         | 
| 418 | 
            -
                                h = self.down[i_level].attn[i_block](h)
         | 
| 419 | 
            -
                            hs.append(h)
         | 
| 420 | 
            -
                        if i_level != self.num_resolutions-1:
         | 
| 421 | 
            -
                            hs.append(self.down[i_level].downsample(hs[-1]))
         | 
| 422 | 
            -
             | 
| 423 | 
            -
                    # middle
         | 
| 424 | 
            -
                    h = hs[-1]
         | 
| 425 | 
            -
                    h = self.mid.block_1(h, temb)
         | 
| 426 | 
            -
                    h = self.mid.attn_1(h)
         | 
| 427 | 
            -
                    h = self.mid.block_2(h, temb)
         | 
| 428 | 
            -
             | 
| 429 | 
            -
                    # end
         | 
| 430 | 
            -
                    h = self.norm_out(h)
         | 
| 431 | 
            -
                    h = nonlinearity(h)
         | 
| 432 | 
            -
                    h = self.conv_out(h)
         | 
| 433 | 
            -
                    return h
         | 
| 434 | 
            -
             | 
| 435 | 
            -
             | 
| 436 | 
            -
            class Decoder(nn.Module):
         | 
| 437 | 
            -
                def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
         | 
| 438 | 
            -
                             attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
         | 
| 439 | 
            -
                             resolution, z_channels, give_pre_end=False, **ignorekwargs):
         | 
| 440 | 
            -
                    super().__init__()
         | 
| 441 | 
            -
                    self.ch = ch
         | 
| 442 | 
            -
                    self.temb_ch = 0
         | 
| 443 | 
            -
                    self.num_resolutions = len(ch_mult)
         | 
| 444 | 
            -
                    self.num_res_blocks = num_res_blocks
         | 
| 445 | 
            -
                    self.resolution = resolution
         | 
| 446 | 
            -
                    self.in_channels = in_channels
         | 
| 447 | 
            -
                    self.give_pre_end = give_pre_end
         | 
| 448 | 
            -
             | 
| 449 | 
            -
                    # compute in_ch_mult, block_in and curr_res at lowest res
         | 
| 450 | 
            -
                    in_ch_mult = (1,)+tuple(ch_mult)
         | 
| 451 | 
            -
                    block_in = ch*ch_mult[self.num_resolutions-1]
         | 
| 452 | 
            -
                    curr_res = resolution // 2**(self.num_resolutions-1)
         | 
| 453 | 
            -
                    self.z_shape = (1,z_channels,curr_res,curr_res)
         | 
| 454 | 
            -
                    print("Working with z of shape {} = {} dimensions.".format(
         | 
| 455 | 
            -
                        self.z_shape, np.prod(self.z_shape)))
         | 
| 456 | 
            -
             | 
| 457 | 
            -
                    # z to block_in
         | 
| 458 | 
            -
                    self.conv_in = torch.nn.Conv2d(z_channels,
         | 
| 459 | 
            -
                                                   block_in,
         | 
| 460 | 
            -
                                                   kernel_size=3,
         | 
| 461 | 
            -
                                                   stride=1,
         | 
| 462 | 
            -
                                                   padding=1)
         | 
| 463 | 
            -
             | 
| 464 | 
            -
                    # middle
         | 
| 465 | 
            -
                    self.mid = nn.Module()
         | 
| 466 | 
            -
                    self.mid.block_1 = ResnetBlock(in_channels=block_in,
         | 
| 467 | 
            -
                                                   out_channels=block_in,
         | 
| 468 | 
            -
                                                   temb_channels=self.temb_ch,
         | 
| 469 | 
            -
                                                   dropout=dropout)
         | 
| 470 | 
            -
                    self.mid.attn_1 = AttnBlock(block_in)
         | 
| 471 | 
            -
                    self.mid.block_2 = ResnetBlock(in_channels=block_in,
         | 
| 472 | 
            -
                                                   out_channels=block_in,
         | 
| 473 | 
            -
                                                   temb_channels=self.temb_ch,
         | 
| 474 | 
            -
                                                   dropout=dropout)
         | 
| 475 | 
            -
             | 
| 476 | 
            -
                    # upsampling
         | 
| 477 | 
            -
                    self.up = nn.ModuleList()
         | 
| 478 | 
            -
                    for i_level in reversed(range(self.num_resolutions)):
         | 
| 479 | 
            -
                        block = nn.ModuleList()
         | 
| 480 | 
            -
                        attn = nn.ModuleList()
         | 
| 481 | 
            -
                        block_out = ch*ch_mult[i_level]
         | 
| 482 | 
            -
                        for i_block in range(self.num_res_blocks+1):
         | 
| 483 | 
            -
                            block.append(ResnetBlock(in_channels=block_in,
         | 
| 484 | 
            -
                                                     out_channels=block_out,
         | 
| 485 | 
            -
                                                     temb_channels=self.temb_ch,
         | 
| 486 | 
            -
                                                     dropout=dropout))
         | 
| 487 | 
            -
                            block_in = block_out
         | 
| 488 | 
            -
                            if curr_res in attn_resolutions:
         | 
| 489 | 
            -
                                attn.append(AttnBlock(block_in))
         | 
| 490 | 
            -
                        up = nn.Module()
         | 
| 491 | 
            -
                        up.block = block
         | 
| 492 | 
            -
                        up.attn = attn
         | 
| 493 | 
            -
                        if i_level != 0:
         | 
| 494 | 
            -
                            up.upsample = Upsample(block_in, resamp_with_conv)
         | 
| 495 | 
            -
                            curr_res = curr_res * 2
         | 
| 496 | 
            -
                        self.up.insert(0, up) # prepend to get consistent order
         | 
| 497 | 
            -
             | 
| 498 | 
            -
                    # end
         | 
| 499 | 
            -
                    self.norm_out = Normalize(block_in)
         | 
| 500 | 
            -
                    self.conv_out = torch.nn.Conv2d(block_in,
         | 
| 501 | 
            -
                                                    out_ch,
         | 
| 502 | 
            -
                                                    kernel_size=3,
         | 
| 503 | 
            -
                                                    stride=1,
         | 
| 504 | 
            -
                                                    padding=1)
         | 
| 505 | 
            -
             | 
| 506 | 
            -
                def forward(self, z):
         | 
| 507 | 
            -
                    #assert z.shape[1:] == self.z_shape[1:]
         | 
| 508 | 
            -
                    self.last_z_shape = z.shape
         | 
| 509 | 
            -
             | 
| 510 | 
            -
                    # timestep embedding
         | 
| 511 | 
            -
                    temb = None
         | 
| 512 | 
            -
             | 
| 513 | 
            -
                    # z to block_in
         | 
| 514 | 
            -
                    h = self.conv_in(z)
         | 
| 515 | 
            -
             | 
| 516 | 
            -
                    # middle
         | 
| 517 | 
            -
                    h = self.mid.block_1(h, temb)
         | 
| 518 | 
            -
                    h = self.mid.attn_1(h)
         | 
| 519 | 
            -
                    h = self.mid.block_2(h, temb)
         | 
| 520 | 
            -
             | 
| 521 | 
            -
                    # upsampling
         | 
| 522 | 
            -
                    for i_level in reversed(range(self.num_resolutions)):
         | 
| 523 | 
            -
                        for i_block in range(self.num_res_blocks+1):
         | 
| 524 | 
            -
                            h = self.up[i_level].block[i_block](h, temb)
         | 
| 525 | 
            -
                            if len(self.up[i_level].attn) > 0:
         | 
| 526 | 
            -
                                h = self.up[i_level].attn[i_block](h)
         | 
| 527 | 
            -
                        if i_level != 0:
         | 
| 528 | 
            -
                            h = self.up[i_level].upsample(h)
         | 
| 529 | 
            -
             | 
| 530 | 
            -
                    # end
         | 
| 531 | 
            -
                    if self.give_pre_end:
         | 
| 532 | 
            -
                        return h
         | 
| 533 | 
            -
             | 
| 534 | 
            -
                    h = self.norm_out(h)
         | 
| 535 | 
            -
                    h = nonlinearity(h)
         | 
| 536 | 
            -
                    h = self.conv_out(h)
         | 
| 537 | 
            -
                    return h
         | 
| 538 | 
            -
             | 
| 539 | 
            -
             | 
| 540 | 
            -
            class VUNet(nn.Module):
         | 
| 541 | 
            -
                def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
         | 
| 542 | 
            -
                             attn_resolutions, dropout=0.0, resamp_with_conv=True,
         | 
| 543 | 
            -
                             in_channels, c_channels,
         | 
| 544 | 
            -
                             resolution, z_channels, use_timestep=False, **ignore_kwargs):
         | 
| 545 | 
            -
                    super().__init__()
         | 
| 546 | 
            -
                    self.ch = ch
         | 
| 547 | 
            -
                    self.temb_ch = self.ch*4
         | 
| 548 | 
            -
                    self.num_resolutions = len(ch_mult)
         | 
| 549 | 
            -
                    self.num_res_blocks = num_res_blocks
         | 
| 550 | 
            -
                    self.resolution = resolution
         | 
| 551 | 
            -
             | 
| 552 | 
            -
                    self.use_timestep = use_timestep
         | 
| 553 | 
            -
                    if self.use_timestep:
         | 
| 554 | 
            -
                        # timestep embedding
         | 
| 555 | 
            -
                        self.temb = nn.Module()
         | 
| 556 | 
            -
                        self.temb.dense = nn.ModuleList([
         | 
| 557 | 
            -
                            torch.nn.Linear(self.ch,
         | 
| 558 | 
            -
                                            self.temb_ch),
         | 
| 559 | 
            -
                            torch.nn.Linear(self.temb_ch,
         | 
| 560 | 
            -
                                            self.temb_ch),
         | 
| 561 | 
            -
                        ])
         | 
| 562 | 
            -
             | 
| 563 | 
            -
                    # downsampling
         | 
| 564 | 
            -
                    self.conv_in = torch.nn.Conv2d(c_channels,
         | 
| 565 | 
            -
                                                   self.ch,
         | 
| 566 | 
            -
                                                   kernel_size=3,
         | 
| 567 | 
            -
                                                   stride=1,
         | 
| 568 | 
            -
                                                   padding=1)
         | 
| 569 | 
            -
             | 
| 570 | 
            -
                    curr_res = resolution
         | 
| 571 | 
            -
                    in_ch_mult = (1,)+tuple(ch_mult)
         | 
| 572 | 
            -
                    self.down = nn.ModuleList()
         | 
| 573 | 
            -
                    for i_level in range(self.num_resolutions):
         | 
| 574 | 
            -
                        block = nn.ModuleList()
         | 
| 575 | 
            -
                        attn = nn.ModuleList()
         | 
| 576 | 
            -
                        block_in = ch*in_ch_mult[i_level]
         | 
| 577 | 
            -
                        block_out = ch*ch_mult[i_level]
         | 
| 578 | 
            -
                        for i_block in range(self.num_res_blocks):
         | 
| 579 | 
            -
                            block.append(ResnetBlock(in_channels=block_in,
         | 
| 580 | 
            -
                                                     out_channels=block_out,
         | 
| 581 | 
            -
                                                     temb_channels=self.temb_ch,
         | 
| 582 | 
            -
                                                     dropout=dropout))
         | 
| 583 | 
            -
                            block_in = block_out
         | 
| 584 | 
            -
                            if curr_res in attn_resolutions:
         | 
| 585 | 
            -
                                attn.append(AttnBlock(block_in))
         | 
| 586 | 
            -
                        down = nn.Module()
         | 
| 587 | 
            -
                        down.block = block
         | 
| 588 | 
            -
                        down.attn = attn
         | 
| 589 | 
            -
                        if i_level != self.num_resolutions-1:
         | 
| 590 | 
            -
                            down.downsample = Downsample(block_in, resamp_with_conv)
         | 
| 591 | 
            -
                            curr_res = curr_res // 2
         | 
| 592 | 
            -
                        self.down.append(down)
         | 
| 593 | 
            -
             | 
| 594 | 
            -
                    self.z_in = torch.nn.Conv2d(z_channels,
         | 
| 595 | 
            -
                                                block_in,
         | 
| 596 | 
            -
                                                kernel_size=1,
         | 
| 597 | 
            -
                                                stride=1,
         | 
| 598 | 
            -
                                                padding=0)
         | 
| 599 | 
            -
                    # middle
         | 
| 600 | 
            -
                    self.mid = nn.Module()
         | 
| 601 | 
            -
                    self.mid.block_1 = ResnetBlock(in_channels=2*block_in,
         | 
| 602 | 
            -
                                                   out_channels=block_in,
         | 
| 603 | 
            -
                                                   temb_channels=self.temb_ch,
         | 
| 604 | 
            -
                                                   dropout=dropout)
         | 
| 605 | 
            -
                    self.mid.attn_1 = AttnBlock(block_in)
         | 
| 606 | 
            -
                    self.mid.block_2 = ResnetBlock(in_channels=block_in,
         | 
| 607 | 
            -
                                                   out_channels=block_in,
         | 
| 608 | 
            -
                                                   temb_channels=self.temb_ch,
         | 
| 609 | 
            -
                                                   dropout=dropout)
         | 
| 610 | 
            -
             | 
| 611 | 
            -
                    # upsampling
         | 
| 612 | 
            -
                    self.up = nn.ModuleList()
         | 
| 613 | 
            -
                    for i_level in reversed(range(self.num_resolutions)):
         | 
| 614 | 
            -
                        block = nn.ModuleList()
         | 
| 615 | 
            -
                        attn = nn.ModuleList()
         | 
| 616 | 
            -
                        block_out = ch*ch_mult[i_level]
         | 
| 617 | 
            -
                        skip_in = ch*ch_mult[i_level]
         | 
| 618 | 
            -
                        for i_block in range(self.num_res_blocks+1):
         | 
| 619 | 
            -
                            if i_block == self.num_res_blocks:
         | 
| 620 | 
            -
                                skip_in = ch*in_ch_mult[i_level]
         | 
| 621 | 
            -
                            block.append(ResnetBlock(in_channels=block_in+skip_in,
         | 
| 622 | 
            -
                                                     out_channels=block_out,
         | 
| 623 | 
            -
                                                     temb_channels=self.temb_ch,
         | 
| 624 | 
            -
                                                     dropout=dropout))
         | 
| 625 | 
            -
                            block_in = block_out
         | 
| 626 | 
            -
                            if curr_res in attn_resolutions:
         | 
| 627 | 
            -
                                attn.append(AttnBlock(block_in))
         | 
| 628 | 
            -
                        up = nn.Module()
         | 
| 629 | 
            -
                        up.block = block
         | 
| 630 | 
            -
                        up.attn = attn
         | 
| 631 | 
            -
                        if i_level != 0:
         | 
| 632 | 
            -
                            up.upsample = Upsample(block_in, resamp_with_conv)
         | 
| 633 | 
            -
                            curr_res = curr_res * 2
         | 
| 634 | 
            -
                        self.up.insert(0, up) # prepend to get consistent order
         | 
| 635 | 
            -
             | 
| 636 | 
            -
                    # end
         | 
| 637 | 
            -
                    self.norm_out = Normalize(block_in)
         | 
| 638 | 
            -
                    self.conv_out = torch.nn.Conv2d(block_in,
         | 
| 639 | 
            -
                                                    out_ch,
         | 
| 640 | 
            -
                                                    kernel_size=3,
         | 
| 641 | 
            -
                                                    stride=1,
         | 
| 642 | 
            -
                                                    padding=1)
         | 
| 643 | 
            -
             | 
| 644 | 
            -
             | 
| 645 | 
            -
                def forward(self, x, z):
         | 
| 646 | 
            -
                    #assert x.shape[2] == x.shape[3] == self.resolution
         | 
| 647 | 
            -
             | 
| 648 | 
            -
                    if self.use_timestep:
         | 
| 649 | 
            -
                        # timestep embedding
         | 
| 650 | 
            -
                        assert t is not None
         | 
| 651 | 
            -
                        temb = get_timestep_embedding(t, self.ch)
         | 
| 652 | 
            -
                        temb = self.temb.dense[0](temb)
         | 
| 653 | 
            -
                        temb = nonlinearity(temb)
         | 
| 654 | 
            -
                        temb = self.temb.dense[1](temb)
         | 
| 655 | 
            -
                    else:
         | 
| 656 | 
            -
                        temb = None
         | 
| 657 | 
            -
             | 
| 658 | 
            -
                    # downsampling
         | 
| 659 | 
            -
                    hs = [self.conv_in(x)]
         | 
| 660 | 
            -
                    for i_level in range(self.num_resolutions):
         | 
| 661 | 
            -
                        for i_block in range(self.num_res_blocks):
         | 
| 662 | 
            -
                            h = self.down[i_level].block[i_block](hs[-1], temb)
         | 
| 663 | 
            -
                            if len(self.down[i_level].attn) > 0:
         | 
| 664 | 
            -
                                h = self.down[i_level].attn[i_block](h)
         | 
| 665 | 
            -
                            hs.append(h)
         | 
| 666 | 
            -
                        if i_level != self.num_resolutions-1:
         | 
| 667 | 
            -
                            hs.append(self.down[i_level].downsample(hs[-1]))
         | 
| 668 | 
            -
             | 
| 669 | 
            -
                    # middle
         | 
| 670 | 
            -
                    h = hs[-1]
         | 
| 671 | 
            -
                    z = self.z_in(z)
         | 
| 672 | 
            -
                    h = torch.cat((h,z),dim=1)
         | 
| 673 | 
            -
                    h = self.mid.block_1(h, temb)
         | 
| 674 | 
            -
                    h = self.mid.attn_1(h)
         | 
| 675 | 
            -
                    h = self.mid.block_2(h, temb)
         | 
| 676 | 
            -
             | 
| 677 | 
            -
                    # upsampling
         | 
| 678 | 
            -
                    for i_level in reversed(range(self.num_resolutions)):
         | 
| 679 | 
            -
                        for i_block in range(self.num_res_blocks+1):
         | 
| 680 | 
            -
                            h = self.up[i_level].block[i_block](
         | 
| 681 | 
            -
                                torch.cat([h, hs.pop()], dim=1), temb)
         | 
| 682 | 
            -
                            if len(self.up[i_level].attn) > 0:
         | 
| 683 | 
            -
                                h = self.up[i_level].attn[i_block](h)
         | 
| 684 | 
            -
                        if i_level != 0:
         | 
| 685 | 
            -
                            h = self.up[i_level].upsample(h)
         | 
| 686 | 
            -
             | 
| 687 | 
            -
                    # end
         | 
| 688 | 
            -
                    h = self.norm_out(h)
         | 
| 689 | 
            -
                    h = nonlinearity(h)
         | 
| 690 | 
            -
                    h = self.conv_out(h)
         | 
| 691 | 
            -
                    return h
         | 
| 692 | 
            -
             | 
| 693 | 
            -
             | 
| 694 | 
            -
            class SimpleDecoder(nn.Module):
         | 
| 695 | 
            -
                def __init__(self, in_channels, out_channels, *args, **kwargs):
         | 
| 696 | 
            -
                    super().__init__()
         | 
| 697 | 
            -
                    self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
         | 
| 698 | 
            -
                                                 ResnetBlock(in_channels=in_channels,
         | 
| 699 | 
            -
                                                             out_channels=2 * in_channels,
         | 
| 700 | 
            -
                                                             temb_channels=0, dropout=0.0),
         | 
| 701 | 
            -
                                                 ResnetBlock(in_channels=2 * in_channels,
         | 
| 702 | 
            -
                                                            out_channels=4 * in_channels,
         | 
| 703 | 
            -
                                                            temb_channels=0, dropout=0.0),
         | 
| 704 | 
            -
                                                 ResnetBlock(in_channels=4 * in_channels,
         | 
| 705 | 
            -
                                                            out_channels=2 * in_channels,
         | 
| 706 | 
            -
                                                            temb_channels=0, dropout=0.0),
         | 
| 707 | 
            -
                                                 nn.Conv2d(2*in_channels, in_channels, 1),
         | 
| 708 | 
            -
                                                 Upsample(in_channels, with_conv=True)])
         | 
| 709 | 
            -
                    # end
         | 
| 710 | 
            -
                    self.norm_out = Normalize(in_channels)
         | 
| 711 | 
            -
                    self.conv_out = torch.nn.Conv2d(in_channels,
         | 
| 712 | 
            -
                                                    out_channels,
         | 
| 713 | 
            -
                                                    kernel_size=3,
         | 
| 714 | 
            -
                                                    stride=1,
         | 
| 715 | 
            -
                                                    padding=1)
         | 
| 716 | 
            -
             | 
| 717 | 
            -
                def forward(self, x):
         | 
| 718 | 
            -
                    for i, layer in enumerate(self.model):
         | 
| 719 | 
            -
                        if i in [1,2,3]:
         | 
| 720 | 
            -
                            x = layer(x, None)
         | 
| 721 | 
            -
                        else:
         | 
| 722 | 
            -
                            x = layer(x)
         | 
| 723 | 
            -
             | 
| 724 | 
            -
                    h = self.norm_out(x)
         | 
| 725 | 
            -
                    h = nonlinearity(h)
         | 
| 726 | 
            -
                    x = self.conv_out(h)
         | 
| 727 | 
            -
                    return x
         | 
| 728 | 
            -
             | 
| 729 | 
            -
             | 
| 730 | 
            -
            class UpsampleDecoder(nn.Module):
         | 
| 731 | 
            -
                def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
         | 
| 732 | 
            -
                             ch_mult=(2,2), dropout=0.0):
         | 
| 733 | 
            -
                    super().__init__()
         | 
| 734 | 
            -
                    # upsampling
         | 
| 735 | 
            -
                    self.temb_ch = 0
         | 
| 736 | 
            -
                    self.num_resolutions = len(ch_mult)
         | 
| 737 | 
            -
                    self.num_res_blocks = num_res_blocks
         | 
| 738 | 
            -
                    block_in = in_channels
         | 
| 739 | 
            -
                    curr_res = resolution // 2 ** (self.num_resolutions - 1)
         | 
| 740 | 
            -
                    self.res_blocks = nn.ModuleList()
         | 
| 741 | 
            -
                    self.upsample_blocks = nn.ModuleList()
         | 
| 742 | 
            -
                    for i_level in range(self.num_resolutions):
         | 
| 743 | 
            -
                        res_block = []
         | 
| 744 | 
            -
                        block_out = ch * ch_mult[i_level]
         | 
| 745 | 
            -
                        for i_block in range(self.num_res_blocks + 1):
         | 
| 746 | 
            -
                            res_block.append(ResnetBlock(in_channels=block_in,
         | 
| 747 | 
            -
                                                     out_channels=block_out,
         | 
| 748 | 
            -
                                                     temb_channels=self.temb_ch,
         | 
| 749 | 
            -
                                                     dropout=dropout))
         | 
| 750 | 
            -
                            block_in = block_out
         | 
| 751 | 
            -
                        self.res_blocks.append(nn.ModuleList(res_block))
         | 
| 752 | 
            -
                        if i_level != self.num_resolutions - 1:
         | 
| 753 | 
            -
                            self.upsample_blocks.append(Upsample(block_in, True))
         | 
| 754 | 
            -
                            curr_res = curr_res * 2
         | 
| 755 | 
            -
             | 
| 756 | 
            -
                    # end
         | 
| 757 | 
            -
                    self.norm_out = Normalize(block_in)
         | 
| 758 | 
            -
                    self.conv_out = torch.nn.Conv2d(block_in,
         | 
| 759 | 
            -
                                                    out_channels,
         | 
| 760 | 
            -
                                                    kernel_size=3,
         | 
| 761 | 
            -
                                                    stride=1,
         | 
| 762 | 
            -
                                                    padding=1)
         | 
| 763 | 
            -
             | 
| 764 | 
            -
                def forward(self, x):
         | 
| 765 | 
            -
                    # upsampling
         | 
| 766 | 
            -
                    h = x
         | 
| 767 | 
            -
                    for k, i_level in enumerate(range(self.num_resolutions)):
         | 
| 768 | 
            -
                        for i_block in range(self.num_res_blocks + 1):
         | 
| 769 | 
            -
                            h = self.res_blocks[i_level][i_block](h, None)
         | 
| 770 | 
            -
                        if i_level != self.num_resolutions - 1:
         | 
| 771 | 
            -
                            h = self.upsample_blocks[k](h)
         | 
| 772 | 
            -
                    h = self.norm_out(h)
         | 
| 773 | 
            -
                    h = nonlinearity(h)
         | 
| 774 | 
            -
                    h = self.conv_out(h)
         | 
| 775 | 
            -
                    return h
         | 
| 776 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        taming-transformers/taming/modules/discriminator/model.py
    DELETED
    
    | @@ -1,67 +0,0 @@ | |
| 1 | 
            -
            import functools
         | 
| 2 | 
            -
            import torch.nn as nn
         | 
| 3 | 
            -
             | 
| 4 | 
            -
             | 
| 5 | 
            -
            from taming.modules.util import ActNorm
         | 
| 6 | 
            -
             | 
| 7 | 
            -
             | 
| 8 | 
            -
            def weights_init(m):
         | 
| 9 | 
            -
                classname = m.__class__.__name__
         | 
| 10 | 
            -
                if classname.find('Conv') != -1:
         | 
| 11 | 
            -
                    nn.init.normal_(m.weight.data, 0.0, 0.02)
         | 
| 12 | 
            -
                elif classname.find('BatchNorm') != -1:
         | 
| 13 | 
            -
                    nn.init.normal_(m.weight.data, 1.0, 0.02)
         | 
| 14 | 
            -
                    nn.init.constant_(m.bias.data, 0)
         | 
| 15 | 
            -
             | 
| 16 | 
            -
             | 
| 17 | 
            -
            class NLayerDiscriminator(nn.Module):
         | 
| 18 | 
            -
                """Defines a PatchGAN discriminator as in Pix2Pix
         | 
| 19 | 
            -
                    --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
         | 
| 20 | 
            -
                """
         | 
| 21 | 
            -
                def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
         | 
| 22 | 
            -
                    """Construct a PatchGAN discriminator
         | 
| 23 | 
            -
                    Parameters:
         | 
| 24 | 
            -
                        input_nc (int)  -- the number of channels in input images
         | 
| 25 | 
            -
                        ndf (int)       -- the number of filters in the last conv layer
         | 
| 26 | 
            -
                        n_layers (int)  -- the number of conv layers in the discriminator
         | 
| 27 | 
            -
                        norm_layer      -- normalization layer
         | 
| 28 | 
            -
                    """
         | 
| 29 | 
            -
                    super(NLayerDiscriminator, self).__init__()
         | 
| 30 | 
            -
                    if not use_actnorm:
         | 
| 31 | 
            -
                        norm_layer = nn.BatchNorm2d
         | 
| 32 | 
            -
                    else:
         | 
| 33 | 
            -
                        norm_layer = ActNorm
         | 
| 34 | 
            -
                    if type(norm_layer) == functools.partial:  # no need to use bias as BatchNorm2d has affine parameters
         | 
| 35 | 
            -
                        use_bias = norm_layer.func != nn.BatchNorm2d
         | 
| 36 | 
            -
                    else:
         | 
| 37 | 
            -
                        use_bias = norm_layer != nn.BatchNorm2d
         | 
| 38 | 
            -
             | 
| 39 | 
            -
                    kw = 4
         | 
| 40 | 
            -
                    padw = 1
         | 
| 41 | 
            -
                    sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
         | 
| 42 | 
            -
                    nf_mult = 1
         | 
| 43 | 
            -
                    nf_mult_prev = 1
         | 
| 44 | 
            -
                    for n in range(1, n_layers):  # gradually increase the number of filters
         | 
| 45 | 
            -
                        nf_mult_prev = nf_mult
         | 
| 46 | 
            -
                        nf_mult = min(2 ** n, 8)
         | 
| 47 | 
            -
                        sequence += [
         | 
| 48 | 
            -
                            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
         | 
| 49 | 
            -
                            norm_layer(ndf * nf_mult),
         | 
| 50 | 
            -
                            nn.LeakyReLU(0.2, True)
         | 
| 51 | 
            -
                        ]
         | 
| 52 | 
            -
             | 
| 53 | 
            -
                    nf_mult_prev = nf_mult
         | 
| 54 | 
            -
                    nf_mult = min(2 ** n_layers, 8)
         | 
| 55 | 
            -
                    sequence += [
         | 
| 56 | 
            -
                        nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
         | 
| 57 | 
            -
                        norm_layer(ndf * nf_mult),
         | 
| 58 | 
            -
                        nn.LeakyReLU(0.2, True)
         | 
| 59 | 
            -
                    ]
         | 
| 60 | 
            -
             | 
| 61 | 
            -
                    sequence += [
         | 
| 62 | 
            -
                        nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]  # output 1 channel prediction map
         | 
| 63 | 
            -
                    self.main = nn.Sequential(*sequence)
         | 
| 64 | 
            -
             | 
| 65 | 
            -
                def forward(self, input):
         | 
| 66 | 
            -
                    """Standard forward."""
         | 
| 67 | 
            -
                    return self.main(input)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        taming-transformers/taming/modules/losses/__init__.py
    DELETED
    
    | @@ -1,2 +0,0 @@ | |
| 1 | 
            -
            from taming.modules.losses.vqperceptual import DummyLoss
         | 
| 2 | 
            -
             | 
|  | |
|  | |
|  | 
    	
        taming-transformers/taming/modules/losses/lpips.py
    DELETED
    
    | @@ -1,123 +0,0 @@ | |
| 1 | 
            -
            """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
         | 
| 2 | 
            -
             | 
| 3 | 
            -
            import torch
         | 
| 4 | 
            -
            import torch.nn as nn
         | 
| 5 | 
            -
            from torchvision import models
         | 
| 6 | 
            -
            from collections import namedtuple
         | 
| 7 | 
            -
             | 
| 8 | 
            -
            from taming.util import get_ckpt_path
         | 
| 9 | 
            -
             | 
| 10 | 
            -
             | 
| 11 | 
            -
            class LPIPS(nn.Module):
         | 
| 12 | 
            -
                # Learned perceptual metric
         | 
| 13 | 
            -
                def __init__(self, use_dropout=True):
         | 
| 14 | 
            -
                    super().__init__()
         | 
| 15 | 
            -
                    self.scaling_layer = ScalingLayer()
         | 
| 16 | 
            -
                    self.chns = [64, 128, 256, 512, 512]  # vg16 features
         | 
| 17 | 
            -
                    self.net = vgg16(pretrained=True, requires_grad=False)
         | 
| 18 | 
            -
                    self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
         | 
| 19 | 
            -
                    self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
         | 
| 20 | 
            -
                    self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
         | 
| 21 | 
            -
                    self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
         | 
| 22 | 
            -
                    self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
         | 
| 23 | 
            -
                    self.load_from_pretrained()
         | 
| 24 | 
            -
                    for param in self.parameters():
         | 
| 25 | 
            -
                        param.requires_grad = False
         | 
| 26 | 
            -
             | 
| 27 | 
            -
                def load_from_pretrained(self, name="vgg_lpips"):
         | 
| 28 | 
            -
                    ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips")
         | 
| 29 | 
            -
                    self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
         | 
| 30 | 
            -
                    print("loaded pretrained LPIPS loss from {}".format(ckpt))
         | 
| 31 | 
            -
             | 
| 32 | 
            -
                @classmethod
         | 
| 33 | 
            -
                def from_pretrained(cls, name="vgg_lpips"):
         | 
| 34 | 
            -
                    if name != "vgg_lpips":
         | 
| 35 | 
            -
                        raise NotImplementedError
         | 
| 36 | 
            -
                    model = cls()
         | 
| 37 | 
            -
                    ckpt = get_ckpt_path(name)
         | 
| 38 | 
            -
                    model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
         | 
| 39 | 
            -
                    return model
         | 
| 40 | 
            -
             | 
| 41 | 
            -
                def forward(self, input, target):
         | 
| 42 | 
            -
                    in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
         | 
| 43 | 
            -
                    outs0, outs1 = self.net(in0_input), self.net(in1_input)
         | 
| 44 | 
            -
                    feats0, feats1, diffs = {}, {}, {}
         | 
| 45 | 
            -
                    lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
         | 
| 46 | 
            -
                    for kk in range(len(self.chns)):
         | 
| 47 | 
            -
                        feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
         | 
| 48 | 
            -
                        diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
         | 
| 49 | 
            -
             | 
| 50 | 
            -
                    res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
         | 
| 51 | 
            -
                    val = res[0]
         | 
| 52 | 
            -
                    for l in range(1, len(self.chns)):
         | 
| 53 | 
            -
                        val += res[l]
         | 
| 54 | 
            -
                    return val
         | 
| 55 | 
            -
             | 
| 56 | 
            -
             | 
| 57 | 
            -
            class ScalingLayer(nn.Module):
         | 
| 58 | 
            -
                def __init__(self):
         | 
| 59 | 
            -
                    super(ScalingLayer, self).__init__()
         | 
| 60 | 
            -
                    self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
         | 
| 61 | 
            -
                    self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
         | 
| 62 | 
            -
             | 
| 63 | 
            -
                def forward(self, inp):
         | 
| 64 | 
            -
                    return (inp - self.shift) / self.scale
         | 
| 65 | 
            -
             | 
| 66 | 
            -
             | 
| 67 | 
            -
            class NetLinLayer(nn.Module):
         | 
| 68 | 
            -
                """ A single linear layer which does a 1x1 conv """
         | 
| 69 | 
            -
                def __init__(self, chn_in, chn_out=1, use_dropout=False):
         | 
| 70 | 
            -
                    super(NetLinLayer, self).__init__()
         | 
| 71 | 
            -
                    layers = [nn.Dropout(), ] if (use_dropout) else []
         | 
| 72 | 
            -
                    layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
         | 
| 73 | 
            -
                    self.model = nn.Sequential(*layers)
         | 
| 74 | 
            -
             | 
| 75 | 
            -
             | 
| 76 | 
            -
            class vgg16(torch.nn.Module):
         | 
| 77 | 
            -
                def __init__(self, requires_grad=False, pretrained=True):
         | 
| 78 | 
            -
                    super(vgg16, self).__init__()
         | 
| 79 | 
            -
                    vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
         | 
| 80 | 
            -
                    self.slice1 = torch.nn.Sequential()
         | 
| 81 | 
            -
                    self.slice2 = torch.nn.Sequential()
         | 
| 82 | 
            -
                    self.slice3 = torch.nn.Sequential()
         | 
| 83 | 
            -
                    self.slice4 = torch.nn.Sequential()
         | 
| 84 | 
            -
                    self.slice5 = torch.nn.Sequential()
         | 
| 85 | 
            -
                    self.N_slices = 5
         | 
| 86 | 
            -
                    for x in range(4):
         | 
| 87 | 
            -
                        self.slice1.add_module(str(x), vgg_pretrained_features[x])
         | 
| 88 | 
            -
                    for x in range(4, 9):
         | 
| 89 | 
            -
                        self.slice2.add_module(str(x), vgg_pretrained_features[x])
         | 
| 90 | 
            -
                    for x in range(9, 16):
         | 
| 91 | 
            -
                        self.slice3.add_module(str(x), vgg_pretrained_features[x])
         | 
| 92 | 
            -
                    for x in range(16, 23):
         | 
| 93 | 
            -
                        self.slice4.add_module(str(x), vgg_pretrained_features[x])
         | 
| 94 | 
            -
                    for x in range(23, 30):
         | 
| 95 | 
            -
                        self.slice5.add_module(str(x), vgg_pretrained_features[x])
         | 
| 96 | 
            -
                    if not requires_grad:
         | 
| 97 | 
            -
                        for param in self.parameters():
         | 
| 98 | 
            -
                            param.requires_grad = False
         | 
| 99 | 
            -
             | 
| 100 | 
            -
                def forward(self, X):
         | 
| 101 | 
            -
                    h = self.slice1(X)
         | 
| 102 | 
            -
                    h_relu1_2 = h
         | 
| 103 | 
            -
                    h = self.slice2(h)
         | 
| 104 | 
            -
                    h_relu2_2 = h
         | 
| 105 | 
            -
                    h = self.slice3(h)
         | 
| 106 | 
            -
                    h_relu3_3 = h
         | 
| 107 | 
            -
                    h = self.slice4(h)
         | 
| 108 | 
            -
                    h_relu4_3 = h
         | 
| 109 | 
            -
                    h = self.slice5(h)
         | 
| 110 | 
            -
                    h_relu5_3 = h
         | 
| 111 | 
            -
                    vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
         | 
| 112 | 
            -
                    out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
         | 
| 113 | 
            -
                    return out
         | 
| 114 | 
            -
             | 
| 115 | 
            -
             | 
| 116 | 
            -
            def normalize_tensor(x,eps=1e-10):
         | 
| 117 | 
            -
                norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True))
         | 
| 118 | 
            -
                return x/(norm_factor+eps)
         | 
| 119 | 
            -
             | 
| 120 | 
            -
             | 
| 121 | 
            -
            def spatial_average(x, keepdim=True):
         | 
| 122 | 
            -
                return x.mean([2,3],keepdim=keepdim)
         | 
| 123 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        taming-transformers/taming/modules/losses/segmentation.py
    DELETED
    
    | @@ -1,22 +0,0 @@ | |
| 1 | 
            -
            import torch.nn as nn
         | 
| 2 | 
            -
            import torch.nn.functional as F
         | 
| 3 | 
            -
             | 
| 4 | 
            -
             | 
| 5 | 
            -
            class BCELoss(nn.Module):
         | 
| 6 | 
            -
                def forward(self, prediction, target):
         | 
| 7 | 
            -
                    loss = F.binary_cross_entropy_with_logits(prediction,target)
         | 
| 8 | 
            -
                    return loss, {}
         | 
| 9 | 
            -
             | 
| 10 | 
            -
             | 
| 11 | 
            -
            class BCELossWithQuant(nn.Module):
         | 
| 12 | 
            -
                def __init__(self, codebook_weight=1.):
         | 
| 13 | 
            -
                    super().__init__()
         | 
| 14 | 
            -
                    self.codebook_weight = codebook_weight
         | 
| 15 | 
            -
             | 
| 16 | 
            -
                def forward(self, qloss, target, prediction, split):
         | 
| 17 | 
            -
                    bce_loss = F.binary_cross_entropy_with_logits(prediction,target)
         | 
| 18 | 
            -
                    loss = bce_loss + self.codebook_weight*qloss
         | 
| 19 | 
            -
                    return loss, {"{}/total_loss".format(split): loss.clone().detach().mean(),
         | 
| 20 | 
            -
                                  "{}/bce_loss".format(split): bce_loss.detach().mean(),
         | 
| 21 | 
            -
                                  "{}/quant_loss".format(split): qloss.detach().mean()
         | 
| 22 | 
            -
                                  }
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        taming-transformers/taming/modules/losses/vqperceptual.py
    DELETED
    
    | @@ -1,136 +0,0 @@ | |
| 1 | 
            -
            import torch
         | 
| 2 | 
            -
            import torch.nn as nn
         | 
| 3 | 
            -
            import torch.nn.functional as F
         | 
| 4 | 
            -
             | 
| 5 | 
            -
            from taming.modules.losses.lpips import LPIPS
         | 
| 6 | 
            -
            from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
         | 
| 7 | 
            -
             | 
| 8 | 
            -
             | 
| 9 | 
            -
            class DummyLoss(nn.Module):
         | 
| 10 | 
            -
                def __init__(self):
         | 
| 11 | 
            -
                    super().__init__()
         | 
| 12 | 
            -
             | 
| 13 | 
            -
             | 
| 14 | 
            -
            def adopt_weight(weight, global_step, threshold=0, value=0.):
         | 
| 15 | 
            -
                if global_step < threshold:
         | 
| 16 | 
            -
                    weight = value
         | 
| 17 | 
            -
                return weight
         | 
| 18 | 
            -
             | 
| 19 | 
            -
             | 
| 20 | 
            -
            def hinge_d_loss(logits_real, logits_fake):
         | 
| 21 | 
            -
                loss_real = torch.mean(F.relu(1. - logits_real))
         | 
| 22 | 
            -
                loss_fake = torch.mean(F.relu(1. + logits_fake))
         | 
| 23 | 
            -
                d_loss = 0.5 * (loss_real + loss_fake)
         | 
| 24 | 
            -
                return d_loss
         | 
| 25 | 
            -
             | 
| 26 | 
            -
             | 
| 27 | 
            -
            def vanilla_d_loss(logits_real, logits_fake):
         | 
| 28 | 
            -
                d_loss = 0.5 * (
         | 
| 29 | 
            -
                    torch.mean(torch.nn.functional.softplus(-logits_real)) +
         | 
| 30 | 
            -
                    torch.mean(torch.nn.functional.softplus(logits_fake)))
         | 
| 31 | 
            -
                return d_loss
         | 
| 32 | 
            -
             | 
| 33 | 
            -
             | 
| 34 | 
            -
            class VQLPIPSWithDiscriminator(nn.Module):
         | 
| 35 | 
            -
                def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
         | 
| 36 | 
            -
                             disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
         | 
| 37 | 
            -
                             perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
         | 
| 38 | 
            -
                             disc_ndf=64, disc_loss="hinge"):
         | 
| 39 | 
            -
                    super().__init__()
         | 
| 40 | 
            -
                    assert disc_loss in ["hinge", "vanilla"]
         | 
| 41 | 
            -
                    self.codebook_weight = codebook_weight
         | 
| 42 | 
            -
                    self.pixel_weight = pixelloss_weight
         | 
| 43 | 
            -
                    self.perceptual_loss = LPIPS().eval()
         | 
| 44 | 
            -
                    self.perceptual_weight = perceptual_weight
         | 
| 45 | 
            -
             | 
| 46 | 
            -
                    self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
         | 
| 47 | 
            -
                                                             n_layers=disc_num_layers,
         | 
| 48 | 
            -
                                                             use_actnorm=use_actnorm,
         | 
| 49 | 
            -
                                                             ndf=disc_ndf
         | 
| 50 | 
            -
                                                             ).apply(weights_init)
         | 
| 51 | 
            -
                    self.discriminator_iter_start = disc_start
         | 
| 52 | 
            -
                    if disc_loss == "hinge":
         | 
| 53 | 
            -
                        self.disc_loss = hinge_d_loss
         | 
| 54 | 
            -
                    elif disc_loss == "vanilla":
         | 
| 55 | 
            -
                        self.disc_loss = vanilla_d_loss
         | 
| 56 | 
            -
                    else:
         | 
| 57 | 
            -
                        raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
         | 
| 58 | 
            -
                    print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
         | 
| 59 | 
            -
                    self.disc_factor = disc_factor
         | 
| 60 | 
            -
                    self.discriminator_weight = disc_weight
         | 
| 61 | 
            -
                    self.disc_conditional = disc_conditional
         | 
| 62 | 
            -
             | 
| 63 | 
            -
                def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
         | 
| 64 | 
            -
                    if last_layer is not None:
         | 
| 65 | 
            -
                        nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
         | 
| 66 | 
            -
                        g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
         | 
| 67 | 
            -
                    else:
         | 
| 68 | 
            -
                        nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
         | 
| 69 | 
            -
                        g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
         | 
| 70 | 
            -
             | 
| 71 | 
            -
                    d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
         | 
| 72 | 
            -
                    d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
         | 
| 73 | 
            -
                    d_weight = d_weight * self.discriminator_weight
         | 
| 74 | 
            -
                    return d_weight
         | 
| 75 | 
            -
             | 
| 76 | 
            -
                def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
         | 
| 77 | 
            -
                            global_step, last_layer=None, cond=None, split="train"):
         | 
| 78 | 
            -
                    rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
         | 
| 79 | 
            -
                    if self.perceptual_weight > 0:
         | 
| 80 | 
            -
                        p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
         | 
| 81 | 
            -
                        rec_loss = rec_loss + self.perceptual_weight * p_loss
         | 
| 82 | 
            -
                    else:
         | 
| 83 | 
            -
                        p_loss = torch.tensor([0.0])
         | 
| 84 | 
            -
             | 
| 85 | 
            -
                    nll_loss = rec_loss
         | 
| 86 | 
            -
                    #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
         | 
| 87 | 
            -
                    nll_loss = torch.mean(nll_loss)
         | 
| 88 | 
            -
             | 
| 89 | 
            -
                    # now the GAN part
         | 
| 90 | 
            -
                    if optimizer_idx == 0:
         | 
| 91 | 
            -
                        # generator update
         | 
| 92 | 
            -
                        if cond is None:
         | 
| 93 | 
            -
                            assert not self.disc_conditional
         | 
| 94 | 
            -
                            logits_fake = self.discriminator(reconstructions.contiguous())
         | 
| 95 | 
            -
                        else:
         | 
| 96 | 
            -
                            assert self.disc_conditional
         | 
| 97 | 
            -
                            logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
         | 
| 98 | 
            -
                        g_loss = -torch.mean(logits_fake)
         | 
| 99 | 
            -
             | 
| 100 | 
            -
                        try:
         | 
| 101 | 
            -
                            d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
         | 
| 102 | 
            -
                        except RuntimeError:
         | 
| 103 | 
            -
                            assert not self.training
         | 
| 104 | 
            -
                            d_weight = torch.tensor(0.0)
         | 
| 105 | 
            -
             | 
| 106 | 
            -
                        disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
         | 
| 107 | 
            -
                        loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
         | 
| 108 | 
            -
             | 
| 109 | 
            -
                        log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
         | 
| 110 | 
            -
                               "{}/quant_loss".format(split): codebook_loss.detach().mean(),
         | 
| 111 | 
            -
                               "{}/nll_loss".format(split): nll_loss.detach().mean(),
         | 
| 112 | 
            -
                               "{}/rec_loss".format(split): rec_loss.detach().mean(),
         | 
| 113 | 
            -
                               "{}/p_loss".format(split): p_loss.detach().mean(),
         | 
| 114 | 
            -
                               "{}/d_weight".format(split): d_weight.detach(),
         | 
| 115 | 
            -
                               "{}/disc_factor".format(split): torch.tensor(disc_factor),
         | 
| 116 | 
            -
                               "{}/g_loss".format(split): g_loss.detach().mean(),
         | 
| 117 | 
            -
                               }
         | 
| 118 | 
            -
                        return loss, log
         | 
| 119 | 
            -
             | 
| 120 | 
            -
                    if optimizer_idx == 1:
         | 
| 121 | 
            -
                        # second pass for discriminator update
         | 
| 122 | 
            -
                        if cond is None:
         | 
| 123 | 
            -
                            logits_real = self.discriminator(inputs.contiguous().detach())
         | 
| 124 | 
            -
                            logits_fake = self.discriminator(reconstructions.contiguous().detach())
         | 
| 125 | 
            -
                        else:
         | 
| 126 | 
            -
                            logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
         | 
| 127 | 
            -
                            logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
         | 
| 128 | 
            -
             | 
| 129 | 
            -
                        disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
         | 
| 130 | 
            -
                        d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
         | 
| 131 | 
            -
             | 
| 132 | 
            -
                        log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
         | 
| 133 | 
            -
                               "{}/logits_real".format(split): logits_real.detach().mean(),
         | 
| 134 | 
            -
                               "{}/logits_fake".format(split): logits_fake.detach().mean()
         | 
| 135 | 
            -
                               }
         | 
| 136 | 
            -
                        return d_loss, log
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        taming-transformers/taming/modules/misc/coord.py
    DELETED
    
    | @@ -1,31 +0,0 @@ | |
| 1 | 
            -
            import torch
         | 
| 2 | 
            -
             | 
| 3 | 
            -
            class CoordStage(object):
         | 
| 4 | 
            -
                def __init__(self, n_embed, down_factor):
         | 
| 5 | 
            -
                    self.n_embed = n_embed
         | 
| 6 | 
            -
                    self.down_factor = down_factor
         | 
| 7 | 
            -
             | 
| 8 | 
            -
                def eval(self):
         | 
| 9 | 
            -
                    return self
         | 
| 10 | 
            -
             | 
| 11 | 
            -
                def encode(self, c):
         | 
| 12 | 
            -
                    """fake vqmodel interface"""
         | 
| 13 | 
            -
                    assert 0.0 <= c.min() and c.max() <= 1.0
         | 
| 14 | 
            -
                    b,ch,h,w = c.shape
         | 
| 15 | 
            -
                    assert ch == 1
         | 
| 16 | 
            -
             | 
| 17 | 
            -
                    c = torch.nn.functional.interpolate(c, scale_factor=1/self.down_factor,
         | 
| 18 | 
            -
                                                        mode="area")
         | 
| 19 | 
            -
                    c = c.clamp(0.0, 1.0)
         | 
| 20 | 
            -
                    c = self.n_embed*c
         | 
| 21 | 
            -
                    c_quant = c.round()
         | 
| 22 | 
            -
                    c_ind = c_quant.to(dtype=torch.long)
         | 
| 23 | 
            -
             | 
| 24 | 
            -
                    info = None, None, c_ind
         | 
| 25 | 
            -
                    return c_quant, None, info
         | 
| 26 | 
            -
             | 
| 27 | 
            -
                def decode(self, c):
         | 
| 28 | 
            -
                    c = c/self.n_embed
         | 
| 29 | 
            -
                    c = torch.nn.functional.interpolate(c, scale_factor=self.down_factor,
         | 
| 30 | 
            -
                                                        mode="nearest")
         | 
| 31 | 
            -
                    return c
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        taming-transformers/taming/modules/transformer/mingpt.py
    DELETED
    
    | @@ -1,415 +0,0 @@ | |
| 1 | 
            -
            """
         | 
| 2 | 
            -
            taken from: https://github.com/karpathy/minGPT/
         | 
| 3 | 
            -
            GPT model:
         | 
| 4 | 
            -
            - the initial stem consists of a combination of token encoding and a positional encoding
         | 
| 5 | 
            -
            - the meat of it is a uniform sequence of Transformer blocks
         | 
| 6 | 
            -
                - each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block
         | 
| 7 | 
            -
                - all blocks feed into a central residual pathway similar to resnets
         | 
| 8 | 
            -
            - the final decoder is a linear projection into a vanilla Softmax classifier
         | 
| 9 | 
            -
            """
         | 
| 10 | 
            -
             | 
| 11 | 
            -
            import math
         | 
| 12 | 
            -
            import logging
         | 
| 13 | 
            -
             | 
| 14 | 
            -
            import torch
         | 
| 15 | 
            -
            import torch.nn as nn
         | 
| 16 | 
            -
            from torch.nn import functional as F
         | 
| 17 | 
            -
            from transformers import top_k_top_p_filtering
         | 
| 18 | 
            -
             | 
| 19 | 
            -
            logger = logging.getLogger(__name__)
         | 
| 20 | 
            -
             | 
| 21 | 
            -
             | 
| 22 | 
            -
            class GPTConfig:
         | 
| 23 | 
            -
                """ base GPT config, params common to all GPT versions """
         | 
| 24 | 
            -
                embd_pdrop = 0.1
         | 
| 25 | 
            -
                resid_pdrop = 0.1
         | 
| 26 | 
            -
                attn_pdrop = 0.1
         | 
| 27 | 
            -
             | 
| 28 | 
            -
                def __init__(self, vocab_size, block_size, **kwargs):
         | 
| 29 | 
            -
                    self.vocab_size = vocab_size
         | 
| 30 | 
            -
                    self.block_size = block_size
         | 
| 31 | 
            -
                    for k,v in kwargs.items():
         | 
| 32 | 
            -
                        setattr(self, k, v)
         | 
| 33 | 
            -
             | 
| 34 | 
            -
             | 
| 35 | 
            -
            class GPT1Config(GPTConfig):
         | 
| 36 | 
            -
                """ GPT-1 like network roughly 125M params """
         | 
| 37 | 
            -
                n_layer = 12
         | 
| 38 | 
            -
                n_head = 12
         | 
| 39 | 
            -
                n_embd = 768
         | 
| 40 | 
            -
             | 
| 41 | 
            -
             | 
| 42 | 
            -
            class CausalSelfAttention(nn.Module):
         | 
| 43 | 
            -
                """
         | 
| 44 | 
            -
                A vanilla multi-head masked self-attention layer with a projection at the end.
         | 
| 45 | 
            -
                It is possible to use torch.nn.MultiheadAttention here but I am including an
         | 
| 46 | 
            -
                explicit implementation here to show that there is nothing too scary here.
         | 
| 47 | 
            -
                """
         | 
| 48 | 
            -
             | 
| 49 | 
            -
                def __init__(self, config):
         | 
| 50 | 
            -
                    super().__init__()
         | 
| 51 | 
            -
                    assert config.n_embd % config.n_head == 0
         | 
| 52 | 
            -
                    # key, query, value projections for all heads
         | 
| 53 | 
            -
                    self.key = nn.Linear(config.n_embd, config.n_embd)
         | 
| 54 | 
            -
                    self.query = nn.Linear(config.n_embd, config.n_embd)
         | 
| 55 | 
            -
                    self.value = nn.Linear(config.n_embd, config.n_embd)
         | 
| 56 | 
            -
                    # regularization
         | 
| 57 | 
            -
                    self.attn_drop = nn.Dropout(config.attn_pdrop)
         | 
| 58 | 
            -
                    self.resid_drop = nn.Dropout(config.resid_pdrop)
         | 
| 59 | 
            -
                    # output projection
         | 
| 60 | 
            -
                    self.proj = nn.Linear(config.n_embd, config.n_embd)
         | 
| 61 | 
            -
                    # causal mask to ensure that attention is only applied to the left in the input sequence
         | 
| 62 | 
            -
                    mask = torch.tril(torch.ones(config.block_size,
         | 
| 63 | 
            -
                                                 config.block_size))
         | 
| 64 | 
            -
                    if hasattr(config, "n_unmasked"):
         | 
| 65 | 
            -
                        mask[:config.n_unmasked, :config.n_unmasked] = 1
         | 
| 66 | 
            -
                    self.register_buffer("mask", mask.view(1, 1, config.block_size, config.block_size))
         | 
| 67 | 
            -
                    self.n_head = config.n_head
         | 
| 68 | 
            -
             | 
| 69 | 
            -
                def forward(self, x, layer_past=None):
         | 
| 70 | 
            -
                    B, T, C = x.size()
         | 
| 71 | 
            -
             | 
| 72 | 
            -
                    # calculate query, key, values for all heads in batch and move head forward to be the batch dim
         | 
| 73 | 
            -
                    k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
         | 
| 74 | 
            -
                    q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
         | 
| 75 | 
            -
                    v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
         | 
| 76 | 
            -
             | 
| 77 | 
            -
                    present = torch.stack((k, v))
         | 
| 78 | 
            -
                    if layer_past is not None:
         | 
| 79 | 
            -
                        past_key, past_value = layer_past
         | 
| 80 | 
            -
                        k = torch.cat((past_key, k), dim=-2)
         | 
| 81 | 
            -
                        v = torch.cat((past_value, v), dim=-2)
         | 
| 82 | 
            -
             | 
| 83 | 
            -
                    # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
         | 
| 84 | 
            -
                    att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
         | 
| 85 | 
            -
                    if layer_past is None:
         | 
| 86 | 
            -
                        att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
         | 
| 87 | 
            -
             | 
| 88 | 
            -
                    att = F.softmax(att, dim=-1)
         | 
| 89 | 
            -
                    att = self.attn_drop(att)
         | 
| 90 | 
            -
                    y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
         | 
| 91 | 
            -
                    y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
         | 
| 92 | 
            -
             | 
| 93 | 
            -
                    # output projection
         | 
| 94 | 
            -
                    y = self.resid_drop(self.proj(y))
         | 
| 95 | 
            -
                    return y, present   # TODO: check that this does not break anything
         | 
| 96 | 
            -
             | 
| 97 | 
            -
             | 
| 98 | 
            -
            class Block(nn.Module):
         | 
| 99 | 
            -
                """ an unassuming Transformer block """
         | 
| 100 | 
            -
                def __init__(self, config):
         | 
| 101 | 
            -
                    super().__init__()
         | 
| 102 | 
            -
                    self.ln1 = nn.LayerNorm(config.n_embd)
         | 
| 103 | 
            -
                    self.ln2 = nn.LayerNorm(config.n_embd)
         | 
| 104 | 
            -
                    self.attn = CausalSelfAttention(config)
         | 
| 105 | 
            -
                    self.mlp = nn.Sequential(
         | 
| 106 | 
            -
                        nn.Linear(config.n_embd, 4 * config.n_embd),
         | 
| 107 | 
            -
                        nn.GELU(),  # nice
         | 
| 108 | 
            -
                        nn.Linear(4 * config.n_embd, config.n_embd),
         | 
| 109 | 
            -
                        nn.Dropout(config.resid_pdrop),
         | 
| 110 | 
            -
                    )
         | 
| 111 | 
            -
             | 
| 112 | 
            -
                def forward(self, x, layer_past=None, return_present=False):
         | 
| 113 | 
            -
                    # TODO: check that training still works
         | 
| 114 | 
            -
                    if return_present: assert not self.training
         | 
| 115 | 
            -
                    # layer past: tuple of length two with B, nh, T, hs
         | 
| 116 | 
            -
                    attn, present = self.attn(self.ln1(x), layer_past=layer_past)
         | 
| 117 | 
            -
             | 
| 118 | 
            -
                    x = x + attn
         | 
| 119 | 
            -
                    x = x + self.mlp(self.ln2(x))
         | 
| 120 | 
            -
                    if layer_past is not None or return_present:
         | 
| 121 | 
            -
                        return x, present
         | 
| 122 | 
            -
                    return x
         | 
| 123 | 
            -
             | 
| 124 | 
            -
             | 
| 125 | 
            -
            class GPT(nn.Module):
         | 
| 126 | 
            -
                """  the full GPT language model, with a context size of block_size """
         | 
| 127 | 
            -
                def __init__(self, vocab_size, block_size, n_layer=12, n_head=8, n_embd=256,
         | 
| 128 | 
            -
                             embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0):
         | 
| 129 | 
            -
                    super().__init__()
         | 
| 130 | 
            -
                    config = GPTConfig(vocab_size=vocab_size, block_size=block_size,
         | 
| 131 | 
            -
                                       embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop,
         | 
| 132 | 
            -
                                       n_layer=n_layer, n_head=n_head, n_embd=n_embd,
         | 
| 133 | 
            -
                                       n_unmasked=n_unmasked)
         | 
| 134 | 
            -
                    # input embedding stem
         | 
| 135 | 
            -
                    self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
         | 
| 136 | 
            -
                    self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
         | 
| 137 | 
            -
                    self.drop = nn.Dropout(config.embd_pdrop)
         | 
| 138 | 
            -
                    # transformer
         | 
| 139 | 
            -
                    self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
         | 
| 140 | 
            -
                    # decoder head
         | 
| 141 | 
            -
                    self.ln_f = nn.LayerNorm(config.n_embd)
         | 
| 142 | 
            -
                    self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
         | 
| 143 | 
            -
                    self.block_size = config.block_size
         | 
| 144 | 
            -
                    self.apply(self._init_weights)
         | 
| 145 | 
            -
                    self.config = config
         | 
| 146 | 
            -
                    logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
         | 
| 147 | 
            -
             | 
| 148 | 
            -
                def get_block_size(self):
         | 
| 149 | 
            -
                    return self.block_size
         | 
| 150 | 
            -
             | 
| 151 | 
            -
                def _init_weights(self, module):
         | 
| 152 | 
            -
                    if isinstance(module, (nn.Linear, nn.Embedding)):
         | 
| 153 | 
            -
                        module.weight.data.normal_(mean=0.0, std=0.02)
         | 
| 154 | 
            -
                        if isinstance(module, nn.Linear) and module.bias is not None:
         | 
| 155 | 
            -
                            module.bias.data.zero_()
         | 
| 156 | 
            -
                    elif isinstance(module, nn.LayerNorm):
         | 
| 157 | 
            -
                        module.bias.data.zero_()
         | 
| 158 | 
            -
                        module.weight.data.fill_(1.0)
         | 
| 159 | 
            -
             | 
| 160 | 
            -
                def forward(self, idx, embeddings=None, targets=None):
         | 
| 161 | 
            -
                    # forward the GPT model
         | 
| 162 | 
            -
                    token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
         | 
| 163 | 
            -
             | 
| 164 | 
            -
                    if embeddings is not None: # prepend explicit embeddings
         | 
| 165 | 
            -
                        token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)
         | 
| 166 | 
            -
             | 
| 167 | 
            -
                    t = token_embeddings.shape[1]
         | 
| 168 | 
            -
                    assert t <= self.block_size, "Cannot forward, model block size is exhausted."
         | 
| 169 | 
            -
                    position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
         | 
| 170 | 
            -
                    x = self.drop(token_embeddings + position_embeddings)
         | 
| 171 | 
            -
                    x = self.blocks(x)
         | 
| 172 | 
            -
                    x = self.ln_f(x)
         | 
| 173 | 
            -
                    logits = self.head(x)
         | 
| 174 | 
            -
             | 
| 175 | 
            -
                    # if we are given some desired targets also calculate the loss
         | 
| 176 | 
            -
                    loss = None
         | 
| 177 | 
            -
                    if targets is not None:
         | 
| 178 | 
            -
                        loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
         | 
| 179 | 
            -
             | 
| 180 | 
            -
                    return logits, loss
         | 
| 181 | 
            -
             | 
| 182 | 
            -
                def forward_with_past(self, idx, embeddings=None, targets=None, past=None, past_length=None):
         | 
| 183 | 
            -
                    # inference only
         | 
| 184 | 
            -
                    assert not self.training
         | 
| 185 | 
            -
                    token_embeddings = self.tok_emb(idx)    # each index maps to a (learnable) vector
         | 
| 186 | 
            -
                    if embeddings is not None:              # prepend explicit embeddings
         | 
| 187 | 
            -
                        token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)
         | 
| 188 | 
            -
             | 
| 189 | 
            -
                    if past is not None:
         | 
| 190 | 
            -
                        assert past_length is not None
         | 
| 191 | 
            -
                        past = torch.cat(past, dim=-2)   # n_layer, 2, b, nh, len_past, dim_head
         | 
| 192 | 
            -
                        past_shape = list(past.shape)
         | 
| 193 | 
            -
                        expected_shape = [self.config.n_layer, 2, idx.shape[0], self.config.n_head, past_length, self.config.n_embd//self.config.n_head]
         | 
| 194 | 
            -
                        assert past_shape == expected_shape, f"{past_shape} =/= {expected_shape}"
         | 
| 195 | 
            -
                        position_embeddings = self.pos_emb[:, past_length, :]  # each position maps to a (learnable) vector
         | 
| 196 | 
            -
                    else:
         | 
| 197 | 
            -
                        position_embeddings = self.pos_emb[:, :token_embeddings.shape[1], :]
         | 
| 198 | 
            -
             | 
| 199 | 
            -
                    x = self.drop(token_embeddings + position_embeddings)
         | 
| 200 | 
            -
                    presents = []  # accumulate over layers
         | 
| 201 | 
            -
                    for i, block in enumerate(self.blocks):
         | 
| 202 | 
            -
                        x, present = block(x, layer_past=past[i, ...] if past is not None else None, return_present=True)
         | 
| 203 | 
            -
                        presents.append(present)
         | 
| 204 | 
            -
             | 
| 205 | 
            -
                    x = self.ln_f(x)
         | 
| 206 | 
            -
                    logits = self.head(x)
         | 
| 207 | 
            -
                    # if we are given some desired targets also calculate the loss
         | 
| 208 | 
            -
                    loss = None
         | 
| 209 | 
            -
                    if targets is not None:
         | 
| 210 | 
            -
                        loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
         | 
| 211 | 
            -
             | 
| 212 | 
            -
                    return logits, loss, torch.stack(presents)  # _, _, n_layer, 2, b, nh, 1, dim_head
         | 
| 213 | 
            -
             | 
| 214 | 
            -
             | 
| 215 | 
            -
            class DummyGPT(nn.Module):
         | 
| 216 | 
            -
                # for debugging
         | 
| 217 | 
            -
                def __init__(self, add_value=1):
         | 
| 218 | 
            -
                    super().__init__()
         | 
| 219 | 
            -
                    self.add_value = add_value
         | 
| 220 | 
            -
             | 
| 221 | 
            -
                def forward(self, idx):
         | 
| 222 | 
            -
                    return idx + self.add_value, None
         | 
| 223 | 
            -
             | 
| 224 | 
            -
             | 
| 225 | 
            -
            class CodeGPT(nn.Module):
         | 
| 226 | 
            -
                """Takes in semi-embeddings"""
         | 
| 227 | 
            -
                def __init__(self, vocab_size, block_size, in_channels, n_layer=12, n_head=8, n_embd=256,
         | 
| 228 | 
            -
                             embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0):
         | 
| 229 | 
            -
                    super().__init__()
         | 
| 230 | 
            -
                    config = GPTConfig(vocab_size=vocab_size, block_size=block_size,
         | 
| 231 | 
            -
                                       embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop,
         | 
| 232 | 
            -
                                       n_layer=n_layer, n_head=n_head, n_embd=n_embd,
         | 
| 233 | 
            -
                                       n_unmasked=n_unmasked)
         | 
| 234 | 
            -
                    # input embedding stem
         | 
| 235 | 
            -
                    self.tok_emb = nn.Linear(in_channels, config.n_embd)
         | 
| 236 | 
            -
                    self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
         | 
| 237 | 
            -
                    self.drop = nn.Dropout(config.embd_pdrop)
         | 
| 238 | 
            -
                    # transformer
         | 
| 239 | 
            -
                    self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
         | 
| 240 | 
            -
                    # decoder head
         | 
| 241 | 
            -
                    self.ln_f = nn.LayerNorm(config.n_embd)
         | 
| 242 | 
            -
                    self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
         | 
| 243 | 
            -
                    self.block_size = config.block_size
         | 
| 244 | 
            -
                    self.apply(self._init_weights)
         | 
| 245 | 
            -
                    self.config = config
         | 
| 246 | 
            -
                    logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
         | 
| 247 | 
            -
             | 
| 248 | 
            -
                def get_block_size(self):
         | 
| 249 | 
            -
                    return self.block_size
         | 
| 250 | 
            -
             | 
| 251 | 
            -
                def _init_weights(self, module):
         | 
| 252 | 
            -
                    if isinstance(module, (nn.Linear, nn.Embedding)):
         | 
| 253 | 
            -
                        module.weight.data.normal_(mean=0.0, std=0.02)
         | 
| 254 | 
            -
                        if isinstance(module, nn.Linear) and module.bias is not None:
         | 
| 255 | 
            -
                            module.bias.data.zero_()
         | 
| 256 | 
            -
                    elif isinstance(module, nn.LayerNorm):
         | 
| 257 | 
            -
                        module.bias.data.zero_()
         | 
| 258 | 
            -
                        module.weight.data.fill_(1.0)
         | 
| 259 | 
            -
             | 
| 260 | 
            -
                def forward(self, idx, embeddings=None, targets=None):
         | 
| 261 | 
            -
                    # forward the GPT model
         | 
| 262 | 
            -
                    token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
         | 
| 263 | 
            -
             | 
| 264 | 
            -
                    if embeddings is not None: # prepend explicit embeddings
         | 
| 265 | 
            -
                        token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)
         | 
| 266 | 
            -
             | 
| 267 | 
            -
                    t = token_embeddings.shape[1]
         | 
| 268 | 
            -
                    assert t <= self.block_size, "Cannot forward, model block size is exhausted."
         | 
| 269 | 
            -
                    position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
         | 
| 270 | 
            -
                    x = self.drop(token_embeddings + position_embeddings)
         | 
| 271 | 
            -
                    x = self.blocks(x)
         | 
| 272 | 
            -
                    x = self.taming_cinln_f(x)
         | 
| 273 | 
            -
                    logits = self.head(x)
         | 
| 274 | 
            -
             | 
| 275 | 
            -
                    # if we are given some desired targets also calculate the loss
         | 
| 276 | 
            -
                    loss = None
         | 
| 277 | 
            -
                    if targets is not None:
         | 
| 278 | 
            -
                        loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
         | 
| 279 | 
            -
             | 
| 280 | 
            -
                    return logits, loss
         | 
| 281 | 
            -
             | 
| 282 | 
            -
             | 
| 283 | 
            -
             | 
| 284 | 
            -
            #### sampling utils
         | 
| 285 | 
            -
             | 
| 286 | 
            -
            def top_k_logits(logits, k):
         | 
| 287 | 
            -
                v, ix = torch.topk(logits, k)
         | 
| 288 | 
            -
                out = logits.clone()
         | 
| 289 | 
            -
                out[out < v[:, [-1]]] = -float('Inf')
         | 
| 290 | 
            -
                return out
         | 
| 291 | 
            -
             | 
| 292 | 
            -
            @torch.no_grad()
         | 
| 293 | 
            -
            def sample(model, x, steps, temperature=1.0, sample=False, top_k=None):
         | 
| 294 | 
            -
                """
         | 
| 295 | 
            -
                take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in
         | 
| 296 | 
            -
                the sequence, feeding the predictions back into the model each time. Clearly the sampling
         | 
| 297 | 
            -
                has quadratic complexity unlike an RNN that is only linear, and has a finite context window
         | 
| 298 | 
            -
                of block_size, unlike an RNN that has an infinite context window.
         | 
| 299 | 
            -
                """
         | 
| 300 | 
            -
                block_size = model.get_block_size()
         | 
| 301 | 
            -
                model.eval()
         | 
| 302 | 
            -
                for k in range(steps):
         | 
| 303 | 
            -
                    x_cond = x if x.size(1) <= block_size else x[:, -block_size:]  # crop context if needed
         | 
| 304 | 
            -
                    logits, _ = model(x_cond)
         | 
| 305 | 
            -
                    # pluck the logits at the final step and scale by temperature
         | 
| 306 | 
            -
                    logits = logits[:, -1, :] / temperature
         | 
| 307 | 
            -
                    # optionally crop probabilities to only the top k options
         | 
| 308 | 
            -
                    if top_k is not None:
         | 
| 309 | 
            -
                        logits = top_k_logits(logits, top_k)
         | 
| 310 | 
            -
                    # apply softmax to convert to probabilities
         | 
| 311 | 
            -
                    probs = F.softmax(logits, dim=-1)
         | 
| 312 | 
            -
                    # sample from the distribution or take the most likely
         | 
| 313 | 
            -
                    if sample:
         | 
| 314 | 
            -
                        ix = torch.multinomial(probs, num_samples=1)
         | 
| 315 | 
            -
                    else:
         | 
| 316 | 
            -
                        _, ix = torch.topk(probs, k=1, dim=-1)
         | 
| 317 | 
            -
                    # append to the sequence and continue
         | 
| 318 | 
            -
                    x = torch.cat((x, ix), dim=1)
         | 
| 319 | 
            -
             | 
| 320 | 
            -
                return x
         | 
| 321 | 
            -
             | 
| 322 | 
            -
             | 
| 323 | 
            -
            @torch.no_grad()
         | 
| 324 | 
            -
            def sample_with_past(x, model, steps, temperature=1., sample_logits=True,
         | 
| 325 | 
            -
                                 top_k=None, top_p=None, callback=None):
         | 
| 326 | 
            -
                # x is conditioning
         | 
| 327 | 
            -
                sample = x
         | 
| 328 | 
            -
                cond_len = x.shape[1]
         | 
| 329 | 
            -
                past = None
         | 
| 330 | 
            -
                for n in range(steps):
         | 
| 331 | 
            -
                    if callback is not None:
         | 
| 332 | 
            -
                        callback(n)
         | 
| 333 | 
            -
                    logits, _, present = model.forward_with_past(x, past=past, past_length=(n+cond_len-1))
         | 
| 334 | 
            -
                    if past is None:
         | 
| 335 | 
            -
                        past = [present]
         | 
| 336 | 
            -
                    else:
         | 
| 337 | 
            -
                        past.append(present)
         | 
| 338 | 
            -
                    logits = logits[:, -1, :] / temperature
         | 
| 339 | 
            -
                    if top_k is not None:
         | 
| 340 | 
            -
                        logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
         | 
| 341 | 
            -
             | 
| 342 | 
            -
                    probs = F.softmax(logits, dim=-1)
         | 
| 343 | 
            -
                    if not sample_logits:
         | 
| 344 | 
            -
                        _, x = torch.topk(probs, k=1, dim=-1)
         | 
| 345 | 
            -
                    else:
         | 
| 346 | 
            -
                        x = torch.multinomial(probs, num_samples=1)
         | 
| 347 | 
            -
                    # append to the sequence and continue
         | 
| 348 | 
            -
                    sample = torch.cat((sample, x), dim=1)
         | 
| 349 | 
            -
                del past
         | 
| 350 | 
            -
                sample = sample[:, cond_len:]  # cut conditioning off
         | 
| 351 | 
            -
                return sample
         | 
| 352 | 
            -
             | 
| 353 | 
            -
             | 
| 354 | 
            -
            #### clustering utils
         | 
| 355 | 
            -
             | 
| 356 | 
            -
            class KMeans(nn.Module):
         | 
| 357 | 
            -
                def __init__(self, ncluster=512, nc=3, niter=10):
         | 
| 358 | 
            -
                    super().__init__()
         | 
| 359 | 
            -
                    self.ncluster = ncluster
         | 
| 360 | 
            -
                    self.nc = nc
         | 
| 361 | 
            -
                    self.niter = niter
         | 
| 362 | 
            -
                    self.shape = (3,32,32)
         | 
| 363 | 
            -
                    self.register_buffer("C", torch.zeros(self.ncluster,nc))
         | 
| 364 | 
            -
                    self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
         | 
| 365 | 
            -
             | 
| 366 | 
            -
                def is_initialized(self):
         | 
| 367 | 
            -
                    return self.initialized.item() == 1
         | 
| 368 | 
            -
             | 
| 369 | 
            -
                @torch.no_grad()
         | 
| 370 | 
            -
                def initialize(self, x):
         | 
| 371 | 
            -
                    N, D = x.shape
         | 
| 372 | 
            -
                    assert D == self.nc, D
         | 
| 373 | 
            -
                    c = x[torch.randperm(N)[:self.ncluster]] # init clusters at random
         | 
| 374 | 
            -
                    for i in range(self.niter):
         | 
| 375 | 
            -
                        # assign all pixels to the closest codebook element
         | 
| 376 | 
            -
                        a = ((x[:, None, :] - c[None, :, :])**2).sum(-1).argmin(1)
         | 
| 377 | 
            -
                        # move each codebook element to be the mean of the pixels that assigned to it
         | 
| 378 | 
            -
                        c = torch.stack([x[a==k].mean(0) for k in range(self.ncluster)])
         | 
| 379 | 
            -
                        # re-assign any poorly positioned codebook elements
         | 
| 380 | 
            -
                        nanix = torch.any(torch.isnan(c), dim=1)
         | 
| 381 | 
            -
                        ndead = nanix.sum().item()
         | 
| 382 | 
            -
                        print('done step %d/%d, re-initialized %d dead clusters' % (i+1, self.niter, ndead))
         | 
| 383 | 
            -
                        c[nanix] = x[torch.randperm(N)[:ndead]] # re-init dead clusters
         | 
| 384 | 
            -
             | 
| 385 | 
            -
                    self.C.copy_(c)
         | 
| 386 | 
            -
                    self.initialized.fill_(1)
         | 
| 387 | 
            -
             | 
| 388 | 
            -
             | 
| 389 | 
            -
                def forward(self, x, reverse=False, shape=None):
         | 
| 390 | 
            -
                    if not reverse:
         | 
| 391 | 
            -
                        # flatten
         | 
| 392 | 
            -
                        bs,c,h,w = x.shape
         | 
| 393 | 
            -
                        assert c == self.nc
         | 
| 394 | 
            -
                        x = x.reshape(bs,c,h*w,1)
         | 
| 395 | 
            -
                        C = self.C.permute(1,0)
         | 
| 396 | 
            -
                        C = C.reshape(1,c,1,self.ncluster)
         | 
| 397 | 
            -
                        a = ((x-C)**2).sum(1).argmin(-1) # bs, h*w indices
         | 
| 398 | 
            -
                        return a
         | 
| 399 | 
            -
                    else:
         | 
| 400 | 
            -
                        # flatten
         | 
| 401 | 
            -
                        bs, HW = x.shape
         | 
| 402 | 
            -
                        """
         | 
| 403 | 
            -
                        c = self.C.reshape( 1, self.nc,  1, self.ncluster)
         | 
| 404 | 
            -
                        c = c[bs*[0],:,:,:]
         | 
| 405 | 
            -
                        c = c[:,:,HW*[0],:]
         | 
| 406 | 
            -
                        x =      x.reshape(bs,       1, HW,             1)
         | 
| 407 | 
            -
                        x = x[:,3*[0],:,:]
         | 
| 408 | 
            -
                        x = torch.gather(c, dim=3, index=x)
         | 
| 409 | 
            -
                        """
         | 
| 410 | 
            -
                        x = self.C[x]
         | 
| 411 | 
            -
                        x = x.permute(0,2,1)
         | 
| 412 | 
            -
                        shape = shape if shape is not None else self.shape
         | 
| 413 | 
            -
                        x = x.reshape(bs, *shape)
         | 
| 414 | 
            -
             | 
| 415 | 
            -
                        return x
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        taming-transformers/taming/modules/transformer/permuter.py
    DELETED
    
    | @@ -1,248 +0,0 @@ | |
| 1 | 
            -
            import torch
         | 
| 2 | 
            -
            import torch.nn as nn
         | 
| 3 | 
            -
            import numpy as np
         | 
| 4 | 
            -
             | 
| 5 | 
            -
             | 
| 6 | 
            -
            class AbstractPermuter(nn.Module):
         | 
| 7 | 
            -
                def __init__(self, *args, **kwargs):
         | 
| 8 | 
            -
                    super().__init__()
         | 
| 9 | 
            -
                def forward(self, x, reverse=False):
         | 
| 10 | 
            -
                    raise NotImplementedError
         | 
| 11 | 
            -
             | 
| 12 | 
            -
             | 
| 13 | 
            -
            class Identity(AbstractPermuter):
         | 
| 14 | 
            -
                def __init__(self):
         | 
| 15 | 
            -
                    super().__init__()
         | 
| 16 | 
            -
             | 
| 17 | 
            -
                def forward(self, x, reverse=False):
         | 
| 18 | 
            -
                    return x
         | 
| 19 | 
            -
             | 
| 20 | 
            -
             | 
| 21 | 
            -
            class Subsample(AbstractPermuter):
         | 
| 22 | 
            -
                def __init__(self, H, W):
         | 
| 23 | 
            -
                    super().__init__()
         | 
| 24 | 
            -
                    C = 1
         | 
| 25 | 
            -
                    indices = np.arange(H*W).reshape(C,H,W)
         | 
| 26 | 
            -
                    while min(H, W) > 1:
         | 
| 27 | 
            -
                        indices = indices.reshape(C,H//2,2,W//2,2)
         | 
| 28 | 
            -
                        indices = indices.transpose(0,2,4,1,3)
         | 
| 29 | 
            -
                        indices = indices.reshape(C*4,H//2, W//2)
         | 
| 30 | 
            -
                        H = H//2
         | 
| 31 | 
            -
                        W = W//2
         | 
| 32 | 
            -
                        C = C*4
         | 
| 33 | 
            -
                    assert H == W == 1
         | 
| 34 | 
            -
                    idx = torch.tensor(indices.ravel())
         | 
| 35 | 
            -
                    self.register_buffer('forward_shuffle_idx',
         | 
| 36 | 
            -
                                         nn.Parameter(idx, requires_grad=False))
         | 
| 37 | 
            -
                    self.register_buffer('backward_shuffle_idx',
         | 
| 38 | 
            -
                                         nn.Parameter(torch.argsort(idx), requires_grad=False))
         | 
| 39 | 
            -
             | 
| 40 | 
            -
                def forward(self, x, reverse=False):
         | 
| 41 | 
            -
                    if not reverse:
         | 
| 42 | 
            -
                        return x[:, self.forward_shuffle_idx]
         | 
| 43 | 
            -
                    else:
         | 
| 44 | 
            -
                        return x[:, self.backward_shuffle_idx]
         | 
| 45 | 
            -
             | 
| 46 | 
            -
             | 
| 47 | 
            -
            def mortonify(i, j):
         | 
| 48 | 
            -
                """(i,j) index to linear morton code"""
         | 
| 49 | 
            -
                i = np.uint64(i)
         | 
| 50 | 
            -
                j = np.uint64(j)
         | 
| 51 | 
            -
             | 
| 52 | 
            -
                z = np.uint(0)
         | 
| 53 | 
            -
             | 
| 54 | 
            -
                for pos in range(32):
         | 
| 55 | 
            -
                    z = (z |
         | 
| 56 | 
            -
                         ((j & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos)) |
         | 
| 57 | 
            -
                         ((i & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos+1))
         | 
| 58 | 
            -
                         )
         | 
| 59 | 
            -
                return z
         | 
| 60 | 
            -
             | 
| 61 | 
            -
             | 
| 62 | 
            -
            class ZCurve(AbstractPermuter):
         | 
| 63 | 
            -
                def __init__(self, H, W):
         | 
| 64 | 
            -
                    super().__init__()
         | 
| 65 | 
            -
                    reverseidx = [np.int64(mortonify(i,j)) for i in range(H) for j in range(W)]
         | 
| 66 | 
            -
                    idx = np.argsort(reverseidx)
         | 
| 67 | 
            -
                    idx = torch.tensor(idx)
         | 
| 68 | 
            -
                    reverseidx = torch.tensor(reverseidx)
         | 
| 69 | 
            -
                    self.register_buffer('forward_shuffle_idx',
         | 
| 70 | 
            -
                                         idx)
         | 
| 71 | 
            -
                    self.register_buffer('backward_shuffle_idx',
         | 
| 72 | 
            -
                                         reverseidx)
         | 
| 73 | 
            -
             | 
| 74 | 
            -
                def forward(self, x, reverse=False):
         | 
| 75 | 
            -
                    if not reverse:
         | 
| 76 | 
            -
                        return x[:, self.forward_shuffle_idx]
         | 
| 77 | 
            -
                    else:
         | 
| 78 | 
            -
                        return x[:, self.backward_shuffle_idx]
         | 
| 79 | 
            -
             | 
| 80 | 
            -
             | 
| 81 | 
            -
            class SpiralOut(AbstractPermuter):
         | 
| 82 | 
            -
                def __init__(self, H, W):
         | 
| 83 | 
            -
                    super().__init__()
         | 
| 84 | 
            -
                    assert H == W
         | 
| 85 | 
            -
                    size = W
         | 
| 86 | 
            -
                    indices = np.arange(size*size).reshape(size,size)
         | 
| 87 | 
            -
             | 
| 88 | 
            -
                    i0 = size//2
         | 
| 89 | 
            -
                    j0 = size//2-1
         | 
| 90 | 
            -
             | 
| 91 | 
            -
                    i = i0
         | 
| 92 | 
            -
                    j = j0
         | 
| 93 | 
            -
             | 
| 94 | 
            -
                    idx = [indices[i0, j0]]
         | 
| 95 | 
            -
                    step_mult = 0
         | 
| 96 | 
            -
                    for c in range(1, size//2+1):
         | 
| 97 | 
            -
                        step_mult += 1
         | 
| 98 | 
            -
                        # steps left
         | 
| 99 | 
            -
                        for k in range(step_mult):
         | 
| 100 | 
            -
                            i = i - 1
         | 
| 101 | 
            -
                            j = j
         | 
| 102 | 
            -
                            idx.append(indices[i, j])
         | 
| 103 | 
            -
             | 
| 104 | 
            -
                        # step down
         | 
| 105 | 
            -
                        for k in range(step_mult):
         | 
| 106 | 
            -
                            i = i
         | 
| 107 | 
            -
                            j = j + 1
         | 
| 108 | 
            -
                            idx.append(indices[i, j])
         | 
| 109 | 
            -
             | 
| 110 | 
            -
                        step_mult += 1
         | 
| 111 | 
            -
                        if c < size//2:
         | 
| 112 | 
            -
                            # step right
         | 
| 113 | 
            -
                            for k in range(step_mult):
         | 
| 114 | 
            -
                                i = i + 1
         | 
| 115 | 
            -
                                j = j
         | 
| 116 | 
            -
                                idx.append(indices[i, j])
         | 
| 117 | 
            -
             | 
| 118 | 
            -
                            # step up
         | 
| 119 | 
            -
                            for k in range(step_mult):
         | 
| 120 | 
            -
                                i = i
         | 
| 121 | 
            -
                                j = j - 1
         | 
| 122 | 
            -
                                idx.append(indices[i, j])
         | 
| 123 | 
            -
                        else:
         | 
| 124 | 
            -
                            # end reached
         | 
| 125 | 
            -
                            for k in range(step_mult-1):
         | 
| 126 | 
            -
                                i = i + 1
         | 
| 127 | 
            -
                                idx.append(indices[i, j])
         | 
| 128 | 
            -
             | 
| 129 | 
            -
                    assert len(idx) == size*size
         | 
| 130 | 
            -
                    idx = torch.tensor(idx)
         | 
| 131 | 
            -
                    self.register_buffer('forward_shuffle_idx', idx)
         | 
| 132 | 
            -
                    self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
         | 
| 133 | 
            -
             | 
| 134 | 
            -
                def forward(self, x, reverse=False):
         | 
| 135 | 
            -
                    if not reverse:
         | 
| 136 | 
            -
                        return x[:, self.forward_shuffle_idx]
         | 
| 137 | 
            -
                    else:
         | 
| 138 | 
            -
                        return x[:, self.backward_shuffle_idx]
         | 
| 139 | 
            -
             | 
| 140 | 
            -
             | 
| 141 | 
            -
            class SpiralIn(AbstractPermuter):
         | 
| 142 | 
            -
                def __init__(self, H, W):
         | 
| 143 | 
            -
                    super().__init__()
         | 
| 144 | 
            -
                    assert H == W
         | 
| 145 | 
            -
                    size = W
         | 
| 146 | 
            -
                    indices = np.arange(size*size).reshape(size,size)
         | 
| 147 | 
            -
             | 
| 148 | 
            -
                    i0 = size//2
         | 
| 149 | 
            -
                    j0 = size//2-1
         | 
| 150 | 
            -
             | 
| 151 | 
            -
                    i = i0
         | 
| 152 | 
            -
                    j = j0
         | 
| 153 | 
            -
             | 
| 154 | 
            -
                    idx = [indices[i0, j0]]
         | 
| 155 | 
            -
                    step_mult = 0
         | 
| 156 | 
            -
                    for c in range(1, size//2+1):
         | 
| 157 | 
            -
                        step_mult += 1
         | 
| 158 | 
            -
                        # steps left
         | 
| 159 | 
            -
                        for k in range(step_mult):
         | 
| 160 | 
            -
                            i = i - 1
         | 
| 161 | 
            -
                            j = j
         | 
| 162 | 
            -
                            idx.append(indices[i, j])
         | 
| 163 | 
            -
             | 
| 164 | 
            -
                        # step down
         | 
| 165 | 
            -
                        for k in range(step_mult):
         | 
| 166 | 
            -
                            i = i
         | 
| 167 | 
            -
                            j = j + 1
         | 
| 168 | 
            -
                            idx.append(indices[i, j])
         | 
| 169 | 
            -
             | 
| 170 | 
            -
                        step_mult += 1
         | 
| 171 | 
            -
                        if c < size//2:
         | 
| 172 | 
            -
                            # step right
         | 
| 173 | 
            -
                            for k in range(step_mult):
         | 
| 174 | 
            -
                                i = i + 1
         | 
| 175 | 
            -
                                j = j
         | 
| 176 | 
            -
                                idx.append(indices[i, j])
         | 
| 177 | 
            -
             | 
| 178 | 
            -
                            # step up
         | 
| 179 | 
            -
                            for k in range(step_mult):
         | 
| 180 | 
            -
                                i = i
         | 
| 181 | 
            -
                                j = j - 1
         | 
| 182 | 
            -
                                idx.append(indices[i, j])
         | 
| 183 | 
            -
                        else:
         | 
| 184 | 
            -
                            # end reached
         | 
| 185 | 
            -
                            for k in range(step_mult-1):
         | 
| 186 | 
            -
                                i = i + 1
         | 
| 187 | 
            -
                                idx.append(indices[i, j])
         | 
| 188 | 
            -
             | 
| 189 | 
            -
                    assert len(idx) == size*size
         | 
| 190 | 
            -
                    idx = idx[::-1]
         | 
| 191 | 
            -
                    idx = torch.tensor(idx)
         | 
| 192 | 
            -
                    self.register_buffer('forward_shuffle_idx', idx)
         | 
| 193 | 
            -
                    self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
         | 
| 194 | 
            -
             | 
| 195 | 
            -
                def forward(self, x, reverse=False):
         | 
| 196 | 
            -
                    if not reverse:
         | 
| 197 | 
            -
                        return x[:, self.forward_shuffle_idx]
         | 
| 198 | 
            -
                    else:
         | 
| 199 | 
            -
                        return x[:, self.backward_shuffle_idx]
         | 
| 200 | 
            -
             | 
| 201 | 
            -
             | 
| 202 | 
            -
            class Random(nn.Module):
         | 
| 203 | 
            -
                def __init__(self, H, W):
         | 
| 204 | 
            -
                    super().__init__()
         | 
| 205 | 
            -
                    indices = np.random.RandomState(1).permutation(H*W)
         | 
| 206 | 
            -
                    idx = torch.tensor(indices.ravel())
         | 
| 207 | 
            -
                    self.register_buffer('forward_shuffle_idx', idx)
         | 
| 208 | 
            -
                    self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
         | 
| 209 | 
            -
             | 
| 210 | 
            -
                def forward(self, x, reverse=False):
         | 
| 211 | 
            -
                    if not reverse:
         | 
| 212 | 
            -
                        return x[:, self.forward_shuffle_idx]
         | 
| 213 | 
            -
                    else:
         | 
| 214 | 
            -
                        return x[:, self.backward_shuffle_idx]
         | 
| 215 | 
            -
             | 
| 216 | 
            -
             | 
| 217 | 
            -
            class AlternateParsing(AbstractPermuter):
         | 
| 218 | 
            -
                def __init__(self, H, W):
         | 
| 219 | 
            -
                    super().__init__()
         | 
| 220 | 
            -
                    indices = np.arange(W*H).reshape(H,W)
         | 
| 221 | 
            -
                    for i in range(1, H, 2):
         | 
| 222 | 
            -
                        indices[i, :] = indices[i, ::-1]
         | 
| 223 | 
            -
                    idx = indices.flatten()
         | 
| 224 | 
            -
                    assert len(idx) == H*W
         | 
| 225 | 
            -
                    idx = torch.tensor(idx)
         | 
| 226 | 
            -
                    self.register_buffer('forward_shuffle_idx', idx)
         | 
| 227 | 
            -
                    self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
         | 
| 228 | 
            -
             | 
| 229 | 
            -
                def forward(self, x, reverse=False):
         | 
| 230 | 
            -
                    if not reverse:
         | 
| 231 | 
            -
                        return x[:, self.forward_shuffle_idx]
         | 
| 232 | 
            -
                    else:
         | 
| 233 | 
            -
                        return x[:, self.backward_shuffle_idx]
         | 
| 234 | 
            -
             | 
| 235 | 
            -
             | 
| 236 | 
            -
            if __name__ == "__main__":
         | 
| 237 | 
            -
                p0 = AlternateParsing(16, 16)
         | 
| 238 | 
            -
                print(p0.forward_shuffle_idx)
         | 
| 239 | 
            -
                print(p0.backward_shuffle_idx)
         | 
| 240 | 
            -
             | 
| 241 | 
            -
                x = torch.randint(0, 768, size=(11, 256))
         | 
| 242 | 
            -
                y = p0(x)
         | 
| 243 | 
            -
                xre = p0(y, reverse=True)
         | 
| 244 | 
            -
                assert torch.equal(x, xre)
         | 
| 245 | 
            -
             | 
| 246 | 
            -
                p1 = SpiralOut(2, 2)
         | 
| 247 | 
            -
                print(p1.forward_shuffle_idx)
         | 
| 248 | 
            -
                print(p1.backward_shuffle_idx)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        taming-transformers/taming/modules/util.py
    DELETED
    
    | @@ -1,130 +0,0 @@ | |
| 1 | 
            -
            import torch
         | 
| 2 | 
            -
            import torch.nn as nn
         | 
| 3 | 
            -
             | 
| 4 | 
            -
             | 
| 5 | 
            -
            def count_params(model):
         | 
| 6 | 
            -
                total_params = sum(p.numel() for p in model.parameters())
         | 
| 7 | 
            -
                return total_params
         | 
| 8 | 
            -
             | 
| 9 | 
            -
             | 
| 10 | 
            -
            class ActNorm(nn.Module):
         | 
| 11 | 
            -
                def __init__(self, num_features, logdet=False, affine=True,
         | 
| 12 | 
            -
                             allow_reverse_init=False):
         | 
| 13 | 
            -
                    assert affine
         | 
| 14 | 
            -
                    super().__init__()
         | 
| 15 | 
            -
                    self.logdet = logdet
         | 
| 16 | 
            -
                    self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
         | 
| 17 | 
            -
                    self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
         | 
| 18 | 
            -
                    self.allow_reverse_init = allow_reverse_init
         | 
| 19 | 
            -
             | 
| 20 | 
            -
                    self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
         | 
| 21 | 
            -
             | 
| 22 | 
            -
                def initialize(self, input):
         | 
| 23 | 
            -
                    with torch.no_grad():
         | 
| 24 | 
            -
                        flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
         | 
| 25 | 
            -
                        mean = (
         | 
| 26 | 
            -
                            flatten.mean(1)
         | 
| 27 | 
            -
                            .unsqueeze(1)
         | 
| 28 | 
            -
                            .unsqueeze(2)
         | 
| 29 | 
            -
                            .unsqueeze(3)
         | 
| 30 | 
            -
                            .permute(1, 0, 2, 3)
         | 
| 31 | 
            -
                        )
         | 
| 32 | 
            -
                        std = (
         | 
| 33 | 
            -
                            flatten.std(1)
         | 
| 34 | 
            -
                            .unsqueeze(1)
         | 
| 35 | 
            -
                            .unsqueeze(2)
         | 
| 36 | 
            -
                            .unsqueeze(3)
         | 
| 37 | 
            -
                            .permute(1, 0, 2, 3)
         | 
| 38 | 
            -
                        )
         | 
| 39 | 
            -
             | 
| 40 | 
            -
                        self.loc.data.copy_(-mean)
         | 
| 41 | 
            -
                        self.scale.data.copy_(1 / (std + 1e-6))
         | 
| 42 | 
            -
             | 
| 43 | 
            -
                def forward(self, input, reverse=False):
         | 
| 44 | 
            -
                    if reverse:
         | 
| 45 | 
            -
                        return self.reverse(input)
         | 
| 46 | 
            -
                    if len(input.shape) == 2:
         | 
| 47 | 
            -
                        input = input[:,:,None,None]
         | 
| 48 | 
            -
                        squeeze = True
         | 
| 49 | 
            -
                    else:
         | 
| 50 | 
            -
                        squeeze = False
         | 
| 51 | 
            -
             | 
| 52 | 
            -
                    _, _, height, width = input.shape
         | 
| 53 | 
            -
             | 
| 54 | 
            -
                    if self.training and self.initialized.item() == 0:
         | 
| 55 | 
            -
                        self.initialize(input)
         | 
| 56 | 
            -
                        self.initialized.fill_(1)
         | 
| 57 | 
            -
             | 
| 58 | 
            -
                    h = self.scale * (input + self.loc)
         | 
| 59 | 
            -
             | 
| 60 | 
            -
                    if squeeze:
         | 
| 61 | 
            -
                        h = h.squeeze(-1).squeeze(-1)
         | 
| 62 | 
            -
             | 
| 63 | 
            -
                    if self.logdet:
         | 
| 64 | 
            -
                        log_abs = torch.log(torch.abs(self.scale))
         | 
| 65 | 
            -
                        logdet = height*width*torch.sum(log_abs)
         | 
| 66 | 
            -
                        logdet = logdet * torch.ones(input.shape[0]).to(input)
         | 
| 67 | 
            -
                        return h, logdet
         | 
| 68 | 
            -
             | 
| 69 | 
            -
                    return h
         | 
| 70 | 
            -
             | 
| 71 | 
            -
                def reverse(self, output):
         | 
| 72 | 
            -
                    if self.training and self.initialized.item() == 0:
         | 
| 73 | 
            -
                        if not self.allow_reverse_init:
         | 
| 74 | 
            -
                            raise RuntimeError(
         | 
| 75 | 
            -
                                "Initializing ActNorm in reverse direction is "
         | 
| 76 | 
            -
                                "disabled by default. Use allow_reverse_init=True to enable."
         | 
| 77 | 
            -
                            )
         | 
| 78 | 
            -
                        else:
         | 
| 79 | 
            -
                            self.initialize(output)
         | 
| 80 | 
            -
                            self.initialized.fill_(1)
         | 
| 81 | 
            -
             | 
| 82 | 
            -
                    if len(output.shape) == 2:
         | 
| 83 | 
            -
                        output = output[:,:,None,None]
         | 
| 84 | 
            -
                        squeeze = True
         | 
| 85 | 
            -
                    else:
         | 
| 86 | 
            -
                        squeeze = False
         | 
| 87 | 
            -
             | 
| 88 | 
            -
                    h = output / self.scale - self.loc
         | 
| 89 | 
            -
             | 
| 90 | 
            -
                    if squeeze:
         | 
| 91 | 
            -
                        h = h.squeeze(-1).squeeze(-1)
         | 
| 92 | 
            -
                    return h
         | 
| 93 | 
            -
             | 
| 94 | 
            -
             | 
| 95 | 
            -
            class AbstractEncoder(nn.Module):
         | 
| 96 | 
            -
                def __init__(self):
         | 
| 97 | 
            -
                    super().__init__()
         | 
| 98 | 
            -
             | 
| 99 | 
            -
                def encode(self, *args, **kwargs):
         | 
| 100 | 
            -
                    raise NotImplementedError
         | 
| 101 | 
            -
             | 
| 102 | 
            -
             | 
| 103 | 
            -
            class Labelator(AbstractEncoder):
         | 
| 104 | 
            -
                """Net2Net Interface for Class-Conditional Model"""
         | 
| 105 | 
            -
                def __init__(self, n_classes, quantize_interface=True):
         | 
| 106 | 
            -
                    super().__init__()
         | 
| 107 | 
            -
                    self.n_classes = n_classes
         | 
| 108 | 
            -
                    self.quantize_interface = quantize_interface
         | 
| 109 | 
            -
             | 
| 110 | 
            -
                def encode(self, c):
         | 
| 111 | 
            -
                    c = c[:,None]
         | 
| 112 | 
            -
                    if self.quantize_interface:
         | 
| 113 | 
            -
                        return c, None, [None, None, c.long()]
         | 
| 114 | 
            -
                    return c
         | 
| 115 | 
            -
             | 
| 116 | 
            -
             | 
| 117 | 
            -
            class SOSProvider(AbstractEncoder):
         | 
| 118 | 
            -
                # for unconditional training
         | 
| 119 | 
            -
                def __init__(self, sos_token, quantize_interface=True):
         | 
| 120 | 
            -
                    super().__init__()
         | 
| 121 | 
            -
                    self.sos_token = sos_token
         | 
| 122 | 
            -
                    self.quantize_interface = quantize_interface
         | 
| 123 | 
            -
             | 
| 124 | 
            -
                def encode(self, x):
         | 
| 125 | 
            -
                    # get batch size from data and replicate sos_token
         | 
| 126 | 
            -
                    c = torch.ones(x.shape[0], 1)*self.sos_token
         | 
| 127 | 
            -
                    c = c.long().to(x.device)
         | 
| 128 | 
            -
                    if self.quantize_interface:
         | 
| 129 | 
            -
                        return c, None, [None, None, c]
         | 
| 130 | 
            -
                    return c
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        taming-transformers/taming/modules/vqvae/quantize.py
    DELETED
    
    | @@ -1,445 +0,0 @@ | |
| 1 | 
            -
            import torch
         | 
| 2 | 
            -
            import torch.nn as nn
         | 
| 3 | 
            -
            import torch.nn.functional as F
         | 
| 4 | 
            -
            import numpy as np
         | 
| 5 | 
            -
            from torch import einsum
         | 
| 6 | 
            -
            from einops import rearrange
         | 
| 7 | 
            -
             | 
| 8 | 
            -
             | 
| 9 | 
            -
            class VectorQuantizer(nn.Module):
         | 
| 10 | 
            -
                """
         | 
| 11 | 
            -
                see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
         | 
| 12 | 
            -
                ____________________________________________
         | 
| 13 | 
            -
                Discretization bottleneck part of the VQ-VAE.
         | 
| 14 | 
            -
                Inputs:
         | 
| 15 | 
            -
                - n_e : number of embeddings
         | 
| 16 | 
            -
                - e_dim : dimension of embedding
         | 
| 17 | 
            -
                - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
         | 
| 18 | 
            -
                _____________________________________________
         | 
| 19 | 
            -
                """
         | 
| 20 | 
            -
             | 
| 21 | 
            -
                # NOTE: this class contains a bug regarding beta; see VectorQuantizer2 for
         | 
| 22 | 
            -
                # a fix and use legacy=False to apply that fix. VectorQuantizer2 can be
         | 
| 23 | 
            -
                # used wherever VectorQuantizer has been used before and is additionally
         | 
| 24 | 
            -
                # more efficient.
         | 
| 25 | 
            -
                def __init__(self, n_e, e_dim, beta):
         | 
| 26 | 
            -
                    super(VectorQuantizer, self).__init__()
         | 
| 27 | 
            -
                    self.n_e = n_e
         | 
| 28 | 
            -
                    self.e_dim = e_dim
         | 
| 29 | 
            -
                    self.beta = beta
         | 
| 30 | 
            -
             | 
| 31 | 
            -
                    self.embedding = nn.Embedding(self.n_e, self.e_dim)
         | 
| 32 | 
            -
                    self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
         | 
| 33 | 
            -
             | 
| 34 | 
            -
                def forward(self, z):
         | 
| 35 | 
            -
                    """
         | 
| 36 | 
            -
                    Inputs the output of the encoder network z and maps it to a discrete
         | 
| 37 | 
            -
                    one-hot vector that is the index of the closest embedding vector e_j
         | 
| 38 | 
            -
                    z (continuous) -> z_q (discrete)
         | 
| 39 | 
            -
                    z.shape = (batch, channel, height, width)
         | 
| 40 | 
            -
                    quantization pipeline:
         | 
| 41 | 
            -
                        1. get encoder input (B,C,H,W)
         | 
| 42 | 
            -
                        2. flatten input to (B*H*W,C)
         | 
| 43 | 
            -
                    """
         | 
| 44 | 
            -
                    # reshape z -> (batch, height, width, channel) and flatten
         | 
| 45 | 
            -
                    z = z.permute(0, 2, 3, 1).contiguous()
         | 
| 46 | 
            -
                    z_flattened = z.view(-1, self.e_dim)
         | 
| 47 | 
            -
                    # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
         | 
| 48 | 
            -
             | 
| 49 | 
            -
                    d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
         | 
| 50 | 
            -
                        torch.sum(self.embedding.weight**2, dim=1) - 2 * \
         | 
| 51 | 
            -
                        torch.matmul(z_flattened, self.embedding.weight.t())
         | 
| 52 | 
            -
             | 
| 53 | 
            -
                    ## could possible replace this here
         | 
| 54 | 
            -
                    # #\start...
         | 
| 55 | 
            -
                    # find closest encodings
         | 
| 56 | 
            -
                    min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
         | 
| 57 | 
            -
             | 
| 58 | 
            -
                    min_encodings = torch.zeros(
         | 
| 59 | 
            -
                        min_encoding_indices.shape[0], self.n_e).to(z)
         | 
| 60 | 
            -
                    min_encodings.scatter_(1, min_encoding_indices, 1)
         | 
| 61 | 
            -
             | 
| 62 | 
            -
                    # dtype min encodings: torch.float32
         | 
| 63 | 
            -
                    # min_encodings shape: torch.Size([2048, 512])
         | 
| 64 | 
            -
                    # min_encoding_indices.shape: torch.Size([2048, 1])
         | 
| 65 | 
            -
             | 
| 66 | 
            -
                    # get quantized latent vectors
         | 
| 67 | 
            -
                    z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
         | 
| 68 | 
            -
                    #.........\end
         | 
| 69 | 
            -
             | 
| 70 | 
            -
                    # with:
         | 
| 71 | 
            -
                    # .........\start
         | 
| 72 | 
            -
                    #min_encoding_indices = torch.argmin(d, dim=1)
         | 
| 73 | 
            -
                    #z_q = self.embedding(min_encoding_indices)
         | 
| 74 | 
            -
                    # ......\end......... (TODO)
         | 
| 75 | 
            -
             | 
| 76 | 
            -
                    # compute loss for embedding
         | 
| 77 | 
            -
                    loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
         | 
| 78 | 
            -
                        torch.mean((z_q - z.detach()) ** 2)
         | 
| 79 | 
            -
             | 
| 80 | 
            -
                    # preserve gradients
         | 
| 81 | 
            -
                    z_q = z + (z_q - z).detach()
         | 
| 82 | 
            -
             | 
| 83 | 
            -
                    # perplexity
         | 
| 84 | 
            -
                    e_mean = torch.mean(min_encodings, dim=0)
         | 
| 85 | 
            -
                    perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
         | 
| 86 | 
            -
             | 
| 87 | 
            -
                    # reshape back to match original input shape
         | 
| 88 | 
            -
                    z_q = z_q.permute(0, 3, 1, 2).contiguous()
         | 
| 89 | 
            -
             | 
| 90 | 
            -
                    return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
         | 
| 91 | 
            -
             | 
| 92 | 
            -
                def get_codebook_entry(self, indices, shape):
         | 
| 93 | 
            -
                    # shape specifying (batch, height, width, channel)
         | 
| 94 | 
            -
                    # TODO: check for more easy handling with nn.Embedding
         | 
| 95 | 
            -
                    min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices)
         | 
| 96 | 
            -
                    min_encodings.scatter_(1, indices[:,None], 1)
         | 
| 97 | 
            -
             | 
| 98 | 
            -
                    # get quantized latent vectors
         | 
| 99 | 
            -
                    z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
         | 
| 100 | 
            -
             | 
| 101 | 
            -
                    if shape is not None:
         | 
| 102 | 
            -
                        z_q = z_q.view(shape)
         | 
| 103 | 
            -
             | 
| 104 | 
            -
                        # reshape back to match original input shape
         | 
| 105 | 
            -
                        z_q = z_q.permute(0, 3, 1, 2).contiguous()
         | 
| 106 | 
            -
             | 
| 107 | 
            -
                    return z_q
         | 
| 108 | 
            -
             | 
| 109 | 
            -
             | 
| 110 | 
            -
            class GumbelQuantize(nn.Module):
         | 
| 111 | 
            -
                """
         | 
| 112 | 
            -
                credit to @karpathy: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!)
         | 
| 113 | 
            -
                Gumbel Softmax trick quantizer
         | 
| 114 | 
            -
                Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016
         | 
| 115 | 
            -
                https://arxiv.org/abs/1611.01144
         | 
| 116 | 
            -
                """
         | 
| 117 | 
            -
                def __init__(self, num_hiddens, embedding_dim, n_embed, straight_through=True,
         | 
| 118 | 
            -
                             kl_weight=5e-4, temp_init=1.0, use_vqinterface=True,
         | 
| 119 | 
            -
                             remap=None, unknown_index="random"):
         | 
| 120 | 
            -
                    super().__init__()
         | 
| 121 | 
            -
             | 
| 122 | 
            -
                    self.embedding_dim = embedding_dim
         | 
| 123 | 
            -
                    self.n_embed = n_embed
         | 
| 124 | 
            -
             | 
| 125 | 
            -
                    self.straight_through = straight_through
         | 
| 126 | 
            -
                    self.temperature = temp_init
         | 
| 127 | 
            -
                    self.kl_weight = kl_weight
         | 
| 128 | 
            -
             | 
| 129 | 
            -
                    self.proj = nn.Conv2d(num_hiddens, n_embed, 1)
         | 
| 130 | 
            -
                    self.embed = nn.Embedding(n_embed, embedding_dim)
         | 
| 131 | 
            -
             | 
| 132 | 
            -
                    self.use_vqinterface = use_vqinterface
         | 
| 133 | 
            -
             | 
| 134 | 
            -
                    self.remap = remap
         | 
| 135 | 
            -
                    if self.remap is not None:
         | 
| 136 | 
            -
                        self.register_buffer("used", torch.tensor(np.load(self.remap)))
         | 
| 137 | 
            -
                        self.re_embed = self.used.shape[0]
         | 
| 138 | 
            -
                        self.unknown_index = unknown_index # "random" or "extra" or integer
         | 
| 139 | 
            -
                        if self.unknown_index == "extra":
         | 
| 140 | 
            -
                            self.unknown_index = self.re_embed
         | 
| 141 | 
            -
                            self.re_embed = self.re_embed+1
         | 
| 142 | 
            -
                        print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
         | 
| 143 | 
            -
                              f"Using {self.unknown_index} for unknown indices.")
         | 
| 144 | 
            -
                    else:
         | 
| 145 | 
            -
                        self.re_embed = n_embed
         | 
| 146 | 
            -
             | 
| 147 | 
            -
                def remap_to_used(self, inds):
         | 
| 148 | 
            -
                    ishape = inds.shape
         | 
| 149 | 
            -
                    assert len(ishape)>1
         | 
| 150 | 
            -
                    inds = inds.reshape(ishape[0],-1)
         | 
| 151 | 
            -
                    used = self.used.to(inds)
         | 
| 152 | 
            -
                    match = (inds[:,:,None]==used[None,None,...]).long()
         | 
| 153 | 
            -
                    new = match.argmax(-1)
         | 
| 154 | 
            -
                    unknown = match.sum(2)<1
         | 
| 155 | 
            -
                    if self.unknown_index == "random":
         | 
| 156 | 
            -
                        new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
         | 
| 157 | 
            -
                    else:
         | 
| 158 | 
            -
                        new[unknown] = self.unknown_index
         | 
| 159 | 
            -
                    return new.reshape(ishape)
         | 
| 160 | 
            -
             | 
| 161 | 
            -
                def unmap_to_all(self, inds):
         | 
| 162 | 
            -
                    ishape = inds.shape
         | 
| 163 | 
            -
                    assert len(ishape)>1
         | 
| 164 | 
            -
                    inds = inds.reshape(ishape[0],-1)
         | 
| 165 | 
            -
                    used = self.used.to(inds)
         | 
| 166 | 
            -
                    if self.re_embed > self.used.shape[0]: # extra token
         | 
| 167 | 
            -
                        inds[inds>=self.used.shape[0]] = 0 # simply set to zero
         | 
| 168 | 
            -
                    back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
         | 
| 169 | 
            -
                    return back.reshape(ishape)
         | 
| 170 | 
            -
             | 
| 171 | 
            -
                def forward(self, z, temp=None, return_logits=False):
         | 
| 172 | 
            -
                    # force hard = True when we are in eval mode, as we must quantize. actually, always true seems to work
         | 
| 173 | 
            -
                    hard = self.straight_through if self.training else True
         | 
| 174 | 
            -
                    temp = self.temperature if temp is None else temp
         | 
| 175 | 
            -
             | 
| 176 | 
            -
                    logits = self.proj(z)
         | 
| 177 | 
            -
                    if self.remap is not None:
         | 
| 178 | 
            -
                        # continue only with used logits
         | 
| 179 | 
            -
                        full_zeros = torch.zeros_like(logits)
         | 
| 180 | 
            -
                        logits = logits[:,self.used,...]
         | 
| 181 | 
            -
             | 
| 182 | 
            -
                    soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard)
         | 
| 183 | 
            -
                    if self.remap is not None:
         | 
| 184 | 
            -
                        # go back to all entries but unused set to zero
         | 
| 185 | 
            -
                        full_zeros[:,self.used,...] = soft_one_hot
         | 
| 186 | 
            -
                        soft_one_hot = full_zeros
         | 
| 187 | 
            -
                    z_q = einsum('b n h w, n d -> b d h w', soft_one_hot, self.embed.weight)
         | 
| 188 | 
            -
             | 
| 189 | 
            -
                    # + kl divergence to the prior loss
         | 
| 190 | 
            -
                    qy = F.softmax(logits, dim=1)
         | 
| 191 | 
            -
                    diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean()
         | 
| 192 | 
            -
             | 
| 193 | 
            -
                    ind = soft_one_hot.argmax(dim=1)
         | 
| 194 | 
            -
                    if self.remap is not None:
         | 
| 195 | 
            -
                        ind = self.remap_to_used(ind)
         | 
| 196 | 
            -
                    if self.use_vqinterface:
         | 
| 197 | 
            -
                        if return_logits:
         | 
| 198 | 
            -
                            return z_q, diff, (None, None, ind), logits
         | 
| 199 | 
            -
                        return z_q, diff, (None, None, ind)
         | 
| 200 | 
            -
                    return z_q, diff, ind
         | 
| 201 | 
            -
             | 
| 202 | 
            -
                def get_codebook_entry(self, indices, shape):
         | 
| 203 | 
            -
                    b, h, w, c = shape
         | 
| 204 | 
            -
                    assert b*h*w == indices.shape[0]
         | 
| 205 | 
            -
                    indices = rearrange(indices, '(b h w) -> b h w', b=b, h=h, w=w)
         | 
| 206 | 
            -
                    if self.remap is not None:
         | 
| 207 | 
            -
                        indices = self.unmap_to_all(indices)
         | 
| 208 | 
            -
                    one_hot = F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float()
         | 
| 209 | 
            -
                    z_q = einsum('b n h w, n d -> b d h w', one_hot, self.embed.weight)
         | 
| 210 | 
            -
                    return z_q
         | 
| 211 | 
            -
             | 
| 212 | 
            -
             | 
| 213 | 
            -
            class VectorQuantizer2(nn.Module):
         | 
| 214 | 
            -
                """
         | 
| 215 | 
            -
                Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
         | 
| 216 | 
            -
                avoids costly matrix multiplications and allows for post-hoc remapping of indices.
         | 
| 217 | 
            -
                """
         | 
| 218 | 
            -
                # NOTE: due to a bug the beta term was applied to the wrong term. for
         | 
| 219 | 
            -
                # backwards compatibility we use the buggy version by default, but you can
         | 
| 220 | 
            -
                # specify legacy=False to fix it.
         | 
| 221 | 
            -
                def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random",
         | 
| 222 | 
            -
                             sane_index_shape=False, legacy=True):
         | 
| 223 | 
            -
                    super().__init__()
         | 
| 224 | 
            -
                    self.n_e = n_e
         | 
| 225 | 
            -
                    self.e_dim = e_dim
         | 
| 226 | 
            -
                    self.beta = beta
         | 
| 227 | 
            -
                    self.legacy = legacy
         | 
| 228 | 
            -
             | 
| 229 | 
            -
                    self.embedding = nn.Embedding(self.n_e, self.e_dim)
         | 
| 230 | 
            -
                    self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
         | 
| 231 | 
            -
             | 
| 232 | 
            -
                    self.remap = remap
         | 
| 233 | 
            -
                    if self.remap is not None:
         | 
| 234 | 
            -
                        self.register_buffer("used", torch.tensor(np.load(self.remap)))
         | 
| 235 | 
            -
                        self.re_embed = self.used.shape[0]
         | 
| 236 | 
            -
                        self.unknown_index = unknown_index # "random" or "extra" or integer
         | 
| 237 | 
            -
                        if self.unknown_index == "extra":
         | 
| 238 | 
            -
                            self.unknown_index = self.re_embed
         | 
| 239 | 
            -
                            self.re_embed = self.re_embed+1
         | 
| 240 | 
            -
                        print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
         | 
| 241 | 
            -
                              f"Using {self.unknown_index} for unknown indices.")
         | 
| 242 | 
            -
                    else:
         | 
| 243 | 
            -
                        self.re_embed = n_e
         | 
| 244 | 
            -
             | 
| 245 | 
            -
                    self.sane_index_shape = sane_index_shape
         | 
| 246 | 
            -
             | 
| 247 | 
            -
                def remap_to_used(self, inds):
         | 
| 248 | 
            -
                    ishape = inds.shape
         | 
| 249 | 
            -
                    assert len(ishape)>1
         | 
| 250 | 
            -
                    inds = inds.reshape(ishape[0],-1)
         | 
| 251 | 
            -
                    used = self.used.to(inds)
         | 
| 252 | 
            -
                    match = (inds[:,:,None]==used[None,None,...]).long()
         | 
| 253 | 
            -
                    new = match.argmax(-1)
         | 
| 254 | 
            -
                    unknown = match.sum(2)<1
         | 
| 255 | 
            -
                    if self.unknown_index == "random":
         | 
| 256 | 
            -
                        new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
         | 
| 257 | 
            -
                    else:
         | 
| 258 | 
            -
                        new[unknown] = self.unknown_index
         | 
| 259 | 
            -
                    return new.reshape(ishape)
         | 
| 260 | 
            -
             | 
| 261 | 
            -
                def unmap_to_all(self, inds):
         | 
| 262 | 
            -
                    ishape = inds.shape
         | 
| 263 | 
            -
                    assert len(ishape)>1
         | 
| 264 | 
            -
                    inds = inds.reshape(ishape[0],-1)
         | 
| 265 | 
            -
                    used = self.used.to(inds)
         | 
| 266 | 
            -
                    if self.re_embed > self.used.shape[0]: # extra token
         | 
| 267 | 
            -
                        inds[inds>=self.used.shape[0]] = 0 # simply set to zero
         | 
| 268 | 
            -
                    back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
         | 
| 269 | 
            -
                    return back.reshape(ishape)
         | 
| 270 | 
            -
             | 
| 271 | 
            -
                def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
         | 
| 272 | 
            -
                    assert temp is None or temp==1.0, "Only for interface compatible with Gumbel"
         | 
| 273 | 
            -
                    assert rescale_logits==False, "Only for interface compatible with Gumbel"
         | 
| 274 | 
            -
                    assert return_logits==False, "Only for interface compatible with Gumbel"
         | 
| 275 | 
            -
                    # reshape z -> (batch, height, width, channel) and flatten
         | 
| 276 | 
            -
                    z = rearrange(z, 'b c h w -> b h w c').contiguous()
         | 
| 277 | 
            -
                    z_flattened = z.view(-1, self.e_dim)
         | 
| 278 | 
            -
                    # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
         | 
| 279 | 
            -
             | 
| 280 | 
            -
                    d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
         | 
| 281 | 
            -
                        torch.sum(self.embedding.weight**2, dim=1) - 2 * \
         | 
| 282 | 
            -
                        torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
         | 
| 283 | 
            -
             | 
| 284 | 
            -
                    min_encoding_indices = torch.argmin(d, dim=1)
         | 
| 285 | 
            -
                    z_q = self.embedding(min_encoding_indices).view(z.shape)
         | 
| 286 | 
            -
                    perplexity = None
         | 
| 287 | 
            -
                    min_encodings = None
         | 
| 288 | 
            -
             | 
| 289 | 
            -
                    # compute loss for embedding
         | 
| 290 | 
            -
                    if not self.legacy:
         | 
| 291 | 
            -
                        loss = self.beta * torch.mean((z_q.detach()-z)**2) + \
         | 
| 292 | 
            -
                               torch.mean((z_q - z.detach()) ** 2)
         | 
| 293 | 
            -
                    else:
         | 
| 294 | 
            -
                        loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
         | 
| 295 | 
            -
                               torch.mean((z_q - z.detach()) ** 2)
         | 
| 296 | 
            -
             | 
| 297 | 
            -
                    # preserve gradients
         | 
| 298 | 
            -
                    z_q = z + (z_q - z).detach()
         | 
| 299 | 
            -
             | 
| 300 | 
            -
                    # reshape back to match original input shape
         | 
| 301 | 
            -
                    z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
         | 
| 302 | 
            -
             | 
| 303 | 
            -
                    if self.remap is not None:
         | 
| 304 | 
            -
                        min_encoding_indices = min_encoding_indices.reshape(z.shape[0],-1) # add batch axis
         | 
| 305 | 
            -
                        min_encoding_indices = self.remap_to_used(min_encoding_indices)
         | 
| 306 | 
            -
                        min_encoding_indices = min_encoding_indices.reshape(-1,1) # flatten
         | 
| 307 | 
            -
             | 
| 308 | 
            -
                    if self.sane_index_shape:
         | 
| 309 | 
            -
                        min_encoding_indices = min_encoding_indices.reshape(
         | 
| 310 | 
            -
                            z_q.shape[0], z_q.shape[2], z_q.shape[3])
         | 
| 311 | 
            -
             | 
| 312 | 
            -
                    return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
         | 
| 313 | 
            -
             | 
| 314 | 
            -
                def get_codebook_entry(self, indices, shape):
         | 
| 315 | 
            -
                    # shape specifying (batch, height, width, channel)
         | 
| 316 | 
            -
                    if self.remap is not None:
         | 
| 317 | 
            -
                        indices = indices.reshape(shape[0],-1) # add batch axis
         | 
| 318 | 
            -
                        indices = self.unmap_to_all(indices)
         | 
| 319 | 
            -
                        indices = indices.reshape(-1) # flatten again
         | 
| 320 | 
            -
             | 
| 321 | 
            -
                    # get quantized latent vectors
         | 
| 322 | 
            -
                    z_q = self.embedding(indices)
         | 
| 323 | 
            -
             | 
| 324 | 
            -
                    if shape is not None:
         | 
| 325 | 
            -
                        z_q = z_q.view(shape)
         | 
| 326 | 
            -
                        # reshape back to match original input shape
         | 
| 327 | 
            -
                        z_q = z_q.permute(0, 3, 1, 2).contiguous()
         | 
| 328 | 
            -
             | 
| 329 | 
            -
                    return z_q
         | 
| 330 | 
            -
             | 
| 331 | 
            -
            class EmbeddingEMA(nn.Module):
         | 
| 332 | 
            -
                def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5):
         | 
| 333 | 
            -
                    super().__init__()
         | 
| 334 | 
            -
                    self.decay = decay
         | 
| 335 | 
            -
                    self.eps = eps        
         | 
| 336 | 
            -
                    weight = torch.randn(num_tokens, codebook_dim)
         | 
| 337 | 
            -
                    self.weight = nn.Parameter(weight, requires_grad = False)
         | 
| 338 | 
            -
                    self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad = False)
         | 
| 339 | 
            -
                    self.embed_avg = nn.Parameter(weight.clone(), requires_grad = False)
         | 
| 340 | 
            -
                    self.update = True
         | 
| 341 | 
            -
             | 
| 342 | 
            -
                def forward(self, embed_id):
         | 
| 343 | 
            -
                    return F.embedding(embed_id, self.weight)
         | 
| 344 | 
            -
             | 
| 345 | 
            -
                def cluster_size_ema_update(self, new_cluster_size):
         | 
| 346 | 
            -
                    self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay)
         | 
| 347 | 
            -
             | 
| 348 | 
            -
                def embed_avg_ema_update(self, new_embed_avg): 
         | 
| 349 | 
            -
                    self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay)
         | 
| 350 | 
            -
             | 
| 351 | 
            -
                def weight_update(self, num_tokens):
         | 
| 352 | 
            -
                    n = self.cluster_size.sum()
         | 
| 353 | 
            -
                    smoothed_cluster_size = (
         | 
| 354 | 
            -
                            (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n
         | 
| 355 | 
            -
                        )
         | 
| 356 | 
            -
                    #normalize embedding average with smoothed cluster size
         | 
| 357 | 
            -
                    embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)
         | 
| 358 | 
            -
                    self.weight.data.copy_(embed_normalized)   
         | 
| 359 | 
            -
             | 
| 360 | 
            -
             | 
| 361 | 
            -
            class EMAVectorQuantizer(nn.Module):
         | 
| 362 | 
            -
                def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5,
         | 
| 363 | 
            -
                            remap=None, unknown_index="random"):
         | 
| 364 | 
            -
                    super().__init__()
         | 
| 365 | 
            -
                    self.codebook_dim = codebook_dim
         | 
| 366 | 
            -
                    self.num_tokens = num_tokens
         | 
| 367 | 
            -
                    self.beta = beta
         | 
| 368 | 
            -
                    self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps)
         | 
| 369 | 
            -
             | 
| 370 | 
            -
                    self.remap = remap
         | 
| 371 | 
            -
                    if self.remap is not None:
         | 
| 372 | 
            -
                        self.register_buffer("used", torch.tensor(np.load(self.remap)))
         | 
| 373 | 
            -
                        self.re_embed = self.used.shape[0]
         | 
| 374 | 
            -
                        self.unknown_index = unknown_index # "random" or "extra" or integer
         | 
| 375 | 
            -
                        if self.unknown_index == "extra":
         | 
| 376 | 
            -
                            self.unknown_index = self.re_embed
         | 
| 377 | 
            -
                            self.re_embed = self.re_embed+1
         | 
| 378 | 
            -
                        print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
         | 
| 379 | 
            -
                              f"Using {self.unknown_index} for unknown indices.")
         | 
| 380 | 
            -
                    else:
         | 
| 381 | 
            -
                        self.re_embed = n_embed
         | 
| 382 | 
            -
             | 
| 383 | 
            -
                def remap_to_used(self, inds):
         | 
| 384 | 
            -
                    ishape = inds.shape
         | 
| 385 | 
            -
                    assert len(ishape)>1
         | 
| 386 | 
            -
                    inds = inds.reshape(ishape[0],-1)
         | 
| 387 | 
            -
                    used = self.used.to(inds)
         | 
| 388 | 
            -
                    match = (inds[:,:,None]==used[None,None,...]).long()
         | 
| 389 | 
            -
                    new = match.argmax(-1)
         | 
| 390 | 
            -
                    unknown = match.sum(2)<1
         | 
| 391 | 
            -
                    if self.unknown_index == "random":
         | 
| 392 | 
            -
                        new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
         | 
| 393 | 
            -
                    else:
         | 
| 394 | 
            -
                        new[unknown] = self.unknown_index
         | 
| 395 | 
            -
                    return new.reshape(ishape)
         | 
| 396 | 
            -
             | 
| 397 | 
            -
                def unmap_to_all(self, inds):
         | 
| 398 | 
            -
                    ishape = inds.shape
         | 
| 399 | 
            -
                    assert len(ishape)>1
         | 
| 400 | 
            -
                    inds = inds.reshape(ishape[0],-1)
         | 
| 401 | 
            -
                    used = self.used.to(inds)
         | 
| 402 | 
            -
                    if self.re_embed > self.used.shape[0]: # extra token
         | 
| 403 | 
            -
                        inds[inds>=self.used.shape[0]] = 0 # simply set to zero
         | 
| 404 | 
            -
                    back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
         | 
| 405 | 
            -
                    return back.reshape(ishape)
         | 
| 406 | 
            -
             | 
| 407 | 
            -
                def forward(self, z):
         | 
| 408 | 
            -
                    # reshape z -> (batch, height, width, channel) and flatten
         | 
| 409 | 
            -
                    #z, 'b c h w -> b h w c'
         | 
| 410 | 
            -
                    z = rearrange(z, 'b c h w -> b h w c')
         | 
| 411 | 
            -
                    z_flattened = z.reshape(-1, self.codebook_dim)
         | 
| 412 | 
            -
                    
         | 
| 413 | 
            -
                    # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
         | 
| 414 | 
            -
                    d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \
         | 
| 415 | 
            -
                        self.embedding.weight.pow(2).sum(dim=1) - 2 * \
         | 
| 416 | 
            -
                        torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n'
         | 
| 417 | 
            -
             | 
| 418 | 
            -
             | 
| 419 | 
            -
                    encoding_indices = torch.argmin(d, dim=1)
         | 
| 420 | 
            -
             | 
| 421 | 
            -
                    z_q = self.embedding(encoding_indices).view(z.shape)
         | 
| 422 | 
            -
                    encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)     
         | 
| 423 | 
            -
                    avg_probs = torch.mean(encodings, dim=0)
         | 
| 424 | 
            -
                    perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
         | 
| 425 | 
            -
             | 
| 426 | 
            -
                    if self.training and self.embedding.update:
         | 
| 427 | 
            -
                        #EMA cluster size
         | 
| 428 | 
            -
                        encodings_sum = encodings.sum(0)            
         | 
| 429 | 
            -
                        self.embedding.cluster_size_ema_update(encodings_sum)
         | 
| 430 | 
            -
                        #EMA embedding average
         | 
| 431 | 
            -
                        embed_sum = encodings.transpose(0,1) @ z_flattened            
         | 
| 432 | 
            -
                        self.embedding.embed_avg_ema_update(embed_sum)
         | 
| 433 | 
            -
                        #normalize embed_avg and update weight
         | 
| 434 | 
            -
                        self.embedding.weight_update(self.num_tokens)
         | 
| 435 | 
            -
             | 
| 436 | 
            -
                    # compute loss for embedding
         | 
| 437 | 
            -
                    loss = self.beta * F.mse_loss(z_q.detach(), z) 
         | 
| 438 | 
            -
             | 
| 439 | 
            -
                    # preserve gradients
         | 
| 440 | 
            -
                    z_q = z + (z_q - z).detach()
         | 
| 441 | 
            -
             | 
| 442 | 
            -
                    # reshape back to match original input shape
         | 
| 443 | 
            -
                    #z_q, 'b h w c -> b c h w'
         | 
| 444 | 
            -
                    z_q = rearrange(z_q, 'b h w c -> b c h w')
         | 
| 445 | 
            -
                    return z_q, loss, (perplexity, encodings, encoding_indices)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        taming-transformers/taming/util.py
    DELETED
    
    | @@ -1,157 +0,0 @@ | |
| 1 | 
            -
            import os, hashlib
         | 
| 2 | 
            -
            import requests
         | 
| 3 | 
            -
            from tqdm import tqdm
         | 
| 4 | 
            -
             | 
| 5 | 
            -
            URL_MAP = {
         | 
| 6 | 
            -
                "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
         | 
| 7 | 
            -
            }
         | 
| 8 | 
            -
             | 
| 9 | 
            -
            CKPT_MAP = {
         | 
| 10 | 
            -
                "vgg_lpips": "vgg.pth"
         | 
| 11 | 
            -
            }
         | 
| 12 | 
            -
             | 
| 13 | 
            -
            MD5_MAP = {
         | 
| 14 | 
            -
                "vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
         | 
| 15 | 
            -
            }
         | 
| 16 | 
            -
             | 
| 17 | 
            -
             | 
| 18 | 
            -
            def download(url, local_path, chunk_size=1024):
         | 
| 19 | 
            -
                os.makedirs(os.path.split(local_path)[0], exist_ok=True)
         | 
| 20 | 
            -
                with requests.get(url, stream=True) as r:
         | 
| 21 | 
            -
                    total_size = int(r.headers.get("content-length", 0))
         | 
| 22 | 
            -
                    with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
         | 
| 23 | 
            -
                        with open(local_path, "wb") as f:
         | 
| 24 | 
            -
                            for data in r.iter_content(chunk_size=chunk_size):
         | 
| 25 | 
            -
                                if data:
         | 
| 26 | 
            -
                                    f.write(data)
         | 
| 27 | 
            -
                                    pbar.update(chunk_size)
         | 
| 28 | 
            -
             | 
| 29 | 
            -
             | 
| 30 | 
            -
            def md5_hash(path):
         | 
| 31 | 
            -
                with open(path, "rb") as f:
         | 
| 32 | 
            -
                    content = f.read()
         | 
| 33 | 
            -
                return hashlib.md5(content).hexdigest()
         | 
| 34 | 
            -
             | 
| 35 | 
            -
             | 
| 36 | 
            -
            def get_ckpt_path(name, root, check=False):
         | 
| 37 | 
            -
                assert name in URL_MAP
         | 
| 38 | 
            -
                path = os.path.join(root, CKPT_MAP[name])
         | 
| 39 | 
            -
                if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
         | 
| 40 | 
            -
                    print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
         | 
| 41 | 
            -
                    download(URL_MAP[name], path)
         | 
| 42 | 
            -
                    md5 = md5_hash(path)
         | 
| 43 | 
            -
                    assert md5 == MD5_MAP[name], md5
         | 
| 44 | 
            -
                return path
         | 
| 45 | 
            -
             | 
| 46 | 
            -
             | 
| 47 | 
            -
            class KeyNotFoundError(Exception):
         | 
| 48 | 
            -
                def __init__(self, cause, keys=None, visited=None):
         | 
| 49 | 
            -
                    self.cause = cause
         | 
| 50 | 
            -
                    self.keys = keys
         | 
| 51 | 
            -
                    self.visited = visited
         | 
| 52 | 
            -
                    messages = list()
         | 
| 53 | 
            -
                    if keys is not None:
         | 
| 54 | 
            -
                        messages.append("Key not found: {}".format(keys))
         | 
| 55 | 
            -
                    if visited is not None:
         | 
| 56 | 
            -
                        messages.append("Visited: {}".format(visited))
         | 
| 57 | 
            -
                    messages.append("Cause:\n{}".format(cause))
         | 
| 58 | 
            -
                    message = "\n".join(messages)
         | 
| 59 | 
            -
                    super().__init__(message)
         | 
| 60 | 
            -
             | 
| 61 | 
            -
             | 
| 62 | 
            -
            def retrieve(
         | 
| 63 | 
            -
                list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False
         | 
| 64 | 
            -
            ):
         | 
| 65 | 
            -
                """Given a nested list or dict return the desired value at key expanding
         | 
| 66 | 
            -
                callable nodes if necessary and :attr:`expand` is ``True``. The expansion
         | 
| 67 | 
            -
                is done in-place.
         | 
| 68 | 
            -
             | 
| 69 | 
            -
                Parameters
         | 
| 70 | 
            -
                ----------
         | 
| 71 | 
            -
                    list_or_dict : list or dict
         | 
| 72 | 
            -
                        Possibly nested list or dictionary.
         | 
| 73 | 
            -
                    key : str
         | 
| 74 | 
            -
                        key/to/value, path like string describing all keys necessary to
         | 
| 75 | 
            -
                        consider to get to the desired value. List indices can also be
         | 
| 76 | 
            -
                        passed here.
         | 
| 77 | 
            -
                    splitval : str
         | 
| 78 | 
            -
                        String that defines the delimiter between keys of the
         | 
| 79 | 
            -
                        different depth levels in `key`.
         | 
| 80 | 
            -
                    default : obj
         | 
| 81 | 
            -
                        Value returned if :attr:`key` is not found.
         | 
| 82 | 
            -
                    expand : bool
         | 
| 83 | 
            -
                        Whether to expand callable nodes on the path or not.
         | 
| 84 | 
            -
             | 
| 85 | 
            -
                Returns
         | 
| 86 | 
            -
                -------
         | 
| 87 | 
            -
                    The desired value or if :attr:`default` is not ``None`` and the
         | 
| 88 | 
            -
                    :attr:`key` is not found returns ``default``.
         | 
| 89 | 
            -
             | 
| 90 | 
            -
                Raises
         | 
| 91 | 
            -
                ------
         | 
| 92 | 
            -
                    Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is
         | 
| 93 | 
            -
                    ``None``.
         | 
| 94 | 
            -
                """
         | 
| 95 | 
            -
             | 
| 96 | 
            -
                keys = key.split(splitval)
         | 
| 97 | 
            -
             | 
| 98 | 
            -
                success = True
         | 
| 99 | 
            -
                try:
         | 
| 100 | 
            -
                    visited = []
         | 
| 101 | 
            -
                    parent = None
         | 
| 102 | 
            -
                    last_key = None
         | 
| 103 | 
            -
                    for key in keys:
         | 
| 104 | 
            -
                        if callable(list_or_dict):
         | 
| 105 | 
            -
                            if not expand:
         | 
| 106 | 
            -
                                raise KeyNotFoundError(
         | 
| 107 | 
            -
                                    ValueError(
         | 
| 108 | 
            -
                                        "Trying to get past callable node with expand=False."
         | 
| 109 | 
            -
                                    ),
         | 
| 110 | 
            -
                                    keys=keys,
         | 
| 111 | 
            -
                                    visited=visited,
         | 
| 112 | 
            -
                                )
         | 
| 113 | 
            -
                            list_or_dict = list_or_dict()
         | 
| 114 | 
            -
                            parent[last_key] = list_or_dict
         | 
| 115 | 
            -
             | 
| 116 | 
            -
                        last_key = key
         | 
| 117 | 
            -
                        parent = list_or_dict
         | 
| 118 | 
            -
             | 
| 119 | 
            -
                        try:
         | 
| 120 | 
            -
                            if isinstance(list_or_dict, dict):
         | 
| 121 | 
            -
                                list_or_dict = list_or_dict[key]
         | 
| 122 | 
            -
                            else:
         | 
| 123 | 
            -
                                list_or_dict = list_or_dict[int(key)]
         | 
| 124 | 
            -
                        except (KeyError, IndexError, ValueError) as e:
         | 
| 125 | 
            -
                            raise KeyNotFoundError(e, keys=keys, visited=visited)
         | 
| 126 | 
            -
             | 
| 127 | 
            -
                        visited += [key]
         | 
| 128 | 
            -
                    # final expansion of retrieved value
         | 
| 129 | 
            -
                    if expand and callable(list_or_dict):
         | 
| 130 | 
            -
                        list_or_dict = list_or_dict()
         | 
| 131 | 
            -
                        parent[last_key] = list_or_dict
         | 
| 132 | 
            -
                except KeyNotFoundError as e:
         | 
| 133 | 
            -
                    if default is None:
         | 
| 134 | 
            -
                        raise e
         | 
| 135 | 
            -
                    else:
         | 
| 136 | 
            -
                        list_or_dict = default
         | 
| 137 | 
            -
                        success = False
         | 
| 138 | 
            -
             | 
| 139 | 
            -
                if not pass_success:
         | 
| 140 | 
            -
                    return list_or_dict
         | 
| 141 | 
            -
                else:
         | 
| 142 | 
            -
                    return list_or_dict, success
         | 
| 143 | 
            -
             | 
| 144 | 
            -
             | 
| 145 | 
            -
            if __name__ == "__main__":
         | 
| 146 | 
            -
                config = {"keya": "a",
         | 
| 147 | 
            -
                          "keyb": "b",
         | 
| 148 | 
            -
                          "keyc":
         | 
| 149 | 
            -
                              {"cc1": 1,
         | 
| 150 | 
            -
                               "cc2": 2,
         | 
| 151 | 
            -
                               }
         | 
| 152 | 
            -
                          }
         | 
| 153 | 
            -
                from omegaconf import OmegaConf
         | 
| 154 | 
            -
                config = OmegaConf.create(config)
         | 
| 155 | 
            -
                print(config)
         | 
| 156 | 
            -
                retrieve(config, "keya")
         | 
| 157 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        taming-transformers/taming_transformers.egg-info/PKG-INFO
    DELETED
    
    | @@ -1,10 +0,0 @@ | |
| 1 | 
            -
            Metadata-Version: 2.1
         | 
| 2 | 
            -
            Name: taming-transformers
         | 
| 3 | 
            -
            Version: 0.0.1
         | 
| 4 | 
            -
            Summary: Taming Transformers for High-Resolution Image Synthesis
         | 
| 5 | 
            -
            Home-page: UNKNOWN
         | 
| 6 | 
            -
            License: UNKNOWN
         | 
| 7 | 
            -
            Platform: UNKNOWN
         | 
| 8 | 
            -
             | 
| 9 | 
            -
            UNKNOWN
         | 
| 10 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        taming-transformers/taming_transformers.egg-info/SOURCES.txt
    DELETED
    
    | @@ -1,7 +0,0 @@ | |
| 1 | 
            -
            README.md
         | 
| 2 | 
            -
            setup.py
         | 
| 3 | 
            -
            taming_transformers.egg-info/PKG-INFO
         | 
| 4 | 
            -
            taming_transformers.egg-info/SOURCES.txt
         | 
| 5 | 
            -
            taming_transformers.egg-info/dependency_links.txt
         | 
| 6 | 
            -
            taming_transformers.egg-info/requires.txt
         | 
| 7 | 
            -
            taming_transformers.egg-info/top_level.txt
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        taming-transformers/taming_transformers.egg-info/dependency_links.txt
    DELETED
    
    | @@ -1 +0,0 @@ | |
| 1 | 
            -
             | 
|  | |
|  | 
    	
        taming-transformers/taming_transformers.egg-info/requires.txt
    DELETED
    
    | @@ -1,3 +0,0 @@ | |
| 1 | 
            -
            torch
         | 
| 2 | 
            -
            numpy
         | 
| 3 | 
            -
            tqdm
         | 
|  | |
|  | |
|  | |
|  | 
    	
        taming-transformers/taming_transformers.egg-info/top_level.txt
    DELETED
    
    | @@ -1 +0,0 @@ | |
| 1 | 
            -
             | 
|  | |
|  | 
