From c5520b08cf6df4fd0511cbe471292d121af7a469 Mon Sep 17 00:00:00 2001 From: Oliver Mannion <125105+tekumara@users.noreply.github.com> Date: Sat, 24 Jun 2023 15:14:33 +1000 Subject: [PATCH] feat: support commit and rollback on connection resolves #6 --- fakesnow/fakes.py | 6 ++++++ tests/test_fakes.py | 17 +++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/fakesnow/fakes.py b/fakesnow/fakes.py index eaec2e9..30a7a12 100644 --- a/fakesnow/fakes.py +++ b/fakesnow/fakes.py @@ -396,6 +396,9 @@ def __exit__( ) -> bool: return False + def commit(self) -> None: + self.cursor().execute("COMMIT") + def cursor(self, cursor_class: Type[SnowflakeCursor] = SnowflakeCursor) -> FakeSnowflakeCursor: return FakeSnowflakeCursor(conn=self, duck_conn=self._duck_conn, use_dict_result=cursor_class == DictCursor) @@ -414,6 +417,9 @@ def execute_string( ] return cursors if return_cursors else [] + def rollback(self) -> None: + self.cursor().execute("ROLLBACK") + def _insert_df( self, df: pd.DataFrame, table_name: str, database: str | None = None, schema: str | None = None ) -> int: diff --git a/tests/test_fakes.py b/tests/test_fakes.py index b602e2a..f0481cb 100644 --- a/tests/test_fakes.py +++ b/tests/test_fakes.py @@ -462,6 +462,23 @@ def test_timestamp_to_date(cur: snowflake.connector.cursor.SnowflakeCursor): assert cur.fetchall() == [(datetime.date(1970, 1, 1), datetime.date(1970, 1, 1))] +def test_transactions(conn: snowflake.connector.SnowflakeConnection): + conn.execute_string( + """CREATE TABLE table1 (i int); + BEGIN TRANSACTION; + INSERT INTO table1 (i) VALUES (1);""" + ) + conn.rollback() + conn.execute_string( + """BEGIN TRANSACTION; + INSERT INTO table1 (i) VALUES (2);""" + ) + conn.commit() + with conn.cursor() as cur: + cur.execute("SELECT * FROM table1") + assert cur.fetchall() == [(2,)] + + def test_unquoted_identifiers_are_upper_cased(conn: snowflake.connector.SnowflakeConnection): with conn.cursor(snowflake.connector.cursor.DictCursor) as cur: cur.execute("create table customers (id int, first_name varchar, last_name varchar)")