Skip to content

Commit

Permalink
feat: Implements basic snowflake session variables via SET/UNSET (#111)
Browse files Browse the repository at this point in the history
Implements the very basic usages of sql variables
```
SET var2 = 'blue';
SELECT * from table where color = $var2;                  Produces: SELECT * from table where color = 'blue';
UNSET var2;
```
as described in
https://docs.snowflake.com/en/sql-reference/session-variables.

List of features supported and not supported by this PR:

- [x] Variables are scoped to the session (Eg. The connection, not the
cursor)
- [x] Simple scalar variables: SET var1 = 1;
- [x] Unset variables: UNSET var1;
- [x] Simple SQL expression variables: SET INCREMENTAL_DATE = DATEADD(
'DAY', -7, CURRENT_DATE());
- [x] Basic use of variables in SQL using $ syntax: SELECT $var1;
- [ ] Multiple variables: SET (var1, var2) = (1, 'hello');
- [ ] Variables set via 'properties' on the connection
https://docs.snowflake.com/en/sql-reference/session-variables#setting-variables-on-connection
- [ ] Using variables via the IDENTIFIER function: INSERT INTO
IDENTIFIER($my_table_name) (i) VALUES (42);
- [ ] Session variable functions:
https://docs.snowflake.com/en/sql-reference/session-variables#session-variable-functions
  • Loading branch information
jsibbison-square authored Jul 8, 2024
1 parent 7656ab9 commit 7696cbd
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 0 deletions.
7 changes: 7 additions & 0 deletions fakesnow/fakes.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import fakesnow.macros as macros
import fakesnow.transforms as transforms
from fakesnow.global_database import create_global_database
from fakesnow.variables import Variables

SCHEMA_UNSET = "schema_unset"
SQL_SUCCESS = "SELECT 'Statement executed successfully.' as 'status'"
Expand Down Expand Up @@ -134,6 +135,7 @@ def execute(
if os.environ.get("FAKESNOW_DEBUG") == "snowflake":
print(f"{command};{params=}" if params else f"{command};", file=sys.stderr)

command = self._inline_variables(command)
command, params = self._rewrite_with_params(command, params)
if self._conn.nop_regexes and any(re.match(p, command, re.IGNORECASE) for p in self._conn.nop_regexes):
transformed = transforms.SUCCESS_NOP
Expand All @@ -148,6 +150,7 @@ def execute(
def _transform(self, expression: exp.Expression) -> exp.Expression:
return (
expression.transform(transforms.upper_case_unquoted_identifiers)
.transform(transforms.update_variables, variables=self._conn.variables)
.transform(transforms.set_schema, current_database=self._conn.database)
.transform(transforms.create_database, db_path=self._conn.db_path)
.transform(transforms.extract_comment_on_table)
Expand Down Expand Up @@ -501,6 +504,9 @@ def convert(param: Any) -> Any: # noqa: ANN401

return command, params

def _inline_variables(self, sql: str) -> str:
return self._conn.variables.inline_variables(sql)


class FakeSnowflakeConnection:
def __init__(
Expand All @@ -525,6 +531,7 @@ def __init__(
self.db_path = Path(db_path) if db_path else None
self.nop_regexes = nop_regexes
self._paramstyle = snowflake.connector.paramstyle
self.variables = Variables()

create_global_database(duck_conn)

Expand Down
11 changes: 11 additions & 0 deletions fakesnow/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from sqlglot import exp

from fakesnow.global_database import USERS_TABLE_FQ_NAME
from fakesnow.variables import Variables

MISSING_DATABASE = "missing_database"
SUCCESS_NOP = sqlglot.parse_one("SELECT 'Statement executed successfully.'")
Expand Down Expand Up @@ -1407,6 +1408,16 @@ def show_keys(
return expression


def update_variables(
expression: exp.Expression,
variables: Variables,
) -> exp.Expression:
if Variables.is_variable_modifier(expression):
variables.update_variables(expression)
return SUCCESS_NOP # Nothing further to do if its a SET/UNSET operation.
return expression


class SHA256(exp.Func):
_sql_names: ClassVar = ["SHA256"]
arg_types: ClassVar = {"this": True}
Expand Down
57 changes: 57 additions & 0 deletions fakesnow/variables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import re

import snowflake.connector.errors
from sqlglot import exp


# Implements snowflake variables: https://docs.snowflake.com/en/sql-reference/session-variables#using-variables-in-sql
class Variables:
@classmethod
def is_variable_modifier(cls, expr: exp.Expression) -> bool:
return isinstance(expr, exp.Set) or cls._is_unset_expression(expr)

@classmethod
def _is_unset_expression(cls, expr: exp.Expression) -> bool:
if isinstance(expr, exp.Alias):
this_expr = expr.this.args.get("this")
return isinstance(this_expr, exp.Expression) and this_expr.this == "UNSET"
return False

def __init__(self) -> None:
self._variables = {}

def update_variables(self, expr: exp.Expression) -> None:
if isinstance(expr, exp.Set):
unset = expr.args.get("unset")
if not unset: # SET varname = value;
unset_expressions = expr.args.get("expressions")
assert unset_expressions, "SET without values in expression(s) is unexpected."
eq = unset_expressions[0].this
name = eq.this.sql()
value = eq.args.get("expression").sql()
self._set(name, value)
else:
# Haven't been able to produce this in tests yet due to UNSET being parsed as an Alias expression.
raise NotImplementedError("UNSET not supported yet")
elif self._is_unset_expression(expr): # Unfortunately UNSET varname; is parsed as an Alias expression :(
alias = expr.args.get("alias")
assert alias, "UNSET without value in alias attribute is unexpected."
name = alias.this
self._unset(name)

def _set(self, name: str, value: str) -> None:
self._variables[name] = value

def _unset(self, name: str) -> None:
self._variables.pop(name)

def inline_variables(self, sql: str) -> str:
for name, value in self._variables.items():
sql = re.sub(rf"\${name}", value, sql, flags=re.IGNORECASE)

remaining_variables = re.search(r"\$\w+", sql)
if remaining_variables:
raise snowflake.connector.errors.ProgrammingError(
msg=f"Session variable '{remaining_variables.group().upper()}' does not exist"
)
return sql
48 changes: 48 additions & 0 deletions tests/test_fakes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# pyright: reportOptionalMemberAccess=false
import datetime
import json
import re
import tempfile
from decimal import Decimal

Expand Down Expand Up @@ -1557,6 +1558,53 @@ def test_use_invalid_schema(_fakesnow: None):
)


# Snowflake SQL variables: https://docs.snowflake.com/en/sql-reference/session-variables#using-variables-in-sql
#
# Variables are scoped to the session (Eg. The connection, not the cursor)
# [x] Simple scalar variables: SET var1 = 1;
# [x] Unset variables: UNSET var1;
# [x] Simple SQL expression variables: SET INCREMENTAL_DATE = DATEADD( 'DAY', -7, CURRENT_DATE());
# [x] Basic use of variables in SQL using $ syntax: SELECT $var1;
# [ ] Multiple variables: SET (var1, var2) = (1, 'hello');
# [ ] Variables set via 'properties' on the connection https://docs.snowflake.com/en/sql-reference/session-variables#setting-variables-on-connection
# [ ] Using variables via the IDENTIFIER function: INSERT INTO IDENTIFIER($my_table_name) (i) VALUES (42);
# [ ] Session variable functions: https://docs.snowflake.com/en/sql-reference/session-variables#session-variable-functions
def test_variables(conn: snowflake.connector.SnowflakeConnection):
with conn.cursor() as cur:
cur.execute("SET var1 = 1;")
cur.execute("SET var2 = 'hello';")
cur.execute("SET var3 = DATEADD( 'DAY', -7, '2024-10-09');")

cur.execute("select $var1, $var2, $var3;")
assert cur.fetchall() == [(1, "hello", datetime.datetime(2024, 10, 2, 0, 0))]

cur.execute("CREATE TABLE example (id int, name varchar);")
cur.execute("INSERT INTO example VALUES (10, 'hello'), (20, 'world');")
cur.execute("select id, name from example where name = $var2;")
assert cur.fetchall() == [(10, "hello")]

cur.execute("UNSET var3;")
with pytest.raises(
snowflake.connector.errors.ProgrammingError, match=re.escape("Session variable '$VAR3' does not exist")
):
cur.execute("select $var3;")

# variables are scoped to the session, so they should be available in a new cursor.
with conn.cursor() as cur:
cur.execute("select $var1, $var2")
assert cur.fetchall() == [(1, "hello")]

# but not in a new connection.
with (
snowflake.connector.connect() as conn,
conn.cursor() as cur,
pytest.raises(
snowflake.connector.errors.ProgrammingError, match=re.escape("Session variable '$VAR1' does not exist")
),
):
cur.execute("select $var1;")


def test_values(conn: snowflake.connector.SnowflakeConnection):
with conn.cursor(snowflake.connector.cursor.DictCursor) as cur:
cur.execute("select * from VALUES ('Amsterdam', 1), ('London', 2)")
Expand Down

0 comments on commit 7696cbd

Please sign in to comment.