diff --git a/python-sdk/pyproject.toml b/python-sdk/pyproject.toml index 62a770353..fe9dbff6d 100644 --- a/python-sdk/pyproject.toml +++ b/python-sdk/pyproject.toml @@ -57,7 +57,8 @@ google = [ "protobuf", "apache-airflow-providers-google>=10.15.0", "sqlalchemy-bigquery>=1.3.0", - "smart-open[gcs]>=5.2.1,<7.0.0" + "smart-open[gcs]>=5.2.1,<7.0.0", + "google-cloud-bigquery<3.21.0" ] snowflake = [ "apache-airflow-providers-snowflake>=5.3.0", @@ -126,7 +127,8 @@ all = [ "azure-storage-blob", "apache-airflow-providers-microsoft-mssql>=3.2", "airflow-provider-duckdb>=0.0.2", - "apache-airflow-providers-mysql" + "apache-airflow-providers-mysql", + "google-cloud-bigquery<3.21.0" ] doc = [ "myst-parser>=0.17", diff --git a/python-sdk/src/astro/sql/operators/data_validations/check_table.py b/python-sdk/src/astro/sql/operators/data_validations/check_table.py index ef8b180da..dbe26f27b 100644 --- a/python-sdk/src/astro/sql/operators/data_validations/check_table.py +++ b/python-sdk/src/astro/sql/operators/data_validations/check_table.py @@ -56,12 +56,7 @@ def execute(self, context: "Context"): db = create_database(self.dataset.conn_id) self.table = db.get_table_qualified_name(self.dataset) self.conn_id = self.dataset.conn_id - # apache-airflow-providers-common-sql == 1.2.0 which is compatible with airflow 2.2.5 implements the self.sql - # differently compared to apache-airflow-providers-common-sql == 1.3.3 - try: - self.sql = f"SELECT check_name, check_result FROM ({self._generate_sql_query()}) AS check_table" - except AttributeError: - self.sql = f"SELECT * FROM {self.table};" + self.sql = f"SELECT check_name, check_result FROM ({self._generate_sql_query()}) AS check_table" super().execute(context) def get_db_hook(self) -> Any: diff --git a/python-sdk/tests/airflow_tests/test_datasets.py b/python-sdk/tests/airflow_tests/test_datasets.py index dc21c1a96..5ecde4707 100644 --- a/python-sdk/tests/airflow_tests/test_datasets.py +++ b/python-sdk/tests/airflow_tests/test_datasets.py @@ -104,7 +104,6 @@ def test_kwargs_with_temp_table(): @pytest.mark.skipif(airflow.__version__ < "2.4.0", reason="Require Airflow version >= 2.4.0") def test_example_dataset_dag(): from airflow.datasets import Dataset - from airflow.models.dataset import DatasetModel dir_path = os.path.dirname(os.path.realpath(__file__)) db = DagBag(dir_path + "/../../example_dags/example_datasets.py") @@ -115,9 +114,8 @@ def test_example_dataset_dag(): outlets = producer_dag.tasks[-1].outlets assert isinstance(outlets[0], Dataset) # Test that dataset_triggers is only set if all the instances passed to the DAG object are Datasets - assert consumer_dag.dataset_triggers == outlets + assert consumer_dag.dataset_triggers.objects[0] == outlets[0] assert outlets[0].uri == "astro://postgres_conn@?table=imdb_movies" - assert DatasetModel.from_public(outlets[0]) == Dataset("astro://postgres_conn@?table=imdb_movies") def test_disable_auto_inlets_outlets(): diff --git a/python-sdk/tests/sql/operators/test_cleanup.py b/python-sdk/tests/sql/operators/test_cleanup.py index dde4ac724..a0b59c099 100644 --- a/python-sdk/tests/sql/operators/test_cleanup.py +++ b/python-sdk/tests/sql/operators/test_cleanup.py @@ -105,19 +105,20 @@ def test_error_raised_with_blocking_op_executors( reason="BackfillJobRunner and Job classes are only available in airflow >= 2.6", ) @pytest.mark.parametrize( - "executor_in_job,executor_in_cfg,expected_val", + "executor_in_job, executor_in_cfg, expected_val", [ - (SequentialExecutor(), "LocalExecutor", True), + (SequentialExecutor(), "SequentialExecutor", True), (LocalExecutor(), "LocalExecutor", False), (None, "LocalExecutor", False), (None, "SequentialExecutor", True), ], ) -def test_single_worker_mode_backfill(executor_in_job, executor_in_cfg, expected_val): +def test_single_worker_mode_backfill(monkeypatch, executor_in_job, executor_in_cfg, expected_val): """Test that if we run Backfill Job it should be marked as single worker node""" from airflow.jobs.backfill_job_runner import BackfillJobRunner from airflow.jobs.job import Job + monkeypatch.setattr("airflow.executors.executor_loader._executor_names", []) dag = DAG("test_single_worker_mode_backfill", start_date=datetime(2022, 1, 1)) dr = DagRun(dag_id=dag.dag_id) @@ -175,17 +176,18 @@ def test_single_worker_mode_backfill_airflow_2_5(executor_in_job, executor_in_cf @pytest.mark.parametrize( "executor_in_job,executor_in_cfg,expected_val", [ - (SequentialExecutor(), "LocalExecutor", True), + (SequentialExecutor(), "SequentialExecutor", True), (LocalExecutor(), "LocalExecutor", False), (None, "LocalExecutor", False), (None, "SequentialExecutor", True), ], ) -def test_single_worker_mode_scheduler_job(executor_in_job, executor_in_cfg, expected_val): +def test_single_worker_mode_scheduler_job(monkeypatch, executor_in_job, executor_in_cfg, expected_val): """Test that if we run Scheduler Job it should be marked as single worker node""" from airflow.jobs.job import Job from airflow.jobs.scheduler_job_runner import SchedulerJobRunner + monkeypatch.setattr("airflow.executors.executor_loader._executor_names", []) dag = DAG("test_single_worker_mode_scheduler_job", start_date=datetime(2022, 1, 1)) dr = DagRun(dag_id=dag.dag_id) diff --git a/python-sdk/tests_integration/sql/operators/data_validation/test_check_column.py b/python-sdk/tests_integration/sql/operators/data_validation/test_check_column.py index a8763e220..eb48ecf4d 100644 --- a/python-sdk/tests_integration/sql/operators/data_validation/test_check_column.py +++ b/python-sdk/tests_integration/sql/operators/data_validation/test_check_column.py @@ -5,7 +5,6 @@ from astro import sql as aql from astro.constants import Database from astro.files import File -from astro.table import Table from tests.sql.operators import utils as test_utils CWD = pathlib.Path(__file__).parent @@ -22,7 +21,6 @@ { "database": Database.BIGQUERY, "file": File(path=str(CWD) + "/../../../data/data_validation.csv"), - "table": Table(conn_id="gcp_conn_project"), }, { "database": Database.POSTGRES, diff --git a/python-sdk/tests_integration/sql/operators/data_validation/test_check_table.py b/python-sdk/tests_integration/sql/operators/data_validation/test_check_table.py index c5af04d3b..204d22e2d 100644 --- a/python-sdk/tests_integration/sql/operators/data_validation/test_check_table.py +++ b/python-sdk/tests_integration/sql/operators/data_validation/test_check_table.py @@ -5,7 +5,6 @@ from astro import sql as aql from astro.constants import Database from astro.files import File -from astro.table import Table from tests.sql.operators import utils as test_utils CWD = pathlib.Path(__file__).parent @@ -22,7 +21,6 @@ { "database": Database.BIGQUERY, "file": File(path=str(CWD) + "/../../../data/homes_main.csv"), - "table": Table(conn_id="gcp_conn_project"), }, { "database": Database.POSTGRES,