diff --git a/src/data_hub/plotting.py b/src/data_hub/plotting.py index d7f5f84..5d1dac8 100644 --- a/src/data_hub/plotting.py +++ b/src/data_hub/plotting.py @@ -4,7 +4,7 @@ from beartype.typing import Optional from bokeh.io import show from bokeh.layouts import column, row -from bokeh.models import ColumnDataSource, CustomJS, HoverTool, Select, TapTool +from bokeh.models import ColumnDataSource, CustomJS, HoverTool, Select from bokeh.plotting import figure from deeporigin.data_hub.dataframe import DataFrame from deeporigin.exceptions import DeepOriginException @@ -17,13 +17,21 @@ def scatter( y: Optional[str] = None, size: Optional[str] = None, ): - """function to make a scatter plot from a Deep Origin dataframe, with support for interactivity""" + """function to make a scatter plot from a Deep Origin dataframe, with support for interactivity + + Args: + df (DataFrame): Deep Origin dataframe + x (Optional[str], optional): name of column to use for x axis. Defaults to None. + y (Optional[str], optional): name of column to use for y axis. Defaults to None. + size (Optional[str], optional): name of column to use for size. Defaults to None. + + `df` should be a Deep Origin dataframe with at least two numeric columns. + + """ figure_width = 500 select_width = int(figure_width * 0.3) - - if df.shape[1] < 2: - raise ValueError("DataFrame must contain at least two columns.") + js_code = _read_js_code() cols = df.attrs["metadata"]["cols"] cols = [col["name"] for col in cols if col["type"] in ["float", "integer"]] @@ -40,16 +48,18 @@ def scatter( if size is None: size = cols[0] - # Set up initial data for scatter plot - initial_x = x - initial_y = y - initial_size = size + # normalize sizes. this should match what's in + # axes_callback.js + sizes = list(df[size]) + min_size = min(sizes) + max_size = max(sizes) + sizes = [2 + 15 * (value - min_size) / (max_size - min_size) for value in sizes] # CDS for scatter data data = dict( - x=list(df[initial_x]), - y=list(df[initial_y]), - size=list(df[initial_size]), + x=list(df[x]), + y=list(df[y]), + size=sizes, ) scatter_source = ColumnDataSource(data) @@ -86,34 +96,31 @@ def scatter( ) # Set initial axis labels - p.xaxis.axis_label = initial_x - p.yaxis.axis_label = initial_y + p.xaxis.axis_label = x + p.yaxis.axis_label = y # Create dropdown selectors for X and Y axes x_select = Select( title="X-Axis", - value=initial_x, + value=x, options=cols, width=select_width, ) y_select = Select( title="Y-Axis", - value=initial_y, + value=y, options=cols, width=select_width, ) size_select = Select( title="Size", - value=initial_size, + value=size, options=cols, width=select_width, ) # JavaScript callback to update data, axis labels, and point sizes on select change - with importlib.resources.open_text("deeporigin.data_hub", "axes_callback.js") as f: - axes_callback_js = f.read() - axes_callback = CustomJS( args=dict( source=scatter_source, @@ -124,18 +131,15 @@ def scatter( x_axis=p.xaxis[0], y_axis=p.yaxis[0], ), - code=axes_callback_js, + code=js_code["axes_callback"], ) # JS code, will run in browser # this updates the value of the slider to the currently # hovered point - code = """ - - """ callback = CustomJS( - code=code, + code=js_code["hover_callback"], args=dict( marker_source=marker_source, scatter_source=scatter_source,