From 4fa6cadd374ac54418158131ca8a4e5b6546caca Mon Sep 17 00:00:00 2001 From: pmishchenko-ua <65538066+pmishchenko-ua@users.noreply.github.com> Date: Tue, 20 Aug 2024 21:01:25 +0300 Subject: [PATCH] Add client_found_rows option to connect() (#32) * Add client_found_rows option to connect() --------- Co-authored-by: Pavlo Mishchenko Co-authored-by: Kevin D Smith --- singlestoredb/config.py | 7 ++++++ singlestoredb/connection.py | 1 + singlestoredb/mysql/connection.py | 3 +++ singlestoredb/tests/test_connection.py | 33 ++++++++++++++++++++++++++ 4 files changed, 44 insertions(+) diff --git a/singlestoredb/config.py b/singlestoredb/config.py index 1dd120dfa..80bb292f9 100644 --- a/singlestoredb/config.py +++ b/singlestoredb/config.py @@ -103,6 +103,13 @@ environ='SINGLESTOREDB_MULTI_STATEMENTS', ) +register_option( + 'client_found_rows', 'bool', check_bool, False, + 'Should affected_rows in OK_PACKET indicate the ' + 'number of matched rows instead of changed?', + environ='SINGLESTOREDB_CLIENT_FOUND_ROWS', +) + register_option( 'ssl_key', 'str', check_str, None, 'File containing SSL key', diff --git a/singlestoredb/connection.py b/singlestoredb/connection.py index 663f9ec97..d840fbd23 100644 --- a/singlestoredb/connection.py +++ b/singlestoredb/connection.py @@ -1298,6 +1298,7 @@ def connect( program_name: Optional[str] = None, conn_attrs: Optional[Dict[str, str]] = None, multi_statements: Optional[bool] = None, + client_found_rows: Optional[bool] = None, connect_timeout: Optional[int] = None, nan_as_null: Optional[bool] = None, inf_as_null: Optional[bool] = None, diff --git a/singlestoredb/mysql/connection.py b/singlestoredb/mysql/connection.py index b388ee4b1..928685a05 100644 --- a/singlestoredb/mysql/connection.py +++ b/singlestoredb/mysql/connection.py @@ -347,6 +347,7 @@ def __init__( # noqa: C901 driver=None, # internal use conn_attrs=None, multi_statements=None, + client_found_rows=None, nan_as_null=None, inf_as_null=None, encoding_errors='strict', @@ -380,6 +381,8 @@ def __init__( # noqa: C901 client_flag |= CLIENT.LOCAL_FILES if multi_statements: client_flag |= CLIENT.MULTI_STATEMENTS + if client_found_rows: + client_flag |= CLIENT.FOUND_ROWS if read_default_group and not read_default_file: if sys.platform.startswith('win'): diff --git a/singlestoredb/tests/test_connection.py b/singlestoredb/tests/test_connection.py index 75f9c5f96..c0e73e815 100755 --- a/singlestoredb/tests/test_connection.py +++ b/singlestoredb/tests/test_connection.py @@ -2776,6 +2776,39 @@ def test_multi_statements(self): self.assertEqual([(2,)], list(cur)) self.assertIsNone(cur.nextset()) + def test_client_found_rows(self): + if self.conn.driver not in ['http', 'https']: + with s2.connect(database=type(self).dbname, client_found_rows=False) as conn: + with conn.cursor() as cur: + tag = str(uuid.uuid4()).replace('-', '_') + table_name = f'test_client_found_rows_{tag}' + cur.execute(f"CREATE TABLE {table_name} (id BIGINT \ + PRIMARY KEY, s TEXT DEFAULT 'def');") + cur.execute(f'INSERT INTO {table_name} (id) \ + VALUES (1), (2), (3);') + cur.execute(f"UPDATE {table_name} SET s = 'def' \ + WHERE id = 1;") + # UPDATE statement above is not changing any rows, + # so affected_rows is 0 if client_found_rows is False (default) + self.assertEqual(0, conn.affected_rows()) + cur.execute(f'DROP TABLE {table_name};') + + with s2.connect(database=type(self).dbname, client_found_rows=True) as conn: + with conn.cursor() as cur: + tag = str(uuid.uuid4()).replace('-', '_') + table_name = f'test_client_found_rows_{tag}' + cur.execute(f"CREATE TABLE {table_name} (id BIGINT \ + PRIMARY KEY, s TEXT DEFAULT 'def');") + cur.execute(f'INSERT INTO {table_name} (id) \ + VALUES (1), (2), (3);') + cur.execute(f"UPDATE {table_name} SET s = 'def' \ + WHERE id = 1;") + # UPDATE statement above is not changing any rows, + # but affected_rows is 1 as 1 row is subject to update, and + # this is what affected_rows return when client_found_rows is True + self.assertEqual(1, conn.affected_rows()) + cur.execute(f'DROP TABLE {table_name};') + def test_connect_timeout(self): with s2.connect(database=type(self).dbname, connect_timeout=8) as conn: with conn.cursor() as cur: