From f1dae6a74599b8039d34845b68c759346d79c384 Mon Sep 17 00:00:00 2001 From: Kevin Franklin Kim Date: Thu, 21 Mar 2024 18:01:24 +0100 Subject: [PATCH] fix: race conditions --- pkg/repo/loader.go | 29 +++++++++++------- pkg/repo/repo.go | 70 +++++++++++++++++++++++++++++-------------- pkg/repo/repo_test.go | 44 ++++++++++++--------------- 3 files changed, 85 insertions(+), 58 deletions(-) diff --git a/pkg/repo/loader.go b/pkg/repo/loader.go index 4999157..c6df9be 100644 --- a/pkg/repo/loader.go +++ b/pkg/repo/loader.go @@ -1,6 +1,7 @@ package repo import ( + "bytes" "context" "io" "net/http" @@ -130,7 +131,7 @@ func (r *Repo) _updateDimension(dimension string, newNode *content.RepoNode) err // copy old datastructure to prevent concurrent map access // collect other dimension in the Directory newRepoDirectory := map[string]*Dimension{} - for d, D := range r.Directory { + for d, D := range r.Directory() { if d != dimension { newRepoDirectory[d] = D } @@ -142,7 +143,7 @@ func (r *Repo) _updateDimension(dimension string, newNode *content.RepoNode) err Directory: newDirectory, URIDirectory: newURIDirectory, } - r.Directory = newRepoDirectory + r.SetDirectory(newRepoDirectory) // --------------------------------------------- @@ -193,7 +194,7 @@ func wireAliases(directory map[string]*content.RepoNode) error { func (r *Repo) loadNodesFromJSON() (nodes map[string]*content.RepoNode, err error) { nodes = make(map[string]*content.RepoNode) - err = json.Unmarshal(r.jsonBuf.Bytes(), &nodes) + err = json.Unmarshal(r.JSONBufferBytes(), &nodes) if err != nil { r.l.Error("Failed to deserialize nodes", zap.Error(err)) return nil, errors.New("failed to deserialize nodes") @@ -202,10 +203,12 @@ func (r *Repo) loadNodesFromJSON() (nodes map[string]*content.RepoNode, err erro } func (r *Repo) tryToRestoreCurrent() error { - err := r.history.GetCurrent(&r.jsonBuf) + buffer := &bytes.Buffer{} + err := r.history.GetCurrent(buffer) if err != nil { return err } + r.SetJSONBuffer(buffer) return r.loadJSONBytes() } @@ -225,13 +228,14 @@ func (r *Repo) get(ctx context.Context, url string) error { } // Log.Info(ansi.Red + "RESETTING BUFFER" + ansi.Reset) - r.jsonBuf.Reset() + buffer := &bytes.Buffer{} // Log.Info(ansi.Green + "LOADING DATA INTO BUFFER" + ansi.Reset) - _, err = io.Copy(&r.jsonBuf, response.Body) + _, err = io.Copy(buffer, response.Body) if err != nil { return errors.Wrap(err, "failed to copy IO stream") } + r.SetJSONBuffer(buffer) return nil } @@ -280,7 +284,7 @@ func (r *Repo) update(ctx context.Context) (repoRuntime int64, err error) { r.l.Debug("failed to load json", zap.Error(err)) return repoRuntime, err } - r.l.Debug("loading json", zap.String("server", repoURL), zap.Int("length", len(r.jsonBuf.Bytes()))) + r.l.Debug("loading json", zap.String("server", repoURL), zap.Int("length", len(r.JSONBufferBytes()))) nodes, err := r.loadNodesFromJSON() if err != nil { // could not load nodes from json @@ -314,7 +318,7 @@ func (r *Repo) tryUpdate() (repoRuntime int64, err error) { func (r *Repo) loadJSONBytes() error { nodes, err := r.loadNodesFromJSON() if err != nil { - data := r.jsonBuf.Bytes() + data := r.JSONBufferBytes() if len(data) > 10 { r.l.Debug("could not parse json", @@ -327,7 +331,7 @@ func (r *Repo) loadJSONBytes() error { err = r.loadNodes(nodes) if err == nil { - errHistory := r.history.Add(r.jsonBuf.Bytes()) + errHistory := r.history.Add(r.JSONBufferBytes()) if errHistory != nil { r.l.Error("Could not add valid JSON to history", zap.Error(errHistory)) metrics.HistoryPersistFailedCounter.WithLabelValues().Inc() @@ -361,11 +365,14 @@ func (r *Repo) loadNodes(newNodes map[string]*content.RepoNode) error { return false } // we need to throw away orphaned dimensions - for dimension := range r.Directory { + directory := map[string]*Dimension{} + for dimension, value := range r.Directory() { if !dimensionIsValid(dimension) { r.l.Info("removing orphaned dimension", zap.String("dimension", dimension)) - delete(r.Directory, dimension) + continue } + directory[dimension] = value } + r.SetDirectory(directory) return nil } diff --git a/pkg/repo/repo.go b/pkg/repo/repo.go index b0b6e64..055b7cf 100644 --- a/pkg/repo/repo.go +++ b/pkg/repo/repo.go @@ -8,6 +8,7 @@ import ( "net/http" "os" "strings" + "sync" "sync/atomic" "time" @@ -33,13 +34,14 @@ type ( loaded *atomic.Bool history *History httpClient *http.Client - Directory map[string]*Dimension // updateLock sync.Mutex dimensionUpdateChannel chan *RepoDimension dimensionUpdateDoneChannel chan error updateInProgressChannel chan chan updateResponse - // jsonBytes []byte - jsonBuf bytes.Buffer + directory map[string]*Dimension + directoryLock sync.RWMutex + jsonBuffer *bytes.Buffer + jsonBufferLock sync.RWMutex } Option func(*Repo) ) @@ -56,7 +58,7 @@ func New(l *zap.Logger, url string, history *History, opts ...Option) *Repo { loaded: &atomic.Bool{}, history: history, httpClient: http.DefaultClient, - Directory: map[string]*Dimension{}, + directory: map[string]*Dimension{}, dimensionUpdateChannel: make(chan *RepoDimension), dimensionUpdateDoneChannel: make(chan error), updateInProgressChannel: make(chan chan updateResponse), @@ -93,14 +95,38 @@ func (r *Repo) Loaded() bool { return r.loaded.Load() } -func (r *Repo) OnStart(fn func()) { - r.onStart = fn +func (r *Repo) Directory() map[string]*Dimension { + r.directoryLock.RLock() + defer r.directoryLock.RUnlock() + return r.directory +} + +func (r *Repo) SetDirectory(v map[string]*Dimension) { + r.directoryLock.Lock() + defer r.directoryLock.Unlock() + r.directory = v +} + +func (r *Repo) JSONBufferBytes() []byte { + r.jsonBufferLock.RLock() + defer r.jsonBufferLock.RUnlock() + return r.jsonBuffer.Bytes() +} + +func (r *Repo) SetJSONBuffer(v *bytes.Buffer) { + r.jsonBufferLock.Lock() + defer r.jsonBufferLock.Unlock() + r.jsonBuffer = v } // ------------------------------------------------------------------------------------------------ // ~ Public methods // ------------------------------------------------------------------------------------------------ +func (r *Repo) OnStart(fn func()) { + r.onStart = fn +} + // GetURIs get many uris at once func (r *Repo) GetURIs(dimension string, ids []string) map[string]string { uris := map[string]string{} @@ -149,7 +175,7 @@ func (r *Repo) GetContent(req *requests.Content) (*content.SiteContent, error) { c.Path = node.GetPath(req.PathDataFields) // fetch URIs for all dimensions uris := make(map[string]string) - for dimensionName := range r.Directory { + for dimensionName := range r.Directory() { uris[dimensionName] = r.getURI(dimensionName, node.ID) } c.URIs = uris @@ -179,7 +205,7 @@ func (r *Repo) GetContent(req *requests.Content) (*content.SiteContent, error) { // GetRepo get the whole repo in all dimensions func (r *Repo) GetRepo() map[string]*content.RepoNode { response := make(map[string]*content.RepoNode) - for dimensionName, dimension := range r.Directory { + for dimensionName, dimension := range r.Directory() { response[dimensionName] = dimension.Node } return response @@ -238,15 +264,15 @@ func (r *Repo) Update() (updateResponse *responses.Update) { } else { updateResponse.Success = true // persist the currently loaded one - historyErr := r.history.Add(r.jsonBuf.Bytes()) + historyErr := r.history.Add(r.JSONBufferBytes()) if historyErr != nil { r.l.Error("Could not persist current repo in history", zap.Error(historyErr)) metrics.HistoryPersistFailedCounter.WithLabelValues().Inc() } // add some stats - for dimension := range r.Directory { - updateResponse.Stats.NumberOfNodes += len(r.Directory[dimension].Directory) - updateResponse.Stats.NumberOfURIs += len(r.Directory[dimension].URIDirectory) + for _, dimension := range r.Directory() { + updateResponse.Stats.NumberOfNodes += len(dimension.Directory) + updateResponse.Stats.NumberOfURIs += len(dimension.URIDirectory) } r.loaded.Store(true) } @@ -294,7 +320,7 @@ func (r *Repo) Start(ctx context.Context) error { } else if !r.Loaded() { l.Debug("trying to update initial state") if resp := r.Update(); !resp.Success { - l.Fatal("failed to update", + l.Error("failed to update initial state", zap.String("error", resp.ErrorMessage), zap.Int("num_modes", resp.Stats.NumberOfNodes), zap.Int("num_uris", resp.Stats.NumberOfURIs), @@ -332,13 +358,13 @@ func (r *Repo) getNodes(nodeRequests map[string]*requests.Node, env *requests.En groups = nodeRequest.Groups } - dimensionNode, ok := r.Directory[nodeRequest.Dimension] + dimensionNode, ok := r.Directory()[nodeRequest.Dimension] nodes[nodeName] = nil if !ok && nodeRequest.Dimension == "" { r.l.Debug("Could not get dimension root node", zap.String("dimension", nodeRequest.Dimension)) for _, dimension := range env.Dimensions { - dimensionNode, ok = r.Directory[dimension] + dimensionNode, ok = r.Directory()[dimension] if ok { r.l.Debug("Found root node in env.Dimensions", zap.String("dimension", dimension)) break @@ -376,7 +402,7 @@ func (r *Repo) resolveContent(dimensions []string, uri string) (resolved bool, r testURI = content.PathSeparator } for _, dimension := range dimensions { - if d, ok := r.Directory[dimension]; ok { + if d, ok := r.Directory()[dimension]; ok { r.l.Debug("Checking node", zap.String("dimension", dimension), zap.String("URI", testURI), @@ -402,7 +428,7 @@ func (r *Repo) getURIForNode(dimension string, repoNode *content.RepoNode, recur uri = repoNode.URI return } - linkedNode, ok := r.Directory[dimension].Directory[repoNode.LinkID] + linkedNode, ok := r.Directory()[dimension].Directory[repoNode.LinkID] if ok { if recursionLevel > maxGetURIForNodeRecursionLevel { r.l.Error("maxGetURIForNodeRecursionLevel reached", zap.String("repoNode.ID", repoNode.ID), zap.String("linkID", repoNode.LinkID), zap.String("dimension", dimension)) @@ -414,7 +440,7 @@ func (r *Repo) getURIForNode(dimension string, repoNode *content.RepoNode, recur } func (r *Repo) getURI(dimension string, id string) string { - directory, ok := r.Directory[dimension] + directory, ok := r.Directory()[dimension] if !ok { return "" } @@ -463,14 +489,14 @@ func (r *Repo) validateContentRequest(req *requests.Content) (err error) { } for _, envDimension := range req.Env.Dimensions { if !r.hasDimension(envDimension) { - availableDimensions := make([]string, 0, len(r.Directory)) - for availableDimension := range r.Directory { + availableDimensions := make([]string, 0, len(r.Directory())) + for availableDimension := range r.Directory() { availableDimensions = append(availableDimensions, availableDimension) } return errors.New(fmt.Sprint( "unknown dimension ", envDimension, " in r.Env must be one of ", availableDimensions, - " repo has ", len(r.Directory), " dimensions", + " repo has ", len(availableDimensions), " dimensions", )) } } @@ -478,6 +504,6 @@ func (r *Repo) validateContentRequest(req *requests.Content) (err error) { } func (r *Repo) hasDimension(d string) bool { - _, hasDimension := r.Directory[d] + _, hasDimension := r.Directory()[d] return hasDimension } diff --git a/pkg/repo/repo_test.go b/pkg/repo/repo_test.go index 0af28fd..3a48ced 100644 --- a/pkg/repo/repo_test.go +++ b/pkg/repo/repo_test.go @@ -14,9 +14,9 @@ import ( "go.uber.org/zap/zaptest" ) -func NewTestRepo(l *zap.Logger, server, varDir string) *Repo { +func NewTestRepo(l *zap.Logger, url, varDir string) *Repo { h := NewHistory(l, HistoryWithMax(2), HistoryWithVarDir(varDir)) - r := New(l, server, h) + r := New(l, url, h) go r.Start(context.Background()) //nolint:errcheck time.Sleep(100 * time.Millisecond) return r @@ -25,23 +25,22 @@ func NewTestRepo(l *zap.Logger, server, varDir string) *Repo { func assertRepoIsEmpty(t *testing.T, r *Repo, empty bool) { t.Helper() if empty { - if len(r.Directory) > 0 { + if len(r.Directory()) > 0 { t.Fatal("directory should have been empty, but is not") } } else { - if len(r.Directory) == 0 { + if len(r.Directory()) == 0 { t.Fatal("directory is empty, but should have been not") } } } func TestLoad404(t *testing.T) { - l := zaptest.NewLogger(t) - var ( + l = zaptest.NewLogger(t) mockServer, varDir = mock.GetMockData(t) - server = mockServer.URL + "/repo-no-have" - r = NewTestRepo(l, server, varDir) + url = mockServer.URL + "/repo-no-have" + r = NewTestRepo(l, url, varDir) ) response := r.Update() @@ -51,9 +50,8 @@ func TestLoad404(t *testing.T) { } func TestLoadBrokenRepo(t *testing.T) { - l := zaptest.NewLogger(t) - var ( + l = zaptest.NewLogger(t) mockServer, varDir = mock.GetMockData(t) server = mockServer.URL + "/repo-broken-json.json" r = NewTestRepo(l, server, varDir) @@ -66,14 +64,13 @@ func TestLoadBrokenRepo(t *testing.T) { } func TestLoadRepo(t *testing.T) { - l := zaptest.NewLogger(t) - var ( + l = zaptest.NewLogger(t) mockServer, varDir = mock.GetMockData(t) server = mockServer.URL + "/repo-ok.json" r = NewTestRepo(l, server, varDir) ) - assertRepoIsEmpty(t, r, true) + assertRepoIsEmpty(t, r, false) response := r.Update() assertRepoIsEmpty(t, r, false) @@ -94,23 +91,19 @@ func TestLoadRepo(t *testing.T) { } func BenchmarkLoadRepo(b *testing.B) { - l := zaptest.NewLogger(b) - var ( + l = zaptest.NewLogger(b) t = &testing.T{} mockServer, varDir = mock.GetMockData(t) server = mockServer.URL + "/repo-ok.json" r = NewTestRepo(l, server, varDir) ) - if len(r.Directory) > 0 { - b.Fatal("directory should have been empty, but is not") - } b.ReportAllocs() b.ResetTimer() for n := 0; n < b.N; n++ { response := r.Update() - if len(r.Directory) == 0 { + if len(r.Directory()) == 0 { b.Fatal("directory is empty, but should have been not") } @@ -121,11 +114,12 @@ func BenchmarkLoadRepo(b *testing.B) { } func TestLoadRepoDuplicateUris(t *testing.T) { - l := zaptest.NewLogger(t) - - mockServer, varDir := mock.GetMockData(t) - server := mockServer.URL + "/repo-duplicate-uris.json" - r := NewTestRepo(l, server, varDir) + var ( + l = zaptest.NewLogger(t) + mockServer, varDir = mock.GetMockData(t) + server = mockServer.URL + "/repo-duplicate-uris.json" + r = NewTestRepo(l, server, varDir) + ) response := r.Update() require.False(t, response.Success, "there are duplicates, this repo update should have failed") @@ -147,7 +141,7 @@ func TestDimensionHygiene(t *testing.T) { response = r.Update() require.True(t, response.Success, "it is called repo ok") - assert.Lenf(t, r.Directory, 1, "directory hygiene failed") + assert.Lenf(t, r.Directory(), 1, "directory hygiene failed") } func getTestRepo(t *testing.T, path string) *Repo {