|
import os |
|
import numpy as np |
|
import json |
|
from tqdm import tqdm |
|
from argparse import ArgumentParser, Namespace |
|
|
|
parser = ArgumentParser() |
|
parser.add_argument("--video_folder", default='demo_videos', type=str, help="Folder path of your videos") |
|
parser.add_argument("--diff_feat_folder", default='output_diffusion_features', type=str, help="Folder path of your extracted diffusion features") |
|
parser.add_argument("--merged_feat_folder", default='output_merged_diffusion_features', type=str, help="Folder path of output merged diffusion features") |
|
args = parser.parse_args() |
|
|
|
feat_content_root = os.path.join(args.diff_feat_folder, 'feat_content') |
|
feat_degradation_root = os.path.join(args.diff_feat_folder, 'feat_degradation') |
|
|
|
all_frames_content = os.listdir(feat_content_root) |
|
all_frames_degradation = os.listdir(feat_degradation_root) |
|
|
|
content_save_root = os.path.join(args.merged_feat_folder, 'feat_content') |
|
os.makedirs(content_save_root, exist_ok=True) |
|
degradation_save_root = os.path.join(args.merged_feat_folder, 'feat_degradation') |
|
os.makedirs(degradation_save_root, exist_ok=True) |
|
|
|
file_list = os.listdir(args.video_folder) |
|
|
|
for file in tqdm(file_list): |
|
video_name = file.split('.')[0] |
|
|
|
content_feat_list = [f for f in all_frames_content if os.path.isfile(os.path.join(feat_content_root, f)) and f.startswith(video_name)] |
|
degradation_feat_list = [f for f in all_frames_degradation if os.path.isfile(os.path.join(feat_degradation_root, f)) and f.startswith(video_name)] |
|
|
|
content_feat_list = sorted(content_feat_list) |
|
degradation_feat_list = sorted(degradation_feat_list) |
|
|
|
num_frames = 15 |
|
final_frames_content, final_frames_degradation = [], [] |
|
|
|
if len(content_feat_list) < num_frames: |
|
quotient, remainder = divmod(num_frames, len(content_feat_list)) |
|
|
|
for i, item in enumerate(content_feat_list[:remainder]): |
|
final_frames_content.extend([item] * (quotient+1)) |
|
final_frames_degradation.extend([degradation_feat_list[i]] * (quotient+1)) |
|
|
|
for i, item in enumerate(content_feat_list[remainder:]): |
|
final_frames_content.extend([item] * quotient) |
|
final_frames_degradation.extend([degradation_feat_list[i]] * quotient) |
|
|
|
else: |
|
step = len(content_feat_list) / num_frames |
|
final_frames_content = [content_feat_list[int(i * step)] for i in range(num_frames)] |
|
final_frames_degradation = [degradation_feat_list[int(i * step)] for i in range(num_frames)] |
|
|
|
''' Merge content feat ''' |
|
merged_data = { |
|
'pred_latent_000': [], |
|
'input_latent_000': [], |
|
'input_unet_000_000': [], |
|
'input_unet_000_001': [], |
|
'output_unet_000_000': [], |
|
'output_unet_000_001': [] |
|
} |
|
|
|
|
|
for file in final_frames_content: |
|
data = np.load(os.path.join(feat_content_root, file)) |
|
for key in merged_data.keys(): |
|
merged_data[key].append(data[key]) |
|
|
|
for key in merged_data.keys(): |
|
merged_data[key] = np.stack(merged_data[key], axis=0) |
|
|
|
save_path = os.path.join(content_save_root, video_name + '.npz') |
|
np.savez(save_path, |
|
pred_latent=merged_data['pred_latent_000'], |
|
input_latent=merged_data['input_latent_000'], |
|
input_unet_000=merged_data['input_unet_000_000'], |
|
input_unet_001=merged_data['input_unet_000_001'], |
|
output_unet_000=merged_data['output_unet_000_000'], |
|
output_unet_001=merged_data['output_unet_000_001'], |
|
) |
|
|
|
''' Merge degradation feat ''' |
|
merged_data = {} |
|
|
|
|
|
for file in final_frames_degradation: |
|
data = np.load(os.path.join(feat_degradation_root, file)) |
|
|
|
for key in data.files: |
|
if key not in merged_data: |
|
merged_data[key] = [] |
|
|
|
merged_data[key].append(data[key]) |
|
|
|
for key in merged_data.keys(): |
|
merged_data[key] = np.stack(merged_data[key], axis=0) |
|
|
|
save_path = os.path.join(degradation_save_root, video_name + '.npz') |
|
np.savez(save_path, **{key: merged_data[key] for key in merged_data}) |
|
|
|
''' |
|
# Delete original files to save storage if needed |
|
for file in final_frames_content: |
|
file_path = os.path.join(feat_content_root, file) |
|
try: |
|
os.remove(file_path) |
|
|
|
except: |
|
print('{} cannot be deleted!'.format(file_path)) |
|
|
|
for file in final_frames_degradation: |
|
file_path = os.path.join(feat_degradation_root, file) |
|
try: |
|
os.remove(file_path) |
|
|
|
except: |
|
print('{} cannot be deleted!'.format(file_path)) |
|
''' |
|
|