fix: use sync map

This commit is contained in:
Kevin Franklin Kim 2025-05-26 21:55:06 +02:00
parent d95f8df187
commit 056f7ab54c
No known key found for this signature in database
4 changed files with 160 additions and 43 deletions

78
env/env.go vendored
View File

@ -3,15 +3,15 @@ package env
import ( import (
"fmt" "fmt"
"os" "os"
"slices"
"strconv" "strconv"
"strings" "strings"
"sync"
) )
var ( var (
defaults = map[string]interface{}{} types = sync.Map{}
requiredKeys []string defaults = sync.Map{}
types = map[string]string{} requiredKeys = sync.Map{}
) )
// Exists return true if env var is defined // Exists return true if env var is defined
@ -25,16 +25,16 @@ func MustExists(key string) {
if !Exists(key) { if !Exists(key) {
panic(fmt.Sprintf("required environment variable `%s` does not exist", key)) panic(fmt.Sprintf("required environment variable `%s` does not exist", key))
} }
if !slices.Contains(requiredKeys, key) { if _, ok := requiredKeys.Load(key); !ok {
requiredKeys = append(requiredKeys, key) requiredKeys.Store(key, true)
} }
} }
// Get env var or fallback // Get env var or fallback
func Get(key, fallback string) string { func Get(key, fallback string) string {
defaults[key] = fallback defaults.Store(key, fallback)
if _, ok := types[key]; !ok { if _, ok := types.Load(key); !ok {
types[key] = "string" types.Store(key, "string")
} }
if v, ok := os.LookupEnv(key); ok { if v, ok := os.LookupEnv(key); ok {
return v return v
@ -50,8 +50,8 @@ func MustGet(key string) string {
// GetInt env var or fallback as int // GetInt env var or fallback as int
func GetInt(key string, fallback int) int { func GetInt(key string, fallback int) int {
if _, ok := types[key]; !ok { if _, ok := types.Load(key); !ok {
types[key] = "int" types.Store(key, "int")
} }
if value, err := strconv.Atoi(Get(key, "")); err == nil { if value, err := strconv.Atoi(Get(key, "")); err == nil {
return value return value
@ -67,8 +67,8 @@ func MustGetInt(key string) int {
// GetInt64 env var or fallback as int64 // GetInt64 env var or fallback as int64
func GetInt64(key string, fallback int64) int64 { func GetInt64(key string, fallback int64) int64 {
if _, ok := types[key]; !ok { if _, ok := types.Load(key); !ok {
types[key] = "int64" types.Store(key, "int64")
} }
if value, err := strconv.ParseInt(Get(key, ""), 10, 64); err == nil { if value, err := strconv.ParseInt(Get(key, ""), 10, 64); err == nil {
return value return value
@ -84,8 +84,8 @@ func MustGetInt64(key string) int64 {
// GetFloat64 env var or fallback as float64 // GetFloat64 env var or fallback as float64
func GetFloat64(key string, fallback float64) float64 { func GetFloat64(key string, fallback float64) float64 {
if _, ok := types[key]; !ok { if _, ok := types.Load(key); !ok {
types[key] = "float64" types.Store(key, "float64")
} }
if value, err := strconv.ParseFloat(Get(key, ""), 64); err == nil { if value, err := strconv.ParseFloat(Get(key, ""), 64); err == nil {
return value return value
@ -101,8 +101,8 @@ func MustGetFloat64(key string) float64 {
// GetBool env var or fallback as bool // GetBool env var or fallback as bool
func GetBool(key string, fallback bool) bool { func GetBool(key string, fallback bool) bool {
if _, ok := types[key]; !ok { if _, ok := types.Load(key); !ok {
types[key] = "bool" types.Store(key, "bool")
} }
if val, err := strconv.ParseBool(Get(key, "")); err == nil { if val, err := strconv.ParseBool(Get(key, "")); err == nil {
return val return val
@ -118,8 +118,8 @@ func MustGetBool(key string) bool {
// GetStringSlice env var or fallback as []string // GetStringSlice env var or fallback as []string
func GetStringSlice(key string, fallback []string) []string { func GetStringSlice(key string, fallback []string) []string {
if _, ok := types[key]; !ok { if _, ok := types.Load(key); !ok {
types[key] = "[]string" types.Store(key, "[]string")
} }
if v := Get(key, ""); v != "" { if v := Get(key, ""); v != "" {
return strings.Split(v, ",") return strings.Split(v, ",")
@ -135,8 +135,8 @@ func MustGetStringSlice(key string) []string {
// GetIntSlice env var or fallback as []string // GetIntSlice env var or fallback as []string
func GetIntSlice(key string, fallback []int) []int { func GetIntSlice(key string, fallback []int) []int {
if _, ok := types[key]; !ok { if _, ok := types.Load(key); !ok {
types[key] = "[]int" types.Store(key, "[]int")
} }
if v := Get(key, ""); v != "" { if v := Get(key, ""); v != "" {
elements := strings.Split(v, ",") elements := strings.Split(v, ",")
@ -159,20 +159,46 @@ func MustGetGetIntSlice(key string) []int {
} }
func RequiredKeys() []string { func RequiredKeys() []string {
return requiredKeys var ret []string
requiredKeys.Range(func(key, value interface{}) bool {
if v, ok := key.(string); ok {
ret = append(ret, v)
}
return true
})
return ret
} }
func Defaults() map[string]interface{} { func Defaults() map[string]interface{} {
return defaults ret := map[string]interface{}{}
defaults.Range(func(key, value interface{}) bool {
if k, ok := key.(string); ok {
ret[k] = value
}
return true
})
return ret
} }
func Types() map[string]string { func Types() map[string]string {
return types ret := map[string]string{}
types.Range(func(key, value interface{}) bool {
if v, ok := value.(string); ok {
if k, ok := key.(string); ok {
ret[k] = v
}
}
return true
})
return ret
} }
func TypeOf(key string) string { func TypeOf(key string) string {
if v, ok := types[key]; ok { if v, ok := types.Load(key); ok {
return v if s, ok := v.(string); ok {
return s
}
return ""
} }
return "" return ""
} }

85
env/env_test.go vendored Normal file
View File

@ -0,0 +1,85 @@
package env_test
import (
"os"
"testing"
"github.com/foomo/keel/env"
"github.com/stretchr/testify/assert"
)
func TestMain(m *testing.M) {
_ = os.Setenv("TEST_ENV_INT", "3")
_ = os.Setenv("TEST_ENV_STRING", "test")
m.Run()
}
func TextExists(t *testing.T) {
t.Parallel()
assert.True(t, env.Exists("TEST_ENV_EXISTS"))
assert.False(t, env.Exists("TEST_ENV_NOOP"))
}
func TestMustExists(t *testing.T) {
t.Parallel()
assert.True(t, env.Exists("TEST_ENV_STRING"))
assert.Panics(t, func() {
env.MustExists("TEST_ENV_NOOP")
})
}
func TestGet(t *testing.T) {
t.Parallel()
assert.Equal(t, "test", env.Get("TEST_ENV_STRING", "fallback"))
assert.Equal(t, "fallback", env.Get("TEST_ENV_NOOP", "fallback"))
}
func TestMustGet(t *testing.T) {
t.Parallel()
assert.Equal(t, "test", env.Get("TEST_ENV_STRING", "fallback"))
assert.Panics(t, func() {
env.MustGet("TEST_ENV_NOOP")
})
}
func TestGetInt(t *testing.T) {
t.Parallel()
assert.Equal(t, 3, env.GetInt("TEST_ENV_INT", 4))
assert.Equal(t, 4, env.GetInt("TEST_ENV_NOOP", 4))
}
func TestMustGetInt(t *testing.T) {
t.Parallel()
assert.Equal(t, 3, env.GetInt("TEST_ENV_INT", 4))
assert.Panics(t, func() {
env.MustGet("TEST_ENV_NOOP")
})
}
func TestRequiredKeys(t *testing.T) {
t.Parallel()
env.MustExists("TEST_ENV_STRING")
assert.Contains(t, env.RequiredKeys(), "TEST_ENV_STRING")
}
func TestDefaults(t *testing.T) {
t.Parallel()
env.Get("TEST_ENV_STRING", "test")
assert.Contains(t, env.Defaults(), "TEST_ENV_STRING")
}
func TestTypes(t *testing.T) {
t.Parallel()
env.Get("TEST_ENV_STRING", "test")
env.GetInt("TEST_ENV_INT", 3)
assert.NotEmpty(t, env.Types())
}
func TestTypeOf(t *testing.T) {
t.Parallel()
env.Get("TEST_ENV_STRING", "test")
env.GetInt("TEST_ENV_INT", 3)
assert.Equal(t, "string", env.TypeOf("TEST_ENV_STRING"))
assert.Equal(t, "int", env.TypeOf("TEST_ENV_INT"))
}

38
env/readme.go vendored
View File

@ -11,23 +11,29 @@ func Readme() string {
md := &markdown.Markdown{} md := &markdown.Markdown{}
{ {
for key, fallback := range defaults { defaults.Range(func(key, fallback any) bool {
rows = append(rows, []string{ if k, ok := key.(string); ok {
markdown.Code(key), rows = append(rows, []string{
markdown.Code(TypeOf(key)), markdown.Code(k),
"", markdown.Code(TypeOf(k)),
markdown.Code(fmt.Sprintf("%v", fallback)), "",
}) markdown.Code(fmt.Sprintf("%v", fallback)),
} })
}
return true
})
for _, key := range requiredKeys { requiredKeys.Range(func(key, fallback any) bool {
rows = append(rows, []string{ if k, ok := key.(string); ok {
markdown.Code(key), rows = append(rows, []string{
markdown.Code(TypeOf(key)), markdown.Code(k),
markdown.Code("true"), markdown.Code(TypeOf(k)),
"", markdown.Code("true"),
}) "",
} })
}
return true
})
} }
if len(rows) > 0 { if len(rows) > 0 {

View File

@ -21,7 +21,7 @@ func waitFor(addr string) {
} }
func httpGet(url string) string { func httpGet(url string) string {
resp, err := http.Get(url) //nolint:all resp, err := http.Get(url)
if err != nil { if err != nil {
panic(err.Error()) panic(err.Error())
} }