diff --git a/env/env.go b/env/env.go index 31b5c93..c1e3127 100644 --- a/env/env.go +++ b/env/env.go @@ -3,15 +3,15 @@ package env import ( "fmt" "os" - "slices" "strconv" "strings" + "sync" ) var ( - defaults = map[string]interface{}{} - requiredKeys []string - types = map[string]string{} + types = sync.Map{} + defaults = sync.Map{} + requiredKeys = sync.Map{} ) // Exists return true if env var is defined @@ -25,16 +25,16 @@ func MustExists(key string) { if !Exists(key) { panic(fmt.Sprintf("required environment variable `%s` does not exist", key)) } - if !slices.Contains(requiredKeys, key) { - requiredKeys = append(requiredKeys, key) + if _, ok := requiredKeys.Load(key); !ok { + requiredKeys.Store(key, true) } } // Get env var or fallback func Get(key, fallback string) string { - defaults[key] = fallback - if _, ok := types[key]; !ok { - types[key] = "string" + defaults.Store(key, fallback) + if _, ok := types.Load(key); !ok { + types.Store(key, "string") } if v, ok := os.LookupEnv(key); ok { return v @@ -50,8 +50,8 @@ func MustGet(key string) string { // GetInt env var or fallback as int func GetInt(key string, fallback int) int { - if _, ok := types[key]; !ok { - types[key] = "int" + if _, ok := types.Load(key); !ok { + types.Store(key, "int") } if value, err := strconv.Atoi(Get(key, "")); err == nil { return value @@ -67,8 +67,8 @@ func MustGetInt(key string) int { // GetInt64 env var or fallback as int64 func GetInt64(key string, fallback int64) int64 { - if _, ok := types[key]; !ok { - types[key] = "int64" + if _, ok := types.Load(key); !ok { + types.Store(key, "int64") } if value, err := strconv.ParseInt(Get(key, ""), 10, 64); err == nil { return value @@ -84,8 +84,8 @@ func MustGetInt64(key string) int64 { // GetFloat64 env var or fallback as float64 func GetFloat64(key string, fallback float64) float64 { - if _, ok := types[key]; !ok { - types[key] = "float64" + if _, ok := types.Load(key); !ok { + types.Store(key, "float64") } if value, err := strconv.ParseFloat(Get(key, ""), 64); err == nil { return value @@ -101,8 +101,8 @@ func MustGetFloat64(key string) float64 { // GetBool env var or fallback as bool func GetBool(key string, fallback bool) bool { - if _, ok := types[key]; !ok { - types[key] = "bool" + if _, ok := types.Load(key); !ok { + types.Store(key, "bool") } if val, err := strconv.ParseBool(Get(key, "")); err == nil { return val @@ -118,8 +118,8 @@ func MustGetBool(key string) bool { // GetStringSlice env var or fallback as []string func GetStringSlice(key string, fallback []string) []string { - if _, ok := types[key]; !ok { - types[key] = "[]string" + if _, ok := types.Load(key); !ok { + types.Store(key, "[]string") } if v := Get(key, ""); v != "" { return strings.Split(v, ",") @@ -135,8 +135,8 @@ func MustGetStringSlice(key string) []string { // GetIntSlice env var or fallback as []string func GetIntSlice(key string, fallback []int) []int { - if _, ok := types[key]; !ok { - types[key] = "[]int" + if _, ok := types.Load(key); !ok { + types.Store(key, "[]int") } if v := Get(key, ""); v != "" { elements := strings.Split(v, ",") @@ -159,20 +159,46 @@ func MustGetGetIntSlice(key string) []int { } func RequiredKeys() []string { - return requiredKeys + var ret []string + requiredKeys.Range(func(key, value interface{}) bool { + if v, ok := key.(string); ok { + ret = append(ret, v) + } + return true + }) + return ret } func Defaults() map[string]interface{} { - return defaults + ret := map[string]interface{}{} + defaults.Range(func(key, value interface{}) bool { + if k, ok := key.(string); ok { + ret[k] = value + } + return true + }) + return ret } func Types() map[string]string { - return types + ret := map[string]string{} + types.Range(func(key, value interface{}) bool { + if v, ok := value.(string); ok { + if k, ok := key.(string); ok { + ret[k] = v + } + } + return true + }) + return ret } func TypeOf(key string) string { - if v, ok := types[key]; ok { - return v + if v, ok := types.Load(key); ok { + if s, ok := v.(string); ok { + return s + } + return "" } return "" } diff --git a/env/env_test.go b/env/env_test.go new file mode 100644 index 0000000..dc1f1a6 --- /dev/null +++ b/env/env_test.go @@ -0,0 +1,85 @@ +package env_test + +import ( + "os" + "testing" + + "github.com/foomo/keel/env" + "github.com/stretchr/testify/assert" +) + +func TestMain(m *testing.M) { + _ = os.Setenv("TEST_ENV_INT", "3") + _ = os.Setenv("TEST_ENV_STRING", "test") + m.Run() +} + +func TextExists(t *testing.T) { + t.Parallel() + assert.True(t, env.Exists("TEST_ENV_EXISTS")) + assert.False(t, env.Exists("TEST_ENV_NOOP")) +} + +func TestMustExists(t *testing.T) { + t.Parallel() + assert.True(t, env.Exists("TEST_ENV_STRING")) + assert.Panics(t, func() { + env.MustExists("TEST_ENV_NOOP") + }) +} + +func TestGet(t *testing.T) { + t.Parallel() + + assert.Equal(t, "test", env.Get("TEST_ENV_STRING", "fallback")) + assert.Equal(t, "fallback", env.Get("TEST_ENV_NOOP", "fallback")) +} + +func TestMustGet(t *testing.T) { + t.Parallel() + assert.Equal(t, "test", env.Get("TEST_ENV_STRING", "fallback")) + assert.Panics(t, func() { + env.MustGet("TEST_ENV_NOOP") + }) +} + +func TestGetInt(t *testing.T) { + t.Parallel() + assert.Equal(t, 3, env.GetInt("TEST_ENV_INT", 4)) + assert.Equal(t, 4, env.GetInt("TEST_ENV_NOOP", 4)) +} + +func TestMustGetInt(t *testing.T) { + t.Parallel() + assert.Equal(t, 3, env.GetInt("TEST_ENV_INT", 4)) + assert.Panics(t, func() { + env.MustGet("TEST_ENV_NOOP") + }) +} + +func TestRequiredKeys(t *testing.T) { + t.Parallel() + env.MustExists("TEST_ENV_STRING") + assert.Contains(t, env.RequiredKeys(), "TEST_ENV_STRING") +} + +func TestDefaults(t *testing.T) { + t.Parallel() + env.Get("TEST_ENV_STRING", "test") + assert.Contains(t, env.Defaults(), "TEST_ENV_STRING") +} + +func TestTypes(t *testing.T) { + t.Parallel() + env.Get("TEST_ENV_STRING", "test") + env.GetInt("TEST_ENV_INT", 3) + assert.NotEmpty(t, env.Types()) +} + +func TestTypeOf(t *testing.T) { + t.Parallel() + env.Get("TEST_ENV_STRING", "test") + env.GetInt("TEST_ENV_INT", 3) + assert.Equal(t, "string", env.TypeOf("TEST_ENV_STRING")) + assert.Equal(t, "int", env.TypeOf("TEST_ENV_INT")) +} diff --git a/env/readme.go b/env/readme.go index 909d1bf..752ea5d 100644 --- a/env/readme.go +++ b/env/readme.go @@ -11,23 +11,29 @@ func Readme() string { md := &markdown.Markdown{} { - for key, fallback := range defaults { - rows = append(rows, []string{ - markdown.Code(key), - markdown.Code(TypeOf(key)), - "", - markdown.Code(fmt.Sprintf("%v", fallback)), - }) - } + defaults.Range(func(key, fallback any) bool { + if k, ok := key.(string); ok { + rows = append(rows, []string{ + markdown.Code(k), + markdown.Code(TypeOf(k)), + "", + markdown.Code(fmt.Sprintf("%v", fallback)), + }) + } + return true + }) - for _, key := range requiredKeys { - rows = append(rows, []string{ - markdown.Code(key), - markdown.Code(TypeOf(key)), - markdown.Code("true"), - "", - }) - } + requiredKeys.Range(func(key, fallback any) bool { + if k, ok := key.(string); ok { + rows = append(rows, []string{ + markdown.Code(k), + markdown.Code(TypeOf(k)), + markdown.Code("true"), + "", + }) + } + return true + }) } if len(rows) > 0 { diff --git a/service/helper_test.go b/service/helper_test.go index 67a5568..773696f 100644 --- a/service/helper_test.go +++ b/service/helper_test.go @@ -21,7 +21,7 @@ func waitFor(addr string) { } func httpGet(url string) string { - resp, err := http.Get(url) //nolint:all + resp, err := http.Get(url) if err != nil { panic(err.Error()) }