Skip to content

Commit

Permalink
Fix command parsing logic (#507)
Browse files Browse the repository at this point in the history
* fixed command parsing logic in podman

* code review changes

* added more test cases

* code review changes
  • Loading branch information
neel-astro committed Mar 8, 2022
1 parent e22843e commit 03e8fa8
Show file tree
Hide file tree
Showing 9 changed files with 352 additions and 51 deletions.
2 changes: 1 addition & 1 deletion airflow/container.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ type ContainerHandler interface {
Stop() error
PS() error
Run(args []string, user string) error
ExecCommand(containerID, command string) string
ExecCommand(containerID, command string) (string, error)
GetContainerID(containerName string) (string, error)
}

Expand Down
6 changes: 3 additions & 3 deletions airflow/docker.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,18 +240,18 @@ func (d *DockerCompose) Run(args []string, user string) error {
}

// ExecCommand executes a command on webserver container, and sends the response as string, this can be clubbed with Run()
func (d *DockerCompose) ExecCommand(containerID, command string) string {
func (d *DockerCompose) ExecCommand(containerID, command string) (string, error) {
cmd := exec.Command("docker", "exec", "-it", containerID, "bash", "-c", command)
cmd.Stdin = os.Stdin
cmd.Stderr = os.Stderr

out, err := cmd.Output()
if err != nil {
_ = errors.Wrapf(err, "error encountered")
return "", err
}

stringOut := string(out)
return stringOut
return stringOut, nil
}

func (d *DockerCompose) GetContainerID(containerName string) (string, error) {
Expand Down
11 changes: 9 additions & 2 deletions airflow/mocks/ContainerHandler.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

32 changes: 22 additions & 10 deletions airflow/podman.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/astronomer/astro-cli/airflow/include"
"github.com/astronomer/astro-cli/config"
"github.com/astronomer/astro-cli/messages"
"github.com/google/shlex"

"github.com/containers/podman/v3/pkg/api/handlers"
"github.com/containers/podman/v3/pkg/bindings/containers"
Expand All @@ -35,6 +36,11 @@ var (
webserverHealthCheckInterval = 5 * time.Second
)

const (
webserverHealthCheckCmd = "curl --fail http://127.0.0.1:8080/health"
webserverHealthStatus = "\"healthy\""
)

type Podman struct {
projectDir string
envFile string
Expand Down Expand Up @@ -176,29 +182,31 @@ func (p *Podman) Stop() error {
return err
}

func (p *Podman) ExecCommand(containerID, command string) string {
command = strings.TrimLeft(command, " ")
command = strings.TrimRight(command, " ")
func (p *Podman) ExecCommand(containerID, command string) (string, error) {
execConfig := new(handlers.ExecCreateConfig)
execConfig.AttachStdout = true
execConfig.AttachStderr = true
execConfig.Cmd = strings.Split(command, " ")
cmd, err := parseCommand(command)
if err != nil {
return "", err
}
execConfig.Cmd = cmd

execID, err := p.podmanBind.ExecCreate(p.conn, containerID, execConfig)
if err != nil {
return err.Error()
return "", err
}

r, w, err := os.Pipe()
if err != nil {
return err.Error()
return "", err
}
defer r.Close()

streams := new(containers.ExecStartAndAttachOptions).WithOutputStream(w).WithErrorStream(w).WithAttachOutput(true).WithAttachError(true)
err = p.podmanBind.ExecStartAndAttach(p.conn, execID, streams)
if err != nil {
return err.Error()
return "", err
}

outputC := make(chan string)
Expand All @@ -211,7 +219,7 @@ func (p *Podman) ExecCommand(containerID, command string) string {
}()

w.Close()
return <-outputC
return <-outputC, nil
}

func (p *Podman) Run(args []string, user string) error {
Expand Down Expand Up @@ -406,8 +414,8 @@ func (p *Podman) webserverHealthCheck() {
if err != nil {
goto sleep
}
resp = p.ExecCommand(containerID, "airflow db check")
if strings.Contains(resp, "Connection successful.") {
resp, _ = p.ExecCommand(containerID, webserverHealthCheckCmd)
if strings.Contains(resp, webserverHealthStatus) {
break
}
sleep:
Expand All @@ -419,3 +427,7 @@ func (p *Podman) webserverHealthCheck() {
func (p *Podman) getWebserverContainerID() (string, error) {
return p.GetContainerID(config.CFG.WebserverContainerName.GetString())
}

func parseCommand(command string) ([]string, error) {
return shlex.Split(command)
}
37 changes: 37 additions & 0 deletions airflow/podman_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -289,3 +289,40 @@ func TestPodmanRunFailure(t *testing.T) {
err := podmanMock.Run([]string{"db", "check"}, "user")
assert.Contains(t, err.Error(), errPodman.Error())
}

func TestParseCommand(t *testing.T) {
tests := []struct {
name string
command string
expectedOutput []string
expectedError string
}{
{
name: "test with intentional extra white spaces",
command: ` airflow connections add "local_postgres" --conn-type "postgres" --conn-host 'test.db.sql.com' --conn-login 'user' --conn-password 'pass' --conn-schema 'schema' --conn-port 5432`,
expectedOutput: []string{"airflow", "connections", "add", "local_postgres", "--conn-type", "postgres", "--conn-host", "test.db.sql.com", "--conn-login", "user", "--conn-password", "pass", "--conn-schema", "schema", "--conn-port", "5432"},
expectedError: "",
},
{
name: "test with spaces inside quoted strings",
command: `airflow connections add "azure_batch_default" --conn-type "azure_batch" --conn-extra '{"account_url": "<ACCOUNT_URL>"}' --conn-host '<ACCOUNT_URL>' --conn-login '<ACCOUNT_NAME>' --conn-password '<ACCOUNT_KEY>'`,
expectedOutput: []string{"airflow", "connections", "add", "azure_batch_default", "--conn-type", "azure_batch", "--conn-extra", "{\"account_url\": \"<ACCOUNT_URL>\"}", "--conn-host", "<ACCOUNT_URL>", "--conn-login", "<ACCOUNT_NAME>", "--conn-password", "<ACCOUNT_KEY>"},
expectedError: "",
},
{
name: "test with intentional comments, new line and back slash",
command: "one two \"three four\" \"five \\\"six\\\"\" seven#eight # nine # ten\n eleven 'twelve\\' thirteen=13 fourteen/14",
expectedOutput: []string{"one", "two", "three four", "five \"six\"", "seven#eight", "eleven", "twelve\\", "thirteen=13", "fourteen/14"},
expectedError: "",
},
}
for _, tt := range tests {
output, err := parseCommand(tt.command)
if tt.expectedError != "" {
assert.EqualError(t, err, tt.expectedError)
} else {
assert.NoError(t, err)
}
assert.Equal(t, tt.expectedOutput, output)
}
}
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ require (
github.com/docker/compose/v2 v2.2.2
github.com/docker/docker v20.10.8+incompatible
github.com/fatih/camelcase v1.0.0
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510
github.com/gorilla/websocket v1.4.2
github.com/iancoleman/strcase v0.0.0-20180726023541-3605ed457bf7
github.com/joho/godotenv v1.4.0
Expand Down Expand Up @@ -93,7 +94,6 @@ require (
github.com/google/go-cmp v0.5.6 // indirect
github.com/google/go-intervals v0.0.2 // indirect
github.com/google/gofuzz v1.1.0 // indirect
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect
github.com/google/uuid v1.3.0 // indirect
github.com/gorilla/mux v1.8.0 // indirect
github.com/gorilla/schema v1.2.0 // indirect
Expand Down
87 changes: 55 additions & 32 deletions settings/settings.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,12 @@ func AddVariables(containerEngine airflow.ContainerHandler, id string, version u

airflowCommand += fmt.Sprintf("'%s'", variable.VariableValue)

containerEngine.ExecCommand(id, airflowCommand)
fmt.Printf("Added Variable: %s\n", variable.VariableName)
_, err := containerEngine.ExecCommand(id, airflowCommand)
if err != nil {
fmt.Printf("Error adding variable %s: %s\n", variable.VariableName, err.Error())
} else {
fmt.Printf("Added Variable: %s\n", variable.VariableName)
}
}
}
}
Expand Down Expand Up @@ -131,7 +135,7 @@ func AddConnections(containerEngine airflow.ContainerHandler, id string, airflow
}

airflowCommand := baseListCmd
out := containerEngine.ExecCommand(id, airflowCommand)
out, _ := containerEngine.ExecCommand(id, airflowCommand)
for idx := range connections {
conn := connections[idx]
if !objectValidator(0, conn.ConnID) {
Expand All @@ -142,40 +146,23 @@ func AddConnections(containerEngine airflow.ContainerHandler, id string, airflow
if strings.Contains(out, quotedConnID) {
fmt.Printf("Found Connection: \"%s\"...replacing...\n", conn.ConnID)
airflowCommand = fmt.Sprintf("%s %s \"%s\"", baseRmCmd, connIDArg, conn.ConnID)
containerEngine.ExecCommand(id, airflowCommand)
_, err := containerEngine.ExecCommand(id, airflowCommand)
if err != nil {
fmt.Printf("Error replacing connection %s: %s\n", conn.ConnID, err.Error())
}
}

if !objectValidator(1, conn.ConnType, conn.ConnURI) {
fmt.Printf("Skipping %s: conn_type or conn_uri must be specified.\n", conn.ConnID)
continue
}
airflowCommand = fmt.Sprintf("%s %s \"%s\" ", baseAddCmd, connIDArg, conn.ConnID)
if objectValidator(0, conn.ConnType) {
airflowCommand += fmt.Sprintf("%s \"%s\" ", connTypeArg, conn.ConnType)
}
if objectValidator(0, conn.ConnURI) {
airflowCommand += fmt.Sprintf("%s '%s' ", connURIArg, conn.ConnURI)
}
if objectValidator(0, conn.ConnExtra) {
airflowCommand += fmt.Sprintf("%s '%s' ", connExtraArg, conn.ConnExtra)
}
if objectValidator(0, conn.ConnHost) {
airflowCommand += fmt.Sprintf("%s '%s' ", connHostArg, conn.ConnHost)
airflowCommand = prepareAddConnCmd(baseAddCmd, connIDArg, connTypeArg, connURIArg, connExtraArg, connHostArg, connLoginArg, connPasswordArg, connSchemaArg, connPortArg, &conn)
_, err := containerEngine.ExecCommand(id, airflowCommand)
if err != nil {
fmt.Printf("Error adding connection %s: %s\n", conn.ConnID, err.Error())
} else {
fmt.Printf("Added Connection: %s\n", conn.ConnID)
}
if objectValidator(0, conn.ConnLogin) {
airflowCommand += fmt.Sprintf("%s '%s' ", connLoginArg, conn.ConnLogin)
}
if objectValidator(0, conn.ConnPassword) {
airflowCommand += fmt.Sprintf("%s '%s' ", connPasswordArg, conn.ConnPassword)
}
if objectValidator(0, conn.ConnSchema) {
airflowCommand += fmt.Sprintf("%s '%s' ", connSchemaArg, conn.ConnSchema)
}
if conn.ConnPort != 0 {
airflowCommand += fmt.Sprintf("%s %v", connPortArg, conn.ConnPort)
}
containerEngine.ExecCommand(id, airflowCommand)
fmt.Printf("Added Connection: %s\n", conn.ConnID)
}
}

Expand Down Expand Up @@ -204,15 +191,51 @@ func AddPools(containerEngine airflow.ContainerHandler, id string, airflowVersio
} else {
airflowCommand += ""
}
containerEngine.ExecCommand(id, airflowCommand)
fmt.Printf("Added Pool: %s\n", pool.PoolName)
_, err := containerEngine.ExecCommand(id, airflowCommand)
if err != nil {
fmt.Printf("Error adding pool %s: %s\n", pool.PoolName, err.Error())
} else {
fmt.Printf("Added Pool: %s\n", pool.PoolName)
}
} else {
fmt.Printf("Skipping %s: Pool Slot must be set.\n", pool.PoolName)
}
}
}
}

func prepareAddConnCmd(baseAddCmd, connIDArg, connTypeArg, connURIArg, connExtraArg, connHostArg, connLoginArg, connPasswordArg, connSchemaArg, connPortArg string, conn *Connection) string {
if conn == nil {
return ""
}
airflowCommand := fmt.Sprintf("%s %s \"%s\" ", baseAddCmd, connIDArg, conn.ConnID)
if objectValidator(0, conn.ConnType) {
airflowCommand += fmt.Sprintf("%s \"%s\" ", connTypeArg, conn.ConnType)
}
if objectValidator(0, conn.ConnURI) {
airflowCommand += fmt.Sprintf("%s '%s' ", connURIArg, conn.ConnURI)
}
if objectValidator(0, conn.ConnExtra) {
airflowCommand += fmt.Sprintf("%s '%s' ", connExtraArg, conn.ConnExtra)
}
if objectValidator(0, conn.ConnHost) {
airflowCommand += fmt.Sprintf("%s '%s' ", connHostArg, conn.ConnHost)
}
if objectValidator(0, conn.ConnLogin) {
airflowCommand += fmt.Sprintf("%s '%s' ", connLoginArg, conn.ConnLogin)
}
if objectValidator(0, conn.ConnPassword) {
airflowCommand += fmt.Sprintf("%s '%s' ", connPasswordArg, conn.ConnPassword)
}
if objectValidator(0, conn.ConnSchema) {
airflowCommand += fmt.Sprintf("%s '%s' ", connSchemaArg, conn.ConnSchema)
}
if conn.ConnPort != 0 {
airflowCommand += fmt.Sprintf("%s %v", connPortArg, conn.ConnPort)
}
return airflowCommand
}

func objectValidator(bound int, args ...string) bool {
count := 0
for _, arg := range args {
Expand Down
Loading

0 comments on commit 03e8fa8

Please sign in to comment.