Skip to content

Commit

Permalink
Let SQLAlchemy float
Browse files Browse the repository at this point in the history
  • Loading branch information
prescod committed Apr 6, 2024
1 parent c148699 commit 80076fc
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 13 deletions.
2 changes: 1 addition & 1 deletion requirements/prod.in
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
SQLAlchemy<2.0
SQLAlchemy
Faker
faker-nonprofit
faker-edu
Expand Down
10 changes: 8 additions & 2 deletions snowfakery/output_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,11 @@ def write_single_row(self, tablename: str, row: Dict) -> None:
self.buffered_rows[tablename].append(row)

def flush(self):
with self.session.begin():
self._flush_rows()
self.session.flush()

def _flush_rows(self):
for tablename, (insert_statement, fallback_dict) in self.table_info.items():
# Make sure every row has the same records per SQLAlchemy's rules

Expand All @@ -350,16 +355,15 @@ def flush(self):
if values:
self.session.execute(insert_statement, values)
self.buffered_rows[tablename] = []
self.session.flush()

def commit(self):
if any(self.buffered_rows):
self.flush()
self.session.commit()

def close(self, **kwargs) -> Optional[Sequence[str]]:
self.commit()
self.session.close()
self.engine.dispose()

def create_or_validate_tables(self, inferred_tables: Dict[str, TableInfo]) -> None:
try:
Expand Down Expand Up @@ -437,6 +441,8 @@ def _dump_db(self):
assert self.text_output.stream
self.text_output.stream.write("%s\n" % line)

con.close()

def close(self, *args, **kwargs):
self._dump_db()
self.sql_db.close(*args, **kwargs)
Expand Down
8 changes: 4 additions & 4 deletions snowfakery/tools/snowbench.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import locale

import click
from sqlalchemy import create_engine, inspect
from sqlalchemy import create_engine, inspect, text

from snowfakery import generate_data

Expand Down Expand Up @@ -50,7 +50,6 @@ def snowbench(
with TemporaryDirectory() as tempdir, click.progressbar(
label="Benchmarking", length=num_records, show_eta=False
) as progress_bar:

start = time()
Thread(
daemon=True,
Expand Down Expand Up @@ -163,14 +162,15 @@ def count_database(filename, counts):
dburl = f"sqlite:///{filename}?mode=ro"
engine = create_engine(dburl)
insp = inspect(engine)
tables = insp.get_table_names() # type: ignore
tables = insp.get_table_names() # type: ignore
for table in tables:
counts[table] += count_table(engine, table)
return counts


def count_table(engine, tablename):
return engine.execute(f"select count(Id) from '{tablename}'").first()[0]
with engine.connect() as c:
return c.execute(text(f"select count(Id) from '{tablename}'")).first()[0]


def snowfakery(recipe, num_records, tablename, outputfile):
Expand Down
19 changes: 15 additions & 4 deletions tests/test_output_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,14 +174,16 @@ def do_output(self, yaml, url=None):
with engine.connect() as connection:
tables = {
table_name: [
row._mapping
dict(row._mapping)
for row in connection.execute(
text(f"select * from {table_name}")
)
]
for table_name in table_names
}
return tables
engine.dispose()
del engine
return tables

def test_null(self):
yaml = """
Expand All @@ -206,10 +208,14 @@ def test_table_already_exists(self):
metadata.create_all(bind=engine)
with engine.begin() as c:
c.execute(t.insert().values([[5]]))
engine.dispose()

with pytest.raises(exc.DataGenError, match="Table already exists"):
output_stream = SqlDbOutputStream.from_url(url)
generate(StringIO(yaml), {}, output_stream)
try:
generate(StringIO(yaml), {}, output_stream)
finally:
output_stream.close()

def test_bad_database_connection(self):
yaml = """
Expand All @@ -222,7 +228,7 @@ def test_bad_database_connection(self):
self.do_output(yaml, "unknowndb://foo/bar/baz")

# missing driver
with pytest.raises(exc.DataGenError, match="fdb"):
with pytest.raises(exc.DataGenError, match="(fdb)|(firebird)"):
self.do_output(yaml, "firebird://foo/bar/baz")

# cannot connect
Expand Down Expand Up @@ -288,6 +294,7 @@ def test_json_output_mocked(self):
def test_from_cli(self):
x = StringIO()
with redirect_stdout(x):
assert generate_cli.callback
generate_cli.callback(yaml_file=sample_yaml, output_format="json")
data = json.loads(x.getvalue())
print(data)
Expand Down Expand Up @@ -363,6 +370,7 @@ def test_csv_output(self):
output_stream = CSVOutputStream(Path(t) / "csvoutput")
generate(StringIO(yaml), {}, output_stream)
messages = output_stream.close()
assert messages
assert "foo.csv" in messages[0]
assert "bar.csv" in messages[1]
assert "csvw" in messages[2]
Expand Down Expand Up @@ -417,6 +425,7 @@ class TestExternalOutputStream:
def test_external_output_stream(self):
x = StringIO()
with redirect_stdout(x):
assert generate_cli.callback
generate_cli.callback(
yaml_file=sample_yaml, output_format="package1.TestOutputStream"
)
Expand All @@ -430,6 +439,7 @@ def test_external_output_stream(self):
def test_external_output_stream_yaml(self):
x = StringIO()
with redirect_stdout(x):
assert generate_cli.callback
generate_cli.callback(
yaml_file=sample_yaml, output_format="examples.YamlOutputStream"
)
Expand All @@ -451,6 +461,7 @@ def test_external_output_stream_yaml(self):

def test_external_output_stream__failure(self):
with pytest.raises(ClickException, match="no.such.output.Stream"):
assert generate_cli.callback
generate_cli.callback(
yaml_file=sample_yaml, output_format="no.such.output.Stream"
)
4 changes: 2 additions & 2 deletions tests/test_with_cci.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from snowfakery.standard_plugins import Salesforce

try:
import cumulusci
import cumulusci # type: ignore
except ImportError:
cumulusci = False

Expand All @@ -43,7 +43,6 @@ def test_mapping_file(self):
],
standalone_mode=False,
)

engine = create_engine(url)
with engine.connect() as connection:
result = [
Expand All @@ -52,6 +51,7 @@ def test_mapping_file(self):
]
assert result[0]["id"] == 1
assert result[0]["BillingCountry"] == "Canada"
engine.dispose()


class FakeSimpleSalesforce:
Expand Down

0 comments on commit 80076fc

Please sign in to comment.