Skip to content

Commit

Permalink
feat: add support for artifact URL pipeline specs
Browse files Browse the repository at this point in the history
  • Loading branch information
Jonny Browning (Datatonic) authored Jul 7, 2022
1 parent 278eb95 commit 47a8f96
Show file tree
Hide file tree
Showing 9 changed files with 170 additions and 11 deletions.
9 changes: 9 additions & 0 deletions .github/workflows/pr-checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,21 @@ jobs:
- name: Copy pipeline.json to GCS
run: "gsutil cp test/pipeline.json gs://${{ secrets.TEST_BUCKET }}/terraform-google-scheduled-vertex-pipelines/${{ github.run_id }}/pipeline.json"

- name: Copy pipeline.yaml to AR
run: >
curl -X POST
-H "Authorization: Bearer $(gcloud auth print-access-token)"
-F tags=latest
-F content=@test/pipeline.yaml
"https://europe-west2-kfp.pkg.dev/${{ secrets.TEST_PROJECT_ID }}/${{ secrets.TEST_AR_REPO }}"
- name: Run Terratest
run: make test
env:
TF_VAR_project: ${{ secrets.TEST_PROJECT_ID }}
TF_VAR_gcs_bucket: ${{ secrets.TEST_BUCKET }}
TF_VAR_object_name: "terraform-google-scheduled-vertex-pipelines/${{ github.run_id }}/pipeline.json"
TF_VAR_ar_repository: ${{ secrets.TEST_AR_REPO }}

- name: Delete pipeline.json from GCS after test
if: always()
Expand Down
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,14 @@ You can customise the template (including this text for example) in `.github/wor
| Name | Version |
|------|---------|
| <a name="requirement_google"></a> [google](#requirement\_google) | >= 4.0.0 |
| <a name="requirement_http"></a> [http](#requirement\_http) | >= 2.2.0 |

## Providers

| Name | Version |
|------|---------|
| <a name="provider_google"></a> [google](#provider\_google) | >= 4.0.0 |
| <a name="provider_http"></a> [http](#provider\_http) | >= 2.2.0 |

## Modules

Expand All @@ -57,8 +59,10 @@ No modules.
| Name | Type |
|------|------|
| [google_cloud_scheduler_job.job](https://registry.terraform.io/providers/hashicorp/google/latest/docs/resources/cloud_scheduler_job) | resource |
| [google_client_config.default](https://registry.terraform.io/providers/hashicorp/google/latest/docs/data-sources/client_config) | data source |
| [google_compute_default_service_account.default](https://registry.terraform.io/providers/hashicorp/google/latest/docs/data-sources/compute_default_service_account) | data source |
| [google_storage_bucket_object_content.pipeline_spec](https://registry.terraform.io/providers/hashicorp/google/latest/docs/data-sources/storage_bucket_object_content) | data source |
| [http_http.pipeline_spec](https://registry.terraform.io/providers/hashicorp/http/latest/docs/data-sources/http) | data source |

## Inputs

Expand All @@ -76,7 +80,7 @@ No modules.
| <a name="input_labels"></a> [labels](#input\_labels) | The labels with user-defined metadata to organize PipelineJob. Label keys and values can be no longer than 64 characters (Unicode codepoints), can only contain lowercase letters, numeric characters, underscores and dashes. International characters are allowed. See https://goo.gl/xmQnxf for more information and examples of labels. | `map(string)` | `{}` | no |
| <a name="input_network"></a> [network](#input\_network) | The full name of the Compute Engine network to which the Pipeline Job's workload should be peered. For example, projects/12345/global/networks/myVPC. Format is of the form projects/{project}/global/networks/{network}. Where {project} is a project number, as in 12345, and {network} is a network name. Private services access must already be configured for the network. Pipeline job will apply the network configuration to the GCP resources being launched, if applied, such as Vertex AI Training or Dataflow job. If left unspecified, the workload is not peered with any network. | `string` | `null` | no |
| <a name="input_parameter_values"></a> [parameter\_values](#input\_parameter\_values) | The runtime parameters of the PipelineJob. The parameters will be passed into PipelineJob.pipeline\_spec to replace the placeholders at runtime. This field is used by pipelines built using PipelineJob.pipeline\_spec.schema\_version 2.1.0, such as pipelines built using Kubeflow Pipelines SDK 1.9 or higher and the v2 DSL. | `map(any)` | `{}` | no |
| <a name="input_pipeline_spec_path"></a> [pipeline\_spec\_path](#input\_pipeline\_spec\_path) | Path to the KFP pipeline spec file (YAML or JSON). This can be a local or a GCS path. | `string` | n/a | yes |
| <a name="input_pipeline_spec_path"></a> [pipeline\_spec\_path](#input\_pipeline\_spec\_path) | Path to the KFP pipeline spec file (YAML or JSON). This can be a local file, GCS path, or Artifact Registry path. | `string` | n/a | yes |
| <a name="input_project"></a> [project](#input\_project) | The GCP project ID where the cloud scheduler job and Vertex Pipeline should be deployed. | `string` | n/a | yes |
| <a name="input_schedule"></a> [schedule](#input\_schedule) | Describes the schedule on which the job will be executed. | `string` | n/a | yes |
| <a name="input_time_zone"></a> [time\_zone](#input\_time\_zone) | Specifies the time zone to be used in interpreting schedule. The value of this field must be a time zone name from the tz database. | `string` | `"UTC"` | no |
Expand Down
16 changes: 16 additions & 0 deletions examples/hello_world_ar/main.tf
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
module "hello_world_pipeline" {
source = "../../"
project = var.project
vertex_region = "europe-west2"
cloud_scheduler_region = "europe-west2"
pipeline_spec_path = "https://europe-west2-kfp.pkg.dev/${var.project}/${var.ar_repository}/hello-world/latest"
parameter_values = {
"text" = "Hello, world!"
}
gcs_output_directory = "gs://my-bucket/my-output-directory"
vertex_service_account_email = "my-vertex-service-account@my-gcp-project-id.iam.gserviceaccount.com"
time_zone = "UTC"
schedule = "0 0 * * *"
cloud_scheduler_job_name = "pipeline-from-local-spec"

}
4 changes: 4 additions & 0 deletions examples/hello_world_ar/outputs.tf
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
output "id" {
value = module.hello_world_pipeline.id
description = "an identifier for the Cloud Scheduler job resource with format projects/{{project}}/locations/{{region}}/jobs/{{name}}"
}
9 changes: 9 additions & 0 deletions examples/hello_world_ar/variables.tf
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
variable "project" {
type = string
description = "The GCP project ID where the cloud scheduler job and Vertex Pipeline should be deployed."
}

variable "ar_repository" {
type = string
description = "Name of the Artifact Registry repository used to store the pipeline definition."
}
57 changes: 48 additions & 9 deletions main.tf
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,45 @@ terraform {
source = "hashicorp/google"
version = ">= 4.0.0"
}

http = {
source = "hashicorp/http"
version = ">= 2.2.0"
}
}
}

locals {
# var.pipeline_spec_path minus gs:// prefix (if prefix exists)
pipeline_spec_path_no_gcs_prefix = trimprefix(var.pipeline_spec_path, "gs://")
# Regex explanation:

# Starts with named group "scheme"
# either "https://" ("http_scheme") (for Artifact registry pipeline spec)
# or "gs://" ("gs://") (for GCS pipeline spec)
# or nothing

# is var.pipeline_spec_path a GCS path? (i.e. has trimming the prefix made a difference?)
pipeline_spec_path_is_gcs_path = (var.pipeline_spec_path != local.pipeline_spec_path_no_gcs_prefix)
# Next part is named group "root"
# For GCS path "root" = bucket name
# otherwise it's just the first part of the path (minus prefix)

# split the path into parts by "/"
pipeline_spec_path_no_gcs_prefix_parts = split("/", local.pipeline_spec_path_no_gcs_prefix)
# Next named group is "rest_of_path_including_slash"
# This consists of two named groups:
# 1) a forward slash (named group "slash")
# 2) rest of the string (named group "rest_of_path")
# For GCS pipeline spec "rest_of_path" = GCS object name
pipeline_spec_path = regex("^(?P<scheme>(?P<http_scheme>https\\:\\/\\/)|(?P<gs_scheme>gs\\:\\/\\/))?(?P<root>[\\w.-]*)?(?P<rest_of_path_including_slash>(?P<slash>\\/)(?P<rest_of_path>.*))*", var.pipeline_spec_path)

pipeline_spec_path_is_gcs_path = local.pipeline_spec_path.scheme == "gs://"
pipeline_spec_path_is_ar_path = local.pipeline_spec_path.scheme == "https://"
pipeline_spec_path_is_local_path = local.pipeline_spec_path.scheme == null

# Load the pipeline spec from YAML/JSON
# If it's a GCS path, load it from the GCS object content
# If it's an AR path, load it from Artifact registry
# If it's a local path, load from the local file
pipeline_spec = yamldecode(local.pipeline_spec_path_is_gcs_path ? data.google_storage_bucket_object_content.pipeline_spec[0].content : file(var.pipeline_spec_path))
pipeline_spec = yamldecode(
local.pipeline_spec_path_is_gcs_path ? data.google_storage_bucket_object_content.pipeline_spec[0].content :
(local.pipeline_spec_path_is_ar_path ? data.http.pipeline_spec[0].response_body :
file(var.pipeline_spec_path)))

# If var.kms_key_name is provided, construct the encryption_spec object
encryption_spec = (var.kms_key_name == null) ? null : { "kmsKeyName" : var.kms_key_name }
Expand Down Expand Up @@ -48,8 +70,25 @@ locals {
# Load the pipeline spec from the GCS path
data "google_storage_bucket_object_content" "pipeline_spec" {
count = local.pipeline_spec_path_is_gcs_path ? 1 : 0
name = join("/", slice(local.pipeline_spec_path_no_gcs_prefix_parts, 1, length(local.pipeline_spec_path_no_gcs_prefix_parts)))
bucket = local.pipeline_spec_path_no_gcs_prefix_parts[0]
name = local.pipeline_spec_path.rest_of_path
bucket = local.pipeline_spec_path.root
}

# If var.pipeline_spec_path is an Artifact Registry (https) path
# We will need the authorization token
data "google_client_config" "default" {
count = local.pipeline_spec_path_is_ar_path ? 1 : 0
}

# If var.pipeline_spec_path is an Artifact Registry (https) path
# Load the pipeline spec from AR (over HTTPS) using authorization token
data "http" "pipeline_spec" {
count = local.pipeline_spec_path_is_ar_path ? 1 : 0
url = var.pipeline_spec_path

request_headers = {
Authorization = "Bearer ${data.google_client_config.default[0].access_token}"
}
}

# If a service account is not specified for Cloud Scheduler, use the default compute service account
Expand Down
45 changes: 45 additions & 0 deletions test/pipeline.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
components:
comp-hello-world:
executorLabel: exec-hello-world
deploymentSpec:
executors:
exec-hello-world:
container:
args:
- --executor_input
- '{{$}}'
- --function_to_execute
- hello_world
command:
- sh
- -c
- "\nif ! [ -x \"$(command -v pip)\" ]; then\n python3 -m ensurepip ||\
\ python3 -m ensurepip --user || apt-get install python3-pip\nfi\n\nPIP_DISABLE_PIP_VERSION_CHECK=1\
\ python3 -m pip install --quiet --no-warn-script-location 'kfp==2.0.0-beta.0'\
\ && \"$0\" \"$@\"\n"
- sh
- -ec
- 'program_path=$(mktemp -d)
printf "%s" "$0" > "$program_path/ephemeral_component.py"
python3 -m kfp.components.executor_main --component_module_path "$program_path/ephemeral_component.py" "$@"
'
- "\nimport kfp\nfrom kfp import dsl\nfrom kfp.dsl import *\nfrom typing import\
\ *\n\ndef hello_world():\n print(\"Hello, world!\")\n\n"
image: python:3.7
pipelineInfo:
name: hello-world
root:
dag:
tasks:
hello-world:
cachingOptions:
enableCache: true
componentRef:
name: comp-hello-world
taskInfo:
name: hello-world
schemaVersion: 2.1.0
sdkVersion: kfp-2.0.0-beta.0
33 changes: 33 additions & 0 deletions test/terraform_google_scheduled_vertex_pipelines_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,36 @@ func TestHelloWorldGCS(t *testing.T) {
assert.Equal(t, schedulerpb.Job_ENABLED, resp.State)

}

func TestHelloWorldAR(t *testing.T) {

terraformOptions := terraform.WithDefaultRetryableErrors(t, &terraform.Options{
// Directory where main.tf for test is
TerraformDir: "../examples/hello_world_ar",

// Missing variables will come from TF_VAR env variables

})

defer terraform.Destroy(t, terraformOptions)

terraform.InitAndApply(t, terraformOptions)

// Get cloud scheduler job ID from terraform output
cloud_scheduler_job_id := terraform.Output(t, terraformOptions, "id")

// set up Google Cloud SDK connection
ctx := context.Background()
c, _ := scheduler.NewCloudSchedulerClient(ctx)
defer c.Close()

// Get cloud scheduler job using Google Cloud SDK
req := &schedulerpb.GetJobRequest{
Name: cloud_scheduler_job_id,
}
resp, _ := c.GetJob(ctx, req)

// Assert Cloud Scheduler job exists and is enabled
assert.Equal(t, schedulerpb.Job_ENABLED, resp.State)

}
2 changes: 1 addition & 1 deletion variables.tf
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ variable "display_name" {

variable "pipeline_spec_path" {
type = string
description = "Path to the KFP pipeline spec file (YAML or JSON). This can be a local or a GCS path."
description = "Path to the KFP pipeline spec file (YAML or JSON). This can be a local file, GCS path, or Artifact Registry path."
}

variable "labels" {
Expand Down

0 comments on commit 47a8f96

Please sign in to comment.