Skip to content

Commit

Permalink
fix: fixed a bug where sizes weren't scaled on first render
Browse files Browse the repository at this point in the history
  • Loading branch information
sg-s committed Oct 29, 2024
1 parent f459334 commit 5d36d9a
Showing 1 changed file with 29 additions and 25 deletions.
54 changes: 29 additions & 25 deletions src/data_hub/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from beartype.typing import Optional
from bokeh.io import show
from bokeh.layouts import column, row
from bokeh.models import ColumnDataSource, CustomJS, HoverTool, Select, TapTool
from bokeh.models import ColumnDataSource, CustomJS, HoverTool, Select
from bokeh.plotting import figure
from deeporigin.data_hub.dataframe import DataFrame
from deeporigin.exceptions import DeepOriginException
Expand All @@ -17,13 +17,21 @@ def scatter(
y: Optional[str] = None,
size: Optional[str] = None,
):
"""function to make a scatter plot from a Deep Origin dataframe, with support for interactivity"""
"""function to make a scatter plot from a Deep Origin dataframe, with support for interactivity
Args:
df (DataFrame): Deep Origin dataframe
x (Optional[str], optional): name of column to use for x axis. Defaults to None.
y (Optional[str], optional): name of column to use for y axis. Defaults to None.
size (Optional[str], optional): name of column to use for size. Defaults to None.
`df` should be a Deep Origin dataframe with at least two numeric columns.
"""

figure_width = 500
select_width = int(figure_width * 0.3)

if df.shape[1] < 2:
raise ValueError("DataFrame must contain at least two columns.")
js_code = _read_js_code()

cols = df.attrs["metadata"]["cols"]
cols = [col["name"] for col in cols if col["type"] in ["float", "integer"]]
Expand All @@ -40,16 +48,18 @@ def scatter(
if size is None:
size = cols[0]

# Set up initial data for scatter plot
initial_x = x
initial_y = y
initial_size = size
# normalize sizes. this should match what's in
# axes_callback.js
sizes = list(df[size])
min_size = min(sizes)
max_size = max(sizes)
sizes = [2 + 15 * (value - min_size) / (max_size - min_size) for value in sizes]

# CDS for scatter data
data = dict(
x=list(df[initial_x]),
y=list(df[initial_y]),
size=list(df[initial_size]),
x=list(df[x]),
y=list(df[y]),
size=sizes,
)

scatter_source = ColumnDataSource(data)
Expand Down Expand Up @@ -86,34 +96,31 @@ def scatter(
)

# Set initial axis labels
p.xaxis.axis_label = initial_x
p.yaxis.axis_label = initial_y
p.xaxis.axis_label = x
p.yaxis.axis_label = y

# Create dropdown selectors for X and Y axes
x_select = Select(
title="X-Axis",
value=initial_x,
value=x,
options=cols,
width=select_width,
)
y_select = Select(
title="Y-Axis",
value=initial_y,
value=y,
options=cols,
width=select_width,
)
size_select = Select(
title="Size",
value=initial_size,
value=size,
options=cols,
width=select_width,
)

# JavaScript callback to update data, axis labels, and point sizes on select change

with importlib.resources.open_text("deeporigin.data_hub", "axes_callback.js") as f:
axes_callback_js = f.read()

axes_callback = CustomJS(
args=dict(
source=scatter_source,
Expand All @@ -124,18 +131,15 @@ def scatter(
x_axis=p.xaxis[0],
y_axis=p.yaxis[0],
),
code=axes_callback_js,
code=js_code["axes_callback"],
)

# JS code, will run in browser
# this updates the value of the slider to the currently
# hovered point
code = """
"""

callback = CustomJS(
code=code,
code=js_code["hover_callback"],
args=dict(
marker_source=marker_source,
scatter_source=scatter_source,
Expand Down

0 comments on commit 5d36d9a

Please sign in to comment.