jimnoneill commited on
Commit
e14e4bc
·
verified ·
1 Parent(s): 6a60d2d

Upload scientific_model_inference2.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scientific_model_inference2.py +989 -0
scientific_model_inference2.py ADDED
@@ -0,0 +1,989 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Scientific Summarization Model Inference Module - FIXED VERSION
4
+ Fixed generation errors and improved title quality
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import pandas as pd
10
+ import numpy as np
11
+ import pickle
12
+ import json
13
+ import re
14
+ from pathlib import Path
15
+ from sentence_transformers import SentenceTransformer
16
+ from transformers import AutoTokenizer, AutoModelForCausalLM
17
+ from peft import get_peft_model, LoraConfig, TaskType
18
+ from typing import Dict, List, Tuple, Optional
19
+ from datetime import datetime
20
+ import csv
21
+ from collections import defaultdict, Counter
22
+ from tqdm import tqdm
23
+ import unicodedata
24
+ import hashlib
25
+ import os
26
+ import gc
27
+ import warnings
28
+
29
+ # Suppress transformer warnings
30
+ warnings.filterwarnings("ignore", category=UserWarning, module="transformers")
31
+
32
+ # SPEED OPTIMIZATION: Enhanced environment setup for RTX 3080
33
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
34
+ os.environ["NCCL_P2P_DISABLE"] = "0"
35
+ os.environ["NCCL_IB_DISABLE"] = "0"
36
+ os.environ["ACCELERATE_DEVICE_PLACEMENT"] = "false"
37
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512,expandable_segments:True"
38
+
39
+ # SPEED OPTIMIZATION: Enable all performance optimizations
40
+ torch.backends.cuda.matmul.allow_tf32 = True
41
+ torch.backends.cudnn.allow_tf32 = True
42
+ torch.backends.cudnn.benchmark = True
43
+ torch.backends.cudnn.deterministic = False
44
+ torch._dynamo.config.suppress_errors = True
45
+
46
+ class Sbert2Prompt(nn.Module):
47
+ """Prompt generator from SBERT embeddings - matching training architecture"""
48
+ def __init__(self, sbert_dim, llama_hidden_dim, prompt_length=24): # Using 24 from training
49
+ super().__init__()
50
+ self.prompt_length = prompt_length
51
+ self.llama_hidden_dim = llama_hidden_dim
52
+
53
+ self.projection = nn.Sequential(
54
+ nn.Linear(sbert_dim, llama_hidden_dim * 2),
55
+ nn.GELU(),
56
+ nn.Dropout(0.1),
57
+ nn.Linear(llama_hidden_dim * 2, llama_hidden_dim * prompt_length)
58
+ )
59
+
60
+ def forward(self, sbert_emb):
61
+ B = sbert_emb.size(0)
62
+ out = self.projection(sbert_emb)
63
+ return out.view(B, self.prompt_length, self.llama_hidden_dim)
64
+
65
+ def normalize_characters(text):
66
+ """Normalize various Unicode characters to standard ASCII equivalents"""
67
+ if not isinstance(text, str):
68
+ return str(text)
69
+
70
+ # Normalize space characters
71
+ space_chars = ['\xa0', '\u2000', '\u2001', '\u2002', '\u2003', '\u2004', '\u2005', '\u2006', '\u2007', '\u2008', '\u2009', '\u200a', '\u202f', '\u205f', '\u3000']
72
+ for space in space_chars:
73
+ text = text.replace(space, ' ')
74
+
75
+ # Normalize single quotes
76
+ single_quotes = [''', ''', '‛', '′', '‹', '›', '‚', '‟']
77
+ for quote in single_quotes:
78
+ text = text.replace(quote, "'")
79
+
80
+ # Normalize double quotes
81
+ double_quotes = ['"', '"', '„', '‟', '«', '»', '〝', '〞', '〟', '"']
82
+ for quote in double_quotes:
83
+ text = text.replace(quote, '"')
84
+
85
+ # Remove or normalize any remaining special characters
86
+ text = unicodedata.normalize('NFKD', text)
87
+ return text
88
+
89
+ def clean_text(text):
90
+ """Clean and validate text data"""
91
+ if not text or str(text) in ['nan', 'None', '']:
92
+ return ""
93
+
94
+ text = normalize_characters(str(text))
95
+ text = re.sub(r'\s+', ' ', text).strip()
96
+ return text
97
+
98
+ class ScientificModelInference:
99
+ """Main inference class with fixed generation and better titles"""
100
+
101
+ def __init__(self, model_dir: str, device: str = "auto"):
102
+ """
103
+ Initialize the inference model with enhanced generation capabilities
104
+
105
+ Args:
106
+ model_dir: Path to saved model directory
107
+ device: Device to use ('auto', 'cuda', 'cpu')
108
+ """
109
+ self.model_dir = Path(model_dir)
110
+ self.device = device if device != "auto" else ("cuda" if torch.cuda.is_available() else "cpu")
111
+
112
+ # Load configuration
113
+ with open(self.model_dir / "config.json", 'r') as f:
114
+ self.config = json.load(f)
115
+
116
+ # ENHANCED: Update prompt length to match training (24)
117
+ if 'prompt_length' in self.config:
118
+ self.config['prompt_length'] = 24 # Match training configuration
119
+
120
+ print(f"🔧 Loading model on device: {self.device}")
121
+ self._load_models()
122
+
123
+ # Store keywords for title generation context
124
+ self._last_keywords = []
125
+ self._last_abstracts = [] # ENHANCED: Store abstracts for better context
126
+
127
+ # ENHANCED: Track title generation patterns and word frequency to avoid repetition
128
+ self._title_patterns_used = Counter()
129
+ self._title_word_frequency = Counter() # Track word usage across all titles
130
+
131
+ # SPEED OPTIMIZATION: Compile model for faster inference if supported
132
+ self._optimize_models()
133
+
134
+ def _load_models(self):
135
+ """Load all required models with speed optimizations"""
136
+ # SPEED OPTIMIZATION: Load SBERT model with optimizations
137
+ print("📊 Loading SBERT model with optimizations...")
138
+ self.sbert_model = SentenceTransformer(self.config['sbert_model_name'])
139
+ self.sbert_model = self.sbert_model.to(self.device)
140
+ self.sbert_model.eval()
141
+
142
+ # SPEED OPTIMIZATION: Disable gradients for SBERT
143
+ for param in self.sbert_model.parameters():
144
+ param.requires_grad = False
145
+
146
+ # Load tokenizer with optimizations
147
+ print("🔤 Loading tokenizer...")
148
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir / "model")
149
+ if self.tokenizer.pad_token is None:
150
+ self.tokenizer.pad_token = self.tokenizer.eos_token
151
+
152
+ # SPEED OPTIMIZATION: Load main model with better memory settings
153
+ print("🧠 Loading language model with enhanced generation support...")
154
+ self.model = AutoModelForCausalLM.from_pretrained(
155
+ self.model_dir / "model",
156
+ torch_dtype=torch.float16,
157
+ device_map="auto" if self.device == "cuda" else None,
158
+ low_cpu_mem_usage=True,
159
+ use_cache=True,
160
+ attn_implementation="flash_attention_2" if hasattr(torch.nn, 'scaled_dot_product_attention') else "eager"
161
+ )
162
+ self.model.eval()
163
+
164
+ # SPEED OPTIMIZATION: Disable gradients for inference
165
+ for param in self.model.parameters():
166
+ param.requires_grad = False
167
+
168
+ # Load prompt generator with correct architecture
169
+ print("⚡ Loading prompt generator (24 tokens)...")
170
+ self.prompt_generator = Sbert2Prompt(
171
+ self.config['embedding_dim'],
172
+ self.config['llama_hidden_dim'],
173
+ 24 # Match training prompt length
174
+ )
175
+ self.prompt_generator.load_state_dict(
176
+ torch.load(self.model_dir / "prompt_generator.pt", map_location=self.device, weights_only=False)
177
+ )
178
+ self.prompt_generator = self.prompt_generator.to(self.device, dtype=torch.float16)
179
+ self.prompt_generator.eval()
180
+
181
+ # SPEED OPTIMIZATION: Disable gradients for prompt generator
182
+ for param in self.prompt_generator.parameters():
183
+ param.requires_grad = False
184
+
185
+ print("✅ All models loaded with enhanced generation support!")
186
+
187
+ def _optimize_models(self):
188
+ """Apply additional speed optimizations"""
189
+ try:
190
+ # SPEED OPTIMIZATION: Try to compile models for faster inference (PyTorch 2.0+)
191
+ if hasattr(torch, 'compile') and torch.cuda.is_available():
192
+ print("🚀 Applying torch.compile optimizations...")
193
+ self.model = torch.compile(self.model, mode="reduce-overhead")
194
+ self.prompt_generator = torch.compile(self.prompt_generator, mode="reduce-overhead")
195
+ print("✅ Torch compile applied successfully!")
196
+ except Exception as e:
197
+ print(f"⚠️ Torch compile not available or failed: {e}")
198
+
199
+ # Pre-warm GPU
200
+ try:
201
+ if self.device == "cuda":
202
+ dummy_input = torch.randn(1, 1024, dtype=torch.float16, device=self.device)
203
+ _ = self.sbert_model.encode(["test"], convert_to_tensor=True, device=self.device)
204
+ del dummy_input
205
+ torch.cuda.empty_cache()
206
+ print("✅ GPU pre-warmed successfully!")
207
+ except Exception as e:
208
+ print(f"⚠️ GPU pre-warming failed: {e}")
209
+
210
+ def create_cluster_embedding(self, pmid_abstracts: List[str], keywords: List[str]) -> torch.Tensor:
211
+ """
212
+ ENHANCED: Create better cluster embedding with keyword weighting
213
+ """
214
+ # Store for context
215
+ self._last_keywords = keywords
216
+ self._last_abstracts = pmid_abstracts
217
+
218
+ # Combine all abstracts
219
+ combined_abstracts = " ".join([clean_text(abstract) for abstract in pmid_abstracts if abstract])
220
+
221
+ # ENHANCED: Better keyword processing with importance weighting
222
+ if keywords:
223
+ clean_keywords = []
224
+ keyword_weights = []
225
+
226
+ for i, kw in enumerate(keywords):
227
+ if isinstance(kw, str):
228
+ clean_kw = re.sub(r'\s*\([^)]+\)', '', kw).strip()
229
+ if clean_kw and len(clean_kw) > 1:
230
+ clean_keywords.append(clean_kw)
231
+ # Higher weight for earlier keywords (assumed more important)
232
+ keyword_weights.append(1.0 / (i + 1))
233
+
234
+ # Limit keywords but keep weights proportional
235
+ if len(clean_keywords) > 20:
236
+ clean_keywords = clean_keywords[:20]
237
+ keyword_weights = keyword_weights[:20]
238
+
239
+ # Normalize weights
240
+ if keyword_weights:
241
+ total_weight = sum(keyword_weights)
242
+ keyword_weights = [w/total_weight for w in keyword_weights]
243
+
244
+ # ENHANCED: Create weighted keyword text
245
+ keyword_text = ', '.join(clean_keywords)
246
+
247
+ # ENHANCED: Combine with emphasis on important keywords
248
+ important_keywords = clean_keywords[:5] if len(clean_keywords) >= 5 else clean_keywords
249
+ combined_text = f"{combined_abstracts}\n\nKey research topics: {', '.join(important_keywords)}. Additional concepts: {keyword_text}"
250
+ else:
251
+ combined_text = combined_abstracts
252
+
253
+ # Generate embedding with enhanced method
254
+ return self._compute_enhanced_embedding(combined_text, keywords)
255
+
256
+ def _compute_enhanced_embedding(self, text: str, keywords: List[str] = None) -> torch.Tensor:
257
+ """
258
+ ENHANCED: Compute embedding with better chunking and keyword integration
259
+ """
260
+ with torch.no_grad():
261
+ # Get main text embedding
262
+ text_embedding = self._compute_robust_embedding(text)
263
+
264
+ # ENHANCED: Add keyword embedding if available
265
+ if keywords and len(keywords) > 0:
266
+ # Create keyword-only embedding
267
+ keyword_text = ' [SEP] '.join(keywords[:15]) # Use separator tokens
268
+ keyword_embedding = self.sbert_model.encode(
269
+ [keyword_text],
270
+ convert_to_tensor=True,
271
+ device=self.device,
272
+ normalize_embeddings=True
273
+ ).squeeze(0).cpu()
274
+
275
+ # ENHANCED: Weighted combination (80% text, 20% keywords)
276
+ alpha = 0.85 # Text weight
277
+ beta = 0.15 # Keyword weight
278
+
279
+ combined_embedding = alpha * text_embedding + beta * keyword_embedding
280
+ combined_embedding = torch.nn.functional.normalize(combined_embedding.unsqueeze(0), p=2, dim=-1).squeeze(0)
281
+
282
+ return combined_embedding
283
+
284
+ return text_embedding
285
+
286
+ def _compute_robust_embedding(self, text: str) -> torch.Tensor:
287
+ """Compute robust embedding with chunking - optimized version"""
288
+ with torch.no_grad():
289
+ tokenized = self.sbert_model.tokenizer.encode(text, add_special_tokens=False)
290
+ total_tokens = len(tokenized)
291
+
292
+ if total_tokens <= 512:
293
+ embedding = self.sbert_model.encode(
294
+ [text],
295
+ convert_to_tensor=True,
296
+ device=self.device,
297
+ batch_size=1,
298
+ show_progress_bar=False,
299
+ normalize_embeddings=True
300
+ )
301
+ else:
302
+ # ENHANCED: Better chunking with overlap
303
+ chunks = []
304
+ chunk_weights = []
305
+
306
+ # Use sliding window with overlap
307
+ window_size = 512
308
+ stride = 256 # 50% overlap for better context
309
+
310
+ for i in range(0, total_tokens, stride):
311
+ chunk_tokens = tokenized[i:i + window_size]
312
+ if len(chunk_tokens) < 100: # Skip tiny chunks
313
+ break
314
+
315
+ chunk_text = self.sbert_model.tokenizer.decode(chunk_tokens, skip_special_tokens=True)
316
+ chunks.append(chunk_text)
317
+
318
+ # ENHANCED: Position-based weighting (first and last chunks more important)
319
+ position_weight = 1.2 if i == 0 else (1.1 if i + window_size >= total_tokens else 1.0)
320
+ chunk_weights.append(position_weight * len(chunk_tokens))
321
+
322
+ # Process chunks in batches
323
+ chunk_batch_size = 16
324
+ chunk_embeddings_list = []
325
+
326
+ for i in range(0, len(chunks), chunk_batch_size):
327
+ batch_chunks = chunks[i:i+chunk_batch_size]
328
+ batch_embeds = self.sbert_model.encode(
329
+ batch_chunks,
330
+ convert_to_tensor=True,
331
+ device=self.device,
332
+ batch_size=len(batch_chunks),
333
+ show_progress_bar=False,
334
+ normalize_embeddings=True
335
+ )
336
+ chunk_embeddings_list.append(batch_embeds)
337
+
338
+ chunk_embeddings = torch.cat(chunk_embeddings_list, dim=0)
339
+ chunk_weights_tensor = torch.tensor(chunk_weights, dtype=torch.float16, device=chunk_embeddings.device)
340
+
341
+ # Normalize weights
342
+ chunk_weights_tensor = chunk_weights_tensor / chunk_weights_tensor.sum()
343
+
344
+ # Weighted average
345
+ embedding = torch.sum(chunk_embeddings * chunk_weights_tensor.unsqueeze(1), dim=0, keepdim=True)
346
+
347
+ return embedding.squeeze(0).cpu()
348
+
349
+ def generate_research_analysis(self, embedding: torch.Tensor, max_length: int = 500) -> Tuple[str, str, str]:
350
+ """
351
+ FIXED: Generate with corrected generation parameters
352
+ """
353
+ self.model.eval()
354
+ self.prompt_generator.eval()
355
+
356
+ # FIXED: Use compatible generation configurations
357
+ generation_configs = [
358
+ {
359
+ 'name': 'high_quality',
360
+ 'temperature': 0.7,
361
+ 'top_p': 0.9,
362
+ 'top_k': 50,
363
+ 'num_beams': 5,
364
+ 'do_sample': True,
365
+ 'repetition_penalty': 1.15
366
+ },
367
+ {
368
+ 'name': 'diverse_beam',
369
+ 'num_beams': 5,
370
+ 'num_beam_groups': 5,
371
+ 'diversity_penalty': 0.5,
372
+ 'do_sample': False, # FIXED: Must be False for diverse beam search
373
+ 'temperature': 1.0, # Not used when do_sample=False
374
+ 'repetition_penalty': 1.2
375
+ },
376
+ {
377
+ 'name': 'focused',
378
+ 'temperature': 0.6,
379
+ 'top_p': 0.85,
380
+ 'top_k': 40,
381
+ 'num_beams': 6,
382
+ 'do_sample': True,
383
+ 'repetition_penalty': 1.1
384
+ }
385
+ ]
386
+
387
+ with torch.no_grad():
388
+ if embedding.dim() == 1:
389
+ embedding = embedding.unsqueeze(0)
390
+
391
+ embedding = embedding.to(self.device, dtype=torch.float16)
392
+ prefix_embeds = self.prompt_generator(embedding)
393
+
394
+ # ENHANCED: Better keyword context
395
+ if self._last_keywords:
396
+ # Clean keywords for better prompting
397
+ clean_keywords = []
398
+ for kw in self._last_keywords[:5]:
399
+ clean_kw = re.sub(r'[_-]', ' ', str(kw)).strip()
400
+ if clean_kw:
401
+ clean_keywords.append(clean_kw)
402
+ keywords_text = ', '.join(clean_keywords) if clean_keywords else 'research topics'
403
+ else:
404
+ keywords_text = 'research topics'
405
+
406
+ # ENHANCED: Diverse vocabulary instruction prompt to reduce repetition
407
+ instruction_start = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
408
+
409
+ You are a scientific theme analyst. Generate exactly three outputs for a biomedical topic:
410
+
411
+ TITLE: [8-12 word distinctive title using diverse vocabulary - avoid repeating 'research', 'analysis', 'study'. Use terms like: mechanisms, pathways, connections, interactions, dynamics, networks, insights, perspectives, implications, applications]
412
+ SHORT_SUMMARY: [2-3 sentences, 50-100 words describing the scientific domain and scope]
413
+ ABSTRACT: [4-6 sentences, 150-300 words detailed description of mechanisms, pathways, and clinical significance]
414
+
415
+ Use varied scientific terminology. Avoid repetitive language patterns. Focus on biological mechanisms, molecular pathways, clinical implications, and therapeutic potential.<|eot_id|><|start_header_id|>user<|end_header_id|>
416
+
417
+ Generate content for biomedical domain involving: {keywords_text}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
418
+
419
+ TITLE: """
420
+
421
+ instruction_tokens = self.tokenizer(
422
+ instruction_start,
423
+ return_tensors="pt",
424
+ add_special_tokens=False
425
+ )
426
+ instruction_embeds = self.model.get_input_embeddings()(instruction_tokens["input_ids"].to(prefix_embeds.device))
427
+
428
+ full_inputs_embeds = torch.cat([prefix_embeds, instruction_embeds], dim=1)
429
+
430
+ seq_len = full_inputs_embeds.shape[1]
431
+ attention_mask = torch.ones((1, seq_len), dtype=torch.long, device=prefix_embeds.device)
432
+
433
+ # Try different generation strategies
434
+ generated_text = None
435
+ for config in generation_configs[:2]: # Try first two configs
436
+ try:
437
+ # Build generation kwargs based on config
438
+ gen_kwargs = {
439
+ 'inputs_embeds': full_inputs_embeds,
440
+ 'attention_mask': attention_mask,
441
+ 'max_new_tokens': max_length,
442
+ 'min_new_tokens': 200,
443
+ 'num_beams': config.get('num_beams', 4),
444
+ 'no_repeat_ngram_size': 4,
445
+ 'length_penalty': 1.0,
446
+ 'early_stopping': False,
447
+ 'pad_token_id': self.tokenizer.pad_token_id,
448
+ 'eos_token_id': self.tokenizer.eos_token_id,
449
+ 'use_cache': True,
450
+ 'repetition_penalty': config.get('repetition_penalty', 1.1)
451
+ }
452
+
453
+ # Add config-specific parameters
454
+ if 'num_beam_groups' in config:
455
+ gen_kwargs['num_beam_groups'] = config['num_beam_groups']
456
+ if 'diversity_penalty' in config:
457
+ gen_kwargs['diversity_penalty'] = config['diversity_penalty']
458
+ if 'do_sample' in config:
459
+ gen_kwargs['do_sample'] = config['do_sample']
460
+ if config.get('do_sample', False): # Only add these if sampling
461
+ gen_kwargs['temperature'] = config.get('temperature', 0.7)
462
+ gen_kwargs['top_p'] = config.get('top_p', 0.9)
463
+ gen_kwargs['top_k'] = config.get('top_k', 50)
464
+
465
+ generated_ids = self.model.generate(**gen_kwargs)
466
+ generated_text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
467
+
468
+ # Extract generated part
469
+ if "TITLE:" in generated_text:
470
+ parts = generated_text.split("TITLE:")
471
+ if len(parts) > 1:
472
+ generated_text = "TITLE:" + parts[-1]
473
+
474
+ # If we got a good generation, break
475
+ if generated_text and len(generated_text) > 100:
476
+ break
477
+
478
+ except Exception as e:
479
+ if "diversity_penalty" not in str(e): # Only print unexpected errors
480
+ print(f"⚠️ Generation with {config['name']} config failed: {e}")
481
+ continue
482
+
483
+ # Parse the output
484
+ if generated_text:
485
+ return self._parse_generated_output_enhanced(generated_text)
486
+ else:
487
+ # Fallback if all attempts failed
488
+ return self._generate_contextual_abstract(), self._generate_contextual_overview(), self._generate_contextual_title()
489
+
490
+ def _parse_generated_output_enhanced(self, text: str) -> Tuple[str, str, str]:
491
+ """
492
+ ENHANCED: Better parsing with validation and correction
493
+ """
494
+ text = text.strip()
495
+
496
+ # Clean up artifacts
497
+ text = re.sub(r'<\|.*?\|>', '', text).strip()
498
+
499
+ # ENHANCED: More robust regex patterns matching training format
500
+ title_match = re.search(
501
+ r'(?:TITLE|Title):?\s*([^\n]+?)(?=\n|SHORT_SUMMARY:|SHORT SUMMARY:|$)',
502
+ text,
503
+ re.IGNORECASE
504
+ )
505
+
506
+ short_match = re.search(
507
+ r'(?:SHORT[_ ]SUMMARY):?\s*([^\n]+(?:\n[^\n:]+)*?)(?=\nABSTRACT:|$)',
508
+ text,
509
+ re.IGNORECASE | re.DOTALL
510
+ )
511
+
512
+ abstract_match = re.search(
513
+ r'(?:ABSTRACT|Abstract):?\s*(.+?)(?=$)',
514
+ text,
515
+ re.IGNORECASE | re.DOTALL
516
+ )
517
+
518
+ title = title_match.group(1).strip() if title_match else ""
519
+ overview = short_match.group(1).strip() if short_match else ""
520
+ abstract = abstract_match.group(1).strip() if abstract_match else ""
521
+
522
+ # ENHANCED: Better validation and correction
523
+ title = self._validate_and_correct_title(title)
524
+ overview = self._validate_and_correct_overview(overview)
525
+ abstract = self._validate_and_correct_abstract(abstract)
526
+
527
+ # Final quality check
528
+ if not self._is_quality_output(title, overview, abstract):
529
+ # Try to salvage what we can
530
+ if not title:
531
+ title = self._generate_contextual_title()
532
+ if not overview:
533
+ overview = self._generate_contextual_overview()
534
+ if not abstract:
535
+ abstract = self._generate_contextual_abstract()
536
+
537
+ return abstract, overview, title
538
+
539
+ def _validate_and_correct_title(self, title: str) -> str:
540
+ """ENHANCED: Validate and correct title, removing repetitive patterns and repeated words"""
541
+ if not title:
542
+ return ""
543
+
544
+ # Remove common prefixes and suffixes
545
+ title = re.sub(r'^(TITLE:?\s*|Title:?\s*)', '', title, flags=re.IGNORECASE)
546
+ title = re.sub(r'^(Investigation of|Analysis of|Study of|Research on|Examination of)\s+', '', title, flags=re.IGNORECASE)
547
+
548
+ # ENHANCED: Remove more repetitive endings and patterns
549
+ repetitive_endings = [
550
+ r'\s+in Clinical Research Applications?$',
551
+ r'\s+in Biomedical Research$',
552
+ r'\s+in Healthcare Settings?$',
553
+ r'\s+in Medical Research$',
554
+ r'\s+Research Applications?$',
555
+ r'\s+Clinical Applications?$',
556
+ r'\s+Research Theme$',
557
+ r'\s+Theme Analysis$',
558
+ r'\s+Research Analysis$',
559
+ r'\s+Clinical Analysis$'
560
+ ]
561
+
562
+ for pattern in repetitive_endings:
563
+ title = re.sub(pattern, '', title, flags=re.IGNORECASE)
564
+
565
+ # ENHANCED: Remove repeated words within the title
566
+ title = self._remove_repeated_words(title)
567
+
568
+ # Clean whitespace
569
+ title = re.sub(r'\s+', ' ', title).strip()
570
+
571
+ # Enforce word count (8-15 words for more concise titles)
572
+ words = title.split()
573
+ if len(words) > 15:
574
+ # Find natural break point
575
+ for i in range(12, min(16, len(words))):
576
+ if words[i].lower() in ['and', 'with', 'through', 'via', 'using', 'from', 'to', 'in', 'for']:
577
+ words = words[:i]
578
+ break
579
+ else:
580
+ words = words[:15]
581
+ title = ' '.join(words)
582
+
583
+ # Ensure minimum length
584
+ if len(words) < 5:
585
+ return ""
586
+
587
+ # ENHANCED: Check for overused terms and suggest alternatives
588
+ title = self._avoid_overused_terms(title)
589
+
590
+ # Track word usage for future titles
591
+ self._track_title_words(title)
592
+
593
+ # Capitalize appropriately
594
+ return self._smart_capitalize(title)
595
+
596
+ def _remove_repeated_words(self, text: str) -> str:
597
+ """Remove repeated words within a title while preserving meaning"""
598
+ words = text.split()
599
+ if len(words) <= 3:
600
+ return text
601
+
602
+ # Track word usage (case-insensitive)
603
+ seen_words = set()
604
+ filtered_words = []
605
+
606
+ # Common words that can appear multiple times
607
+ allowed_repeats = {'and', 'or', 'of', 'in', 'for', 'with', 'the', 'a', 'an', 'to', 'from', 'by'}
608
+
609
+ for word in words:
610
+ word_lower = word.lower()
611
+ # Allow common words to repeat, but remove other repetitions
612
+ if word_lower not in seen_words or word_lower in allowed_repeats:
613
+ filtered_words.append(word)
614
+ seen_words.add(word_lower)
615
+ # Special case: if removing this word would make title too short, keep it
616
+ elif len(filtered_words) < 6:
617
+ filtered_words.append(word)
618
+
619
+ return ' '.join(filtered_words)
620
+
621
+ def _track_title_words(self, title: str) -> None:
622
+ """Track word usage across all generated titles"""
623
+ words = title.lower().split()
624
+ # Filter out common words that don't affect diversity
625
+ meaningful_words = [w for w in words if w not in {'and', 'or', 'of', 'in', 'for', 'with', 'the', 'a', 'an', 'to', 'from', 'by', 'on', 'at'}]
626
+ self._title_word_frequency.update(meaningful_words)
627
+
628
+ def _avoid_overused_terms(self, title: str) -> str:
629
+ """Replace overused terms with alternatives to improve diversity"""
630
+ words = title.split()
631
+
632
+ # Replacement dictionary for overused terms
633
+ replacements = {
634
+ 'research': ['investigation', 'exploration', 'inquiry', 'analysis'],
635
+ 'analysis': ['examination', 'evaluation', 'assessment', 'investigation'],
636
+ 'study': ['investigation', 'exploration', 'examination', 'inquiry'],
637
+ 'application': ['implementation', 'utilization', 'deployment', 'use'],
638
+ 'approach': ['strategy', 'method', 'technique', 'framework'],
639
+ 'system': ['network', 'framework', 'mechanism', 'pathway'],
640
+ 'method': ['technique', 'approach', 'strategy', 'protocol'],
641
+ 'role': ['function', 'impact', 'influence', 'effect'],
642
+ 'effect': ['impact', 'influence', 'consequence', 'outcome'],
643
+ 'factor': ['element', 'component', 'determinant', 'variable']
644
+ }
645
+
646
+ # Check each word for overuse
647
+ for i, word in enumerate(words):
648
+ word_lower = word.lower()
649
+ # If word is overused (appears more than 5 times) and has replacements
650
+ if (self._title_word_frequency[word_lower] > 5 and
651
+ word_lower in replacements):
652
+ # Choose replacement based on current frequency
653
+ alternatives = replacements[word_lower]
654
+ best_alt = min(alternatives, key=lambda x: self._title_word_frequency[x])
655
+ # Only replace if the alternative is less used
656
+ if self._title_word_frequency[best_alt] < self._title_word_frequency[word_lower]:
657
+ # Preserve original capitalization
658
+ if word[0].isupper():
659
+ words[i] = best_alt.capitalize()
660
+ else:
661
+ words[i] = best_alt
662
+
663
+ return ' '.join(words)
664
+
665
+ def _validate_and_correct_overview(self, overview: str) -> str:
666
+ """ENHANCED: Validate and correct overview"""
667
+ if not overview:
668
+ return ""
669
+
670
+ # Remove label
671
+ overview = re.sub(r'^(SHORT[_ ]SUMMARY|OVERVIEW):?\s*', '', overview, flags=re.IGNORECASE)
672
+ overview = re.sub(r'\s+', ' ', overview).strip()
673
+
674
+ # Check length (should be 50-150 words)
675
+ words = overview.split()
676
+ if len(words) < 20 or len(words) > 150:
677
+ return ""
678
+
679
+ # Ensure it ends with proper punctuation
680
+ if overview and overview[-1] not in '.!?':
681
+ overview += '.'
682
+
683
+ return overview
684
+
685
+ def _validate_and_correct_abstract(self, abstract: str) -> str:
686
+ """ENHANCED: Validate and correct abstract"""
687
+ if not abstract:
688
+ return ""
689
+
690
+ # Remove label
691
+ abstract = re.sub(r'^(ABSTRACT):?\s*', '', abstract, flags=re.IGNORECASE)
692
+ abstract = re.sub(r'\s+', ' ', abstract).strip()
693
+
694
+ # Check length (should be 150-400 words)
695
+ words = abstract.split()
696
+ if len(words) < 50:
697
+ return ""
698
+
699
+ # Truncate if too long
700
+ if len(words) > 400:
701
+ # Try to find sentence boundary
702
+ sentences = re.split(r'(?<=[.!?])\s+', abstract)
703
+ result = []
704
+ word_count = 0
705
+ for sentence in sentences:
706
+ sentence_words = len(sentence.split())
707
+ if word_count + sentence_words <= 380:
708
+ result.append(sentence)
709
+ word_count += sentence_words
710
+ else:
711
+ break
712
+ abstract = ' '.join(result)
713
+
714
+ # Ensure proper ending
715
+ if abstract and abstract[-1] not in '.!?':
716
+ abstract += '.'
717
+
718
+ return abstract
719
+
720
+ def _is_quality_output(self, title: str, overview: str, abstract: str) -> bool:
721
+ """Check if output meets quality standards"""
722
+ return (
723
+ len(title.split()) >= 5 and len(title.split()) <= 20 and
724
+ len(overview.split()) >= 20 and len(overview.split()) <= 150 and
725
+ len(abstract.split()) >= 50 and len(abstract.split()) <= 400 and
726
+ title != overview and title != abstract and overview != abstract
727
+ )
728
+
729
+ def _smart_capitalize(self, text: str) -> str:
730
+ """Smart capitalization for titles"""
731
+ words = text.split()
732
+ if not words:
733
+ return text
734
+
735
+ # Always capitalize first word
736
+ words[0] = words[0][0].upper() + words[0][1:] if len(words[0]) > 1 else words[0].upper()
737
+
738
+ # Small words that shouldn't be capitalized (unless first)
739
+ small_words = {'of', 'in', 'and', 'or', 'the', 'a', 'an', 'to', 'for', 'with', 'from', 'by', 'on', 'at'}
740
+
741
+ for i in range(1, len(words)):
742
+ if words[i].lower() not in small_words or i == len(words) - 1:
743
+ # Keep acronyms as is
744
+ if not words[i].isupper() or len(words[i]) > 4:
745
+ words[i] = words[i][0].upper() + words[i][1:] if len(words[i]) > 1 else words[i].upper()
746
+
747
+ return ' '.join(words)
748
+
749
+ def _generate_contextual_title(self) -> str:
750
+ """ENHANCED: Generate diverse theme titles with varied vocabulary"""
751
+ if self._last_keywords and len(self._last_keywords) >= 2:
752
+ # Clean keywords
753
+ kw1 = re.sub(r'[_-]', ' ', str(self._last_keywords[0])).strip().title()
754
+ kw2 = re.sub(r'[_-]', ' ', str(self._last_keywords[1])).strip().title()
755
+
756
+ # ENHANCED: More diverse templates with varied vocabulary
757
+ templates = [
758
+ f"{kw1} and {kw2} Integration",
759
+ f"{kw1}-{kw2} Connections",
760
+ f"{kw1} Influences on {kw2}",
761
+ f"{kw2} Mechanisms in {kw1}",
762
+ f"{kw1} and {kw2}: Clinical Insights",
763
+ f"{kw1}-{kw2} Therapeutic Pathways",
764
+ f"{kw1} Interactions with {kw2}",
765
+ f"{kw2}-Mediated {kw1} Effects",
766
+ f"{kw1} and {kw2}: Biomedical Perspectives",
767
+ f"{kw1}-{kw2} Molecular Networks",
768
+ f"{kw1} Impact on {kw2} Regulation",
769
+ f"{kw2} Dynamics in {kw1} Context",
770
+ f"{kw1} and {kw2}: Translational Science",
771
+ f"{kw1}-{kw2} Disease Mechanisms",
772
+ f"{kw1} and {kw2}: Precision Medicine",
773
+ f"{kw2}-Associated {kw1} Pathways"
774
+ ]
775
+
776
+ # Select based on hash for consistency, but avoid repeating
777
+ base_hash = hash(''.join(self._last_keywords[:2]))
778
+
779
+ # Try to avoid recently used patterns
780
+ for i in range(len(templates)):
781
+ idx = (base_hash + i) % len(templates)
782
+ candidate = templates[idx]
783
+ pattern_key = f"{kw1[:3]}_{kw2[:3]}" # Simple key for tracking
784
+
785
+ if self._title_patterns_used[pattern_key] < 3: # Allow each pattern 3 times max
786
+ self._title_patterns_used[pattern_key] += 1
787
+ return candidate
788
+
789
+ # Fallback if all patterns used
790
+ return templates[base_hash % len(templates)]
791
+
792
+ return "Biomedical Mechanisms and Clinical Applications"
793
+
794
+ def _generate_contextual_overview(self) -> str:
795
+ """UPDATED: Generate theme overview using 'research theme covers' language"""
796
+ if self._last_keywords and len(self._last_keywords) >= 2:
797
+ # Clean keywords for natural language
798
+ clean_kw = []
799
+ for kw in self._last_keywords[:3]:
800
+ clean = re.sub(r'[_-]', ' ', str(kw)).strip().lower()
801
+ if clean:
802
+ clean_kw.append(clean)
803
+
804
+ if len(clean_kw) >= 2:
805
+ return (f"This research theme covers the relationships between {clean_kw[0]} and {clean_kw[1]}, "
806
+ f"encompassing significant implications for clinical practice. The theme covers "
807
+ f"novel mechanisms that could lead to improved therapeutic strategies and patient outcomes.")
808
+
809
+ return ("This research theme covers important biomedical mechanisms with "
810
+ "significant clinical implications. The theme encompasses new insights for "
811
+ "developing more effective treatment strategies and improving patient care.")
812
+
813
+ def _generate_contextual_abstract(self) -> str:
814
+ """UPDATED: Generate theme abstract using theme-oriented language"""
815
+ if self._last_keywords and len(self._last_keywords) >= 3:
816
+ # Clean keywords
817
+ kw1 = re.sub(r'[_-]', ' ', str(self._last_keywords[0])).strip().lower()
818
+ kw2 = re.sub(r'[_-]', ' ', str(self._last_keywords[1])).strip().lower()
819
+ kw3 = re.sub(r'[_-]', ' ', str(self._last_keywords[2])).strip().lower()
820
+
821
+ return (f"This research theme covers the complex relationships between {kw1} and {kw2} "
822
+ f"through comprehensive analysis of clinical and experimental data. The theme encompasses "
823
+ f"novel interactions involving {kw3} that contribute to disease mechanisms and therapeutic responses. "
824
+ f"This research theme covers previously unrecognized pathways that regulate these processes in clinical "
825
+ f"populations. The theme demonstrates significant associations between these "
826
+ f"factors and patient outcomes, with important implications for treatment selection "
827
+ f"and optimization. This research theme provides a foundation for developing targeted "
828
+ f"interventions and improving clinical care through personalized medicine approaches.")
829
+
830
+ return self._generate_fallback_abstract()
831
+
832
+ def _generate_fallback_title(self) -> str:
833
+ """ENHANCED: Generate diverse fallback titles"""
834
+ if self._last_keywords and len(self._last_keywords) >= 2:
835
+ kw1 = re.sub(r'[_-]', ' ', str(self._last_keywords[0])).strip().title()
836
+ kw2 = re.sub(r'[_-]', ' ', str(self._last_keywords[1])).strip().title()
837
+ fallback_patterns = [
838
+ f"{kw1} and {kw2}: Molecular Insights",
839
+ f"{kw1}-{kw2} Therapeutic Connections",
840
+ f"{kw1} Interactions with {kw2}",
841
+ f"{kw2}-Mediated {kw1} Pathways"
842
+ ]
843
+ # Use hash for consistent but varied selection
844
+ idx = hash(''.join(self._last_keywords[:2])) % len(fallback_patterns)
845
+ return fallback_patterns[idx]
846
+ return "Biomedical Mechanisms and Clinical Applications"
847
+
848
+ def _generate_fallback_overview(self) -> str:
849
+ """UPDATED: Generate fallback theme overview"""
850
+ return ("This research theme covers important insights into biomedical mechanisms "
851
+ "and their clinical applications. The theme encompasses significant implications "
852
+ "for improving patient care and developing new treatment strategies.")
853
+
854
+ def _generate_fallback_abstract(self) -> str:
855
+ """UPDATED: Generate fallback theme abstract"""
856
+ return ("This research theme covers complex biomedical mechanisms "
857
+ "through systematic analysis of clinical and experimental data. The theme encompasses "
858
+ "novel pathways and interactions that contribute to disease progression and treatment response. "
859
+ "This research theme covers important regulatory mechanisms that were previously unrecognized in clinical "
860
+ "populations. The theme has significant implications for developing "
861
+ "more effective therapeutic strategies and improving patient outcomes through "
862
+ "personalized medicine approaches. This research theme provides a foundation for future "
863
+ "research and clinical applications in precision medicine.")
864
+
865
+ # Memory management utilities
866
+ def cleanup_memory(self):
867
+ """Aggressive memory cleanup for long-running inference"""
868
+ torch.cuda.empty_cache()
869
+ gc.collect()
870
+ print("🧹 Memory cleanup completed")
871
+
872
+ def get_memory_stats(self):
873
+ """Get current GPU memory usage"""
874
+ if torch.cuda.is_available():
875
+ allocated = torch.cuda.memory_allocated() / 1024**3
876
+ reserved = torch.cuda.memory_reserved() / 1024**3
877
+ return f"GPU Memory - Allocated: {allocated:.2f}GB, Reserved: {reserved:.2f}GB"
878
+ return "CUDA not available"
879
+
880
+ def process_pickle_data(self, pickle_file_path: str, keywords_dict: Dict = None) -> List[Dict]:
881
+ """Process pickle file data with enhanced generation"""
882
+ print(f"📂 Loading data from {pickle_file_path}")
883
+
884
+ with open(pickle_file_path, 'rb') as f:
885
+ data = pickle.load(f)
886
+
887
+ results = []
888
+ num_clusters = data['metadata']['num_clusters']
889
+
890
+ print(f"🔄 Processing {num_clusters} clusters with enhanced generation...")
891
+
892
+ # Pre-allocate result list
893
+ results = [None] * num_clusters
894
+
895
+ # Process with progress bar
896
+ for cluster_idx in tqdm(range(num_clusters), desc="Generating analyses"):
897
+ try:
898
+ # Extract cluster data
899
+ cluster_docs = data['cluster_docs'][cluster_idx] if cluster_idx < len(data['cluster_docs']) else []
900
+ pmid_abstracts = data['pmid_abstracts'][cluster_idx] if cluster_idx < len(data['pmid_abstracts']) else []
901
+ keywords = keywords_dict.get(cluster_idx, []) if keywords_dict else []
902
+
903
+ # Create embedding with enhanced method
904
+ embedding = self.create_cluster_embedding(pmid_abstracts, keywords)
905
+
906
+ # Generate content with enhanced parameters
907
+ abstract, overview, title = self.generate_research_analysis(embedding, max_length=500)
908
+
909
+ results[cluster_idx] = {
910
+ 'cluster_id': cluster_idx,
911
+ 'abstract': abstract,
912
+ 'overview': overview,
913
+ 'title': title,
914
+ 'num_pmids': len(pmid_abstracts),
915
+ 'keywords': keywords[:10]
916
+ }
917
+
918
+ # Memory cleanup every 10 clusters
919
+ if cluster_idx % 10 == 0:
920
+ torch.cuda.empty_cache()
921
+ gc.collect()
922
+
923
+ except Exception as e:
924
+ print(f"⚠️ Error processing cluster {cluster_idx}: {e}")
925
+ results[cluster_idx] = {
926
+ 'cluster_id': cluster_idx,
927
+ 'abstract': self._generate_fallback_abstract(),
928
+ 'overview': self._generate_fallback_overview(),
929
+ 'title': f"Research Theme {cluster_idx} Analysis",
930
+ 'num_pmids': 0,
931
+ 'keywords': []
932
+ }
933
+
934
+ # Final cleanup
935
+ torch.cuda.empty_cache()
936
+ gc.collect()
937
+
938
+ # Filter out None results
939
+ results = [r for r in results if r is not None]
940
+
941
+ return results
942
+
943
+ def save_results_tsv(self, results: List[Dict], output_path: str = None, prefix: str = "research_analyses"):
944
+ """Save results to timestamped TSV file"""
945
+ if output_path is None:
946
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
947
+ output_path = f"{prefix}_{timestamp}.tsv"
948
+
949
+ df = pd.DataFrame(results)
950
+ df.to_csv(output_path, sep='\t', index=False)
951
+ print(f"💾 Results saved to: {output_path}")
952
+ return output_path
953
+
954
+ # Backward compatibility wrapper
955
+ def generate_research_summary(self, embedding: torch.Tensor, max_length: int = 500) -> Tuple[str, str, str]:
956
+ """Backward compatibility wrapper"""
957
+ return self.generate_research_analysis(embedding, max_length)
958
+
959
+ # Convenience function for easy usage
960
+ def load_model_and_generate(model_dir: str, pickle_files: List[str], keywords_dict: Dict = None,
961
+ output_prefix: str = "research_analyses") -> List[str]:
962
+ """
963
+ Convenience function to load model and generate analyses for multiple pickle files
964
+ """
965
+ print("🚀 Initializing model with fixed generation parameters...")
966
+ model = ScientificModelInference(model_dir)
967
+
968
+ print(f"📊 {model.get_memory_stats()}")
969
+
970
+ output_files = []
971
+
972
+ for i, pickle_file in enumerate(pickle_files):
973
+ print(f"\n📋 Processing {pickle_file} ({i+1}/{len(pickle_files)})")
974
+
975
+ # Process data with enhanced generation
976
+ results = model.process_pickle_data(pickle_file, keywords_dict)
977
+
978
+ # Generate unique output name
979
+ period_name = Path(pickle_file).stem
980
+ output_path = model.save_results_tsv(results, prefix=f"{output_prefix}_{period_name}")
981
+ output_files.append(output_path)
982
+
983
+ # Memory cleanup between files
984
+ if len(pickle_files) > 1:
985
+ model.cleanup_memory()
986
+ print(f"📊 {model.get_memory_stats()}")
987
+
988
+ print(f"🎉 Completed processing {len(pickle_files)} files with improved titles!")
989
+ return output_files