Skip to content

Commit

Permalink
feat: ability to write to database
Browse files Browse the repository at this point in the history
  • Loading branch information
sg-s committed Oct 29, 2024
1 parent 5d36d9a commit 46e684e
Show file tree
Hide file tree
Showing 6 changed files with 196 additions and 21 deletions.
6 changes: 5 additions & 1 deletion src/data_hub/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1125,7 +1125,11 @@ def _row_to_dict(
elif field.type in ["float", "integer", "boolean"]:
value = field.value
elif field.type == "select":
value = field.value.selected_options
value = field.value
if isinstance(value, dict):
value = value["selectedOptions"]
else:
value = value.selected_options

elif field.type == "reference":
value = field.value.row_ids
Expand Down
4 changes: 2 additions & 2 deletions src/data_hub/axes_callback.js
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ const sizes = size_data.map(value => {
});

// Update the data source
source.data = { 'x': x_data, 'y': y_data, 'size': sizes };
scatter_source.data = { 'x': x_data, 'y': y_data, 'size': sizes };

// Update the axis labels
x_axis.axis_label = x_select.value;
y_axis.axis_label = y_select.value;

// Trigger data change
source.change.emit();
scatter_source.change.emit();
35 changes: 35 additions & 0 deletions src/data_hub/button_callback.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Callback when a user presses the button to update labels in the scatter plot Bokeh figure

const selectedData = lasso_selection_source.data;
const label = label_select.value;
const labelColumn = label_column_select.value;

console.log("Will write to these rows:", selectedData.ids);
console.log("Will write this label:", label);
console.log("Will write to this column:", labelColumn);

const rowIds = selectedData.ids;

if (typeof window.deeporigin !== "undefined") {
// Reset selections for all the rows in question
const resetChanges = rowIds.map(rowId => ({
rowId: rowId,
fieldChangeEvents: [{
columnId: labelColumn,
newValue: { selectedOptions: [] }
}]
}));

deeporigin.dataHub.primaryDatabase.editRows({ changes: resetChanges });

// Write the new value
const updateChanges = rowIds.map(rowId => ({
rowId: rowId,
fieldChangeEvents: [{
columnId: labelColumn,
newValue: { selectedOptions: [label] }
}]
}));

deeporigin.dataHub.primaryDatabase.editRows({ changes: updateChanges });
}
5 changes: 3 additions & 2 deletions src/data_hub/hover_callback.js
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@ if (cb_data.index.indices.length > 0) {
marker_source.change.emit();

// callback to update selection in database
var id = scatter_source.data['id'][chosen_index];
if (typeof window.deeporigin !== "undefined") {
var id = source.data['id'][index];
deeporigin.dataHub.primaryDatabase.addSelection({ selections: [{ rowId: id }] })
deeporigin.dataHub.primaryDatabase.clearRangeSelection();
deeporigin.dataHub.primaryDatabase.addSelection({ selections: [{ rowId: id }] });

}

Expand Down
32 changes: 32 additions & 0 deletions src/data_hub/lasso_callback.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// callback on lasso selection

function debounce(func, wait) {
let timeout;
return function (...args) {
const context = this;
clearTimeout(timeout);
timeout = setTimeout(() => func.apply(context, args), wait);
};
}

const processSelection = () => {
const selected_indices = scatter_source.selected.indices;
let selected_data = selected_indices.map(i => {
let row = {};
for (const key in scatter_source.data) {
row[key] = scatter_source.data[key][i];
}
return row;
});

const ids = selected_data.map(item => item.id);
console.log("Lasso selected points data:", ids);

lasso_selection_source.data.ids = ids;
lasso_selection_source.change.emit();

};


const debouncedSelection = debounce(processSelection, 50);
debouncedSelection();
135 changes: 119 additions & 16 deletions src/data_hub/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,14 @@
from beartype.typing import Optional
from bokeh.io import show
from bokeh.layouts import column, row
from bokeh.models import ColumnDataSource, CustomJS, HoverTool, Select
from bokeh.models import (
Button,
ColumnDataSource,
CustomJS,
HoverTool,
LassoSelectTool,
Select,
)
from bokeh.plotting import figure
from deeporigin.data_hub.dataframe import DataFrame
from deeporigin.exceptions import DeepOriginException
Expand Down Expand Up @@ -34,19 +41,29 @@ def scatter(
js_code = _read_js_code()

cols = df.attrs["metadata"]["cols"]
cols = [col["name"] for col in cols if col["type"] in ["float", "integer"]]
numeric_cols = [col["name"] for col in cols if col["type"] in ["float", "integer"]]

if len(cols) < 2:
label_data = None

select_cols = [
col for col in df.attrs["metadata"]["cols"] if col["type"] == "select"
]
if len(select_cols) > 0:
select_col_names = [col["name"] for col in select_cols]
select_col_options = [col["configSelect"]["options"] for col in select_cols]
label_data = dict(zip(select_col_names, select_col_options))

if len(numeric_cols) < 2:
raise DeepOriginException(
"DataFrame must contain at least two numeric columns."
)

if x is None:
x = cols[0]
x = numeric_cols[0]
if y is None:
y = cols[1]
y = numeric_cols[1]
if size is None:
size = cols[0]
size = numeric_cols[0]

# normalize sizes. this should match what's in
# axes_callback.js
Expand All @@ -60,12 +77,22 @@ def scatter(
x=list(df[x]),
y=list(df[y]),
size=sizes,
id=df.index,
)

# CDS for scatter plot
scatter_source = ColumnDataSource(data)

# CDS for marker
marker_source = ColumnDataSource(_first_element_in_dict(data))

# CDS to store lasso selection data
lasso_selection_source = ColumnDataSource(dict(ids=[]))

# CDS to store labels
if label_data:
label_source = ColumnDataSource(label_data)

# Create the scatter plot figure
p = figure(
width=figure_width,
Expand Down Expand Up @@ -103,27 +130,66 @@ def scatter(
x_select = Select(
title="X-Axis",
value=x,
options=cols,
options=numeric_cols,
width=select_width,
)
y_select = Select(
title="Y-Axis",
value=y,
options=cols,
options=numeric_cols,
width=select_width,
)
size_select = Select(
title="Size",
value=size,
options=cols,
options=numeric_cols,
width=select_width,
)

# create dropdown selectors for label column
if label_data:
first_col = list(label_data.keys())[0]
label_column_select = Select(
title="Label Column",
value=first_col,
options=list(label_data.keys()),
width=select_width,
)

label_select = Select(
title="Label",
value=label_data[first_col][0],
options=list(set(label_data[first_col])),
width=select_width,
)

# JavaScript callback to update the second select tool based on the selected column
label_select_callback = CustomJS(
args=dict(
label_source=label_source,
label_column_select=label_column_select,
label_select=label_select,
),
code="""
// Get selected column name from first select widget
const column = label_column_select.value;
// Update options of the second select based on unique values of the selected column
const column_data = label_source.data[column];
const unique_values = Array.from(new Set(column_data));
label_select.options = unique_values;
label_select.value = unique_values[0]; // Set default to the first unique value
""",
)

label_column_select.js_on_change("value", label_select_callback)

# JavaScript callback to update data, axis labels, and point sizes on select change

axes_callback = CustomJS(
args=dict(
source=scatter_source,
scatter_source=scatter_source,
x_select=x_select,
y_select=y_select,
size_select=size_select,
Expand All @@ -138,7 +204,7 @@ def scatter(
# this updates the value of the slider to the currently
# hovered point

callback = CustomJS(
hover_callback = CustomJS(
code=js_code["hover_callback"],
args=dict(
marker_source=marker_source,
Expand All @@ -151,7 +217,7 @@ def scatter(
# https://discourse.bokeh.org/t/deactivate-hovertool-for-specific-glyphs/9931/2
hvr = HoverTool(
tooltips=None,
callback=callback,
callback=hover_callback,
)
hvr.renderers = [scatter_glyphs]
p.add_tools(hvr)
Expand All @@ -161,11 +227,47 @@ def scatter(
y_select.js_on_change("value", axes_callback)
size_select.js_on_change("value", axes_callback)

# Layout widgets and plot
layout = column(
row(x_select, y_select, size_select),
p,
lasso_callback = CustomJS(
args=dict(
scatter_source=scatter_source,
lasso_selection_source=lasso_selection_source,
),
code=js_code["lasso_callback"],
)

lasso_tool = LassoSelectTool()
p.add_tools(lasso_tool)
scatter_source.selected.js_on_change("indices", lasso_callback)

# Button to access selected data from selected_source
if label_data:
label_button = Button(label="+Label")
button_callback = CustomJS(
args=dict(
lasso_selection_source=lasso_selection_source,
label_select=label_select,
label_column_select=label_column_select,
),
code=js_code["button_callback"],
)
label_button.js_on_click(button_callback)

# Layout widgets and plot
if label_data:
layout = column(
row(x_select, y_select, size_select),
p,
row(
label_column_select,
label_select,
label_button,
),
)
else:
layout = column(
row(x_select, y_select, size_select),
p,
)
show(layout)


Expand All @@ -182,6 +284,7 @@ def _first_element_in_dict(data: dict) -> dict:
return out_data


@beartype
def _read_js_code() -> dict:
"""utility function to read JS code"""

Expand Down

0 comments on commit 46e684e

Please sign in to comment.