""" Example usage script for LLM2Vec4CXR model. This demonstrates how to load and use the model for chest X-ray report analysis. Prerequisites: 1. Install the LLM2Vec4CXR package: pip install git+https://github.com/lukeingawesome/llm2vec4cxr.git Or clone and install in development mode: git clone https://github.com/lukeingawesome/llm2vec4cxr.git cd llm2vec4cxr pip install -e . 2. The model will be automatically downloaded from Hugging Face when first used. """ import torch import torch.nn.functional as F from llm2vec_wrapper import LLM2VecWrapper as LLM2Vec def load_llm2vec4cxr_model(model_name_or_path="lukeingawesome/llm2vec4cxr"): """ Load the LLM2Vec4CXR model with proper configuration. Args: model_name_or_path (str): Hugging Face model path or local path Returns: tuple: (model, tokenizer) """ # Load model with the specific configuration used for LLM2Vec4CXR model = LLM2Vec.from_pretrained( base_model_name_or_path=model_name_or_path, enable_bidirectional=True, pooling_mode="latent_attention", # This is the key modification max_length=512, torch_dtype=torch.bfloat16, ) # Configure tokenizer tokenizer = model.tokenizer tokenizer.padding_side = 'left' return model, tokenizer def tokenize_with_separator(texts, tokenizer, max_length=512): """ Tokenize texts with special handling for separator-based splitting. This is useful for instruction-following tasks. Args: texts (list): List of texts to tokenize tokenizer: The tokenizer to use max_length (int): Maximum sequence length Returns: dict: Tokenized inputs with attention masks and embed masks """ texts_2 = [] original_texts = [] separator = '!@#$%^&*()' for text in texts: parts = text.split(separator) texts_2.append(parts[1] if len(parts) > 1 else "") original_texts.append("".join(parts)) # Tokenize original texts tokenized = tokenizer( original_texts, return_tensors="pt", padding=True, truncation=True, max_length=max_length, ) # Create embedding masks for the separated parts embed_mask = None for t_i, t in enumerate(texts_2): ids = tokenizer( [t], return_tensors="pt", padding=True, truncation=True, max_length=max_length, add_special_tokens=False, ) e_m = torch.zeros_like(tokenized["attention_mask"][t_i]) if len(ids["input_ids"][0]) > 0: e_m[-len(ids["input_ids"][0]):] = torch.ones(len(ids["input_ids"][0])) if embed_mask is None: embed_mask = e_m.unsqueeze(0) else: embed_mask = torch.cat((embed_mask, e_m.unsqueeze(0)), dim=0) tokenized["embed_mask"] = embed_mask return tokenized def compute_similarities(model, tokenizer, texts, device): """ Compute similarity scores between the first text and all other texts. Args: model: The LLM2Vec model tokenizer: The tokenizer texts (list): List of texts to compare (first text is the reference) device: The device to run computations on Returns: tuple: (embeddings, similarities) """ with torch.no_grad(): # Use separator-based tokenization if texts contain the separator if any('!@#$%^&*()' in text for text in texts): tokenized = tokenize_with_separator(texts, tokenizer, 512) else: tokenized = tokenizer( texts, return_tensors="pt", padding=True, truncation=True, max_length=512, ) tokenized = tokenized.to(device) if hasattr(tokenized, 'to'): tokenized = tokenized.to(torch.bfloat16) else: # Convert each tensor in the dict for key in tokenized: if torch.is_tensor(tokenized[key]): tokenized[key] = tokenized[key].to(torch.bfloat16) embeddings = model(tokenized) # Compute cosine similarities between first embedding and all others similarities = F.cosine_similarity(embeddings[0], embeddings[1:], dim=1) return embeddings, similarities def main(): """ Example usage of the LLM2Vec4CXR model for chest X-ray report analysis. """ # Set device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Using device: {device}") # Load the model print("Loading LLM2Vec4CXR model...") model, tokenizer = load_llm2vec4cxr_model() model = model.to(device).to(torch.bfloat16) model.eval() # Example 1: Basic text embedding using built-in method print("\n" + "="*60) print("Example 1: Basic Text Embedding (Built-in Method)") print("="*60) report = "There is a small increase in the left-sided effusion. There continues to be volume loss at both bases." # Use the convenient built-in method embedding = model.encode_text(report) print(f"Report: {report}") print(f"Embedding shape: {embedding.shape}") print(f"Embedding norm: {torch.norm(embedding).item():.4f}") # Example 2: Instruction-based similarity comparison print("\n" + "="*60) print("Example 2: Instruction-based Similarity Comparison") print("="*60) separator = '!@#$%^&*()' instruction = 'Determine the change or the status of the pleural effusion.' report = 'There is a small increase in the left-sided effusion. There continues to be volume loss at both bases.' text = instruction + separator + report comparison_options = [ 'No pleural effusion', 'Pleural effusion', 'Effusion is seen in the right', 'Effusion is seen in the left', 'Pleural effusion is improving', 'Pleural effusion is stable', 'Pleural effusion is worsening' ] all_texts = [text] + comparison_options # Use built-in method for instruction-based encoding embeddings = model.encode_with_instruction(all_texts) similarities = F.cosine_similarity(embeddings[0], embeddings[1:], dim=1) print(f"Original text: {report}") print(f"Instruction: {instruction}") print("\nSimilarity Scores:") print("-" * 50) for option, score in zip(comparison_options, similarities): print(f"{option:<35} | {score.item():.4f}") # Find the most similar option best_match_idx = torch.argmax(similarities).item() print(f"\nBest match: {comparison_options[best_match_idx]} (score: {similarities[best_match_idx].item():.4f})") # Example 3: Multiple report comparison print("\n" + "="*60) print("Example 3: Multiple Report Comparison") print("="*60) reports = [ "No acute cardiopulmonary abnormality.", "Small bilateral pleural effusions.", "Large left pleural effusion with compressive atelectasis.", "Interval improvement in bilateral pleural effusions.", "Worsening bilateral pleural effusions." ] print("Computing embeddings for multiple reports...") # Use built-in method for multiple texts embeddings = model.encode_text(reports) # Compute pairwise similarities similarity_matrix = F.cosine_similarity( embeddings.unsqueeze(1), embeddings.unsqueeze(0), dim=2 ) print("\nPairwise Similarity Matrix:") print("-" * 30) for i, report1 in enumerate(reports): print(f"Report {i+1}: {report1[:30]}...") for j, report2 in enumerate(reports): print(f" vs Report {j+1}: {similarity_matrix[i][j].item():.4f}") print() if __name__ == "__main__": main()