import gradio as gr import timm import torch import pandas as pd TITLE = "wd-eva02-large-tagger-v3-vector" DESCRIPTION = """ モデル:[SmilingWolf/wd-eva02-large-tagger-v3](https://huggingface.co/SmilingWolf/wd-eva02-large-tagger-v3) 日本語訳?:[p1atdev/danbooru-ja-tag-pair-20241015](https://huggingface.co/datasets/p1atdev/danbooru-ja-tag-pair-20241015) """ model = timm.create_model(f"hf_hub:SmilingWolf/wd-eva02-large-tagger-v3", pretrained=True) head = model.head.weight.data del model df = pd.read_csv(f"https://huggingface.co/SmilingWolf/wd-eva02-large-tagger-v3/resolve/main/selected_tags.csv") id2label = df["name"].to_dict() label2id = {v:k for k,v in id2label.items()} general_tags = df[df["category"] == 0].index character_tags = df[df["category"] == 4].index all_tags = df.index tag_pair_df = pd.read_parquet("hf://datasets/p1atdev/danbooru-ja-tag-pair-20241015/data/train-00000-of-00001.parquet") tag_pair = {title:other_names[0] for title, other_names in zip(tag_pair_df["title"], tag_pair_df["other_names"])} for tag in df["name"]: if tag not in tag_pair: tag_pair[tag] = "" def predict(target_tags, search_in): target_tags = [tag.strip().replace(" ", "_") for tag in target_tags.split(",")] target_ids = [label2id[tag] for tag in target_tags] query = head[target_ids].unsqueeze(1) sim = torch.cosine_similarity(query, head.unsqueeze(0), dim=2).mean(dim=0) tags = general_tags if search_in == "general" else character_tags if search_in == "character" else all_tags return {f"{id2label[i]}({tag_pair[id2label[i]]})": sim[i].item() for i in tags} demo = gr.Interface( fn=predict, inputs=[ gr.Text(value="pink hair, braid", label="Target tags"), gr.Dropdown(["all", "general", "character"], label="Search in", value="all") ], outputs=gr.Label(num_top_classes=50), title=TITLE, description=DESCRIPTION ) demo.launch()