Skip to content

Commit

Permalink
Fixes for evaluation:
Browse files Browse the repository at this point in the history
- default retriever (should achieve 100% recall)
- evaluator waits for scroll (fixes long webpages)
  • Loading branch information
turboNinja2 committed Aug 28, 2024
1 parent e2dfaa6 commit 369142a
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
12 changes: 6 additions & 6 deletions lavague-core/lavague/core/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
import yaml
from llama_index.core import QueryBundle
import traceback
import base64
import ast
from bs4 import BeautifulSoup
from tempfile import NamedTemporaryFile
import time


class Evaluator(ABC):
Expand Down Expand Up @@ -86,7 +86,7 @@ def load_website_in_driver(driver, html, viewport_size, action):
driver.get(f"file:{f.name}")
driver.wait_for_idle()
element = driver.resolve_xpath(action["args"]["xpath"])
driver.execute_script("arguments[0].scrollIntoView({block: 'center'});", element)
driver.execute_script("arguments[0].scrollIntoView({block: 'center', behavior: 'instant'});", element)


FAIL_ACTION = {"args": {"xpath": "(string)"}, "name": "fail"}
Expand All @@ -99,6 +99,7 @@ def evaluate(
dataset: pd.DataFrame,
driver: SeleniumDriver = None, # Optional, the driver passed to the retriever
retriever_name: str = "",
wait_for_scroll: int = 1
) -> pd.DataFrame:
result_filename = (
(retriever_name if retriever_name else type(retriever).__name__)
Expand All @@ -119,14 +120,13 @@ def evaluate(
action = yaml.safe_load(row["action"])
instruction = row["instruction"]
try:
if (
driver
): # artificially get the page if the retriever needs a driver
driver.__init__() # reinit the driver
if driver:
driver.__init__()
viewport_size = parse_viewport_size(row["viewport_size"])
load_website_in_driver(
driver, row["html"], viewport_size, action
)
time.sleep(wait_for_scroll)
t_begin = datetime.now()
nodes = retriever.retrieve(
QueryBundle(query_str=instruction), [driver.get_html()]
Expand Down
6 changes: 6 additions & 0 deletions lavague-core/lavague/core/retrievers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ def get_default_retriever(
)


def get_trivial_retriever(
driver: BaseDriver, embedding: Optional[BaseEmbedding] = None
) -> BaseHtmlRetriever:
return InteractiveXPathRetriever(driver)


class BaseHtmlRetriever(ABC):
@abstractmethod
def retrieve(
Expand Down

0 comments on commit 369142a

Please sign in to comment.