import pandas as pd import numpy as np from PIL import Image import torch import torchvision import clip import matplotlib.pyplot as plt import seaborn as sns import gradio as gr DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' model_name = 'ViT-B/16' #@param ['RN50', 'RN101', 'RN50x4', 'RN50x16', 'ViT-B/32', 'ViT-B/16'] model, preprocess = clip.load(model_name) model.to(DEVICE).eval() resolution = model.visual.input_resolution resizer = torchvision.transforms.Resize(size=(resolution, resolution)) def create_rgb_tensor(color): """color is e.g. [1,0,0]""" return torch.tensor(color, device=DEVICE).reshape((1, 3, 1, 1)) def encode_color(color): """color is e.g. [1,0,0]""" rgb = create_rgb_tensor(color) return model.encode_image( resizer(rgb) ) def encode_text(text): tokenized_text = clip.tokenize(text).to(DEVICE) return model.encode_text(tokenized_text) def lerp(x, y, steps=11): """Linear interpolation between two tensors """ weights = torch.tensor(np.linspace(0,1,steps), device=DEVICE).reshape([-1, 1, 1, 1]) interpolated = x * (1 - weights) + y * weights return interpolated def get_interpolated_scores(x, y, encoded_text, steps=11): interpolated = lerp(x, y, steps) interpolated_encodings = model.encode_image(resizer(interpolated)) scores = torch.cosine_similarity(interpolated_encodings, encoded_text).detach().cpu().numpy() rgb = interpolated.detach().cpu().numpy().reshape(-1, 3) interpolated_hex = [rgb2hex(x) for x in rgb] data = pd.DataFrame({ 'similarity': scores, 'color': interpolated_hex }).reset_index().rename(columns={'index':'step'}) return data def rgb2hex(rgb): rgb = (rgb * 255).astype(int) r,g,b = rgb return "#{:02x}{:02x}{:02x}".format(r,g,b) def similarity_plot(data, text_prompt): title = f'CLIP Cosine Similarity Prompt="{text_prompt}"' fig, ax = plt.subplots() plot = data['similarity'].plot(kind='bar', ax=ax, stacked=True, title=title, color=data['color'], width=1.0, xlim=(0, 2), grid=False) plot.get_xaxis().set_visible(False) ; return fig def interpolation_experiment(rgb_start, rgb_end, text_prompt, steps=11): start = create_rgb_tensor(rgb_start) end = create_rgb_tensor(rgb_end) encoded_text = encode_text(text_prompt) data = get_interpolated_scores(start, end, encoded_text, steps) return similarity_plot(data, text_prompt) start_input = gr.inputs.Textbox(lines=1, default="1, 0, 0", label="Start RGB") end_input = gr.inputs.Textbox(lines=1, default="0, 1, 0", label="End RGB") ' (Comma separated numbers between 0 and 1)' text_input = gr.inputs.Textbox(lines=1, label="Text Prompt", default='A solid red square') steps_input = gr.inputs.Slider(minimum=1, maximum=100, step=1, default=11, label="Interpolation Steps") def gradio_fn(rgb_start, rgb_end, text_prompt, steps=11): rgb_start = [float(x.strip()) for x in rgb_start.split(',')] rgb_end = [float(x.strip()) for x in rgb_end.split(',')] out = interpolation_experiment(rgb_start, rgb_end, text_prompt, steps) return out iface = gr.Interface( fn=gradio_fn, inputs=[start_input, end_input, text_input, steps_input], outputs="plot") iface.launch(debug=True, share=False)