From 439d75a1dae503596550cc01efcfd1c3275d71d8 Mon Sep 17 00:00:00 2001 From: janezd Date: Mon, 11 Jul 2022 17:02:19 +0200 Subject: [PATCH] Split: Refactor for discrete values, add tests, rename --- .../widgets/icons/TextToColumns.svg | 33 +++ .../prototypes/widgets/owtexttocolumns.py | 182 ++++++++++++++++ .../widgets/tests/test_owtexttocolumns.py | 197 ++++++++++++++++++ 3 files changed, 412 insertions(+) create mode 100644 orangecontrib/prototypes/widgets/icons/TextToColumns.svg create mode 100644 orangecontrib/prototypes/widgets/owtexttocolumns.py create mode 100644 orangecontrib/prototypes/widgets/tests/test_owtexttocolumns.py diff --git a/orangecontrib/prototypes/widgets/icons/TextToColumns.svg b/orangecontrib/prototypes/widgets/icons/TextToColumns.svg new file mode 100644 index 00000000..bc42545f --- /dev/null +++ b/orangecontrib/prototypes/widgets/icons/TextToColumns.svg @@ -0,0 +1,33 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/orangecontrib/prototypes/widgets/owtexttocolumns.py b/orangecontrib/prototypes/widgets/owtexttocolumns.py new file mode 100644 index 00000000..4097e506 --- /dev/null +++ b/orangecontrib/prototypes/widgets/owtexttocolumns.py @@ -0,0 +1,182 @@ +from functools import partial + +import numpy as np + +from AnyQt.QtCore import Qt + +from Orange.widgets import gui +from Orange.widgets.settings import ContextSetting, DomainContextHandler +from Orange.widgets.widget import OWWidget, Msg, Output, Input +from Orange.widgets.utils.itemmodels import DomainModel +from Orange.widgets.utils.widgetpreview import WidgetPreview +from Orange.data import Table, Domain, DiscreteVariable, StringVariable +from Orange.data.util import SharedComputeValue, get_unique_names + +from orangewidget.settings import Setting + + +def get_substrings(values, delimiter): + return sorted({ss.strip() for s in values for ss in s.split(delimiter)} + - {""}) + + +# TODO: Replace with Table.get_column after merging and releasing +# https://github.com/biolab/orange3/pull/6058 +def get_column(data, attr): + if attr not in data.domain and attr.compute_value is not None: + return attr.compute_value(data) + return data.get_column_view(attr)[0] + + +class SplitColumn: + def __init__(self, data, attr, delimiter): + self.attr = attr + self.delimiter = delimiter + column = set(get_column(data, self.attr)) + self.new_values = tuple(get_substrings(column, self.delimiter)) + + def __call__(self, data): + column = get_column(data, self.attr) + values = [{ss.strip() for ss in s.split(self.delimiter)} + for s in column] + return {v: np.array([i for i, xs in enumerate(values) if v in xs]) + for v in self.new_values} + + def __eq__(self, other): + return self.attr == other.attr \ + and self.delimiter == other.delimiter \ + and self.new_values == other.new_values + + def __hash__(self): + return hash((self.attr, self.delimiter, self.new_values)) + + +class OneHotStrings(SharedComputeValue): + def __init__(self, fn, new_feature): + super().__init__(fn) + self.new_feature = new_feature + + def compute(self, data, shared_data): + indices = shared_data[self.new_feature] + col = np.zeros(len(data)) + col[indices] = 1 + return col + + def __eq__(self, other): + return super().__eq__(other) and self.new_feature == other.new_feature + + def __hash__(self): + return super().__hash__() ^ hash(self.new_feature) + + +class OneHotDiscrete: + def __init__(self, variable, delimiter, value): + self.variable = variable + self.value = value + self.delimiter = delimiter + + def __call__(self, data): + column = get_column(data, self.variable).astype(float) + col = np.zeros(len(column)) + col[np.isnan(column)] = np.nan + for val_idx, value in enumerate(self.variable.values): + if self.value in value.split(self.delimiter): + col[column == val_idx] = 1 + return col + + def __eq__(self, other): + return self.variable == other.variable \ + and self.value == other.value \ + and self.delimiter == other.delimiter + + def __hash__(self): + return hash((self.variable, self.value, self.delimiter)) + + +class OWTextToColumns(OWWidget): + name = "Text to Columns" + description = "Split text or categorical variables into binary indicators" + icon = "icons/TextToColumns.svg" + priority = 700 + + class Inputs: + data = Input("Data", Table) + + class Outputs: + data = Output("Data", Table) + + class Warning(OWWidget.Warning): + no_disc = Msg("Data contains only numeric variables.") + + want_main_area = False + resizing_enabled = False + + settingsHandler = DomainContextHandler() + attribute = ContextSetting(None) + delimiter = ContextSetting(";") + auto_apply = Setting(True) + + def __init__(self): + super().__init__() + self.data = None + + variable_select_box = gui.vBox(self.controlArea, "Variable") + + gui.comboBox(variable_select_box, self, "attribute", + orientation=Qt.Horizontal, searchable=True, + callback=self.apply.deferred, + model=DomainModel(valid_types=(StringVariable, + DiscreteVariable))) + gui.lineEdit( + variable_select_box, self, "delimiter", + orientation=Qt.Horizontal, callback=self.apply.deferred) + + gui.auto_apply(self.buttonsArea, self, commit=self.apply) + + @Inputs.data + def set_data(self, data): + self.closeContext() + self.data = data + + model = self.controls.attribute.model() + model.set_domain(data.domain if data is not None else None) + self.Warning.no_disc(shown=data is not None and not model) + if not model: + self.attribute = None + self.data = None + return + self.attribute = model[0] + self.openContext(data) + self.apply.now() + + @gui.deferred + def apply(self): + if self.attribute is None: + self.Outputs.data.send(None) + return + var = self.data.domain[self.attribute] + + if var.is_discrete: + values = get_substrings(var.values, self.delimiter) + computer = partial(OneHotDiscrete, var, self.delimiter) + else: + sc = SplitColumn(self.data, var, self.delimiter) + values = sc.new_values + computer = partial(OneHotStrings, sc) + names = get_unique_names(self.data.domain, values, equal_numbers=False) + + new_columns = tuple(DiscreteVariable( + name, values=("0", "1"), compute_value=computer(value) + ) for value, name in zip(values, names)) + + new_domain = Domain( + self.data.domain.attributes + new_columns, + self.data.domain.class_vars, self.data.domain.metas + ) + extended_data = self.data.transform(new_domain) + self.Outputs.data.send(extended_data) + + +if __name__ == "__main__": # pragma: no cover + WidgetPreview(OWTextToColumns).run(Table.from_file( + "tests/orange-in-education.tab")) diff --git a/orangecontrib/prototypes/widgets/tests/test_owtexttocolumns.py b/orangecontrib/prototypes/widgets/tests/test_owtexttocolumns.py new file mode 100644 index 00000000..3fca2264 --- /dev/null +++ b/orangecontrib/prototypes/widgets/tests/test_owtexttocolumns.py @@ -0,0 +1,197 @@ +# pylint: disable=missing-docstring,unsubscriptable-object +import os +import unittest + +import numpy as np + +from Orange.data import Table, StringVariable, Domain, DiscreteVariable +from Orange.widgets.tests.base import WidgetTest + +from orangecontrib.prototypes.widgets.owtexttocolumns import \ + OWTextToColumns, SplitColumn, get_substrings, OneHotStrings, OneHotDiscrete + + +class TestComputation(unittest.TestCase): + def setUp(self): + domain = Domain([DiscreteVariable("x", values=("a c d", "bb d"))], None, + [StringVariable("foo"), StringVariable("bar")]) + self.data = Table.from_numpy( + domain, + np.array([1, 0, np.nan])[:, None], None, + [["a,bbb,d", "e;f o"], ["", "f o"], ["bbb,d", "e;a;o"]] + ) + + def test_get_string_values(self): + np.testing.assert_equal( + set(get_substrings({"a bc", "d,e", "", "f,a t", "t"}, " ")), + {"a", "bc", "d,e", "f,a", "t"}) + np.testing.assert_equal( + set(get_substrings({"a bc", "d,e", "", "f,a t", "t"}, ",")), + {"a bc", "d", "e", "f", "a t", "t"}) + + def test_split_column(self): + sc = SplitColumn(self.data, self.data.domain.metas[0], ",") + shared = sc(self.data) + self.assertEqual(set(sc.new_values), {"a", "bbb", "d"}) + self.assertEqual(set(shared), set(sc.new_values)) + np.testing.assert_equal(shared["a"], [0]) + np.testing.assert_equal(shared["bbb"], [0, 2]) + np.testing.assert_equal(shared["d"], [0, 2]) + + sc = SplitColumn(self.data, self.data.domain.metas[1], ";") + shared = sc(self.data) + self.assertEqual(set(sc.new_values), {"a", "e", "f o", "o"}) + self.assertEqual(set(shared), set(sc.new_values)) + np.testing.assert_equal(shared["a"], [2]) + np.testing.assert_equal(shared["e"], [0, 2]) + np.testing.assert_equal(shared["f o"], [0, 1]) + np.testing.assert_equal(shared["o"], [2]) + + def test_one_hot_strings(self): + attr = self.data.domain.metas[0] + sc = SplitColumn(self.data, attr, ",") + + oh = OneHotStrings(sc, "a") + np.testing.assert_equal(oh(self.data), [1, 0, 0]) + + oh = OneHotStrings(sc, "bbb") + np.testing.assert_equal(oh(self.data), [1, 0, 1]) + + data = Table.from_numpy( + Domain([], None, [attr]), + np.zeros((5, 0)), None, + np.array(["bbb,x,y", "", "bbb", "bbb,a", "foo"])[:, None]) + np.testing.assert_equal(oh(data), [1, 0, 1, 1, 0]) + + def test_one_hot_discrete(self): + attr = self.data.domain.attributes[0] + + oh = OneHotDiscrete(attr, " ", "a") + np.testing.assert_equal(oh(self.data), [0, 1, np.nan]) + + oh = OneHotDiscrete(attr, " ", "d") + np.testing.assert_equal(oh(self.data), [1, 1, np.nan]) + + data = Table.from_numpy( + Domain([attr], None), + np.array([1, 0, 1, 0, np.nan])[:, None]) + + oh = OneHotDiscrete(attr, " ", "a") + np.testing.assert_equal(oh(data), [0, 1, 0, 1, np.nan]) + + oh = OneHotDiscrete(attr, " ", "d") + np.testing.assert_equal(oh(data), [1, 1, 1, 1, np.nan]) + + def test_discrete_metas(self): + attr = DiscreteVariable("x", values=("a c d", "bb d")) + domain = Domain([], None, [attr]) + data = Table.from_numpy(domain, np.zeros((3, 0)), None, + np.array([1, 0, np.nan])[:, None]) + oh = OneHotDiscrete(attr, " ", "a") + np.testing.assert_equal(oh(data), [0, 1, np.nan]) + + +class TestOWTextToColumns(WidgetTest): + def setUp(self): + self.widget = self.create_widget(OWTextToColumns) + test_path = os.path.dirname(os.path.abspath(__file__)) + self.data = Table.from_file(os.path.join(test_path, "orange-in-education.tab")) + self._create_simple_corpus() + + def _set_attr(self, attr, widget=None): + if widget is None: + widget = self.widget + attr_combo = widget.controls.attribute + idx = attr_combo.model().indexOf(attr) + attr_combo.setCurrentIndex(idx) + attr_combo.activated.emit(idx) + + def _create_simple_corpus(self) -> None: + """ + Create a simple dataset with 4 documents. + """ + metas = np.array( + [ + ["foo,"], + ["bar,baz "], + ["foo,bar"], + [""], + ] + ) + text_var = StringVariable("foo") + domain = Domain([], metas=[text_var]) + self.small_table = Table.from_numpy( + domain, + X=np.empty((len(metas), 0)), + metas=metas, + ) + + def test_data(self): + """Basic functionality""" + self.send_signal(self.widget.Inputs.data, self.data) + self._set_attr(self.data.domain.attributes[1]) + output = self.get_output(self.widget.Outputs.data) + self.assertEqual(len(output.domain.attributes), + len(self.data.domain.attributes) + 3) + + def test_empty_data(self): + """Do not crash on empty data""" + self.send_signal(self.widget.Inputs.data, None) + + def test_discrete(self): + """No crash on data attributes of different types""" + self.send_signal(self.widget.Inputs.data, self.data) + self.assertEqual(self.widget.attribute, self.data.domain.metas[1]) + self._set_attr(self.data.domain.attributes[1]) + self.assertEqual(self.widget.attribute, self.data.domain.attributes[1]) + + def test_numeric_only(self): + """Error raised when only numeric variables given""" + housing = Table.from_file("housing") + self.send_signal(self.widget.Inputs.data, housing) + self.assertTrue(self.widget.Warning.no_disc.is_shown()) + + def test_split_nonexisting(self): + """Test splitting when delimiter doesn't exist""" + self.widget.delimiter = "|" + self.send_signal(self.widget.Inputs.data, self.data) + new_cols = set(self.data.get_column_view("Country")[0]) + self.assertFalse(any(self.widget.delimiter in v for v in new_cols)) + self.assertEqual(len(self.get_output( + self.widget.Outputs.data).domain.attributes), + len(self.data.domain.attributes) + len(new_cols)) + + def test_output_string(self): + "Test outputs; at the same time, test for duplicate variables" + self.widget.delimiter = "," + self.send_signal(self.widget.Inputs.data, self.small_table) + out = self.get_output(self.widget.Outputs.data) + self.assertEqual([attr.name for attr in out.domain.attributes], + ["bar", "baz", "foo (1)"]) + np.testing.assert_equal(out.X, + [[0, 0, 1], + [1, 1, 0], + [1, 0, 1], + [0, 0, 0]]) + + def test_output_discrete(self): + self.widget.delimiter = " " + attr = DiscreteVariable("x", values=("bar foo", "bar baz", "crux")) + data = Table.from_numpy( + Domain([attr], None), + np.array([1, 1, 0, 1, 2, np.nan])[:, None], None) + self.send_signal(self.widget.Inputs.data, data) + out = self.get_output(self.widget.Outputs.data) + self.assertEqual([attr.name for attr in out.domain.attributes], + ["x", "bar", "baz", "crux", "foo"]) + np.testing.assert_equal(out.X, + [[1, 1, 1, 0, 0], + [1, 1, 1, 0, 0], + [0, 1, 0, 0, 1], + [1, 1, 1, 0, 0], + [2, 0, 0, 1, 0], + [np.nan, np.nan, np.nan, np.nan, np.nan]]) + + +if __name__ == "__main__": + unittest.main()