import os import numpy as np import json from tqdm import tqdm from argparse import ArgumentParser, Namespace parser = ArgumentParser() parser.add_argument("--video_folder", default='demo_videos', type=str, help="Folder path of your videos") parser.add_argument("--diff_feat_folder", default='output_diffusion_features', type=str, help="Folder path of your extracted diffusion features") parser.add_argument("--merged_feat_folder", default='output_merged_diffusion_features', type=str, help="Folder path of output merged diffusion features") args = parser.parse_args() feat_content_root = os.path.join(args.diff_feat_folder, 'feat_content') feat_degradation_root = os.path.join(args.diff_feat_folder, 'feat_degradation') all_frames_content = os.listdir(feat_content_root) all_frames_degradation = os.listdir(feat_degradation_root) content_save_root = os.path.join(args.merged_feat_folder, 'feat_content') os.makedirs(content_save_root, exist_ok=True) degradation_save_root = os.path.join(args.merged_feat_folder, 'feat_degradation') os.makedirs(degradation_save_root, exist_ok=True) file_list = os.listdir(args.video_folder) for file in tqdm(file_list): video_name = file.split('.')[0] content_feat_list = [f for f in all_frames_content if os.path.isfile(os.path.join(feat_content_root, f)) and f.startswith(video_name)] degradation_feat_list = [f for f in all_frames_degradation if os.path.isfile(os.path.join(feat_degradation_root, f)) and f.startswith(video_name)] content_feat_list = sorted(content_feat_list) degradation_feat_list = sorted(degradation_feat_list) num_frames = 15 # same as the number set in generate_frame.py final_frames_content, final_frames_degradation = [], [] if len(content_feat_list) < num_frames: quotient, remainder = divmod(num_frames, len(content_feat_list)) for i, item in enumerate(content_feat_list[:remainder]): final_frames_content.extend([item] * (quotient+1)) final_frames_degradation.extend([degradation_feat_list[i]] * (quotient+1)) for i, item in enumerate(content_feat_list[remainder:]): final_frames_content.extend([item] * quotient) final_frames_degradation.extend([degradation_feat_list[i]] * quotient) else: step = len(content_feat_list) / num_frames final_frames_content = [content_feat_list[int(i * step)] for i in range(num_frames)] final_frames_degradation = [degradation_feat_list[int(i * step)] for i in range(num_frames)] ''' Merge content feat ''' merged_data = { 'pred_latent_000': [], 'input_latent_000': [], 'input_unet_000_000': [], 'input_unet_000_001': [], 'output_unet_000_000': [], 'output_unet_000_001': [] } # merge feat of different frames for file in final_frames_content: data = np.load(os.path.join(feat_content_root, file)) for key in merged_data.keys(): merged_data[key].append(data[key]) for key in merged_data.keys(): merged_data[key] = np.stack(merged_data[key], axis=0) save_path = os.path.join(content_save_root, video_name + '.npz') np.savez(save_path, pred_latent=merged_data['pred_latent_000'], input_latent=merged_data['input_latent_000'], input_unet_000=merged_data['input_unet_000_000'], input_unet_001=merged_data['input_unet_000_001'], output_unet_000=merged_data['output_unet_000_000'], output_unet_001=merged_data['output_unet_000_001'], ) ''' Merge degradation feat ''' merged_data = {} # merge feat of different frames for file in final_frames_degradation: data = np.load(os.path.join(feat_degradation_root, file)) for key in data.files: if key not in merged_data: merged_data[key] = [] merged_data[key].append(data[key]) for key in merged_data.keys(): merged_data[key] = np.stack(merged_data[key], axis=0) save_path = os.path.join(degradation_save_root, video_name + '.npz') np.savez(save_path, **{key: merged_data[key] for key in merged_data}) ''' # Delete original files to save storage if needed for file in final_frames_content: file_path = os.path.join(feat_content_root, file) try: os.remove(file_path) except: print('{} cannot be deleted!'.format(file_path)) for file in final_frames_degradation: file_path = os.path.join(feat_degradation_root, file) try: os.remove(file_path) except: print('{} cannot be deleted!'.format(file_path)) '''