from dataclasses import dataclass from typing import Dict, Any, Optional import base64 import asyncio import logging import random import traceback import torch # note: there is no HunyuanImageToVideoPipeline yet in Diffusers from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel from diffusers.hooks import apply_enhance_a_video, EnhanceAVideoConfig from varnish import Varnish from varnish.utils import is_truthy, process_input_image from teacache import enable_teacache, disable_teacache # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Check environment variable for pipeline support support_image_prompt = is_truthy(os.getenv("SUPPORT_INPUT_IMAGE_PROMPT")) @dataclass class GenerationConfig: """Configuration for video generation""" # Content settings prompt: str negative_prompt: str = "" # Model settings num_frames: int = 49 # Should be 4k + 1 format height: int = 320 width: int = 576 num_inference_steps: int = 50 guidance_scale: float = 7.0 # Reproducibility seed: int = -1 # Varnish post-processing settings fps: int = 30 double_num_frames: bool = False super_resolution: bool = False grain_amount: float = 0.0 quality: int = 18 # CRF scale (0-51, lower is better) # Audio settings enable_audio: bool = False audio_prompt: str = "" audio_negative_prompt: str = "voices, voice, talking, speaking, speech" # TeaCache settings enable_teacache: bool = True teacache_threshold: float = 0.15 # values: 0 (original), 0.1 (1.6x speedup), 0.15 (2.1x speedup) # Enhance-A-Video settings enable_enhance_a_video: bool = True enhance_a_video_weight: float = 5.0 # LoRA settings lora_model_name: str = "" # HuggingFace repo ID or path to LoRA model lora_model_weight_file: str = "" # Specific weight file to load from the LoRA model lora_model_trigger: str = "" # Optional trigger word to prepend to the prompt def validate_and_adjust(self) -> 'GenerationConfig': """Validate and adjust parameters""" # Ensure num_frames follows 4k + 1 format k = (self.num_frames - 1) // 4 self.num_frames = (k * 4) + 1 # Set random seed if not specified if self.seed == -1: self.seed = random.randint(0, 2**32 - 1) return self class EndpointHandler: """Handles video generation requests using HunyuanVideo and Varnish""" def __init__(self, path: str = ""): """Initialize handler with models Args: path: Path to model weights """ self.device = "cuda" if torch.cuda.is_available() else "cpu" # Initialize transformer with Enhance-A-Video injection first transformer = HunyuanVideoTransformer3DModel.from_pretrained( path, subfolder="transformer", torch_dtype=torch.bfloat16 ) if support_image_prompt: raise Exception("Please use a version of Diffusers that supports HunyuanImageToVideoPipeline") # # Initialize image-to-video pipeline # self.image_to_video = HunyuanImageToVideoPipeline.from_pretrained( # path, # transformer=transformer, # torch_dtype=torch.float16, # ).to(self.device) # # # Initialize components in appropriate precision # self.image_to_video.text_encoder = self.image_to_video.text_encoder.half() # self.image_to_video.text_encoder_2 = self.image_to_video.text_encoder_2.half() # self.image_to_video.transformer = self.image_to_video.transformer.to(torch.bfloat16) # self.image_to_video.vae = self.image_to_video.vae.half() else: # Initialize text-to-video pipeline self.text_to_video = HunyuanVideoPipeline.from_pretrained( path, transformer=transformer, torch_dtype=torch.float16, ).to(self.device) # Initialize components in appropriate precision self.text_to_video.text_encoder = self.text_to_video.text_encoder.half() self.text_to_video.text_encoder_2 = self.text_to_video.text_encoder_2.half() self.text_to_video.transformer = self.text_to_video.transformer.to(torch.bfloat16) self.text_to_video.vae = self.text_to_video.vae.half() # Initialize LoRA tracking self._current_lora_model = None # Initialize Varnish for post-processing self.varnish = Varnish( device=self.device, model_base_dir="/repository/varnish" ) async def process_frames( self, frames: torch.Tensor, config: GenerationConfig ) -> tuple[str, dict]: """Post-process generated frames using Varnish Args: frames: Generated video frames tensor config: Generation configuration Returns: Tuple of (video data URI, metadata dictionary) """ try: # Process video with Varnish result = await self.varnish( input_data=frames, fps=config.fps, double_num_frames=config.double_num_frames, super_resolution=config.super_resolution, grain_amount=config.grain_amount, enable_audio=config.enable_audio, audio_prompt=config.audio_prompt, audio_negative_prompt=config.audio_negative_prompt ) # Convert to data URI video_uri = await result.write(type="data-uri", quality=config.quality) # Collect metadata metadata = { "width": result.metadata.width, "height": result.metadata.height, "num_frames": result.metadata.frame_count, "fps": result.metadata.fps, "duration": result.metadata.duration, "seed": config.seed, "enable_teacache": config.enable_teacache, "teacache_threshold": config.teacache_threshold if config.enable_teacache else 0, "enable_enhance_a_video": config.enable_enhance_a_video, "enhance_a_video_weight": config.enhance_a_video_weight if config.enable_enhance_a_video else 0, } return video_uri, metadata except Exception as e: logger.error(f"Error in process_frames: {str(e)}") raise RuntimeError(f"Failed to process frames: {str(e)}") def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: """Process video generation requests Args: data: Request data containing: - inputs (str): Prompt for video generation - parameters (dict): Generation parameters Returns: Dictionary containing: - video: Base64 encoded MP4 data URI - content-type: MIME type - metadata: Generation metadata """ # Extract inputs inputs = data.pop("inputs", data) if isinstance(inputs, dict): prompt = inputs.get("prompt", "") else: prompt = inputs params = data.get("parameters", {}) # Create and validate config config = GenerationConfig( prompt=prompt, negative_prompt=params.get("negative_prompt", ""), num_frames=params.get("num_frames", 49), height=params.get("height", 320), width=params.get("width", 576), num_inference_steps=params.get("num_inference_steps", 50), guidance_scale=params.get("guidance_scale", 7.0), seed=params.get("seed", -1), fps=params.get("fps", 30), double_num_frames=params.get("double_num_frames", False), super_resolution=params.get("super_resolution", False), grain_amount=params.get("grain_amount", 0.0), quality=params.get("quality", 18), enable_audio=params.get("enable_audio", False), audio_prompt=params.get("audio_prompt", ""), audio_negative_prompt=params.get("audio_negative_prompt", "voices, voice, talking, speaking, speech"), enable_teacache=params.get("enable_teacache", True), # values: 0 (original), 0.1 (1.6x speedup), 0.15 (2.1x speedup). teacache_threshold=params.get("teacache_threshold", 0.15), enable_enhance_a_video=params.get("enable_enhance_a_video", True), enhance_a_video_weight=params.get("enhance_a_video_weight", 5.0), lora_model_name=params.get("lora_model_name", ""), lora_model_weight_file=params.get("lora_model_weight_file", ""), lora_model_trigger=params.get("lora_model_trigger", ""), ).validate_and_adjust() try: # Set random seeds if config.seed != -1: torch.manual_seed(config.seed) random.seed(config.seed) generator = torch.Generator(device=self.device).manual_seed(config.seed) else: generator = None # Configure TeaCache #if config.enable_teacache: # enable_teacache( # self.pipeline.transformer, # num_inference_steps=config.num_inference_steps, # rel_l1_thresh=config.teacache_threshold # ) #else: # disable_teacache(self.pipeline.transformer) with torch.inference_mode(): # Prepare generation parameters generation_kwargs = { "prompt": config.prompt, # Failed to generate video: HunyuanVideoPipeline.__call__() got an unexpected keyword argument 'negative_prompt' #"negative_prompt": config.negative_prompt, "num_frames": config.num_frames, "height": config.height, "width": config.width, "num_inference_steps": config.num_inference_steps, "guidance_scale": config.guidance_scale, "generator": generator, "output_type": "pt", } # Handle LoRA loading/unloading if hasattr(self, '_current_lora_model'): if self._current_lora_model != (config.lora_model_name, config.lora_model_weight_file): # Unload previous LoRA if it exists and is different if support_image_prompt and hasattr(self.image_to_video, 'unload_lora_weights'): self.image_to_video.unload_lora_weights() else: if hasattr(self.text_to_video, 'unload_lora_weights'): self.text_to_video.unload_lora_weights() if config.lora_model_name: # Load new LoRA if support_image_prompt and hasattr(self.image_to_video, 'load_lora_weights'): self.image_to_video.load_lora_weights( config.lora_model_name, weight_name=config.lora_model_weight_file if config.lora_model_weight_file else None, token=hf_token, ) else: if hasattr(self.text_to_video, 'load_lora_weights'): self.text_to_video.load_lora_weights( config.lora_model_name, weight_name=config.lora_model_weight_file if config.lora_model_weight_file else None, token=hf_token, ) self._current_lora_model = (config.lora_model_name, config.lora_model_weight_file) # Modify prompt if trigger word is provided if config.lora_model_trigger: generation_kwargs["prompt"] = f"{config.lora_model_trigger} {generation_kwargs['prompt']}" # Check if image-to-video generation is requested if support_image_prompt and input_image: self._configure_teacache(self.image_to_video, config) processed_image = process_input_image( input_image, config.width, config.height, config.input_image_quality, ) generation_kwargs["image"] = processed_image apply_enhance_a_video(self.image_to_video.transformer, EnhanceAVideoConfig( weight=config.enhance_a_video_weight if config.enable_enhance_a_video else 0.0, num_frames_callback=lambda: (config.num_frames - 1), _attention_type=1 )) frames = self.image_to_video(**generation_kwargs).frames else: self._configure_teacache(self.text_to_video, config) apply_enhance_a_video(self.text_to_video.transformer, EnhanceAVideoConfig( weight=config.enhance_a_video_weight if config.enable_enhance_a_video else 0.0, num_frames_callback=lambda: (config.num_frames - 1), _attention_type=1 )) frames = self.text_to_video(**generation_kwargs).frames try: loop = asyncio.get_event_loop() except RuntimeError: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) video_uri, metadata = loop.run_until_complete(self.process_frames(frames, config)) return { "video": video_uri, "content-type": "video/mp4", "metadata": metadata } except Exception as e: message = f"Error generating video ({str(e)})\n{traceback.format_exc()}" logger.error(message) raise RuntimeError(message)