Skip to content

Commit

Permalink
Merge pull request #5 from guardrails-ai/jc/change_behavior
Browse files Browse the repository at this point in the history
Fix 'OnFix' behavior.
  • Loading branch information
JosephCatrambone authored Dec 6, 2024
2 parents 7b504ed + 680a8f8 commit 90061b5
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 44 deletions.
18 changes: 18 additions & 0 deletions .github/workflows/publish_pypi.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
name: Publish to Guardrails Hub

on:
workflow_dispatch:
push:
# Publish when new releases are tagged.
tags:
- '*'

jobs:
setup:
runs-on: ubuntu-latest
steps:
- name: Build & Deploy
uses: guardrails-ai/guardrails/.github/actions/validator_pypi_publish@main
with:
guardrails_token: ${{ secrets.GR_GUARDRAILS_TOKEN }}
validator_id: guardrails/bias_check
10 changes: 7 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@ This validator can be used to ensure fairness of model output across various dem

* Dependencies:
- guardrails-ai>=0.5.0
- transformers>=4.40
- transformers>=4.37
- tf-keras
- sentencepiece
- tensorflow>=2.16
- sentence-splitter

* Dev Dependencies:
- pytest
Expand Down Expand Up @@ -50,5 +54,5 @@ guard = Guard().use(
BiasCheck(threshold=0.9, on_fail="exception")
)

guard.validate("The movie was great!") # Validator passes
guard.validate("Why do men always think the movie was great?") # Validator fails
guard.validate("The movie was great.") # Validator passes
guard.validate("Why do men think that movie was great?") # Validator fails
13 changes: 8 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
[project]
name = "bias_check"
version = "0.2.0"
version = "0.3.0"
description = "Wrapping a bias check model into a validator."
authors = [
{name = "Guardrails AI", email = "[email protected]"}
]
license = {file = "LICENSE"}
readme = "README.md"
requires-python = ">= 3.8.1"
requires-python = ">= 3.9"
dependencies = [
"guardrails-ai>=0.4.0",
"transformers>=4.40.2",
"tensorflow>=2.18.0"
"guardrails-ai>=0.5.15",
"transformers>=4.37.0",
"tf-keras",
"sentencepiece",
"tensorflow>=2.16.0", # Required for the dbias model, but not as a direct dependency.
"sentence-splitter>=1.4"
]

[project.optional-dependencies]
Expand Down
34 changes: 20 additions & 14 deletions tests/test_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,26 @@

from validator import BiasCheck

# We use 'exception' as the validator's fail action,
# so we expect failures to always raise an Exception
# Learn more about corrective actions here:
# https://www.guardrailsai.com/docs/concepts/output/#%EF%B8%8F-specifying-corrective-actions

def test_success_case():
validator = BiasCheck()
input_text = "The sun rises in the morning."
result = validator.validate(input_text, {})
assert isinstance(result, PassResult)
validator = BiasCheck()
input_text = "The sun rises in the morning."
result = validator.validate(input_text, {})
assert isinstance(result, PassResult)


def test_failure_case():
validator = BiasCheck(threshold=0.5)
input_text = "The sun only rises for Humanists."
result = validator.validate(input_text, {})
assert isinstance(result, FailResult)
assert result.error_message == "The original response contains potentially biased messages:\nThe sun only rises for Humanists.\n (Message scores: 0.6546186208724976)"
assert result.fix_value == ""
validator = BiasCheck(threshold=0.0)
input_text = "Men these days don't care about my arbitrary and deletarious standards of gender. They only care about emotional honesty and participating in a productive, healthy society. smh"
result = validator.validate(input_text, {})
assert isinstance(result, FailResult)
assert result.error_message == "The original response contains potentially biased messages:\nMen these days don't care about my arbitrary and deletarious standards of gender. They only care about emotional honesty and participating in a productive, healthy society. smh\n (Message scores: 0.9477301836013794)"
assert result.fix_value == ""


def test_sentence_fix():
v = BiasCheck(on_fail='fix', threshold=0.9)
input_text = "Men these days don't care about my arbitrary and deletarious standards of gender. They only care about emotional honesty and participating in a productive, healthy society. smh"
out = v.validate(input_text)
assert isinstance(out, FailResult)
assert out.fix_value == "Men these days don't care about my arbitrary and deletarious standards of gender."
90 changes: 69 additions & 21 deletions validator/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Dict, Optional
from typing import Callable, Dict, List, Optional, Union

from guardrails.validator_base import (
FailResult,
Expand All @@ -7,9 +7,11 @@
Validator,
register_validator,
)

from guardrails.types import OnFailAction
from sentence_splitter import split_text_into_sentences
from transformers import pipeline


@register_validator(name="guardrails/bias_check", data_type="string")
class BiasCheck(Validator):
"""Validates that the text is free from biases related to age, gender, sex, ethnicity, religion, etc.
Expand All @@ -23,61 +25,107 @@ class BiasCheck(Validator):
| Programmatic fix | The debiased text if bias is detected |
Args:
threshold (float): Higher is more likely to allow bias. Lower is more sensitive and more likely to flag biased messages.
on_fail (Callable): The policy to enact when a validator fails. If `str`, must be one of `filter`, `noop`, or `exception`. Otherwise, must be a function that is called when the validator fails.
threshold (float): Higher is more likely to allow bias. Lower is more sensitive and more likely to flag biased messages.
on_fail (Callable): The policy to enact when a validator fails. If `str`, must be one of `noop`, `fix`, or `exception`. Otherwise, must be a function that is called when the validator fails.
""" # noqa

def __init__(
self,
threshold: float = 0.9,
on_fail: Optional[Callable] = None,
on_fail: Optional[Union[str, Callable]] = None,
):
super().__init__(on_fail=on_fail)
valid_on_fail_operations = {"filter", "noop", "exception"}
super().__init__(on_fail=on_fail) # type: ignore
valid_on_fail_operations = {"fix", "noop", "exception"}
if isinstance(on_fail, str) and on_fail not in valid_on_fail_operations:
raise Exception(
f"on_fail value ({on_fail}) not in list of allowable operations: {valid_on_fail_operations}"
)
self.threshold = threshold
self.model = pipeline(

# There are some spurious loading complaints with TFDistilBert models.
# See https://discuss.huggingface.co/t/message-some-layers-from-the-model-were-not-used/1972/7
self.classification_model = pipeline(
'text-classification',
model="d4data/bias-detection-model",
tokenizer="d4data/bias-detection-model",
)
# There are some spurious loading complaints with TFDistilBert models.
# See https://discuss.huggingface.co/t/message-some-layers-from-the-model-were-not-used/1972/7

def validate(self, value: Any, metadata: Optional[Dict] = None) -> ValidationResult:
def validate(
self,
value: Union[str, List[str]],
metadata: Optional[Dict] = None
) -> ValidationResult:
"""Validates that the text is free from biases related to age, gender, sex, ethnicity, religion, etc."""
single_sentence_passed = False
if isinstance(value, str):
single_sentence_passed = True
value = [value,] # Ensure we're always passing lists of strings into the classifier.

classified_examples = self.model(value)
scores = self._inference(value)
passing_outputs = list()
passing_scores = list()
failing_outputs = list()
failing_scores = list()
for text, prediction in zip(value, classified_examples):
if prediction['label'] == 'Biased':
score = prediction['score']
elif prediction['label'] == 'Non-biased':
score = -prediction['score'] # Note the negation!
else:
raise Exception(f"Got unexpected prediction label: {prediction['label']}")
all_outputs = list() # A tuple of (fix/ignore, sentence)
for text, score in zip(value, scores):
if score > self.threshold:
failing_outputs.append(text)
failing_scores.append(score)
else:
passing_outputs.append(text)
passing_scores.append(score)
all_outputs.append((score > self.threshold, text))

if failing_outputs:
failure_message = "The original response contains potentially biased messages:\n"
failure_message += "\n - ".join(failing_outputs)
message_scores = [str(s) for s in failing_scores]
failure_message += "\n (Message scores: {})".format(", ".join(message_scores))
# Do we need to call the on_fail_method here?
# Three paths: noop, exception, fix.
# on_fail == NOOP, return only passing passages.
# on_fail == FIX, split passages into sentences and drop sentences.
# EXCEPTION is handled farther up the stack.
if self.on_fail_descriptor != OnFailAction.FIX:
fix_value = passing_outputs
else:
fix_value = list()
for needs_fix, text in all_outputs:
if not needs_fix:
fix_value.append(text)
else:
# The 'text' is a full document, passage, or paragraph.
fix_value.append(self.fix_passage(text))
return FailResult(
error_message=failure_message,
fix_value=" ".join(passing_outputs),
fix_value=" ".join(fix_value) if single_sentence_passed else fix_value,
)
return PassResult()

def fix_passage(self, text: str) -> str:
"""Given a passage of text, split it into sentences, evaluate each for bias,
then recombine them and return a new paragraph. May not preserve whitespace
between sentences."""
sentences = split_text_into_sentences(text, language='en')
scores = self._inference(sentences)
unbiased_sentences = list()
for score, sentence in zip(scores, sentences):
if score < self.threshold:
unbiased_sentences.append(sentence)
return " ".join(unbiased_sentences)

# This normally will be called by _inference.
# Remote inference is unsupported for this model on account of the NER.
def _inference_local(self, sentences: List[str]) -> List[float]: # type: ignore
scores = list()
predictions = self.classification_model(sentences)
for pred in predictions:
label = pred['label'] # type: ignore
score = pred['score'] # type: ignore
if label == 'Biased':
scores.append(score)
elif label == 'Non-biased':
scores.append(-score)
else:
# This should never happen:
raise Exception("Unexpected prediction label: {}".format(label))
return scores
6 changes: 5 additions & 1 deletion validator/post-install.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from transformers import pipeline
print("post-install starting...")
_ = pipeline("text-classification", "d4data/bias-detection-model")
_ = pipeline(
'text-classification',
model="d4data/bias-detection-model",
tokenizer="d4data/bias-detection-model",
)
print("post-install complete!")

0 comments on commit 90061b5

Please sign in to comment.