-
Notifications
You must be signed in to change notification settings - Fork 308
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Show the number of not classified instances for multi-label models (#3964) #4150
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -120,7 +120,7 @@ def classification_report_imbalanced_values( | |
return result | ||
|
||
|
||
def print_labeled_confusion_matrix(confusion_matrix, labels, is_multilabel=False): | ||
def print_labeled_confusion_matrix(confusion_matrix, labels, is_multilabel=False, test_classes=None): | ||
confusion_matrix_table = confusion_matrix.tolist() | ||
|
||
# Don't show the Not classified row in the table output | ||
|
@@ -144,8 +144,13 @@ def print_labeled_confusion_matrix(confusion_matrix, labels, is_multilabel=False | |
if table_labels[i] != "__NOT_CLASSIFIED__" | ||
else "Not classified" | ||
) | ||
if is_multilabel and test_classes is not None: | ||
confusion_matrix_header.append("Not classified") | ||
for i in range(len(table)): | ||
table[i].insert(0, f"{table_labels[i]} (Actual)") | ||
if is_multilabel and test_classes is not None: | ||
y_true_count = (test_classes[:, num].tolist()).count(i) | ||
table[i].append(y_true_count - sum(table[i][1:])) | ||
print( | ||
tabulate(table, headers=confusion_matrix_header, tablefmt="fancy_grid"), | ||
end="\n\n", | ||
|
@@ -499,8 +504,9 @@ def train(self, importance_cutoff=0.15, limit=None): | |
|
||
tracking_metrics["report"] = report | ||
|
||
# no confidence threshold - no need to handle 'Not classified' instances | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this comment is helpful in the context of reviewing this PR. However, without this context, the comment feels out of place (after merging the PR). |
||
print_labeled_confusion_matrix( | ||
confusion_matrix, self.class_names, is_multilabel=is_multilabel | ||
confusion_matrix, self.class_names, is_multilabel=is_multilabel, test_classes=None, | ||
) | ||
|
||
tracking_metrics["confusion_matrix"] = confusion_matrix.tolist() | ||
|
@@ -567,8 +573,9 @@ def train(self, importance_cutoff=0.15, limit=None): | |
labels=confidence_class_names, | ||
) | ||
) | ||
# with confidence threshold - handle 'Not classified' instances by passing y_test | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The same case as the comment above. |
||
print_labeled_confusion_matrix( | ||
confusion_matrix, confidence_class_names, is_multilabel=is_multilabel | ||
confusion_matrix, confidence_class_names, is_multilabel=is_multilabel, test_classes=y_test | ||
) | ||
|
||
self.evaluation() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The number of "Not classified" was originally part of the table but dropped earlier:
It may be cleaner to avoid dropping it in the first place.