Skip to content

Commit

Permalink
Eval for new labelizer (#582)
Browse files Browse the repository at this point in the history
* update eval

* fix evaluator

* close file

* init driver in condition

* Fixes for evaluation:
- default retriever (should achieve 100% recall)
- evaluator waits for scroll (fixes long webpages)

* Simple HTML cleaner for data which is not used by the embedding

* Added destruction of driver after each step

* Minor improvement of the regex to filter base64 images

* format

---------

Co-authored-by: Julien <[email protected]>
Co-authored-by: DanyWin <[email protected]>
  • Loading branch information
3 people authored Aug 29, 2024
1 parent 9df7f7b commit 0e77f9b
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 40 deletions.
81 changes: 41 additions & 40 deletions lavague-core/lavague/core/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
import yaml
from llama_index.core import QueryBundle
import traceback
import base64
import ast
from bs4 import BeautifulSoup
from tempfile import NamedTemporaryFile
import time


class Evaluator(ABC):
Expand Down Expand Up @@ -47,17 +48,6 @@ def compare(
return fig


def parse_action(action: str) -> dict:
action = yaml.safe_load(action)
try:
action = action[0]["actions"][0]["action"]
_ = action["args"]["xpath"]
_ = action["name"]
return action
except:
return None


def parse_yaml(action):
try:
return yaml.safe_load(action)[0]["actions"][0]["action"]
Expand All @@ -84,11 +74,21 @@ def validate_action(action):
return False


def remove_img(html):
soup = BeautifulSoup(html, "html.parser")
for tag in soup.find_all("img"):
tag.extract()
return soup.decode()
def normalize_xpath(xpath: str):
return xpath.replace("[1]", "")


def load_website_in_driver(driver, html, viewport_size, action):
with NamedTemporaryFile(delete=False, mode="w", suffix=".html") as f:
f.write(html)
if viewport_size:
driver.resize_driver(viewport_size["width"], viewport_size["height"])
driver.get(f"file:{f.name}")
driver.wait_for_idle()
element = driver.resolve_xpath(action["args"]["xpath"])
driver.execute_script(
"arguments[0].scrollIntoView({block: 'center', behavior: 'instant'});", element
)


FAIL_ACTION = {"args": {"xpath": "(string)"}, "name": "fail"}
Expand All @@ -101,14 +101,16 @@ def evaluate(
dataset: pd.DataFrame,
driver: SeleniumDriver = None, # Optional, the driver passed to the retriever
retriever_name: str = "",
wait_for_scroll: int = 1,
) -> pd.DataFrame:
result_filename = (
(retriever_name if retriever_name else type(retriever).__name__)
+ "_evaluation_"
+ datetime.now().strftime("%Y-%m-%d_%H-%M")
+ ".csv"
)
results = dataset.loc[dataset["is_verified"]].copy()
results = dataset.loc[dataset["validated"]].copy()
results.insert(len(results.columns), "result_nodes", None)
results.insert(len(results.columns), "recall", None)
results.insert(len(results.columns), "output_size", None)
results.insert(len(results.columns), "time", None)
Expand All @@ -117,19 +119,16 @@ def evaluate(

try:
for i, row in tqdm(results.iterrows()):
driver.__init__() # reinit the driver
action = parse_action(row["action"])
if driver: # artificially get the page if the retriever needs a driver
html_bs64 = base64.b64encode(
remove_img(row["preaction_html_bundle"]).encode()
).decode()
driver.get("data:text/html;base64," + html_bs64)
viewport_size = parse_viewport_size(row["viewport_size"])
driver.resize_driver(
width=viewport_size["width"], height=viewport_size["height"]
)
action = yaml.safe_load(row["action"])
instruction = row["instruction"]
try:
if driver:
driver.__init__()
viewport_size = parse_viewport_size(row["viewport_size"])
load_website_in_driver(
driver, row["html"], viewport_size, action
)
time.sleep(wait_for_scroll)
t_begin = datetime.now()
nodes = retriever.retrieve(
QueryBundle(query_str=instruction), [driver.get_html()]
Expand All @@ -139,8 +138,13 @@ def evaluate(
print("ERROR: ", i)
traceback.print_exc()
nodes = []
if driver:
driver.destroy()
nodes = "\n".join(nodes)
results.at[i, "recall"] = 1 if action["args"]["xpath"] in nodes else 0
results.at[i, "result_nodes"] = nodes
results.at[i, "recall"] = (
1 if normalize_xpath(action["args"]["xpath"]) in nodes else 0
)
results.at[i, "output_size"] = len(nodes)
results.at[i, "time"] = pd.Timedelta(t_end - t_begin).total_seconds()
print("Evaluation terminated successfully.")
Expand Down Expand Up @@ -177,7 +181,7 @@ def evaluate(
+ datetime.now().strftime("%Y-%m-%d_%H-%M")
+ ".csv"
)
results = dataset.loc[dataset["is_verified"]].copy()
results = dataset.loc[dataset["validated"]].copy()
results.insert(len(results.columns), "recall", None)
results.insert(len(results.columns), "correct_action", None)
results.insert(len(results.columns), "correct_xpath", None)
Expand All @@ -187,17 +191,13 @@ def evaluate(

try:
for i, row in tqdm(results.iterrows()):
action = parse_action(row["action"])
html_bs64 = base64.b64encode(
remove_img(row["preaction_html_bundle"]).encode()
).decode()
navigation_engine.driver.get("data:text/html;base64," + html_bs64)
action = yaml.safe_load(row["action"])
viewport_size = parse_viewport_size(row["viewport_size"])
navigation_engine.driver.resize_driver(
width=viewport_size["width"], height=viewport_size["height"]
)
instruction = row["instruction"]
try:
load_website_in_driver(
navigation_engine.driver, row["html"], viewport_size, action
)
t_begin = datetime.now()
test_action = navigation_engine.execute_instruction(
instruction
Expand All @@ -212,7 +212,8 @@ def evaluate(
test_action = FAIL_ACTION
results.at[i, "correct_action"] = action["name"] == test_action["name"]
results.at[i, "correct_xpath"] = (
action["args"]["xpath"] == test_action["args"]["xpath"]
normalize_xpath(action["args"]["xpath"])
== test_action["args"]["xpath"]
)
results.at[i, "recall"] = (
results.at[i, "correct_action"] and results.at[i, "correct_xpath"]
Expand Down
24 changes: 24 additions & 0 deletions lavague-core/lavague/core/retrievers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ def get_default_retriever(
)


def get_trivial_retriever(
driver: BaseDriver, embedding: Optional[BaseEmbedding] = None
) -> BaseHtmlRetriever:
return InteractiveXPathRetriever(driver)


class BaseHtmlRetriever(ABC):
@abstractmethod
def retrieve(
Expand Down Expand Up @@ -585,6 +591,24 @@ def retrieve(
return get_nodes_text(results)


class CleanHTMLRetriever(BaseHtmlRetriever):
def __init__(self, drop_base_64: bool = True, drop_svg: bool = True) -> None:
self.drop_base_64 = drop_base_64
self.drop_svg = drop_svg

def _clean_chunk(self, html: str) -> str:
if self.drop_base_64:
html = re.sub('src="data:image/png;base64,([^"]*?)"', "", html)
if self.drop_svg:
html = re.sub("<svg.*?>(.+?)</svg>", "", html)
return html

def retrieve(
self, query: QueryBundle, html_nodes: List[str], viewport_only=True
) -> List[str]:
return [self._clean_chunk(html) for html in html_nodes]


def filter_for_xpathed_nodes(nodes: List):
pattern = re.compile(r'xpath="([^"]+)"')
compatibles = []
Expand Down

0 comments on commit 0e77f9b

Please sign in to comment.