Spaces:
Sleeping
Sleeping
from __future__ import annotations | |
import base64 | |
import gzip | |
import json | |
from dataclasses import dataclass, fields | |
from io import BytesIO | |
from pathlib import Path | |
from urllib.parse import parse_qsl | |
import altair as alt | |
import ipywidgets as widgets | |
import numpy as np | |
import polars as pl | |
import solara | |
import solara.lab | |
from cmap import Colormap | |
from ipymolstar.widget import PDBeMolstar | |
from pydantic import BaseModel | |
from make_link import decode_data | |
base_v = np.vectorize(np.base_repr) | |
PAD_SIZE = 0.05 # when not autoscale Y size of padding used | |
def norm(x, vmin, vmax): | |
return (x - vmin) / (vmax - vmin) | |
class ColorTransform(BaseModel): | |
name: str = "tol:rainbow_PuRd" | |
norm_type: str = "linear" | |
vmin: float = 0.0 | |
vmax: float = 1.0 | |
missing_data_color: str = "#8c8c8c" | |
highlight_color: str = "#e933f8" | |
def molstar_colors(self, data: pl.DataFrame) -> dict: | |
data = data.drop_nulls() | |
if self.norm_type == "categorical": | |
values = data["value"] | |
else: | |
values = norm(data["value"], vmin=self.vmin, vmax=self.vmax) | |
rgba_array = self.cmap(values, bytes=True) | |
ints = rgba_array.astype(np.uint8).view(dtype=np.uint32).byteswap() | |
padded = np.char.rjust(base_v(ints // 2**8, 16), 6, "0") | |
hex_colors = np.char.add("#", padded).squeeze() | |
color_data = { | |
"data": [ | |
{"residue_number": resi, "color": hcolor.lower()} | |
for resi, hcolor in zip(data["residue_number"], hex_colors) | |
], | |
"nonSelectedColor": self.missing_data_color, | |
} | |
return color_data | |
def cmap(self) -> Colormap: | |
return Colormap(self.name, bad=self.missing_data_color) | |
def altair_scale(self) -> alt.Scale: | |
if self.norm_type == "categorical": | |
colors = self.cmap.to_altair(N=self.cmap.num_colors) | |
domain = range(self.cmap.num_colors) | |
else: | |
colors = self.cmap.to_altair() | |
domain = np.linspace(self.vmin, self.vmax, 256, endpoint=True) | |
scale = alt.Scale(domain=list(domain), range=colors, clamp=True) | |
return scale | |
class AxisProperties(BaseModel): | |
label: str = "x" | |
unit: str = "au" | |
autoscale_y: bool = True | |
def title(self) -> str: | |
return f"{self.label} ({self.unit})" | |
def make_chart( | |
data: pl.DataFrame, colors: ColorTransform, axis_properties: AxisProperties | |
) -> alt.LayerChart: | |
xmin, xmax = data["residue_number"].min(), data["residue_number"].max() | |
xpad = (xmax - xmin) * 0.05 | |
xscale = alt.Scale(domain=(xmin - xpad, xmax + xpad)) | |
if axis_properties.autoscale_y: | |
y_scale = alt.Scale() | |
elif colors.norm_type == "categorical": | |
ypad = colors.cmap.num_colors * 0.05 | |
y_scale = alt.Scale(domain=(0 - ypad, colors.cmap.num_colors - 1 + ypad)) | |
else: | |
ypad = (colors.vmax - colors.vmin) * 0.05 | |
y_scale = alt.Scale(domain=(colors.vmin - ypad, colors.vmax + ypad)) | |
zoom_x = alt.selection_interval( | |
bind="scales", | |
encodings=["x"], | |
zoom="wheel![!event.shiftKey]", | |
) | |
scatter = ( | |
alt.Chart(data) | |
.mark_circle(interpolate="basis", size=200) | |
.encode( | |
x=alt.X("residue_number:Q", title="Residue Number", scale=xscale), | |
y=alt.Y( | |
"value:Q", | |
title=axis_properties.title, | |
scale=y_scale, | |
), | |
color=alt.Color( | |
f"value:{'O' if colors.norm_type == 'categorical' else 'Q'}", | |
scale=colors.altair_scale, | |
title=axis_properties.title, | |
), | |
) | |
.add_params(zoom_x) | |
) | |
# Create a selection that chooses the nearest point & selects based on x-value | |
nearest = alt.selection_point( | |
name="point", | |
nearest=True, | |
on="pointerover", | |
fields=["residue_number"], | |
empty=False, | |
clear="mouseout", | |
) | |
select_residue = ( | |
alt.Chart(data) | |
.mark_point() | |
.encode( | |
x="residue_number:Q", | |
opacity=alt.value(0), | |
) | |
.add_params(nearest) | |
) | |
# Draw a rule at the location of the selection | |
rule = ( | |
alt.Chart(data) | |
.mark_rule(color=colors.highlight_color, size=2) | |
.encode( | |
x="residue_number:Q", | |
) | |
.transform_filter(nearest) | |
) | |
# vline = ( | |
# alt.Chart(pd.DataFrame({"x": [0]})) | |
# .mark_rule(color=colors.highlight_color, size=2) | |
# .encode(x="x:Q") | |
# ) | |
line_position = alt.param(name="line_position", value=0.0) | |
line_opacity = alt.param(name="line_opacity", value=1) | |
df_line = pl.DataFrame({"x": [1.0]}) | |
# Create vertical rule with parameter | |
vline = ( | |
alt.Chart(df_line) | |
.mark_rule(color=colors.highlight_color, opacity=line_opacity, size=2) | |
.encode(x=alt.X("p", type="quantitative")) | |
.transform_calculate(p=alt.datum.x * line_position) | |
.add_params(line_position, line_opacity) | |
) | |
# Put the five layers into a chart and bind the data | |
chart = ( | |
alt.layer(scatter, vline, select_residue, rule).properties( | |
width="container", | |
height=480, # autosize height? | |
) | |
# .configure(autosize="fit") | |
) | |
return chart | |
def ScatterChart( | |
data: pl.DataFrame, | |
colors: ColorTransform, | |
axis_properties: AxisProperties, | |
on_selections, | |
line_value, | |
): | |
def mem_chart(): | |
chart = make_chart(data, colors, axis_properties) | |
return chart | |
chart = solara.use_memo(mem_chart, dependencies=[data, colors, axis_properties]) | |
if line_value is not None: | |
params = {"line_position": line_value, "line_opacity": 1} | |
else: | |
params = {"line_position": 0.0, "line_opacity": 0} | |
dark_effective = solara.lab.use_dark_effective() | |
if dark_effective: | |
options = {"actions": False, "theme": "dark"} | |
else: | |
options = {"actions": False} | |
view = alt.JupyterChart.element( # type: ignore | |
chart=chart, | |
embed_options=options, | |
_params=params, | |
) | |
def bind(): | |
real = solara.get_widget(view) | |
real.selections.observe(on_selections, "point") # type: ignore | |
solara.use_effect(bind, [data, colors]) | |
def is_numeric(val) -> bool: | |
if val is not None: | |
return not np.isnan(val) | |
return False | |
def ProteinView( | |
title: str, | |
molecule_id: str, | |
data: pl.DataFrame, | |
colors: ColorTransform, | |
axis_properties: AxisProperties, | |
dark_effective: bool, | |
description: str = "", | |
): | |
about_dialog = solara.use_reactive(False) | |
fullscreen = solara.use_reactive(False) | |
# residue number to highlight in altair chart | |
line_number = solara.use_reactive(None) | |
# residue number to highlight in protein view | |
highlight_number = solara.use_reactive(None) | |
if data.is_empty(): | |
color_data = {} | |
else: | |
color_data = colors.molstar_colors(data) | |
tooltips = { | |
"data": [ | |
{ | |
"residue_number": resi, | |
"tooltip": f"{axis_properties.label}: {value:.2g} {axis_properties.unit}" | |
if is_numeric(value) | |
else "No data", | |
} | |
for resi, value in zip(data["residue_number"], data["value"]) | |
] | |
} | |
def on_molstar_mouseover(value): | |
r = value.get("residueNumber", None) | |
line_number.set(r) | |
def on_molstar_mouseout(value): | |
on_molstar_mouseover({}) | |
def on_chart_selection(event): | |
try: | |
r = event["new"].value[0]["residue_number"] | |
highlight_number.set(r) | |
except (IndexError, KeyError): | |
highlight_number.set(None) | |
with solara.AppBar(): | |
solara.AppBarTitle(title) | |
with solara.Tooltip("Fullscreen"): | |
solara.Button( | |
icon_name="mdi-fullscreen", | |
icon=True, | |
on_click=lambda: fullscreen.set(not fullscreen.value), | |
) | |
if description: | |
with solara.Tooltip("About"): | |
solara.Button( | |
icon_name="mdi-information-outline", | |
icon=True, | |
on_click=lambda: about_dialog.set(True), | |
) | |
solara.lab.ThemeToggle() | |
with solara.v.Dialog( | |
v_model=about_dialog.value, on_v_model=lambda _ignore: about_dialog.set(False) | |
): | |
with solara.Card("About", margin=0): | |
solara.Markdown(description) | |
with solara.ColumnsResponsive([4, 8]): | |
with solara.Card(style={"height": "550px"}): | |
PDBeMolstar.element( # type: ignore | |
theme="dark" if dark_effective else "light", | |
molecule_id=molecule_id.lower(), | |
color_data=color_data, | |
hide_water=True, | |
tooltips=tooltips, | |
height="525px", | |
highlight={"data": [{"residue_number": int(highlight_number.value)}]} | |
if highlight_number.value | |
else None, | |
highlight_color=colors.highlight_color, | |
on_mouseover_event=on_molstar_mouseover, | |
on_mouseout_event=on_molstar_mouseout, | |
hide_controls_icon=True, | |
hide_expand_icon=True, | |
hide_settings_icon=True, | |
expanded=fullscreen.value, | |
).key(f"molstar-{dark_effective}") | |
if not fullscreen.value: | |
with solara.Card(style={"height": "550px"}): | |
if data.is_empty(): | |
solara.Text("No data") | |
else: | |
ScatterChart( | |
data, | |
colors, | |
axis_properties, | |
on_chart_selection, | |
line_number.value, | |
) | |
def RoutedView(): | |
route = solara.use_router() | |
dark_effective = solara.lab.use_dark_effective() | |
try: | |
query_dict = {k: v for k, v in parse_qsl(route.search)} | |
colors = ColorTransform(**query_dict) # type: ignore | |
axis_properties = AxisProperties(**query_dict) # type: ignore | |
data = decode_data(query_dict["data"]) | |
ProteinView( | |
query_dict["title"], | |
molecule_id=query_dict["molecule_id"], | |
data=data, | |
colors=colors, | |
axis_properties=axis_properties, | |
dark_effective=dark_effective, | |
description=query_dict.get("description", ""), | |
) | |
except KeyError as err: | |
solara.Warning(f"Error: {err}") | |