diff --git a/cmd/airflow.go b/cmd/airflow.go index ac3eef1c0..b607a6d77 100644 --- a/cmd/airflow.go +++ b/cmd/airflow.go @@ -2,6 +2,7 @@ package cmd import ( "fmt" + "os" "path/filepath" "regexp" "slices" @@ -102,7 +103,8 @@ astro dev init --from-template dockerfile = "Dockerfile" configReinitProjectConfigMsg = "Reinitialized existing Astro project in %s\n" - configInitProjectConfigMsg = "Initialized empty Astro project in %s" + configInitProjectConfigMsg = "Initialized empty Astro project in %s\n" + changeDirectoryMsg = "To begin developing, change to your project directory with `cd %s`\n" // this is used to monkey patch the function in order to write unit test cases containerHandlerInit = airflow.ContainerHandlerInit @@ -117,6 +119,7 @@ astro dev init --from-template errNoCompose = errors.New("cannot use '--compose-file' without '--compose' flag") TemplateList = airflow.FetchTemplateList defaultWaitTime = 1 * time.Minute + directoryPermissions = 0o755 ) func newDevRootCmd(platformCoreClient astroplatformcore.CoreClient, astroCoreClient astrocore.CoreClient) *cobra.Command { @@ -463,22 +466,16 @@ func newObjectExportCmd() *cobra.Command { // Use project name for image name func airflowInit(cmd *cobra.Command, args []string) error { - // Validate project name - if projectName != "" { - // error if project name has spaces - if len(args) > 0 { - return errProjectNameSpaces - } - projectNameValid := regexp. - MustCompile(`^(?i)[a-z0-9]([a-z0-9_-]*[a-z0-9])$`). - MatchString + name, err := ensureProjectName(args, projectName) + if err != nil { + return err + } + projectName = name - if !projectNameValid(projectName) { - return errConfigProjectName - } - } else { - projectDirectory := filepath.Base(config.WorkingPath) - projectName = strings.Replace(strcase.ToSnake(projectDirectory), "_", "-", -1) + // Save the directory we are in when the init command is run. + initialDir, err := fileutil.GetWorkingDir() + if err != nil { + return err } if fromTemplate == "select-template" { @@ -507,7 +504,6 @@ func airflowInit(cmd *cobra.Command, args []string) error { } // If user provides a runtime version, use it, otherwise retrieve the latest one (matching Airflow Version if provided) - var err error defaultImageTag := runtimeVersion if defaultImageTag == "" { httpClient := airflowversions.NewClient(httputil.NewHTTPClient(), useAstronomerCertified) @@ -517,16 +513,22 @@ func airflowInit(cmd *cobra.Command, args []string) error { defaultImageName := airflow.AstroRuntimeImageName if useAstronomerCertified { defaultImageName = airflow.AstronomerCertifiedImageName - fmt.Printf("Initializing Astro project\nPulling Airflow development files from Astronomer Certified Airflow Version %s\n", defaultImageTag) - } else { - fmt.Printf("Initializing Astro project\nPulling Airflow development files from Astro Runtime %s\n", defaultImageTag) } + // Ensure the project directory is created if a positional argument is provided. + newProjectPath, err := ensureProjectDirectory(args, config.WorkingPath, projectName) + if err != nil { + return err + } + + // Update the config setting. + config.WorkingPath = newProjectPath + emptyDir := fileutil.IsEmptyDir(config.WorkingPath) if !emptyDir { i, _ := input.Confirm( - fmt.Sprintf("%s \nYou are not in an empty directory. Are you sure you want to initialize a project?", config.WorkingPath)) + fmt.Sprintf("%s is not an empty directory. Are you sure you want to initialize a project here?", config.WorkingPath)) if !i { fmt.Println("Canceling project initialization...") @@ -534,7 +536,10 @@ func airflowInit(cmd *cobra.Command, args []string) error { } } - exists := config.ProjectConfigExists() + exists, err := config.IsProjectDir(config.WorkingPath) + if err != nil { + return err + } if !exists { config.CreateProjectConfig(config.WorkingPath) } @@ -554,14 +559,80 @@ func airflowInit(cmd *cobra.Command, args []string) error { } if exists { - fmt.Printf(configReinitProjectConfigMsg+"\n", config.WorkingPath) + fmt.Printf(configReinitProjectConfigMsg, config.WorkingPath) } else { - fmt.Printf(configInitProjectConfigMsg+"\n", config.WorkingPath) + fmt.Printf(configInitProjectConfigMsg, config.WorkingPath) + } + + // If we started in a different directory, that means the positional argument for projectName was used. + // This means the users shell pwd is not the project directory, so we print a message + // to cd into the project directory. + if initialDir != config.WorkingPath { + fmt.Printf(changeDirectoryMsg, projectName) } return nil } +// ensureProjectDirectory creates a new project directory if a positional argument is provided. +func ensureProjectDirectory(args []string, workingPath, projectName string) (string, error) { + // Return early if no positional argument was provided. + if len(args) == 0 { + return workingPath, nil + } + + // Construct the path to our desired project directory. + newProjectPath := filepath.Join(workingPath, projectName) + + // Determine if the project directory already exists. + projectDirExists, err := fileutil.Exists(newProjectPath, nil) + if err != nil { + return "", err + } + + // If the project directory does not exist, create it. + if !projectDirExists { + err := os.Mkdir(newProjectPath, os.FileMode(directoryPermissions)) + if err != nil { + return "", err + } + } + + // Return the path we just created. + return newProjectPath, nil +} + +func ensureProjectName(args []string, projectName string) (string, error) { + // If the project name is specified with the --name flag, + // it cannot be specified as a positional argument as well, so return an error. + if projectName != "" && len(args) > 0 { + return "", errConfigProjectNameSpecifiedTwice + } + + // The first positional argument is the project name. + // If the project name is provided in this way, we'll + // attempt to create a directory with that name. + if projectName == "" && len(args) > 0 { + projectName = args[0] + } + + // Validate project name + if projectName != "" { + projectNameValid := regexp. + MustCompile(`^(?i)[a-z0-9]([a-z0-9_-]*[a-z0-9])$`). + MatchString + + if !projectNameValid(projectName) { + return "", errConfigProjectName + } + } else { + projectDirectory := filepath.Base(config.WorkingPath) + projectName = strings.Replace(strcase.ToSnake(projectDirectory), "_", "-", -1) + } + + return projectName, nil +} + func airflowUpgradeTest(cmd *cobra.Command, platformCoreClient astroplatformcore.CoreClient) error { //nolint:gocognit // Validate runtimeVersion and airflowVersion if airflowVersion != "" && runtimeVersion != "" { diff --git a/cmd/airflow_test.go b/cmd/airflow_test.go index 154482f21..6a1688e58 100644 --- a/cmd/airflow_test.go +++ b/cmd/airflow_test.go @@ -3,6 +3,7 @@ package cmd import ( "bytes" "errors" + "fmt" "io" "net/http" "os" @@ -300,20 +301,6 @@ func (s *AirflowSuite) TestAirflowInit() { s.True(strings.Contains(dockerfileContents, "FROM quay.io/astronomer/astro-runtime:")) }) - s.Run("invalid args", func() { - cmd := newAirflowInitCmd() - cmd.Flag("name").Value.Set("test-project-name") - args := []string{"invalid-arg"} - - r, stdin := s.mockUserInput("y") - - // Restore stdin right after the test. - defer func() { os.Stdin = stdin }() - os.Stdin = r - err := airflowInit(cmd, args) - s.ErrorIs(err, errProjectNameSpaces) - }) - s.Run("invalid project name", func() { cmd := newAirflowInitCmd() cmd.Flag("name").Value.Set("test@project-name") @@ -387,16 +374,19 @@ func (s *AirflowSuite) TestAirflowInit() { orgStdout := os.Stdout defer func() { os.Stdout = orgStdout }() - r, w, _ := os.Pipe() + _, w, _ := os.Pipe() os.Stdout = w err := airflowInit(cmd, args) w.Close() - out, _ := io.ReadAll(r) s.NoError(err) - s.Contains(string(out), "Pulling Airflow development files from Astronomer Certified Airflow Version") + + // Check the Dockerfile to ensure we are using the AC image. + b, _ := os.ReadFile(filepath.Join(s.tempDir, "Dockerfile")) + dockerfileContents := string(b) + s.True(strings.Contains(dockerfileContents, airflow.AstronomerCertifiedImageName)) }) s.Run("cancel non empty dir warning", func() { @@ -458,6 +448,52 @@ func (s *AirflowSuite) TestAirflowInit() { s.NoError(err) s.Contains(string(out), "Reinitialized existing Astro project in") }) + + s.Run("specify positional argument for project name", func() { + cmd := newAirflowInitCmd() + args := []string{"test-project-name"} + + r, stdin := s.mockUserInput("n") + + // Restore stdin right after the test. + defer func() { os.Stdin = stdin }() + os.Stdin = r + + orgStdout := os.Stdout + defer func() { os.Stdout = orgStdout }() + r, w, _ := os.Pipe() + os.Stdout = w + + err := airflowInit(cmd, args) + + w.Close() + out, _ := io.ReadAll(r) + + s.NoError(err) + s.Contains(string(out), fmt.Sprintf(changeDirectoryMsg, args[0])) + }) + + s.Run("specify flag and positional argument for project name, resulting in error", func() { + cmd := newAirflowInitCmd() + args := []string{"test-project-name"} + cmd.Flag("name").Value.Set("test-project-name") + + r, stdin := s.mockUserInput("n") + + // Restore stdin right after the test. + defer func() { os.Stdin = stdin }() + os.Stdin = r + + orgStdout := os.Stdout + defer func() { os.Stdout = orgStdout }() + _, w, _ := os.Pipe() + os.Stdout = w + + err := airflowInit(cmd, args) + + w.Close() + s.ErrorIs(err, errConfigProjectNameSpecifiedTwice) + }) } func (s *AirflowSuite) TestAirflowStart() { diff --git a/cmd/errors.go b/cmd/errors.go index 9f08184de..97807fddb 100644 --- a/cmd/errors.go +++ b/cmd/errors.go @@ -9,8 +9,8 @@ var ( errInvalidBothAirflowAndRuntimeVersionsUpgrade = errors.New("you provided both a runtime version and an Airflow version. You have to provide only one of these to upgrade") //nolint errInvalidBothCustomImageandVersion = errors.New("you provided both a Custom image and a version. You have to provide only one of these to upgrade") //nolint - errConfigProjectName = errors.New("project name is invalid") - errProjectNameSpaces = errors.New("this project name is invalid, a project name cannot contain spaces. Try using '-' instead") + errConfigProjectName = errors.New("project name is invalid") + errConfigProjectNameSpecifiedTwice = errors.New("project name cannot be set with the --name flag and positional argument, please choose one") errInvalidSetArgs = errors.New("must specify exactly two arguments (key value) when setting a config") errInvalidConfigPath = errors.New("config does not exist, check your config key")