|  |  | 
					
						
						|  | import os | 
					
						
						|  | import numpy as np | 
					
						
						|  | import io | 
					
						
						|  | import re | 
					
						
						|  | import requests | 
					
						
						|  | import html | 
					
						
						|  | import hashlib | 
					
						
						|  | import urllib | 
					
						
						|  | import urllib.request | 
					
						
						|  | import scipy.linalg | 
					
						
						|  | import multiprocessing as mp | 
					
						
						|  | import glob | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | from tqdm import tqdm | 
					
						
						|  | from typing import Any, List, Tuple, Union, Dict, Callable | 
					
						
						|  |  | 
					
						
						|  | from torchvision.io import read_video | 
					
						
						|  | import torch; torch.set_grad_enabled(False) | 
					
						
						|  | from einops import rearrange | 
					
						
						|  |  | 
					
						
						|  | from nitro.util import isvideo | 
					
						
						|  |  | 
					
						
						|  | def compute_frechet_distance(mu_sample,sigma_sample,mu_ref,sigma_ref) -> float: | 
					
						
						|  | print('Calculate frechet distance...') | 
					
						
						|  | m = np.square(mu_sample - mu_ref).sum() | 
					
						
						|  | s, _ = scipy.linalg.sqrtm(np.dot(sigma_sample, sigma_ref), disp=False) | 
					
						
						|  | fid = np.real(m + np.trace(sigma_sample + sigma_ref - s * 2)) | 
					
						
						|  |  | 
					
						
						|  | return float(fid) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def compute_stats(feats: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: | 
					
						
						|  | mu = feats.mean(axis=0) | 
					
						
						|  | sigma = np.cov(feats, rowvar=False) | 
					
						
						|  |  | 
					
						
						|  | return mu, sigma | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def open_url(url: str, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False) -> Any: | 
					
						
						|  | """Download the given URL and return a binary-mode file object to access the data.""" | 
					
						
						|  | assert num_attempts >= 1 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if not re.match('^[a-z]+://', url): | 
					
						
						|  | return url if return_filename else open(url, "rb") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if url.startswith('file://'): | 
					
						
						|  | filename = urllib.parse.urlparse(url).path | 
					
						
						|  | if re.match(r'^/[a-zA-Z]:', filename): | 
					
						
						|  | filename = filename[1:] | 
					
						
						|  | return filename if return_filename else open(filename, "rb") | 
					
						
						|  |  | 
					
						
						|  | url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | url_name = None | 
					
						
						|  | url_data = None | 
					
						
						|  | with requests.Session() as session: | 
					
						
						|  | if verbose: | 
					
						
						|  | print("Downloading %s ..." % url, end="", flush=True) | 
					
						
						|  | for attempts_left in reversed(range(num_attempts)): | 
					
						
						|  | try: | 
					
						
						|  | with session.get(url) as res: | 
					
						
						|  | res.raise_for_status() | 
					
						
						|  | if len(res.content) == 0: | 
					
						
						|  | raise IOError("No data received") | 
					
						
						|  |  | 
					
						
						|  | if len(res.content) < 8192: | 
					
						
						|  | content_str = res.content.decode("utf-8") | 
					
						
						|  | if "download_warning" in res.headers.get("Set-Cookie", ""): | 
					
						
						|  | links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link] | 
					
						
						|  | if len(links) == 1: | 
					
						
						|  | url = requests.compat.urljoin(url, links[0]) | 
					
						
						|  | raise IOError("Google Drive virus checker nag") | 
					
						
						|  | if "Google Drive - Quota exceeded" in content_str: | 
					
						
						|  | raise IOError("Google Drive download quota exceeded -- please try again later") | 
					
						
						|  |  | 
					
						
						|  | match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", "")) | 
					
						
						|  | url_name = match[1] if match else url | 
					
						
						|  | url_data = res.content | 
					
						
						|  | if verbose: | 
					
						
						|  | print(" done") | 
					
						
						|  | break | 
					
						
						|  | except KeyboardInterrupt: | 
					
						
						|  | raise | 
					
						
						|  | except: | 
					
						
						|  | if not attempts_left: | 
					
						
						|  | if verbose: | 
					
						
						|  | print(" failed") | 
					
						
						|  | raise | 
					
						
						|  | if verbose: | 
					
						
						|  | print(".", end="", flush=True) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | assert not return_filename | 
					
						
						|  | return io.BytesIO(url_data) | 
					
						
						|  |  | 
					
						
						|  | def load_video(ip): | 
					
						
						|  | vid, *_ = read_video(ip) | 
					
						
						|  | vid = rearrange(vid, 't h w c -> t c h w').to(torch.uint8) | 
					
						
						|  | return vid | 
					
						
						|  |  | 
					
						
						|  | def get_data_from_str(input_str,nprc = None): | 
					
						
						|  | assert os.path.isdir(input_str), f'Specified input folder "{input_str}" is not a directory' | 
					
						
						|  | vid_filelist = glob.glob(os.path.join(input_str,'*.mp4')) | 
					
						
						|  | print(f'Found {len(vid_filelist)} videos in dir {input_str}') | 
					
						
						|  |  | 
					
						
						|  | if nprc is None: | 
					
						
						|  | try: | 
					
						
						|  | nprc = mp.cpu_count() | 
					
						
						|  | except NotImplementedError: | 
					
						
						|  | print('WARNING: cpu_count() not avlailable, using only 1 cpu for video loading') | 
					
						
						|  | nprc = 1 | 
					
						
						|  |  | 
					
						
						|  | pool = mp.Pool(processes=nprc) | 
					
						
						|  |  | 
					
						
						|  | vids = [] | 
					
						
						|  | for v in tqdm(pool.imap_unordered(load_video,vid_filelist),total=len(vid_filelist),desc='Loading videos...'): | 
					
						
						|  | vids.append(v) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | vids = torch.stack(vids,dim=0).float() | 
					
						
						|  |  | 
					
						
						|  | return vids | 
					
						
						|  |  | 
					
						
						|  | def get_stats(stats): | 
					
						
						|  | assert os.path.isfile(stats) and stats.endswith('.npz'), f'no stats found under {stats}' | 
					
						
						|  |  | 
					
						
						|  | print(f'Using precomputed statistics under {stats}') | 
					
						
						|  | stats = np.load(stats) | 
					
						
						|  | stats = {key: stats[key] for key in stats.files} | 
					
						
						|  |  | 
					
						
						|  | return stats | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @torch.no_grad() | 
					
						
						|  | def compute_fvd(ref_input, sample_input, bs=32, | 
					
						
						|  | ref_stats=None, | 
					
						
						|  | sample_stats=None, | 
					
						
						|  | nprc_load=None): | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | calc_stats = ref_stats is None or sample_stats is None | 
					
						
						|  |  | 
					
						
						|  | if calc_stats: | 
					
						
						|  |  | 
					
						
						|  | only_ref = sample_stats is not None | 
					
						
						|  | only_sample = ref_stats is not None | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if isinstance(ref_input,str) and not only_sample: | 
					
						
						|  | ref_input = get_data_from_str(ref_input,nprc_load) | 
					
						
						|  |  | 
					
						
						|  | if isinstance(sample_input, str) and not only_ref: | 
					
						
						|  | sample_input = get_data_from_str(sample_input, nprc_load) | 
					
						
						|  |  | 
					
						
						|  | stats = compute_statistics(sample_input,ref_input, | 
					
						
						|  | device='cuda' if torch.cuda.is_available() else 'cpu', | 
					
						
						|  | bs=bs, | 
					
						
						|  | only_ref=only_ref, | 
					
						
						|  | only_sample=only_sample) | 
					
						
						|  |  | 
					
						
						|  | if only_ref: | 
					
						
						|  | stats.update(get_stats(sample_stats)) | 
					
						
						|  | elif only_sample: | 
					
						
						|  | stats.update(get_stats(ref_stats)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | else: | 
					
						
						|  | stats = get_stats(sample_stats) | 
					
						
						|  | stats.update(get_stats(ref_stats)) | 
					
						
						|  |  | 
					
						
						|  | fvd = compute_frechet_distance(**stats) | 
					
						
						|  |  | 
					
						
						|  | return {'FVD' : fvd,} | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @torch.no_grad() | 
					
						
						|  | def compute_statistics(videos_fake, videos_real, device: str='cuda', bs=32, only_ref=False,only_sample=False) -> Dict: | 
					
						
						|  | detector_url = 'https://www.dropbox.com/s/ge9e5ujwgetktms/i3d_torchscript.pt?dl=1' | 
					
						
						|  | detector_kwargs = dict(rescale=True, resize=True, return_features=True) | 
					
						
						|  |  | 
					
						
						|  | with open_url(detector_url, verbose=False) as f: | 
					
						
						|  | detector = torch.jit.load(f).eval().to(device) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | assert not (only_sample and only_ref), 'only_ref and only_sample arguments are mutually exclusive' | 
					
						
						|  |  | 
					
						
						|  | ref_embed, sample_embed = [], [] | 
					
						
						|  |  | 
					
						
						|  | info = f'Computing I3D activations for FVD score with batch size {bs}' | 
					
						
						|  |  | 
					
						
						|  | if only_ref: | 
					
						
						|  |  | 
					
						
						|  | if not isvideo(videos_real): | 
					
						
						|  |  | 
					
						
						|  | videos_real = torch.from_numpy(videos_real).permute(0, 4, 1, 2, 3).float() | 
					
						
						|  | print(videos_real.shape) | 
					
						
						|  |  | 
					
						
						|  | if videos_real.shape[0] % bs == 0: | 
					
						
						|  | n_secs = videos_real.shape[0] // bs | 
					
						
						|  | else: | 
					
						
						|  | n_secs = videos_real.shape[0] // bs + 1 | 
					
						
						|  |  | 
					
						
						|  | videos_real = torch.tensor_split(videos_real, n_secs, dim=0) | 
					
						
						|  |  | 
					
						
						|  | for ref_v in tqdm(videos_real, total=len(videos_real),desc=info): | 
					
						
						|  |  | 
					
						
						|  | feats_ref = detector(ref_v.to(device).contiguous(), **detector_kwargs).cpu().numpy() | 
					
						
						|  | ref_embed.append(feats_ref) | 
					
						
						|  |  | 
					
						
						|  | elif only_sample: | 
					
						
						|  |  | 
					
						
						|  | if not isvideo(videos_fake): | 
					
						
						|  |  | 
					
						
						|  | videos_fake = torch.from_numpy(videos_fake).permute(0, 4, 1, 2, 3).float() | 
					
						
						|  | print(videos_fake.shape) | 
					
						
						|  |  | 
					
						
						|  | if videos_fake.shape[0] % bs == 0: | 
					
						
						|  | n_secs = videos_fake.shape[0] // bs | 
					
						
						|  | else: | 
					
						
						|  | n_secs = videos_fake.shape[0] // bs + 1 | 
					
						
						|  |  | 
					
						
						|  | videos_real = torch.tensor_split(videos_real, n_secs, dim=0) | 
					
						
						|  |  | 
					
						
						|  | for sample_v in tqdm(videos_fake, total=len(videos_real),desc=info): | 
					
						
						|  | feats_sample = detector(sample_v.to(device).contiguous(), **detector_kwargs).cpu().numpy() | 
					
						
						|  | sample_embed.append(feats_sample) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | else: | 
					
						
						|  |  | 
					
						
						|  | if not isvideo(videos_real): | 
					
						
						|  |  | 
					
						
						|  | videos_real = torch.from_numpy(videos_real).permute(0, 4, 1, 2, 3).float() | 
					
						
						|  |  | 
					
						
						|  | if not isvideo(videos_fake): | 
					
						
						|  | videos_fake = torch.from_numpy(videos_fake).permute(0, 4, 1, 2, 3).float() | 
					
						
						|  |  | 
					
						
						|  | if videos_fake.shape[0] % bs == 0: | 
					
						
						|  | n_secs = videos_fake.shape[0] // bs | 
					
						
						|  | else: | 
					
						
						|  | n_secs = videos_fake.shape[0] // bs + 1 | 
					
						
						|  |  | 
					
						
						|  | videos_real = torch.tensor_split(videos_real, n_secs, dim=0) | 
					
						
						|  | videos_fake = torch.tensor_split(videos_fake, n_secs, dim=0) | 
					
						
						|  |  | 
					
						
						|  | for ref_v, sample_v in tqdm(zip(videos_real,videos_fake),total=len(videos_fake),desc=info): | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | feats_sample = detector(sample_v.to(device).contiguous(), **detector_kwargs).cpu().numpy() | 
					
						
						|  | feats_ref = detector(ref_v.to(device).contiguous(), **detector_kwargs).cpu().numpy() | 
					
						
						|  | sample_embed.append(feats_sample) | 
					
						
						|  | ref_embed.append(feats_ref) | 
					
						
						|  |  | 
					
						
						|  | out = dict() | 
					
						
						|  | if len(sample_embed) > 0: | 
					
						
						|  | sample_embed = np.concatenate(sample_embed,axis=0) | 
					
						
						|  | mu_sample, sigma_sample = compute_stats(sample_embed) | 
					
						
						|  | out.update({'mu_sample': mu_sample, | 
					
						
						|  | 'sigma_sample': sigma_sample}) | 
					
						
						|  |  | 
					
						
						|  | if len(ref_embed) > 0: | 
					
						
						|  | ref_embed = np.concatenate(ref_embed,axis=0) | 
					
						
						|  | mu_ref, sigma_ref = compute_stats(ref_embed) | 
					
						
						|  | out.update({'mu_ref': mu_ref, | 
					
						
						|  | 'sigma_ref': sigma_ref}) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | return out | 
					
						
						|  |  |