|
--- |
|
language: |
|
- ru |
|
library_name: lstm |
|
pipeline_tag: text-classification |
|
tags: |
|
- news |
|
- media |
|
- russian |
|
datasets: |
|
- data-silence/rus_news_classifier |
|
--- |
|
|
|
# LSTM Text Classifier |
|
|
|
This is a LSTM model for text classification, trained on |
|
my [news dataset](https://huggingface.co/datasets/data-silence/rus_news_classifier), consisting of news from the last 5 |
|
years, hosted on Hugging Face Hub. |
|
The learning news dataset is a well-balanced sample of recent news from the last five years. |
|
|
|
## Model Description |
|
|
|
This model uses LSTM to classify text into 11 categories. It has been trained on ~70_000 examples and achieves an |
|
accuracy of 0.8691 on a test dataset. |
|
|
|
## Task |
|
|
|
The model is designed to classify russian languages news articles into 11 categories. |
|
|
|
## Categories |
|
|
|
The news category is assigned by the classifier to one of 11 categories: |
|
|
|
- climate (климат) |
|
- conflicts (конфликты) |
|
- culture (культура) |
|
- economy (экономика) |
|
- gloss (глянец) |
|
- health (здоровье) |
|
- politics (политика) |
|
- science (наука) |
|
- society (общество) |
|
- sports (спорт) |
|
- travel (путешествия) |
|
|
|
|
|
## Intended uses & limitations |
|
|
|
This model has been trained and downloaded for training purposes only. |
|
|
|
You should not use this model to solve practical problems: LSTM is not the best and fastest solution for text classification. |
|
Moreover, the model architecture is not compatible enough to work with the HF library (pipline, endpoints, etc. are not supported). |
|
|
|
The "gloss" category is used to select yellow press, trashy and dubious news. The model can get confused in the |
|
classification of news categories politics, society and conflicts. |
|
|
|
## Usage |
|
|
|
|
|
|
|
Example of how to use the model: |
|
|
|
```python |
|
import torch.nn as nn |
|
from transformers import BertModel |
|
import torch |
|
from transformers import AutoTokenizer |
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
class BiLSTMClassifier(nn.Module): |
|
def __init__(self, hidden_dim, output_dim, n_layers, dropout): |
|
super(BiLSTMClassifier, self).__init__() |
|
self.bert = BertModel.from_pretrained("bert-base-multilingual-cased") |
|
self.lstm = nn.LSTM(self.bert.config.hidden_size, hidden_dim, num_layers=n_layers, |
|
bidirectional=True, dropout=dropout, batch_first=True) |
|
self.fc = nn.Linear(hidden_dim * 2, output_dim) |
|
self.dropout = nn.Dropout(dropout) |
|
|
|
def forward(self, input_ids, attention_mask, labels=None): |
|
with torch.no_grad(): |
|
embedded = self.bert(input_ids=input_ids, attention_mask=attention_mask)[0] |
|
lstm_out, _ = self.lstm(embedded) |
|
pooled = torch.mean(lstm_out, dim=1) |
|
logits = self.fc(self.dropout(pooled)) |
|
|
|
if labels is not None: |
|
loss_fn = nn.CrossEntropyLoss() |
|
loss = loss_fn(logits, labels) |
|
return {"loss": loss, "logits": logits} # Возвращаем словарь |
|
return logits # Возвращаем логиты, если метки не переданы |
|
|
|
|
|
categories = ['climate', 'conflicts', 'culture', 'economy', 'gloss', 'health', |
|
'politics', 'science', 'society', 'sports', 'travel'] |
|
|
|
repo_id = "data-silence/lstm-news-classifier" |
|
tokenizer = AutoTokenizer.from_pretrained(repo_id) |
|
model_path = hf_hub_download(repo_id=repo_id, filename="model.pth") |
|
|
|
model = torch.load(model_path) |
|
|
|
def get_predictions(news: str, model) -> str: |
|
with torch.no_grad(): |
|
inputs = tokenizer(news, return_tensors="pt") |
|
del inputs['token_type_ids'] |
|
output = model.forward(**inputs) |
|
id_best_label = torch.argmax(output[0, :], dim=-1).detach().cpu().numpy() |
|
prediction = categories[id_best_label] |
|
return prediction |
|
|
|
|
|
# Использование классификатора |
|
get_predictions('В Париже завершилась церемония завершения Олимпийский игр', model=model) |
|
# 'sports' |
|
``` |