Skip to content

Commit

Permalink
1. 将 csv 文件重命名,防止冲突
Browse files Browse the repository at this point in the history
2. 修改 csvreader, 使之完美支持 sql
3. 适配 by_csv 引擎
4. 优化 sqllite 工具类, 使之支持上述功能
  • Loading branch information
kem wan committed Dec 8, 2023
1 parent e7f974b commit 3894213
Show file tree
Hide file tree
Showing 4 changed files with 218 additions and 169 deletions.
77 changes: 66 additions & 11 deletions bricks/db/sqllite.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,24 @@
import contextlib
import json
import sqlite3
import subprocess


class SqlLite:

def __init__(self, name, structure: dict):
def __init__(self, name, structure: dict = None, database=":memory:"):
sqlite3.register_adapter(bool, int)
sqlite3.register_adapter(list, json.dumps)
sqlite3.register_adapter(list, json.dumps)
sqlite3.register_adapter(dict, json.dumps)
sqlite3.register_converter("BOOLEAN", lambda v: bool(int(v)))
sqlite3.register_converter("OBJECT", json.loads)

self._db = sqlite3.connect(":memory:", detect_types=sqlite3.PARSE_DECLTYPES)

self._columns = []
self.database = database
self.name = name
self._db = sqlite3.connect(database, detect_types=sqlite3.PARSE_DECLTYPES)
self._columns = []
self.structure: dict = structure
self.create_table(name, structure)
structure and self.create_table(name, structure)

@contextlib.contextmanager
def cursor(self) -> sqlite3.Cursor:
Expand Down Expand Up @@ -78,9 +78,16 @@ def delete(self, query: str):
cur.execute(sql)

def create_table(self, name, structure: dict):
python_to_sqlite_types = {
type(None): "NULL",
int: "INTEGER",
float: "REAL",
str: "TEXT",
bytes: "BLOB"
}
with self.cursor() as cur:
sql = f"CREATE TABLE IF NOT EXISTS {name}(" + ",".join(
[f'{k} {v}' if v else k for k, v in structure.items()]) + ")"
[f'{k} {python_to_sqlite_types.get(v, "TEXT")}' if v else k for k, v in structure.items()]) + ")"
cur.execute(sql)

@property
Expand All @@ -92,7 +99,55 @@ def columns(self):

return self._columns

def run_sql(self, sql: str):
with self.cursor() as cur:
cur.execute(sql)
return cur.fetchall()
def run_sql(self, sql: str) -> sqlite3.Cursor:
with self.cursor() as cursor:
cursor.execute(sql)
rows = cursor.fetchall()
columns = [description[0] for description in cursor.description]
for row in rows:
yield dict(zip(columns, row))

def execute(self, sql: str):
with self.cursor() as cursor:
cursor.execute(sql)
return cursor.fetchall()

@classmethod
def load_csv(
cls,
database: str,
table: str,
path: str,
structure: dict = None,
reload=True,
debug=False
):
"""
从 csv 中加载数据
:param structure:
:param database: 数据库名称
:param table: 表名
:param path: 路径
:param reload: 是否重新加载, 是的话如果数据库存在, 会先删除数据库
:param debug: debug 模式会
:return:
"""
conn = cls(database=database + ".db", name=table)

if reload:
conn.execute(f'DROP TABLE IF EXISTS {table};')
if structure:
conn.create_table(table, structure=structure)
cmd = f'sqlite3 {database}.db ".mode csv" ".import {path} {table}"'

options = {}
if not debug:
options.update({
"stdout": subprocess.DEVNULL,
"stderr": subprocess.DEVNULL,
})
subprocess.run(cmd, shell=True, text=True, **options)


return conn
108 changes: 61 additions & 47 deletions bricks/plugins/make_seeds.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,85 +5,99 @@
@Author : yintian
@Desc :
"""
import math
import re
import time
from typing import Optional, Union, List
from typing import Optional, Union
from urllib.parse import urlencode

from bricks.utils.csv import CsvReader
from bricks.utils.pandora import get_files_by_path
from bricks.utils.csv_ import CsvReader

LIMIT_PATTERN = re.compile(r"(LIMIT\s+)(\d+)", flags=re.IGNORECASE)
OFFSET_PATTERN = re.compile(r"(OFFSET\s+)(\d+)", flags=re.IGNORECASE)


def _make_key(query: dict):
return urlencode(query)


def by_csv(
path: Union[str, List[str]],
fields: list = None,
batch_size: int = 10000,
skip: Union[str, int] = None,
path: str,
query: str = None,
mapping: Optional[dict] = None,
stop: Optional[int] = math.inf,
reader_kwargs: Optional[dict] = None,
batch_size: int = 10000,
skip: Union[str, int] = ...,
reader_options: Optional[dict] = None,
record: Optional[dict] = None,
):
"""
从 CSV 中获取种子
从 CSV 中获取种子, csv 必须有表头
:param path: 文件路径
:param fields: 暂不支持
:param batch_size: 一次投多少种子
:param query: 查询 sql
:param batch_size: 一次获取多少条数据
:param skip: 跳过初始多少种子
:param query: 查询条件, 为 python 伪代码, 如 "a < 10 and 'bbb' in b.lower()"
:param mapping: 暂不支持
:param stop: 投到多少停止
:param reader_kwargs: 初始化 csv reader 的其他参数
:param reader_options: 初始化 csv reader 的其他参数
:param record:
:return:
"""
record = record or {}
reader_kwargs = reader_kwargs or {}
mapping = mapping or {}
if mapping or fields:
raise ValueError(f'暂不支持 mapping/fields 参数')
if skip is None:
reader_options = reader_options or {}
if skip is ...:
if record:
skip = 'auto'
else:
skip = 0

raw_skip = skip
total = 0
for file in get_files_by_path(path):
_record = {
query = query or "select * from <TABLE>"

for file in CsvReader.get_files(path):
reader = CsvReader(file, **reader_options)

record_key = _make_key({
"path": path,
"file": file,
"query": query,
}
record_key = urlencode(_record)
"skip": raw_skip,
})
if raw_skip == 'auto':
total = int(record.get('total', 0))
skip = int(record.get(record_key, 0))
else:
skip = raw_skip
with CsvReader(file_path=file, **reader_kwargs) as reader:
for row in reader.iter_data(
count=batch_size,
skip=skip,
fields=fields,
query=query
):
total += len(row)
skip += len(row)
if total > stop:
return
record.update({record_key: skip})
yield row

if skip != 0:
# 原来就有 offset
if OFFSET_PATTERN.search(query):
def add_skip(match):
# 将捕获的数字转换为整数,加上 skip,然后格式化回字符串
return f"{match.group(1)}{int(match.group(2)) + skip}"

query = OFFSET_PATTERN.sub(add_skip, query)

# 没有 offset 但是有 limit
elif not LIMIT_PATTERN.search(query):
query = query + f"LIMIT -1 OFFSET {skip}"

# offset 和 limit 都没有
else:
query = LIMIT_PATTERN.sub(r"\1\2 OFFSET " + str(skip), query)

gen = reader.iter_data(query)
seeds = []
for _ in range(batch_size):
for data in gen:
seeds.append(data)
skip += 1
else:
record.update({record_key: skip})
yield seeds


if __name__ == '__main__':
st = time.time()
for __ in range(1000):
with CsvReader(file_path='../../files/e.csv') as r:
for _ in r.iter_data(skip=10, query='int(a) % 3 == 0'):
# print(_)
pass
for d in by_csv(
path='/Users/Kem/Documents/bricks/bricks/utils/test.csv',
query="select cast(a as INTEGER) as a, b from <TABLE> where a =0",
# skip=1
):
print(d)
print(time.time() - st)
111 changes: 0 additions & 111 deletions bricks/utils/csv.py

This file was deleted.

Loading

0 comments on commit 3894213

Please sign in to comment.