Skip to content

Commit

Permalink
Remove code that extracts patterns outside of function
Browse files Browse the repository at this point in the history
dragon-dxw committed Feb 14, 2024
1 parent 79bde2a commit c10e5b9
Showing 3 changed files with 7 additions and 10 deletions.
9 changes: 2 additions & 7 deletions src/lambdas/update_rules_processor/index.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#!env/bin/python

import json
import logging
import urllib.parse
from io import StringIO
@@ -70,15 +69,11 @@ def lambda_handler(event: S3Event, context: LambdaContext) -> None:
csv_file = response["Body"].read().decode("utf-8")
df = pd.read_csv(StringIO(csv_file))

# used by determine_replacements_caselaw
create_test_jsonl(source_bucket, df)
jsonl_key = "test_citation_patterns.jsonl"

patterns_resp = s3.get_object(Bucket=source_bucket, Key=jsonl_key)
patterns = patterns_resp["Body"]
pattern_list = [json.loads(line) for line in patterns.iter_lines()]

try:
test_manifest(df, pattern_list)
test_manifest(df)
except AssertionError:
LOGGER.error("Exception: Manifest test failed")
raise
5 changes: 4 additions & 1 deletion src/utils/validate_patterns.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import json

import pandas as pd
import spacy


def test_manifest(df: pd.DataFrame, patterns: list[str]) -> None:
def test_manifest(df: pd.DataFrame) -> None:
"""
Test for the rules manifest: given a dataframe of the CSV file, and the patterns
(which are also derived directly from that CSV file), check that the number of
@@ -13,6 +15,7 @@ def test_manifest(df: pd.DataFrame, patterns: list[str]) -> None:
"en_core_web_sm", exclude=["tok2vec", "attribute_ruler", "lemmatizer", "ner"]
)
nlp.max_length = 2500000
patterns = [json.loads(s) for s in df["pattern"]]

citation_ruler = nlp.add_pipe("entity_ruler")
citation_ruler.add_patterns(patterns)
3 changes: 1 addition & 2 deletions validate_match_csv.py
Original file line number Diff line number Diff line change
@@ -55,5 +55,4 @@ def get_patterns(csv_dict):
if len(match) > 1:
raise RuntimeError(f"{len(match)} matches for {item['match_example']!r}")


test_manifest(df, patterns)
test_manifest(df)

0 comments on commit c10e5b9

Please sign in to comment.