Chao Xu
commited on
Commit
·
c0c3e1b
1
Parent(s):
6c1250a
pruning
Browse files- sam_utils.py +3 -57
- zero123_utils.py +4 -4
sam_utils.py
CHANGED
|
@@ -1,14 +1,10 @@
|
|
| 1 |
import os
|
| 2 |
import numpy as np
|
| 3 |
import torch
|
| 4 |
-
# import matplotlib.pyplot as plt
|
| 5 |
-
import cv2
|
| 6 |
from PIL import Image
|
| 7 |
-
# from PIL import Image
|
| 8 |
import time
|
| 9 |
-
from utils import find_image_file
|
| 10 |
|
| 11 |
-
from segment_anything import sam_model_registry, SamPredictor
|
| 12 |
|
| 13 |
def sam_init(device_id=0):
|
| 14 |
import inspect
|
|
@@ -22,60 +18,11 @@ def sam_init(device_id=0):
|
|
| 22 |
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
|
| 23 |
sam.to(device=device)
|
| 24 |
predictor = SamPredictor(sam)
|
| 25 |
-
# mask_generator = SamAutomaticMaskGenerator(sam)
|
| 26 |
return predictor
|
| 27 |
|
| 28 |
-
def sam_out(predictor, shape_dir):
|
| 29 |
-
image_path = os.path.join(shape_dir, find_image_file(shape_dir))
|
| 30 |
-
save_path = os.path.join(shape_dir, "image_sam.png")
|
| 31 |
-
bbox_path = os.path.join(shape_dir, "bbox.txt")
|
| 32 |
-
bbox = np.loadtxt(bbox_path, delimiter=',')
|
| 33 |
-
image = cv2.imread(image_path)
|
| 34 |
-
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 35 |
-
|
| 36 |
-
start_time = time.time()
|
| 37 |
-
predictor.set_image(image)
|
| 38 |
-
|
| 39 |
-
h, w, _ = image.shape
|
| 40 |
-
input_point = np.array([[h//2, w//2]])
|
| 41 |
-
input_label = np.array([1])
|
| 42 |
-
|
| 43 |
-
masks, scores, logits = predictor.predict(
|
| 44 |
-
point_coords=input_point,
|
| 45 |
-
point_labels=input_label,
|
| 46 |
-
multimask_output=True,
|
| 47 |
-
)
|
| 48 |
-
|
| 49 |
-
masks_bbox, scores_bbox, logits_bbox = predictor.predict(
|
| 50 |
-
box=bbox,
|
| 51 |
-
multimask_output=True
|
| 52 |
-
)
|
| 53 |
-
|
| 54 |
-
print(f"SAM Time: {time.time() - start_time:.3f}s")
|
| 55 |
-
opt_idx = np.argmax(scores)
|
| 56 |
-
mask = masks[opt_idx]
|
| 57 |
-
out_image = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8)
|
| 58 |
-
out_image[:, :, :3] = image
|
| 59 |
-
out_image_bbox = out_image.copy()
|
| 60 |
-
out_image[:, :, 3] = mask.astype(np.uint8) * 255
|
| 61 |
-
out_image_bbox[:, :, 3] = masks_bbox[-1].astype(np.uint8) * 255 # np.argmax(scores_bbox)
|
| 62 |
-
cv2.imwrite(save_path, cv2.cvtColor(out_image_bbox, cv2.COLOR_RGBA2BGRA))
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
def convert_from_cv2_to_image(img: np.ndarray) -> Image:
|
| 66 |
-
return Image.fromarray(img)
|
| 67 |
-
# return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGRA2RGBA))
|
| 68 |
-
|
| 69 |
-
def convert_from_image_to_cv2(img: Image) -> np.ndarray:
|
| 70 |
-
return np.asarray(img)
|
| 71 |
-
# return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
| 72 |
-
|
| 73 |
def sam_out_nosave(predictor, input_image, *bbox_sliders):
|
| 74 |
-
# save_path = os.path.join(shape_dir, "image_sam.png")
|
| 75 |
-
# bbox_path = os.path.join(shape_dir, "bbox.txt")
|
| 76 |
-
# bbox = np.loadtxt(bbox_path, delimiter=',')
|
| 77 |
bbox = np.array(bbox_sliders)
|
| 78 |
-
image =
|
| 79 |
|
| 80 |
start_time = time.time()
|
| 81 |
predictor.set_image(image)
|
|
@@ -104,5 +51,4 @@ def sam_out_nosave(predictor, input_image, *bbox_sliders):
|
|
| 104 |
out_image[:, :, 3] = mask.astype(np.uint8) * 255
|
| 105 |
out_image_bbox[:, :, 3] = masks_bbox[-1].astype(np.uint8) * 255 # np.argmax(scores_bbox)
|
| 106 |
torch.cuda.empty_cache()
|
| 107 |
-
return Image.fromarray(out_image_bbox, mode='RGBA')
|
| 108 |
-
cv2.imwrite(save_path, cv2.cvtColor(out_image_bbox, cv2.COLOR_RGBA2BGRA))
|
|
|
|
| 1 |
import os
|
| 2 |
import numpy as np
|
| 3 |
import torch
|
|
|
|
|
|
|
| 4 |
from PIL import Image
|
|
|
|
| 5 |
import time
|
|
|
|
| 6 |
|
| 7 |
+
from segment_anything import sam_model_registry, SamPredictor
|
| 8 |
|
| 9 |
def sam_init(device_id=0):
|
| 10 |
import inspect
|
|
|
|
| 18 |
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
|
| 19 |
sam.to(device=device)
|
| 20 |
predictor = SamPredictor(sam)
|
|
|
|
| 21 |
return predictor
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
def sam_out_nosave(predictor, input_image, *bbox_sliders):
|
|
|
|
|
|
|
|
|
|
| 24 |
bbox = np.array(bbox_sliders)
|
| 25 |
+
image = np.asarray(input_image)
|
| 26 |
|
| 27 |
start_time = time.time()
|
| 28 |
predictor.set_image(image)
|
|
|
|
| 51 |
out_image[:, :, 3] = mask.astype(np.uint8) * 255
|
| 52 |
out_image_bbox[:, :, 3] = masks_bbox[-1].astype(np.uint8) * 255 # np.argmax(scores_bbox)
|
| 53 |
torch.cuda.empty_cache()
|
| 54 |
+
return Image.fromarray(out_image_bbox, mode='RGBA')
|
|
|
zero123_utils.py
CHANGED
|
@@ -76,7 +76,7 @@ def sample_model_batch(model, sampler, input_im, xs, ys, n_samples=4, precision=
|
|
| 76 |
cond = {}
|
| 77 |
cond['c_crossattn'] = [c]
|
| 78 |
# c_concat = model.encode_first_stage((input_im.to(c.device))).mode().detach()
|
| 79 |
-
cond['c_concat'] = [model.encode_first_stage(
|
| 80 |
.repeat(n_samples, 1, 1, 1)]
|
| 81 |
if scale != 1.0:
|
| 82 |
uc = {}
|
|
@@ -99,7 +99,8 @@ def sample_model_batch(model, sampler, input_im, xs, ys, n_samples=4, precision=
|
|
| 99 |
# samples_ddim = torch.nn.functional.interpolate(samples_ddim, 64, mode='nearest', antialias=False)
|
| 100 |
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
| 101 |
ret_imgs = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0).cpu()
|
| 102 |
-
del cond, c, x_samples_ddim, samples_ddim, uc
|
|
|
|
| 103 |
return ret_imgs
|
| 104 |
|
| 105 |
|
|
@@ -126,6 +127,7 @@ def predict_stage1(model, sampler, input_img_path, save_path_8, adjust_set=[], d
|
|
| 126 |
del input_im
|
| 127 |
torch.cuda.empty_cache()
|
| 128 |
|
|
|
|
| 129 |
def predict_stage1_gradio(model, raw_im, save_path = "", adjust_set=[], device="cuda", ddim_steps=75, scale=3.0):
|
| 130 |
# raw_im = raw_im.resize([256, 256], Image.LANCZOS)
|
| 131 |
# input_im_init = preprocess_image(models, raw_im, preprocess=False)
|
|
@@ -157,7 +159,6 @@ def predict_stage1_gradio(model, raw_im, save_path = "", adjust_set=[], device="
|
|
| 157 |
out_image.save(os.path.join(save_path, '%d.png'%(stage1_idx)))
|
| 158 |
sample_idx += 1
|
| 159 |
del x_samples_ddims_8
|
| 160 |
-
del input_im
|
| 161 |
del sampler
|
| 162 |
torch.cuda.empty_cache()
|
| 163 |
return ret_imgs
|
|
@@ -188,7 +189,6 @@ def infer_stage_2(model, save_path_stage1, save_path_stage2, delta_x_2, delta_y_
|
|
| 188 |
x_sample_stage2 = 255.0 * rearrange(x_samples_ddims_stage2[stage2_idx].numpy(), 'c h w -> h w c')
|
| 189 |
Image.fromarray(x_sample_stage2.astype(np.uint8)).save(os.path.join(save_path_stage2, '%d_%d.png'%(stage1_idx, stage2_idx)))
|
| 190 |
del input_im
|
| 191 |
-
del sampler
|
| 192 |
del x_samples_ddims_stage2
|
| 193 |
torch.cuda.empty_cache()
|
| 194 |
|
|
|
|
| 76 |
cond = {}
|
| 77 |
cond['c_crossattn'] = [c]
|
| 78 |
# c_concat = model.encode_first_stage((input_im.to(c.device))).mode().detach()
|
| 79 |
+
cond['c_concat'] = [model.encode_first_stage(input_im).mode().detach()
|
| 80 |
.repeat(n_samples, 1, 1, 1)]
|
| 81 |
if scale != 1.0:
|
| 82 |
uc = {}
|
|
|
|
| 99 |
# samples_ddim = torch.nn.functional.interpolate(samples_ddim, 64, mode='nearest', antialias=False)
|
| 100 |
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
| 101 |
ret_imgs = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0).cpu()
|
| 102 |
+
del cond, c, x_samples_ddim, samples_ddim, uc, input_im
|
| 103 |
+
torch.cuda.empty_cache()
|
| 104 |
return ret_imgs
|
| 105 |
|
| 106 |
|
|
|
|
| 127 |
del input_im
|
| 128 |
torch.cuda.empty_cache()
|
| 129 |
|
| 130 |
+
@torch.no_grad()
|
| 131 |
def predict_stage1_gradio(model, raw_im, save_path = "", adjust_set=[], device="cuda", ddim_steps=75, scale=3.0):
|
| 132 |
# raw_im = raw_im.resize([256, 256], Image.LANCZOS)
|
| 133 |
# input_im_init = preprocess_image(models, raw_im, preprocess=False)
|
|
|
|
| 159 |
out_image.save(os.path.join(save_path, '%d.png'%(stage1_idx)))
|
| 160 |
sample_idx += 1
|
| 161 |
del x_samples_ddims_8
|
|
|
|
| 162 |
del sampler
|
| 163 |
torch.cuda.empty_cache()
|
| 164 |
return ret_imgs
|
|
|
|
| 189 |
x_sample_stage2 = 255.0 * rearrange(x_samples_ddims_stage2[stage2_idx].numpy(), 'c h w -> h w c')
|
| 190 |
Image.fromarray(x_sample_stage2.astype(np.uint8)).save(os.path.join(save_path_stage2, '%d_%d.png'%(stage1_idx, stage2_idx)))
|
| 191 |
del input_im
|
|
|
|
| 192 |
del x_samples_ddims_stage2
|
| 193 |
torch.cuda.empty_cache()
|
| 194 |
|