lukecq commited on
Commit
8a69f7f
·
1 Parent(s): d4a063b

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +20 -11
README.md CHANGED
@@ -19,6 +19,7 @@ The model backbone is RoBERTa-base.
19
 
20
  The model is tuned with unlabeled data using a learning objective called first sentence prediction (FSP).
21
  The FSP task is designed by considering both the nature of the unlabeled corpus and the input/output format of classification tasks.
 
22
  The training and validation sets are constructed from the unlabeled corpus using FSP.
23
 
24
  During tuning, BERT-like pre-trained masked language
@@ -56,8 +57,9 @@ model = AutoModelForSequenceClassification.from_pretrained("DAMO-NLP-SG/zero-sho
56
  text = "I love this place! The food is always so fresh and delicious."
57
  list_label = ["negative", "positive"]
58
 
 
59
  list_ABC = [x for x in string.ascii_uppercase]
60
- def add_prefix(text, list_label, shuffle = False):
61
  list_label = [x+'.' if x[-1] != '.' else x for x in list_label]
62
  list_label_new = list_label + [tokenizer.pad_token]* (20 - len(list_label))
63
  if shuffle:
@@ -65,16 +67,23 @@ def add_prefix(text, list_label, shuffle = False):
65
  s_option = ' '.join(['('+list_ABC[i]+') '+list_label_new[i] for i in range(len(list_label_new))])
66
  return f'{s_option} {tokenizer.sep_token} {text}', list_label_new
67
 
68
- text_new, list_label_new = add_prefix(text,list_label,shuffle=False)
69
-
70
- encoding = tokenizer([text_new],truncation=True, padding='max_length',max_length=512, return_tensors='pt')
71
- with torch.no_grad():
72
- logits = model(**encoding).logits
 
 
73
  probs = torch.nn.functional.softmax(logits, dim = -1).tolist()
74
- predictions = torch.argmax(logits, dim=-1)
 
 
 
 
75
 
76
- print(probs)
77
- print(predictions)
 
78
  ```
79
 
80
 
@@ -89,8 +98,8 @@ print(predictions)
89
  Chip Hong Chang and
90
  Lidong Bing},
91
  title = {Zero-Shot Text Classification via Self-Supervised Tuning},
92
- booktitle = {Findings of the 2023 ACL},
93
  year = {2023},
94
- url = {},
95
  }
96
  ```
 
19
 
20
  The model is tuned with unlabeled data using a learning objective called first sentence prediction (FSP).
21
  The FSP task is designed by considering both the nature of the unlabeled corpus and the input/output format of classification tasks.
22
+
23
  The training and validation sets are constructed from the unlabeled corpus using FSP.
24
 
25
  During tuning, BERT-like pre-trained masked language
 
57
  text = "I love this place! The food is always so fresh and delicious."
58
  list_label = ["negative", "positive"]
59
 
60
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
61
  list_ABC = [x for x in string.ascii_uppercase]
62
+ def add_prefix(text,list_label, shuffle=False):
63
  list_label = [x+'.' if x[-1] != '.' else x for x in list_label]
64
  list_label_new = list_label + [tokenizer.pad_token]* (20 - len(list_label))
65
  if shuffle:
 
67
  s_option = ' '.join(['('+list_ABC[i]+') '+list_label_new[i] for i in range(len(list_label_new))])
68
  return f'{s_option} {tokenizer.sep_token} {text}', list_label_new
69
 
70
+ def check_text(model, text, list_label, shuffle=False):
71
+ text, list_label_new = add_prefix(text,list_label, shuffle = shuffle)
72
+ model.to(device).eval()
73
+ encoding = tokenizer([text],truncation=True, max_length=512)
74
+ item = {key: torch.tensor(val).to(device) for key, val in encoding.items()}
75
+ logits = model(**item).logits
76
+ logits = logits if shuffle else logits[:,0:len(list_label)]
77
  probs = torch.nn.functional.softmax(logits, dim = -1).tolist()
78
+ predictions = torch.argmax(logits, dim=-1).item()
79
+ probabilities = [round(x,5) for x in probs[0]]
80
+
81
+ print(f'prediction: {predictions} => ({list_ABC[predictions]}) {list_label_new[predictions]}')
82
+ print(f'probability: {round(probabilities[predictions]*100,2)}%')
83
 
84
+ check_text(model, text, list_label)
85
+ # prediction: 1 => (B) positive.
86
+ # probability: 99.92%
87
  ```
88
 
89
 
 
98
  Chip Hong Chang and
99
  Lidong Bing},
100
  title = {Zero-Shot Text Classification via Self-Supervised Tuning},
101
+ booktitle = {Findings of the Association for Computational Linguistics: ACL 2023},
102
  year = {2023},
103
+ url = {https://arxiv.org/abs/2305.11442},
104
  }
105
  ```