diff --git a/.release-please-manifest.json b/.release-please-manifest.json index a7fd6a0..c3f1463 100644 --- a/.release-please-manifest.json +++ b/.release-please-manifest.json @@ -1,3 +1,3 @@ { - ".": "1.1.5" + ".": "1.2.0" } diff --git a/.release-please-manifest.premain.json b/.release-please-manifest.premain.json index a7fd6a0..6362bad 100644 --- a/.release-please-manifest.premain.json +++ b/.release-please-manifest.premain.json @@ -1,3 +1,3 @@ { - ".": "1.1.5" + ".": "1.2.0-rc" } diff --git a/CHANGELOG.md b/CHANGELOG.md index 1ccbbe5..ccc8403 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,35 @@ # Changelog +## [1.2.0](https://github.com/theory-cloud/TableTheory/compare/v1.1.5...v1.2.0) (2026-01-22) + + +### Features + +* **mocks:** add transaction builder mock ([ea39672](https://github.com/theory-cloud/TableTheory/commit/ea39672edffd22bf24b1471e244c14b79f06211d)) +* **mocks:** add transaction builder mock ([16ab5a5](https://github.com/theory-cloud/TableTheory/commit/16ab5a5d7b22d1087973f96a5c69b2c3a3796c3e)) + + +### Bug Fixes + +* **ci:** retry git fetch in branch-version-sync ([1b4d855](https://github.com/theory-cloud/TableTheory/commit/1b4d8557fe66c5c333846469369d9e5285cc1232)) +* **mocks:** satisfy lint gates ([a9cd117](https://github.com/theory-cloud/TableTheory/commit/a9cd1170fc200489369b76f098635321ed3d81c0)) +* **premain:** restore prerelease version alignment ([9b07cdb](https://github.com/theory-cloud/TableTheory/commit/9b07cdb7df5e69be8012374f742d89252ffde942)) + +## [1.2.0-rc](https://github.com/theory-cloud/TableTheory/compare/v1.1.5...v1.2.0-rc) (2026-01-22) + + +### Features + +* **mocks:** add transaction builder mock ([ea39672](https://github.com/theory-cloud/TableTheory/commit/ea39672edffd22bf24b1471e244c14b79f06211d)) +* **mocks:** add transaction builder mock ([16ab5a5](https://github.com/theory-cloud/TableTheory/commit/16ab5a5d7b22d1087973f96a5c69b2c3a3796c3e)) + + +### Bug Fixes + +* **ci:** retry git fetch in branch-version-sync ([1b4d855](https://github.com/theory-cloud/TableTheory/commit/1b4d8557fe66c5c333846469369d9e5285cc1232)) +* **mocks:** satisfy lint gates ([a9cd117](https://github.com/theory-cloud/TableTheory/commit/a9cd1170fc200489369b76f098635321ed3d81c0)) +* **premain:** restore prerelease version alignment ([9b07cdb](https://github.com/theory-cloud/TableTheory/commit/9b07cdb7df5e69be8012374f742d89252ffde942)) + ## [1.1.5](https://github.com/theory-cloud/TableTheory/compare/v1.1.4...v1.1.5) (2026-01-20) diff --git a/hgm-infra/evidence/hgm-rubric-report.json b/hgm-infra/evidence/hgm-rubric-report.json index 9d5a175..619cf47 100644 --- a/hgm-infra/evidence/hgm-rubric-report.json +++ b/hgm-infra/evidence/hgm-rubric-report.json @@ -1,7 +1,7 @@ { "$schema": "https://hgm.pai.dev/schemas/hgm-rubric-report.schema.json", "schemaVersion": 1, - "timestamp": "2026-01-19T05:20:17Z", + "timestamp": "2026-01-23T01:42:09Z", "pack": { "version": "816465a1618d", "digest": "896aed16549928f21626fb4effe9bb6236fc60292a8f50bae8ce77e873ac775b" diff --git a/pkg/mocks/README_SIMPLE_EXAMPLE.md b/pkg/mocks/README_SIMPLE_EXAMPLE.md index e256dfa..14620dc 100644 --- a/pkg/mocks/README_SIMPLE_EXAMPLE.md +++ b/pkg/mocks/README_SIMPLE_EXAMPLE.md @@ -153,6 +153,35 @@ mockUpdateBuilder.On("Set", "Email", "new@email.com").Return(mockUpdateBuilder) mockUpdateBuilder.On("Execute").Return(nil) ``` +### Pattern 7: Transactions (`TransactWrite`) +If your code uses the transaction DSL: + +```go +err := db.TransactWrite(ctx, func(tx core.TransactionBuilder) error { + tx.WithContext(ctx) + tx.UpdateWithBuilder(&User{ID: "123"}, func(ub core.UpdateBuilder) error { + ub.Set("Status", "active") + return nil + }) + return nil +}) +``` + +You can mock it without boilerplate `.Run(...)` calls by providing a `MockTransactionBuilder`: + +```go +mockDB := new(mocks.MockExtendedDB) +mockTx := new(mocks.MockTransactionBuilder) + +// Tell the ExtendedDB mock which builder to run the callback with +mockDB.TransactWriteBuilder = mockTx + +mockDB.On("TransactWrite", ctx, mock.Anything).Return(nil) +mockTx.On("WithContext", ctx).Return(mockTx) +mockTx.On("UpdateWithBuilder", mock.Anything, mock.Anything, mock.Anything).Return(mockTx) +mockTx.On("Execute").Return(nil) +``` + ## 🚨 Common Mistakes ### ❌ Mistake 1: Forgetting AssertExpectations @@ -191,4 +220,4 @@ Check out these files for complete, working examples: 5. **Always verify expectations** - call `AssertExpectations(t)` 6. **Use `mock.Anything`** when you don't care about specific arguments -Happy testing! 🎉 \ No newline at end of file +Happy testing! 🎉 diff --git a/pkg/mocks/db.go b/pkg/mocks/db.go index d5e68c1..21821a0 100644 --- a/pkg/mocks/db.go +++ b/pkg/mocks/db.go @@ -30,8 +30,37 @@ func (m *MockDB) Model(model any) core.Query { // Transaction executes a function within a database transaction func (m *MockDB) Transaction(fn func(tx *core.Tx) error) error { - args := m.Called(fn) - return args.Error(0) + if fn == nil { + args := m.Called(fn) + return args.Error(0) + } + + var ( + callbackInvoked bool + callbackErr error + ) + + wrapped := func(tx *core.Tx) error { + callbackInvoked = true + callbackErr = fn(tx) + return callbackErr + } + + args := m.Called(wrapped) + + if err := args.Error(0); err != nil { + return err + } + + if !callbackInvoked { + tx := &core.Tx{} + tx.SetDB(m) + if err := wrapped(tx); err != nil { + return err + } + } + + return callbackErr } // Migrate runs all pending migrations diff --git a/pkg/mocks/extended_db.go b/pkg/mocks/extended_db.go index f01a9ba..0aedcc8 100644 --- a/pkg/mocks/extended_db.go +++ b/pkg/mocks/extended_db.go @@ -22,6 +22,13 @@ import ( // mockDB.On("Model", &User{}).Return(mockQuery) // mockQuery.On("Create").Return(nil) type MockExtendedDB struct { + // TransactWriteBuilder is used when TransactWrite auto-executes the provided callback. + // If nil, a new MockTransactionBuilder is created for each call. + TransactWriteBuilder core.TransactionBuilder + + // TransactionFuncTx is passed to TransactionFunc when auto-executing callbacks. + TransactionFuncTx any + MockDB // Embed MockDB to inherit base methods } @@ -81,23 +88,97 @@ func (m *MockExtendedDB) WithLambdaTimeoutBuffer(buffer time.Duration) core.DB { // TransactionFunc executes a function within a full transaction context func (m *MockExtendedDB) TransactionFunc(fn func(tx any) error) error { - args := m.Called(fn) - return args.Error(0) + if fn == nil { + args := m.Called(fn) + return args.Error(0) + } + + var ( + callbackInvoked bool + callbackErr error + ) + + wrapped := func(tx any) error { + callbackInvoked = true + callbackErr = fn(tx) + return callbackErr + } + + args := m.Called(wrapped) + if err := args.Error(0); err != nil { + return err + } + + if !callbackInvoked { + if err := wrapped(m.TransactionFuncTx); err != nil { + return err + } + } + + return callbackErr } // Transact returns a transaction builder mock func (m *MockExtendedDB) Transact() core.TransactionBuilder { args := m.Called() - if builder, ok := args.Get(0).(core.TransactionBuilder); ok { - return builder - } - return nil + return mustTransactionBuilder(args.Get(0)) } // TransactWrite executes a function with a transaction builder func (m *MockExtendedDB) TransactWrite(ctx context.Context, fn func(core.TransactionBuilder) error) error { - args := m.Called(ctx, fn) - return args.Error(0) + if fn == nil { + args := m.Called(ctx, fn) + return args.Error(0) + } + + builder := m.TransactWriteBuilder + if builder == nil { + builder = new(MockTransactionBuilder) + } + + var ( + callbackInvoked bool + callbackInvokedDuringCall bool + callbackErr error + builderUsed core.TransactionBuilder + ) + + inCalled := true + wrapped := func(tx core.TransactionBuilder) error { + callbackInvoked = true + if inCalled { + callbackInvokedDuringCall = true + } + builderUsed = tx + callbackErr = fn(tx) + return callbackErr + } + + args := m.Called(ctx, wrapped) + inCalled = false + + if err := args.Error(0); err != nil { + return err + } + + if !callbackInvoked { + if err := wrapped(builder); err != nil { + return err + } + } + + if callbackErr != nil { + return callbackErr + } + + if !callbackInvokedDuringCall { + if builderUsed == nil { + builderUsed = builder + } + return builderUsed.Execute() + } + + return nil } // NewMockExtendedDB creates a new MockExtendedDB with sensible defaults diff --git a/pkg/mocks/mocks.go b/pkg/mocks/mocks.go index dd27dd9..746c932 100644 --- a/pkg/mocks/mocks.go +++ b/pkg/mocks/mocks.go @@ -101,6 +101,9 @@ type ( // DB is an alias for MockDB to allow shorter declarations DB = MockDB + // TransactionBuilder is an alias for MockTransactionBuilder to simplify usage in tests. + TransactionBuilder = MockTransactionBuilder + // UpdateBuilder is an alias for MockUpdateBuilder to allow shorter declarations UpdateBuilder = MockUpdateBuilder diff --git a/pkg/mocks/mocks_test.go b/pkg/mocks/mocks_test.go index f7f3f9f..72f78d9 100644 --- a/pkg/mocks/mocks_test.go +++ b/pkg/mocks/mocks_test.go @@ -218,5 +218,6 @@ func TestTypeAliases(t *testing.T) { // These should compile without issues _ = new(mocks.Query) _ = new(mocks.DB) + _ = new(mocks.TransactionBuilder) _ = new(mocks.UpdateBuilder) } diff --git a/pkg/mocks/query.go b/pkg/mocks/query.go index e70c153..e8f5996 100644 --- a/pkg/mocks/query.go +++ b/pkg/mocks/query.go @@ -64,6 +64,17 @@ func mustUpdateBuilder(v any) core.UpdateBuilder { return builder } +func mustTransactionBuilder(v any) core.TransactionBuilder { + if v == nil { + return nil + } + builder, ok := v.(core.TransactionBuilder) + if !ok { + panic("unexpected type: expected core.TransactionBuilder") + } + return builder +} + // MockQuery is a mock implementation of the core.Query interface. // It can be used for unit testing code that depends on TableTheory queries. // diff --git a/pkg/mocks/transaction_builder.go b/pkg/mocks/transaction_builder.go new file mode 100644 index 0000000..861cab3 --- /dev/null +++ b/pkg/mocks/transaction_builder.go @@ -0,0 +1,180 @@ +package mocks + +import ( + "context" + + "github.com/stretchr/testify/mock" + + "github.com/theory-cloud/tabletheory/pkg/core" +) + +type noopUpdateBuilder struct{} + +func (n *noopUpdateBuilder) Set(_ string, _ any) core.UpdateBuilder { return n } +func (n *noopUpdateBuilder) SetIfNotExists(_ string, _ any, _ any) core.UpdateBuilder { return n } +func (n *noopUpdateBuilder) Add(_ string, _ any) core.UpdateBuilder { return n } +func (n *noopUpdateBuilder) Increment(_ string) core.UpdateBuilder { return n } +func (n *noopUpdateBuilder) Decrement(_ string) core.UpdateBuilder { return n } +func (n *noopUpdateBuilder) Remove(_ string) core.UpdateBuilder { return n } +func (n *noopUpdateBuilder) Delete(_ string, _ any) core.UpdateBuilder { return n } +func (n *noopUpdateBuilder) AppendToList(_ string, _ any) core.UpdateBuilder { return n } +func (n *noopUpdateBuilder) PrependToList(_ string, _ any) core.UpdateBuilder { return n } +func (n *noopUpdateBuilder) RemoveFromListAt(_ string, _ int) core.UpdateBuilder { return n } +func (n *noopUpdateBuilder) SetListElement(_ string, _ int, _ any) core.UpdateBuilder { return n } +func (n *noopUpdateBuilder) Condition(_ string, _ string, _ any) core.UpdateBuilder { return n } +func (n *noopUpdateBuilder) OrCondition(_ string, _ string, _ any) core.UpdateBuilder { return n } +func (n *noopUpdateBuilder) ConditionExists(_ string) core.UpdateBuilder { return n } +func (n *noopUpdateBuilder) ConditionNotExists(_ string) core.UpdateBuilder { return n } +func (n *noopUpdateBuilder) ConditionVersion(_ int64) core.UpdateBuilder { return n } +func (n *noopUpdateBuilder) ReturnValues(_ string) core.UpdateBuilder { return n } +func (n *noopUpdateBuilder) Execute() error { return nil } +func (n *noopUpdateBuilder) ExecuteWithResult(_ any) error { return nil } + +// MockTransactionBuilder is a mock implementation of the core.TransactionBuilder interface. +// +// It supports the fluent transaction DSL used by ExtendedDB.TransactWrite and can optionally +// execute UpdateWithBuilder callbacks when Execute/ExecuteWithContext is invoked. +type MockTransactionBuilder struct { + mock.Mock + + // UpdateBuilder is used when executing UpdateWithBuilder callbacks during Execute/ExecuteWithContext. + // If nil, a no-op implementation is used. + UpdateBuilder core.UpdateBuilder + + pendingUpdateFns []func(core.UpdateBuilder) error +} + +var _ core.TransactionBuilder = (*MockTransactionBuilder)(nil) + +func (m *MockTransactionBuilder) hasExpectedCall(method string) bool { + for _, call := range m.ExpectedCalls { + if call.Method == method { + return true + } + } + return false +} + +// Put adds a put (upsert) operation. +func (m *MockTransactionBuilder) Put(model any, conditions ...core.TransactCondition) core.TransactionBuilder { + if m.hasExpectedCall("Put") { + args := m.Called(model, conditions) + return mustTransactionBuilder(args.Get(0)) + } + return m +} + +// Create adds a put operation guarded by attribute_not_exists on the primary key. +func (m *MockTransactionBuilder) Create(model any, conditions ...core.TransactCondition) core.TransactionBuilder { + if m.hasExpectedCall("Create") { + args := m.Called(model, conditions) + return mustTransactionBuilder(args.Get(0)) + } + return m +} + +// Update updates selected fields on the provided model. +func (m *MockTransactionBuilder) Update(model any, fields []string, conditions ...core.TransactCondition) core.TransactionBuilder { + if m.hasExpectedCall("Update") { + args := m.Called(model, fields, conditions) + return mustTransactionBuilder(args.Get(0)) + } + return m +} + +// UpdateWithBuilder allows complex expression-based updates. +func (m *MockTransactionBuilder) UpdateWithBuilder(model any, updateFn func(core.UpdateBuilder) error, conditions ...core.TransactCondition) core.TransactionBuilder { + if updateFn != nil { + ran := false + var runErr error + wrapped := func(ub core.UpdateBuilder) error { + if ran { + return runErr + } + ran = true + runErr = updateFn(ub) + return runErr + } + m.pendingUpdateFns = append(m.pendingUpdateFns, wrapped) + + if m.hasExpectedCall("UpdateWithBuilder") { + args := m.Called(model, wrapped, conditions) + return mustTransactionBuilder(args.Get(0)) + } + return m + } + + if m.hasExpectedCall("UpdateWithBuilder") { + args := m.Called(model, updateFn, conditions) + return mustTransactionBuilder(args.Get(0)) + } + return m +} + +// Delete removes the provided model by primary key. +func (m *MockTransactionBuilder) Delete(model any, conditions ...core.TransactCondition) core.TransactionBuilder { + if m.hasExpectedCall("Delete") { + args := m.Called(model, conditions) + return mustTransactionBuilder(args.Get(0)) + } + return m +} + +// ConditionCheck adds a pure condition check without mutating data. +func (m *MockTransactionBuilder) ConditionCheck(model any, conditions ...core.TransactCondition) core.TransactionBuilder { + if m.hasExpectedCall("ConditionCheck") { + args := m.Called(model, conditions) + return mustTransactionBuilder(args.Get(0)) + } + return m +} + +// WithContext sets the context used for DynamoDB calls. +func (m *MockTransactionBuilder) WithContext(ctx context.Context) core.TransactionBuilder { + if m.hasExpectedCall("WithContext") { + args := m.Called(ctx) + return mustTransactionBuilder(args.Get(0)) + } + return m +} + +// Execute commits the transaction using the currently configured context. +func (m *MockTransactionBuilder) Execute() error { + if m.hasExpectedCall("Execute") { + args := m.Called() + if err := args.Error(0); err != nil { + return err + } + } + return m.runPendingUpdateFns() +} + +// ExecuteWithContext commits the transaction with an explicit context override. +func (m *MockTransactionBuilder) ExecuteWithContext(ctx context.Context) error { + if m.hasExpectedCall("ExecuteWithContext") { + args := m.Called(ctx) + if err := args.Error(0); err != nil { + return err + } + } + return m.runPendingUpdateFns() +} + +func (m *MockTransactionBuilder) runPendingUpdateFns() error { + ub := m.UpdateBuilder + if ub == nil { + ub = &noopUpdateBuilder{} + } + + for _, fn := range m.pendingUpdateFns { + if fn == nil { + continue + } + if err := fn(ub); err != nil { + return err + } + } + + m.pendingUpdateFns = nil + return nil +} diff --git a/pkg/mocks/transaction_builder_test.go b/pkg/mocks/transaction_builder_test.go new file mode 100644 index 0000000..f246b92 --- /dev/null +++ b/pkg/mocks/transaction_builder_test.go @@ -0,0 +1,295 @@ +package mocks_test + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + "github.com/theory-cloud/tabletheory/pkg/core" + "github.com/theory-cloud/tabletheory/pkg/mocks" +) + +func TestMockTransactionBuilderImplementsInterface(t *testing.T) { + var _ core.TransactionBuilder = (*mocks.MockTransactionBuilder)(nil) +} + +func TestMockExtendedDB_TransactWrite_AutoRunsCallback(t *testing.T) { + ctx := context.Background() + db := mocks.NewMockExtendedDBStrict() + + tx := new(mocks.MockTransactionBuilder) + db.TransactWriteBuilder = tx + + db.On("TransactWrite", ctx, mock.Anything).Return(nil).Once() + tx.On("WithContext", ctx).Return(tx).Once() + tx.On("UpdateWithBuilder", mock.Anything, mock.Anything, mock.Anything).Return(tx).Once() + tx.On("Execute").Return(nil).Once() + + var callbackCalls int + var updateFnCalls int + + err := db.TransactWrite(ctx, func(tb core.TransactionBuilder) error { + callbackCalls++ + tb.WithContext(ctx) + tb.UpdateWithBuilder(&struct{}{}, func(core.UpdateBuilder) error { + updateFnCalls++ + return nil + }) + return nil + }) + + assert.NoError(t, err) + assert.Equal(t, 1, callbackCalls) + assert.Equal(t, 1, updateFnCalls) + + db.AssertExpectations(t) + tx.AssertExpectations(t) +} + +func TestMockExtendedDB_TransactWrite_DoesNotDoubleRunCallback(t *testing.T) { + ctx := context.Background() + db := mocks.NewMockExtendedDBStrict() + tx := new(mocks.MockTransactionBuilder) + + db.On("TransactWrite", ctx, mock.Anything).Run(func(args mock.Arguments) { + fn := args.Get(1).(func(core.TransactionBuilder) error) + _ = fn(tx) + }).Return(nil).Once() + + var callbackCalls int + err := db.TransactWrite(ctx, func(core.TransactionBuilder) error { + callbackCalls++ + return nil + }) + + assert.NoError(t, err) + assert.Equal(t, 1, callbackCalls) + db.AssertExpectations(t) +} + +func TestMockDB_Transaction_AutoRunsCallback(t *testing.T) { + db := new(mocks.MockDB) + db.On("Transaction", mock.Anything).Return(nil).Once() + + var calls int + err := db.Transaction(func(*core.Tx) error { + calls++ + return nil + }) + + assert.NoError(t, err) + assert.Equal(t, 1, calls) + db.AssertExpectations(t) +} + +func TestMockDB_Transaction_DoesNotDoubleRunCallback(t *testing.T) { + db := new(mocks.MockDB) + db.On("Transaction", mock.Anything).Run(func(args mock.Arguments) { + fn := args.Get(0).(func(*core.Tx) error) + _ = fn(&core.Tx{}) + }).Return(nil).Once() + + var calls int + err := db.Transaction(func(*core.Tx) error { + calls++ + return nil + }) + + assert.NoError(t, err) + assert.Equal(t, 1, calls) + db.AssertExpectations(t) +} + +func TestMockTransactionBuilder_DefaultBehavior_ReturnsSelf(t *testing.T) { + ctx := context.Background() + tx := new(mocks.MockTransactionBuilder) + model := &struct{}{} + cond := core.TransactCondition{Kind: core.TransactConditionKindPrimaryKeyExists} + + assert.Same(t, tx, tx.Put(model, cond)) + assert.Same(t, tx, tx.Create(model, cond)) + assert.Same(t, tx, tx.Update(model, []string{"a"}, cond)) + assert.Same(t, tx, tx.UpdateWithBuilder(model, nil, cond)) + assert.Same(t, tx, tx.Delete(model, cond)) + assert.Same(t, tx, tx.ConditionCheck(model, cond)) + assert.Same(t, tx, tx.WithContext(ctx)) + assert.NoError(t, tx.Execute()) + assert.NoError(t, tx.ExecuteWithContext(ctx)) +} + +func TestMockTransactionBuilder_ExpectedCalls_AreRoutedThroughTestify(t *testing.T) { + ctx := context.Background() + tx := new(mocks.MockTransactionBuilder) + model := &struct{}{} + cond := core.TransactCondition{Kind: core.TransactConditionKindPrimaryKeyExists} + + tx.On("Put", model, []core.TransactCondition{cond}).Return(tx).Once() + tx.On("Create", model, []core.TransactCondition{cond}).Return(tx).Once() + tx.On("Update", model, []string{"a"}, []core.TransactCondition{cond}).Return(tx).Once() + tx.On("UpdateWithBuilder", model, mock.Anything, []core.TransactCondition{cond}).Return(tx).Once() + tx.On("Delete", model, []core.TransactCondition{cond}).Return(tx).Once() + tx.On("ConditionCheck", model, []core.TransactCondition{cond}).Return(tx).Once() + tx.On("WithContext", ctx).Return(tx).Once() + tx.On("ExecuteWithContext", ctx).Return(nil).Once() + + assert.Same(t, tx, tx.Put(model, cond)) + assert.Same(t, tx, tx.Create(model, cond)) + assert.Same(t, tx, tx.Update(model, []string{"a"}, cond)) + assert.Same(t, tx, tx.UpdateWithBuilder(model, func(core.UpdateBuilder) error { return nil }, cond)) + assert.Same(t, tx, tx.Delete(model, cond)) + assert.Same(t, tx, tx.ConditionCheck(model, cond)) + assert.Same(t, tx, tx.WithContext(ctx)) + assert.NoError(t, tx.ExecuteWithContext(ctx)) + + tx.AssertExpectations(t) +} + +func TestMockTransactionBuilder_UpdateWithBuilder_RunsUpdateFnOnExecute_AndCoversNoopUpdateBuilder(t *testing.T) { + tx := new(mocks.MockTransactionBuilder) + model := &struct{}{} + + var calls int + tx.UpdateWithBuilder(model, func(ub core.UpdateBuilder) error { + calls++ + ub.Set("a", 1) + ub.SetIfNotExists("b", 2, 0) + ub.Add("c", 3) + ub.Increment("d") + ub.Decrement("e") + ub.Remove("f") + ub.Delete("g", 1) + ub.AppendToList("h", []string{"x"}) + ub.PrependToList("i", []string{"y"}) + ub.RemoveFromListAt("j", 0) + ub.SetListElement("k", 1, "z") + ub.Condition("l", "=", 1) + ub.OrCondition("m", "=", 2) + ub.ConditionExists("n") + ub.ConditionNotExists("o") + ub.ConditionVersion(1) + ub.ReturnValues("ALL_NEW") + _ = ub.Execute() + _ = ub.ExecuteWithResult(&struct{}{}) + return nil + }) + + assert.NoError(t, tx.Execute()) + assert.Equal(t, 1, calls) + assert.NoError(t, tx.Execute()) + assert.Equal(t, 1, calls) +} + +func TestMockTransactionBuilder_ExecuteWithContext_RunsPendingUpdateFns(t *testing.T) { + tx := new(mocks.MockTransactionBuilder) + model := &struct{}{} + ctx := context.Background() + + var calls int + tx.UpdateWithBuilder(model, func(core.UpdateBuilder) error { + calls++ + return nil + }) + + assert.NoError(t, tx.ExecuteWithContext(ctx)) + assert.Equal(t, 1, calls) +} + +func TestMockTransactionBuilder_Execute_ReturnsExpectedCallError(t *testing.T) { + tx := new(mocks.MockTransactionBuilder) + expectedErr := errors.New("execute failed") + + tx.On("Execute").Return(expectedErr).Once() + assert.ErrorIs(t, tx.Execute(), expectedErr) + tx.AssertExpectations(t) +} + +func TestMockTransactionBuilder_ExecuteWithBuilder_ErrorIsPropagated(t *testing.T) { + tx := new(mocks.MockTransactionBuilder) + expectedErr := errors.New("update failed") + + tx.UpdateWithBuilder(&struct{}{}, func(core.UpdateBuilder) error { + return expectedErr + }) + + assert.ErrorIs(t, tx.Execute(), expectedErr) +} + +func TestMockTransactionBuilder_ExecuteWithContext_ReturnsExpectedCallError(t *testing.T) { + tx := new(mocks.MockTransactionBuilder) + ctx := context.Background() + expectedErr := errors.New("execute-with-context failed") + + tx.On("ExecuteWithContext", ctx).Return(expectedErr).Once() + assert.ErrorIs(t, tx.ExecuteWithContext(ctx), expectedErr) + tx.AssertExpectations(t) +} + +func TestMockTransactionBuilder_UpdateWithBuilder_NilFn_UsesExpectedCallWhenConfigured(t *testing.T) { + tx := new(mocks.MockTransactionBuilder) + tx.On("UpdateWithBuilder", mock.Anything, mock.Anything, mock.Anything).Return(tx).Once() + + assert.Same(t, tx, tx.UpdateWithBuilder(&struct{}{}, nil)) + tx.AssertExpectations(t) +} + +func TestMockTransactionBuilder_UsesProvidedUpdateBuilder(t *testing.T) { + tx := new(mocks.MockTransactionBuilder) + tx.UpdateBuilder = new(mocks.MockUpdateBuilder) + + var calls int + tx.UpdateWithBuilder(&struct{}{}, func(core.UpdateBuilder) error { + calls++ + return nil + }) + + assert.NoError(t, tx.Execute()) + assert.Equal(t, 1, calls) +} + +func TestMockExtendedDB_DescribeTable_ReturnsNilWhenValueIsNil(t *testing.T) { + db := mocks.NewMockExtendedDBStrict() + expectedErr := errors.New("not found") + + db.On("DescribeTable", mock.Anything).Return(nil, expectedErr).Once() + got, err := db.DescribeTable(&struct{}{}) + + assert.Nil(t, got) + assert.ErrorIs(t, err, expectedErr) + db.AssertExpectations(t) +} + +func TestMockExtendedDB_Transact_PanicsOnUnexpectedReturnType(t *testing.T) { + db := mocks.NewMockExtendedDBStrict() + db.On("Transact").Return("not-a-builder").Once() + + assert.Panics(t, func() { _ = db.Transact() }) + db.AssertExpectations(t) +} + +func TestMockExtendedDB_TransactWrite_ReturnsCallbackError(t *testing.T) { + ctx := context.Background() + db := mocks.NewMockExtendedDBStrict() + db.TransactWriteBuilder = new(mocks.MockTransactionBuilder) + + db.On("TransactWrite", ctx, mock.Anything).Return(nil).Once() + + expectedErr := errors.New("tx failed") + err := db.TransactWrite(ctx, func(core.TransactionBuilder) error { return expectedErr }) + + assert.ErrorIs(t, err, expectedErr) + db.AssertExpectations(t) +} + +func TestMockDB_Transaction_ReturnsExpectationError(t *testing.T) { + db := new(mocks.MockDB) + expectedErr := errors.New("tx failed") + + db.On("Transaction", mock.Anything).Return(expectedErr).Once() + err := db.Transaction(func(*core.Tx) error { return nil }) + + assert.ErrorIs(t, err, expectedErr) + db.AssertExpectations(t) +} diff --git a/py/README.md b/py/README.md index 11f3224..e7e27b2 100644 --- a/py/README.md +++ b/py/README.md @@ -47,6 +47,7 @@ class Note: pk: str = theorydb_field(roles=["pk"]) sk: str = theorydb_field(roles=["sk"]) value: int = theorydb_field() + note: str = theorydb_field(omitempty=True, default="") client = boto3.client( @@ -76,7 +77,7 @@ page2 = table.query("A", cursor=page1.next_cursor) if page1.next_cursor else Non ## Batch + transactions ```python -from theorydb_py import TransactUpdate +from theorydb_py import TransactUpdate, UpdateAdd, UpdateSetIfNotExists table.batch_write(puts=[Note(pk="A", sk="2", value=1)], deletes=[("A", "1")]) @@ -85,10 +86,10 @@ table.transact_write( TransactUpdate( pk="A", sk="2", - updates={"value": 2}, - condition_expression="#v = :expected", + updates={"value": UpdateAdd(1), "note": UpdateSetIfNotExists("first")}, + condition_expression="attribute_not_exists(#v) OR #v < :max_allowed", expression_attribute_names={"#v": "value"}, - expression_attribute_values={":expected": 1}, + expression_attribute_values={":max_allowed": 100}, ) ] ) diff --git a/py/src/theorydb_py/__init__.py b/py/src/theorydb_py/__init__.py index ec8667c..f1c84d8 100644 --- a/py/src/theorydb_py/__init__.py +++ b/py/src/theorydb_py/__init__.py @@ -35,6 +35,8 @@ TransactPut, TransactUpdate, TransactWriteAction, + UpdateAdd, + UpdateSetIfNotExists, ) if TYPE_CHECKING: @@ -262,6 +264,8 @@ def __getattr__(name: str) -> Any: "TransactPut", "TransactUpdate", "TransactWriteAction", + "UpdateAdd", + "UpdateSetIfNotExists", "TransactionCanceledError", "Table", "ValidationError", diff --git a/py/src/theorydb_py/table.py b/py/src/theorydb_py/table.py index e4d82d4..2d6d9e3 100644 --- a/py/src/theorydb_py/table.py +++ b/py/src/theorydb_py/table.py @@ -36,6 +36,8 @@ TransactPut, TransactUpdate, TransactWriteAction, + UpdateAdd, + UpdateSetIfNotExists, ) if TYPE_CHECKING: @@ -854,6 +856,14 @@ def _build_update_request( update_values: dict[str, Any] = {} set_parts: list[str] = [] remove_parts: list[str] = [] + add_parts: list[str] = [] + + def normalize_set(value: Any) -> set[Any]: + if isinstance(value, set): + return value + if isinstance(value, (list, tuple)): + return set(value) + return {value} for field_name, value in updates.items(): if field_name not in self._model.attributes: @@ -867,6 +877,30 @@ def _build_update_request( name_ref = f"#d_{field_name}" update_names[name_ref] = attr_def.attribute_name + if isinstance(value, UpdateAdd): + if attr_def.encrypted: + raise ValidationError(f"encrypted fields cannot be used in ADD: {field_name}") + + value_ref = f":d_{field_name}" + if attr_def.set: + update_values[value_ref] = self._serialize_attr_value( + attr_def, + normalize_set(value.value), + ) + else: + if not isinstance(value.value, (int, float, Decimal)): + raise ValidationError("ADD requires a numeric value for non-set fields") + update_values[value_ref] = self._serializer.serialize(value.value) + + add_parts.append(f"{name_ref} {value_ref}") + continue + + if isinstance(value, UpdateSetIfNotExists): + value_ref = f":d_{field_name}" + update_values[value_ref] = self._serialize_attr_value(attr_def, value.default_value) + set_parts.append(f"{name_ref} = if_not_exists({name_ref}, {value_ref})") + continue + if value is None: remove_parts.append(name_ref) continue @@ -880,6 +914,8 @@ def _build_update_request( expr_parts.append("SET " + ", ".join(set_parts)) if remove_parts: expr_parts.append("REMOVE " + ", ".join(remove_parts)) + if add_parts: + expr_parts.append("ADD " + ", ".join(add_parts)) if not expr_parts: raise ValidationError("no updates provided") diff --git a/py/src/theorydb_py/transaction.py b/py/src/theorydb_py/transaction.py index 560a9c5..f12d823 100644 --- a/py/src/theorydb_py/transaction.py +++ b/py/src/theorydb_py/transaction.py @@ -5,6 +5,16 @@ from typing import Any +@dataclass(frozen=True) +class UpdateAdd: + value: Any + + +@dataclass(frozen=True) +class UpdateSetIfNotExists: + default_value: Any + + @dataclass(frozen=True) class TransactPut[T]: item: T diff --git a/py/src/theorydb_py/version.json b/py/src/theorydb_py/version.json index f4ad707..a15e91f 100644 --- a/py/src/theorydb_py/version.json +++ b/py/src/theorydb_py/version.json @@ -1,3 +1,3 @@ { - "version": "1.1.5" + "version": "1.2.0" } diff --git a/py/tests/unit/test_table_additional_coverage.py b/py/tests/unit/test_table_additional_coverage.py index b4a8f88..00ee705 100644 --- a/py/tests/unit/test_table_additional_coverage.py +++ b/py/tests/unit/test_table_additional_coverage.py @@ -14,6 +14,8 @@ SortKeyCondition, Table, TransactionCanceledError, + UpdateAdd, + UpdateSetIfNotExists, ValidationError, theorydb_field, ) @@ -306,6 +308,33 @@ def test_transact_write_builds_requests_for_all_action_types() -> None: assert {next(iter(i.keys())) for i in transact_items} == {"Put", "Delete", "Update", "ConditionCheck"} +def test_transact_update_supports_add_and_if_not_exists() -> None: + model = ModelDefinition.from_dataclass(Item, table_name="tbl") + stub = _StubClient() + table: Table[Item] = Table(model, client=stub) + + table.transact_write( + [ + TransactUpdate( + pk="A", + sk="1", + updates={ + "value": UpdateAdd(1), + "note": UpdateSetIfNotExists("first"), + }, + condition_expression="attribute_not_exists(#v) OR #v < :max", + expression_attribute_names={"#v": "value"}, + expression_attribute_values={":max": 100}, + ) + ] + ) + + update = stub.transact_write_reqs[0]["TransactItems"][0]["Update"] + assert "ADD" in update["UpdateExpression"] + assert "if_not_exists" in update["UpdateExpression"] + assert update["ExpressionAttributeValues"][":d_value"]["N"] == "1" + + def test_put_delete_update_expression_attribute_maps_and_build_request_merges() -> None: model = ModelDefinition.from_dataclass(Item, table_name="tbl") stub = _StubClient() diff --git a/scripts/verify-branch-version-sync.sh b/scripts/verify-branch-version-sync.sh index f209bfb..7cd1ddf 100644 --- a/scripts/verify-branch-version-sync.sh +++ b/scripts/verify-branch-version-sync.sh @@ -8,6 +8,32 @@ set -euo pipefail # - `premain` cuts prereleases using `.release-please-manifest.premain.json` # If `premain` doesn't regularly back-merge `main`, prereleases can get stuck on an old major/minor track. +git_fetch_retry() { + local remote="$1" + shift + + local -a refspecs=("$@") + local attempts="${GIT_FETCH_RETRIES:-5}" + local base_sleep="${GIT_FETCH_RETRY_SLEEP_SECS:-2}" + + local i=1 + while true; do + if git fetch --quiet --depth=1 "${remote}" "${refspecs[@]}"; then + return 0 + fi + + if [[ "${i}" -ge "${attempts}" ]]; then + echo "branch-version-sync: FAIL (git fetch failed after ${attempts} attempts)" >&2 + return 1 + fi + + sleep_for=$((base_sleep * i)) + echo "branch-version-sync: retrying git fetch in ${sleep_for}s (${i}/${attempts})..." >&2 + sleep "${sleep_for}" + i=$((i + 1)) + done +} + base_ref="${GITHUB_BASE_REF:-}" head_ref="${GITHUB_HEAD_REF:-}" ref_name="${GITHUB_REF_NAME:-}" @@ -35,7 +61,7 @@ for f in ".release-please-manifest.json" ".release-please-manifest.premain.json" fi done -git fetch --quiet --depth=1 origin main +git_fetch_retry origin main main_stable="$( python3 - <<'PY' @@ -79,7 +105,7 @@ print(data.get(".", "")) PY )" else - git fetch --quiet --depth=1 origin premain + git_fetch_retry origin premain premain_stable="$( python3 - <<'PY' @@ -164,4 +190,3 @@ if premain_tuple < main_tuple: PY echo "branch-version-sync: PASS (main=${main_stable}, premain=${premain_version})" - diff --git a/ts/README.md b/ts/README.md index 4657f1d..badeb9c 100644 --- a/ts/README.md +++ b/ts/README.md @@ -115,6 +115,19 @@ await db.transactWrite([ item: { PK: 'U#1', SK: 'TX' }, ifNotExists: true, }, + { + kind: 'update', + model: 'User', + key: { PK: 'U#1', SK: 'TX' }, + updateExpression: 'SET #ws = if_not_exists(#ws, :ws) ADD #count :inc', + conditionExpression: 'attribute_not_exists(#count) OR #count < :maxAllowed', + expressionAttributeNames: { '#ws': 'WindowStart', '#count': 'Count' }, + expressionAttributeValues: { + ':ws': { S: '2026-01-23T00:00:00Z' }, + ':inc': { N: '1' }, + ':maxAllowed': { N: '100' }, + }, + }, ]); ``` diff --git a/ts/docs/core-patterns.md b/ts/docs/core-patterns.md index 231b4c2..0d33c97 100644 --- a/ts/docs/core-patterns.md +++ b/ts/docs/core-patterns.md @@ -78,6 +78,19 @@ await db.transactWrite([ item: { PK: 'U#1', SK: 'TX' }, ifNotExists: true, }, + { + kind: 'update', + model: 'User', + key: { PK: 'U#1', SK: 'TX' }, + updateExpression: 'SET #ws = if_not_exists(#ws, :ws) ADD #count :inc', + conditionExpression: 'attribute_not_exists(#count) OR #count < :maxAllowed', + expressionAttributeNames: { '#ws': 'WindowStart', '#count': 'Count' }, + expressionAttributeValues: { + ':ws': { S: '2026-01-23T00:00:00Z' }, + ':inc': { N: '1' }, + ':maxAllowed': { N: '100' }, + }, + }, ]); ``` diff --git a/ts/package-lock.json b/ts/package-lock.json index 5401383..2588532 100644 --- a/ts/package-lock.json +++ b/ts/package-lock.json @@ -1,12 +1,12 @@ { "name": "@theory-cloud/tabletheory-ts", - "version": "1.1.5", + "version": "1.2.0", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "@theory-cloud/tabletheory-ts", - "version": "1.1.5", + "version": "1.2.0", "license": "Apache-2.0", "dependencies": { "@aws-sdk/client-dynamodb": "^3.971.0", diff --git a/ts/package.json b/ts/package.json index fd480a2..48d5dc6 100644 --- a/ts/package.json +++ b/ts/package.json @@ -1,6 +1,6 @@ { "name": "@theory-cloud/tabletheory-ts", - "version": "1.1.5", + "version": "1.2.0", "license": "Apache-2.0", "private": true, "type": "module", diff --git a/ts/src/client.ts b/ts/src/client.ts index e5a7901..32933d9 100644 --- a/ts/src/client.ts +++ b/ts/src/client.ts @@ -11,6 +11,7 @@ import { type ConditionCheck, type Delete, type Put, + type Update, type TransactWriteItem, type WriteRequest, } from '@aws-sdk/client-dynamodb'; @@ -474,6 +475,37 @@ export class TheorydbClient { transactItems.push({ Put: put }); break; } + case 'update': { + const update: Update = { + TableName: model.tableName, + Key: marshalKey(model, a.key), + UpdateExpression: '', + }; + + if ('updateFn' in a) { + const builder = new UpdateBuilder( + this.ddb, + model, + a.key, + provider, + this.sendOptions, + ); + await a.updateFn(builder); + const built = await builder.build(); + update.UpdateExpression = built.updateExpression; + update.ConditionExpression = built.conditionExpression; + update.ExpressionAttributeNames = built.expressionAttributeNames; + update.ExpressionAttributeValues = built.expressionAttributeValues; + } else { + update.UpdateExpression = a.updateExpression; + update.ConditionExpression = a.conditionExpression; + update.ExpressionAttributeNames = a.expressionAttributeNames; + update.ExpressionAttributeValues = a.expressionAttributeValues; + } + + transactItems.push({ Update: update }); + break; + } case 'delete': transactItems.push({ Delete: { diff --git a/ts/src/transaction.ts b/ts/src/transaction.ts index 3c8dca6..f45a013 100644 --- a/ts/src/transaction.ts +++ b/ts/src/transaction.ts @@ -1,5 +1,29 @@ import type { AttributeValue } from '@aws-sdk/client-dynamodb'; +import type { UpdateBuilder } from './update-builder.js'; + +type TransactUpdateRaw = { + kind: 'update'; + model: string; + key: Record; + updateExpression: string; + conditionExpression?: string; + expressionAttributeNames?: Record; + expressionAttributeValues?: Record; + updateFn?: never; +}; + +type TransactUpdateWithBuilder = { + kind: 'update'; + model: string; + key: Record; + updateFn: (builder: UpdateBuilder) => void | Promise; + updateExpression?: never; + conditionExpression?: never; + expressionAttributeNames?: never; + expressionAttributeValues?: never; +}; + export type TransactAction = | { kind: 'put'; @@ -7,6 +31,8 @@ export type TransactAction = item: Record; ifNotExists?: boolean; } + | TransactUpdateRaw + | TransactUpdateWithBuilder | { kind: 'delete'; model: string; diff --git a/ts/src/update-builder.ts b/ts/src/update-builder.ts index a3956c1..4214190 100644 --- a/ts/src/update-builder.ts +++ b/ts/src/update-builder.ts @@ -372,7 +372,12 @@ export class UpdateBuilder { return this; } - async execute(): Promise | undefined> { + async build(): Promise<{ + updateExpression: string; + conditionExpression?: string; + expressionAttributeNames?: Record; + expressionAttributeValues?: Record; + }> { if (this.updateOps.length === 0) { throw new TheorydbError('ErrInvalidOperator', 'No updates provided'); } @@ -406,15 +411,28 @@ export class UpdateBuilder { values[k] = v; } + return { + updateExpression, + ...(cond.expression ? { conditionExpression: cond.expression } : {}), + ...(Object.keys(names).length ? { expressionAttributeNames: names } : {}), + ...(Object.keys(values).length + ? { expressionAttributeValues: values } + : {}), + }; + } + + async execute(): Promise | undefined> { + const built = await this.build(); + const cmd = new UpdateItemCommand({ TableName: this.model.tableName, Key: marshalKey(this.model, this.key), - UpdateExpression: updateExpression, - ...(cond.expression ? { ConditionExpression: cond.expression } : {}), - ExpressionAttributeNames: Object.keys(names).length ? names : undefined, - ExpressionAttributeValues: Object.keys(values).length - ? values - : undefined, + UpdateExpression: built.updateExpression, + ...(built.conditionExpression + ? { ConditionExpression: built.conditionExpression } + : {}), + ExpressionAttributeNames: built.expressionAttributeNames, + ExpressionAttributeValues: built.expressionAttributeValues, ReturnValues: this.returnValuesOpt, }); diff --git a/ts/test/unit/client.test.ts b/ts/test/unit/client.test.ts index 57ec580..7b7b616 100644 --- a/ts/test/unit/client.test.ts +++ b/ts/test/unit/client.test.ts @@ -249,6 +249,15 @@ class StubDdb { ifNotExists: true, }, { kind: 'delete', model: 'User', key: { PK: 'A', SK: '1' } }, + { + kind: 'update', + model: 'User', + key: { PK: 'A', SK: '1' }, + updateExpression: 'ADD #v :inc', + conditionExpression: 'attribute_exists(PK)', + expressionAttributeNames: { '#v': 'version' }, + expressionAttributeValues: { ':inc': { N: '1' } }, + }, { kind: 'condition', model: 'User', @@ -259,6 +268,43 @@ class StubDdb { assert.ok(ddb.sent[0] instanceof TransactWriteItemsCommand); } +{ + const ddb = new StubDdb((cmd) => { + if (cmd instanceof TransactWriteItemsCommand) return {}; + throw new Error('unexpected'); + }); + const client = new TheorydbClient(ddb as unknown as DynamoDBClient).register( + User, + ); + + await client.transactWrite([ + { + kind: 'update', + model: 'User', + key: { PK: 'A', SK: '1' }, + updateFn: (u) => { + u.add('version', 1); + u.setIfNotExists( + 'createdAt', + undefined, + '2026-01-23T00:00:00.000000000Z', + ); + u.conditionNotExists('version').orCondition('version', '<', 100); + }, + }, + ]); + + const cmd = ddb.sent[0]; + assert.ok(cmd instanceof TransactWriteItemsCommand); + assert.equal(cmd.input.TransactItems?.length, 1); + const update = cmd.input.TransactItems?.[0]?.Update; + assert.equal(update?.TableName, 'users_contract'); + assert.ok(update?.UpdateExpression?.includes('ADD')); + assert.ok(update?.UpdateExpression?.includes('if_not_exists')); + assert.ok(update?.ConditionExpression?.includes('attribute_not_exists')); + assert.ok(update?.ConditionExpression?.includes('<')); +} + { const ddb = new StubDdb((cmd) => { if (cmd instanceof QueryCommand) return { Items: [] };