Xenova HF Staff commited on
Commit
4cebf10
·
verified ·
1 Parent(s): 79226a1

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +124 -3
README.md CHANGED
@@ -1,3 +1,124 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ ---
4
+
5
+ ## Usage
6
+
7
+ ### ONNXRuntime
8
+
9
+ <details>
10
+
11
+ <summary>
12
+ First, define the <em>read_gif_frames</em> helper function (click to expand):
13
+ </summary>
14
+
15
+ ```py
16
+ import numpy as np
17
+ from PIL import Image, ImageSequence
18
+ import requests
19
+ from io import BytesIO
20
+ import os
21
+
22
+ def read_gif_frames(path_or_url, shortest_edge=None, center_crop=None):
23
+ # Load GIF from URL or local path
24
+ if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
25
+ response = requests.get(path_or_url)
26
+ gif = Image.open(BytesIO(response.content))
27
+ elif os.path.exists(path_or_url):
28
+ gif = Image.open(path_or_url)
29
+ else:
30
+ raise ValueError("Invalid URL or file path")
31
+
32
+ # Ensure it's a GIF
33
+ if gif.format != "GIF":
34
+ raise ValueError("Not a GIF file")
35
+
36
+ # Extract frames and convert to RGB
37
+ frames = []
38
+ for frame in ImageSequence.Iterator(gif):
39
+ rgb_frame = frame.convert("RGB") # Force 3 channels
40
+
41
+ # Resize if specified
42
+ if shortest_edge is not None:
43
+ w, h = rgb_frame.size
44
+ if h < w:
45
+ new_h = shortest_edge
46
+ new_w = int(w * shortest_edge / h)
47
+ else:
48
+ new_w = shortest_edge
49
+ new_h = int(h * shortest_edge / w)
50
+ rgb_frame = rgb_frame.resize((new_w, new_h), Image.LANCZOS)
51
+
52
+ # Center crop if specified
53
+ if center_crop is not None:
54
+ w, h = rgb_frame.size
55
+ left = (w - center_crop) // 2
56
+ top = (h - center_crop) // 2
57
+ right = left + center_crop
58
+ bottom = top + center_crop
59
+ rgb_frame = rgb_frame.crop((left, top, right, bottom))
60
+
61
+ frame_np = np.array(rgb_frame, dtype=np.uint8)
62
+ frame_np = np.transpose(frame_np, (2, 0, 1)) # HWC -> CHW
63
+ frames.append(frame_np)
64
+
65
+ return np.stack(frames) # Shape: [num_frames, 3, height, width]
66
+ ```
67
+
68
+ </details>
69
+
70
+
71
+ You can then run the model as follows:
72
+ ```py
73
+ import onnxruntime as ort
74
+ from huggingface_hub import hf_hub_download
75
+ from transformers import AutoConfig
76
+
77
+ model_id = "onnx-community/vjepa2-vitl-fpc32-256-diving48-ONNX"
78
+ config = AutoConfig.from_pretrained(model_id)
79
+ path = hf_hub_download(
80
+ repo_id=model_id,
81
+ filename="onnx/model.onnx",
82
+ )
83
+ ort_session = ort.InferenceSession(path)
84
+
85
+ # Load and preprocess video frames
86
+ video = read_gif_frames(
87
+ "http://www.svcl.ucsd.edu/projects/resound/imgs/19.gif",
88
+ shortest_edge=292,
89
+ center_crop=256,
90
+ )
91
+ mean = np.array([0.485, 0.456, 0.406]).reshape(3, 1, 1)
92
+ std = np.array([0.229, 0.224, 0.225]).reshape(3, 1, 1)
93
+ inputs = {
94
+ "pixel_values_videos": ((video / 255 - mean) / std)[np.newaxis, ...].astype(np.float32)
95
+ }
96
+
97
+ # Run the model
98
+ logits = ort_session.run(
99
+ None,
100
+ input_feed=inputs,
101
+ )[0]
102
+
103
+ top_k = 5
104
+ indices = np.argsort(logits[0])[-top_k:][::-1]
105
+
106
+ # Calculate softmax probabilities
107
+ exp_logits = np.exp(logits[0] - np.max(logits[0]))
108
+ softmax_probs = exp_logits / np.sum(exp_logits)
109
+
110
+ print(f"Top {top_k} predicted class names:")
111
+ for idx in indices:
112
+ text_label = config.id2label[idx]
113
+ print(f" - {text_label}: {softmax_probs[idx]:.2f}")
114
+ ```
115
+
116
+ Example output:
117
+ ```
118
+ Top 5 predicted class names:
119
+ - ['Forward', '15som', 'NoTwis', 'PIKE']: 0.69
120
+ - ['Reverse', 'Dive', 'NoTwis', 'PIKE']: 0.22
121
+ - ['Inward', '15som', 'NoTwis', 'PIKE']: 0.06
122
+ - ['Reverse', '15som', '05Twis', 'FREE']: 0.01
123
+ - ['Forward', '25som', 'NoTwis', 'PIKE']: 0.00
124
+ ```