Skip to content

Commit

Permalink
descriptive interaction variables
Browse files Browse the repository at this point in the history
  • Loading branch information
noahnovsak committed Oct 14, 2022
1 parent 8743c45 commit 3d8faa0
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 33 deletions.
21 changes: 12 additions & 9 deletions orangecontrib/prototypes/interactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ def entropy(ar):
return -np.sum(p * np.log2(p))


class Interaction:
class InteractionScorer:
def __init__(self, data):
self.data = data
self.class_h = 0
self.gains = np.zeros(data.X.shape[1])
self.class_entropy = 0
self.information_gain = np.zeros(data.X.shape[1])

self.precompute()

Expand All @@ -61,16 +61,19 @@ def precompute(self):
well as negative interactions with greater magnitude than the
combined information gain.
"""
self.class_h = entropy(self.data.Y)
for attr in range(self.gains.size):
self.gains[attr] = self.class_h \
self.class_entropy = entropy(self.data.Y)
for attr in range(self.information_gain.size):
self.information_gain[attr] = self.class_entropy \
+ entropy(self.data.X[:, attr]) \
- entropy(np.column_stack((self.data.X[:, attr], self.data.Y)))

def __call__(self, attr1, attr2):
attrs = np.column_stack((self.data.X[:, attr1], self.data.X[:, attr2]))
return self.class_h \
- self.gains[attr1] \
- self.gains[attr2] \
return self.class_entropy \
- self.information_gain[attr1] \
- self.information_gain[attr2] \
+ entropy(attrs) \
- entropy(np.column_stack((attrs, self.data.Y)))

def normalize(self, score):
return score / self.class_entropy
20 changes: 9 additions & 11 deletions orangecontrib/prototypes/widgets/owinteractions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from Orange.preprocess import Discretize, Remove
import Orange.widgets.data.owcorrelations

from orangecontrib.prototypes.interactions import Interaction
from orangecontrib.prototypes.interactions import InteractionScorer


SIZE_LIMIT = 1000000
Expand Down Expand Up @@ -149,7 +149,7 @@ class InteractionRank(Orange.widgets.data.owcorrelations.CorrelationRank):

def __init__(self, *args):
VizRankDialogAttrPair.__init__(self, *args)
self.interaction = None
self.scorer = None
self.heuristic = None
self.use_heuristic = False
self.sel_feature_index = None
Expand All @@ -175,19 +175,17 @@ def initialize(self):
self.use_heuristic = False
self.sel_feature_index = self.master.feature and data.domain.index(self.master.feature)
if data:
if self.interaction is None or self.interaction.data != data:
self.interaction = Interaction(data)
if self.scorer is None or self.scorer.data != data:
self.scorer = InteractionScorer(data)
self.use_heuristic = len(data) * len(self.attrs) ** 2 > SIZE_LIMIT
if self.use_heuristic and not self.sel_feature_index:
self.heuristic = Heuristic(self.interaction.gains, self.master.heuristic_type)
self.heuristic = Heuristic(self.scorer.information_gain, self.master.heuristic_type)

def compute_score(self, state):
attr1, attr2 = state
h = self.interaction.class_h
score = self.interaction(attr1, attr2) / h
gain1 = self.interaction.gains[attr1] / h
gain2 = self.interaction.gains[attr2] / h
return score, gain1, gain2
scores = (self.scorer(*state),
self.scorer.information_gain[state[0]],
self.scorer.information_gain[state[1]])
return tuple(self.scorer.normalize(score) for score in scores)

def row_for_state(self, score, state):
attrs = sorted((self.attrs[x] for x in state), key=attrgetter("name"))
Expand Down
26 changes: 13 additions & 13 deletions orangecontrib/prototypes/widgets/tests/test_owinteractions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from Orange.widgets.widget import AttributeList
from orangecontrib.prototypes.widgets.owinteractions import \
OWInteractions, Heuristic, HeuristicType, InteractionRank
from orangecontrib.prototypes.interactions import Interaction
from orangecontrib.prototypes.interactions import InteractionScorer


class TestOWInteractions(WidgetTest):
Expand Down Expand Up @@ -276,23 +276,23 @@ def test_compute_score(self):
y = np.array([0, 1, 1, 1])
domain = Domain([DiscreteVariable(str(i)) for i in range(2)], DiscreteVariable("3"))
data = Table(domain, x, y)
self.interaction = Interaction(data)
npt.assert_almost_equal(self.interaction(0, 1), -0.1226, 4)
npt.assert_almost_equal(self.interaction.class_h, 0.8113, 4)
npt.assert_almost_equal(self.interaction.gains[0], 0.3113, 4)
npt.assert_almost_equal(self.interaction.gains[1], 0.1226, 4)
self.scorer = InteractionScorer(data)
npt.assert_almost_equal(self.scorer(0, 1), -0.1226, 4)
npt.assert_almost_equal(self.scorer.class_entropy, 0.8113, 4)
npt.assert_almost_equal(self.scorer.information_gain[0], 0.3113, 4)
npt.assert_almost_equal(self.scorer.information_gain[1], 0.1226, 4)

def test_nans(self):
"""Check score calculation with sparse data"""
x = np.array([[1, 1], [0, 1], [1, 1], [0, 0], [1, np.nan], [np.nan, 0], [np.nan, np.nan]])
y = np.array([0, 1, 1, 1, 0, 0, 1])
domain = Domain([DiscreteVariable(str(i)) for i in range(2)], DiscreteVariable("3"))
data = Table(domain, x, y)
self.interaction = Interaction(data)
npt.assert_almost_equal(self.interaction(0, 1), 0.0167, 4)
npt.assert_almost_equal(self.interaction.class_h, 0.9852, 4)
npt.assert_almost_equal(self.interaction.gains[0], 0.4343, 4)
npt.assert_almost_equal(self.interaction.gains[1], 0.0343, 4)
self.scorer = InteractionScorer(data)
npt.assert_almost_equal(self.scorer(0, 1), 0.0167, 4)
npt.assert_almost_equal(self.scorer.class_entropy, 0.9852, 4)
npt.assert_almost_equal(self.scorer.information_gain[0], 0.4343, 4)
npt.assert_almost_equal(self.scorer.information_gain[1], 0.0343, 4)


class TestHeuristic(unittest.TestCase):
Expand All @@ -302,8 +302,8 @@ def setUpClass(cls):

def test_heuristic(self):
"""Check attribute pairs returned by heuristic"""
score = Interaction(self.zoo)
heuristic = Heuristic(score.gains, heuristic_type=HeuristicType.INFOGAIN)
scorer = InteractionScorer(self.zoo)
heuristic = Heuristic(scorer.information_gain, heuristic_type=HeuristicType.INFOGAIN)
self.assertListEqual(
list(heuristic.get_states(None))[:9],
[(14, 6), (14, 10), (14, 15), (6, 10), (14, 5), (6, 15), (14, 11), (6, 5), (10, 15)]
Expand Down

0 comments on commit 3d8faa0

Please sign in to comment.