leannetanyt commited on
Commit
c4cefdf
·
verified ·
1 Parent(s): efdd6d9

feat: upload inference script

Browse files
Files changed (1) hide show
  1. inference.py +35 -10
inference.py CHANGED
@@ -1,18 +1,43 @@
 
1
  import os
 
2
  import numpy as np
3
  from openai import OpenAI
4
  from transformers import AutoModel
5
 
6
- texts = ["Eh you damn stupid lah!", "Have a nice day :)", "This is cool~"]
7
 
8
- # Load model directly from Hub
9
- model = AutoModel.from_pretrained("govtech/lionguard-2", trust_remote_code=True)
 
10
 
11
- # Get embeddings (users to input their own OpenAI API key)
12
- client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
13
- response = client.embeddings.create(input=texts, model="text-embedding-3-large")
14
- embeddings = np.array([data.embedding for data in response.data])
15
 
16
- # Run inference
17
- predictions = model.predict(embeddings)
18
- print(predictions)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
  import os
3
+ import sys
4
  import numpy as np
5
  from openai import OpenAI
6
  from transformers import AutoModel
7
 
 
8
 
9
+ def infer(texts):
10
+ # Load model directly from Hub
11
+ model = AutoModel.from_pretrained("govtech/lionguard-2", trust_remote_code=True)
12
 
13
+ # Get embeddings (users to input their own OpenAI API key)
14
+ client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
15
+ response = client.embeddings.create(input=texts, model="text-embedding-3-large")
16
+ embeddings = np.array([data.embedding for data in response.data])
17
 
18
+ # Run inference
19
+ results = model.predict(embeddings)
20
+ return results
21
+
22
+
23
+ if __name__ == "__main__":
24
+
25
+ # Load the data
26
+ try:
27
+ input_data = sys.argv[1]
28
+ batch_text = json.loads(input_data)
29
+ print("Using provided input texts")
30
+
31
+ except (json.JSONDecodeError, IndexError) as e:
32
+ print(f"Error parsing input data: {e}")
33
+ print("Falling back to default sample texts")
34
+
35
+ batch_text = ["Eh you damn stupid lah!", "Have a nice day :)"]
36
+
37
+ # Generate the scores and predictions
38
+ results = infer(batch_text)
39
+ for i in range(len(batch_text)):
40
+ print(f"Text: '{batch_text[i]}'")
41
+ for category in results.keys():
42
+ print(f"[Text {i+1}] {category} score: {results[category][i]:.4f}")
43
+ print("---------------------------------------------")