keel/persistence/mongo/collection.go
2024-04-16 09:38:28 +02:00

447 lines
12 KiB
Go

package keelmongo
import (
"context"
"slices"
"time"
keelerrors "github.com/foomo/keel/errors"
keelpersistence "github.com/foomo/keel/persistence"
keeltime "github.com/foomo/keel/time"
"github.com/pkg/errors"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/bsoncodec"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"go.mongodb.org/mongo-driver/mongo/readconcern"
"go.mongodb.org/mongo-driver/mongo/readpref"
"go.mongodb.org/mongo-driver/mongo/writeconcern"
)
type (
DecodeFn func(val interface{}) error
IterateHandlerFn func(decode DecodeFn) error
)
// Collection can only be used in the Persistor.WithCollection call.ss
type (
Collection struct {
db *mongo.Database
collection *mongo.Collection
}
CollectionOptions struct {
*options.CollectionOptions
*options.CreateIndexesOptions
Indexes []mongo.IndexModel
IndexesContext context.Context
}
CollectionOption func(*CollectionOptions)
)
// ------------------------------------------------------------------------------------------------
// ~ Options
// ------------------------------------------------------------------------------------------------
func DefaultCollectionOptions() CollectionOptions {
return CollectionOptions{
CollectionOptions: options.Collection(),
CreateIndexesOptions: options.CreateIndexes(),
IndexesContext: context.Background(),
}
}
func CollectionWithReadConcern(v *readconcern.ReadConcern) CollectionOption {
return func(o *CollectionOptions) {
o.CollectionOptions.SetReadConcern(v)
}
}
func CollectionWithWriteConcern(v *writeconcern.WriteConcern) CollectionOption {
return func(o *CollectionOptions) {
o.CollectionOptions.SetWriteConcern(v)
}
}
func CollectionWithReadPreference(v *readpref.ReadPref) CollectionOption {
return func(o *CollectionOptions) {
o.CollectionOptions.SetReadPreference(v)
}
}
func CollectionWithRegistry(v *bsoncodec.Registry) CollectionOption {
return func(o *CollectionOptions) {
o.CollectionOptions.SetRegistry(v)
}
}
func CollectionWithIndexes(v ...mongo.IndexModel) CollectionOption {
return func(o *CollectionOptions) {
o.Indexes = v
}
}
func CollectionWithIndexesMaxTime(v time.Duration) CollectionOption {
return func(o *CollectionOptions) {
o.CreateIndexesOptions.SetMaxTime(v)
}
}
func CollectionWithIndexesContext(v int32) CollectionOption {
return func(o *CollectionOptions) {
o.CreateIndexesOptions.SetCommitQuorumInt(v)
}
}
func CollectionWithIndexesQuorumMajority() CollectionOption {
return func(o *CollectionOptions) {
o.CreateIndexesOptions.SetCommitQuorumMajority()
}
}
func CollectionWithIndexesCommitQuorumString(v string) CollectionOption {
return func(o *CollectionOptions) {
o.CreateIndexesOptions.SetCommitQuorumString(v)
}
}
func CollectionWithIndexesCommitQuorumVotingMembers(v context.Context) CollectionOption {
return func(o *CollectionOptions) {
o.CreateIndexesOptions.SetCommitQuorumVotingMembers()
}
}
// ------------------------------------------------------------------------------------------------
// ~ Constructor
// ------------------------------------------------------------------------------------------------
func NewCollection(db *mongo.Database, name string, opts ...CollectionOption) (*Collection, error) {
o := DefaultCollectionOptions()
for _, opt := range opts {
opt(&o)
}
col := db.Collection(name, o.CollectionOptions)
if !slices.Contains(dbs[db.Name()], name) {
dbs[db.Name()] = append(dbs[db.Name()], name)
}
if len(o.Indexes) > 0 {
if _, err := col.Indexes().CreateMany(o.IndexesContext, o.Indexes, o.CreateIndexesOptions); err != nil {
return nil, err
}
if _, ok := indices[db.Name()]; !ok {
indices[db.Name()] = map[string][]string{}
}
for _, index := range o.Indexes {
if index.Options != nil && index.Options.Name != nil {
indices[db.Name()][name] = append(indices[db.Name()][name], *index.Options.Name)
}
}
}
return &Collection{
db: db,
collection: col,
}, nil
}
// ------------------------------------------------------------------------------------------------
// ~ Getter
// ------------------------------------------------------------------------------------------------
func (c *Collection) DB() *mongo.Database {
return c.db
}
func (c *Collection) Col() *mongo.Collection {
return c.collection
}
// ------------------------------------------------------------------------------------------------
// ~ Public methods
// ------------------------------------------------------------------------------------------------
func (c *Collection) Get(ctx context.Context, id string, result interface{}, opts ...*options.FindOneOptions) error {
if id == "" {
return keelpersistence.ErrNotFound
}
return c.FindOne(ctx, bson.M{"id": id}, result, opts...)
}
func (c *Collection) Exists(ctx context.Context, id string) (bool, error) {
if id == "" {
return false, nil
}
ret, err := c.collection.CountDocuments(ctx, bson.M{"id": id})
return ret > 0, err
}
func (c *Collection) Upsert(ctx context.Context, id string, entity Entity) error {
if id == "" {
return errors.New("id must not be empty")
} else if entity == nil {
return errors.New("entity must not be nil")
}
if v, ok := entity.(EntityWithTimestamps); ok {
now := keeltime.Now()
if ct := v.GetCreatedAt(); ct.IsZero() {
v.SetCreatedAt(now)
}
v.SetUpdatedAt(now)
}
if v, ok := entity.(EntityWithVersion); ok {
currentVersion := v.GetVersion()
// increment version
v.IncreaseVersion()
if currentVersion == 0 {
// insert the new document
return c.Insert(ctx, entity)
} else if err := c.collection.FindOneAndUpdate(
ctx,
bson.D{bson.E{Key: "id", Value: id}, bson.E{Key: "version", Value: currentVersion}},
bson.D{bson.E{Key: "$set", Value: entity}},
options.FindOneAndUpdate().SetUpsert(false),
).Err(); errors.Is(err, mongo.ErrNoDocuments) {
return keelerrors.NewWrappedError(keelpersistence.ErrDirtyWrite, err)
} else if err != nil {
return err
}
} else if _, err := c.collection.UpdateOne(
ctx,
bson.D{bson.E{Key: "id", Value: id}},
bson.D{bson.E{Key: "$set", Value: entity}},
options.Update().SetUpsert(true),
); err != nil {
return err
}
return nil
}
// UpsertMany - NOTE: upsert many does NOT through an explicit error on dirty write so we can only assume it.
func (c *Collection) UpsertMany(ctx context.Context, entities []Entity) error {
var versionUpserts int64
var operations []mongo.WriteModel
for _, entity := range entities {
if entity == nil {
return errors.New("entity must not be nil")
} else if entity.GetID() == "" {
return errors.New("id must not be empty")
}
if v, ok := entity.(EntityWithTimestamps); ok {
now := keeltime.Now()
if ct := v.GetCreatedAt(); ct.IsZero() {
v.SetCreatedAt(now)
}
v.SetUpdatedAt(now)
}
if v, ok := entity.(EntityWithVersion); ok {
currentVersion := v.GetVersion()
// increment version
v.IncreaseVersion()
if currentVersion == 0 {
operations = append(operations,
mongo.NewInsertOneModel().SetDocument(entity),
)
} else {
versionUpserts++
operations = append(operations,
mongo.NewUpdateOneModel().
SetFilter(bson.D{bson.E{Key: "id", Value: entity.GetID()}, bson.E{Key: "version", Value: currentVersion}}).
SetUpdate(bson.D{bson.E{Key: "$set", Value: entity}}).
SetUpsert(false),
)
}
} else {
operations = append(operations,
mongo.NewUpdateOneModel().
SetFilter(bson.D{bson.E{Key: "id", Value: entity.GetID()}}).
SetUpdate(bson.D{bson.E{Key: "$set", Value: entity}}).
SetUpsert(true),
)
}
}
// Specify an option to turn the bulk insertion in order of operation
bulkOption := options.BulkWriteOptions{}
bulkOption.SetOrdered(false)
res, err := c.Col().BulkWrite(ctx, operations, &bulkOption)
if err != nil {
return err
} else if versionUpserts > 0 && (res.MatchedCount < versionUpserts || res.ModifiedCount != res.MatchedCount) {
// log.Logger().Info("missing upserts",
// zap.Int64("MatchedCount", res.MatchedCount),
// zap.Int64("InsertedCount", res.InsertedCount),
// zap.Int64("UpsertedCount", res.UpsertedCount),
// zap.Int64("ModifiedCount", res.ModifiedCount),
// zap.Any("UpsertedIDs", res.UpsertedIDs),
// zap.Any("versionUpserts", versionUpserts),
// )
return keelpersistence.ErrDirtyWrite
}
return nil
}
func (c *Collection) Insert(ctx context.Context, entity Entity) error {
if entity == nil {
return errors.New("entity must not be nil")
} else if entity.GetID() == "" {
return errors.New("id must not be empty")
}
if v, ok := entity.(EntityWithTimestamps); ok {
now := keeltime.Now()
if ct := v.GetCreatedAt(); ct.IsZero() {
v.SetCreatedAt(now)
}
v.SetUpdatedAt(now)
}
if v, ok := entity.(EntityWithVersion); ok {
// increment version
v.IncreaseVersion()
}
if _, err := c.collection.InsertOne(ctx, entity); err != nil {
return err
}
return nil
}
func (c *Collection) InsertMany(ctx context.Context, entities []Entity) error {
inserts := make([]interface{}, len(entities))
for i, entity := range entities {
if entity == nil {
return errors.New("entity must not be nil")
} else if entity.GetID() == "" {
return errors.New("id must not be empty")
}
if v, ok := entity.(EntityWithTimestamps); ok {
now := keeltime.Now()
if ct := v.GetCreatedAt(); ct.IsZero() {
v.SetCreatedAt(now)
}
v.SetUpdatedAt(now)
}
if v, ok := entity.(EntityWithVersion); ok {
// increment version
v.IncreaseVersion()
}
inserts[i] = entity
}
if _, err := c.collection.InsertMany(ctx, inserts); err != nil {
return err
}
return nil
}
func (c *Collection) Delete(ctx context.Context, id string) error {
if id == "" {
return keelpersistence.ErrNotFound
}
if err := c.collection.FindOneAndDelete(ctx, bson.M{"id": id}).Err(); errors.Is(err, mongo.ErrNoDocuments) {
return keelerrors.NewWrappedError(keelpersistence.ErrNotFound, err)
} else if err != nil {
return err
}
return nil
}
func (c *Collection) Find(ctx context.Context, filter, results interface{}, opts ...*options.FindOptions) error {
cursor, err := c.collection.Find(ctx, filter, opts...)
if errors.Is(err, mongo.ErrNoDocuments) {
return keelerrors.NewWrappedError(keelpersistence.ErrNotFound, err)
} else if err != nil {
return err
}
if err = cursor.All(ctx, results); err != nil {
return err
}
return cursor.Err()
}
func (c *Collection) FindOne(ctx context.Context, filter, result interface{}, opts ...*options.FindOneOptions) error {
res := c.collection.FindOne(ctx, filter, opts...)
if errors.Is(res.Err(), mongo.ErrNoDocuments) {
return keelerrors.NewWrappedError(keelpersistence.ErrNotFound, res.Err())
} else if res.Err() != nil {
return res.Err()
}
return res.Decode(result)
}
func (c *Collection) FindIterate(ctx context.Context, filter interface{}, handler IterateHandlerFn, opts ...*options.FindOptions) error {
cursor, err := c.collection.Find(ctx, filter, opts...)
if errors.Is(err, mongo.ErrNoDocuments) {
return keelerrors.NewWrappedError(keelpersistence.ErrNotFound, err)
} else if err != nil {
return err
}
defer CloseCursor(context.WithoutCancel(ctx), cursor)
for cursor.Next(ctx) {
if err := handler(cursor.Decode); err != nil {
return err
}
}
return cursor.Err()
}
func (c *Collection) Aggregate(ctx context.Context, pipeline mongo.Pipeline, results interface{}, opts ...*options.AggregateOptions) error {
cursor, err := c.collection.Aggregate(ctx, pipeline, opts...)
if err != nil {
return err
}
if err = cursor.All(ctx, results); err != nil {
return err
}
return cursor.Err()
}
func (c *Collection) AggregateIterate(ctx context.Context, pipeline mongo.Pipeline, handler IterateHandlerFn, opts ...*options.AggregateOptions) error {
cursor, err := c.collection.Aggregate(ctx, pipeline, opts...)
if err != nil {
return err
}
defer CloseCursor(context.WithoutCancel(ctx), cursor)
for cursor.Next(ctx) {
if err := handler(cursor.Decode); err != nil {
return err
}
}
return cursor.Err()
}
// Count returns the count of documents
func (c *Collection) Count(ctx context.Context, filter interface{}, opts ...*options.CountOptions) (int64, error) {
return c.collection.CountDocuments(ctx, filter, opts...)
}
// CountAll returns the count of all documents
func (c *Collection) CountAll(ctx context.Context) (int64, error) {
return c.Count(ctx, bson.D{})
}