Skip to content

Commit

Permalink
Merge branch 'sqlfluff:main' into main
Browse files Browse the repository at this point in the history
troshnev authored Sep 24, 2024
2 parents 2ff3678 + 6dd6433 commit 4979286
Showing 99 changed files with 4,071 additions and 1,887 deletions.
4 changes: 3 additions & 1 deletion src/sqlfluff/cli/commands.py
Original file line number Diff line number Diff line change
@@ -763,7 +763,9 @@ def lint(
if not nofail:
if not non_human_output:
formatter.completion_message()
sys.exit(result.stats(EXIT_FAIL, EXIT_SUCCESS)["exit code"])
exit_code = result.stats(EXIT_FAIL, EXIT_SUCCESS)["exit code"]
assert isinstance(exit_code, int), "result.stats error code must be integer."
sys.exit(exit_code)
else:
sys.exit(EXIT_SUCCESS)

4 changes: 2 additions & 2 deletions src/sqlfluff/cli/formatters.py
Original file line number Diff line number Diff line change
@@ -18,7 +18,7 @@
from sqlfluff.cli.outputstream import OutputStream
from sqlfluff.core import FluffConfig, Linter, SQLBaseError, TimingSummary
from sqlfluff.core.enums import Color
from sqlfluff.core.linter import LintedFile, ParsedString
from sqlfluff.core.linter import FormatterInterface, LintedFile, ParsedString


def split_string_on_spaces(s: str, line_length: int = 100) -> List[str]:
@@ -65,7 +65,7 @@ def format_linting_result_header() -> str:
return text_buffer.getvalue()


class OutputStreamFormatter:
class OutputStreamFormatter(FormatterInterface):
"""Formatter which writes to an OutputStream.
On instantiation, this formatter accepts a function to
2 changes: 2 additions & 0 deletions src/sqlfluff/core/linter/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""Linter class and helper classes."""

from sqlfluff.core.linter.common import ParsedString, RenderedFile, RuleTuple
from sqlfluff.core.linter.formatter import FormatterInterface
from sqlfluff.core.linter.linted_file import LintedFile
from sqlfluff.core.linter.linter import Linter
from sqlfluff.core.linter.linting_result import LintingResult

__all__ = (
"FormatterInterface",
"RuleTuple",
"ParsedString",
"LintedFile",
4 changes: 2 additions & 2 deletions src/sqlfluff/core/linter/discovery.py
Original file line number Diff line number Diff line change
@@ -134,8 +134,8 @@ def _match_file_extension(filepath: str, valid_extensions: Sequence[str]) -> boo
Returns:
True if the file has an extension in `valid_extensions`.
"""
_, file_ext = os.path.splitext(filepath)
return file_ext.lower() in valid_extensions
filepath = filepath.lower()
return any(filepath.endswith(ext) for ext in valid_extensions)


def _process_exact_path(
22 changes: 22 additions & 0 deletions src/sqlfluff/core/linter/formatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""Defines the formatter interface which can be used by the CLI.
The linter module provides an optional formatter input which effectively
allows callbacks at various points of the linting process. This is primarily
to allow printed output at various points by the CLI, but could also be used
for logging our other processes looking to report back as the linting process
continues.
In this module we only define the interface. Any modules wishing to use the
interface should override with their own implementation.
"""

from abc import ABC, abstractmethod


class FormatterInterface(ABC):
"""Generic formatter interface."""

@abstractmethod
def dispatch_persist_filename(self, filename: str, result: str) -> None:
"""Called after a formatted file as been persisted to disk."""
...
95 changes: 53 additions & 42 deletions src/sqlfluff/core/linter/linted_dir.py
Original file line number Diff line number Diff line change
@@ -4,17 +4,23 @@
"""

from typing import Any, Dict, List, Literal, Optional, Tuple, TypedDict, Union, overload
from typing import Dict, Iterable, List, Optional, Tuple, Type, TypedDict, Union

from sqlfluff.core.errors import CheckTuple, SQLLintError
from sqlfluff.core.errors import (
CheckTuple,
SerializedObject,
SQLBaseError,
SQLLintError,
)
from sqlfluff.core.linter.formatter import FormatterInterface
from sqlfluff.core.linter.linted_file import TMP_PRS_ERROR_TYPES, LintedFile
from sqlfluff.core.parser.segments.base import BaseSegment

LintingRecord = TypedDict(
"LintingRecord",
{
"filepath": str,
"violations": List[dict],
"violations": List[SerializedObject],
# Things like file length
"statistics": Dict[str, int],
# Raw timings, in seconds, for both rules and steps
@@ -132,53 +138,56 @@ def add(self, file: LintedFile) -> None:
if self.retain_files:
self.files.append(file)

@overload
def check_tuples(self, by_path: Literal[False]) -> List[CheckTuple]:
"""Return a List of CheckTuples when by_path is False."""

@overload
def check_tuples(self, by_path: Literal[True]) -> Dict[str, List[CheckTuple]]:
"""Return a Dict of paths and CheckTuples when by_path is True."""

@overload
def check_tuples(self, by_path: bool = False):
"""Default overload method."""

def check_tuples(
self, by_path=False, raise_on_non_linting_violations=True
) -> Union[List[CheckTuple], Dict[str, List[CheckTuple]]]:
self, raise_on_non_linting_violations: bool = True
) -> List[CheckTuple]:
"""Compress all the tuples into one list.
NB: This is a little crude, as you can't tell which
file the violations are from. Good for testing though.
For more control set the `by_path` argument to true.
For more control use `check_tuples_by_path`.
"""
assert self.retain_files, "cannot `check_tuples()` without `retain_files`"
if by_path:
return {
file.path: file.check_tuples(
raise_on_non_linting_violations=raise_on_non_linting_violations
)
for file in self.files
}
else:
tuple_buffer: List[CheckTuple] = []
for file in self.files:
tuple_buffer += file.check_tuples(
raise_on_non_linting_violations=raise_on_non_linting_violations
)
return tuple_buffer

def num_violations(self, **kwargs) -> int:
return [
check_tuple
for file in self.files
for check_tuple in file.check_tuples(
raise_on_non_linting_violations=raise_on_non_linting_violations
)
]

def check_tuples_by_path(
self, raise_on_non_linting_violations: bool = True
) -> Dict[str, List[CheckTuple]]:
"""Fetch all check_tuples from all contained `LintedDir` objects.
Returns:
A dict, with lists of tuples grouped by path.
"""
assert (
self.retain_files
), "cannot `check_tuples_by_path()` without `retain_files`"
return {
file.path: file.check_tuples(
raise_on_non_linting_violations=raise_on_non_linting_violations
)
for file in self.files
}

def num_violations(
self,
types: Optional[Union[Type[SQLBaseError], Iterable[Type[SQLBaseError]]]] = None,
fixable: Optional[bool] = None,
) -> int:
"""Count the number of violations in the path."""
return sum(file.num_violations(**kwargs) for file in self.files)
return sum(
file.num_violations(types=types, fixable=fixable) for file in self.files
)

def get_violations(self, **kwargs) -> list:
def get_violations(
self, rules: Optional[Union[str, Tuple[str, ...]]] = None
) -> List[SQLBaseError]:
"""Return a list of violations in the path."""
buff: list = []
for file in self.files:
buff += file.get_violations(**kwargs)
return buff
return [v for file in self.files for v in file.get_violations(rules=rules)]

def as_records(self) -> List[LintingRecord]:
"""Return the result as a list of dictionaries.
@@ -199,7 +208,9 @@ def stats(self) -> Dict[str, int]:
}

def persist_changes(
self, formatter: Any = None, fixed_file_suffix: str = ""
self,
formatter: Optional[FormatterInterface] = None,
fixed_file_suffix: str = "",
) -> Dict[str, Union[bool, str]]:
"""Persist changes to files in the given path.
26 changes: 20 additions & 6 deletions src/sqlfluff/core/linter/linted_file.py
Original file line number Diff line number Diff line change
@@ -12,7 +12,7 @@
import tempfile
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Dict, Iterable, List, NamedTuple, Optional, Tuple, Type, Union
from typing import Dict, Iterable, List, NamedTuple, Optional, Tuple, Type, Union

from sqlfluff.core.errors import (
CheckTuple,
@@ -21,6 +21,7 @@
SQLParseError,
SQLTemplaterError,
)
from sqlfluff.core.linter.formatter import FormatterInterface
from sqlfluff.core.linter.patch import FixPatch, generate_source_patches

# Classes needed only for type checking
@@ -44,7 +45,7 @@ class FileTimings:
# process this as we wish later.
rule_timings: List[Tuple[str, str, float]]

def __repr__(self): # pragma: no cover
def __repr__(self) -> str: # pragma: no cover
return "<FileTimings>"

def get_rule_timing_dict(self) -> Dict[str, float]:
@@ -166,12 +167,23 @@ def get_violations(
violations += self.ignore_mask.generate_warnings_for_unused()
return violations

def num_violations(self, **kwargs) -> int:
def num_violations(
self,
types: Optional[Union[Type[SQLBaseError], Iterable[Type[SQLBaseError]]]] = None,
filter_ignore: bool = True,
filter_warning: bool = True,
fixable: Optional[bool] = None,
) -> int:
"""Count the number of violations.
Optionally now with filters.
"""
violations = self.get_violations(**kwargs)
violations = self.get_violations(
types=types,
filter_ignore=filter_ignore,
filter_warning=filter_warning,
fixable=fixable,
)
return len(violations)

def is_clean(self) -> bool:
@@ -369,7 +381,9 @@ def _build_up_fixed_source_string(
str_buff += raw_source_string[source_slice]
return str_buff

def persist_tree(self, suffix: str = "", formatter: Any = None) -> bool:
def persist_tree(
self, suffix: str = "", formatter: Optional[FormatterInterface] = None
) -> bool:
"""Persist changes to the given path."""
if self.num_violations(fixable=True) > 0:
write_buff, success = self.fix_string()
@@ -398,7 +412,7 @@ def persist_tree(self, suffix: str = "", formatter: Any = None) -> bool:
@staticmethod
def _safe_create_replace_file(
input_path: str, output_path: str, write_buff: str, encoding: str
):
) -> None:
# Write to a temporary file first, so in case of encoding or other
# issues, we don't delete or corrupt the user's existing file.

Loading

0 comments on commit 4979286

Please sign in to comment.