Skip to content

Commit

Permalink
Merge branch 'dev' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
creatorrr authored Oct 17, 2024
2 parents 981df1a + aa44bfd commit 8c1243f
Show file tree
Hide file tree
Showing 12 changed files with 1,265 additions and 170 deletions.
3 changes: 2 additions & 1 deletion agents-api/agents_api/activities/execute_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ async def execute_system(
search_params = HybridDocSearchRequest(
text=arguments.pop("text"),
vector=arguments.pop("vector"),
confidence=arguments.pop("confidence", 0.7),
alpha=arguments.pop("alpha", 0.75),
confidence=arguments.pop("confidence", 0.5),
limit=arguments.get("limit", 10),
)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,21 +1,19 @@
import base64
from typing import Any

from beartype import beartype
from temporalio import activity

from ...autogen.openapi_model import CreateTransitionRequest
from ...common.protocol.tasks import (
StepContext,
StepOutcome,
)
from ...common.protocol.tasks import StepContext
from ...common.storage_handler import auto_blob_store
from .transition_step import original_transition_step


@activity.defn
@auto_blob_store
@beartype
async def raise_complete_async(context: StepContext, output: StepOutcome) -> None:
async def raise_complete_async(context: StepContext, output: Any) -> None:
activity_info = activity.info()

captured_token = base64.b64encode(activity_info.task_token).decode("ascii")
Expand Down
14 changes: 13 additions & 1 deletion agents-api/agents_api/clients/temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
from uuid import UUID

from temporalio.client import Client, TLSConfig
from temporalio.common import (
SearchAttributeKey,
SearchAttributePair,
TypedSearchAttributes,
)

from ..autogen.openapi_model import TransitionTarget
from ..common.protocol.tasks import ExecutionInput
Expand Down Expand Up @@ -48,6 +53,7 @@ async def run_task_execution_workflow(
from ..workflows.task_execution import TaskExecutionWorkflow

client = client or (await get_client())
execution_id_key = SearchAttributeKey.for_keyword("CustomStringField")

return await client.start_workflow(
TaskExecutionWorkflow.run,
Expand All @@ -56,7 +62,13 @@ async def run_task_execution_workflow(
id=str(job_id),
run_timeout=timedelta(days=31),
retry_policy=DEFAULT_RETRY_POLICY,
# TODO: Should add search_attributes for queryability
search_attributes=TypedSearchAttributes(
[
SearchAttributePair(
execution_id_key, str(execution_input.execution.id)
),
]
),
)


Expand Down
216 changes: 216 additions & 0 deletions agents-api/agents_api/common/nlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
import re
from collections import Counter, defaultdict

import spacy

# Load spaCy English model
spacy.prefer_gpu()
nlp = spacy.load("en_core_web_sm")


def extract_keywords(text: str, top_n: int = 10) -> list[str]:
"""
Extracts significant keywords and phrases from the text.
Args:
text (str): The input text to process.
top_n (int): Number of top keywords to extract based on frequency.
Returns:
List[str]: A list of extracted keywords/phrases.
"""
doc = nlp(text)

# Extract named entities
entities = [
ent.text.strip()
for ent in doc.ents
if ent.label_
not in ["DATE", "TIME", "PERCENT", "MONEY", "QUANTITY", "ORDINAL", "CARDINAL"]
]

# Extract nouns and proper nouns
nouns = [
chunk.text.strip().lower()
for chunk in doc.noun_chunks
if not chunk.root.is_stop
]

# Combine entities and nouns
combined = entities + nouns

# Normalize and count frequency
normalized = [re.sub(r"\s+", " ", kw).strip().lower() for kw in combined]
freq = Counter(normalized)

# Get top_n keywords
keywords = [item for item, count in freq.most_common(top_n)]

return keywords


def find_keyword_positions(doc, keyword: str) -> list[int]:
"""
Finds all start indices of the keyword in the tokenized doc.
Args:
doc (spacy.tokens.Doc): The tokenized document.
keyword (str): The keyword or phrase to search for.
Returns:
List[int]: List of starting token indices where the keyword appears.
"""
keyword_tokens = keyword.split()
n = len(keyword_tokens)
positions = []
for i in range(len(doc) - n + 1):
window = doc[i : i + n]
window_text = " ".join([token.text.lower() for token in window])
if window_text == keyword:
positions.append(i)
return positions


def find_proximity_groups(
text: str, keywords: list[str], n: int = 10
) -> list[set[str]]:
"""
Groups keywords that appear within n words of each other.
Args:
text (str): The input text.
keywords (List[str]): List of keywords to consider.
n (int): The proximity window in words.
Returns:
List[Set[str]]: List of sets, each containing keywords that are proximate.
"""
doc = nlp(text.lower())
keyword_positions = defaultdict(list)

for kw in keywords:
positions = find_keyword_positions(doc, kw)
keyword_positions[kw].extend(positions)

# Initialize Union-Find structure
parent = {}

def find(u):
while parent[u] != u:
parent[u] = parent[parent[u]]
u = parent[u]
return u

def union(u, v):
u_root = find(u)
v_root = find(v)
if u_root == v_root:
return
parent[v_root] = u_root

# Initialize each keyword as its own parent
for kw in keywords:
parent[kw] = kw

# Compare all pairs of keywords
for i in range(len(keywords)):
for j in range(i + 1, len(keywords)):
kw1 = keywords[i]
kw2 = keywords[j]
positions1 = keyword_positions[kw1]
positions2 = keyword_positions[kw2]
# Check if any positions are within n words
for pos1 in positions1:
for pos2 in positions2:
distance = abs(pos1 - pos2)
if distance <= n:
union(kw1, kw2)
break
else:
continue
break

# Group keywords by their root parent
groups = defaultdict(set)
for kw in keywords:
root = find(kw)
groups[root].add(kw)

# Convert to list of sets
group_list = list(groups.values())

return group_list


def build_query(groups: list[set[str]], keywords: list[str], n: int = 10) -> str:
"""
Builds a query string using the custom query language.
Args:
groups (List[Set[str]]): List of keyword groups.
keywords (List[str]): Original list of keywords.
n (int): The proximity window for NEAR.
Returns:
str: The constructed query string.
"""
grouped_keywords = set()
clauses = []

for group in groups:
if len(group) == 1:
clauses.append(f'"{list(group)[0]}"')
else:
sorted_group = sorted(
group, key=lambda x: -len(x)
) # Sort by length to prioritize phrases
escaped_keywords = [f'"{kw}"' for kw in sorted_group]
near_clause = f"NEAR/{n}(" + " ".join(escaped_keywords) + ")"
clauses.append(near_clause)
grouped_keywords.update(group)

# Identify keywords not in any group (if any)
remaining = set(keywords) - grouped_keywords
for kw in remaining:
clauses.append(f'"{kw}"')

# Combine all clauses with OR
query = " OR ".join(clauses)

return query


def text_to_custom_query(text: str, top_n: int = 10, proximity_n: int = 10) -> str:
"""
Converts arbitrary text to the custom query language.
Args:
text (str): The input text to convert.
top_n (int): Number of top keywords to extract.
proximity_n (int): The proximity window for NEAR/n.
Returns:
str: The custom query string.
"""
keywords = extract_keywords(text, top_n)
if not keywords:
return ""
groups = find_proximity_groups(text, keywords, proximity_n)
query = build_query(groups, keywords, proximity_n)
return query


def paragraph_to_custom_queries(paragraph: str) -> list[str]:
"""
Converts a paragraph to a list of custom query strings.
Args:
paragraph (str): The input paragraph to convert.
Returns:
List[str]: The list of custom query strings.
"""

queries = [text_to_custom_query(sentence.text) for sentence in nlp(paragraph).sents]

return queries
4 changes: 2 additions & 2 deletions agents-api/agents_api/models/docs/search_docs_by_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def search_docs_by_embedding(
owners: list[tuple[Literal["user", "agent"], UUID]],
query_embedding: list[float],
k: int = 3,
confidence: float = 0.7,
ef: int = 128,
confidence: float = 0.5,
ef: int = 32,
mmr_lambda: float = 0.25,
embedding_size: int = 1024,
) -> tuple[list[str], dict]:
Expand Down
9 changes: 5 additions & 4 deletions agents-api/agents_api/models/docs/search_docs_by_text.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""This module contains functions for searching documents in the CozoDB based on embedding queries."""

import json
from typing import Any, Literal, TypeVar
from uuid import UUID

Expand All @@ -10,6 +9,7 @@
from pydantic import ValidationError

from ...autogen.openapi_model import DocReference
from ...common.nlp import paragraph_to_custom_queries
from ..utils import (
cozo_query,
partialclass,
Expand Down Expand Up @@ -64,7 +64,7 @@ def search_docs_by_text(

# Need to use NEAR/3($query) to search for arbitrary text within 3 words of each other
# See: https://docs.cozodb.org/en/latest/vector.html#full-text-search-fts
query = f"NEAR/3({json.dumps(query)})"
fts_queries = paragraph_to_custom_queries(query)

# Construct the datalog query for searching document snippets
search_query = f"""
Expand Down Expand Up @@ -112,11 +112,12 @@ def search_docs_by_text(
index,
content
|
query: $query,
query: query,
k: {k},
score_kind: 'tf_idf',
bind_score: score,
}},
query in $fts_queries,
distance = -score,
snippet_data = [index, content]
Expand Down Expand Up @@ -183,5 +184,5 @@ def search_docs_by_text(

return (
queries,
{"owners": owners, "query": query},
{"owners": owners, "query": query, "fts_queries": fts_queries},
)
4 changes: 3 additions & 1 deletion agents-api/agents_api/models/execution/create_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ def create_execution(
data["metadata"] = data.get("metadata", {})
execution_data = data

if execution_data["output"] is not None and not isinstance(execution_data["output"], dict):
if execution_data["output"] is not None and not isinstance(
execution_data["output"], dict
):
execution_data["output"] = {OUTPUT_UNNEST_KEY: execution_data["output"]}

columns, values = cozo_process_mutate_data(
Expand Down
4 changes: 3 additions & 1 deletion agents-api/agents_api/models/execution/get_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@
one=True,
transform=lambda d: {
**d,
"output": d["output"][OUTPUT_UNNEST_KEY] if OUTPUT_UNNEST_KEY in d["output"] else d["output"],
"output": d["output"][OUTPUT_UNNEST_KEY]
if OUTPUT_UNNEST_KEY in d["output"]
else d["output"],
},
)
@cozo_query
Expand Down
8 changes: 0 additions & 8 deletions agents-api/agents_api/workflows/task_execution/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,6 @@ async def run(
start: TransitionTarget = TransitionTarget(workflow="main", step=0),
previous_inputs: list[Any] = [],
) -> Any:
# Add metadata to the workflow run
workflow.upsert_search_attributes(
{
"task_id": execution_input.task.id,
"execution_id": execution_input.execution.id,
}
)

workflow.logger.info(
f"TaskExecutionWorkflow for task {execution_input.task.id}"
f" [LOC {start.workflow}.{start.step}]"
Expand Down
Loading

0 comments on commit 8c1243f

Please sign in to comment.