Skip to content

Commit

Permalink
Merge branch 'main' into jsummer/add-generation-sproc
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-jsummer committed Jan 27, 2025
2 parents 2bef25f + 3070eeb commit 3621ac4
Show file tree
Hide file tree
Showing 17 changed files with 1,034 additions and 85 deletions.
31 changes: 25 additions & 6 deletions .github/workflows/lint.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ name: Semantic Model Format & Lint

on:
pull_request:
- "*"
branches:
- "*"

jobs:
build:
Expand All @@ -20,39 +21,57 @@ jobs:
with:
python-version: ${{ matrix.python-version }}

# Caching dependencies using Poetry
- name: Cache Poetry virtualenv
uses: actions/cache@v2
with:
path: ~/.cache/pypoetry/virtualenvs
key: ${{ runner.os }}-poetry-${{ hashFiles('**/poetry.lock') }}
key: ${{ runner.os }}-poetry-${{ hashFiles('poetry.lock') }}
restore-keys: |
${{ runner.os }}-poetry-
- name: Install Poetry
run: |
curl -sSL https://install.python-poetry.org | python3 -
python3 -m pip install --user pipx
python3 -m pipx ensurepath
pipx install poetry
- name: Configure Poetry
run: |
$HOME/.local/bin/poetry config virtualenvs.create false
export PATH="$HOME/.local/bin:$PATH"
poetry config virtualenvs.create false
- name: Install dependencies using Poetry
run: |
$HOME/.local/bin/poetry install --no-interaction
poetry install --no-interaction
- name: Run mypy
id: mypy
run: |
make run_mypy
continue-on-error: true

- name: Check with black
id: black
run: |
make check_black
continue-on-error: true

- name: Check with isort
id: isort
run: |
make check_isort
continue-on-error: true

- name: Run flake8
id: flake8
run: |
make run_flake8
continue-on-error: true

- name: Report failures
run: |
if [ "${{ steps.black.outcome }}" != "success" ]; then echo "black failed"; FAIL=1; fi
if [ "${{ steps.isort.outcome }}" != "success" ]; then echo "isort failed"; FAIL=1; fi
if [ "${{ steps.flake8.outcome }}" != "success" ]; then echo "flake8 failed"; FAIL=1; fi
if [ "$FAIL" == "1" ]; then exit 1; fi
continue-on-error: false
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ pyvenv

# VSCode
.vscode/settings.json
.vscode/launch.json
.vscode/.ropeproject
.vscode/*.log
.vscode/*.json

# Jetbrains
.idea/*
Expand Down
12 changes: 6 additions & 6 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -38,25 +38,25 @@ run_mypy: ## Run mypy
mypy --config-file=mypy.ini .

run_flake8: ## Run flake8
flake8 --ignore=E203,E501,W503 --exclude=venv,pyvenv,tmp,*_pb2.py,*_pb2.pyi,images/*/src .
flake8 --ignore=E203,E501,W503 --exclude=venv,.venv,pyvenv,tmp,*_pb2.py,*_pb2.pyi,images/*/src .

check_black: ## Check to see if files would be updated with black.
# Exclude pyvenv and all generated protobuf code.
black --check --exclude="venv|pyvenv|.*_pb2.py|.*_pb2.pyi" .
black --check --exclude=".venv|venv|pyvenv|.*_pb2.py|.*_pb2.pyi" .

run_black: ## Run black to format files.
# Exclude pyvenv, tmp, and all generated protobuf code.
black --exclude="venv|pyvenv|tmp|.*_pb2.py|.*_pb2.pyi" .
black --exclude=".venv|venv|pyvenv|tmp|.*_pb2.py|.*_pb2.pyi" .

check_isort: ## Check if files would be updated with isort.
isort --profile black --check --skip=venv --skip=pyvenv --skip-glob='*_pb2.py*' .
isort --profile black --check --skip=venv --skip=pyvenv --skip=.venv --skip-glob='*_pb2.py*' .

run_isort: ## Run isort to update imports.
isort --profile black --skip=pyvenv --skip=venv --skip=tmp --skip-glob='*_pb2.py*' .
isort --profile black --skip=pyvenv --skip=venv --skip=tmp --skip=.venv --skip-glob='*_pb2.py*' .


fmt_lint: shell ## lint/fmt in current python environment
make run_mypy run_black run_isort run_flake8
make run_black run_isort run_flake8

# Test below
test: shell ## Run tests.
Expand Down
8 changes: 4 additions & 4 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
from app_utils.shared_utils import ( # noqa: E402
GeneratorAppScreen,
get_snowflake_connection,
set_sit_query_tag,
set_account_name,
set_host_name,
set_user_name,
set_streamlit_location,
set_sit_query_tag,
set_snowpark_session,
set_streamlit_location,
set_user_name,
)
from semantic_model_generator.snowflake_utils.env_vars import ( # noqa: E402
SNOWFLAKE_ACCOUNT_LOCATOR,
Expand All @@ -28,7 +28,7 @@ def failed_connection_popup() -> None:
Renders a dialog box detailing that the credentials provided could not be used to connect to Snowflake.
"""
st.markdown(
f"""It looks like the credentials provided could not be used to connect to the account."""
"""It looks like the credentials provided could not be used to connect to the account."""
)
st.stop()

Expand Down
Empty file added app_utils/__init__.py
Empty file.
4 changes: 2 additions & 2 deletions app_utils/chat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
import re
from typing import Dict, Any
from typing import Any, Dict

import requests
import streamlit as st
Expand Down Expand Up @@ -32,7 +32,7 @@ def send_message(

resp = _snowflake.send_snow_api_request( # type: ignore
"POST",
f"/api/v2/cortex/analyst/message",
"/api/v2/cortex/analyst/message",
{},
{},
request_body,
Expand Down
161 changes: 139 additions & 22 deletions app_utils/shared_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,19 @@

import json
import os
import tempfile
import time
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from io import StringIO
from typing import Any, Optional, List, Union
from typing import Any, Dict, List, Optional, Union

import pandas as pd
import streamlit as st
from snowflake.snowpark import Session
from PIL import Image
from snowflake.connector import ProgrammingError
from snowflake.connector.connection import SnowflakeConnection
from snowflake.snowpark import Session

from semantic_model_generator.data_processing.proto_utils import (
proto_to_yaml,
Expand All @@ -26,23 +26,20 @@
)
from semantic_model_generator.protos import semantic_model_pb2
from semantic_model_generator.protos.semantic_model_pb2 import Dimension, Table
from semantic_model_generator.snowflake_utils.env_vars import ( # noqa: E402
assert_required_env_vars,
)
from semantic_model_generator.snowflake_utils.snowflake_connector import (
SnowflakeConnector,
fetch_databases,
fetch_schemas_in_database,
fetch_stages_in_schema,
fetch_table_schema,
fetch_tables_views_in_schema,
fetch_warehouses,
fetch_stages_in_schema,
fetch_yaml_names_in_stage,
)

from semantic_model_generator.snowflake_utils.env_vars import ( # noqa: E402
SNOWFLAKE_ACCOUNT_LOCATOR,
SNOWFLAKE_HOST,
SNOWFLAKE_USER,
assert_required_env_vars,
)

SNOWFLAKE_ACCOUNT = os.environ.get("SNOWFLAKE_ACCOUNT_LOCATOR", "")

# Add a logo on the top-left corner of the app
Expand Down Expand Up @@ -103,6 +100,7 @@ def get_snowflake_connection() -> SnowflakeConnection:
if st.session_state["sis"]:
# Import SiS-required modules
import sys

from snowflake.snowpark.context import get_active_session

# Non-Anaconda supported packages must be added to path to import from stage
Expand Down Expand Up @@ -200,6 +198,132 @@ def get_available_stages(schema: str) -> List[str]:
return fetch_stages_in_schema(get_snowflake_connection(), schema)


@st.cache_resource(show_spinner=False)
def validate_table_schema(table: str, schema: Dict[str, str]) -> bool:
table_schema = fetch_table_schema(get_snowflake_connection(), table)
# validate columns names
if set(schema.keys()) != set(table_schema.keys()):
return False
# validate column types
for col_name, col_type in table_schema.items():
if not (schema[col_name] in col_type):
return False
return True


@st.cache_resource(show_spinner=False)
def validate_table_exist(schema: str, table_name: str) -> bool:
"""
Validate table exist in the Snowflake account.
Returns:
List[str]: A list of available stages.
"""
table_names = fetch_tables_views_in_schema(get_snowflake_connection(), schema)
table_names = [table.split(".")[2] for table in table_names]
if table_name.upper() in table_names:
return True
return False


def schema_selector_container(
db_selector: Dict[str, str], schema_selector: Dict[str, str]
) -> List[str]:
"""
Common component that encapsulates db/schema/table selection for the admin app.
When a db/schema/table is selected, it is saved to the session state for reading elsewhere.
Returns: None
"""
available_schemas = []
available_tables = []

# First, retrieve all databases that the user has access to.
eval_database = st.selectbox(
db_selector["label"],
options=get_available_databases(),
index=None,
key=db_selector["key"],
)
if eval_database:
# When a valid database is selected, fetch the available schemas in that database.
try:
available_schemas = get_available_schemas(eval_database)
except (ValueError, ProgrammingError):
st.error("Insufficient permissions to read from the selected database.")
st.stop()

eval_schema = st.selectbox(
schema_selector["label"],
options=available_schemas,
index=None,
key=schema_selector["key"],
format_func=lambda x: format_snowflake_context(x, -1),
)
if eval_schema:
# When a valid schema is selected, fetch the available tables in that schema.
try:
available_tables = get_available_tables(eval_schema)
except (ValueError, ProgrammingError):
st.error("Insufficient permissions to read from the selected schema.")
st.stop()

return available_tables


def table_selector_container(
db_selector: Dict[str, str],
schema_selector: Dict[str, str],
table_selector: Dict[str, str],
) -> Optional[str]:
"""
Common component that encapsulates db/schema/table selection for the admin app.
When a db/schema/table is selected, it is saved to the session state for reading elsewhere.
Returns: None
"""
available_schemas = []
available_tables = []

# First, retrieve all databases that the user has access to.
eval_database = st.selectbox(
db_selector["label"],
options=get_available_databases(),
index=None,
key=db_selector["key"],
)
if eval_database:
# When a valid database is selected, fetch the available schemas in that database.
try:
available_schemas = get_available_schemas(eval_database)
except (ValueError, ProgrammingError):
st.error("Insufficient permissions to read from the selected database.")
st.stop()

eval_schema = st.selectbox(
schema_selector["label"],
options=available_schemas,
index=None,
key=schema_selector["key"],
format_func=lambda x: format_snowflake_context(x, -1),
)
if eval_schema:
# When a valid schema is selected, fetch the available tables in that schema.
try:
available_tables = get_available_tables(eval_schema)
except (ValueError, ProgrammingError):
st.error("Insufficient permissions to read from the selected schema.")
st.stop()

tables = st.selectbox(
table_selector["label"],
options=available_tables,
index=None,
key=table_selector["key"],
format_func=lambda x: format_snowflake_context(x, -1),
)

return tables


def stage_selector_container() -> Optional[List[str]]:
"""
Common component that encapsulates db/schema/stage selection for the admin app.
Expand Down Expand Up @@ -986,15 +1110,12 @@ def show_yaml_in_dialog() -> None:

def upload_yaml(file_name: str) -> None:
"""util to upload the semantic model."""
import os
import tempfile

yaml = proto_to_yaml(st.session_state.semantic_model)

with tempfile.TemporaryDirectory() as temp_dir:
tmp_file_path = os.path.join(temp_dir, f"{file_name}.yaml")

with open(tmp_file_path, "w", encoding='utf-8') as temp_file:
with open(tmp_file_path, "w", encoding="utf-8") as temp_file:
temp_file.write(yaml)

st.session_state.session.file.put(
Expand Down Expand Up @@ -1047,17 +1168,13 @@ def model_is_validated() -> bool:

def download_yaml(file_name: str, stage_name: str) -> str:
"""util to download a semantic YAML from a stage."""
import os
import tempfile

with tempfile.TemporaryDirectory() as temp_dir:
# Downloads the YAML to {temp_dir}/{file_name}.
st.session_state.session.file.get(
f"@{stage_name}/{file_name}", temp_dir
)
st.session_state.session.file.get(f"@{stage_name}/{file_name}", temp_dir)

tmp_file_path = os.path.join(temp_dir, f"{file_name}")
with open(tmp_file_path, "r", encoding='utf-8') as temp_file:
with open(tmp_file_path, "r", encoding="utf-8") as temp_file:
# Read the raw contents from {temp_dir}/{file_name} and return it as a string.
yaml_str = temp_file.read()
return yaml_str
Expand Down Expand Up @@ -1263,7 +1380,7 @@ def model(self) -> Optional[str]:
return st.session_state.semantic_model.name # type: ignore
return None

def to_dict(self) -> dict[str, Union[str,None]]:
def to_dict(self) -> dict[str, Union[str, None]]:
return {
"User": self.user,
"Stage": self.stage,
Expand Down
Loading

0 comments on commit 3621ac4

Please sign in to comment.