Spaces:
Sleeping
Sleeping
### ----------------- ### | |
# Standard library imports | |
import os | |
import re | |
import sys | |
import copy | |
import warnings | |
from typing import Optional | |
# Third-party imports | |
import numpy as np | |
import torch | |
import torch.distributed as dist | |
import uvicorn | |
import librosa | |
import whisper | |
import requests | |
from fastapi import FastAPI | |
from pydantic import BaseModel | |
from decord import VideoReader, cpu | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import gradio as gr | |
import spaces | |
# Local imports | |
from egogpt.model.builder import load_pretrained_model | |
from egogpt.mm_utils import get_model_name_from_path, process_images | |
from egogpt.constants import ( | |
IMAGE_TOKEN_INDEX, | |
DEFAULT_IMAGE_TOKEN, | |
IGNORE_INDEX, | |
SPEECH_TOKEN_INDEX, | |
DEFAULT_SPEECH_TOKEN | |
) | |
from egogpt.conversation import conv_templates, SeparatorStyle | |
import subprocess | |
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) | |
from huggingface_hub import snapshot_download | |
# 下载整个模型文件夹到本地 ./llava-onevision-qwen2-7b-ov | |
# snapshot_download( | |
# repo_id="lmms-lab/llava-onevision-qwen2-7b-ov", | |
# local_dir="./llava-onevision-qwen2-7b-ov", # 指定本地存储目录 | |
# ignore_patterns=["*.md", "*.txt"] # 可以忽略一些不必要的文件(可选) | |
# ) | |
from huggingface_hub import hf_hub_download | |
# Download the model checkpoint file (large-v3.pt) | |
ego_gpt_path = hf_hub_download( | |
repo_id="EgoLife-v1/EgoGPT", | |
filename="large-v3.pt", | |
local_dir="./" | |
) | |
# pretrained = "/mnt/sfs-common/jkyang/EgoGPT/checkpoints/EgoGPT-llavaov-7b-EgoIT-109k-release" | |
# pretrained = "/mnt/sfs-common/jkyang/EgoGPT/checkpoints/EgoGPT-llavaov-7b-EgoIT-EgoLife-Demo" | |
pretrained = 'EgoLife-v1/EgoGPT' | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
device_map = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# Add this initialization code before loading the model | |
def setup(rank, world_size): | |
os.environ['MASTER_ADDR'] = 'localhost' | |
os.environ['MASTER_PORT'] = '12377' | |
# initialize the process group | |
dist.init_process_group("gloo", rank=rank, world_size=world_size) | |
setup(0,1) | |
tokenizer, model, max_length = load_pretrained_model(pretrained,device_map=device_map) | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model.to(device).eval() | |
title_markdown = """ | |
<div style="display: flex; justify-content: space-between; align-items: center; background: linear-gradient(90deg, rgba(72,219,251,0.1), rgba(29,209,161,0.1)); border-radius: 20px; box-shadow: 0 4px 6px rgba(0,0,0,0.1); padding: 20px; margin-bottom: 20px;"> | |
<div style="display: flex; align-items: center;"> | |
<a href="https://egolife-ntu.github.io/" style="margin-right: 20px; text-decoration: none; display: flex; align-items: center;"> | |
<img src="https://egolife-ntu.github.io/egolife.png" alt="EgoLife" style="max-width: 100px; height: auto; border-radius: 15px; box-shadow: 0 2px 4px rgba(0,0,0,0.1);"> | |
</a> | |
<div> | |
<h1 style="margin: 0; background: linear-gradient(90deg, #48dbfb, #1dd1a1); -webkit-background-clip: text; -webkit-text-fill-color: transparent; font-size: 2.5em; font-weight: 700;">EgoLife</h1> | |
<h2 style="margin: 10px 0; color: #2d3436; font-weight: 500;">Towards Egocentric Life Assistant</h2> | |
<div style="display: flex; gap: 15px; margin-top: 10px;"> | |
<a href="https://egolife-ntu.github.io/" style="text-decoration: none; color: #48dbfb; font-weight: 500; transition: color 0.3s;">Project Page</a> | | |
<a href="https://github.com/egolife-ntu/EgoGPT" style="text-decoration: none; color: #48dbfb; font-weight: 500; transition: color 0.3s;">Github</a> | | |
<a href="https://huggingface.co/lmms-lab" style="text-decoration: none; color: #48dbfb; font-weight: 500; transition: color 0.3s;">Huggingface</a> | | |
<a href="https://arxiv.org/" style="text-decoration: none; color: #48dbfb; font-weight: 500; transition: color 0.3s;">Paper</a> | | |
<a href="https://x.com/" style="text-decoration: none; color: #48dbfb; font-weight: 500; transition: color 0.3s;">Twitter (X)</a> | |
</div> | |
</div> | |
</div> | |
<div style="text-align: right; margin-left: 20px;"> | |
<h1 style="margin: 0; background: linear-gradient(90deg, #48dbfb, #1dd1a1); -webkit-background-clip: text; -webkit-text-fill-color: transparent; font-size: 2.5em; font-weight: 700;">EgoGPT</h1> | |
<h2 style="margin: 10px 0; background: linear-gradient(90deg, #48dbfb, #1dd1a1); -webkit-background-clip: text; -webkit-text-fill-color: transparent; font-size: 1.8em; font-weight: 600;">An Egocentric Video-Audio-Text Model<br>from EgoLife Project</h2> | |
</div> | |
</div> | |
""" | |
notice_html = """ | |
<div style="background-color: #f9f9f9; border-left: 5px solid #48dbfb; padding: 20px; margin-top: 20px; border-radius: 10px; box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);"> | |
<ul style="list-style-type: none; padding-left: 0; font-size: 1.1em; color: #555;"> | |
<li>- Due to hardware limitations on this demo page, we recommend users only try 10-second videos.</li> | |
<li>- The demo model is used for the egocentric video captioning step for the EgoRAG framework. The recommended prompt includes:</li> | |
<ul style="padding-left: 20px; margin-top: 10px; color: #333;"> | |
<li>Can you help me log everything I do and the key things I see, like a personal journal? Describe them in a natural style. | |
<li>Please provide your response using the first person, with "I" as the subject. Make sure the descriptions are detailed and natural.</li> | |
<li>Can you write down important things I notice or interact with? Please respond in the first person, using "I" as the subject. Describe them in a natural style.</li> | |
</ul> | |
</ul> | |
</div> | |
""" | |
bibtext = """ | |
### Citation | |
``` | |
@article{yang2025egolife, | |
title={EgoLife\: Towards Egocentric Life Assistant}, | |
author={The EgoLife Team}, | |
journal={arXiv preprint arXiv:25xxx}, | |
year={2025} | |
} | |
``` | |
""" | |
# cur_dir = os.path.dirname(os.path.abspath(__file__)) | |
cur_dir = '.' | |
def time_to_frame_idx(time_int: int, fps: int) -> int: | |
""" | |
Convert time in HHMMSSFF format (integer or string) to frame index. | |
:param time_int: Time in HHMMSSFF format, e.g., 10483000 (10:48:30.00) or "10483000". | |
:param fps: Frames per second of the video. | |
:return: Frame index corresponding to the given time. | |
""" | |
# Ensure time_int is a string for slicing | |
time_str = str(time_int).zfill( | |
8) # Pad with zeros if necessary to ensure it's 8 digits | |
hours = int(time_str[:2]) | |
minutes = int(time_str[2:4]) | |
seconds = int(time_str[4:6]) | |
frames = int(time_str[6:8]) | |
total_seconds = hours * 3600 + minutes * 60 + seconds | |
total_frames = total_seconds * fps + frames # Convert to total frames | |
return total_frames | |
def split_text(text, keywords): | |
# 创建一个正则表达式模式,将所有关键词用 | 连接,并使用捕获组 | |
pattern = '(' + '|'.join(map(re.escape, keywords)) + ')' | |
# 使用 re.split 保留分隔符 | |
parts = re.split(pattern, text) | |
# 去除空字符串 | |
parts = [part for part in parts if part] | |
return parts | |
warnings.filterwarnings("ignore") | |
# Create FastAPI instance | |
app = FastAPI() | |
def load_video( | |
video_path: Optional[str] = None, | |
max_frames_num: int = 16, | |
fps: int = 1, | |
video_start_time: Optional[float] = None, | |
start_time: Optional[float] = None, | |
end_time: Optional[float] = None, | |
time_based_processing: bool = False | |
) -> tuple: | |
vr = VideoReader(video_path, ctx=cpu(0), num_threads=1) | |
target_sr = 16000 | |
# Add new time-based processing logic | |
if time_based_processing: | |
# Initialize video reader | |
vr = decord.VideoReader(video_path, ctx=decord.cpu(0), num_threads=1) | |
total_frame_num = len(vr) | |
# Get the actual FPS of the video | |
video_fps = vr.get_avg_fps() | |
# Convert time to frame index based on the actual video FPS | |
video_start_frame = int(time_to_frame_idx(video_start_time, video_fps)) | |
start_frame = int(time_to_frame_idx(start_time, video_fps)) | |
end_frame = int(time_to_frame_idx(end_time, video_fps)) | |
print("start frame", start_frame) | |
print("end frame", end_frame) | |
# Ensure the end time does not exceed the total frame number | |
if end_frame - start_frame > total_frame_num: | |
end_frame = total_frame_num + start_frame | |
# Adjust start_frame and end_frame based on video start time | |
start_frame -= video_start_frame | |
end_frame -= video_start_frame | |
start_frame = max(0, int(round(start_frame))) # 确保不会小于0 | |
end_frame = min(total_frame_num, int(round(end_frame))) # 确保不会超过总帧数 | |
start_frame = int(round(start_frame)) | |
end_frame = int(round(end_frame)) | |
# Sample frames based on the provided fps (e.g., 1 frame per second) | |
frame_idx = [i for i in range(start_frame, end_frame) if (i - start_frame) % int(video_fps / fps) == 0] | |
# Get the video frames for the sampled indices | |
video = vr.get_batch(frame_idx).asnumpy() | |
target_sr = 16000 # Set target sample rate to 16kHz | |
# Load audio from video with resampling | |
y, _ = librosa.load(video_path, sr=target_sr) | |
# Convert time to audio samples (using 16kHz sample rate) | |
start_sample = int(start_time * target_sr) | |
end_sample = int(end_time * target_sr) | |
# Extract audio segment | |
speech = y[start_sample:end_sample] | |
else: | |
# Original processing logic | |
speech, _ = librosa.load(video_path, sr=target_sr) | |
total_frame_num = len(vr) | |
avg_fps = round(vr.get_avg_fps() / fps) | |
frame_idx = [i for i in range(0, total_frame_num, avg_fps)] | |
if max_frames_num > 0: | |
if len(frame_idx) > max_frames_num: | |
uniform_sampled_frames = np.linspace(0, total_frame_num - 1, max_frames_num, dtype=int) | |
frame_idx = uniform_sampled_frames.tolist() | |
video = vr.get_batch(frame_idx).asnumpy() | |
# Process audio | |
speech = whisper.pad_or_trim(speech.astype(np.float32)) | |
speech = whisper.log_mel_spectrogram(speech, n_mels=128).permute(1, 0) | |
speech_lengths = torch.LongTensor([speech.shape[0]]) | |
return video, speech, speech_lengths | |
class PromptRequest(BaseModel): | |
prompt: str | |
video_path: str = None | |
max_frames_num: int = 16 | |
fps: int = 1 | |
video_start_time: float = None | |
start_time: float = None | |
end_time: float = None | |
time_based_processing: bool = False | |
# @spaces.GPU(duration=120) | |
def generate_text(video_path, audio_track, prompt): | |
max_frames_num = 30 | |
fps = 1 | |
# model.eval() | |
# Video + speech branch | |
conv_template = "qwen_1_5" # Make sure you use correct chat template for different models | |
question = f"<image>\n{prompt}" | |
conv = copy.deepcopy(conv_templates[conv_template]) | |
conv.append_message(conv.roles[0], question) | |
conv.append_message(conv.roles[1], None) | |
prompt_question = conv.get_prompt() | |
video, speech, speech_lengths = load_video( | |
video_path=video_path, | |
max_frames_num=max_frames_num, | |
fps=fps, | |
) | |
speech=torch.stack([speech]).to("cuda").half() | |
processor = model.get_vision_tower().image_processor | |
processed_video = processor.preprocess(video, return_tensors="pt")["pixel_values"] | |
image = [(processed_video, video[0].size, "video")] | |
print(prompt_question) | |
parts=split_text(prompt_question,["<image>","<speech>"]) | |
input_ids=[] | |
for part in parts: | |
if "<image>"==part: | |
input_ids+=[IMAGE_TOKEN_INDEX] | |
elif "<speech>"==part: | |
input_ids+=[SPEECH_TOKEN_INDEX] | |
else: | |
input_ids+=tokenizer(part).input_ids | |
input_ids = torch.tensor(input_ids,dtype=torch.long).unsqueeze(0).to(device) | |
image_tensor = [image[0][0].half()] | |
image_sizes = [image[0][1]] | |
generate_kwargs={"eos_token_id":tokenizer.eos_token_id} | |
print(input_ids) | |
cont = model.generate( | |
input_ids, | |
images=image_tensor, | |
image_sizes=image_sizes, | |
speech=speech, | |
speech_lengths=speech_lengths, | |
do_sample=False, | |
temperature=0.5, | |
max_new_tokens=4096, | |
modalities=["video"], | |
**generate_kwargs | |
) | |
text_outputs = tokenizer.batch_decode(cont, skip_special_tokens=True) | |
return text_outputs[0] | |
def extract_audio_from_video(video_path, audio_path=None): | |
if audio_path: | |
try: | |
y, sr = librosa.load(audio_path, sr=8000, mono=True, res_type='kaiser_fast') | |
return (sr, y) | |
except Exception as e: | |
print(f"Error loading audio from {audio_path}: {e}") | |
return None | |
if video_path is None: | |
return None | |
try: | |
y, sr = librosa.load(video_path, sr=8000, mono=True, res_type='kaiser_fast') | |
return (sr, y) | |
except Exception as e: | |
print(f"Error extracting audio from video: {e}") | |
return None | |
head = """ | |
<style> | |
/* Submit按钮默认和悬停效果 */ | |
button.lg.secondary.svelte-1gz44hr { | |
background-color: #ff9933 !important; | |
transition: background-color 0.3s ease !important; | |
} | |
button.lg.secondary.svelte-1gz44hr:hover { | |
background-color: #ff7777 !important; /* 悬停时颜色加深 */ | |
} | |
/* 确保按钮文字始终清晰可见 */ | |
button.lg.secondary.svelte-1gz44hr span { | |
color: white !important; | |
} | |
/* 隐藏表头中的第二列 */ | |
.table-wrap .svelte-p5q82i th:nth-child(2) { | |
display: none; | |
} | |
/* 隐藏表格内容中的第二列 */ | |
.table-wrap .svelte-p5q82i td:nth-child(2) { | |
display: none; | |
} | |
.table-wrap { | |
max-height: 300px; | |
overflow-y: auto; | |
} | |
</style> | |
<script> | |
// 新版同步控制代码 | |
function syncMediaElements() { | |
// 获取视频和音频元素 | |
const video = document.querySelector('[data-testid="Video-player"] video'); | |
const waveform = document.querySelector('#waveform'); | |
const audio = waveform?.querySelector('audio') || waveform?.shadowRoot?.querySelector('audio'); | |
// 如果任一元素不存在,则退出 | |
if (!video || !audio) return; | |
// 解除旧的事件监听(避免重复绑定) | |
video.removeEventListener('play', syncPlay); | |
audio.removeEventListener('play', syncPlay); | |
video.removeEventListener('timeupdate', syncVideoTime); | |
audio.removeEventListener('timeupdate', syncAudioTime); | |
// 定义同步函数 | |
function syncPlay(e) { | |
if(e.target === video && audio.paused) audio.play(); | |
if(e.target === audio && video.paused) video.play(); | |
} | |
function syncVideoTime() { | |
if(!audio.seeking && Math.abs(video.currentTime - audio.currentTime) > 0.1){ | |
audio.currentTime = video.currentTime; | |
} | |
} | |
function syncAudioTime() { | |
if(!video.seeking && Math.abs(audio.currentTime - video.currentTime) > 0.1){ | |
video.currentTime = audio.currentTime; | |
} | |
} | |
// 绑定新的事件监听 | |
video.addEventListener('play', syncPlay); | |
audio.addEventListener('play', syncPlay); | |
video.addEventListener('timeupdate', syncVideoTime); | |
audio.addEventListener('timeupdate', syncAudioTime); | |
// 同步暂停事件 | |
video.addEventListener('pause', () => audio.pause()); | |
audio.addEventListener('pause', () => video.pause()); | |
console.log('Media elements synced successfully!'); | |
} | |
// 智能DOM观察器 | |
const observer = new MutationObserver((mutations) => { | |
mutations.forEach((mutation) => { | |
if (mutation.addedNodes.length) { | |
mutation.addedNodes.forEach((node) => { | |
// 深度检查新增节点 | |
if (node.nodeType === 1) { // Element node | |
// 检查是否包含视频组件 | |
if (node.querySelector?.('[data-testid="Video-player"]')) { | |
// 当视频组件出现时,开始查找音频 | |
const audioObserver = new MutationObserver(() => { | |
if(document.querySelector('#waveform audio')) { | |
audioObserver.disconnect(); | |
setTimeout(syncMediaElements, 500); // 等待组件完全加载 | |
} | |
}); | |
audioObserver.observe(document.body, { | |
childList: true, | |
subtree: true | |
}); | |
} | |
} | |
}); | |
} | |
}); | |
}); | |
// 开始观察整个文档 | |
observer.observe(document.body, { | |
childList: true, | |
subtree: true | |
}); | |
// 初始检查(应对组件已存在的情况) | |
setTimeout(() => { | |
if(document.querySelector('[data-testid="Video-player"] video') && | |
document.querySelector('#waveform audio')){ | |
syncMediaElements(); | |
} | |
}, 1000); | |
</script> | |
""" | |
with gr.Blocks(head=head) as demo: | |
gr.HTML(title_markdown) | |
gr.HTML(notice_html) | |
with gr.Row(): | |
with gr.Column(): | |
video_input = gr.Video(label="Video", autoplay=True, loop=True, format="mp4", width=600, height=400, show_label=False, elem_id='video') | |
# Audio input synchronized with video playback | |
audio_display = gr.Audio(label="Video Audio Track", autoplay=False, show_label=True, visible=True, interactive=False, elem_id="audio") | |
text_input = gr.Textbox(label="Question", placeholder="Enter your message here...") | |
with gr.Column(): # Create a separate column for output and examples | |
output_text = gr.Textbox(label="Response", lines=14, max_lines=14) | |
gr.Examples( | |
examples=[ | |
[f"{cur_dir}/bike.mp4", f"{cur_dir}/bike.mp3", "Can you tell me what I'm doing in short words. Describe them in a natural style."], | |
[f"{cur_dir}/bike.mp4", f"{cur_dir}/bike.mp3", "Can you tell me what I'm doing in short words. Describe them in a natural style."], | |
[f"{cur_dir}/bike.mp4", f"{cur_dir}/bike.mp3", "Can you tell me what I'm doing in short words. Describe them in a natural style."], | |
[f"{cur_dir}/bike.mp4", f"{cur_dir}/bike.mp3", "Can you tell me what I'm doing in short words. Describe them in a natural style."] | |
], | |
inputs=[video_input, audio_display, text_input], | |
outputs=[output_text] | |
) | |
# Add event handler for video changes | |
video_input.change( | |
fn=lambda video_path: extract_audio_from_video(video_path, audio_path=None), | |
inputs=[video_input], | |
outputs=[audio_display] | |
) | |
# Add event handler for video clear/delete | |
def clear_outputs(video): | |
if video is None: # Video is cleared/deleted | |
return "" | |
return gr.skip() # Keep existing text if video exists | |
video_input.change( | |
fn=clear_outputs, | |
inputs=[video_input], | |
outputs=[output_text] | |
) | |
# Add submit button and its event handler | |
submit_btn = gr.Button("Submit") | |
submit_btn.click( | |
fn=generate_text, | |
inputs=[video_input, audio_display, text_input], | |
outputs=[output_text] | |
) | |
gr.Markdown(bibtext) | |
# Launch the Gradio app | |
if __name__ == "__main__": | |
demo.launch(share=True) | |