import gradio as gr import torch from transformers import AutoConfig, AutoModelForCausalLM from janus.models import MultiModalityCausalLM, VLChatProcessor from janus.utils.io import load_pil_images from PIL import Image import numpy as np import os import time import spaces # Load model and processor model_path = "deepseek-ai/Janus-Pro-7B" config = AutoConfig.from_pretrained(model_path) language_config = config.language_config language_config._attn_implementation = 'eager' vl_gpt = AutoModelForCausalLM.from_pretrained(model_path, language_config=language_config, trust_remote_code=True) if torch.cuda.is_available(): vl_gpt = vl_gpt.to(torch.bfloat16).cuda() else: vl_gpt = vl_gpt.to(torch.float16) vl_chat_processor = VLChatProcessor.from_pretrained(model_path) tokenizer = vl_chat_processor.tokenizer cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu' @torch.inference_mode() @spaces.GPU(duration=120) def multimodal_understanding(image, question, seed, top_p, temperature): # Clear CUDA cache before generating torch.cuda.empty_cache() # set seed torch.manual_seed(seed) np.random.seed(seed) torch.cuda.manual_seed(seed) conversation = [ { "role": "<|User|>", "content": f"\n{question}", "images": [image], }, {"role": "<|Assistant|>", "content": ""}, ] pil_images = [Image.fromarray(image)] prepare_inputs = vl_chat_processor( conversations=conversation, images=pil_images, force_batchify=True ).to(cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16) inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs) outputs = vl_gpt.language_model.generate( inputs_embeds=inputs_embeds, attention_mask=prepare_inputs.attention_mask, pad_token_id=tokenizer.eos_token_id, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id, max_new_tokens=4000, do_sample=False if temperature == 0 else True, use_cache=True, temperature=temperature, top_p=top_p, ) answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True) return answer def generate(input_ids, width, height, temperature: float = 1, parallel_size: int = 5, cfg_weight: float = 5, image_token_num_per_image: int = 576, patch_size: int = 16): # Clear CUDA cache before generating torch.cuda.empty_cache() tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(cuda_device) for i in range(parallel_size * 2): tokens[i, :] = input_ids if i % 2 != 0: tokens[i, 1:-1] = vl_chat_processor.pad_id inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens) generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).to(cuda_device) pkv = None for i in range(image_token_num_per_image): with torch.no_grad(): outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=pkv) pkv = outputs.past_key_values hidden_states = outputs.last_hidden_state logits = vl_gpt.gen_head(hidden_states[:, -1, :]) logit_cond = logits[0::2, :] logit_uncond = logits[1::2, :] logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond) probs = torch.softmax(logits / temperature, dim=-1) next_token = torch.multinomial(probs, num_samples=1) generated_tokens[:, i] = next_token.squeeze(dim=-1) next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1) img_embeds = vl_gpt.prepare_gen_img_embeds(next_token) inputs_embeds = img_embeds.unsqueeze(dim=1) patches = vl_gpt.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int), shape=[parallel_size, 8, width // patch_size, height // patch_size]) return generated_tokens.to(dtype=torch.int), patches def unpack(dec, width, height, parallel_size=5): dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1) dec = np.clip((dec + 1) / 2 * 255, 0, 255) visual_img = np.zeros((parallel_size, width, height, 3), dtype=np.uint8) visual_img[:, :, :] = dec return visual_img @torch.inference_mode() @spaces.GPU(duration=120) # Specify a duration to avoid timeout def generate_image(prompt, seed=None, guidance=5, t2i_temperature=1.0): # Clear CUDA cache and avoid tracking gradients torch.cuda.empty_cache() # Set the seed for reproducible results if seed is not None: torch.manual_seed(seed) torch.cuda.manual_seed(seed) np.random.seed(seed) width = 384 height = 384 parallel_size = 5 with torch.no_grad(): messages = [{'role': '<|User|>', 'content': prompt}, {'role': '<|Assistant|>', 'content': ''}] text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(conversations=messages, sft_format=vl_chat_processor.sft_format, system_prompt='') text = text + vl_chat_processor.image_start_tag input_ids = torch.LongTensor(tokenizer.encode(text)) output, patches = generate(input_ids, width // 16 * 16, height // 16 * 16, cfg_weight=guidance, parallel_size=parallel_size, temperature=t2i_temperature) images = unpack(patches, width // 16 * 16, height // 16 * 16, parallel_size=parallel_size) return [Image.fromarray(images[i]).resize((768, 768), Image.LANCZOS) for i in range(parallel_size)] # Custom CSS as a string custom_css = """ .gradio-container { font-family: 'Inter', -apple-system, sans-serif; } .image-preview { min-height: 300px; max-height: 500px; width: 100%; object-fit: contain; border-radius: 8px; border: 2px solid #eee; } .tab-nav { background: white; padding: 1rem; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05); } .examples-table { font-size: 0.9rem; } .gr-button.gr-button-lg { padding: 12px 24px; font-size: 1.1rem; } .gr-input, .gr-select { border-radius: 6px; } .gr-form { background: white; padding: 20px; border-radius: 12px; box-shadow: 0 4px 6px rgba(0,0,0,0.05); } .gr-panel { border: none; background: transparent; } .footer { text-align: center; margin-top: 2rem; padding: 1rem; color: #666; } """ # Gradio interface with improved UI with gr.Blocks( theme=gr.themes.Soft(primary_hue="blue", secondary_hue="indigo"), css=custom_css ) as demo: gr.Markdown( """ # Deepseek Multimodal ### Advanced AI for Visual Understanding and Generation This powerful multimodal AI system combines: * **Visual Analysis**: Advanced image understanding and medical image interpretation * **Creative Generation**: High-quality image generation from text descriptions * **Interactive Chat**: Natural conversation about visual content """ ) with gr.Tabs(): # Visual Chat Tab with gr.Tab("Visual Understanding"): with gr.Row(equal_height=True): with gr.Column(scale=1): image_input = gr.Image( label="Upload Image", type="numpy", elem_classes="image-preview" ) with gr.Column(scale=1): question_input = gr.Textbox( label="Question or Analysis Request", placeholder="Ask a question about the image or request detailed analysis...", lines=3 ) with gr.Row(): und_seed_input = gr.Number( label="Seed", precision=0, value=42, container=False ) top_p = gr.Slider( minimum=0, maximum=1, value=0.95, step=0.05, label="Top-p", container=False ) temperature = gr.Slider( minimum=0, maximum=1, value=0.1, step=0.05, label="Temperature", container=False ) understanding_button = gr.Button( "Analyze Image", variant="primary" ) understanding_output = gr.Textbox( label="Analysis Results", lines=10, show_copy_button=True ) with gr.Accordion("Medical Analysis Examples", open=False): gr.Examples( examples=[ [ """You are an AI assistant trained to analyze medical images...""", "fundus.webp", ], ], inputs=[question_input, image_input], ) # Image Generation Tab with gr.Tab("Image Generation"): with gr.Column(): prompt_input = gr.Textbox( label="Image Description", placeholder="Describe the image you want to create in detail...", lines=3 ) with gr.Row(): cfg_weight_input = gr.Slider( minimum=1, maximum=10, value=5, step=0.5, label="Guidance Scale", info="Higher values create images that more closely match your prompt" ) t2i_temperature = gr.Slider( minimum=0, maximum=1, value=1.0, step=0.05, label="Temperature", info="Controls randomness in generation" ) seed_input = gr.Number( label="Seed (Optional)", precision=0, value=12345, info="Set for reproducible results" ) generation_button = gr.Button( "Generate Images", variant="primary" ) image_output = gr.Gallery( label="Generated Images", columns=3, rows=2, height=500, object_fit="contain" ) with gr.Accordion("Generation Examples", open=False): gr.Examples( examples=[ "Master shifu racoon wearing drip attire as a street gangster.", "The face of a beautiful girl", "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", "A glass of red wine on a reflective surface.", "A cute and adorable baby fox with big brown eyes...", ], inputs=prompt_input, ) # Connect components understanding_button.click( multimodal_understanding, inputs=[image_input, question_input, und_seed_input, top_p, temperature], outputs=understanding_output ) generation_button.click( fn=generate_image, inputs=[prompt_input, seed_input, cfg_weight_input, t2i_temperature], outputs=image_output ) # Launch the demo if __name__ == "__main__": demo.launch(share=True)