From 03e8fa8f66a478d6fb13959b293c10620bf33333 Mon Sep 17 00:00:00 2001 From: Neel Dalsania Date: Tue, 8 Mar 2022 22:37:15 +0530 Subject: [PATCH] Fix command parsing logic (#507) * fixed command parsing logic in podman * code review changes * added more test cases * code review changes --- airflow/container.go | 2 +- airflow/docker.go | 6 +- airflow/mocks/ContainerHandler.go | 11 +- airflow/podman.go | 32 +++-- airflow/podman_test.go | 37 +++++ go.mod | 2 +- settings/settings.go | 87 +++++++----- settings/settings_test.go | 220 ++++++++++++++++++++++++++++++ settings/types.go | 6 +- 9 files changed, 352 insertions(+), 51 deletions(-) create mode 100644 settings/settings_test.go diff --git a/airflow/container.go b/airflow/container.go index e3bb66c15..d69bbf317 100644 --- a/airflow/container.go +++ b/airflow/container.go @@ -30,7 +30,7 @@ type ContainerHandler interface { Stop() error PS() error Run(args []string, user string) error - ExecCommand(containerID, command string) string + ExecCommand(containerID, command string) (string, error) GetContainerID(containerName string) (string, error) } diff --git a/airflow/docker.go b/airflow/docker.go index e32300780..3e84d6115 100644 --- a/airflow/docker.go +++ b/airflow/docker.go @@ -240,18 +240,18 @@ func (d *DockerCompose) Run(args []string, user string) error { } // ExecCommand executes a command on webserver container, and sends the response as string, this can be clubbed with Run() -func (d *DockerCompose) ExecCommand(containerID, command string) string { +func (d *DockerCompose) ExecCommand(containerID, command string) (string, error) { cmd := exec.Command("docker", "exec", "-it", containerID, "bash", "-c", command) cmd.Stdin = os.Stdin cmd.Stderr = os.Stderr out, err := cmd.Output() if err != nil { - _ = errors.Wrapf(err, "error encountered") + return "", err } stringOut := string(out) - return stringOut + return stringOut, nil } func (d *DockerCompose) GetContainerID(containerName string) (string, error) { diff --git a/airflow/mocks/ContainerHandler.go b/airflow/mocks/ContainerHandler.go index 19b40228a..e56ee0a8f 100644 --- a/airflow/mocks/ContainerHandler.go +++ b/airflow/mocks/ContainerHandler.go @@ -10,7 +10,7 @@ type ContainerHandler struct { } // ExecCommand provides a mock function with given fields: containerID, command -func (_m *ContainerHandler) ExecCommand(containerID string, command string) string { +func (_m *ContainerHandler) ExecCommand(containerID string, command string) (string, error) { ret := _m.Called(containerID, command) var r0 string @@ -20,7 +20,14 @@ func (_m *ContainerHandler) ExecCommand(containerID string, command string) stri r0 = ret.Get(0).(string) } - return r0 + var r1 error + if rf, ok := ret.Get(1).(func(string, string) error); ok { + r1 = rf(containerID, command) + } else { + r1 = ret.Error(1) + } + + return r0, r1 } // GetContainerID provides a mock function with given fields: containerName diff --git a/airflow/podman.go b/airflow/podman.go index d26faec56..2c936d52a 100644 --- a/airflow/podman.go +++ b/airflow/podman.go @@ -15,6 +15,7 @@ import ( "github.com/astronomer/astro-cli/airflow/include" "github.com/astronomer/astro-cli/config" "github.com/astronomer/astro-cli/messages" + "github.com/google/shlex" "github.com/containers/podman/v3/pkg/api/handlers" "github.com/containers/podman/v3/pkg/bindings/containers" @@ -35,6 +36,11 @@ var ( webserverHealthCheckInterval = 5 * time.Second ) +const ( + webserverHealthCheckCmd = "curl --fail http://127.0.0.1:8080/health" + webserverHealthStatus = "\"healthy\"" +) + type Podman struct { projectDir string envFile string @@ -176,29 +182,31 @@ func (p *Podman) Stop() error { return err } -func (p *Podman) ExecCommand(containerID, command string) string { - command = strings.TrimLeft(command, " ") - command = strings.TrimRight(command, " ") +func (p *Podman) ExecCommand(containerID, command string) (string, error) { execConfig := new(handlers.ExecCreateConfig) execConfig.AttachStdout = true execConfig.AttachStderr = true - execConfig.Cmd = strings.Split(command, " ") + cmd, err := parseCommand(command) + if err != nil { + return "", err + } + execConfig.Cmd = cmd execID, err := p.podmanBind.ExecCreate(p.conn, containerID, execConfig) if err != nil { - return err.Error() + return "", err } r, w, err := os.Pipe() if err != nil { - return err.Error() + return "", err } defer r.Close() streams := new(containers.ExecStartAndAttachOptions).WithOutputStream(w).WithErrorStream(w).WithAttachOutput(true).WithAttachError(true) err = p.podmanBind.ExecStartAndAttach(p.conn, execID, streams) if err != nil { - return err.Error() + return "", err } outputC := make(chan string) @@ -211,7 +219,7 @@ func (p *Podman) ExecCommand(containerID, command string) string { }() w.Close() - return <-outputC + return <-outputC, nil } func (p *Podman) Run(args []string, user string) error { @@ -406,8 +414,8 @@ func (p *Podman) webserverHealthCheck() { if err != nil { goto sleep } - resp = p.ExecCommand(containerID, "airflow db check") - if strings.Contains(resp, "Connection successful.") { + resp, _ = p.ExecCommand(containerID, webserverHealthCheckCmd) + if strings.Contains(resp, webserverHealthStatus) { break } sleep: @@ -419,3 +427,7 @@ func (p *Podman) webserverHealthCheck() { func (p *Podman) getWebserverContainerID() (string, error) { return p.GetContainerID(config.CFG.WebserverContainerName.GetString()) } + +func parseCommand(command string) ([]string, error) { + return shlex.Split(command) +} diff --git a/airflow/podman_test.go b/airflow/podman_test.go index be7cbb971..83e9dd00f 100644 --- a/airflow/podman_test.go +++ b/airflow/podman_test.go @@ -289,3 +289,40 @@ func TestPodmanRunFailure(t *testing.T) { err := podmanMock.Run([]string{"db", "check"}, "user") assert.Contains(t, err.Error(), errPodman.Error()) } + +func TestParseCommand(t *testing.T) { + tests := []struct { + name string + command string + expectedOutput []string + expectedError string + }{ + { + name: "test with intentional extra white spaces", + command: ` airflow connections add "local_postgres" --conn-type "postgres" --conn-host 'test.db.sql.com' --conn-login 'user' --conn-password 'pass' --conn-schema 'schema' --conn-port 5432`, + expectedOutput: []string{"airflow", "connections", "add", "local_postgres", "--conn-type", "postgres", "--conn-host", "test.db.sql.com", "--conn-login", "user", "--conn-password", "pass", "--conn-schema", "schema", "--conn-port", "5432"}, + expectedError: "", + }, + { + name: "test with spaces inside quoted strings", + command: `airflow connections add "azure_batch_default" --conn-type "azure_batch" --conn-extra '{"account_url": ""}' --conn-host '' --conn-login '' --conn-password ''`, + expectedOutput: []string{"airflow", "connections", "add", "azure_batch_default", "--conn-type", "azure_batch", "--conn-extra", "{\"account_url\": \"\"}", "--conn-host", "", "--conn-login", "", "--conn-password", ""}, + expectedError: "", + }, + { + name: "test with intentional comments, new line and back slash", + command: "one two \"three four\" \"five \\\"six\\\"\" seven#eight # nine # ten\n eleven 'twelve\\' thirteen=13 fourteen/14", + expectedOutput: []string{"one", "two", "three four", "five \"six\"", "seven#eight", "eleven", "twelve\\", "thirteen=13", "fourteen/14"}, + expectedError: "", + }, + } + for _, tt := range tests { + output, err := parseCommand(tt.command) + if tt.expectedError != "" { + assert.EqualError(t, err, tt.expectedError) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tt.expectedOutput, output) + } +} diff --git a/go.mod b/go.mod index 6fc49984f..cc5a1e6d2 100644 --- a/go.mod +++ b/go.mod @@ -14,6 +14,7 @@ require ( github.com/docker/compose/v2 v2.2.2 github.com/docker/docker v20.10.8+incompatible github.com/fatih/camelcase v1.0.0 + github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 github.com/gorilla/websocket v1.4.2 github.com/iancoleman/strcase v0.0.0-20180726023541-3605ed457bf7 github.com/joho/godotenv v1.4.0 @@ -93,7 +94,6 @@ require ( github.com/google/go-cmp v0.5.6 // indirect github.com/google/go-intervals v0.0.2 // indirect github.com/google/gofuzz v1.1.0 // indirect - github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect github.com/google/uuid v1.3.0 // indirect github.com/gorilla/mux v1.8.0 // indirect github.com/gorilla/schema v1.2.0 // indirect diff --git a/settings/settings.go b/settings/settings.go index cc2b27ddb..80bb5fa3a 100644 --- a/settings/settings.go +++ b/settings/settings.go @@ -87,8 +87,12 @@ func AddVariables(containerEngine airflow.ContainerHandler, id string, version u airflowCommand += fmt.Sprintf("'%s'", variable.VariableValue) - containerEngine.ExecCommand(id, airflowCommand) - fmt.Printf("Added Variable: %s\n", variable.VariableName) + _, err := containerEngine.ExecCommand(id, airflowCommand) + if err != nil { + fmt.Printf("Error adding variable %s: %s\n", variable.VariableName, err.Error()) + } else { + fmt.Printf("Added Variable: %s\n", variable.VariableName) + } } } } @@ -131,7 +135,7 @@ func AddConnections(containerEngine airflow.ContainerHandler, id string, airflow } airflowCommand := baseListCmd - out := containerEngine.ExecCommand(id, airflowCommand) + out, _ := containerEngine.ExecCommand(id, airflowCommand) for idx := range connections { conn := connections[idx] if !objectValidator(0, conn.ConnID) { @@ -142,40 +146,23 @@ func AddConnections(containerEngine airflow.ContainerHandler, id string, airflow if strings.Contains(out, quotedConnID) { fmt.Printf("Found Connection: \"%s\"...replacing...\n", conn.ConnID) airflowCommand = fmt.Sprintf("%s %s \"%s\"", baseRmCmd, connIDArg, conn.ConnID) - containerEngine.ExecCommand(id, airflowCommand) + _, err := containerEngine.ExecCommand(id, airflowCommand) + if err != nil { + fmt.Printf("Error replacing connection %s: %s\n", conn.ConnID, err.Error()) + } } if !objectValidator(1, conn.ConnType, conn.ConnURI) { fmt.Printf("Skipping %s: conn_type or conn_uri must be specified.\n", conn.ConnID) continue } - airflowCommand = fmt.Sprintf("%s %s \"%s\" ", baseAddCmd, connIDArg, conn.ConnID) - if objectValidator(0, conn.ConnType) { - airflowCommand += fmt.Sprintf("%s \"%s\" ", connTypeArg, conn.ConnType) - } - if objectValidator(0, conn.ConnURI) { - airflowCommand += fmt.Sprintf("%s '%s' ", connURIArg, conn.ConnURI) - } - if objectValidator(0, conn.ConnExtra) { - airflowCommand += fmt.Sprintf("%s '%s' ", connExtraArg, conn.ConnExtra) - } - if objectValidator(0, conn.ConnHost) { - airflowCommand += fmt.Sprintf("%s '%s' ", connHostArg, conn.ConnHost) + airflowCommand = prepareAddConnCmd(baseAddCmd, connIDArg, connTypeArg, connURIArg, connExtraArg, connHostArg, connLoginArg, connPasswordArg, connSchemaArg, connPortArg, &conn) + _, err := containerEngine.ExecCommand(id, airflowCommand) + if err != nil { + fmt.Printf("Error adding connection %s: %s\n", conn.ConnID, err.Error()) + } else { + fmt.Printf("Added Connection: %s\n", conn.ConnID) } - if objectValidator(0, conn.ConnLogin) { - airflowCommand += fmt.Sprintf("%s '%s' ", connLoginArg, conn.ConnLogin) - } - if objectValidator(0, conn.ConnPassword) { - airflowCommand += fmt.Sprintf("%s '%s' ", connPasswordArg, conn.ConnPassword) - } - if objectValidator(0, conn.ConnSchema) { - airflowCommand += fmt.Sprintf("%s '%s' ", connSchemaArg, conn.ConnSchema) - } - if conn.ConnPort != 0 { - airflowCommand += fmt.Sprintf("%s %v", connPortArg, conn.ConnPort) - } - containerEngine.ExecCommand(id, airflowCommand) - fmt.Printf("Added Connection: %s\n", conn.ConnID) } } @@ -204,8 +191,12 @@ func AddPools(containerEngine airflow.ContainerHandler, id string, airflowVersio } else { airflowCommand += "" } - containerEngine.ExecCommand(id, airflowCommand) - fmt.Printf("Added Pool: %s\n", pool.PoolName) + _, err := containerEngine.ExecCommand(id, airflowCommand) + if err != nil { + fmt.Printf("Error adding pool %s: %s\n", pool.PoolName, err.Error()) + } else { + fmt.Printf("Added Pool: %s\n", pool.PoolName) + } } else { fmt.Printf("Skipping %s: Pool Slot must be set.\n", pool.PoolName) } @@ -213,6 +204,38 @@ func AddPools(containerEngine airflow.ContainerHandler, id string, airflowVersio } } +func prepareAddConnCmd(baseAddCmd, connIDArg, connTypeArg, connURIArg, connExtraArg, connHostArg, connLoginArg, connPasswordArg, connSchemaArg, connPortArg string, conn *Connection) string { + if conn == nil { + return "" + } + airflowCommand := fmt.Sprintf("%s %s \"%s\" ", baseAddCmd, connIDArg, conn.ConnID) + if objectValidator(0, conn.ConnType) { + airflowCommand += fmt.Sprintf("%s \"%s\" ", connTypeArg, conn.ConnType) + } + if objectValidator(0, conn.ConnURI) { + airflowCommand += fmt.Sprintf("%s '%s' ", connURIArg, conn.ConnURI) + } + if objectValidator(0, conn.ConnExtra) { + airflowCommand += fmt.Sprintf("%s '%s' ", connExtraArg, conn.ConnExtra) + } + if objectValidator(0, conn.ConnHost) { + airflowCommand += fmt.Sprintf("%s '%s' ", connHostArg, conn.ConnHost) + } + if objectValidator(0, conn.ConnLogin) { + airflowCommand += fmt.Sprintf("%s '%s' ", connLoginArg, conn.ConnLogin) + } + if objectValidator(0, conn.ConnPassword) { + airflowCommand += fmt.Sprintf("%s '%s' ", connPasswordArg, conn.ConnPassword) + } + if objectValidator(0, conn.ConnSchema) { + airflowCommand += fmt.Sprintf("%s '%s' ", connSchemaArg, conn.ConnSchema) + } + if conn.ConnPort != 0 { + airflowCommand += fmt.Sprintf("%s %v", connPortArg, conn.ConnPort) + } + return airflowCommand +} + func objectValidator(bound int, args ...string) bool { count := 0 for _, arg := range args { diff --git a/settings/settings_test.go b/settings/settings_test.go new file mode 100644 index 000000000..a482bb357 --- /dev/null +++ b/settings/settings_test.go @@ -0,0 +1,220 @@ +package settings + +import ( + "errors" + "io/ioutil" + "os" + "testing" + + "github.com/astronomer/astro-cli/airflow/mocks" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +var errContainer = errors.New("some container engine error") + +func TestAddConnectionsAirflowOne(t *testing.T) { + testConn := Connection{ + ConnID: "test-id", + ConnType: "test-type", + ConnHost: "test-host", + ConnSchema: "test-schema", + ConnLogin: "test-login", + ConnPassword: "test-password", + ConnPort: 1, + ConnURI: "test-uri", + ConnExtra: "test-extras", + } + settings.Airflow.Connections = []Connection{testConn} + + mockContainer := new(mocks.ContainerHandler) + mockContainer.On("ExecCommand", mock.Anything, mock.Anything).Return("no connections found", nil).Once() + expectedAddCmd := "airflow connections -a --conn_id \"test-id\" --conn_type \"test-type\" --conn_uri 'test-uri' --conn_extra 'test-extras' --conn_host 'test-host' --conn_login 'test-login' --conn_password 'test-password' --conn_schema 'test-schema' --conn_port 1" + mockContainer.On("ExecCommand", mock.Anything, expectedAddCmd).Return("connections added", nil).Once() + AddConnections(mockContainer, "test-conn-id", 1) + mockContainer.AssertExpectations(t) + + mockContainer.On("ExecCommand", mock.Anything, mock.Anything).Return("'test-id' connection exists", nil).Once() + expectedDelCmd := "airflow connections -d --conn_id \"test-id\"" + mockContainer.On("ExecCommand", mock.Anything, expectedDelCmd).Return("connection deleted", nil).Once() + mockContainer.On("ExecCommand", mock.Anything, expectedAddCmd).Return("connections added", nil).Once() + AddConnections(mockContainer, "test-conn-id", 1) + mockContainer.AssertExpectations(t) +} + +func TestAddConnectionsAirflowTwo(t *testing.T) { + testConn := Connection{ + ConnID: "test-id", + ConnType: "test-type", + ConnHost: "test-host", + ConnSchema: "test-schema", + ConnLogin: "test-login", + ConnPassword: "test-password", + ConnPort: 1, + ConnURI: "test-uri", + ConnExtra: "test-extras", + } + settings.Airflow.Connections = []Connection{testConn} + + mockContainer := new(mocks.ContainerHandler) + mockContainer.On("ExecCommand", mock.Anything, mock.Anything).Return("no connections found", nil).Once() + expectedAddCmd := "airflow connections add \"test-id\" --conn-type \"test-type\" --conn-uri 'test-uri' --conn-extra 'test-extras' --conn-host 'test-host' --conn-login 'test-login' --conn-password 'test-password' --conn-schema 'test-schema' --conn-port 1" + mockContainer.On("ExecCommand", mock.Anything, expectedAddCmd).Return("connections added", nil).Once() + AddConnections(mockContainer, "test-conn-id", 2) + mockContainer.AssertExpectations(t) + + mockContainer.On("ExecCommand", mock.Anything, mock.Anything).Return("'test-id' connection exists", nil).Once() + expectedDelCmd := "airflow connections delete \"test-id\"" + mockContainer.On("ExecCommand", mock.Anything, expectedDelCmd).Return("connection deleted", nil).Once() + mockContainer.On("ExecCommand", mock.Anything, expectedAddCmd).Return("connections added", nil).Once() + AddConnections(mockContainer, "test-conn-id", 2) + mockContainer.AssertExpectations(t) +} + +func TestAddConnectionsFailure(t *testing.T) { + testConn := Connection{ + ConnID: "test-id", + ConnType: "test-type", + ConnHost: "test-host", + ConnSchema: "test-schema", + ConnLogin: "test-login", + ConnPassword: "test-password", + ConnPort: 1, + ConnURI: "test-uri", + ConnExtra: "test-extras", + } + settings.Airflow.Connections = []Connection{testConn} + + mockContainer := new(mocks.ContainerHandler) + mockContainer.On("ExecCommand", mock.Anything, mock.Anything).Return("no connections found", nil).Once() + expectedAddCmd := "airflow connections add \"test-id\" --conn-type \"test-type\" --conn-uri 'test-uri' --conn-extra 'test-extras' --conn-host 'test-host' --conn-login 'test-login' --conn-password 'test-password' --conn-schema 'test-schema' --conn-port 1" + mockContainer.On("ExecCommand", mock.Anything, expectedAddCmd).Return("", errContainer).Once() + + stdout := os.Stdout + r, w, _ := os.Pipe() + os.Stdout = w + + AddConnections(mockContainer, "test-conn-id", 2) + + w.Close() + out, _ := ioutil.ReadAll(r) + os.Stdout = stdout + mockContainer.AssertExpectations(t) + assert.Contains(t, string(out), errContainer.Error()) + assert.Contains(t, string(out), "Error adding connection") +} + +func TestAddVariableAirflowOne(t *testing.T) { + settings.Airflow.Variables = Variables{ + { + VariableName: "test-var-name", + VariableValue: "test-var-val", + }, + } + + mockContainer := new(mocks.ContainerHandler) + expectedAddCmd := "airflow variables -s test-var-name 'test-var-val'" + mockContainer.On("ExecCommand", mock.Anything, expectedAddCmd).Return("variable added", nil).Once() + AddVariables(mockContainer, "test-conn-id", 1) + mockContainer.AssertExpectations(t) +} + +func TestAddVariableAirflowTwo(t *testing.T) { + settings.Airflow.Variables = Variables{ + { + VariableName: "test-var-name", + VariableValue: "test-var-val", + }, + } + + mockContainer := new(mocks.ContainerHandler) + expectedAddCmd := "airflow variables set test-var-name 'test-var-val'" + mockContainer.On("ExecCommand", mock.Anything, expectedAddCmd).Return("variable added", nil).Once() + AddVariables(mockContainer, "test-conn-id", 2) + mockContainer.AssertExpectations(t) +} + +func TestAddVariableFailure(t *testing.T) { + settings.Airflow.Variables = Variables{ + { + VariableName: "test-var-name", + VariableValue: "test-var-val", + }, + } + + mockContainer := new(mocks.ContainerHandler) + expectedAddCmd := "airflow variables set test-var-name 'test-var-val'" + mockContainer.On("ExecCommand", mock.Anything, expectedAddCmd).Return("", errContainer).Once() + + stdout := os.Stdout + r, w, _ := os.Pipe() + os.Stdout = w + + AddVariables(mockContainer, "test-conn-id", 2) + + w.Close() + out, _ := ioutil.ReadAll(r) + os.Stdout = stdout + mockContainer.AssertExpectations(t) + assert.Contains(t, string(out), errContainer.Error()) + assert.Contains(t, string(out), "Error adding variable") +} + +func TestAddPoolsAirflowOne(t *testing.T) { + settings.Airflow.Pools = Pools{ + { + PoolName: "test-pool-name", + PoolSlot: 1, + PoolDescription: "test-pool-description", + }, + } + + mockContainer := new(mocks.ContainerHandler) + expectedAddCmd := "airflow pool -s test-pool-name 1 'test-pool-description' " + mockContainer.On("ExecCommand", mock.Anything, expectedAddCmd).Return("pool added", nil).Once() + AddPools(mockContainer, "test-conn-id", 1) + mockContainer.AssertExpectations(t) +} + +func TestAddPoolsAirflowTwo(t *testing.T) { + settings.Airflow.Pools = Pools{ + { + PoolName: "test-pool-name", + PoolSlot: 1, + PoolDescription: "test-pool-description", + }, + } + + mockContainer := new(mocks.ContainerHandler) + expectedAddCmd := "airflow pools set test-pool-name 1 'test-pool-description' " + mockContainer.On("ExecCommand", mock.Anything, expectedAddCmd).Return("variable added", nil).Once() + AddPools(mockContainer, "test-conn-id", 2) + mockContainer.AssertExpectations(t) +} + +func TestAddPoolsFailure(t *testing.T) { + settings.Airflow.Pools = Pools{ + { + PoolName: "test-pool-name", + PoolSlot: 1, + PoolDescription: "test-pool-description", + }, + } + + mockContainer := new(mocks.ContainerHandler) + expectedAddCmd := "airflow pools set test-pool-name 1 'test-pool-description' " + mockContainer.On("ExecCommand", mock.Anything, expectedAddCmd).Return("", errContainer).Once() + + stdout := os.Stdout + r, w, _ := os.Pipe() + os.Stdout = w + + AddPools(mockContainer, "test-conn-id", 2) + + w.Close() + out, _ := ioutil.ReadAll(r) + os.Stdout = stdout + mockContainer.AssertExpectations(t) + assert.Contains(t, string(out), errContainer.Error()) + assert.Contains(t, string(out), "Error adding pool") +} diff --git a/settings/types.go b/settings/types.go index 4bae5ac64..4b110681d 100644 --- a/settings/types.go +++ b/settings/types.go @@ -1,7 +1,9 @@ package settings -// Connections contains structure of airflow connections -type Connections []struct { +type Connections []Connection + +// Connection contains structure of airflow connection +type Connection struct { ConnID string `mapstructure:"conn_id"` ConnType string `mapstructure:"conn_type"` ConnHost string `mapstructure:"conn_host"`