Skip to content

Commit

Permalink
Fix linter issues in sqllogic module
Browse files Browse the repository at this point in the history
  • Loading branch information
Felixoid committed Mar 5, 2024
1 parent 711da95 commit 770d710
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 149 deletions.
24 changes: 10 additions & 14 deletions tests/sqllogic/connection.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import datetime
import logging
import pyodbc
import sqlite3
import traceback
import enum
import logging
import random
import sqlite3
import string
from contextlib import contextmanager

import pyodbc # pylint:disable=import-error; for style check
from exceptions import ProgramError


logger = logging.getLogger("connection")
logger.setLevel(logging.DEBUG)

Expand All @@ -22,9 +19,7 @@ def __init__(self, **kwargs):
self._kwargs = kwargs

def __str__(self):
conn_str = ";".join(
["{}={}".format(x, y) for x, y in self._kwargs.items() if y]
)
conn_str = ";".join([f"{x}={y}" for x, y in self._kwargs.items() if y])
return conn_str

def update_database(self, database):
Expand All @@ -49,6 +44,7 @@ def create_from_connection_string(conn_str):
for kv in conn_str.split(";"):
if kv:
k, v = kv.split("=", 1)
# pylint:disable-next=protected-access
args._kwargs[k] = v
return args

Expand Down Expand Up @@ -82,7 +78,7 @@ class KnownDBMS(str, enum.Enum):
clickhouse = "ClickHouse"


class ConnectionWrap(object):
class ConnectionWrap:
def __init__(self, connection=None, factory=None, factory_kwargs=None):
self._factory = factory
self._factory_kwargs = factory_kwargs
Expand Down Expand Up @@ -126,7 +122,7 @@ def drop_all_tables(self):
f"SELECT name FROM system.tables WHERE database='{self.DATABASE_NAME}'"
)
elif self.DBMS_NAME == KnownDBMS.sqlite.value:
list_query = f"SELECT name FROM sqlite_master WHERE type='table'"
list_query = "SELECT name FROM sqlite_master WHERE type='table'"
else:
logger.warning(
"unable to drop all tables for unknown database: %s", self.DBMS_NAME
Expand Down Expand Up @@ -154,7 +150,7 @@ def use_random_database(self):
self._use_database(database)
logger.info(
"currentDatabase : %s",
execute_request(f"SELECT currentDatabase()", self).get_result(),
execute_request("SELECT currentDatabase()", self).get_result(),
)

@contextmanager
Expand All @@ -174,7 +170,7 @@ def with_test_database_scope(self):

def __exit__(self, *args):
if hasattr(self._connection, "close"):
return self._connection.close()
self._connection.close()


def setup_connection(engine, conn_str=None, make_debug_request=True):
Expand Down Expand Up @@ -263,7 +259,7 @@ def has_exception(self):
def assert_no_exception(self):
if self.has_exception():
raise ProgramError(
f"request doesn't have a result set, it has the exception",
"request doesn't have a result set, it has the exception",
parent=self._exception,
)

Expand Down
26 changes: 4 additions & 22 deletions tests/sqllogic/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from enum import Enum


class Error(Exception):
def __init__(
Expand Down Expand Up @@ -45,16 +43,8 @@ def message(self):

@property
def reason(self):
return ", ".join(
(
str(x)
for x in [
super().__str__(),
"details: {}".format(self._details) if self._details else "",
]
if x
)
)
details = f"details: {self._details}" if self._details else ""
return ", ".join((str(x) for x in [super().__str__(), details] if x))

def set_details(self, file=None, name=None, pos=None, request=None, details=None):
if file is not None:
Expand Down Expand Up @@ -88,16 +78,8 @@ def get_parent(self):

@property
def reason(self):
return ", ".join(
(
str(x)
for x in [
super().reason,
"exception: {}".format(str(self._parent)) if self._parent else "",
]
if x
)
)
exception = f"exception: {self._parent}" if self._parent else ""
return ", ".join((str(x) for x in [super().reason, exception] if x))


class ProgramError(ErrorWithParent):
Expand Down
41 changes: 24 additions & 17 deletions tests/sqllogic/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,25 @@
# -*- coding: utf-8 -*-

import argparse
import enum
import os
import logging
import csv
import enum
import json
import logging
import multiprocessing
import os
from functools import reduce
from deepdiff import DeepDiff

from connection import setup_connection, Engines, default_clickhouse_odbc_conn_str
from test_runner import TestRunner, Status, RequestType
# isort: off
from deepdiff import DeepDiff # pylint:disable=import-error; for style check

# isort: on

from connection import Engines, default_clickhouse_odbc_conn_str, setup_connection
from test_runner import RequestType, Status, TestRunner

LEVEL_NAMES = [x.lower() for x in logging._nameToLevel.keys() if x != logging.NOTSET]
LEVEL_NAMES = [ # pylint:disable-next=protected-access
l.lower() for l, n in logging._nameToLevel.items() if n != logging.NOTSET
]


def setup_logger(args):
Expand All @@ -41,7 +46,7 @@ def __write_check_status(status_row, out_dir):
if len(status_row) > 140:
status_row = status_row[0:135] + "..."
check_status_path = os.path.join(out_dir, "check_status.tsv")
with open(check_status_path, "a") as stream:
with open(check_status_path, "a", encoding="utf-8") as stream:
writer = csv.writer(stream, delimiter="\t", lineterminator="\n")
writer.writerow(status_row)

Expand All @@ -60,7 +65,7 @@ def __write_test_result(
):
all_stages = reports.keys()
test_results_path = os.path.join(out_dir, "test_results.tsv")
with open(test_results_path, "a") as stream:
with open(test_results_path, "a", encoding="utf-8") as stream:
writer = csv.writer(stream, delimiter="\t", lineterminator="\n")
for stage in all_stages:
report = reports[stage]
Expand Down Expand Up @@ -182,7 +187,7 @@ def calle(args):
input_dir, f"check statements:: not a dir {input_dir}"
)

reports = dict()
reports = {}

out_stages_dir = os.path.join(out_dir, f"{args.mode}-stages")

Expand Down Expand Up @@ -242,7 +247,7 @@ def calle(args):
input_dir, f"check statements:: not a dir {input_dir}"
)

reports = dict()
reports = {}

out_stages_dir = os.path.join(out_dir, f"{args.mode}-stages")

Expand Down Expand Up @@ -286,23 +291,25 @@ def make_actual_report(reports):
return {stage: report.get_map() for stage, report in reports.items()}


def write_actual_report(actial, out_dir):
with open(os.path.join(out_dir, "actual_report.json"), "w") as f:
f.write(json.dumps(actial))
def write_actual_report(actual, out_dir):
with open(os.path.join(out_dir, "actual_report.json"), "w", encoding="utf-8") as f:
f.write(json.dumps(actual))


def read_canonic_report(input_dir):
file = os.path.join(input_dir, "canonic_report.json")
if not os.path.exists(file):
return {}

with open(os.path.join(input_dir, "canonic_report.json"), "r") as f:
with open(
os.path.join(input_dir, "canonic_report.json"), "r", encoding="utf-8"
) as f:
data = f.read()
return json.loads(data)


def write_canonic_report(canonic, out_dir):
with open(os.path.join(out_dir, "canonic_report.json"), "w") as f:
with open(os.path.join(out_dir, "canonic_report.json"), "w", encoding="utf-8") as f:
f.write(json.dumps(canonic))


Expand Down Expand Up @@ -370,7 +377,7 @@ def calle(args):
if not os.path.isdir(out_dir):
raise NotADirectoryError(out_dir, f"self test: not a dir {out_dir}")

reports = dict()
reports = {}

out_stages_dir = os.path.join(out_dir, f"{args.mode}-stages")

Expand Down
Loading

0 comments on commit 770d710

Please sign in to comment.