-
Notifications
You must be signed in to change notification settings - Fork 81
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
460 additions
and
111 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
from __future__ import annotations | ||
|
||
import shlex | ||
import shutil | ||
import subprocess | ||
from pathlib import Path | ||
|
||
from patchwork.logger import logger | ||
from patchwork.step import Step, StepStatus | ||
from patchwork.steps.CallCommand.typed import CallCommandInputs, CallCommandOutputs | ||
|
||
|
||
class CallCommand(Step, input_class=CallCommandInputs, output_class=CallCommandOutputs): | ||
def __init__(self, inputs: dict): | ||
super().__init__(inputs) | ||
self.command = shutil.which(inputs["command"]) | ||
if self.command is None: | ||
raise ValueError(f"Command `{inputs['command']}` not found in PATH") | ||
self.command_args = shlex.split(inputs.get("command_args", "")) | ||
self.working_dir = inputs.get("working_dir", Path.cwd()) | ||
self.env = self.__parse_env_text(inputs.get("env", "")) | ||
|
||
@staticmethod | ||
def __parse_env_text(env_text: str) -> dict[str, str]: | ||
env_spliter = shlex.shlex(env_text, posix=True) | ||
env_spliter.whitespace_split = True | ||
env_spliter.whitespace += ";" | ||
|
||
env: dict[str, str] = dict() | ||
for env_assign in env_spliter: | ||
env_assign_spliter = shlex.shlex(env_assign, posix=True) | ||
env_assign_spliter.whitespace_split = True | ||
env_assign_spliter.whitespace += "=" | ||
env_parts = list(env_assign_spliter) | ||
if len(env_parts) < 1: | ||
continue | ||
|
||
env_assign_target = env_parts[0] | ||
if len(env_parts) < 2: | ||
logger.error(f"{env_assign_target} is not assigned anything, skipping...") | ||
continue | ||
if len(env_parts) > 2: | ||
logger.error(f"{env_assign_target} has more than 1 assignment, skipping...") | ||
continue | ||
env[env_assign_target] = env_parts[1] | ||
|
||
return env | ||
|
||
def run(self) -> dict: | ||
cmd = [self.command, *self.command_args] | ||
p = subprocess.run(cmd, capture_output=True, text=True, cwd=self.working_dir, env=self.env) | ||
try: | ||
p.check_returncode() | ||
return dict(stdout_output=p.stdout) | ||
except subprocess.CalledProcessError as e: | ||
self.set_status( | ||
StepStatus.FAILED, | ||
f"`{self.command} {self.command_args}` failed with stdout:\n{p.stdout}\nstderr:\n{e.stderr}", | ||
) | ||
return dict(stdout_output="") |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from __future__ import annotations | ||
|
||
from typing_extensions import Annotated, TypedDict | ||
|
||
from patchwork.common.utils.step_typing import StepTypeConfig | ||
|
||
|
||
class __RequiredCallCommandInputs(TypedDict): | ||
command: str | ||
|
||
|
||
class CallCommandInputs(__RequiredCallCommandInputs, total=False): | ||
command_args: str | ||
working_dir: Annotated[str, StepTypeConfig(is_path=True)] | ||
env: str | ||
|
||
|
||
class CallCommandOutputs(TypedDict): | ||
stdout_output: str |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
from __future__ import annotations | ||
|
||
from sqlalchemy import URL, create_engine, exc, text | ||
|
||
from patchwork.step import Step, StepStatus | ||
from patchwork.steps.CallSQL.typed import CallSQLInputs, CallSQLOutputs | ||
|
||
|
||
class CallSQL(Step, input_class=CallSQLInputs, output_class=CallSQLOutputs): | ||
def __init__(self, inputs: dict): | ||
super().__init__(inputs) | ||
self.query = inputs["query"] | ||
self.__build_engine(inputs) | ||
|
||
def __build_engine(self, inputs: dict): | ||
dialect = inputs["dialect"] | ||
driver = inputs.get("driver") | ||
dialect_plus_driver = f"{dialect}+{driver}" if driver is not None else dialect | ||
kwargs = dict( | ||
username=inputs["username"], | ||
host=inputs.get("host", "localhost"), | ||
port=inputs.get("port", 5432), | ||
) | ||
if inputs.get("password") is not None: | ||
kwargs["password"] = inputs.get("password") | ||
connection_url = URL.create( | ||
dialect_plus_driver, | ||
**kwargs, | ||
) | ||
self.engine = create_engine(connection_url) | ||
with self.engine.connect() as conn: | ||
conn.execute(text("SELECT 1")) | ||
return self.engine | ||
|
||
def run(self) -> dict: | ||
try: | ||
with self.engine.begin() as conn: | ||
cursor = conn.execute(text(self.query)) | ||
result = cursor.fetchall() | ||
return dict(result=result) | ||
except exc.InvalidRequestError as e: | ||
self.set_status(StepStatus.FAILED, f"`{self.query}` failed with message:\n{e}") | ||
return dict(result=[]) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from __future__ import annotations | ||
|
||
from typing_extensions import Any, TypedDict | ||
|
||
|
||
class __RequiredCallSQLInputs(TypedDict): | ||
dialect: str | ||
username: str | ||
query: str | ||
|
||
|
||
class CallSQLInputs(__RequiredCallSQLInputs, total=False): | ||
driver: str | ||
password: str | ||
host: str | ||
port: int | ||
database: str | ||
|
||
|
||
class CallSQLOutputs(TypedDict): | ||
result: list[dict[str, Any]] |
Oops, something went wrong.