Skip to content

Commit

Permalink
Refactor GuardrailsPII to support entity type mapping and update .git…
Browse files Browse the repository at this point in the history
…ignore.
  • Loading branch information
brianlai98 committed Dec 10, 2024
1 parent fe36f6d commit 5b96bac
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 7 deletions.
Empty file removed .env
Empty file.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ build
.pytest_cache
.ruff_cache
.vscode
.idea
.idea
.ropeproject
40 changes: 34 additions & 6 deletions validator/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def get_entity_threshold(entity: str) -> float:
return 0.5
else:
return 0.0

class InferenceInput(BaseModel):
text: str
entities: List[str]
Expand All @@ -40,17 +40,41 @@ class InferenceOutputResult(BaseModel):
start: int
end: int
score: float

class InferenceOutput(BaseModel):
results: List[InferenceOutputResult]
results: List[InferenceOutputResult]
anonymized_text: str


@register_validator(name="guardrails/guardrails_pii", data_type="string")
class GuardrailsPII(Validator):
PII_ENTITIES_MAP = {
"pii": [
"EMAIL_ADDRESS",
"PHONE_NUMBER",
"DOMAIN_NAME",
"IP_ADDRESS",
"DATE_TIME",
"LOCATION",
"PERSON",
"URL",
],
"spi": [
"CREDIT_CARD",
"CRYPTO",
"IBAN_CODE",
"NRP",
"MEDICAL_LICENSE",
"US_BANK_NUMBER",
"US_DRIVER_LICENSE",
"US_ITIN",
"US_PASSPORT",
"US_SSN",
],
}
def __init__(
self,
entities: List[str],
entities: str | List[str],
model_name: str = "urchade/gliner_small-v2.1",
get_entity_threshold: Callable = get_entity_threshold,
on_fail: Optional[Callable] = None,
Expand Down Expand Up @@ -85,7 +109,11 @@ def __init__(
**kwargs,
)

self.entities = entities
if isinstance(entities, str):
assert entities in self.PII_ENTITIES_MAP, f"Invalid entity type: {entities}"
self.entities = self.PII_ENTITIES_MAP[entities]
else:
self.entities = entities
self.model_name = model_name
self.get_entity_threshold = get_entity_threshold

Expand Down Expand Up @@ -169,7 +197,7 @@ def anonymize(self, text: str, entities: list[str]) -> Tuple[str, list[ErrorSpan
]

return output.anonymized_text, error_spans


def _validate(self, value: Any, metadata: Dict = {}) -> ValidationResult:
entities = metadata.get("entities", self.entities)
Expand Down

0 comments on commit 5b96bac

Please sign in to comment.