diff --git a/samples/ml/airflow/README.md b/samples/ml/airflow/README.md deleted file mode 100644 index 449275d..0000000 --- a/samples/ml/airflow/README.md +++ /dev/null @@ -1,270 +0,0 @@ -# Productionizing an XGBoost Training Script - -## Setup - -Generate synthetic data using [generate_data.sql](./scripts/generate_data.sql). -Adjust the `rowcount` value as desired to test performance with different data -sizes. - -### Connecting to Snowflake in Python - -The scripts included in this example use the `SnowflakeLoginOptions` utility API -from `snowflake-ml-python` to retrieve Snowflake connection settings from config -files must be authored before use. See [Configure Connections](https://docs.snowflake.com/developer-guide/snowflake-cli/connecting/configure-connections#define-connections) for information on how to define default -Snowflake connection(s) in a config.toml file - -```python -from snowflake.ml.utils.connection_params import SnowflakeLoginOptions - -# Requires valid ~/.snowflake/config.toml file -session = Session.builder.configs(SnowflakeLoginOptions()).create() -``` - -## Training Script - -The [src directory](./src/) contains the model training code which will comprise -the job payload. Note that the script only uses Snowpark APIs for data ingestion -and model registration; otherwise the script uses vanilla XGBoost and SKLearn for -model training and evaluation. This is the recommended approach for single-node -training in container runtimes. - -A `requirements.txt` file is not necessary since the script only depends on -common ML libraries like xgboost, sklearn, and snowflake-ml-python which are -installed by default on Container Runtime images. - -### Script Parameters - -- `--source_table` (OPTIONAL) Training data location. Defaults to `loan_applications` - which is created in the [setup step](#setup) -- `--save_mode` (OPTIONAL) Controls whether to save model to a local path or into Model Registry. Defaults to local -- `--output_dir` (OPTIONAL) Local save path. Only used if `save_mode=local` - -## Launch Job - -Set up compute pool -[compute_pool.sql](./scripts/compute_pool.sql) contains helpful SQL commands for -setting up an SPCS Compute Pool which can be used for this example. The main step -is simply: - -```sql -CREATE OR REPLACE COMPUTE POOL ML_DEMO_POOL - MIN_NODES = 1 - MAX_NODES = 10 - INSTANCE_FAMILY = HIGHMEM_X64_S; -``` - -This will create a basic compute pool using `HIGHMEM_X64_S` instances as the node type. -`HIGHMEM_X64_S` gives us nodes with 58 GiB of memory which can be helpful when operating -on large datasets such as `loan_applications_1b`. You may consider using a smaller -INSTANCE_FAMILY such as `CPU_X64_S` during small-scale experimentation to minimize costs. -Note that we recommend using nodes with at least 2 CPUs to avoid potential deadlocks. -See [CREATE COMPUTE POOL](https://docs.snowflake.com/en/sql-reference/sql/create-compute-pool) -documentation for more information on different instance families and their respective -resources. - -### Manual Job Execution - -In this example, we will upload the payload to a stage and mount that stage into -the job container. This approach removes the need to build and upload your own -container into the Snowflake Image Registry just to run your training script in -SPCS. You can upload your payload into a stage using the Snowsight UI or by running -the script below: - -```python -import os -from snowflake.snowpark import Session -from snowflake.ml.utils.connection_params import SnowflakeLoginOptions - -session = Session.builder.configs(SnowflakeLoginOptions()).create() - -# Create stage -stage_name = "@ML_DEMO_STAGE" -session.sql(f"create stage if not exists {stage_name.lstrip('@')}").collect() - -# Upload payload to stage -payload_path = "path/to/headless/scheduled-xgb/src/*" # NOTE: Update to reflect your local path -session.file.put( - payload_path, - stage_name, - overwrite=True, - auto_compress=False, -) -``` - -Once the payload has been uploaded to a stage, we can then launch the SPCS JOB -SERVICE. This requires authoring a [service specification](https://docs.snowflake.com/en/developer-guide/snowpark-container-services/specification-reference) -which we then pass to an [EXECUTE JOB SERVICE](https://docs.snowflake.com/en/sql-reference/sql/execute-job-service) -query. [submit_job.sql](./scripts/submit_job.sql) contains the SQL query for running -our job including the YAML service specification. - -#### Helper Script - -The included [submit_job.py](./scripts/job_utils/submit_job.py) script can automatically -generate the service specification and run the SQL query on your behalf. The -script can be invoked with default arguments which have been configured to work -with this example: - -```bash -> python headless/scheduled-xgb/scripts/submit_job.py -Generated service name: JOB_36E39504_6210_495D_A7A2_07816924D31E -Submitted job id: 01b828e7-0002-9808-0000-da071b345272 -``` - -You may also inspect the available arguments if you need to customize its behavior - -```bash -> python headless/scheduled-xgb/scripts/submit_job.py -h -``` - -### Scheduled Job Execution - -Jobs are commonly used as part of CI/CD pipelines. Pipelines and jobs may be -launched: - -- Explicitly by a manual or external action -- On a scheduled basis -- Based on event triggers - -Explicit triggers can be used to bridge job/pipeline frameworks, for instance -allowing [Airflow](#apache-airflow) DAGs to execute SPCS jobs. -Meanwhile, [Snowflake Tasks](https://docs.snowflake.com/en/user-guide/tasks-intro) -enable configuring scheduled and triggered DAGs natively in Snowflake. - -#### Airflow Integration - -In this example we will explore using [Apache Airflow](https://airflow.apache.org/) -to build a CD pipeline that runs our training script on a weekly basis. -We will define a task which executes the following steps in an SPCS container: - -1. Pull the latest code from a private GitHub repository -2. Run the training script on the latest production data -3. Save the trained model to a stage for downstream consumption - -First we will need to generate a GitHub PAT for our job agent to authenticate -with GitHub. This step is only necessary if we are pulling the code from a -private repository. See [GitHub's documentation](https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens) -for more information about PATs. - -Once we have our PAT, we can register it as a [Snowflake Secret](https://docs.snowflake.com/en/sql-reference/sql/create-secret). -We will also need to configure an [External Access Integration (EAI)](https://docs.snowflake.com/en/developer-guide/external-network-access/creating-using-external-network-access) -to enable external network access in our SPCS jobs. -[github_integration.sql](./scripts/github_integration.sql) contains the SQL -commands necessary to create both the secret and the EAI. -Be sure to replace fields indicated with `` appropriate values. - -We are finally ready to define our Airflow DAG. -Configure Airflow's [Snowflake connection](https://airflow.apache.org/docs/apache-airflow-providers-snowflake/stable/connections/snowflake.html#json-format-example) -using the `AIRFLOW_CONN_SNOWFLAKE_DEFAULT` environment variable. -Point Airflow's `dag_folders` setting to find [train_dag.py](./airflow/train_dag.py) -which contains the DAG definition and job submission logic. The DAG also -references our prepared service specification at [job_spec.yaml](./airflow/job_spec.yaml). -Make sure to modify the GitHub repository URL in the YAML file to point to your -own GitHub repository. - -```bash -# Pull GitHub repo using secret -# ACTION REQUIRED: Change the repository URL below to your repo! -git clone https://${GIT_TOKEN}@github.com/sfc-gh-dhung/test-repo.git $LOCAL_REPO_PATH -``` - -The current example DAG only contains a single task, but this can easily be -chained together with additional tasks such as upstream data preprocessing and -downstream model registration or inference/evaluation. - -## Observability - -### Job Monitoring - -SPCS JOB SERVICE executions are tied to query execution and can be inspected in -Snowsight under query history. Active and recent jobs may also be inspected using -[SHOW JOB SERVICES](https://docs.snowflake.com/en/sql-reference/sql/show-services), -[DESCIBE SERVICE](https://docs.snowflake.com/en/sql-reference/sql/desc-service), -and [SHOW SERVICE CONTAINERS IN SERVICE](https://docs.snowflake.com/en/sql-reference/sql/show-service-containers-in-service). - - -Success/failure notifications can be enabled through -[Snowflake Tasks](https://docs.snowflake.com/en/user-guide/tasks-errors). -Externally triggered jobs will need to manually configure alerting based on -job execution result. - -### Logging - -[SPCS Documentation](https://docs.snowflake.com/en/developer-guide/snowpark-container-services/monitoring-services#accessing-container-logs) -gives a good overview of logging options for SPCS jobs and services. In short: -- [SYSTEM$GET_SERVICE_LOGS](https://docs.snowflake.com/en/sql-reference/functions/system_get_service_logs) - retrieves container logs of an existing SERVICE or JOB SERVICE. - - ```sql - -- Get last 100 logs from container named 'main' in job service 'JOB_36E39504' - SELECT SYSTEM$GET_SERVICE_LOGS('JOB_36E39504', '0', 'main', 100) - ``` - - Previous runs of restarted containers and dropped services cannot be - inspected in this way. This includes JOB SERVICE entities which have - been automatically cleaned up after completion. -- Container console logs in SPCS are automatically captured to the account's active - [Event Table](https://docs.snowflake.com/en/developer-guide/logging-tracing/event-table-setting-up). - Log level may optionally be customized using the [spec.logExporters](https://docs.snowflake.com/en/developer-guide/snowpark-container-services/specification-reference#label-snowpark-containers-spec-reference-spec-logexporters). - service specification field. If not set, all logs will be captured by default. - - Recommendation is for ACCOUNTADMIN to create (filtered) VIEWS on top of - Event Table(s) and configure RBAC at the VIEW level - ```sql - USE ROLE ACCOUNTADMIN; - - -- Create database and schema to hold our views - CREATE DATABASE IF NOT EXISTS TELEMETRY_DB; - CREATE SCHEMA IF NOT EXISTS TELEMETRY_DB.SPCS_LOGS; - GRANT USAGE ON SCHEMA TELEMETRY_DB.SPCS_LOGS TO ROLE ; - - -- Create and grant VIEW to expose relevant logs to - CREATE VIEW TELEMETRY_DB.SPCS_LOGS.ML_DEMO_JOBS_V as ( - -- Default event table. Replace with your active event table if applicable. - SELECT * FROM SNOWFLAKE.TELEMETRY.EVENTS_VIEW - where 1=1 - and resource_attributes:"snow.database.name" = 'ML_DEMO_DB' - and resource_attributes:"snow.schema.name" = 'ML_DEMO_SCHEMA' - and resource_attributes:"snow.compute_pool.name" = 'ML_DEMO_POOL' - and resource_attributes:"snow.service.type" = 'Job' - ); - GRANT SELECT ON TELEMETRY_DB.SPCS_LOGS.ML_DEMO_JOBS_V TO ROLE ; - ``` - -DataDog supports integration with Event Tables. See -[this blog post](https://www.datadoghq.com/blog/snowflake-snowpark-monitoring-datadog/) -by DataDog for more information. - -### Compute Metrics - -Service level metrics such as CPU/GPU utilization, network activity, and disk -activity can be logged using the `spec.platformMonitor` service specification -field. -See [Accessing Event Table Service Metrics](https://docs.snowflake.com/en/developer-guide/snowpark-container-services/monitoring-services#accessing-event-table-service-metrics). - -Compute pool level metrics can be monitored and exported to visualizers like -DataDog and Grafana by setting up a monitor service in SPCS. See -[Tutorial: Grafana Visualization Service for Compute Pool Metrics](https://github.com/Snowflake-Labs/spcs-templates/blob/main/user-metrics/tutorial%20-%20visualize_metrics_using_grafana/Introduction.md). - -### Model Metrics - -[Snowflake Model Registry](https://docs.snowflake.com/en/developer-guide/snowflake-ml/model-registry/overview) -natively supports saving evaluation metrics when logging models -as shown in the [example training script](#training-script). -Such metrics are displayed in the Model Registry UI and are included in the -`SHOW VERSIONS IN MODEL ` output. - -We can also integrate with frameworks like -[MLflow](https://mlflow.org/) and [Weights and Biases (W&B)](https://wandb.ai/) -for live training progress monitoring and experiment tracking. -- We can run an MLflow tracking server on a separate SPCS service and securely -connect to it using [service-to-service](https://docs.snowflake.com/en/developer-guide/snowpark-container-services/working-with-services#service-to-service-communications) -communication. - - We recommend persisting tracking server state using a - a [block storage volume](https://docs.snowflake.com/en/developer-guide/snowpark-container-services/block-storage-volume) - for runs metadata (default `./mlruns`) - and a [stage volume](https://docs.snowflake.com/en/developer-guide/snowpark-container-services/snowflake-stage-volume) - for artifacts (default `./mlartifacts`) - See [Configure Server](https://mlflow.org/docs/latest/tracking/server.html#configure-server) - for how to configure your MLflow tracking server. - -- We can also connect to externally hosted MLflow tracking servers or W&B with - External Access Integrations for external network access. - See the [Single Node PyTorch Example](../single-node/README.md#weights-and-biases-integration) - for a full example with W&B integration. diff --git a/samples/ml/airflow/airflow/future_dag.py b/samples/ml/airflow/airflow/future_dag.py deleted file mode 100644 index 4f92bea..0000000 --- a/samples/ml/airflow/airflow/future_dag.py +++ /dev/null @@ -1,108 +0,0 @@ -import re -import json -from uuid import uuid4 -from datetime import datetime, timedelta, UTC - -from airflow.decorators import dag, task -from airflow.exceptions import AirflowException -from airflow.models import Variable -from airflow.models.taskinstance import TaskInstance -from snowflake.snowpark.context import get_active_session -from snowflake.ml.jobs import submit_job, get_job - -@dag( - schedule=timedelta(weeks=1), - start_date=datetime(2024, 11, 4, tzinfo=UTC), - catchup=False, -) -def future_ml_training_mockup(): - """ - This is **not** a working sample!! This is a mockup of what a DAG for - MLOps in Snowflake may look like in the future. `snowflake.ml.jobs` - does not currently exist and the APIs below are only for illustrative - purposes. - """ - - @task.snowpark() - def prepare_data(task_instance: TaskInstance | None = None): - session = get_active_session() - run_id = re.sub(r"[\-:+.]", "_", task_instance.run_id) - - # Kick off preprocessing job on SPCS - job = submit_job( - session=session, - repo_url="https://github.com/my_org/my_repo", - repo_tag="ci-verified", - entrypoint="src/prepare_data.py", - args=["--input_table", "DB.SCHEMA.MY_TABLE", "--output_path", "@DB.SCHEMA.MY_STAGE/"], - compute_pool="cpu_pool", - external_access_integrations=["pypi_eai"], - ) - - # Block until job completes - job.result() - - return run_id - - @task.snowpark() - def start_training_job(run_id: str, model_config: dict): - session = get_active_session() - - # We manually generate a job ID that we can also use as the model ID - job_id = str(uuid4()).replace('-', '_') - - job = submit_job( - session=session, - job_id=job_id, # Manually set job ID - repo_url="https://github.com/my_org/my_repo", - repo_tag="ci-verified", - entrypoint="src/train_model.py", - args=[ - "--input_data", f"@DB.SCHEMA.MY_STAGE/{run_id}", - "--output_path", f"@DB.SCHEMA.MODELS/{run_id}/{job_id}", - "--model_config", json.dumps(model_config), - ], - compute_pool="gpu_pool", - num_instances=4, - external_access_integrations=["pypi_eai"], - ) - - assert job.job_id == job_id - return job.job_id - - @task.snowpark_sensor(poke_interval=60, timeout=7200, mode="reschedule") - def wait_for_completion(job_id: str) -> bool: - session = get_active_session() - job = get_job(session, job_id) - if job.status == "COMPLETE": - print("Job completed. Logs:\n", job.get_logs()) - return True - elif job.status == "FAILED": - raise AirflowException("Job failed. Logs:\n %s" % job.get_logs()) - return False - - @task.snowpark() - def evaluate_model(run_id: str, model_id: str): - session = get_active_session() - - # Run eval job to completion and retrieve result - eval_result = submit_job( - session=session, - repo_url="https://github.com/my_org/my_repo", - repo_tag="ci-verified", - entrypoint="src/evaluate_model.py", - args=["--model_path", f"@DB.SCHEMA.MODELS/{run_id}/{model_id}", "--eval_data", "DB.SCHEMA.EVAL_DATA"], - compute_pool="gpu_pool", - num_instances=1, - external_access_integrations=["pypi_eai"], - ).result() - - print("Evaluation result:", eval_result) - - run_id = prepare_data() - configs = Variable.get("model_configs", deserialize_json=True) - for config in configs: - job_id = start_training_job(run_id, config) - wait_for_completion(job_id) >> evaluate_model(run_id, job_id) - -future_ml_training_mockup() diff --git a/samples/ml/airflow/airflow/job_spec.yaml b/samples/ml/airflow/airflow/job_spec.yaml deleted file mode 100644 index 4763879..0000000 --- a/samples/ml/airflow/airflow/job_spec.yaml +++ /dev/null @@ -1,50 +0,0 @@ -spec: - containers: - - name: main - image: /snowflake/images/snowflake_images/st_plat/runtime/x86/runtime_image/snowbooks:0.4.0 - command: - - bash - - -c - - |- - #!/bin/bash - - LOCAL_REPO_PATH="/ci-app" - - # Pull GitHub repo using secret - # ACTION REQUIRED: Change the repository URL below to your repo! - git clone https://${GIT_TOKEN}@github.com/sfc-gh-dhung/test-repo.git $LOCAL_REPO_PATH - - # Run train script - python "${LOCAL_REPO_PATH}/headless/scheduled-xgb/src/train.py" \ - --save_mode local \ - --output_dir /mnt/stage - secrets: - - snowflakeSecret: ML_DEMO_GIT_TOKEN - secretKeyRef: secret_string - envVarName: GIT_TOKEN - resources: - limits: - cpu: 6000m - memory: 58Gi - requests: - cpu: 6000m - memory: 58Gi - volumeMounts: - - mountPath: /var/log/managedservices/system/mlrs - name: system-logs - - mountPath: /var/log/managedservices/user/mlrs - name: user-logs - - mountPath: /dev/shm - name: dshm - - mountPath: /mnt/stage - name: stage-volume - volumes: - - name: system-logs - source: local - - name: user-logs - source: local - - name: dshm - size: 17Gi - source: memory - - name: stage-volume - source: '@ML_DEMO_STAGE' \ No newline at end of file diff --git a/samples/ml/airflow/airflow/requirements.txt b/samples/ml/airflow/airflow/requirements.txt deleted file mode 100644 index 0700752..0000000 --- a/samples/ml/airflow/airflow/requirements.txt +++ /dev/null @@ -1,2 +0,0 @@ -apache-airflow==2.10.2 -apache-airflow-providers-snowflake==5.8.0 \ No newline at end of file diff --git a/samples/ml/airflow/airflow/train_dag.py b/samples/ml/airflow/airflow/train_dag.py deleted file mode 100644 index 718b6f5..0000000 --- a/samples/ml/airflow/airflow/train_dag.py +++ /dev/null @@ -1,47 +0,0 @@ -import os -import re -import time -from datetime import datetime, timedelta, UTC -from textwrap import dedent - -from airflow.decorators import dag, task -from airflow.models.taskinstance import TaskInstance -from snowflake.snowpark import Session -from snowflake.snowpark.context import get_active_session - -@dag( - schedule=timedelta(weeks=1), - start_date=datetime(2024, 11, 4, tzinfo=UTC), - catchup=False, -) -def ml_training_example(): - @task.snowpark() - def run_training_job(session: Session, task_instance: TaskInstance | None = None): - service_name = f"ML_DEMO_{task_instance.run_id}" - service_name = re.sub(r"[\-:+.]", "_", service_name).upper() - query_template = dedent("""\ - EXECUTE JOB SERVICE - IN COMPUTE POOL ML_DEMO_POOL - FROM SPECIFICATION $$ - {spec} - $$ - NAME = {service_name} - EXTERNAL_ACCESS_INTEGRATIONS = (GITHUB_EAI) - QUERY_WAREHOUSE = ML_DEMO_WH - """) - - spec_path = os.path.join(os.path.dirname(__file__), "job_spec.yaml") - with open(spec_path, "r") as spec_file: - spec = spec_file.read() - query = query_template.format(service_name=service_name, spec=spec) - - try: - session.sql(query).collect() - return service_name - finally: - (logs,) = session.sql(f"CALL SYSTEM$GET_SERVICE_LOGS('{service_name}', '0', 'main', 500)").collect() - print("Console logs:\n", logs[0]) - - job_id = run_training_job() - -ml_training_example() diff --git a/samples/ml/airflow/scripts/compute_pool.sql b/samples/ml/airflow/scripts/compute_pool.sql deleted file mode 100644 index 2b3b1fb..0000000 --- a/samples/ml/airflow/scripts/compute_pool.sql +++ /dev/null @@ -1,14 +0,0 @@ --- Basic setup -USE ROLE SYSADMIN; -CREATE DATABASE IF NOT EXISTS ML_DEMO_DB; -CREATE SCHEMA IF NOT EXISTS ML_DEMO_SCHEMA; - --- Set up compute pool -CREATE OR REPLACE COMPUTE POOL ML_DEMO_POOL - MIN_NODES = 1 - MAX_NODES = 10 - INSTANCE_FAMILY = HIGHMEM_X64_S; - --- Start compute pool if necessary -SHOW COMPUTE POOLS LIKE 'ML_DEMO_POOL'; -ALTER COMPUTE POOL ML_DEMO_POOL RESUME; \ No newline at end of file diff --git a/samples/ml/airflow/scripts/generate_data.sql b/samples/ml/airflow/scripts/generate_data.sql deleted file mode 100644 index 9c7fdbb..0000000 --- a/samples/ml/airflow/scripts/generate_data.sql +++ /dev/null @@ -1,47 +0,0 @@ --- Basic setup -USE ROLE SYSADMIN; -CREATE OR REPLACE WAREHOUSE ML_DEMO_WH; --by default, this creates an XS Standard Warehouse -CREATE OR REPLACE DATABASE ML_DEMO_DB; -CREATE OR REPLACE SCHEMA ML_DEMO_SCHEMA; - --- Create 10M rows of synthetic data -CREATE OR REPLACE TABLE loan_applications AS -SELECT - ROW_NUMBER() OVER (ORDER BY RANDOM()) as application_id, - ROUND(NORMAL(40, 10, RANDOM())) as age, - ROUND(NORMAL(65000, 20000, RANDOM())) as income, - ROUND(NORMAL(680, 50, RANDOM())) as credit_score, - ROUND(NORMAL(5, 2, RANDOM())) as employment_length, - ROUND(NORMAL(25000, 8000, RANDOM())) as loan_amount, - ROUND(NORMAL(35, 10, RANDOM()), 2) as debt_to_income, - ROUND(NORMAL(5, 2, RANDOM())) as number_of_credit_lines, - GREATEST(0, ROUND(NORMAL(1, 1, RANDOM()))) as previous_defaults, - ARRAY_CONSTRUCT( - 'home_improvement', 'debt_consolidation', 'business', 'education', - 'major_purchase', 'medical', 'vehicle', 'other' - )[UNIFORM(1, 8, RANDOM())] as loan_purpose, - RANDOM() < 0.15 as is_default, - TIMEADD("MINUTE", UNIFORM(-525600, 0, RANDOM()), CURRENT_TIMESTAMP()) as created_at -FROM TABLE(GENERATOR(rowcount => 10000000)) -ORDER BY created_at; - --- Create 1B rows of synthetic data -CREATE OR REPLACE TABLE loan_applications_1b AS -SELECT - ROW_NUMBER() OVER (ORDER BY RANDOM()) as application_id, - ROUND(NORMAL(40, 10, RANDOM())) as age, - ROUND(NORMAL(65000, 20000, RANDOM())) as income, - ROUND(NORMAL(680, 50, RANDOM())) as credit_score, - ROUND(NORMAL(5, 2, RANDOM())) as employment_length, - ROUND(NORMAL(25000, 8000, RANDOM())) as loan_amount, - ROUND(NORMAL(35, 10, RANDOM()), 2) as debt_to_income, - ROUND(NORMAL(5, 2, RANDOM())) as number_of_credit_lines, - GREATEST(0, ROUND(NORMAL(1, 1, RANDOM()))) as previous_defaults, - ARRAY_CONSTRUCT( - 'home_improvement', 'debt_consolidation', 'business', 'education', - 'major_purchase', 'medical', 'vehicle', 'other' - )[UNIFORM(1, 8, RANDOM())] as loan_purpose, - RANDOM() < 0.15 as is_default, - TIMEADD("MINUTE", UNIFORM(-525600, 0, RANDOM()), CURRENT_TIMESTAMP()) as created_at -FROM TABLE(GENERATOR(rowcount => 1000000000)) -ORDER BY created_at; \ No newline at end of file diff --git a/samples/ml/airflow/scripts/github_integration.sql b/samples/ml/airflow/scripts/github_integration.sql deleted file mode 100644 index 2e232bf..0000000 --- a/samples/ml/airflow/scripts/github_integration.sql +++ /dev/null @@ -1,15 +0,0 @@ --- Save a GitHub access token as a secret -CREATE OR REPLACE SECRET ML_DEMO_GIT_TOKEN - TYPE = GENERIC_STRING - SECRET_STRING = ''; - --- Create External Access Integration to allow network access to GitHub -CREATE OR REPLACE NETWORK RULE GITHUB_NETWORK_RULE - MODE = EGRESS - TYPE = HOST_PORT - VALUE_LIST = ('github.com:443'); - -CREATE OR REPLACE EXTERNAL ACCESS INTEGRATION GITHUB_EAI - ALLOWED_NETWORK_RULES = (GITHUB_NETWORK_RULE) - ENABLED = true; -GRANT USAGE ON INTEGRATION GITHUB_EAI TO ROLE ; \ No newline at end of file diff --git a/samples/ml/airflow/scripts/job_utils/spec_utils.py b/samples/ml/airflow/scripts/job_utils/spec_utils.py deleted file mode 100644 index 56d667d..0000000 --- a/samples/ml/airflow/scripts/job_utils/spec_utils.py +++ /dev/null @@ -1,362 +0,0 @@ -import os -import re -from math import ceil -from dataclasses import dataclass -from pathlib import Path -from typing import Any, Dict, List, Literal, Optional, Union - -from snowflake.snowpark import Session - -# See https://docs.snowflake.com/en/sql-reference/identifiers-syntax for identifier syntax -UNQUOTED_IDENTIFIER_REGEX = r"([a-zA-Z_])([a-zA-Z0-9_$]{0,254})" -QUOTED_IDENTIFIER_REGEX = r'"((""|[^"]){0,255})"' -VALID_IDENTIFIER_REGEX = f"(?:{UNQUOTED_IDENTIFIER_REGEX}|{QUOTED_IDENTIFIER_REGEX})" -_SECRET_IDENTIFIER_REGEX = rf"(?P(?:(?:{VALID_IDENTIFIER_REGEX})?[.]{VALID_IDENTIFIER_REGEX}[.])?(?P{VALID_IDENTIFIER_REGEX}))" -_SECRET_CONFIG_REGEX = rf"(?:(?P\w+)=)?{_SECRET_IDENTIFIER_REGEX}(?:[.](?Pusername|password))?" - - -@dataclass -class _SecretConfig: - name: str - fqn: str - subkey: str - mount_path: str - # TODO: Add support for file mount - mount_type: Literal["environment"] = "environment" - - -def _parse_secret_config(s: str) -> _SecretConfig: - m = re.fullmatch(_SECRET_CONFIG_REGEX, s) - if not m: - raise ValueError(f"{s} is not a valid secret config string") - name, fqn = m.group("name"), m.group("fqn") - subkey = m.group("subkey") or "secret_string" - mount_path = m.group("mount_path") or name.upper() - - # Validate (inferred) mount_path - # TODO: Do different validation based on mount type (env var vs directory) - if not re.fullmatch(r"\w+", mount_path): - raise ValueError( - f"Failed to infer secret placement. Please explicitly specify placement in format 'ENV_VAR_NAME=SECRET_NAME'" - ) - - return _SecretConfig(name=name, fqn=fqn, mount_path=mount_path, subkey=subkey) - -@dataclass -class _ComputeResources: - cpu: float # Number of vCPU cores - memory: float # Memory in GiB - gpu: int = 0 # Number of GPUs - gpu_type: Optional[str] = None - - -@dataclass -class _ImageSpec: - repo: str - arch: str - family: str - tag: str - resource_requests: _ComputeResources - resource_limits: _ComputeResources - - @property - def full_name(self) -> str: - return f"{self.repo}/st_plat/runtime/{self.arch}/{self.family}:{self.tag}" - - -# TODO: Query Snowflake for resource information instead of relying on this hardcoded -# table from https://docs.snowflake.com/en/sql-reference/sql/create-compute-pool -_COMMON_INSTANCE_FAMILIES = { - "CPU_X64_XS": _ComputeResources(cpu=1, memory=6), - "CPU_X64_S": _ComputeResources(cpu=3, memory=13), - "CPU_X64_M": _ComputeResources(cpu=6, memory=28), - "CPU_X64_L": _ComputeResources(cpu=28, memory=116), - "HIGHMEM_X64_S": _ComputeResources(cpu=6, memory=58), -} -_AWS_INSTANCE_FAMILIES = { - "HIGHMEM_X64_M": _ComputeResources(cpu=28, memory=240), - "HIGHMEM_X64_L": _ComputeResources(cpu=124, memory=984), - "GPU_NV_S": _ComputeResources(cpu=6, memory=27, gpu=1, gpu_type="A10G"), - "GPU_NV_M": _ComputeResources(cpu=44, memory=178, gpu=4, gpu_type="A10G"), - "GPU_NV_L": _ComputeResources(cpu=92, memory=1112, gpu=8, gpu_type="A100"), -} -_AZURE_INSTANCE_FAMILIES = { - "HIGHMEM_X64_M": _ComputeResources(cpu=28, memory=244), - "HIGHMEM_X64_L": _ComputeResources(cpu=92, memory=654), - "GPU_NV_XS": _ComputeResources(cpu=3, memory=26, gpu=1, gpu_type="T4"), - "GPU_NV_SM": _ComputeResources(cpu=32, memory=424, gpu=1, gpu_type="A10"), - "GPU_NV_2M": _ComputeResources(cpu=68, memory=858, gpu=2, gpu_type="A10"), - "GPU_NV_3M": _ComputeResources(cpu=44, memory=424, gpu=2, gpu_type="A100"), - "GPU_NV_SL": _ComputeResources(cpu=92, memory=858, gpu=4, gpu_type="A100"), -} -_CLOUD_INSTANCE_FAMILIES = { - "aws": _AWS_INSTANCE_FAMILIES, - "azure": _AZURE_INSTANCE_FAMILIES, -} - - -def _get_node_resources(session: Session, compute_pool: str) -> _ComputeResources: - """Extract resource information for the specified compute pool""" - # Get the instance family - (row,) = session.sql(f"show compute pools like '{compute_pool}'").collect() - instance_family: str = row["instance_family"] - - # Get the cloud we're using (AWS, Azure, etc) - (row,) = session.sql(f"select current_region()").collect() - region: str = row[0] - region_group, region_name = f".{region}".split(".")[ - -2: - ] # Prepend a period so we always get at least 2 splits - regions = session.sql(f"show regions like '{region_name}'").collect() - if region_group: - regions = [r for r in regions if r["region_group"] == region_group] - cloud = regions[0]["cloud"] - - return ( - _COMMON_INSTANCE_FAMILIES.get(instance_family) - or _CLOUD_INSTANCE_FAMILIES[cloud][instance_family] - ) - - -def _get_image_spec(session: Session, compute_pool: str) -> _ImageSpec: - # Retrieve compute pool node resources - resources = _get_node_resources(session, compute_pool=compute_pool) - - # Use MLRuntime image - # TODO: Build new image if needed - image_repo = "/snowflake/images/snowflake_images" - image_arch = "x86" - image_family = ( - "generic_gpu/runtime_image/snowbooks" - if resources.gpu > 0 - else "runtime_image/snowbooks" - ) - image_tag = "0.4.0" - - # Try to pull latest image tag from server side if possible - query_result = session.sql( - f"SHOW PARAMETERS LIKE 'RUNTIME_BASE_IMAGE_TAG' IN ACCOUNT" - ).collect() - if query_result: - image_tag = query_result[0]["value"] - - # TODO: Should each instance consume the entire pod? - return _ImageSpec( - repo=image_repo, - arch=image_arch, - family=image_family, - tag=image_tag, - resource_requests=resources, - resource_limits=resources, - ) - - -def _generate_spec( - image_spec: _ImageSpec, - stage_path: Path, - script_path: Path, - args: Optional[List[str]] = None, - env_vars: Optional[Dict[str, str]] = None, - secrets: Optional[List[str]] = None, -) -> dict: - volumes: List[Dict[str, str]] = [] - volume_mounts: List[Dict[str, str]] = [] - - # Set resource requests/limits, including nvidia.com/gpu quantity if applicable - resource_requests: Dict[str, Union[str, int]] = { - "cpu": f"{image_spec.resource_requests.cpu * 1000}m", - "memory": f"{image_spec.resource_limits.memory}Gi", - } - resource_limits: Dict[str, Union[str, int]] = { - "cpu": f"{image_spec.resource_requests.cpu * 1000}m", - "memory": f"{image_spec.resource_limits.memory}Gi", - } - if image_spec.resource_limits.gpu > 0: - resource_requests["nvidia.com/gpu"] = image_spec.resource_requests.gpu - resource_limits["nvidia.com/gpu"] = image_spec.resource_limits.gpu - - # Create container spec - main_container: Dict[str, Any] = { - "name": "main", - "image": image_spec.full_name, - "volumeMounts": volume_mounts, - "resources": { - "requests": resource_requests, - "limits": resource_limits, - }, - } - - # Add local volumes for ephemeral logs and artifacts - for volume_name, mount_path in [ - ("system-logs", "/var/log/managedservices/system/mlrs"), - ("user-logs", "/var/log/managedservices/user/mlrs"), - ]: - volume_mounts.append( - { - "name": volume_name, - "mountPath": mount_path, - } - ) - volumes.append( - { - "name": volume_name, - "source": "local", - } - ) - - # Mount 30% of memory limit as a memory-backed volume - memory_volume_name = "dshm" - memory_volume_size = min( - ceil(image_spec.resource_limits.memory * 0.3), - image_spec.resource_requests.memory, - ) - volume_mounts.append( - { - "name": memory_volume_name, - "mountPath": "/dev/shm", - } - ) - volumes.append( - { - "name": memory_volume_name, - "source": "memory", - "size": f"{memory_volume_size}Gi", - } - ) - - # Mount payload as volume - stage_mount = "/opt/app" - stage_volume_name = "stage-volume" - volume_mounts.append( - { - "name": stage_volume_name, - "mountPath": stage_mount, - } - ) - volumes.append( - { - "name": stage_volume_name, - "source": str(stage_path), - } - ) - - # TODO: Add hooks for endpoints for integration with TensorBoard, W&B, etc - - # Propagate user payload config - commands = { - ".py": "python", - ".sh": "bash", - ".rb": "ruby", - ".pl": "perl", - ".js": "node", - # Add more formats as needed - } - command = commands[script_path.suffix] - main_container["command"] = [ - command, - os.path.join(stage_mount, script_path), - *(args or []), - ] - - if env_vars: - main_container["env"] = env_vars - - if secrets: - secrets_spec = [] - for s in secrets: - # TODO: Add support for other secret types (e.g. username/password) - # TODO: Add support for other mount types - secret = _parse_secret_config(s) - assert secret.mount_type == "environment" - secrets_spec.append( - { - "snowflakeSecret": secret.fqn, - "envVarName": secret.mount_path, - "secretKeyRef": secret.subkey, - } - ) - main_container["secrets"] = secrets_spec - - return { - "spec": { - "containers": [main_container], - "volumes": volumes, - } - } - - -def _prepare_payload( - session: Session, - stage_path: Path, - source: Path, - entrypoint: Path, -) -> Path: - """Load payload onto stage""" - # TODO: Detect if source is a git repo or existing stage - if not entrypoint.exists(): - entrypoint = source / entrypoint - if not (source.exists() and entrypoint.exists()): - raise FileNotFoundError(f"{source} or {entrypoint} does not exist") - - # Create stage if necessary - stage_name = stage_path.parts[0] - session.sql(f"create stage if not exists {stage_name.lstrip('@')}").collect() - - # Upload payload to stage - if source.is_dir(): - # Filter to only files in source since Snowflake PUT can't handle directories - for path in set( - p.parent.joinpath(f"*{p.suffix}") if p.suffix else p - for p in source.rglob("*") - if p.is_file() - ): - session.file.put( - str(path.resolve()), - str(stage_path.joinpath(path.parent.relative_to(source))), - overwrite=True, - auto_compress=False, - ) - else: - session.file.put( - str(source.resolve()), - str(stage_path), - overwrite=True, - auto_compress=False, - ) - - return entrypoint.relative_to(source) - - -def prepare_spec( - session: Session, - service_name: str, - compute_pool: str, - stage_name: str, - payload: Path, - entrypoint: Path, - args: Optional[List[str]] = None, - secrets: Optional[List[str]] = None, - env: Optional[Dict[str, str]] = None, -) -> Dict[str, Any]: - - # Generate image spec based on compute pool - image_spec = _get_image_spec(session, compute_pool=compute_pool) - - # Prepare payload - stage_path = Path(f"@{stage_name}/{service_name}") - script_path = _prepare_payload( - session, - stage_path, - source=payload, - entrypoint=entrypoint, - ) - - spec = _generate_spec( - image_spec=image_spec, - stage_path=stage_path, - script_path=script_path, - args=args, - env_vars=env, - secrets=secrets, - ) - return spec diff --git a/samples/ml/airflow/scripts/job_utils/submit_job.py b/samples/ml/airflow/scripts/job_utils/submit_job.py deleted file mode 100644 index ba006d4..0000000 --- a/samples/ml/airflow/scripts/job_utils/submit_job.py +++ /dev/null @@ -1,112 +0,0 @@ -import os -import yaml -import time -from textwrap import dedent -from uuid import uuid4 -from pathlib import Path -from typing import List, Optional - -from snowflake.snowpark import Session -from snowflake.snowpark.exceptions import SnowparkSessionException -from snowflake.snowpark.context import get_active_session -from snowflake.ml.utils.connection_params import SnowflakeLoginOptions - -from spec_utils import prepare_spec - -def _get_session() -> Session: - try: - return get_active_session() - except SnowparkSessionException: - # Initialize Snowflake session - # See https://docs.snowflake.com/developer-guide/snowflake-cli/connecting/configure-connections#define-connections - # for how to define default connections in a config.toml file - return Session.builder.configs(SnowflakeLoginOptions("preprod8")).create() - -def submit_job( - compute_pool: str, - stage_name: str, - payload_path: Path, - service_name: Optional[str] = None, - entrypoint: Optional[Path] = None, - entrypoint_args: Optional[List[str]] = None, - external_access_integrations: Optional[List[str]] = None, - query_warehouse: Optional[str] = None, - comment: Optional[str] = None, - dry_run: bool = False, -) -> Optional[str]: - session = _get_session() - if not service_name: - service_name = "JOB_" + str(uuid4()).replace('-', '_').upper() - print("Generated service name: %s" % service_name) - if not entrypoint: - if payload_path.is_file(): - raise ValueError("entrypoint is required if payload_path is not a file") - entrypoint = payload_path - - spec = prepare_spec( - session=session, - service_name=service_name, - compute_pool=compute_pool, - stage_name=stage_name, - payload=payload_path, - entrypoint=entrypoint, - args=entrypoint_args, - ) - - query_template = dedent(f"""\ - EXECUTE JOB SERVICE - IN COMPUTE POOL {compute_pool} - FROM SPECIFICATION $$ - {{}} - $$ - NAME = {service_name} - """) - query = query_template.format(yaml.dump(spec)).splitlines() - if external_access_integrations: - external_access_integration_list = ",".join( - f"{e}" for e in external_access_integrations - ) - query.append( - f"EXTERNAL_ACCESS_INTEGRATIONS = ({external_access_integration_list})" - ) - if query_warehouse: - query.append(f"QUERY_WAREHOUSE = {query_warehouse}") - if comment: - query.append(f"COMMENT = {comment}") - - query_text = "\n".join(line for line in query if line) - if dry_run: - print("\n================= GENERATED SERVICE SPEC =================") - print(query_text) - print("==================== END SERVICE SPEC ====================\n") - return None - - try: - async_job = session.sql(query_text).collect_nowait() - time.sleep(0.1) # Check if query failed "immediately" before exiting - _ = session.connection.get_query_status_throw_if_error(async_job.query_id) - return async_job.query_id - finally: - session.connection.close() - -if __name__ == '__main__': - import argparse - - default_payload_path = os.path.realpath(os.path.join(os.path.dirname(__file__), "..", "..", "src")) - default_entrypoint = "train.py" - - parser = argparse.ArgumentParser("SPCS Job Launcher") - parser.add_argument("-c", "--compute_pool", default="ML_DEMO_POOL", type=str, help="Compute pool to use") - parser.add_argument("-s", "--stage_name", default="ML_DEMO_STAGE", type=str, help="Stage for payload upload and job artifacts") - parser.add_argument("-p", "--payload_path", default=default_payload_path, type=Path, help="Path to local payload") - parser.add_argument("-e", "--entrypoint", default=default_entrypoint, type=Path, help="Relative path to entrypoint in payload") - parser.add_argument("-a", "--entrypoint_args", type=str, nargs='*', help="Arguments to pass to entrypoint") - parser.add_argument("-n", "--service_name", type=str, help="Name for created job service") - parser.add_argument("--external_access_integrations", type=str, nargs='*', help="External access integrations to enable in service") - parser.add_argument("--query_warehouse", type=str, help="Warehouse to use for queries executed by service") - parser.add_argument("--comment", type=str, help="Comment to add to created service") - parser.add_argument("--dry_run", action='store_true', help="Flag to enable dry run mode") - args = parser.parse_args() - - job_id = submit_job(**vars(args)) - print("Submitted job id: " + (job_id or "")) \ No newline at end of file diff --git a/samples/ml/airflow/scripts/submit_job.sql b/samples/ml/airflow/scripts/submit_job.sql deleted file mode 100644 index 3a83944..0000000 --- a/samples/ml/airflow/scripts/submit_job.sql +++ /dev/null @@ -1,40 +0,0 @@ -EXECUTE JOB SERVICE -IN COMPUTE POOL ML_DEMO_POOL -FROM SPECIFICATION $$ -spec: - containers: - - command: - - python - - /opt/app/train.py - image: /snowflake/images/snowflake_images/st_plat/runtime/x86/runtime_image/snowbooks:0.6.0 - name: main - resources: - limits: - cpu: 6000m - memory: 58Gi - requests: - cpu: 6000m - memory: 58Gi - volumeMounts: - - mountPath: /var/log/managedservices/system/mlrs - name: system-logs - - mountPath: /var/log/managedservices/user/mlrs - name: user-logs - - mountPath: /dev/shm - name: dshm - - mountPath: /opt/app - name: stage-volume - volumes: - - name: system-logs - source: local - - name: user-logs - source: local - - name: dshm - size: 17Gi - source: memory - - name: stage-volume - source: '@ML_DEMO_STAGE' -$$ -NAME = ML_DEMO_JOB; - -CALL SYSTEM$GET_SERVICE_LOGS('ML_DEMO_JOB', '0', 'main', 500); \ No newline at end of file diff --git a/samples/ml/airflow/src/train.py b/samples/ml/airflow/src/train.py deleted file mode 100644 index f92c5ae..0000000 --- a/samples/ml/airflow/src/train.py +++ /dev/null @@ -1,229 +0,0 @@ -import json -import os -import pickle -from time import perf_counter -from typing import Literal, Optional - -import pandas as pd -import xgboost as xgb -from sklearn.compose import ColumnTransformer -from sklearn.impute import SimpleImputer -from sklearn.metrics import accuracy_score, classification_report, roc_auc_score -from sklearn.model_selection import train_test_split -from sklearn.pipeline import Pipeline -from sklearn.preprocessing import OneHotEncoder, StandardScaler -from snowflake.ml.data.data_connector import DataConnector -from snowflake.ml.registry import Registry as ModelRegistry -from snowflake.ml.utils.connection_params import SnowflakeLoginOptions -from snowflake.snowpark import Session - - -def create_data_connector(session, table_name: str) -> DataConnector: - """Load data from Snowflake table""" - # Example query - modify according to your schema - query = f""" - SELECT - age, - income, - credit_score, - employment_length, - loan_amount, - debt_to_income, - number_of_credit_lines, - previous_defaults, - loan_purpose, - is_default - FROM {table_name} - """ - sp_df = session.sql(query) - return DataConnector.from_dataframe(sp_df) - - -def build_pipeline(model_params: dict = None) -> Pipeline: - """Create pipeline with preprocessors and model""" - # Define column types - categorical_cols = ["LOAN_PURPOSE"] - numerical_cols = [ - "AGE", - "INCOME", - "CREDIT_SCORE", - "EMPLOYMENT_LENGTH", - "LOAN_AMOUNT", - "DEBT_TO_INCOME", - "NUMBER_OF_CREDIT_LINES", - "PREVIOUS_DEFAULTS", - ] - - # Numerical preprocessing pipeline - numeric_transformer = Pipeline( - steps=[ - ("imputer", SimpleImputer(strategy="median")), - ("scaler", StandardScaler()), - ] - ) - - # Categorical preprocessing pipeline - categorical_transformer = Pipeline( - steps=[ - ("imputer", SimpleImputer(strategy="constant", fill_value="missing")), - ("onehot", OneHotEncoder(handle_unknown="ignore", sparse_output=False)), - ] - ) - - # Combine transformers - preprocessor = ColumnTransformer( - transformers=[ - ("num", numeric_transformer, numerical_cols), - ("cat", categorical_transformer, categorical_cols), - ] - ) - - # Define model parameters - default_params = { - "objective": "binary:logistic", - "eval_metric": "auc", - "max_depth": 6, - "learning_rate": 0.1, - "n_estimators": 100, - "subsample": 0.8, - "colsample_bytree": 0.8, - "random_state": 42, - } - model = xgb.XGBClassifier(**(model_params or default_params)) - - return Pipeline([("preprocessor", preprocessor), ("classifier", model)]) - - -def evaluate_model(model: Pipeline, X_test: pd.DataFrame, y_test: pd.DataFrame): - """Evaluate model performance""" - # Make predictions - y_pred = model.predict(X_test) - y_pred_proba = model.predict_proba(X_test)[:, 1] - - # Calculate metrics - metrics = { - "accuracy": accuracy_score(y_test, y_pred), - "roc_auc": roc_auc_score(y_test, y_pred_proba), - "classification_report": classification_report(y_test, y_pred), - } - - return metrics - - -def save_to_registry( - session: Session, - model: Pipeline, - model_name: str, - metrics: dict, - sample_input_data: pd.DataFrame, -): - """Save model and artifacts to Snowflake Model Registry""" - # Initialize model registry - registry = ModelRegistry(session) - - # Save to registry - registry.log_model( - model=model, - model_name=model_name, - metrics=metrics, - sample_input_data=sample_input_data[:5], - conda_dependencies=["xgboost"], - ) - - -def main(source_data: str, save_mode: Literal["local", "registry"] = "local", output_dir: Optional[str] = None): - # Initialize Snowflake session - # See https://docs.snowflake.com/developer-guide/snowflake-cli/connecting/configure-connections#define-connections - # for how to define default connections in a config.toml file - session = Session.builder.configs(SnowflakeLoginOptions()).create() - - # Load data - dc = create_data_connector(session, table_name=source_data) - print("Loading data...", end="", flush=True) - start = perf_counter() - df = dc.to_pandas() - elapsed = perf_counter() - start - print(f" done! Loaded {len(df)} rows, elapsed={elapsed:.3f}s") - - # Split data - X = df.drop("IS_DEFAULT", axis=1) - y = df["IS_DEFAULT"] - X_train, X_test, y_train, y_test = train_test_split( - X, y, test_size=0.2, random_state=42 - ) - - # Train model - model = build_pipeline() - print("Training model...", end="") - start = perf_counter() - model.fit(X_train, y_train) - elapsed = perf_counter() - start - print(f" done! Elapsed={elapsed:.3f}s") - - # Evaluate model - print("Evaluating model...", end="") - start = perf_counter() - metrics = evaluate_model( - model, - X_test, - y_test, - ) - elapsed = perf_counter() - start - print(f" done! Elapsed={elapsed:.3f}s") - - # Print evaluation results - print("\nModel Performance Metrics:") - print(f"Accuracy: {metrics['accuracy']:.4f}") - print(f"ROC AUC: {metrics['roc_auc']:.4f}") - # Uncomment below for full classification report - # print("\nClassification Report:") - # print(metrics["classification_report"]) - - start = perf_counter() - if save_mode == "local": - # Save model locally - print("Saving model to disk...", end="") - output_dir = output_dir or os.path.dirname(__file__) - model_subdir = os.environ.get("SNOWFLAKE_SERVICE_NAME", "output") - model_dir = os.path.join(output_dir, model_subdir) if not output_dir.endswith(model_subdir) else output_dir - os.makedirs(model_dir, exist_ok=True) - with open(os.path.join(model_dir, "model.pkl"), "wb") as f: - pickle.dump(model, f) - with open(os.path.join(model_dir, "metrics.json"), "w") as f: - json.dump(metrics, f, indent=2) - elif save_mode == "registry": - # Save model to registry - print("Logging model to Model Registry...", end="") - save_to_registry( - session, - model=model, - model_name="loan_default_predictor", - metrics=metrics, - sample_input_data=X_train, - ) - elapsed = perf_counter() - start - print(f" done! Elapsed={elapsed:.3f}s") - - # Close Snowflake session - session.close() - - -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument( - "--source_data", default="loan_applications", help="Name of input data table" - ) - parser.add_argument( - "--save_mode", - choices=["local", "registry"], - default="local", - help="Model save mode", - ) - parser.add_argument( - "--output_dir", type=str, help="Local save path. Only relevant if save_mode=local" - ) - args = parser.parse_args() - - main(**vars(args))