|
from transformers import pipeline |
|
import re |
|
|
|
class ContextAwareLyricCleaner: |
|
def __init__(self): |
|
self.classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") |
|
self.replacements = { |
|
r'\bfuck\b': 'frick', |
|
r'\bshit\b': 'shoot', |
|
r'\bfucking\b': 'flipping', |
|
r'\bfucked\b': 'flipped', |
|
r'\bshitty\b': 'soggy', |
|
r'\bass\b': 'butt', |
|
r'\basses\b': 'butts', |
|
r'\basshole\b': 'jerkface', |
|
r'\bbitch\b': 'witch', |
|
r'\bbitches\b': 'witches', |
|
r'\bdamn\b': 'darn', |
|
r'\bcunt\b': 'punk', |
|
r'\bcrap\b': 'junk', |
|
r'\bdick\b': 'prick', |
|
r'\bfag\b': 'nerd', |
|
r'\bfaggot\b': 'loser', |
|
r'\bmothafucka\b': 'motherlover', |
|
r'\bmotherfucker\b': 'motherlover', |
|
r'\bhell\b': 'heck', |
|
r'\bprick\b': 'jerk', |
|
r'\bpiss\b': 'pee', |
|
r'\bpissed\b': 'mad', |
|
r'\bshithead\b': 'knucklehead', |
|
r'\bslut\b': 'scout', |
|
r'\bwhore\b': 'score', |
|
r'\bwtf\b': 'what the flip', |
|
r'\bwtf\b': 'what the flip', |
|
r'\bson of a bitch\b': 'son of a glitch', |
|
r'\bbastard\b': 'rascal', |
|
r'\bgod\b': 'gosh', |
|
r'\blord\b': 'love', |
|
|
|
} |
|
self.patterns = {re.compile(k, re.IGNORECASE): v for k, v in self.replacements.items()} |
|
self.explicit_labels = ["explicit", "offensive", "inappropriate"] |
|
self.threshold = 0.7 |
|
|
|
def is_explicit(self, text: str) -> bool: |
|
result = self.classifier(text, candidate_labels=self.explicit_labels + ["clean"], multi_label=False) |
|
scores = dict(zip(result['labels'], result['scores'])) |
|
|
|
return any(scores.get(label, 0) > self.threshold for label in self.explicit_labels) |
|
|
|
def clean_line(self, line: str) -> str: |
|
cleaned = line |
|
for pattern, replacement in self.patterns.items(): |
|
cleaned = pattern.sub(replacement, cleaned) |
|
return cleaned |
|
|
|
def clean_lyrics(self, lyrics: str) -> str: |
|
lines = lyrics.split('\n') |
|
cleaned_lines = [] |
|
for line in lines: |
|
if self.is_explicit(line): |
|
cleaned_lines.append(self.clean_line(line)) |
|
else: |
|
cleaned_lines.append(line) |
|
return '\n'.join(cleaned_lines) |
|
|