UIVQA_demo / merge.py
Madlord's picture
Upload 4 files
3beb455 verified
import os
import numpy as np
import json
from tqdm import tqdm
from argparse import ArgumentParser, Namespace
parser = ArgumentParser()
parser.add_argument("--video_folder", default='demo_videos', type=str, help="Folder path of your videos")
parser.add_argument("--diff_feat_folder", default='output_diffusion_features', type=str, help="Folder path of your extracted diffusion features")
parser.add_argument("--merged_feat_folder", default='output_merged_diffusion_features', type=str, help="Folder path of output merged diffusion features")
args = parser.parse_args()
feat_content_root = os.path.join(args.diff_feat_folder, 'feat_content')
feat_degradation_root = os.path.join(args.diff_feat_folder, 'feat_degradation')
all_frames_content = os.listdir(feat_content_root)
all_frames_degradation = os.listdir(feat_degradation_root)
content_save_root = os.path.join(args.merged_feat_folder, 'feat_content')
os.makedirs(content_save_root, exist_ok=True)
degradation_save_root = os.path.join(args.merged_feat_folder, 'feat_degradation')
os.makedirs(degradation_save_root, exist_ok=True)
file_list = os.listdir(args.video_folder)
for file in tqdm(file_list):
video_name = file.split('.')[0]
content_feat_list = [f for f in all_frames_content if os.path.isfile(os.path.join(feat_content_root, f)) and f.startswith(video_name)]
degradation_feat_list = [f for f in all_frames_degradation if os.path.isfile(os.path.join(feat_degradation_root, f)) and f.startswith(video_name)]
content_feat_list = sorted(content_feat_list)
degradation_feat_list = sorted(degradation_feat_list)
num_frames = 15 # same as the number set in generate_frame.py
final_frames_content, final_frames_degradation = [], []
if len(content_feat_list) < num_frames:
quotient, remainder = divmod(num_frames, len(content_feat_list))
for i, item in enumerate(content_feat_list[:remainder]):
final_frames_content.extend([item] * (quotient+1))
final_frames_degradation.extend([degradation_feat_list[i]] * (quotient+1))
for i, item in enumerate(content_feat_list[remainder:]):
final_frames_content.extend([item] * quotient)
final_frames_degradation.extend([degradation_feat_list[i]] * quotient)
else:
step = len(content_feat_list) / num_frames
final_frames_content = [content_feat_list[int(i * step)] for i in range(num_frames)]
final_frames_degradation = [degradation_feat_list[int(i * step)] for i in range(num_frames)]
''' Merge content feat '''
merged_data = {
'pred_latent_000': [],
'input_latent_000': [],
'input_unet_000_000': [],
'input_unet_000_001': [],
'output_unet_000_000': [],
'output_unet_000_001': []
}
# merge feat of different frames
for file in final_frames_content:
data = np.load(os.path.join(feat_content_root, file))
for key in merged_data.keys():
merged_data[key].append(data[key])
for key in merged_data.keys():
merged_data[key] = np.stack(merged_data[key], axis=0)
save_path = os.path.join(content_save_root, video_name + '.npz')
np.savez(save_path,
pred_latent=merged_data['pred_latent_000'],
input_latent=merged_data['input_latent_000'],
input_unet_000=merged_data['input_unet_000_000'],
input_unet_001=merged_data['input_unet_000_001'],
output_unet_000=merged_data['output_unet_000_000'],
output_unet_001=merged_data['output_unet_000_001'],
)
''' Merge degradation feat '''
merged_data = {}
# merge feat of different frames
for file in final_frames_degradation:
data = np.load(os.path.join(feat_degradation_root, file))
for key in data.files:
if key not in merged_data:
merged_data[key] = []
merged_data[key].append(data[key])
for key in merged_data.keys():
merged_data[key] = np.stack(merged_data[key], axis=0)
save_path = os.path.join(degradation_save_root, video_name + '.npz')
np.savez(save_path, **{key: merged_data[key] for key in merged_data})
'''
# Delete original files to save storage if needed
for file in final_frames_content:
file_path = os.path.join(feat_content_root, file)
try:
os.remove(file_path)
except:
print('{} cannot be deleted!'.format(file_path))
for file in final_frames_degradation:
file_path = os.path.join(feat_degradation_root, file)
try:
os.remove(file_path)
except:
print('{} cannot be deleted!'.format(file_path))
'''