Skip to content
This repository has been archived by the owner on Sep 12, 2024. It is now read-only.

Commit

Permalink
created get_llm_call_args to create arguments for LLM call
Browse files Browse the repository at this point in the history
  • Loading branch information
chandralegend committed Aug 1, 2024
1 parent 48c97a7 commit 16e9b27
Show file tree
Hide file tree
Showing 3 changed files with 198 additions and 168 deletions.
185 changes: 17 additions & 168 deletions jaclang/compiler/passes/main/pyast_gen_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import jaclang.compiler.absyntree as ast
from jaclang.compiler.constant import Constants as Con, EdgeDir, Tokens as Tok
from jaclang.compiler.passes import Pass
from jaclang.runtimelib.utils import extract_params, extract_type

T = TypeVar("T", bound=ast3.AST)

Expand Down Expand Up @@ -2883,6 +2882,13 @@ def by_llm_call(
exclude_info,
)

def get_by_llm_call_args(self, node: ast.FuncCall) -> tuple:
"""Get the arguments for the by_llm_call."""
# to avoid circular import
from jaclang.plugin.feature import JacFeature

return JacFeature.get_by_llm_call_args(self, node)

def exit_func_call(self, node: ast.FuncCall) -> None:
"""Sub objects.
Expand All @@ -2908,173 +2914,16 @@ def exit_func_call(self, node: ast.FuncCall) -> None:
self.ice("Invalid Parameter")
if node.genai_call:
self.needs_jac_feature()
model = node.genai_call.target.gen.py_ast[0]
model_params, include_info, exclude_info = extract_params(node.genai_call)
action = self.sync(
ast3.Constant(
value="Create an object of the specified type, using the specifically "
" provided input value(s) and look up any missing attributes from reliable"
" online sources to fill them in accurately."
)
)
_output_ = "".join(extract_type(node.target))
include_info.append(
(
_output_.split(".")[0],
self.sync(ast3.Name(id=_output_.split(".")[0], ctx=ast3.Load())),
)
)
scope = self.sync(
ast3.Call(
func=self.sync(
ast3.Attribute(
value=self.sync(
ast3.Name(
id=Con.JAC_FEATURE.value,
ctx=ast3.Load(),
)
),
attr="obj_scope",
ctx=ast3.Load(),
)
),
args=[
self.sync(
ast3.Name(
id="__file__",
ctx=ast3.Load(),
)
),
self.sync(ast3.Constant(value=_output_)),
],
keywords=[],
)
)
outputs = self.sync(
ast3.Call(
func=self.sync(
ast3.Attribute(
value=self.sync(
ast3.Name(
id=Con.JAC_FEATURE.value,
ctx=ast3.Load(),
)
),
attr="get_sem_type",
ctx=ast3.Load(),
)
),
args=[
self.sync(
ast3.Name(
id="__file__",
ctx=ast3.Load(),
)
),
self.sync(ast3.Constant(value=str(_output_))),
],
keywords=[],
)
)
if node.params and node.params.items:
inputs = [
self.sync(
ast3.Tuple(
elts=[
self.sync(
ast3.Call(
func=self.sync(
ast3.Attribute(
value=self.sync(
ast3.Name(
id=Con.JAC_FEATURE.value,
ctx=ast3.Load(),
)
),
attr="get_semstr_type",
ctx=ast3.Load(),
)
),
args=[
self.sync(
ast3.Name(
id="__file__", ctx=ast3.Load()
)
),
scope,
self.sync(
ast3.Constant(
value=(
kw_pair.key.value
if isinstance(
kw_pair.key, ast.Name
)
else None
)
)
),
self.sync(ast3.Constant(value=True)),
],
keywords=[],
)
),
self.sync(
ast3.Call(
func=self.sync(
ast3.Attribute(
value=self.sync(
ast3.Name(
id=Con.JAC_FEATURE.value,
ctx=ast3.Load(),
)
),
attr="get_semstr_type",
ctx=ast3.Load(),
)
),
args=[
self.sync(
ast3.Name(
id="__file__", ctx=ast3.Load()
)
),
scope,
self.sync(
ast3.Constant(
value=(
kw_pair.key.value
if isinstance(
kw_pair.key, ast.Name
)
else None
)
)
),
self.sync(ast3.Constant(value=False)),
],
keywords=[],
)
),
self.sync(
ast3.Constant(
value=(
kw_pair.key.value
if isinstance(kw_pair.key, ast.Name)
else None
)
)
),
kw_pair.value.gen.py_ast[0],
],
ctx=ast3.Load(),
)
)
for kw_pair in node.params.items
if isinstance(kw_pair, ast.KWPair)
]
else:
inputs = []

(
model,
model_params,
scope,
inputs,
outputs,
action,
include_info,
exclude_info,
) = self.get_by_llm_call_args(node)
node.gen.py_ast = [
self.sync(
ast3.Call(
Expand Down
176 changes: 176 additions & 0 deletions jaclang/plugin/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -979,6 +979,182 @@ def by_llm_call(
)
)

@staticmethod
@hookimpl
def get_by_llm_call_args(_pass: PyastGenPass, node: ast.FuncCall) -> tuple:
"""Get the by LLM call args."""
if node.genai_call is None:
raise _pass.ice("No genai_call")
model = node.genai_call.target.gen.py_ast[0]
model_params, include_info, exclude_info = extract_params(node.genai_call)
action = _pass.sync(
ast3.Constant(
value="Create an object of the specified type, using the specifically "
" provided input value(s) and look up any missing attributes from reliable"
" online sources to fill them in accurately."
)
)
_output_ = "".join(extract_type(node.target))
include_info.append(
(
_output_.split(".")[0],
_pass.sync(ast3.Name(id=_output_.split(".")[0], ctx=ast3.Load())),
)
)
scope = _pass.sync(
ast3.Call(
func=_pass.sync(
ast3.Attribute(
value=_pass.sync(
ast3.Name(
id=Con.JAC_FEATURE.value,
ctx=ast3.Load(),
)
),
attr="obj_scope",
ctx=ast3.Load(),
)
),
args=[
_pass.sync(
ast3.Name(
id="__file__",
ctx=ast3.Load(),
)
),
_pass.sync(ast3.Constant(value=_output_)),
],
keywords=[],
)
)
outputs = _pass.sync(
ast3.Call(
func=_pass.sync(
ast3.Attribute(
value=_pass.sync(
ast3.Name(
id=Con.JAC_FEATURE.value,
ctx=ast3.Load(),
)
),
attr="get_sem_type",
ctx=ast3.Load(),
)
),
args=[
_pass.sync(
ast3.Name(
id="__file__",
ctx=ast3.Load(),
)
),
_pass.sync(ast3.Constant(value=str(_output_))),
],
keywords=[],
)
)
if node.params and node.params.items:
inputs = [
_pass.sync(
ast3.Tuple(
elts=[
_pass.sync(
ast3.Call(
func=_pass.sync(
ast3.Attribute(
value=_pass.sync(
ast3.Name(
id=Con.JAC_FEATURE.value,
ctx=ast3.Load(),
)
),
attr="get_semstr_type",
ctx=ast3.Load(),
)
),
args=[
_pass.sync(
ast3.Name(id="__file__", ctx=ast3.Load())
),
scope,
_pass.sync(
ast3.Constant(
value=(
kw_pair.key.value
if isinstance(kw_pair.key, ast.Name)
else None
)
)
),
_pass.sync(ast3.Constant(value=True)),
],
keywords=[],
)
),
_pass.sync(
ast3.Call(
func=_pass.sync(
ast3.Attribute(
value=_pass.sync(
ast3.Name(
id=Con.JAC_FEATURE.value,
ctx=ast3.Load(),
)
),
attr="get_semstr_type",
ctx=ast3.Load(),
)
),
args=[
_pass.sync(
ast3.Name(id="__file__", ctx=ast3.Load())
),
scope,
_pass.sync(
ast3.Constant(
value=(
kw_pair.key.value
if isinstance(kw_pair.key, ast.Name)
else None
)
)
),
_pass.sync(ast3.Constant(value=False)),
],
keywords=[],
)
),
_pass.sync(
ast3.Constant(
value=(
kw_pair.key.value
if isinstance(kw_pair.key, ast.Name)
else None
)
)
),
kw_pair.value.gen.py_ast[0],
],
ctx=ast3.Load(),
)
)
for kw_pair in node.params.items
if isinstance(kw_pair, ast.KWPair)
]
else:
inputs = []

return (
model,
model_params,
scope,
inputs,
outputs,
action,
include_info,
exclude_info,
)


class JacBuiltin:
"""Jac Builtins."""
Expand Down
Loading

0 comments on commit 16e9b27

Please sign in to comment.