diff --git a/cmd/zoekt-sourcegraph-indexserver/main.go b/cmd/zoekt-sourcegraph-indexserver/main.go index aff3d2c95..c1ac19f5b 100644 --- a/cmd/zoekt-sourcegraph-indexserver/main.go +++ b/cmd/zoekt-sourcegraph-indexserver/main.go @@ -1213,7 +1213,7 @@ func (rc *rootConfig) registerRootFlags(fs *flag.FlagSet) { fs.IntVar(&rc.blockProfileRate, "block_profile_rate", getEnvWithDefaultInt("BLOCK_PROFILE_RATE", -1), "Sampling rate of Go's block profiler in nanoseconds. Values <=0 disable the blocking profiler Var(default). A value of 1 includes every blocking event. See https://pkg.go.dev/runtime#SetBlockProfileRate") fs.DurationVar(&rc.backoffDuration, "backoff_duration", getEnvWithDefaultDuration("BACKOFF_DURATION", 10*time.Minute), "for the given duration we backoff from enqueue operations for a repository that's failed its previous indexing attempt. Consecutive failures increase the duration of the delay linearly up to the maxBackoffDuration. A negative value disables indexing backoff.") fs.DurationVar(&rc.maxBackoffDuration, "max_backoff_duration", getEnvWithDefaultDuration("MAX_BACKOFF_DURATION", 120*time.Minute), "the maximum duration to backoff from enqueueing a repo for indexing. A negative value disables indexing backoff.") - fs.BoolVar(&rc.useGRPC, "use_grpc", getEnvWithDefaultBool("GRPC_ENABLED", false), "use the gRPC API to talk to Sourcegraph") + fs.BoolVar(&rc.useGRPC, "use_grpc", mustGetBoolFromEnvironmentVariables([]string{"GRPC_ENABLED", "SG_FEATURE_FLAG_GRPC"}, true), "use the gRPC API to talk to Sourcegraph") // flags related to shard merging fs.DurationVar(&rc.vacuumInterval, "vacuum_interval", getEnvWithDefaultDuration("SRC_VACUUM_INTERVAL", 24*time.Hour), "run vacuum this often") @@ -1596,3 +1596,36 @@ func main() { log.Fatal(err) } } + +// mustGetBoolFromEnvironmentVariables is like getBoolFromEnvironmentVariables, but it panics +// if any of the provided environment variables fails to parse as a boolean. +func mustGetBoolFromEnvironmentVariables(envVarNames []string, defaultBool bool) bool { + value, err := getBoolFromEnvironmentVariables(envVarNames, defaultBool) + if err != nil { + panic(err) + } + + return value +} + +// getBoolFromEnvironmentVariables returns the boolean defined by the first environment +// variable listed in envVarNames that is set in the current process environment, or the defaultBool if none are set. +// +// An error is returned of the provided environment variables fails to parse as a boolean. +func getBoolFromEnvironmentVariables(envVarNames []string, defaultBool bool) (bool, error) { + for _, envVar := range envVarNames { + v := os.Getenv(envVar) + if v == "" { + continue + } + + b, err := strconv.ParseBool(v) + if err != nil { + return false, fmt.Errorf("parsing environment variable %q to boolean: %v", envVar, err) + } + + return b, nil + } + + return defaultBool, nil +} diff --git a/cmd/zoekt-sourcegraph-indexserver/main_test.go b/cmd/zoekt-sourcegraph-indexserver/main_test.go index 6981601ac..0d3f65cd3 100644 --- a/cmd/zoekt-sourcegraph-indexserver/main_test.go +++ b/cmd/zoekt-sourcegraph-indexserver/main_test.go @@ -286,6 +286,103 @@ func TestDefaultGRPCServiceConfigurationSyntax(t *testing.T) { } } +func TestGetBoolFromEnvironmentVariables(t *testing.T) { + testCases := []struct { + name string + envVarsToSet map[string]string + + envVarNames []string + defaultBool bool + + wantBool bool + wantErr bool + }{ + { + name: "respect default value: true", + + envVarsToSet: map[string]string{}, + + envVarNames: []string{"FOO", "BAR"}, + defaultBool: true, + + wantBool: true, + }, + { + name: "respect default value: false", + + envVarsToSet: map[string]string{}, + + envVarNames: []string{"FOO", "BAR"}, + defaultBool: false, + + wantBool: false, + }, + { + name: "read from environment", + + envVarsToSet: map[string]string{"FOO": "1"}, + + envVarNames: []string{"FOO"}, + defaultBool: false, + + wantBool: true, + }, + { + name: "read from first env var that is set", + + envVarsToSet: map[string]string{ + "BAR": "false", + "BAZ": "true", + }, + + envVarNames: []string{"FOO", "BAR", "BAZ"}, + defaultBool: true, + + wantBool: false, + }, + + { + name: "should error for invalid input", + + envVarsToSet: map[string]string{"INVALID": "not a boolean"}, + + envVarNames: []string{"INVALID"}, + defaultBool: false, + + wantErr: true, + }, + } + + for _, tc := range testCases { + t.Run("", func(t *testing.T) { + // Prepare the environment by loading all the appropriate environment variables + for _, v := range tc.envVarNames { + _ = os.Unsetenv(v) + } + + for k, _ := range tc.envVarsToSet { + _ = os.Unsetenv(k) + } + + for k, v := range tc.envVarsToSet { + t.Setenv(k, v) + } + + // Run the test + got, err := getBoolFromEnvironmentVariables(tc.envVarNames, tc.defaultBool) + + // Examine the results + if tc.wantErr != (err != nil) { + t.Fatalf("unexpected error (wantErr = %t): %v", tc.wantErr, err) + } + + if got != tc.wantBool { + t.Errorf("got %v, want %v", got, tc.wantBool) + } + }) + } +} + func TestAddDefaultPort(t *testing.T) { tests := []struct { name string