diff --git a/database/database.go b/database/database.go index abdb5ce..1b19891 100644 --- a/database/database.go +++ b/database/database.go @@ -1,9 +1,8 @@ package database import ( - "fmt" - "errors" + "fmt" ) var ( @@ -66,23 +65,28 @@ type DB interface { Open(dataSourceName string, opt ...Option) error // Close closes the current database. Close() error - // Get returns the value stored in the given table/bucket and key. + // Get returns the value stored in the given table/bucket/collection and key. Get(bucket, key []byte) (ret []byte, err error) - // Set sets the given value in the given table/bucket and key. + // Set sets the given value in the given table/bucket/collection and key. Set(bucket, key, value []byte) error - // CmpAndSwap swaps the value at the given bucket and key if the current - // value is equivalent to the oldValue input. Returns 'true' if the - // swap was successful and 'false' otherwise. + // CmpAndSwap takess a bucket, key, oldValue and newValue as inputs: + // CmpAndSwap returns: error, wasSwapped, keyValue + // - if the bucket / key does not exist, it returns database.ErrNotFound + // - if the bucket / key exists but the keyValue is not equivalent to oldValue, + // it returns wasSwapped = false, keyValue + // - if the bucket / key exists and the keyValue is equivalent to oldValue, + // it returns wasSwapped = true, keyValue + // - if an error occurs, it returns error CmpAndSwap(bucket, key, oldValue, newValue []byte) ([]byte, bool, error) // Del deletes the data in the given table/bucket and key. Del(bucket, key []byte) error - // List returns a list of all the entries in a given table/bucket. + // List returns a list of all the entries in a given table/bucket/collection. List(bucket []byte) ([]*Entry, error) // Update performs a transaction with multiple read-write commands. Update(tx *Tx) error - // CreateTable creates a table or a bucket in the database. + // CreateTable creates a table, bucket or collection in the database. CreateTable(bucket []byte) error - // DeleteTable deletes a table or a bucket in the database. + // DeleteTable deletes a table, bucket or collection in the database. DeleteTable(bucket []byte) error } diff --git a/go.mod b/go.mod index a8f0e05..cb9436f 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/pkg/errors v0.9.1 github.com/smallstep/assert v0.0.0-20180720014142-de77670473b5 go.etcd.io/bbolt v1.3.8 + go.mongodb.org/mongo-driver v1.13.1 ) require ( @@ -27,9 +28,15 @@ require ( github.com/jackc/pgproto3/v2 v2.3.2 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect github.com/jackc/pgtype v1.14.0 // indirect - github.com/klauspost/compress v1.12.3 // indirect + github.com/klauspost/compress v1.13.6 // indirect + github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe // indirect + github.com/xdg-go/pbkdf2 v1.0.0 // indirect + github.com/xdg-go/scram v1.1.2 // indirect + github.com/xdg-go/stringprep v1.0.4 // indirect + github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d // indirect golang.org/x/crypto v0.17.0 // indirect golang.org/x/net v0.17.0 // indirect + golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 // indirect golang.org/x/sys v0.15.0 // indirect golang.org/x/text v0.14.0 // indirect google.golang.org/protobuf v1.25.0 // indirect diff --git a/go.sum b/go.sum index 23b9b87..0b513e4 100644 --- a/go.sum +++ b/go.sum @@ -58,14 +58,16 @@ github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvq github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= github.com/golang/protobuf v1.4.3 h1:JjCZWpVbqXDqFVmTfYWEVTMIYrL/NPdPSCHPJ0T/raM= github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v0.0.3 h1:fHPg5GQYlCeLIPB9BZqMVR5nR9A+IM5zcgeTdjMYmLA= github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.0 h1:/QaMHBdZ26BB3SSst0Iwl10Epc+xhTquomWX0oZEB6w= github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.2 h1:X2ev0eStA3AbceY54o37/0PQ/UWqKEiiO2dKL5OPaFM= +github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= @@ -118,8 +120,9 @@ github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0f github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v1.3.0/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= -github.com/klauspost/compress v1.12.3 h1:G5AfA94pHPysR56qqrkO2pxEexdDzrpFJ6yt/VqWxVU= github.com/klauspost/compress v1.12.3/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg= +github.com/klauspost/compress v1.13.6 h1:P76CopJELS0TiO2mebmnzgWaajssP/EszplttgQxcgc= +github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= @@ -140,6 +143,8 @@ github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hd github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= +github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe h1:iruDEfMl2E6fbMZ9s0scYfZQ84/6SPL6zC8ACM2oIL0= +github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe/go.mod h1:wL8QJuTMNUDYhXwkmfOly8iTdp5TEcJFWZD2D7SIkUc= github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= @@ -184,11 +189,21 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0= +github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c= +github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= +github.com/xdg-go/scram v1.1.2 h1:FHX5I5B4i4hKRVRBCFRxq1iQRej7WO3hhBuJf+UUySY= +github.com/xdg-go/scram v1.1.2/go.mod h1:RT/sEzTbU5y00aCK8UOx6R7YryM0iF1N2MOmC3kKLN4= +github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8= +github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM= github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q= +github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d h1:splanxYIlg+5LfHAM6xpdFEAYOk8iySO56hMFq6uLyA= +github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d/go.mod h1:rHwXgn7JulP+udvsHwJoVG1YGAP6VLg4y9I5dyZdqmA= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= go.etcd.io/bbolt v1.3.8 h1:xs88BrvEv273UsB79e0hcVrlUWmS0a8upikMFhSyAtA= go.etcd.io/bbolt v1.3.8/go.mod h1:N9Mkw9X8x5fupy0IKsmuqVtoGDyxsaDlbk4Rd05IAQw= +go.mongodb.org/mongo-driver v1.13.1 h1:YIc7HTYsKndGK4RFzJ3covLz1byri52x0IoMB0Pt/vk= +go.mongodb.org/mongo-driver v1.13.1/go.mod h1:wcDf1JBCXy2mOW0bWHwO/IOYqdca1MPCwDtFu/Z9+eo= go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= @@ -211,6 +226,7 @@ golang.org/x/crypto v0.0.0-20201203163018-be400aefbc4c/go.mod h1:jdWPYTVW3xRLrWP golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58= golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k= golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= @@ -230,6 +246,7 @@ golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= @@ -238,6 +255,7 @@ golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAG golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 h1:uVc8UZUe6tr40fFVnUP5Oj+veunVezqYl9z7DYw9xzw= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -253,6 +271,7 @@ golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -269,6 +288,7 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= diff --git a/mongo/mongo.go b/mongo/mongo.go new file mode 100644 index 0000000..584ee83 --- /dev/null +++ b/mongo/mongo.go @@ -0,0 +1,361 @@ +//go:build !nomongo +// +build !nomongo + +package mongo + +import ( + "bytes" + "context" + "errors" + "fmt" + "time" + + "github.com/smallstep/nosql/database" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" + "go.mongodb.org/mongo-driver/mongo/writeconcern" +) + +// DB is a wrapper over *sql.DB, +type DB struct { + db *mongo.Database +} + +type tuple struct { + Key []byte `bson:"key" json:"key"` + Value []byte `bson:"value" json:"value"` +} + +// Open creates a Driver and connects to the database with the given address +// and access details. +func (db *DB) Open(uri string, opt ...database.Option) error { + opts := &database.Options{} + for _, o := range opt { + if err := o(opts); err != nil { + return err + } + } + + clientOptions := options.Client().ApplyURI(uri) + if rs := clientOptions.ReplicaSet; *rs == "" { + return errors.New("replica set name is required to enable transactions") + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + client, err := mongo.Connect(ctx, clientOptions) + if err != nil { + return fmt.Errorf("failed creating client to invalid options %v: %w", clientOptions, err) + } + + if err := client.Ping(context.Background(), nil); err != nil { + return fmt.Errorf("failed connecting to MongoDB: %w", err) + } + + dbOpts := options.Database() + db.db = client.Database(opts.Database, dbOpts) + + return nil +} + +func (db *DB) Close() error { + if err := db.db.Client().Disconnect(context.Background()); err != nil { + return fmt.Errorf("failed disconnecting from MongoDB: %w", err) + } + + return nil +} + +// CreateTable creates a collection or an embedded collection if it does not exist. +func (db *DB) CreateTable(collection []byte) error { + return db.createTable(context.Background(), collection) +} + +// DeleteTable deletes a root or embedded collection. Returns an error if the +// collection cannot be found or if the key represents a non-collection value. +func (db *DB) DeleteTable(collection []byte) error { + return db.deleteTable(context.Background(), collection) +} + +// Get returns the value stored in the given collection and key. +func (db *DB) Get(collection, key []byte) (ret []byte, err error) { + return db.get(context.Background(), collection, key) +} + +// Set stores the given value in collection and key. +func (db *DB) Set(collection, key, value []byte) error { + return db.set(context.Background(), collection, key, value) +} + +// Del deletes the value stored in the given collection and key. +func (db *DB) Del(collection, key []byte) error { + return db.del(context.Background(), collection, key) +} + +// List returns the full list of entries in a collection. +func (db *DB) List(collection []byte) ([]*database.Entry, error) { + return db.list(context.Background(), collection) +} + +// CmpAndSwap modifies the value at the given collection and key (to newValue) +// only if the existing (current) value matches oldValue. +func (db *DB) CmpAndSwap(collection, key, oldValue, newValue []byte) ([]byte, bool, error) { + wc := writeconcern.Majority() + txnOptions := options.Transaction().SetWriteConcern(wc) + + session, err := db.db.Client().StartSession() + if err != nil { + return oldValue, false, fmt.Errorf("failed starting session: %w", err) + } + defer session.EndSession(context.Background()) + + val, swapped := []byte{}, false + err = mongo.WithSession(context.Background(), session, func(ctx mongo.SessionContext) error { + if err := session.StartTransaction(txnOptions); err != nil { + return fmt.Errorf("failed starting transaction to pending transaction: %w", err) + } + + val, swapped, err = db.cmpAndSwap(ctx, collection, key, oldValue, newValue) + if err != nil { + if err := session.AbortTransaction(ctx); err != nil { + return fmt.Errorf("failed to execute CmpAndSwap transaction on %s/%s and failed to rollback transaction: %w", collection, key, err) + } + return fmt.Errorf("failed aborting transaction: %w", err) + } + + if err := session.CommitTransaction(ctx); err != nil { + return fmt.Errorf("failed committing transaction: %w", err) + } + return nil + }) + + return val, swapped, err +} + +// Update performs multiple commands on one read-write transaction. +func (db *DB) Update(tx *database.Tx) error { + wc := writeconcern.Majority() + txnOptions := options.Transaction().SetWriteConcern(wc) + + session, err := db.db.Client().StartSession() + if err != nil { + return fmt.Errorf("failed starting session: %w", err) + } + defer session.EndSession(context.Background()) + + err = mongo.WithSession(context.Background(), session, func(ctx mongo.SessionContext) error { + if err := session.StartTransaction(txnOptions); err != nil { + return fmt.Errorf("failed starting transaction to pending transaction: %w", err) + } + + if err := db.executeTransactions(ctx, tx, session); err != nil { + return err + } + + if err := session.CommitTransaction(ctx); err != nil { + return fmt.Errorf("failed committing transaction: %w", err) + } + return nil + }) + + if err != nil { + return err + } + + return nil +} + +func (db *DB) createTable(ctx context.Context, collection []byte) error { + if err := db.db.CreateCollection(ctx, string(collection)); err != nil { + return fmt.Errorf("failed creating collection %q: %w", collection, err) + } + + // create an index on the Key field + index := mongo.IndexModel{ + Keys: createFilter("key", 1), + Options: options.Index().SetUnique(true), + } + + if _, err := db.db.Collection(string(collection)).Indexes().CreateOne(ctx, index); err != nil { + return fmt.Errorf("failed creating collection %q: %w", collection, err) + } + + return nil +} + +func (db *DB) deleteTable(ctx context.Context, collection []byte) error { + if !collectionExists(ctx, db.db, string(collection)) { + return fmt.Errorf("failed deleting collection %q: %w ", collection, database.ErrNotFound) + } + + if err := db.db.Collection(string(collection)).Drop(ctx); err != nil { + return fmt.Errorf("failed deleting collection %q: %w", collection, err) + } + return nil +} + +func (db *DB) get(ctx context.Context, collection, key []byte) (ret []byte, err error) { + filter := createFilter("key", key) + res := db.db.Collection(string(collection)).FindOne(ctx, filter) + + if err := res.Err(); err != nil { + return nil, fmt.Errorf("failed finding %s/%s: %w", collection, key, database.ErrNotFound) + } + + result := tuple{} + if err := res.Decode(&result); err != nil { + return nil, fmt.Errorf("failed decoding value for %s/%s: %w", collection, key, err) + } + + return result.Value, nil +} + +func (db *DB) set(ctx context.Context, collection, key, value []byte) error { + filter := createFilter("key", key) + update := createUpdate(value) + opts := options.Update().SetUpsert(true) + + if _, err := db.db.Collection(string(collection)).UpdateOne(ctx, filter, update, opts); err != nil { + return fmt.Errorf("failed setting value %s/%s: %w", collection, key, err) + } + return nil +} + +func (db *DB) list(ctx context.Context, collection []byte) ([]*database.Entry, error) { + if !collectionExists(ctx, db.db, string(collection)) { + return nil, fmt.Errorf("failed finding collection %q: %w", collection, database.ErrNotFound) + } + + // match all + filter := bson.D{{}} + + cursor, err := db.db.Collection(string(collection)).Find(ctx, filter) + if err != nil { + return nil, fmt.Errorf("failed listing values of collection %q: %w", collection, err) + } + defer cursor.Close(ctx) + + if err := cursor.Err(); err != nil { + return nil, fmt.Errorf("failed listing values of collection %q: %w", collection, err) + } + + var entries []*database.Entry + + for cursor.Next(ctx) { + t := tuple{} + + if err := cursor.Decode(&t); err != nil { + return nil, fmt.Errorf("failed decoding value: %w", err) + } + + entries = append(entries, &database.Entry{ + Bucket: collection, + Key: t.Key, + Value: t.Value, + }) + } + + if err := cursor.Err(); err != nil { + return nil, fmt.Errorf("failed listing values of collection %q: %w", collection, err) + } + + return entries, nil +} + +func (db *DB) del(ctx context.Context, collection, key []byte) error { + filter := createFilter("key", key) + + mongoRes, err := db.db.Collection(string(collection)).DeleteOne(ctx, filter) + if err != nil { + return fmt.Errorf("failed deleting %s/%s: %w", collection, key, err) + } + + if mongoRes.DeletedCount == 0 { + return fmt.Errorf("failed deleting %s/%s: %w", collection, key, database.ErrNotFound) + } + + return nil +} + +func (db *DB) cmpAndSwap(ctx context.Context, collection, key, target, newValue []byte) ([]byte, bool, error) { + v, err := db.get(ctx, collection, key) + if err != nil && !database.IsErrNotFound(err) { + return nil, false, err + } + + if !bytes.Equal(v, target) { + return v, false, nil + } + + if err := db.set(ctx, collection, key, newValue); err != nil { + return nil, false, err + } + + return newValue, true, nil +} + +func (db *DB) executeTransactions(ctx mongo.SessionContext, tx *database.Tx, session mongo.Session) error { + for _, op := range tx.Operations { + var err error + switch op.Cmd { + case database.CreateTable: + if err := db.CreateTable(op.Bucket); err != nil { + return abort(ctx, session, fmt.Errorf("failed creating table %s: %w", op.Bucket, err)) + } + case database.DeleteTable: + if err := db.DeleteTable(op.Bucket); err != nil { + return abort(ctx, session, fmt.Errorf("failed deleting table %s: %w", op.Bucket, err)) + } + case database.Get: + if op.Result, err = db.get(ctx, op.Bucket, op.Key); err != nil { + return abort(ctx, session, fmt.Errorf("failed getting %s/%s: %w", op.Bucket, op.Key, err)) + } + case database.Set: + if err := db.set(ctx, op.Bucket, op.Key, op.Value); err != nil { + return abort(ctx, session, fmt.Errorf("failed setting %s/%s: %w", op.Bucket, op.Key, err)) + } + case database.Delete: + if err := db.del(ctx, op.Bucket, op.Key); err != nil { + return abort(ctx, session, fmt.Errorf("failed deleting %s/%s: %w", op.Bucket, op.Key, err)) + } + case database.CmpAndSwap: + op.Result, op.Swapped, err = db.cmpAndSwap(ctx, op.Bucket, op.Key, op.CmpValue, op.Value) + if err != nil { + return abort(ctx, session, fmt.Errorf("failed load-or-store on %s/%s: %w", op.Bucket, op.Key, err)) + } + case database.CmpOrRollback: + return abort(ctx, session, database.ErrOpNotSupported) + default: + return abort(ctx, session, database.ErrOpNotSupported) + } + } + + return nil +} + +func collectionExists(ctx context.Context, db *mongo.Database, name string) bool { + filter := createFilter("name", name) + list, err := db.ListCollectionNames(ctx, filter) + if err != nil { + return false + } + return len(list) > 0 +} + +func createFilter(key string, value any) bson.D { + return bson.D{{Key: key, Value: value}} +} + +func createUpdate(value []byte) bson.D { + return bson.D{{Key: "$set", Value: bson.D{{Key: "value", Value: value}}}} +} + +func abort(ctx context.Context, session mongo.Session, err error) error { + abortError := session.AbortTransaction(ctx) + if abortError != nil { + return fmt.Errorf("failed executing update, rollback failed to %w: %w", abortError, err) + } + return fmt.Errorf("failed executing update: %w", err) +} diff --git a/mongo/nomongo.go b/mongo/nomongo.go new file mode 100644 index 0000000..2e31e80 --- /dev/null +++ b/mongo/nomongo.go @@ -0,0 +1,8 @@ +//go:build nomongo +// +build nomongo + +package mongo + +import "github.com/smallstep/nosql/database" + +type DB = database.NotSupportedDB diff --git a/nosql.go b/nosql.go index a72da2a..9dd3295 100644 --- a/nosql.go +++ b/nosql.go @@ -8,6 +8,7 @@ import ( badgerV2 "github.com/smallstep/nosql/badger/v2" "github.com/smallstep/nosql/bolt" "github.com/smallstep/nosql/database" + "github.com/smallstep/nosql/mongo" "github.com/smallstep/nosql/mysql" "github.com/smallstep/nosql/postgresql" ) @@ -38,6 +39,8 @@ var ( // Available db driver types. // + // MongoDriver indicates the default MySQL database. + MongoDriver = "mongodb" // BadgerDriver indicates the default Badger database - currently Badger V1. BadgerDriver = "badger" // BadgerV1Driver explicitly selects the Badger V1 driver. @@ -72,6 +75,8 @@ func New(driver, dataSourceName string, opt ...Option) (db database.DB, err erro db = &mysql.DB{} case PostgreSQLDriver: db = &postgresql.DB{} + case MongoDriver: + db = &mongo.DB{} default: return nil, errors.Errorf("%s database not supported", driver) } diff --git a/nosql_test.go b/nosql_test.go index 0d92622..bfe5a52 100644 --- a/nosql_test.go +++ b/nosql_test.go @@ -16,7 +16,7 @@ type testUser struct { } func run(t *testing.T, db database.DB) { - var boogers = []byte("boogers") + boogers := []byte("boogers") ub := []byte("testNoSQLUsers") assert.True(t, IsErrNotFound(db.DeleteTable(ub))) @@ -269,7 +269,6 @@ func run(t *testing.T, db database.DB) { } func TestMain(m *testing.M) { - // setup path := "./tmp" if _, err := os.Stat(path); os.IsNotExist(err) { @@ -291,7 +290,7 @@ func TestMySQL(t *testing.T) { pwd = "password" proto = "tcp" addr = "127.0.0.1:3306" - //path = "/tmp/mysql.sock" + // path = "/tmp/mysql.sock" testDB = "test" ) @@ -310,12 +309,36 @@ func TestMySQL(t *testing.T) { run(t, db) } +func TestMongoDB(t *testing.T) { + var ( + uname = "user" + pwd = "password" + addr = "localhost:27017" + // path = "/tmp/mysql.sock" + replicaSet = "dbrs" + testDB = "test" + ) + + isCITest := os.Getenv("CI") + if isCITest == "" { + t.Skip("not running MongoDB integration tests\n") + } + + db, err := New("mongodb", + fmt.Sprintf("mongodb://%s:%s@%s/?replicaSet=%s", uname, pwd, addr, replicaSet), + WithDatabase(testDB)) + assert.FatalError(t, err) + defer db.Close() + + run(t, db) +} + func TestPostgreSQL(t *testing.T) { var ( uname = "user" pwd = "password" addr = "127.0.0.1:5432" - //path = "/tmp/postgresql.sock" + // path = "/tmp/postgresql.sock" testDB = "test" )