Skip to content

Commit

Permalink
Merge pull request #607 from GraphScope/refactor_workflow
Browse files Browse the repository at this point in the history
fix bug when chroma db is empty and prepare workflow for extraction
  • Loading branch information
pomelo-nwu authored Nov 28, 2024
2 parents 7cdb58f + 3cf35cd commit 328e28b
Show file tree
Hide file tree
Showing 13 changed files with 700 additions and 336 deletions.
159 changes: 131 additions & 28 deletions python/graphy/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,139 @@ python paper_scrapper.py --max-workers 4 --max-inspectors 500 --workflow <path_t

- `--max-workers` (optional): Specifies the maximum number of parallel workers (default: 4).
- `--max-inspectors` (optional): Defines the maximum number of papers to fetch (default: 100).
- `--workflow` (optional): Path to a workflow configuration file. If not provided, the default configuration file config/workflow.json will be used.
- `--workflow` (optional): Path to a workflow configuration file. If not provided, the default configuration file `config/workflow.json` will be used.
- `<path_to_seed_papers>`: Provide the path containing seed papers. Each paper is a PDF document.

> If no `workflow` provided, a default workflow configuration in `config/workflow.json` will be used.
Ensure that the workflow configuration contains your custom LLM model settings by modifying the "llm_model" field.
# Workflow Configuration
Refer to an [example](config/workflow.json) for a workflow with Paper Inspector and Reference Navigator. Below are instructions on the following key fields in a workflow: `id`, `llm_config`, and `graph`.

## The `id` field
The `id` field uniquely identifies the workflow. This can be any descriptive string or a generated ID.

**Example**:
```json
"id": "test_paper_inspector"
```

## The `llm_config` field
The `llm_config` field configures the large language model (LLM) used in the workflow
- `llm_model`: Specifies the LLM (e.g., qwen-plus).
- `base_url`: The API endpoint for the LLM service.
- `api_key`: The API key for authentication (if required).
- `model_kwargs`: Additional parameters for fine-tuning the model behavior, such as temperature and streaming output.

We currently offer two options for configuring an LLM model:
- **Option 1: Using OpenAI-Compatible APIs**
This option supports OpenAI and other providers offering compatible APIs. To configure, provide the llm_model, base_url, api_key, and any additional model arguments. The example below demonstrates using [OpenAI](https://platform.openai.com/) and Alibaba’s [DashScope](https://help.aliyun.com/zh/dashscope/developer-reference/compatibility-of-openai-with-dashscope).

**Example of OpenAI**:
```json
"llm_config": {
"llm_model": "gpt4o",
"base_url": "https://api.openai.com/v1",
"api_key": "xx",
"model_kwargs": {
"temperature": 0,
"streaming": true
}
}
```

**Example of DashScope**:
```json
"llm_config": {
"llm_model": "qwen-plus",
"base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
"api_key": "xx",
"model_kwargs": {
"temperature": 0,
"streaming": true
}
}
```

- **Option 2: Using Locally Deployed Models with Ollama**
This option supports locally deployed LLM models through [Ollama](https://ollama.com/). Set llm_model to `ollama/<ollama_model_name>` to specify a model. For instance, the following settings configure the locally deployed Llama3.1 model (defaulting to 8b) from Ollama:

**Example**:
```json
"llm_config": {
"llm_model": "ollama/llama3.1",
"base_url": "http://localhost:11434",
"model_kwargs": {
"streaming": true
}
}
```


Note: If no LLM model is specified for a dataset, a default model configuration will be applied. To customize this default, open `models/__init__.py` and modify the `DEFAULT_LLM_MODEL_CONFIG` variable.



## The `graph` field
The graph field defines the structure of the workflow, comprising inspectors and navigators.

### Inspectors

- Inspectors define workflows for extracting structured information from unstructured data.
- Each inspector can contain a graph (an inner workflow) with:
- nodes: Represent individual extraction tasks.
- edges: Define dependencies between nodes.

We further explain the `extract_from` field for a node within an Inspector, which specifies where to extract the given information in a paper:
- `exact`: Explicitly defined page numbers or sections. Such as `{"exact": ["1"]}` means searching in Section 1.
- `match`: Keywords to search for relevant sections. Such as `"match": ["introduction"]` means searching for sections
semantically like "Introduction".
- `exact` and `match` fields can both present in `extract_from`.
- If `extract_from` is omitted or empty, the entire document will be searched.

**Example of a Node within paper Inspector**:
```json
{
"name": "Background",
"query": "**Question**: Please describe the problem studied in this paper...",
"extract_from": {"exact": ["1"], "match": ["introduction"]},
"output_schema": {
"type": "single",
"description": "The background of this paper",
"item": [
{
"name": "problem_definition",
"type": "string",
"description": "The problem studied in this paper."
},
{
"name": "problem_value",
"type": "string",
"description": "Why the problem is worth studying."
},
{
"name": "existing_solutions",
"type": "string",
"description": "What are the existing solutions and their problems."
}
]
}
}
```

### Navigators
Currently, the only thing to configure in a navigator is the connected Inspector nodes.
The `navigators` can be left empty, as in [workflow_inspector](config/workflow_inspector.json),
which will only process Paper Inspector without Reference Navigator.

**Example**:
```json
{
"name": "Reference",
"source": "PaperInspector",
"target": "PaperInspector"
}
```


The scraped data will be saved in the directory specified by [WF_OUTPUT_DIR](config/__init__.py), under a subdirectory named after your workflow ID (`<your_workflow_id>`).
- If the default workflow configuration is used, the workflow ID is `test_paper_scrapper`.
Expand All @@ -81,10 +209,7 @@ A backend demo application is included in this project, accessible as a standalo
python apps/demo_app.py
```

The server will be running on `http://localhost:9999` by default.

# Run Frontend Server
Please refer to the [frontend README](../../examples/graphy/README.md) for instructions on how to run the frontend server.
The server will be running on `http://localhost:9999` by default. A GUI [frontend](../../examples/graphy/README.md) is provided to demonstrate the graphy process.

# Instruction of Backend APIs

Expand Down Expand Up @@ -136,28 +261,6 @@ curl -X POST http://0.0.0.0:9999/api/llm/config -H "Content-Type: application/js
}'
```

We currently offer two options for configuring an LLM model:

- **Option 1: Using OpenAI-Compatible APIs**
This option supports OpenAI and other providers offering compatible APIs. To configure, provide the llm_model, base_url, api_key, and any additional model arguments. The example below demonstrates using OpenAI-compatible APIs through Alibaba’s Dashscope with the qwen-plus model.
- **Option 2: Using Locally Deployed Models with Ollama**
This option supports locally deployed LLM models through Ollama. Set llm_model to `ollama/<ollama_model_name>` to specify a model. For instance, the following settings configure the locally deployed Llama3.1 model (defaulting to 8b) from Ollama:

```bash
curl -X POST http://0.0.0.0:9999/api/llm/config -H "Content-Type: application/json" -d '{
"dataset_id": "8547eb64-a106-5d09-8950-8a47fb9292dc",
"llm_model": "ollama/llama3.1",
"base_url": "http://localhost:11434",
"model_kwargs": {
"streaming": true
}
}'
```


Note: If no LLM model is specified for a dataset, a default model configuration will be applied. To customize this default, open `models/__init__.py` and modify the `DEFAULT_LLM_MODEL_CONFIG` variable.


### Get the LLM Config

```bash
Expand Down
127 changes: 46 additions & 81 deletions python/graphy/apps/demo_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,6 @@ def set_meta_config(self, dataset_id, config):
persist_store = self.get_persist_store(dataset_id)
metadata = persist_store.get_state("", "metadata")
if config:
self.set_llm_model(dataset_id, config)
if "api_key" in config:
config["api_key"] = encrypt_key(config["api_key"])
metadata["config"] = config
Expand All @@ -176,8 +175,6 @@ def get_meta_config(self, dataset_id):
metadata = persist_store.get_state("", "metadata")
llm_config = metadata.get("config", {})
# must initialize the llm model if it is not configured yet
if llm_config:
self.set_llm_model(dataset_id, llm_config, initialize=False)
return llm_config

def set_meta_interactive_info(self, dataset_id, graph_id):
Expand Down Expand Up @@ -209,21 +206,12 @@ def get_status(self, dataset_id):

def set_metadata(self, dataset_id, value):
persist_store = self.get_persist_store(dataset_id)
llm_config = value.get("config", {})
if llm_config:
self.set_llm_model(dataset_id, llm_config)
if "api_key" in llm_config:
llm_config["api_key"] = encrypt_key(llm_config["api_key"])
persist_store.save_state("", "metadata", value)

def get_metadata(self, dataset_id):
persist_store = self.get_persist_store(dataset_id)
metadata = persist_store.get_state("", "metadata")
llm_config = metadata.get("config", {})
# must initialize the llm model if it is not configured yet
if llm_config:
self.set_llm_model(dataset_id, llm_config, initialize=False)

return metadata

def set_cache(self, dataset_id, key, value):
Expand All @@ -245,64 +233,33 @@ def get_workflow_node_names(self, dataset_id):
node_names = workflow.graph.get_node_names()
return node_names

def set_llm_model(self, dataset_id, llm_config, initialize=True):
if initialize:
llm = set_llm_model(llm_config)
self.set_cache(dataset_id, "llm_model", llm)
else:
cache = self.cache.get(dataset_id, {})
if cache and not "llm_model" in cache:
llm_config = copy.deepcopy(llm_config)
if "api_key" in llm_config:
llm_config["api_key"] = decrypt_key(llm_config["api_key"])
llm = set_llm_model(llm_config)
self.set_cache(dataset_id, "llm_model", llm)

def set_workflow(self, dataset_id, workflow_dict):
persist_store = self.get_persist_store(dataset_id)
cache = self.cache.get(dataset_id, {})
llm = cache.get("llm_model", None)
if llm:
logger.debug(
f"The model {llm.model_name} is used for the workflow, which has maximum output token limit of {llm.context_size}"
)

if not llm and self.llm:
llm = self.llm
logger.debug(
f"Default model {self.llm.model_name} is used for the workflow, which has maximum output token limit of {llm.context_size}"
)

if not llm:
raise ValueError("LLM model not configured")

embedding_model = self.embedding_model.chroma_embedding_model()
if not embedding_model:
# A safe guarantee to use the default model
embedding_model = embedding_functions.DefaultEmbeddingFunction()

vectordb = chromadb.PersistentClient(path=WF_VECTDB_DIR)
new_workflow_dict = {}
if "graph" not in workflow_dict:
new_workflow_dict["graph"] = workflow_dict
workflow_dict = new_workflow_dict
if "id" not in workflow_dict:
workflow_dict["id"] = dataset_id
if "llm_config" not in workflow_dict:
llm_config = self.get_meta_config(dataset_id)
if llm_config:
if "api_key" in llm_config:
llm_config["api_key"] = decrypt_key(llm_config["api_key"])
workflow_dict["llm_config"] = llm_config
else:
workflow_dict["llm_config"] = DEFAULT_LLM_MODEL_CONFIG

# Initialize the workflow
workflow = SurveyPaperReading(
dataset_id,
llm,
llm,
embedding_model,
vectordb,
workflow_dict,
persist_store,
)
workflow = SurveyPaperReading.from_dict(workflow_dict, persist_store)
self.set_cache(dataset_id, "workflow", workflow)

def get_progress(self, dataset_id, node_names=[]):
"""
def get_paper_data(node, progress, workflow, get_paper_data=False):
output_data = {}
output_data["node_name"] = node
output_data["papers"] = []
progress[node] = workflow.get_progress(node)

"""
for wf in self.cache[dataset_id]["wf_dict"].values():
progress[node].add(wf.get_progress(node))
progress["total"].add(wf.get_progress())
Expand All @@ -324,9 +281,9 @@ def get_paper_data(node, progress, workflow, get_paper_data=False):
result["id"] = hash_id(f"{paper_data['id']}_{i}")
paper_data["data"].append(result)
output_data["papers"].append(paper_data)
"""
return output_data
"""

def get_default_paper_data(node, get_paper_data=False):
output_data = {}
Expand Down Expand Up @@ -373,29 +330,38 @@ def get_default_paper_data(node, get_paper_data=False):
return output_data

progress = {}
progress["total"] = ProgressInfo()
total_progress = ProgressInfo()
output = []
if dataset_id in self.cache:
workflow = self.cache[dataset_id].get("workflow", None)
if not node_names:
inspector_node = workflow.graph.get_first_node()
node_names = inspector_node.graph.get_node_names()
for node in node_names:
progress[node] = ProgressInfo()
if workflow:
if not node_names:
node_names = workflow.graph.get_node_names()
for node in node_names:
progress[node] = {}
output_data = {}
if dataset_id == DEFAULT_DATASET_ID:
output_data = get_default_paper_data(node, len(node_names) == 1)
output_data["progress"] = 100.0
else:
output_data["node_name"] = node
output_data["progress"] = []
progress[node] = workflow.get_progress(node)
for key, val in progress[node].items():
output_data["progress"].append(
{
"node_name": key,
"progress": val.get_percentage(),
}
)
total_progress.add(val)
# output_data["progress"] = progress[node].get_percentage()
output.append(output_data)

for node in node_names:
if dataset_id == DEFAULT_DATASET_ID:
output_data = get_default_paper_data(node, len(node_names) == 1)
output_data["progress"] = 100.0
else:
output_data = get_paper_data(
node, progress, workflow, len(node_names) == 1
)
output_data["progress"] = progress[node].get_percentage()
output.append(output_data)
progress["total"] = workflow.get_progress("")
if len(output) == 1:
output = output[0]["progress"]

return output, progress["total"].get_percentage()
return output, total_progress.get_percentage()

def check_status(self, dataset_id, status):
metadata = self.get_metadata(dataset_id)
Expand Down Expand Up @@ -526,9 +492,6 @@ def get_single_dataset(dataset_id):
if dataset_id != DEFAULT_DATASET_ID: # do not do for default data
status = STATUS.INITIALIZED
if "config" in metadata:
self.set_llm_model(
dataset_id, metadata["config"], initialize=False
)
status = STATUS.WAITING_WORKFLOW_CONFIG.value
if "schema" in metadata:
if "workflow" not in cache:
Expand Down Expand Up @@ -657,6 +620,7 @@ def config_workflow():
)

except Exception as e:
traceback.print_exc()
return create_error_response(str(e)), 500

@self.app.route("/api/dataset/workflow/config", methods=["GET"])
Expand Down Expand Up @@ -768,6 +732,7 @@ def get_extract():
)
self.add_meta_entities(dataset_id, workflow_node_names)
if total_progress == 100.0:
logger.info("The extraction workflow is completed")
self.set_status(dataset_id, STATUS.WAITING_CLUSTER)

# Return the results as a JSON response
Expand Down
Loading

0 comments on commit 328e28b

Please sign in to comment.