Skip to content

Commit

Permalink
Fix automl vision classification test
Browse files Browse the repository at this point in the history
- provide new dataset
- delete redundant steps
  • Loading branch information
Oleg Kachur committed Sep 16, 2024
1 parent f3b238a commit 756abe6
Showing 1 changed file with 6 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,6 @@
from google.protobuf.struct_pb2 import Value

from airflow.models.dag import DAG
from airflow.providers.google.cloud.operators.gcs import (
GCSCreateBucketOperator,
GCSDeleteBucketOperator,
GCSSynchronizeBucketsOperator,
)
from airflow.providers.google.cloud.operators.vertex_ai.auto_ml import (
CreateAutoMLImageTrainingJobOperator,
DeleteAutoMLTrainingJobOperator,
Expand All @@ -44,16 +39,16 @@
)
from airflow.utils.trigger_rule import TriggerRule

DAG_ID = "example_automl_vision_clss"
DAG_ID = "automl_vision_clss"
ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default")
PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "default")
REGION = "us-central1"
IMAGE_DISPLAY_NAME = f"automl-vision-clss-{ENV_ID}"
MODEL_DISPLAY_NAME = f"automl-vision-clss-model-{ENV_ID}"

RESOURCE_DATA_BUCKET = "airflow-system-tests-resources"
IMAGE_GCS_BUCKET_NAME = f"bucket_image_clss_{ENV_ID}".replace("_", "-")

RESOURCE_IMPORT_DATA_URI = (
"gs://airflow-system-tests-resources/automl/datasets/vision/img_classification_short.csv"
)
IMAGE_DATASET = {
"display_name": f"automl-vision-clss-dataset-{ENV_ID}",
"metadata_schema_uri": schema.dataset.metadata.image,
Expand All @@ -62,7 +57,7 @@
IMAGE_DATA_CONFIG = [
{
"import_schema_uri": schema.dataset.ioformat.image.single_label_classification,
"gcs_source": {"uris": [f"gs://{IMAGE_GCS_BUCKET_NAME}/automl/image-dataset-classification.csv"]},
"gcs_source": {"uris": [RESOURCE_IMPORT_DATA_URI]},
},
]

Expand All @@ -74,22 +69,6 @@
catchup=False,
tags=["example", "automl", "vision", "classification"],
) as dag:
create_bucket = GCSCreateBucketOperator(
task_id="create_bucket",
bucket_name=IMAGE_GCS_BUCKET_NAME,
storage_class="REGIONAL",
location=REGION,
)

move_dataset_file = GCSSynchronizeBucketsOperator(
task_id="move_dataset_to_bucket",
source_bucket=RESOURCE_DATA_BUCKET,
source_object="automl/datasets/vision",
destination_bucket=IMAGE_GCS_BUCKET_NAME,
destination_object="automl",
recursive=True,
)

create_image_dataset = CreateDatasetOperator(
task_id="image_dataset",
dataset=IMAGE_DATASET,
Expand Down Expand Up @@ -142,25 +121,15 @@
trigger_rule=TriggerRule.ALL_DONE,
)

delete_bucket = GCSDeleteBucketOperator(
task_id="delete_bucket",
bucket_name=IMAGE_GCS_BUCKET_NAME,
trigger_rule=TriggerRule.ALL_DONE,
)

(
# TEST SETUP
[
create_bucket >> move_dataset_file,
create_image_dataset,
]
create_image_dataset
>> import_image_dataset
# TEST BODY
>> create_auto_ml_image_training_job
# TEST TEARDOWN
>> delete_auto_ml_image_training_job
>> delete_image_dataset
>> delete_bucket
)

from tests.system.utils.watcher import watcher
Expand Down

0 comments on commit 756abe6

Please sign in to comment.