Skip to content

Commit

Permalink
Merge pull request #725 from guardrails-ai/karan/ruffify
Browse files Browse the repository at this point in the history
Replace black, isort and flake8 with ruff
  • Loading branch information
zsimjee authored Apr 23, 2024
2 parents 6de5641 + 119815e commit 6cb406b
Show file tree
Hide file tree
Showing 46 changed files with 739 additions and 769 deletions.
6 changes: 0 additions & 6 deletions .flake8

This file was deleted.

2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ jobs:
run: |
make full
- name: Lint with isort, black, docformatter, flake8
- name: Lint with ruff
run: |
make lint
Expand Down
13 changes: 13 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.1
hooks:
# Performs ruff check with safe fixes
- id: ruff
name: ruff
description: "Run 'ruff' for linting"
args: ["--fix"]
# Performs ruff format
- id: ruff-format
name: ruff-format
description: "Run 'ruff format' for formatting"
10 changes: 5 additions & 5 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ PYDANTIC_VERSION_MAJOR := $(shell poetry run python -c 'import pydantic; print(p
TYPING_CMD := type-pydantic-v$(PYDANTIC_VERSION_MAJOR)-openai-v$(OPENAI_VERSION_MAJOR)

autoformat:
poetry run black guardrails/ tests/
poetry run isort --atomic guardrails/ tests/
poetry run ruff check guardrails/ tests/ --fix
poetry run ruff format guardrails/ tests/
poetry run docformatter --in-place --recursive guardrails tests

.PHONY: type
Expand Down Expand Up @@ -37,9 +37,8 @@ type-pydantic-v2-openai-v1:
rm pyrightconfig.json

lint:
poetry run isort -c guardrails/ tests/
poetry run black guardrails/ tests/ --check
poetry run flake8 guardrails/ tests/
poetry run ruff check guardrails/ tests/
poetry run ruff format guardrails/ tests/ --check

test:
poetry run pytest tests/
Expand All @@ -66,6 +65,7 @@ docs-deploy:

dev:
poetry install
poetry run pre-commit install

full:
poetry install --all-extras
Expand Down
7 changes: 3 additions & 4 deletions guardrails/classes/history/call.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,8 @@ def validation_response(self) -> Optional[Union[str, Dict, ReAsk]]:
self.inputs.full_schema_reask
or number_of_iterations < 2
or isinstance(
self.iterations.last.validation_response, ReAsk # type: ignore
self.iterations.last.validation_response, # type: ignore
ReAsk, # type: ignore
)
or isinstance(self.iterations.last.validation_response, str) # type: ignore
):
Expand Down Expand Up @@ -410,9 +411,7 @@ def tree(self) -> Tree:
title="Validated Output",
style="on #F0FFF0",
)
tree.children[
-1
].label.renderable._renderables = previous_panels + ( # type: ignore
tree.children[-1].label.renderable._renderables = previous_panels + ( # type: ignore
validated_outcome_panel,
)

Expand Down
2 changes: 1 addition & 1 deletion guardrails/classes/history/iteration.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def status(self) -> str:
@property
def rich_group(self) -> Group:
def create_msg_history_table(
msg_history: Optional[List[Dict[str, Prompt]]]
msg_history: Optional[List[Dict[str, Prompt]]],
) -> Union[str, Table]:
if msg_history is None:
return "No message history."
Expand Down
1 change: 1 addition & 0 deletions guardrails/cli/hub/create_validator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# ruff: noqa: E501
import os
from datetime import date
from string import Template
Expand Down
3 changes: 2 additions & 1 deletion guardrails/cli/hub/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,8 @@ def install(

console.print(f"\nInstalling {package_uri}...\n")
logger.log(
level=LEVELS.get("SPAM"), msg=f"Installing {package_uri}..." # type: ignore
level=LEVELS.get("SPAM"), # type: ignore
msg=f"Installing {package_uri}...",
)

# Validation
Expand Down
6 changes: 3 additions & 3 deletions guardrails/cli/logger.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import logging
import os

os.environ[
"COLOREDLOGS_LEVEL_STYLES"
] = "spam=white,faint;success=green,bold;debug=magenta;verbose=blue;notice=cyan,bold;warning=yellow;error=red;critical=background=red" # noqa
os.environ["COLOREDLOGS_LEVEL_STYLES"] = (
"spam=white,faint;success=green,bold;debug=magenta;verbose=blue;notice=cyan,bold;warning=yellow;error=red;critical=background=red" # noqa
)
LEVELS = {
"SPAM": 5,
"VERBOSE": 15,
Expand Down
7 changes: 2 additions & 5 deletions guardrails/document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ class DocumentStoreBase(ABC):
The store can be queried by text for similar documents.
"""

def __init__(self, vector_db: VectorDBBase, path: Optional[str] = None):
...
def __init__(self, vector_db: VectorDBBase, path: Optional[str] = None): ...

@abstractmethod
def add_document(self, document: Document) -> None:
Expand Down Expand Up @@ -182,9 +181,7 @@ class RealSqlDocument(Base):
__tablename__ = "documents"

id: Mapped[int] = mapped_column(primary_key=True) # type: ignore
page_num: Mapped[int] = mapped_column(
sqlalchemy.Integer, primary_key=True
) # type: ignore
page_num: Mapped[int] = mapped_column(sqlalchemy.Integer, primary_key=True) # type: ignore
text: Mapped[str] = mapped_column(sqlalchemy.String) # type: ignore
meta: Mapped[dict] = mapped_column(sqlalchemy.PickleType) # type: ignore
vector_index: Mapped[int] = mapped_column(sqlalchemy.Integer) # type: ignore
Expand Down
31 changes: 11 additions & 20 deletions guardrails/guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@
overload,
)

from guardrails_api_client.models import AnyObject
from guardrails_api_client.models import Guard as GuardModel
from guardrails_api_client.models import (
AnyObject,
History,
HistoryEvent,
ValidatePayload,
ValidationOutput,
)
from guardrails_api_client.models import Guard as GuardModel
from guardrails_api_client.types import UNSET
from langchain_core.messages import BaseMessage
from langchain_core.runnables import Runnable, RunnableConfig
Expand Down Expand Up @@ -467,8 +467,7 @@ def __call__(
stream: Optional[bool] = False,
*args,
**kwargs,
) -> Union[ValidationOutcome[OT], Iterable[ValidationOutcome[OT]]]:
...
) -> Union[ValidationOutcome[OT], Iterable[ValidationOutcome[OT]]]: ...

@overload
def __call__(
Expand All @@ -483,8 +482,7 @@ def __call__(
full_schema_reask: Optional[bool] = None,
*args,
**kwargs,
) -> Awaitable[ValidationOutcome[OT]]:
...
) -> Awaitable[ValidationOutcome[OT]]: ...

def __call__(
self,
Expand Down Expand Up @@ -811,8 +809,7 @@ def parse(
full_schema_reask: Optional[bool] = None,
*args,
**kwargs,
) -> ValidationOutcome[OT]:
...
) -> ValidationOutcome[OT]: ...

@overload
def parse(
Expand All @@ -825,8 +822,7 @@ def parse(
full_schema_reask: Optional[bool] = None,
*args,
**kwargs,
) -> Awaitable[ValidationOutcome[OT]]:
...
) -> Awaitable[ValidationOutcome[OT]]: ...

@overload
def parse(
Expand All @@ -839,8 +835,7 @@ def parse(
full_schema_reask: Optional[bool] = None,
*args,
**kwargs,
) -> ValidationOutcome[OT]:
...
) -> ValidationOutcome[OT]: ...

def parse(
self,
Expand Down Expand Up @@ -1194,14 +1189,12 @@ def __add_validator(self, validator: Validator, on: str = "output"):
)

@overload
def use(self, validator: Validator, *, on: str = "output") -> "Guard":
...
def use(self, validator: Validator, *, on: str = "output") -> "Guard": ...

@overload
def use(
self, validator: Type[Validator], *args, on: str = "output", **kwargs
) -> "Guard":
...
) -> "Guard": ...

def use(
self,
Expand All @@ -1227,8 +1220,7 @@ def use(
return self

@overload
def use_many(self, *validators: Validator, on: str = "output") -> "Guard":
...
def use_many(self, *validators: Validator, on: str = "output") -> "Guard": ...

@overload
def use_many(
Expand All @@ -1239,8 +1231,7 @@ def use_many(
Optional[Dict[str, Any]],
],
on: str = "output",
) -> "Guard":
...
) -> "Guard": ...

def use_many(
self,
Expand Down
2 changes: 1 addition & 1 deletion guardrails/llm_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,7 +946,7 @@ def model_is_supported_server_side(

# FIXME: Update with newly supported LLMs
def get_llm_api_enum(
llm_api: Callable[[Any], Awaitable[Any]]
llm_api: Callable[[Any], Awaitable[Any]],
) -> Optional[ValidatePayloadLlmApi]:
# TODO: Distinguish between v1 and v2
if llm_api == get_static_openai_create_func():
Expand Down
1 change: 1 addition & 0 deletions guardrails/prompt/base_prompt.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Class for representing a prompt entry."""

import re
from string import Template
from typing import Optional
Expand Down
1 change: 1 addition & 0 deletions guardrails/prompt/instructions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Instructions to the LLM, to be passed in the prompt."""

from string import Template

from guardrails.utils.parsing_utils import get_template_variables
Expand Down
1 change: 1 addition & 0 deletions guardrails/prompt/prompt.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""The LLM prompt."""

from string import Template

from guardrails.utils.parsing_utils import get_template_variables
Expand Down
1 change: 1 addition & 0 deletions guardrails/rail.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Rail class."""

import warnings
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Sequence, Type, Union
Expand Down
2 changes: 1 addition & 1 deletion guardrails/utils/json_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def verify(


def generate_type_skeleton_from_schema(
schema: Union[Object, ListDataType]
schema: Union[Object, ListDataType],
) -> Placeholder:
"""Generate a JSON skeleton from an XML schema."""

Expand Down
2 changes: 1 addition & 1 deletion guardrails/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def generate_test_artifacts(
os.path.join(artifact_dir, f"validated_response_{on_fail_type}{ext}.py"),
"w",
) as f:
f.write("# flake8: noqa: E501\n")
f.write("# ruff: noqa: E501\n")

reasks, _ = gather_reasks(validated_output)
if len(reasks):
Expand Down
23 changes: 13 additions & 10 deletions guardrails/utils/pydantic_utils/v1.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Utilities for working with Pydantic models."""

import typing
import warnings
from copy import deepcopy
Expand Down Expand Up @@ -113,7 +114,9 @@ def prepare_type_annotation(type_annotation: Union[ModelField, Type]) -> Type:
# Strip a Union type annotation to the first non-None type
if get_origin(type_annotation) == Union:
non_none_type_annotation = [
t for t in get_args(type_annotation) if t != type(None) # noqa E721
t
for t in get_args(type_annotation)
if t != type(None) # noqa E721
]
if len(non_none_type_annotation) == 1:
return non_none_type_annotation[0]
Expand Down Expand Up @@ -311,7 +314,7 @@ def schema_to_bare_model(model: Type[BaseModel]) -> Type[BaseModel]:


def convert_pydantic_model_to_openai_fn(
model: Union[Type[BaseModel], Type[List[Type[BaseModel]]]]
model: Union[Type[BaseModel], Type[List[Type[BaseModel]]]],
) -> Dict:
"""Convert a Pydantic BaseModel to an OpenAI function.
Expand Down Expand Up @@ -459,14 +462,14 @@ def convert_pydantic_model_to_datatype(
assert typing.get_origin(case_discriminator_type) is typing.Literal
assert len(typing.get_args(case_discriminator_type)) == 1
discriminator_value = typing.get_args(case_discriminator_type)[0]
choice_children[
discriminator_value
] = convert_pydantic_model_to_datatype(
case,
datatype=CaseDataType,
name=discriminator_value,
strict=strict,
excluded_fields=[discriminator],
choice_children[discriminator_value] = (
convert_pydantic_model_to_datatype(
case,
datatype=CaseDataType,
name=discriminator_value,
strict=strict,
excluded_fields=[discriminator],
)
)
children[field_name] = pydantic_field_to_datatype(
Choice,
Expand Down
22 changes: 12 additions & 10 deletions guardrails/utils/pydantic_utils/v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def is_enum(type_annotation: Any) -> bool:


def convert_pydantic_model_to_openai_fn(
model: Union[Type[BaseModel], Type[List[Type[BaseModel]]]]
model: Union[Type[BaseModel], Type[List[Type[BaseModel]]]],
) -> Dict:
"""Convert a Pydantic BaseModel to an OpenAI function.
Expand Down Expand Up @@ -298,7 +298,9 @@ def prepare_type_annotation(type_annotation: Union[FieldInfo, Type]) -> Type:
# Strip a Union type annotation to the first non-None type
if typing.get_origin(type_annotation) == Union:
non_none_type_annotation = [
t for t in get_args(type_annotation) if t != type(None) # noqa E721
t
for t in get_args(type_annotation)
if t != type(None) # noqa E721
]
if len(non_none_type_annotation) == 1:
return non_none_type_annotation[0]
Expand Down Expand Up @@ -424,14 +426,14 @@ def convert_pydantic_model_to_datatype(
assert typing.get_origin(case_discriminator_type) is typing.Literal
assert len(typing.get_args(case_discriminator_type)) == 1
discriminator_value = typing.get_args(case_discriminator_type)[0]
choice_children[
discriminator_value
] = convert_pydantic_model_to_datatype(
case,
datatype=CaseDataType,
name=discriminator_value,
strict=strict,
excluded_fields=[discriminator],
choice_children[discriminator_value] = (
convert_pydantic_model_to_datatype(
case,
datatype=CaseDataType,
name=discriminator_value,
strict=strict,
excluded_fields=[discriminator],
)
)
children[field_name] = pydantic_field_to_datatype(
Choice,
Expand Down
2 changes: 1 addition & 1 deletion guardrails/utils/reask_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class NonParseableReAsk(ReAsk):


def gather_reasks(
validated_output: Optional[Union[str, Dict, List, ReAsk]]
validated_output: Optional[Union[str, Dict, List, ReAsk]],
) -> Tuple[List[ReAsk], Union[Dict, List, None]]:
"""Traverse output and gather all ReAsk objects.
Expand Down
Loading

0 comments on commit 6cb406b

Please sign in to comment.