From ac7201d135dd53258ad2369df82eff16746583cf Mon Sep 17 00:00:00 2001 From: YuriBocharov Date: Wed, 10 Jan 2024 21:45:33 -0500 Subject: [PATCH 1/4] feat: adds a MongoDB connector This commit adds a MongoDB connector. A couple of things to note: - To enable transactions the database must a replica set, I made the connection fail if the condition was not met. MongoDB is a document database, thus concepts translate as: bucket/table = collection key = document value = document value Note: mongo will automatically create a collection if a operation is attempted on one that does not exist. For example: if you delete a table, then call list to check for the contents of the table it will be recreated. closes smallstep/certificates#141 --- database/database.go | 24 +-- go.sum | 24 ++- mongo/mongo.go | 362 +++++++++++++++++++++++++++++++++++++++++++ mongo/nomongo.go | 8 + nosql.go | 5 + nosql_test.go | 37 ++++- 6 files changed, 444 insertions(+), 16 deletions(-) create mode 100644 mongo/mongo.go create mode 100644 mongo/nomongo.go diff --git a/database/database.go b/database/database.go index abdb5ce..1b19891 100644 --- a/database/database.go +++ b/database/database.go @@ -1,9 +1,8 @@ package database import ( - "fmt" - "errors" + "fmt" ) var ( @@ -66,23 +65,28 @@ type DB interface { Open(dataSourceName string, opt ...Option) error // Close closes the current database. Close() error - // Get returns the value stored in the given table/bucket and key. + // Get returns the value stored in the given table/bucket/collection and key. Get(bucket, key []byte) (ret []byte, err error) - // Set sets the given value in the given table/bucket and key. + // Set sets the given value in the given table/bucket/collection and key. Set(bucket, key, value []byte) error - // CmpAndSwap swaps the value at the given bucket and key if the current - // value is equivalent to the oldValue input. Returns 'true' if the - // swap was successful and 'false' otherwise. + // CmpAndSwap takess a bucket, key, oldValue and newValue as inputs: + // CmpAndSwap returns: error, wasSwapped, keyValue + // - if the bucket / key does not exist, it returns database.ErrNotFound + // - if the bucket / key exists but the keyValue is not equivalent to oldValue, + // it returns wasSwapped = false, keyValue + // - if the bucket / key exists and the keyValue is equivalent to oldValue, + // it returns wasSwapped = true, keyValue + // - if an error occurs, it returns error CmpAndSwap(bucket, key, oldValue, newValue []byte) ([]byte, bool, error) // Del deletes the data in the given table/bucket and key. Del(bucket, key []byte) error - // List returns a list of all the entries in a given table/bucket. + // List returns a list of all the entries in a given table/bucket/collection. List(bucket []byte) ([]*Entry, error) // Update performs a transaction with multiple read-write commands. Update(tx *Tx) error - // CreateTable creates a table or a bucket in the database. + // CreateTable creates a table, bucket or collection in the database. CreateTable(bucket []byte) error - // DeleteTable deletes a table or a bucket in the database. + // DeleteTable deletes a table, bucket or collection in the database. DeleteTable(bucket []byte) error } diff --git a/go.sum b/go.sum index 23b9b87..0b513e4 100644 --- a/go.sum +++ b/go.sum @@ -58,14 +58,16 @@ github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvq github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= github.com/golang/protobuf v1.4.3 h1:JjCZWpVbqXDqFVmTfYWEVTMIYrL/NPdPSCHPJ0T/raM= github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v0.0.3 h1:fHPg5GQYlCeLIPB9BZqMVR5nR9A+IM5zcgeTdjMYmLA= github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.0 h1:/QaMHBdZ26BB3SSst0Iwl10Epc+xhTquomWX0oZEB6w= github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.2 h1:X2ev0eStA3AbceY54o37/0PQ/UWqKEiiO2dKL5OPaFM= +github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= @@ -118,8 +120,9 @@ github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0f github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v1.3.0/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= -github.com/klauspost/compress v1.12.3 h1:G5AfA94pHPysR56qqrkO2pxEexdDzrpFJ6yt/VqWxVU= github.com/klauspost/compress v1.12.3/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg= +github.com/klauspost/compress v1.13.6 h1:P76CopJELS0TiO2mebmnzgWaajssP/EszplttgQxcgc= +github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= @@ -140,6 +143,8 @@ github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hd github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= +github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe h1:iruDEfMl2E6fbMZ9s0scYfZQ84/6SPL6zC8ACM2oIL0= +github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe/go.mod h1:wL8QJuTMNUDYhXwkmfOly8iTdp5TEcJFWZD2D7SIkUc= github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= @@ -184,11 +189,21 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0= +github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c= +github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= +github.com/xdg-go/scram v1.1.2 h1:FHX5I5B4i4hKRVRBCFRxq1iQRej7WO3hhBuJf+UUySY= +github.com/xdg-go/scram v1.1.2/go.mod h1:RT/sEzTbU5y00aCK8UOx6R7YryM0iF1N2MOmC3kKLN4= +github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8= +github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM= github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q= +github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d h1:splanxYIlg+5LfHAM6xpdFEAYOk8iySO56hMFq6uLyA= +github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d/go.mod h1:rHwXgn7JulP+udvsHwJoVG1YGAP6VLg4y9I5dyZdqmA= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= go.etcd.io/bbolt v1.3.8 h1:xs88BrvEv273UsB79e0hcVrlUWmS0a8upikMFhSyAtA= go.etcd.io/bbolt v1.3.8/go.mod h1:N9Mkw9X8x5fupy0IKsmuqVtoGDyxsaDlbk4Rd05IAQw= +go.mongodb.org/mongo-driver v1.13.1 h1:YIc7HTYsKndGK4RFzJ3covLz1byri52x0IoMB0Pt/vk= +go.mongodb.org/mongo-driver v1.13.1/go.mod h1:wcDf1JBCXy2mOW0bWHwO/IOYqdca1MPCwDtFu/Z9+eo= go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= @@ -211,6 +226,7 @@ golang.org/x/crypto v0.0.0-20201203163018-be400aefbc4c/go.mod h1:jdWPYTVW3xRLrWP golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58= golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k= golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= @@ -230,6 +246,7 @@ golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= @@ -238,6 +255,7 @@ golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAG golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 h1:uVc8UZUe6tr40fFVnUP5Oj+veunVezqYl9z7DYw9xzw= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -253,6 +271,7 @@ golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -269,6 +288,7 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= diff --git a/mongo/mongo.go b/mongo/mongo.go new file mode 100644 index 0000000..a2b7768 --- /dev/null +++ b/mongo/mongo.go @@ -0,0 +1,362 @@ +//go:build !nomongo +// +build !nomongo + +package mongo + +import ( + "bytes" + "context" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/nosql/database" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" + "go.mongodb.org/mongo-driver/mongo/writeconcern" +) + +// DB is a wrapper over *sql.DB, +type DB struct { + db *mongo.Database +} + +type Tuple struct { + Key []byte `bson:"key" json:"key"` + Value []byte `bson:"value" json:"value"` +} + +// Open creates a Driver and connects to the database with the given address +// and access details. +func (db *DB) Open(uri string, opt ...database.Option) error { + opts := &database.Options{} + for _, o := range opt { + if err := o(opts); err != nil { + return err + } + } + + clientOptions := options.Client().ApplyURI(uri) + if rs := clientOptions.ReplicaSet; *rs == "" { + return errors.New("To enable transactions, please provide a replica set name") + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + client, err := mongo.Connect(ctx, clientOptions) + if err != nil { + return errors.Wrap(err, "error in configuration") + } + + err = client.Ping(context.Background(), nil) + if err != nil { + return errors.Wrap(err, "error connecting to mongo") + } + + dbOpts := options.Database() + db.db = client.Database(opts.Database, dbOpts) + + return nil +} + +func (db *DB) Close() error { + if err := db.db.Client().Disconnect(context.Background()); err != nil { + return errors.WithStack(err) + } + + return nil +} + +func (db *DB) CreateTable(bucket []byte) error { + return db.createTable(bucket, context.Background()) +} + +func (db *DB) DeleteTable(bucket []byte) error { + return db.deleteTable(bucket, context.Background()) +} + +// Get returns the value stored in the given bucked and key. +func (db *DB) Get(bucket, key []byte) (ret []byte, err error) { + return db.get(bucket, key, context.Background()) +} + +// Set stores the given value on bucket and key. +func (db *DB) Set(bucket, key, value []byte) error { + return db.set(bucket, key, value, context.Background()) +} + +// Del deletes the value stored in the given bucket and key. +func (db *DB) Del(bucket, key []byte) error { + return db.del(bucket, key, context.Background()) +} + +func (db *DB) List(bucket []byte) ([]*database.Entry, error) { + return db.list(bucket, context.Background()) +} + +// CmpAndSwap modifies the value at the given bucket and key (to newValue) +// only if the existing (current) value matches oldValue. +func (db *DB) CmpAndSwap(bucket, key, oldValue, newValue []byte) ([]byte, bool, error) { + wc := writeconcern.Majority() + txnOptions := options.Transaction().SetWriteConcern(wc) + + session, err := db.db.Client().StartSession() + if err != nil { + return oldValue, false, errors.Wrap(err, "error starting session") + } + defer session.EndSession(context.Background()) + + val, swapped := []byte{}, false + err = mongo.WithSession(context.Background(), session, func(ctx mongo.SessionContext) error { + if err = session.StartTransaction(txnOptions); err != nil { + return errors.Wrap(err, "error: pending transaction") + } + + val, swapped, err = db.cmpAndSwap(bucket, key, oldValue, newValue, ctx) + if err != nil { + if err = session.AbortTransaction(context.Background()); err != nil { + return errors.Wrapf(err, "failed to execute CmpAndSwap transaction on %s/%s and failed to rollback transaction", bucket, key) + } + return errors.Wrap(err, "error aborting transaction") + } + + if err = session.CommitTransaction(context.Background()); err != nil { + return errors.Wrap(err, "error committing transaction") + } + return nil + }) + + return val, swapped, err +} + +// Update performs multiple commands on one read-write transaction. +func (db *DB) Update(tx *database.Tx) error { + wc := writeconcern.Majority() + txnOptions := options.Transaction().SetWriteConcern(wc) + + session, err := db.db.Client().StartSession() + if err != nil { + return errors.Wrap(err, "error starting session") + } + defer session.EndSession(context.Background()) + + err = mongo.WithSession(context.TODO(), session, func(ctx mongo.SessionContext) error { + if err = session.StartTransaction(txnOptions); err != nil { + return errors.Wrap(err, "error: pending transaction") + } + + err = db.executeTransactions(tx, ctx, session) + if err != nil { + return err + } + + if err = errors.WithStack(session.CommitTransaction(context.Background())); err != nil { + return errors.Wrap(err, "error committing transaction") + } + return nil + }) + + if err != nil { + return err + } + + return nil +} + +// CreateTable creates a bucket or an embedded bucket if it does not exists. +func (db *DB) createTable(bucket []byte, ctx context.Context) error { + if err := db.db.CreateCollection(ctx, string(bucket)); err != nil { + return errors.Wrap(err, "error creating collection") + } + + // create an index on the Key field + index := mongo.IndexModel{ + Keys: createFilter("key", 1), + Options: options.Index().SetUnique(true), + } + + _, err := db.db.Collection(string(bucket)).Indexes().CreateOne(context.TODO(), index) + if err != nil { + return errors.Wrap(err, "error creating collection") + } + + return nil +} + +// DeleteTable deletes a root or embedded bucket. Returns an error if the +// bucket cannot be found or if the key represents a non-bucket value. +func (db *DB) deleteTable(bucket []byte, ctx context.Context) error { + if !collectionExists(db.db, string(bucket)) { + return errors.Wrapf(database.ErrNotFound, "bucket %s not found", bucket) + } + + if err := db.db.Collection(string(bucket)).Drop(ctx); err != nil { + return errors.Wrapf(err, "error dropping collection: %s", bucket) + } + return nil +} + +func (db *DB) get(bucket, key []byte, ctx context.Context) (ret []byte, err error) { + filter := createFilter("key", key) + res := db.db.Collection(string(bucket)).FindOne(ctx, filter) + + if err := res.Err(); err != nil { + return nil, errors.Wrapf(database.ErrNotFound, "%s/%s", bucket, key) + } + + result := Tuple{} + if err := res.Decode(&result); err != nil { + return nil, errors.Wrap(err, "error decoding value") + } + + return []byte(result.Value), nil +} + +func (db *DB) set(bucket, key, value []byte, ctx context.Context) error { + filter := createFilter("key", key) + update := createUpdate(value) + opts := options.Update().SetUpsert(true) + + _, err := db.db.Collection(string(bucket)).UpdateOne(ctx, filter, update, opts) + if err != nil { + return errors.Wrapf(err, "error setting value %s/%s", bucket, key) + } + return nil +} + +// List returns the full list of entries in a bucket. +func (db *DB) list(bucket []byte, ctx context.Context) ([]*database.Entry, error) { + if !collectionExists(db.db, string(bucket)) { + return nil, errors.Wrapf(database.ErrNotFound, "bucket %s not found", bucket) + } + + // match all + filter := bson.D{{}} + + cursor, err := db.db.Collection(string(bucket)).Find(ctx, filter) + if err != nil { + return nil, errors.Wrap(err, "error listing values") + } + defer cursor.Close(context.Background()) + + if err = cursor.Err(); err != nil { + return nil, errors.Wrap(err, "error listing values") + } + + var entries []*database.Entry + + for cursor.Next(context.Background()) { + tuple := Tuple{} + cursor.Decode(&tuple) + entries = append(entries, &database.Entry{ + Bucket: bucket, + Key: []byte(tuple.Key), + Value: []byte(tuple.Value), + }) + } + + if err = cursor.Err(); err != nil { + return nil, errors.Wrap(err, "error listing values") + } + + return entries, nil +} + +// Del deletes the value stored in the given bucket and key. +func (db *DB) del(bucket, key []byte, ctx context.Context) error { + filter := createFilter("key", key) + + mongoRes, err := db.db.Collection(string(bucket)).DeleteOne(ctx, filter) + if err != nil { + return errors.Wrapf(err, "failed to delete %s/%s", bucket, key) + } + + if mongoRes.DeletedCount == 0 { + return errors.Wrapf(err, "failed to delete: %s/%s. Value not found", bucket, key) + } + + return nil +} + +func (db *DB) cmpAndSwap(bucket, key, target, newValue []byte, ctx context.Context) ([]byte, bool, error) { + v, err := db.get(bucket, key, ctx) + if err != nil && !database.IsErrNotFound(err) { + return nil, false, err + } + + if !bytes.Equal(v, target) { + return v, false, nil + } + + err = db.set(bucket, key, newValue, ctx) + if err != nil { + return nil, false, err + } + + return newValue, true, nil +} + +func (db *DB) executeTransactions(tx *database.Tx, ctx mongo.SessionContext, session mongo.Session) error { + for _, op := range tx.Operations { + var err error + switch op.Cmd { + case database.CreateTable: + if err := db.CreateTable(op.Bucket); err != nil { + return abort(session, err) + } + case database.DeleteTable: + if err := db.DeleteTable(op.Bucket); err != nil { + return abort(session, err) + } + case database.Get: + if op.Result, err = db.get(op.Bucket, op.Key, ctx); err != nil { + return abort(session, err) + } + case database.Set: + if err := db.set(op.Bucket, op.Key, op.Value, ctx); err != nil { + return abort(session, err) + } + case database.Delete: + if err := db.del(op.Bucket, op.Key, ctx); err != nil { + return abort(session, err) + } + case database.CmpAndSwap: + op.Result, op.Swapped, err = db.cmpAndSwap(op.Bucket, op.Key, op.CmpValue, op.Value, ctx) + if err != nil { + return abort(session, err) + } + case database.CmpOrRollback: + return abort(session, database.ErrOpNotSupported) + default: + return abort(session, database.ErrOpNotSupported) + } + } + + return nil +} + +func collectionExists(db *mongo.Database, name string) bool { + filter := createFilter("name", name) + list, err := db.ListCollectionNames(context.Background(), filter) + if err != nil { + return false + } + return len(list) > 0 +} + +func createFilter(key string, value any) bson.D { + return bson.D{{Key: key, Value: value}} +} + +func createUpdate(value []byte) bson.D { + return bson.D{{Key: "$set", Value: bson.D{{Key: "value", Value: value}}}} +} + +func abort(session mongo.Session, err error) error { + abortError := session.AbortTransaction(context.Background()) + if abortError != nil { + return errors.Wrap(err, "error aborting transaction") + } + return errors.Wrap(err, "UPDATE failed") +} diff --git a/mongo/nomongo.go b/mongo/nomongo.go new file mode 100644 index 0000000..2e31e80 --- /dev/null +++ b/mongo/nomongo.go @@ -0,0 +1,8 @@ +//go:build nomongo +// +build nomongo + +package mongo + +import "github.com/smallstep/nosql/database" + +type DB = database.NotSupportedDB diff --git a/nosql.go b/nosql.go index a72da2a..9dd3295 100644 --- a/nosql.go +++ b/nosql.go @@ -8,6 +8,7 @@ import ( badgerV2 "github.com/smallstep/nosql/badger/v2" "github.com/smallstep/nosql/bolt" "github.com/smallstep/nosql/database" + "github.com/smallstep/nosql/mongo" "github.com/smallstep/nosql/mysql" "github.com/smallstep/nosql/postgresql" ) @@ -38,6 +39,8 @@ var ( // Available db driver types. // + // MongoDriver indicates the default MySQL database. + MongoDriver = "mongodb" // BadgerDriver indicates the default Badger database - currently Badger V1. BadgerDriver = "badger" // BadgerV1Driver explicitly selects the Badger V1 driver. @@ -72,6 +75,8 @@ func New(driver, dataSourceName string, opt ...Option) (db database.DB, err erro db = &mysql.DB{} case PostgreSQLDriver: db = &postgresql.DB{} + case MongoDriver: + db = &mongo.DB{} default: return nil, errors.Errorf("%s database not supported", driver) } diff --git a/nosql_test.go b/nosql_test.go index 0d92622..f6f8c5c 100644 --- a/nosql_test.go +++ b/nosql_test.go @@ -16,7 +16,7 @@ type testUser struct { } func run(t *testing.T, db database.DB) { - var boogers = []byte("boogers") + boogers := []byte("boogers") ub := []byte("testNoSQLUsers") assert.True(t, IsErrNotFound(db.DeleteTable(ub))) @@ -269,7 +269,6 @@ func run(t *testing.T, db database.DB) { } func TestMain(m *testing.M) { - // setup path := "./tmp" if _, err := os.Stat(path); os.IsNotExist(err) { @@ -291,7 +290,7 @@ func TestMySQL(t *testing.T) { pwd = "password" proto = "tcp" addr = "127.0.0.1:3306" - //path = "/tmp/mysql.sock" + // path = "/tmp/mysql.sock" testDB = "test" ) @@ -310,12 +309,42 @@ func TestMySQL(t *testing.T) { run(t, db) } +func TestMongoDB(t *testing.T) { + var ( + uname = "user" + pwd = "password" + addr = "localhost:27017" + // path = "/tmp/mysql.sock" + replicaSet = "dbrs" + testDB = "test" + ) + + isCITest := os.Getenv("CI") + if isCITest == "" { + fmt.Printf("Not running MongoDB integration tests\n") + return + } + + db, err := New("mongodb", + fmt.Sprintf("mongodb://%s:%s@%s/?replicaSet=%s", uname, pwd, addr, replicaSet), + WithDatabase(testDB)) + if err != nil { + t.Fatalf(fmt.Sprintf("Error: %s\n", err)) + return + } + assert.FatalError(t, err) + + defer db.Close() + + run(t, db) +} + func TestPostgreSQL(t *testing.T) { var ( uname = "user" pwd = "password" addr = "127.0.0.1:5432" - //path = "/tmp/postgresql.sock" + // path = "/tmp/postgresql.sock" testDB = "test" ) From 6fdea12191b6ccac9cf784beb4d4fcd295229fdb Mon Sep 17 00:00:00 2001 From: YuriBocharov Date: Tue, 30 Jan 2024 19:12:41 -0500 Subject: [PATCH 2/4] fix: syntax issues removes use of github/errors package, uses fmt.Errorf instead changes syntax of error messages changes order of context as a parameter makes sure session is passed down where possible renames bucket to collection other minor fixes for review --- go.mod | 9 ++- mongo/mongo.go | 207 +++++++++++++++++++++++++------------------------ nosql_test.go | 8 +- 3 files changed, 114 insertions(+), 110 deletions(-) diff --git a/go.mod b/go.mod index a8f0e05..cb9436f 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/pkg/errors v0.9.1 github.com/smallstep/assert v0.0.0-20180720014142-de77670473b5 go.etcd.io/bbolt v1.3.8 + go.mongodb.org/mongo-driver v1.13.1 ) require ( @@ -27,9 +28,15 @@ require ( github.com/jackc/pgproto3/v2 v2.3.2 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect github.com/jackc/pgtype v1.14.0 // indirect - github.com/klauspost/compress v1.12.3 // indirect + github.com/klauspost/compress v1.13.6 // indirect + github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe // indirect + github.com/xdg-go/pbkdf2 v1.0.0 // indirect + github.com/xdg-go/scram v1.1.2 // indirect + github.com/xdg-go/stringprep v1.0.4 // indirect + github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d // indirect golang.org/x/crypto v0.17.0 // indirect golang.org/x/net v0.17.0 // indirect + golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 // indirect golang.org/x/sys v0.15.0 // indirect golang.org/x/text v0.14.0 // indirect google.golang.org/protobuf v1.25.0 // indirect diff --git a/mongo/mongo.go b/mongo/mongo.go index a2b7768..9b42e11 100644 --- a/mongo/mongo.go +++ b/mongo/mongo.go @@ -6,9 +6,9 @@ package mongo import ( "bytes" "context" + "fmt" "time" - "github.com/pkg/errors" "github.com/smallstep/nosql/database" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" @@ -21,7 +21,7 @@ type DB struct { db *mongo.Database } -type Tuple struct { +type tuple struct { Key []byte `bson:"key" json:"key"` Value []byte `bson:"value" json:"value"` } @@ -38,7 +38,7 @@ func (db *DB) Open(uri string, opt ...database.Option) error { clientOptions := options.Client().ApplyURI(uri) if rs := clientOptions.ReplicaSet; *rs == "" { - return errors.New("To enable transactions, please provide a replica set name") + return fmt.Errorf("to enable transactions, please provide a replica set name") } ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) @@ -46,12 +46,11 @@ func (db *DB) Open(uri string, opt ...database.Option) error { client, err := mongo.Connect(ctx, clientOptions) if err != nil { - return errors.Wrap(err, "error in configuration") + return fmt.Errorf("failed to configuration: %v, %w", clientOptions, err) } - err = client.Ping(context.Background(), nil) - if err != nil { - return errors.Wrap(err, "error connecting to mongo") + if err = client.Ping(context.Background(), nil); err != nil { + return fmt.Errorf("failed connecting mongo to: %w", err) } dbOpts := options.Database() @@ -62,67 +61,67 @@ func (db *DB) Open(uri string, opt ...database.Option) error { func (db *DB) Close() error { if err := db.db.Client().Disconnect(context.Background()); err != nil { - return errors.WithStack(err) + return fmt.Errorf("failed disconnecting mongo to: %w", err) } return nil } -func (db *DB) CreateTable(bucket []byte) error { - return db.createTable(bucket, context.Background()) +func (db *DB) CreateTable(collection []byte) error { + return db.createTable(context.Background(), collection) } -func (db *DB) DeleteTable(bucket []byte) error { - return db.deleteTable(bucket, context.Background()) +func (db *DB) DeleteTable(collection []byte) error { + return db.deleteTable(context.Background(), collection) } // Get returns the value stored in the given bucked and key. -func (db *DB) Get(bucket, key []byte) (ret []byte, err error) { - return db.get(bucket, key, context.Background()) +func (db *DB) Get(collection, key []byte) (ret []byte, err error) { + return db.get(context.Background(), collection, key) } -// Set stores the given value on bucket and key. -func (db *DB) Set(bucket, key, value []byte) error { - return db.set(bucket, key, value, context.Background()) +// Set stores the given value on collection and key. +func (db *DB) Set(collection, key, value []byte) error { + return db.set(context.Background(), collection, key, value) } -// Del deletes the value stored in the given bucket and key. -func (db *DB) Del(bucket, key []byte) error { - return db.del(bucket, key, context.Background()) +// Del deletes the value stored in the given collection and key. +func (db *DB) Del(collection, key []byte) error { + return db.del(context.Background(), collection, key) } -func (db *DB) List(bucket []byte) ([]*database.Entry, error) { - return db.list(bucket, context.Background()) +func (db *DB) List(collection []byte) ([]*database.Entry, error) { + return db.list(context.Background(), collection) } -// CmpAndSwap modifies the value at the given bucket and key (to newValue) +// CmpAndSwap modifies the value at the given collection and key (to newValue) // only if the existing (current) value matches oldValue. -func (db *DB) CmpAndSwap(bucket, key, oldValue, newValue []byte) ([]byte, bool, error) { +func (db *DB) CmpAndSwap(collection, key, oldValue, newValue []byte) ([]byte, bool, error) { wc := writeconcern.Majority() txnOptions := options.Transaction().SetWriteConcern(wc) session, err := db.db.Client().StartSession() if err != nil { - return oldValue, false, errors.Wrap(err, "error starting session") + return oldValue, false, fmt.Errorf("failed starting session to: %w", err) } defer session.EndSession(context.Background()) val, swapped := []byte{}, false err = mongo.WithSession(context.Background(), session, func(ctx mongo.SessionContext) error { if err = session.StartTransaction(txnOptions); err != nil { - return errors.Wrap(err, "error: pending transaction") + return fmt.Errorf("failed to pending transaction: %w", err) } - val, swapped, err = db.cmpAndSwap(bucket, key, oldValue, newValue, ctx) + val, swapped, err = db.cmpAndSwap(ctx, collection, key, oldValue, newValue) if err != nil { - if err = session.AbortTransaction(context.Background()); err != nil { - return errors.Wrapf(err, "failed to execute CmpAndSwap transaction on %s/%s and failed to rollback transaction", bucket, key) + if err = session.AbortTransaction(ctx); err != nil { + return fmt.Errorf("failed to execute CmpAndSwap transaction on %s/%s and failed to rollback transaction: %w", collection, key, err) } - return errors.Wrap(err, "error aborting transaction") + return fmt.Errorf("failed aborting transaction to: %w", err) } - if err = session.CommitTransaction(context.Background()); err != nil { - return errors.Wrap(err, "error committing transaction") + if err = session.CommitTransaction(ctx); err != nil { + return fmt.Errorf("failed committing transaction to: %w", err) } return nil }) @@ -137,22 +136,22 @@ func (db *DB) Update(tx *database.Tx) error { session, err := db.db.Client().StartSession() if err != nil { - return errors.Wrap(err, "error starting session") + return fmt.Errorf("failed starting session to: %w", err) } defer session.EndSession(context.Background()) - err = mongo.WithSession(context.TODO(), session, func(ctx mongo.SessionContext) error { + err = mongo.WithSession(context.Background(), session, func(ctx mongo.SessionContext) error { if err = session.StartTransaction(txnOptions); err != nil { - return errors.Wrap(err, "error: pending transaction") + return fmt.Errorf("failed to pending transaction: %w", err) } - err = db.executeTransactions(tx, ctx, session) + err = db.executeTransactions(ctx, tx, session) if err != nil { return err } - if err = errors.WithStack(session.CommitTransaction(context.Background())); err != nil { - return errors.Wrap(err, "error committing transaction") + if err = session.CommitTransaction(ctx); err != nil { + return fmt.Errorf("failed committing transaction to: %w", err) } return nil }) @@ -164,10 +163,10 @@ func (db *DB) Update(tx *database.Tx) error { return nil } -// CreateTable creates a bucket or an embedded bucket if it does not exists. -func (db *DB) createTable(bucket []byte, ctx context.Context) error { - if err := db.db.CreateCollection(ctx, string(bucket)); err != nil { - return errors.Wrap(err, "error creating collection") +// CreateTable creates a collection or an embedded collection if it does not exists. +func (db *DB) createTable(ctx context.Context, collection []byte) error { + if err := db.db.CreateCollection(ctx, string(collection)); err != nil { + return fmt.Errorf("failed creating collection %s to: %w", collection, err) } // create an index on the Key field @@ -176,111 +175,115 @@ func (db *DB) createTable(bucket []byte, ctx context.Context) error { Options: options.Index().SetUnique(true), } - _, err := db.db.Collection(string(bucket)).Indexes().CreateOne(context.TODO(), index) + _, err := db.db.Collection(string(collection)).Indexes().CreateOne(ctx, index) if err != nil { - return errors.Wrap(err, "error creating collection") + return fmt.Errorf("failed creating collection %s to: %w", collection, err) } return nil } -// DeleteTable deletes a root or embedded bucket. Returns an error if the -// bucket cannot be found or if the key represents a non-bucket value. -func (db *DB) deleteTable(bucket []byte, ctx context.Context) error { - if !collectionExists(db.db, string(bucket)) { - return errors.Wrapf(database.ErrNotFound, "bucket %s not found", bucket) +// DeleteTable deletes a root or embedded collection. Returns an error if the +// collection cannot be found or if the key represents a non-collection value. +func (db *DB) deleteTable(ctx context.Context, collection []byte) error { + if !collectionExists(ctx, db.db, string(collection)) { + return fmt.Errorf("failed deleting collection %s to: %w ", collection, database.ErrNotFound) } - if err := db.db.Collection(string(bucket)).Drop(ctx); err != nil { - return errors.Wrapf(err, "error dropping collection: %s", bucket) + if err := db.db.Collection(string(collection)).Drop(ctx); err != nil { + return fmt.Errorf("failed dropping collection %s to: %w", collection, err) } return nil } -func (db *DB) get(bucket, key []byte, ctx context.Context) (ret []byte, err error) { +func (db *DB) get(ctx context.Context, collection, key []byte) (ret []byte, err error) { filter := createFilter("key", key) - res := db.db.Collection(string(bucket)).FindOne(ctx, filter) + res := db.db.Collection(string(collection)).FindOne(ctx, filter) if err := res.Err(); err != nil { - return nil, errors.Wrapf(database.ErrNotFound, "%s/%s", bucket, key) + return nil, fmt.Errorf("failed finding %s/%s to: %w", collection, key, database.ErrNotFound) } - result := Tuple{} + result := tuple{} if err := res.Decode(&result); err != nil { - return nil, errors.Wrap(err, "error decoding value") + return nil, fmt.Errorf("failed decoding value for %s/%s to: %w", collection, key, err) } - return []byte(result.Value), nil + return result.Value, nil } -func (db *DB) set(bucket, key, value []byte, ctx context.Context) error { +func (db *DB) set(ctx context.Context, collection, key, value []byte) error { filter := createFilter("key", key) update := createUpdate(value) opts := options.Update().SetUpsert(true) - _, err := db.db.Collection(string(bucket)).UpdateOne(ctx, filter, update, opts) + _, err := db.db.Collection(string(collection)).UpdateOne(ctx, filter, update, opts) if err != nil { - return errors.Wrapf(err, "error setting value %s/%s", bucket, key) + return fmt.Errorf("failed setting value %s/%s to: %w", collection, key, err) } return nil } -// List returns the full list of entries in a bucket. -func (db *DB) list(bucket []byte, ctx context.Context) ([]*database.Entry, error) { - if !collectionExists(db.db, string(bucket)) { - return nil, errors.Wrapf(database.ErrNotFound, "bucket %s not found", bucket) +// List returns the full list of entries in a collection. +func (db *DB) list(ctx context.Context, collection []byte) ([]*database.Entry, error) { + if !collectionExists(ctx, db.db, string(collection)) { + return nil, fmt.Errorf("failed finding collection %s: %w", collection, database.ErrNotFound) } // match all filter := bson.D{{}} - cursor, err := db.db.Collection(string(bucket)).Find(ctx, filter) + cursor, err := db.db.Collection(string(collection)).Find(ctx, filter) if err != nil { - return nil, errors.Wrap(err, "error listing values") + return nil, fmt.Errorf("failed listing values of collection %s to: %w", collection, err) } - defer cursor.Close(context.Background()) + defer cursor.Close(ctx) if err = cursor.Err(); err != nil { - return nil, errors.Wrap(err, "error listing values") + return nil, fmt.Errorf("failed listing values of %s to: %w", collection, err) } var entries []*database.Entry - for cursor.Next(context.Background()) { - tuple := Tuple{} - cursor.Decode(&tuple) + for cursor.Next(ctx) { + t := tuple{} + + if err := cursor.Decode(&t); err != nil { + return nil, fmt.Errorf("failed decoding value to: %w", err) + } + entries = append(entries, &database.Entry{ - Bucket: bucket, - Key: []byte(tuple.Key), - Value: []byte(tuple.Value), + Bucket: collection, + Key: t.Key, + Value: t.Value, }) } if err = cursor.Err(); err != nil { - return nil, errors.Wrap(err, "error listing values") + return nil, fmt.Errorf("failed listing values of collection %s to: %w", collection, err) } return entries, nil } -// Del deletes the value stored in the given bucket and key. -func (db *DB) del(bucket, key []byte, ctx context.Context) error { +// Del deletes the value stored in the given collection and key. +func (db *DB) del(ctx context.Context, collection, key []byte) error { filter := createFilter("key", key) - mongoRes, err := db.db.Collection(string(bucket)).DeleteOne(ctx, filter) + mongoRes, err := db.db.Collection(string(collection)).DeleteOne(ctx, filter) if err != nil { - return errors.Wrapf(err, "failed to delete %s/%s", bucket, key) + return fmt.Errorf("failed deleting %s/%s to: %w", collection, key, err) } if mongoRes.DeletedCount == 0 { - return errors.Wrapf(err, "failed to delete: %s/%s. Value not found", bucket, key) + return fmt.Errorf("failed to delete: %s/%s to: %w", collection, key, database.ErrNotFound) } return nil } -func (db *DB) cmpAndSwap(bucket, key, target, newValue []byte, ctx context.Context) ([]byte, bool, error) { - v, err := db.get(bucket, key, ctx) +func (db *DB) cmpAndSwap(ctx context.Context, collection, key, target, newValue []byte) ([]byte, bool, error) { + v, err := db.get(ctx, collection, key) if err != nil && !database.IsErrNotFound(err) { return nil, false, err } @@ -289,7 +292,7 @@ func (db *DB) cmpAndSwap(bucket, key, target, newValue []byte, ctx context.Conte return v, false, nil } - err = db.set(bucket, key, newValue, ctx) + err = db.set(ctx, collection, key, newValue) if err != nil { return nil, false, err } @@ -297,48 +300,48 @@ func (db *DB) cmpAndSwap(bucket, key, target, newValue []byte, ctx context.Conte return newValue, true, nil } -func (db *DB) executeTransactions(tx *database.Tx, ctx mongo.SessionContext, session mongo.Session) error { +func (db *DB) executeTransactions(ctx mongo.SessionContext, tx *database.Tx, session mongo.Session) error { for _, op := range tx.Operations { var err error switch op.Cmd { case database.CreateTable: if err := db.CreateTable(op.Bucket); err != nil { - return abort(session, err) + return abort(ctx, session, err) } case database.DeleteTable: if err := db.DeleteTable(op.Bucket); err != nil { - return abort(session, err) + return abort(ctx, session, err) } case database.Get: - if op.Result, err = db.get(op.Bucket, op.Key, ctx); err != nil { - return abort(session, err) + if op.Result, err = db.get(ctx, op.Bucket, op.Key); err != nil { + return abort(ctx, session, err) } case database.Set: - if err := db.set(op.Bucket, op.Key, op.Value, ctx); err != nil { - return abort(session, err) + if err := db.set(ctx, op.Bucket, op.Key, op.Value); err != nil { + return abort(ctx, session, err) } case database.Delete: - if err := db.del(op.Bucket, op.Key, ctx); err != nil { - return abort(session, err) + if err := db.del(ctx, op.Bucket, op.Key); err != nil { + return abort(ctx, session, err) } case database.CmpAndSwap: - op.Result, op.Swapped, err = db.cmpAndSwap(op.Bucket, op.Key, op.CmpValue, op.Value, ctx) + op.Result, op.Swapped, err = db.cmpAndSwap(ctx, op.Bucket, op.Key, op.CmpValue, op.Value) if err != nil { - return abort(session, err) + return abort(ctx, session, err) } case database.CmpOrRollback: - return abort(session, database.ErrOpNotSupported) + return abort(ctx, session, database.ErrOpNotSupported) default: - return abort(session, database.ErrOpNotSupported) + return abort(ctx, session, database.ErrOpNotSupported) } } return nil } -func collectionExists(db *mongo.Database, name string) bool { +func collectionExists(ctx context.Context, db *mongo.Database, name string) bool { filter := createFilter("name", name) - list, err := db.ListCollectionNames(context.Background(), filter) + list, err := db.ListCollectionNames(ctx, filter) if err != nil { return false } @@ -353,10 +356,10 @@ func createUpdate(value []byte) bson.D { return bson.D{{Key: "$set", Value: bson.D{{Key: "value", Value: value}}}} } -func abort(session mongo.Session, err error) error { - abortError := session.AbortTransaction(context.Background()) +func abort(ctx context.Context, session mongo.Session, err error) error { + abortError := session.AbortTransaction(ctx) if abortError != nil { - return errors.Wrap(err, "error aborting transaction") + return fmt.Errorf("failed aborting transaction to: %w", err) } - return errors.Wrap(err, "UPDATE failed") + return fmt.Errorf("failed update to: %w", err) } diff --git a/nosql_test.go b/nosql_test.go index f6f8c5c..bfe5a52 100644 --- a/nosql_test.go +++ b/nosql_test.go @@ -321,19 +321,13 @@ func TestMongoDB(t *testing.T) { isCITest := os.Getenv("CI") if isCITest == "" { - fmt.Printf("Not running MongoDB integration tests\n") - return + t.Skip("not running MongoDB integration tests\n") } db, err := New("mongodb", fmt.Sprintf("mongodb://%s:%s@%s/?replicaSet=%s", uname, pwd, addr, replicaSet), WithDatabase(testDB)) - if err != nil { - t.Fatalf(fmt.Sprintf("Error: %s\n", err)) - return - } assert.FatalError(t, err) - defer db.Close() run(t, db) From 00983adf8359f1d090ab2444725aefa5dce29c64 Mon Sep 17 00:00:00 2001 From: elasticspoon Date: Tue, 13 Feb 2024 18:30:10 -0500 Subject: [PATCH 3/4] fix: error syntax changes Co-authored-by: Herman Slatman --- mongo/mongo.go | 52 +++++++++++++++++++++++++------------------------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/mongo/mongo.go b/mongo/mongo.go index 9b42e11..4fb531d 100644 --- a/mongo/mongo.go +++ b/mongo/mongo.go @@ -38,7 +38,7 @@ func (db *DB) Open(uri string, opt ...database.Option) error { clientOptions := options.Client().ApplyURI(uri) if rs := clientOptions.ReplicaSet; *rs == "" { - return fmt.Errorf("to enable transactions, please provide a replica set name") + return fmt.Errorf("replica set name is required to enable transactions") } ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) @@ -46,11 +46,11 @@ func (db *DB) Open(uri string, opt ...database.Option) error { client, err := mongo.Connect(ctx, clientOptions) if err != nil { - return fmt.Errorf("failed to configuration: %v, %w", clientOptions, err) + return fmt.Errorf("failed to invalid options %v: %w", clientOptions, err) } if err = client.Ping(context.Background(), nil); err != nil { - return fmt.Errorf("failed connecting mongo to: %w", err) + return fmt.Errorf("failed connecting to MongoDB: %w", err) } dbOpts := options.Database() @@ -67,10 +67,13 @@ func (db *DB) Close() error { return nil } +// CreateTable creates a collection or an embedded collection if it does not exists. func (db *DB) CreateTable(collection []byte) error { return db.createTable(context.Background(), collection) } +// DeleteTable deletes a root or embedded collection. Returns an error if the +// collection cannot be found or if the key represents a non-collection value. func (db *DB) DeleteTable(collection []byte) error { return db.deleteTable(context.Background(), collection) } @@ -102,7 +105,7 @@ func (db *DB) CmpAndSwap(collection, key, oldValue, newValue []byte) ([]byte, bo session, err := db.db.Client().StartSession() if err != nil { - return oldValue, false, fmt.Errorf("failed starting session to: %w", err) + return oldValue, false, fmt.Errorf("failed starting session: %w", err) } defer session.EndSession(context.Background()) @@ -117,11 +120,11 @@ func (db *DB) CmpAndSwap(collection, key, oldValue, newValue []byte) ([]byte, bo if err = session.AbortTransaction(ctx); err != nil { return fmt.Errorf("failed to execute CmpAndSwap transaction on %s/%s and failed to rollback transaction: %w", collection, key, err) } - return fmt.Errorf("failed aborting transaction to: %w", err) + return fmt.Errorf("failed aborting transaction: %w", err) } if err = session.CommitTransaction(ctx); err != nil { - return fmt.Errorf("failed committing transaction to: %w", err) + return fmt.Errorf("failed committing transaction: %w", err) } return nil }) @@ -136,7 +139,7 @@ func (db *DB) Update(tx *database.Tx) error { session, err := db.db.Client().StartSession() if err != nil { - return fmt.Errorf("failed starting session to: %w", err) + return fmt.Errorf("failed starting session: %w", err) } defer session.EndSession(context.Background()) @@ -151,7 +154,7 @@ func (db *DB) Update(tx *database.Tx) error { } if err = session.CommitTransaction(ctx); err != nil { - return fmt.Errorf("failed committing transaction to: %w", err) + return fmt.Errorf("failed committing transaction: %w", err) } return nil }) @@ -163,10 +166,9 @@ func (db *DB) Update(tx *database.Tx) error { return nil } -// CreateTable creates a collection or an embedded collection if it does not exists. func (db *DB) createTable(ctx context.Context, collection []byte) error { if err := db.db.CreateCollection(ctx, string(collection)); err != nil { - return fmt.Errorf("failed creating collection %s to: %w", collection, err) + return fmt.Errorf("failed creating collection %q: %w", collection, err) } // create an index on the Key field @@ -177,21 +179,19 @@ func (db *DB) createTable(ctx context.Context, collection []byte) error { _, err := db.db.Collection(string(collection)).Indexes().CreateOne(ctx, index) if err != nil { - return fmt.Errorf("failed creating collection %s to: %w", collection, err) + return fmt.Errorf("failed creating collection %q: %w", collection, err) } return nil } -// DeleteTable deletes a root or embedded collection. Returns an error if the -// collection cannot be found or if the key represents a non-collection value. func (db *DB) deleteTable(ctx context.Context, collection []byte) error { if !collectionExists(ctx, db.db, string(collection)) { - return fmt.Errorf("failed deleting collection %s to: %w ", collection, database.ErrNotFound) + return fmt.Errorf("failed deleting collection %q: %w ", collection, database.ErrNotFound) } if err := db.db.Collection(string(collection)).Drop(ctx); err != nil { - return fmt.Errorf("failed dropping collection %s to: %w", collection, err) + return fmt.Errorf("failed dropping collection %q: %w", collection, err) } return nil } @@ -201,12 +201,12 @@ func (db *DB) get(ctx context.Context, collection, key []byte) (ret []byte, err res := db.db.Collection(string(collection)).FindOne(ctx, filter) if err := res.Err(); err != nil { - return nil, fmt.Errorf("failed finding %s/%s to: %w", collection, key, database.ErrNotFound) + return nil, fmt.Errorf("failed finding %s/%s: %w", collection, key, database.ErrNotFound) } result := tuple{} if err := res.Decode(&result); err != nil { - return nil, fmt.Errorf("failed decoding value for %s/%s to: %w", collection, key, err) + return nil, fmt.Errorf("failed decoding value for %s/%s: %w", collection, key, err) } return result.Value, nil @@ -219,7 +219,7 @@ func (db *DB) set(ctx context.Context, collection, key, value []byte) error { _, err := db.db.Collection(string(collection)).UpdateOne(ctx, filter, update, opts) if err != nil { - return fmt.Errorf("failed setting value %s/%s to: %w", collection, key, err) + return fmt.Errorf("failed setting value %s/%s: %w", collection, key, err) } return nil } @@ -227,7 +227,7 @@ func (db *DB) set(ctx context.Context, collection, key, value []byte) error { // List returns the full list of entries in a collection. func (db *DB) list(ctx context.Context, collection []byte) ([]*database.Entry, error) { if !collectionExists(ctx, db.db, string(collection)) { - return nil, fmt.Errorf("failed finding collection %s: %w", collection, database.ErrNotFound) + return nil, fmt.Errorf("failed finding collection %q: %w", collection, database.ErrNotFound) } // match all @@ -240,7 +240,7 @@ func (db *DB) list(ctx context.Context, collection []byte) ([]*database.Entry, e defer cursor.Close(ctx) if err = cursor.Err(); err != nil { - return nil, fmt.Errorf("failed listing values of %s to: %w", collection, err) + return nil, fmt.Errorf("failed listing values of %q: %w", collection, err) } var entries []*database.Entry @@ -249,7 +249,7 @@ func (db *DB) list(ctx context.Context, collection []byte) ([]*database.Entry, e t := tuple{} if err := cursor.Decode(&t); err != nil { - return nil, fmt.Errorf("failed decoding value to: %w", err) + return nil, fmt.Errorf("failed decoding value: %w", err) } entries = append(entries, &database.Entry{ @@ -260,7 +260,7 @@ func (db *DB) list(ctx context.Context, collection []byte) ([]*database.Entry, e } if err = cursor.Err(); err != nil { - return nil, fmt.Errorf("failed listing values of collection %s to: %w", collection, err) + return nil, fmt.Errorf("failed listing values of collection %q: %w", collection, err) } return entries, nil @@ -272,11 +272,11 @@ func (db *DB) del(ctx context.Context, collection, key []byte) error { mongoRes, err := db.db.Collection(string(collection)).DeleteOne(ctx, filter) if err != nil { - return fmt.Errorf("failed deleting %s/%s to: %w", collection, key, err) + return fmt.Errorf("failed deleting %s/%s: %w", collection, key, err) } if mongoRes.DeletedCount == 0 { - return fmt.Errorf("failed to delete: %s/%s to: %w", collection, key, database.ErrNotFound) + return fmt.Errorf("failed to delete: %s/%s: %w", collection, key, database.ErrNotFound) } return nil @@ -359,7 +359,7 @@ func createUpdate(value []byte) bson.D { func abort(ctx context.Context, session mongo.Session, err error) error { abortError := session.AbortTransaction(ctx) if abortError != nil { - return fmt.Errorf("failed aborting transaction to: %w", err) + return fmt.Errorf("failed aborting transaction due to %q: %w", abortError, err) } - return fmt.Errorf("failed update to: %w", err) + return fmt.Errorf("failed executing transaction: %w", err) } From 41c3abaaf4be8e13cff755ee1670cbca1b7174a2 Mon Sep 17 00:00:00 2001 From: YuriBocharov Date: Thu, 22 Feb 2024 17:01:28 -0500 Subject: [PATCH 4/4] fix: more syntax Changes to error syntax --- mongo/mongo.go | 72 ++++++++++++++++++++++++-------------------------- 1 file changed, 34 insertions(+), 38 deletions(-) diff --git a/mongo/mongo.go b/mongo/mongo.go index 4fb531d..584ee83 100644 --- a/mongo/mongo.go +++ b/mongo/mongo.go @@ -6,6 +6,7 @@ package mongo import ( "bytes" "context" + "errors" "fmt" "time" @@ -38,7 +39,7 @@ func (db *DB) Open(uri string, opt ...database.Option) error { clientOptions := options.Client().ApplyURI(uri) if rs := clientOptions.ReplicaSet; *rs == "" { - return fmt.Errorf("replica set name is required to enable transactions") + return errors.New("replica set name is required to enable transactions") } ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) @@ -46,10 +47,10 @@ func (db *DB) Open(uri string, opt ...database.Option) error { client, err := mongo.Connect(ctx, clientOptions) if err != nil { - return fmt.Errorf("failed to invalid options %v: %w", clientOptions, err) + return fmt.Errorf("failed creating client to invalid options %v: %w", clientOptions, err) } - if err = client.Ping(context.Background(), nil); err != nil { + if err := client.Ping(context.Background(), nil); err != nil { return fmt.Errorf("failed connecting to MongoDB: %w", err) } @@ -61,13 +62,13 @@ func (db *DB) Open(uri string, opt ...database.Option) error { func (db *DB) Close() error { if err := db.db.Client().Disconnect(context.Background()); err != nil { - return fmt.Errorf("failed disconnecting mongo to: %w", err) + return fmt.Errorf("failed disconnecting from MongoDB: %w", err) } return nil } -// CreateTable creates a collection or an embedded collection if it does not exists. +// CreateTable creates a collection or an embedded collection if it does not exist. func (db *DB) CreateTable(collection []byte) error { return db.createTable(context.Background(), collection) } @@ -78,12 +79,12 @@ func (db *DB) DeleteTable(collection []byte) error { return db.deleteTable(context.Background(), collection) } -// Get returns the value stored in the given bucked and key. +// Get returns the value stored in the given collection and key. func (db *DB) Get(collection, key []byte) (ret []byte, err error) { return db.get(context.Background(), collection, key) } -// Set stores the given value on collection and key. +// Set stores the given value in collection and key. func (db *DB) Set(collection, key, value []byte) error { return db.set(context.Background(), collection, key, value) } @@ -93,6 +94,7 @@ func (db *DB) Del(collection, key []byte) error { return db.del(context.Background(), collection, key) } +// List returns the full list of entries in a collection. func (db *DB) List(collection []byte) ([]*database.Entry, error) { return db.list(context.Background(), collection) } @@ -111,19 +113,19 @@ func (db *DB) CmpAndSwap(collection, key, oldValue, newValue []byte) ([]byte, bo val, swapped := []byte{}, false err = mongo.WithSession(context.Background(), session, func(ctx mongo.SessionContext) error { - if err = session.StartTransaction(txnOptions); err != nil { - return fmt.Errorf("failed to pending transaction: %w", err) + if err := session.StartTransaction(txnOptions); err != nil { + return fmt.Errorf("failed starting transaction to pending transaction: %w", err) } val, swapped, err = db.cmpAndSwap(ctx, collection, key, oldValue, newValue) if err != nil { - if err = session.AbortTransaction(ctx); err != nil { + if err := session.AbortTransaction(ctx); err != nil { return fmt.Errorf("failed to execute CmpAndSwap transaction on %s/%s and failed to rollback transaction: %w", collection, key, err) } return fmt.Errorf("failed aborting transaction: %w", err) } - if err = session.CommitTransaction(ctx); err != nil { + if err := session.CommitTransaction(ctx); err != nil { return fmt.Errorf("failed committing transaction: %w", err) } return nil @@ -144,16 +146,15 @@ func (db *DB) Update(tx *database.Tx) error { defer session.EndSession(context.Background()) err = mongo.WithSession(context.Background(), session, func(ctx mongo.SessionContext) error { - if err = session.StartTransaction(txnOptions); err != nil { - return fmt.Errorf("failed to pending transaction: %w", err) + if err := session.StartTransaction(txnOptions); err != nil { + return fmt.Errorf("failed starting transaction to pending transaction: %w", err) } - err = db.executeTransactions(ctx, tx, session) - if err != nil { + if err := db.executeTransactions(ctx, tx, session); err != nil { return err } - if err = session.CommitTransaction(ctx); err != nil { + if err := session.CommitTransaction(ctx); err != nil { return fmt.Errorf("failed committing transaction: %w", err) } return nil @@ -177,8 +178,7 @@ func (db *DB) createTable(ctx context.Context, collection []byte) error { Options: options.Index().SetUnique(true), } - _, err := db.db.Collection(string(collection)).Indexes().CreateOne(ctx, index) - if err != nil { + if _, err := db.db.Collection(string(collection)).Indexes().CreateOne(ctx, index); err != nil { return fmt.Errorf("failed creating collection %q: %w", collection, err) } @@ -191,7 +191,7 @@ func (db *DB) deleteTable(ctx context.Context, collection []byte) error { } if err := db.db.Collection(string(collection)).Drop(ctx); err != nil { - return fmt.Errorf("failed dropping collection %q: %w", collection, err) + return fmt.Errorf("failed deleting collection %q: %w", collection, err) } return nil } @@ -217,14 +217,12 @@ func (db *DB) set(ctx context.Context, collection, key, value []byte) error { update := createUpdate(value) opts := options.Update().SetUpsert(true) - _, err := db.db.Collection(string(collection)).UpdateOne(ctx, filter, update, opts) - if err != nil { + if _, err := db.db.Collection(string(collection)).UpdateOne(ctx, filter, update, opts); err != nil { return fmt.Errorf("failed setting value %s/%s: %w", collection, key, err) } return nil } -// List returns the full list of entries in a collection. func (db *DB) list(ctx context.Context, collection []byte) ([]*database.Entry, error) { if !collectionExists(ctx, db.db, string(collection)) { return nil, fmt.Errorf("failed finding collection %q: %w", collection, database.ErrNotFound) @@ -235,12 +233,12 @@ func (db *DB) list(ctx context.Context, collection []byte) ([]*database.Entry, e cursor, err := db.db.Collection(string(collection)).Find(ctx, filter) if err != nil { - return nil, fmt.Errorf("failed listing values of collection %s to: %w", collection, err) + return nil, fmt.Errorf("failed listing values of collection %q: %w", collection, err) } defer cursor.Close(ctx) - if err = cursor.Err(); err != nil { - return nil, fmt.Errorf("failed listing values of %q: %w", collection, err) + if err := cursor.Err(); err != nil { + return nil, fmt.Errorf("failed listing values of collection %q: %w", collection, err) } var entries []*database.Entry @@ -259,14 +257,13 @@ func (db *DB) list(ctx context.Context, collection []byte) ([]*database.Entry, e }) } - if err = cursor.Err(); err != nil { + if err := cursor.Err(); err != nil { return nil, fmt.Errorf("failed listing values of collection %q: %w", collection, err) } return entries, nil } -// Del deletes the value stored in the given collection and key. func (db *DB) del(ctx context.Context, collection, key []byte) error { filter := createFilter("key", key) @@ -276,7 +273,7 @@ func (db *DB) del(ctx context.Context, collection, key []byte) error { } if mongoRes.DeletedCount == 0 { - return fmt.Errorf("failed to delete: %s/%s: %w", collection, key, database.ErrNotFound) + return fmt.Errorf("failed deleting %s/%s: %w", collection, key, database.ErrNotFound) } return nil @@ -292,8 +289,7 @@ func (db *DB) cmpAndSwap(ctx context.Context, collection, key, target, newValue return v, false, nil } - err = db.set(ctx, collection, key, newValue) - if err != nil { + if err := db.set(ctx, collection, key, newValue); err != nil { return nil, false, err } @@ -306,28 +302,28 @@ func (db *DB) executeTransactions(ctx mongo.SessionContext, tx *database.Tx, ses switch op.Cmd { case database.CreateTable: if err := db.CreateTable(op.Bucket); err != nil { - return abort(ctx, session, err) + return abort(ctx, session, fmt.Errorf("failed creating table %s: %w", op.Bucket, err)) } case database.DeleteTable: if err := db.DeleteTable(op.Bucket); err != nil { - return abort(ctx, session, err) + return abort(ctx, session, fmt.Errorf("failed deleting table %s: %w", op.Bucket, err)) } case database.Get: if op.Result, err = db.get(ctx, op.Bucket, op.Key); err != nil { - return abort(ctx, session, err) + return abort(ctx, session, fmt.Errorf("failed getting %s/%s: %w", op.Bucket, op.Key, err)) } case database.Set: if err := db.set(ctx, op.Bucket, op.Key, op.Value); err != nil { - return abort(ctx, session, err) + return abort(ctx, session, fmt.Errorf("failed setting %s/%s: %w", op.Bucket, op.Key, err)) } case database.Delete: if err := db.del(ctx, op.Bucket, op.Key); err != nil { - return abort(ctx, session, err) + return abort(ctx, session, fmt.Errorf("failed deleting %s/%s: %w", op.Bucket, op.Key, err)) } case database.CmpAndSwap: op.Result, op.Swapped, err = db.cmpAndSwap(ctx, op.Bucket, op.Key, op.CmpValue, op.Value) if err != nil { - return abort(ctx, session, err) + return abort(ctx, session, fmt.Errorf("failed load-or-store on %s/%s: %w", op.Bucket, op.Key, err)) } case database.CmpOrRollback: return abort(ctx, session, database.ErrOpNotSupported) @@ -359,7 +355,7 @@ func createUpdate(value []byte) bson.D { func abort(ctx context.Context, session mongo.Session, err error) error { abortError := session.AbortTransaction(ctx) if abortError != nil { - return fmt.Errorf("failed aborting transaction due to %q: %w", abortError, err) + return fmt.Errorf("failed executing update, rollback failed to %w: %w", abortError, err) } - return fmt.Errorf("failed executing transaction: %w", err) + return fmt.Errorf("failed executing update: %w", err) }