Skip to content

Commit

Permalink
Add new steps
Browse files Browse the repository at this point in the history
  • Loading branch information
CTY-git committed Jan 9, 2025
1 parent 3e81f25 commit 3626991
Show file tree
Hide file tree
Showing 9 changed files with 460 additions and 111 deletions.
44 changes: 28 additions & 16 deletions patchwork/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,16 @@

from enum import Enum

from typing_extensions import Any, Dict, List, Optional, Union, is_typeddict
from typing_extensions import (
Any,
Collection,
Dict,
List,
Optional,
Type,
Union,
is_typeddict,
)

from patchwork.logger import logger

Expand Down Expand Up @@ -45,10 +54,9 @@ def __init__(self, inputs: DataPoint):
"""

# check if the inputs have the required keys
if self.__input_class is not None:
missing_keys = self.__input_class.__required_keys__.difference(inputs.keys())
if len(missing_keys) > 0:
raise ValueError(f"Missing required data: {list(missing_keys)}")
missing_keys = self.find_missing_inputs(inputs)
if len(missing_keys) > 0:
raise ValueError(f"Missing required data: {list(missing_keys)}")

# store the inputs
self.inputs = inputs
Expand All @@ -64,19 +72,23 @@ def __init__(self, inputs: DataPoint):
self.original_run = self.run
self.run = self.__managed_run

def __init_subclass__(cls, **kwargs):
input_class = kwargs.get("input_class", None) or getattr(cls, "input_class", None)
output_class = kwargs.get("output_class", None) or getattr(cls, "output_class", None)
def __init_subclass__(cls, input_class: Optional[Type] = None, output_class: Optional[Type] = None, **kwargs):
input_class = input_class or getattr(cls, "input_class", None)
if input_class is not None and not is_typeddict(input_class):
input_class = None

if input_class is not None and is_typeddict(input_class):
cls.__input_class = input_class
else:
cls.__input_class = None
output_class = output_class or getattr(cls, "output_class", None)
if output_class is not None and not is_typeddict(output_class):
output_class = None

if output_class is not None and is_typeddict(output_class):
cls.__output_class = output_class
else:
cls.__output_class = None
cls.__input_class = input_class
cls.__output_class = output_class

@classmethod
def find_missing_inputs(cls, inputs: DataPoint) -> Collection:
if getattr(cls, "__input_class", None) is None:
return []
return cls.__input_class.__required_keys__.difference(inputs.keys())

def __managed_run(self, *args, **kwargs) -> Any:
self.debug(self.inputs)
Expand Down
60 changes: 60 additions & 0 deletions patchwork/steps/CallCommand/CallCommand.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from __future__ import annotations

import shlex
import shutil
import subprocess
from pathlib import Path

from patchwork.logger import logger
from patchwork.step import Step, StepStatus
from patchwork.steps.CallCommand.typed import CallCommandInputs, CallCommandOutputs


class CallCommand(Step, input_class=CallCommandInputs, output_class=CallCommandOutputs):
def __init__(self, inputs: dict):
super().__init__(inputs)
self.command = shutil.which(inputs["command"])
if self.command is None:
raise ValueError(f"Command `{inputs['command']}` not found in PATH")
self.command_args = shlex.split(inputs.get("command_args", ""))
self.working_dir = inputs.get("working_dir", Path.cwd())
self.env = self.__parse_env_text(inputs.get("env", ""))

@staticmethod
def __parse_env_text(env_text: str) -> dict[str, str]:
env_spliter = shlex.shlex(env_text, posix=True)
env_spliter.whitespace_split = True
env_spliter.whitespace += ";"

env: dict[str, str] = dict()
for env_assign in env_spliter:
env_assign_spliter = shlex.shlex(env_assign, posix=True)
env_assign_spliter.whitespace_split = True
env_assign_spliter.whitespace += "="
env_parts = list(env_assign_spliter)
if len(env_parts) < 1:
continue

env_assign_target = env_parts[0]
if len(env_parts) < 2:
logger.error(f"{env_assign_target} is not assigned anything, skipping...")
continue
if len(env_parts) > 2:
logger.error(f"{env_assign_target} has more than 1 assignment, skipping...")
continue
env[env_assign_target] = env_parts[1]

return env

def run(self) -> dict:
cmd = [self.command, *self.command_args]
p = subprocess.run(cmd, capture_output=True, text=True, cwd=self.working_dir, env=self.env)
try:
p.check_returncode()
return dict(stdout_output=p.stdout)
except subprocess.CalledProcessError as e:
self.set_status(
StepStatus.FAILED,
f"`{self.command} {self.command_args}` failed with stdout:\n{p.stdout}\nstderr:\n{e.stderr}",
)
return dict(stdout_output="")
Empty file.
19 changes: 19 additions & 0 deletions patchwork/steps/CallCommand/typed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from __future__ import annotations

from typing_extensions import Annotated, TypedDict

from patchwork.common.utils.step_typing import StepTypeConfig


class __RequiredCallCommandInputs(TypedDict):
command: str


class CallCommandInputs(__RequiredCallCommandInputs, total=False):
command_args: str
working_dir: Annotated[str, StepTypeConfig(is_path=True)]
env: str


class CallCommandOutputs(TypedDict):
stdout_output: str
43 changes: 43 additions & 0 deletions patchwork/steps/CallSQL/CallSQL.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from __future__ import annotations

from sqlalchemy import URL, create_engine, exc, text

from patchwork.step import Step, StepStatus
from patchwork.steps.CallSQL.typed import CallSQLInputs, CallSQLOutputs


class CallSQL(Step, input_class=CallSQLInputs, output_class=CallSQLOutputs):
def __init__(self, inputs: dict):
super().__init__(inputs)
self.query = inputs["query"]
self.__build_engine(inputs)

def __build_engine(self, inputs: dict):
dialect = inputs["dialect"]
driver = inputs.get("driver")
dialect_plus_driver = f"{dialect}+{driver}" if driver is not None else dialect
kwargs = dict(
username=inputs["username"],
host=inputs.get("host", "localhost"),
port=inputs.get("port", 5432),
)
if inputs.get("password") is not None:
kwargs["password"] = inputs.get("password")
connection_url = URL.create(
dialect_plus_driver,
**kwargs,
)
self.engine = create_engine(connection_url)
with self.engine.connect() as conn:
conn.execute(text("SELECT 1"))
return self.engine

def run(self) -> dict:
try:
with self.engine.begin() as conn:
cursor = conn.execute(text(self.query))
result = cursor.fetchall()
return dict(result=result)
except exc.InvalidRequestError as e:
self.set_status(StepStatus.FAILED, f"`{self.query}` failed with message:\n{e}")
return dict(result=[])
Empty file.
21 changes: 21 additions & 0 deletions patchwork/steps/CallSQL/typed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from __future__ import annotations

from typing_extensions import Any, TypedDict


class __RequiredCallSQLInputs(TypedDict):
dialect: str
username: str
query: str


class CallSQLInputs(__RequiredCallSQLInputs, total=False):
driver: str
password: str
host: str
port: int
database: str


class CallSQLOutputs(TypedDict):
result: list[dict[str, Any]]
Loading

0 comments on commit 3626991

Please sign in to comment.