Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
mfz-ant committed Jul 19, 2024
1 parent e72dc7a commit 2c163cb
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 15 deletions.
12 changes: 7 additions & 5 deletions python/knext/knext/ca/common/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class Question(object):
and can also have sub-questions. It helps in structuring a complex problem
into manageable parts, ensuring that each part can be addressed in a logical
sequence.
There are two possible relationships between questions:
1. The content of one question depends on the answer to another question.
2. One question can be broken down into several sub-questions.
Expand Down Expand Up @@ -82,7 +82,7 @@ class KagBaseModule(object):
This class handles the processing flow from the input question to the output generated by the LLM.
It supports intermediate processing tools and additional information fetching tools.
A significant feature of this class is the management of prompt templates used to communicate with the LLM.
The class allows for default or custom prompt templates to be loaded, processed, and saved.
If a computational context is indicated, it initializes and manages the state dictionary containing the prompt template.
"""
Expand All @@ -103,14 +103,14 @@ def __init__(
use_default_prompt_template (bool): Flag to use the default prompt template.
prompt_template_dir (str): Directory to load the prompt template from, if not using the default.
is_computational (bool): Indicates if the module operates in a computational context, impacting prompt template usage.
If the module is computational, it initializes the state dictionary with the prompt template.
"""
self.llm_module = llm_module
self.is_computational = is_computational
self.intermediate_process_tools = []
self.extra_info_fetch_tools = []

if is_computational:
self.use_default_prompt_template = use_default_prompt_template
self.is_prompt_template_cn = is_prompt_template_cn
Expand Down Expand Up @@ -160,6 +160,7 @@ def postprocess(self, question: Question, llm_output):

def get_ca_default_prompt_template_dir(self):
directory = os.path.dirname(os.path.abspath(__file__))
directory = os.path.join(directory, '..', 'logic/modules')
if self.is_prompt_template_cn:
return os.path.join(directory, 'default_prompt_template')
else:
Expand All @@ -170,6 +171,7 @@ def load_prompt_template(self, prompt_dir):
prompt_dir,
f'{self.get_module_name()}.txt'
)
logger.info(f"##### {self.get_module_name()} prompt_file_path: {prompt_file_path} {os.path.exists(prompt_file_path)}")
if os.path.exists(prompt_file_path):
with open(prompt_file_path, 'r') as f:
template_string = f.read()
Expand Down Expand Up @@ -222,7 +224,7 @@ def init_state_dict(self):
Returns:
state_dict (dict): The state dictionary containing the prompt template.
"""
"""
if self.use_default_prompt_template:
return self.create_default_state_dict()
else:
Expand Down
20 changes: 12 additions & 8 deletions python/knext/knext/ca/logic/modules/reasoner.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,9 @@ class ExtractTriplesFromTextModule(Reasoner):
"""

def __init__(self, llm_module, prompt_template_dir=None, is_prompt_template_cn=True):
super().__init__(llm_module, prompt_template_dir, is_prompt_template_cn)
def __init__(self, llm_module, use_default_prompt_template=True, prompt_template_dir=None,
is_prompt_template_cn=True):
super().__init__(llm_module, use_default_prompt_template, prompt_template_dir, is_prompt_template_cn)

def get_module_name(self):
return "ExtractTriplesFromTextModule"
Expand All @@ -130,8 +131,9 @@ class FetchSubject(Reasoner):
"""

def __init__(self, llm_module, prompt_template_dir=None, is_prompt_template_cn=True):
super().__init__(llm_module, prompt_template_dir, is_prompt_template_cn)
def __init__(self, llm_module, use_default_prompt_template=True, prompt_template_dir=None,
is_prompt_template_cn=True):
super().__init__(llm_module, use_default_prompt_template, prompt_template_dir, is_prompt_template_cn)

def preprocess(self, question: Question):
prompt = self.state_dict['prompt_template'].substitute(
Expand All @@ -153,8 +155,9 @@ class FetchPredicate(Reasoner):
"""

def __init__(self, llm_module, prompt_template_dir=None, is_prompt_template_cn=True):
super().__init__(llm_module, prompt_template_dir, is_prompt_template_cn)
def __init__(self, llm_module, use_default_prompt_template=True, prompt_template_dir=None,
is_prompt_template_cn=True):
super().__init__(llm_module, use_default_prompt_template, prompt_template_dir, is_prompt_template_cn)

def preprocess(self, question: Question):
prompt = self.state_dict['prompt_template'].substitute(
Expand All @@ -176,8 +179,9 @@ class FetchObject(Reasoner):
"""

def __init__(self, llm_module, prompt_template_dir=None, is_prompt_template_cn=True):
super().__init__(llm_module, prompt_template_dir, is_prompt_template_cn)
def __init__(self, llm_module, use_default_prompt_template=True, prompt_template_dir=None,
is_prompt_template_cn=True):
super().__init__(llm_module, use_default_prompt_template, prompt_template_dir, is_prompt_template_cn)

def preprocess(self, question: Question):
prompt = self.state_dict['prompt_template'].substitute(
Expand Down
2 changes: 1 addition & 1 deletion python/knext/knext/examples/musique/musique_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def __init__(self,
debug_mode,
max_depth=1,
):
use_default_prompt_template = False
use_default_prompt_template = True
intermediate_process_tools = []
intermediate_process_tools.append(
LoggerIntermediateProcessTool(debug_mode=debug_mode)
Expand Down
2 changes: 1 addition & 1 deletion python/knext/knext/examples/musique/run_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def create_divide_and_conquer_agent_hierarchical_info_tool(musique_data, get_llm

answer_parent_question = SolveQuestionWithContext(
llm_module=llm,
use_default_prompt_template=False,
use_default_prompt_template=True,
prompt_template_dir=prompt_template_dir
)

Expand Down

0 comments on commit 2c163cb

Please sign in to comment.