lmzjms's picture
Upload 1162 files
0b32ad6 verified
raw
history blame
1.56 kB
import os
import torch
import random
import argparse
import transformers
from s3prl import hub
from packaging import version
SAMPLE_RATE = 16000
BATCH_SIZE = 8
parser = argparse.ArgumentParser()
parser.add_argument("--base", action="store_true")
parser.add_argument("--large", action="store_true")
parser.add_argument("--device", default="cuda")
args = parser.parse_args()
assert version.parse(transformers.__version__) <= version.parse(
"4.9.0"
), "Newer version of transformers change the places for feature extraction."
assert args.base or args.large
s3prl_str = "wav2vec2_base_960" if args.base else "wav2vec2_large_ll60k"
huggingface_str = "wav2vec2_hug_base_960" if args.base else "wav2vec2_hug_large_ll60k"
s3prl = getattr(hub, s3prl_str)().to(args.device)
huggingface = getattr(hub, huggingface_str)().to(args.device)
if args.base:
s3prl.wav_normalize = True
s3prl.apply_padding_mask = False
s3prl.numpy_wav_normalize = True
s3prl.eval()
huggingface.eval()
wavs = [
torch.randn(random.randint(SAMPLE_RATE * 1, SAMPLE_RATE * 15)).to(args.device)
for _ in range(BATCH_SIZE)
]
with torch.no_grad():
hiddens1 = s3prl(wavs)["hidden_states"]
hiddens2 = huggingface(wavs)["hidden_states"]
assert len(hiddens1) == len(hiddens2)
diffs = []
for idx, (hidden1, hidden2) in enumerate(zip(hiddens1, hiddens2)):
diff = (hidden1 - hidden2).abs().max().item()
print(f"hidden {idx} difference: {diff}")
diffs.append(diff)
print(f"Max difference: {torch.FloatTensor(diffs).max().item()}")