Chao Xu
		
	commited on
		
		
					Commit 
							
							·
						
						6c1250a
	
1
								Parent(s):
							
							0e93edd
								
empty cache
Browse files- sam_utils.py +1 -0
- zero123_utils.py +8 -9
    	
        sam_utils.py
    CHANGED
    
    | @@ -103,5 +103,6 @@ def sam_out_nosave(predictor, input_image, *bbox_sliders): | |
| 103 | 
             
                out_image_bbox = out_image.copy()
         | 
| 104 | 
             
                out_image[:, :, 3] = mask.astype(np.uint8) * 255
         | 
| 105 | 
             
                out_image_bbox[:, :, 3] = masks_bbox[-1].astype(np.uint8) * 255 # np.argmax(scores_bbox)
         | 
|  | |
| 106 | 
             
                return Image.fromarray(out_image_bbox, mode='RGBA') 
         | 
| 107 | 
             
                cv2.imwrite(save_path, cv2.cvtColor(out_image_bbox, cv2.COLOR_RGBA2BGRA))
         | 
|  | |
| 103 | 
             
                out_image_bbox = out_image.copy()
         | 
| 104 | 
             
                out_image[:, :, 3] = mask.astype(np.uint8) * 255
         | 
| 105 | 
             
                out_image_bbox[:, :, 3] = masks_bbox[-1].astype(np.uint8) * 255 # np.argmax(scores_bbox)
         | 
| 106 | 
            +
                torch.cuda.empty_cache()
         | 
| 107 | 
             
                return Image.fromarray(out_image_bbox, mode='RGBA') 
         | 
| 108 | 
             
                cv2.imwrite(save_path, cv2.cvtColor(out_image_bbox, cv2.COLOR_RGBA2BGRA))
         | 
    	
        zero123_utils.py
    CHANGED
    
    | @@ -61,9 +61,9 @@ def init_model(device, ckpt): | |
| 61 | 
             
                return models
         | 
| 62 |  | 
| 63 | 
             
            @torch.no_grad()
         | 
| 64 | 
            -
            def sample_model_batch(model, sampler, input_im, xs, ys, n_samples=4, precision=' | 
| 65 | 
             
                precision_scope = autocast if precision == 'autocast' else nullcontext
         | 
| 66 | 
            -
                with precision_scope( | 
| 67 | 
             
                    with model.ema_scope():
         | 
| 68 | 
             
                        c = model.get_learned_conditioning(input_im).tile(n_samples, 1, 1)
         | 
| 69 | 
             
                        T = []
         | 
| @@ -98,7 +98,9 @@ def sample_model_batch(model, sampler, input_im, xs, ys, n_samples=4, precision= | |
| 98 | 
             
                        print(samples_ddim.shape)
         | 
| 99 | 
             
                        # samples_ddim = torch.nn.functional.interpolate(samples_ddim, 64, mode='nearest', antialias=False)
         | 
| 100 | 
             
                        x_samples_ddim = model.decode_first_stage(samples_ddim)
         | 
| 101 | 
            -
                         | 
|  | |
|  | |
| 102 |  | 
| 103 |  | 
| 104 | 
             
            def predict_stage1(model, sampler, input_img_path, save_path_8, adjust_set=[], device="cuda"):
         | 
| @@ -118,7 +120,7 @@ def predict_stage1(model, sampler, input_img_path, save_path_8, adjust_set=[], d | |
| 118 | 
             
                for stage1_idx in range(len(x_samples_ddims_8)):
         | 
| 119 | 
             
                    if adjust_set != [] and stage1_idx not in adjust_set:
         | 
| 120 | 
             
                        continue
         | 
| 121 | 
            -
                    x_sample = 255.0 * rearrange(x_samples_ddims_8[stage1_idx]. | 
| 122 | 
             
                    Image.fromarray(x_sample.astype(np.uint8)).save(os.path.join(save_path_8, '%d.png'%(stage1_idx)))
         | 
| 123 | 
             
                del x_samples_ddims_8
         | 
| 124 | 
             
                del input_im
         | 
| @@ -148,7 +150,7 @@ def predict_stage1_gradio(model, raw_im, save_path = "", adjust_set=[], device=" | |
| 148 | 
             
                for stage1_idx in range(len(delta_x_1_8)):
         | 
| 149 | 
             
                    if adjust_set != [] and stage1_idx not in adjust_set:
         | 
| 150 | 
             
                        continue
         | 
| 151 | 
            -
                    x_sample = 255.0 * rearrange(x_samples_ddims_8[sample_idx]. | 
| 152 | 
             
                    out_image = Image.fromarray(x_sample.astype(np.uint8))
         | 
| 153 | 
             
                    ret_imgs.append(out_image)
         | 
| 154 | 
             
                    if save_path:
         | 
| @@ -177,16 +179,13 @@ def infer_stage_2(model, save_path_stage1, save_path_stage2, delta_x_2, delta_y_ | |
| 177 | 
             
                    input_im_init = input_im_init / 255.0
         | 
| 178 | 
             
                    input_im = transforms.ToTensor()(input_im_init).unsqueeze(0).to(device)
         | 
| 179 | 
             
                    input_im = input_im * 2 - 1
         | 
| 180 | 
            -
                    print("debug input device", input_im.device)
         | 
| 181 | 
            -
                    print("debug model device", model.device)
         | 
| 182 | 
             
                    # infer stage 2
         | 
| 183 | 
             
                    sampler = DDIMSampler(model)
         | 
| 184 | 
            -
                    print("debug sampler device", sampler.device)
         | 
| 185 | 
             
                    # sampler.to(device)
         | 
| 186 | 
             
                    # stage2_in = x_samples_ddims[stage1_idx][None, ...].to(device) * 2 - 1
         | 
| 187 | 
             
                    x_samples_ddims_stage2 = sample_model_batch(model, sampler, input_im, delta_x_2, delta_y_2, n_samples=len(delta_x_2), ddim_steps=ddim_steps, scale=scale)
         | 
| 188 | 
             
                    for stage2_idx in range(len(delta_x_2)):
         | 
| 189 | 
            -
                        x_sample_stage2 = 255.0 * rearrange(x_samples_ddims_stage2[stage2_idx]. | 
| 190 | 
             
                        Image.fromarray(x_sample_stage2.astype(np.uint8)).save(os.path.join(save_path_stage2, '%d_%d.png'%(stage1_idx, stage2_idx)))
         | 
| 191 | 
             
                    del input_im
         | 
| 192 | 
             
                    del sampler
         | 
|  | |
| 61 | 
             
                return models
         | 
| 62 |  | 
| 63 | 
             
            @torch.no_grad()
         | 
| 64 | 
            +
            def sample_model_batch(model, sampler, input_im, xs, ys, n_samples=4, precision='autocast', ddim_eta=1.0, ddim_steps=75, scale=3.0, h=256, w=256):
         | 
| 65 | 
             
                precision_scope = autocast if precision == 'autocast' else nullcontext
         | 
| 66 | 
            +
                with precision_scope("cuda"):
         | 
| 67 | 
             
                    with model.ema_scope():
         | 
| 68 | 
             
                        c = model.get_learned_conditioning(input_im).tile(n_samples, 1, 1)
         | 
| 69 | 
             
                        T = []
         | 
|  | |
| 98 | 
             
                        print(samples_ddim.shape)
         | 
| 99 | 
             
                        # samples_ddim = torch.nn.functional.interpolate(samples_ddim, 64, mode='nearest', antialias=False)
         | 
| 100 | 
             
                        x_samples_ddim = model.decode_first_stage(samples_ddim)
         | 
| 101 | 
            +
                        ret_imgs = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0).cpu()
         | 
| 102 | 
            +
                        del cond, c, x_samples_ddim, samples_ddim, uc
         | 
| 103 | 
            +
                        return ret_imgs
         | 
| 104 |  | 
| 105 |  | 
| 106 | 
             
            def predict_stage1(model, sampler, input_img_path, save_path_8, adjust_set=[], device="cuda"):
         | 
|  | |
| 120 | 
             
                for stage1_idx in range(len(x_samples_ddims_8)):
         | 
| 121 | 
             
                    if adjust_set != [] and stage1_idx not in adjust_set:
         | 
| 122 | 
             
                        continue
         | 
| 123 | 
            +
                    x_sample = 255.0 * rearrange(x_samples_ddims_8[stage1_idx].numpy(), 'c h w -> h w c')
         | 
| 124 | 
             
                    Image.fromarray(x_sample.astype(np.uint8)).save(os.path.join(save_path_8, '%d.png'%(stage1_idx)))
         | 
| 125 | 
             
                del x_samples_ddims_8
         | 
| 126 | 
             
                del input_im
         | 
|  | |
| 150 | 
             
                for stage1_idx in range(len(delta_x_1_8)):
         | 
| 151 | 
             
                    if adjust_set != [] and stage1_idx not in adjust_set:
         | 
| 152 | 
             
                        continue
         | 
| 153 | 
            +
                    x_sample = 255.0 * rearrange(x_samples_ddims_8[sample_idx].numpy(), 'c h w -> h w c')
         | 
| 154 | 
             
                    out_image = Image.fromarray(x_sample.astype(np.uint8))
         | 
| 155 | 
             
                    ret_imgs.append(out_image)
         | 
| 156 | 
             
                    if save_path:
         | 
|  | |
| 179 | 
             
                    input_im_init = input_im_init / 255.0
         | 
| 180 | 
             
                    input_im = transforms.ToTensor()(input_im_init).unsqueeze(0).to(device)
         | 
| 181 | 
             
                    input_im = input_im * 2 - 1
         | 
|  | |
|  | |
| 182 | 
             
                    # infer stage 2
         | 
| 183 | 
             
                    sampler = DDIMSampler(model)
         | 
|  | |
| 184 | 
             
                    # sampler.to(device)
         | 
| 185 | 
             
                    # stage2_in = x_samples_ddims[stage1_idx][None, ...].to(device) * 2 - 1
         | 
| 186 | 
             
                    x_samples_ddims_stage2 = sample_model_batch(model, sampler, input_im, delta_x_2, delta_y_2, n_samples=len(delta_x_2), ddim_steps=ddim_steps, scale=scale)
         | 
| 187 | 
             
                    for stage2_idx in range(len(delta_x_2)):
         | 
| 188 | 
            +
                        x_sample_stage2 = 255.0 * rearrange(x_samples_ddims_stage2[stage2_idx].numpy(), 'c h w -> h w c')
         | 
| 189 | 
             
                        Image.fromarray(x_sample_stage2.astype(np.uint8)).save(os.path.join(save_path_stage2, '%d_%d.png'%(stage1_idx, stage2_idx)))
         | 
| 190 | 
             
                    del input_im
         | 
| 191 | 
             
                    del sampler
         | 
