destonedbob's picture
Create README.md
2d9f6c2 verified
metadata
language:
  - en
base_model:
  - distilbert/distilbert-base-cased
pipeline_tag: text-classification

Finetuned model for a university project to identify entity within sentence (Trump, Kamala, Others).

model_name = "distilbert-base-cased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained('./model/multilabel_entity_distil_bert_7epochs_lr5e-5_d&c_dataset')
model.to('cuda')

def get_probabilities(texts, score=False):
    inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True)
    inputs.to('cuda')
    outputs = model(**inputs)
    logits = outputs.logits
    probabilities = torch.sigmoid(logits)
    if not score:
        return np.array(list(map(lambda x: 1 if x > 0.65 else 0, probabilities.cpu().detach().numpy()[0].tolist())))
    else:
        return probabilities.cpu().detach().numpy()[0]

df['entity_ids'] = df.sentence.apply(get_probabilities)