Skip to content

Commit

Permalink
bugfixes and changes
Browse files Browse the repository at this point in the history
  • Loading branch information
CTY-git committed Jan 14, 2025
1 parent 5b4daf7 commit 07b68d1
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 24 deletions.
33 changes: 21 additions & 12 deletions patchwork/steps/CallSQL/CallSQL.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,33 @@
from sqlalchemy import URL, create_engine, exc, text

from patchwork.common.utils.utils import mustache_render
from patchwork.logger import logger
from patchwork.step import Step, StepStatus
from patchwork.steps.CallSQL.typed import CallSQLInputs, CallSQLOutputs


class CallSQL(Step, input_class=CallSQLInputs, output_class=CallSQLOutputs):
def __init__(self, inputs: dict):
super().__init__(inputs)
query_template_data = inputs.get("query_template_values", {})
self.query = mustache_render(inputs["query"], query_template_data)
query_template_data = inputs.get("db_query_template_values", {})
self.query = mustache_render(inputs["db_query"], query_template_data)
self.__build_engine(inputs)

def __build_engine(self, inputs: dict):
dialect = inputs["dialect"]
driver = inputs.get("driver")
dialect = inputs["db_dialect"]
driver = inputs.get("db_driver")
dialect_plus_driver = f"{dialect}+{driver}" if driver is not None else dialect
kwargs = dict(
username=inputs["username"],
host=inputs.get("host", "localhost"),
port=inputs.get("port", 5432),
username=inputs["db_username"],
host=inputs.get("db_host", "localhost"),
port=inputs.get("db_port", 5432),
)
if inputs.get("password") is not None:
kwargs["password"] = inputs.get("password")
if inputs.get("db_password") is not None:
kwargs["password"] = inputs.get("db_password")
if inputs.get("db_name") is not None:
kwargs["database"] = inputs.get("db_name")
if inputs.get("db_params") is not None:
kwargs["query"] = inputs.get("db_params")
connection_url = URL.create(
dialect_plus_driver,
**kwargs,
Expand All @@ -36,10 +41,14 @@ def __build_engine(self, inputs: dict):

def run(self) -> dict:
try:
rv = []
with self.engine.begin() as conn:
cursor = conn.execute(text(self.query))
result = cursor.fetchall()
return dict(result=result)
for row in cursor:
result = row._asdict()
rv.append(result)
logger.info(f"Retrieved {len(rv)} rows!")
return dict(results=rv)
except exc.InvalidRequestError as e:
self.set_status(StepStatus.FAILED, f"`{self.query}` failed with message:\n{e}")
return dict(result=[])
return dict(results=[])
20 changes: 10 additions & 10 deletions patchwork/steps/CallSQL/typed.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,19 @@


class __RequiredCallSQLInputs(TypedDict):
dialect: str
username: str
query: str
db_dialect: str
db_username: str
db_query: str


class CallSQLInputs(__RequiredCallSQLInputs, total=False):
driver: str
password: str
host: str
port: int
database: str
query_template_values: dict[str, Any]
db_driver: str
db_password: str
db_host: str
db_port: int
db_name: str
db_query_template_values: dict[str, Any]


class CallSQLOutputs(TypedDict):
result: list[dict[str, Any]]
results: list[dict[str, Any]]
4 changes: 3 additions & 1 deletion patchwork/steps/CallShell/CallShell.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ def run(self) -> dict:
except subprocess.CalledProcessError as e:
self.set_status(
StepStatus.FAILED,
f"script failed with stdout:\n{p.stdout}\nstderr:\n{e.stderr}",
f"Script failed.",
)
logger.info(f"stdout: \n{p.stdout}")
logger.info(f"stderr:\n{p.stderr}")
return dict(stdout_output=p.stdout, stderr_output=p.stderr)
1 change: 0 additions & 1 deletion patchwork/steps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@
"ReadPRDiffs",
"ReadPRDiffsPB",
"ReadPRs",
"ResolveIssue",
"ScanDepscan",
"ScanSemgrep",
"ScanSonar",
Expand Down

0 comments on commit 07b68d1

Please sign in to comment.