diff --git a/internal/core/persistence/persistence.go b/internal/core/persistence/persistence.go index b0d4725306..7dbeb5541c 100644 --- a/internal/core/persistence/persistence.go +++ b/internal/core/persistence/persistence.go @@ -2,6 +2,7 @@ package persistence import ( "encoding/hex" + "errors" "fmt" "path" "strings" @@ -10,6 +11,7 @@ import ( "github.com/langgenius/dify-plugin-daemon/internal/db" "github.com/langgenius/dify-plugin-daemon/internal/types/models" "github.com/langgenius/dify-plugin-daemon/pkg/utils/cache" + "gorm.io/gorm" ) type Persistence struct { @@ -47,54 +49,72 @@ func (c *Persistence) Save(tenantId string, pluginId string, maxSize int64, key maxSize = c.maxStorageSize } - if err := c.storage.Save(tenantId, pluginId, key, data); err != nil { + lockKey := fmt.Sprintf("persistence:lock:%s:%s:%s", tenantId, pluginId, key) + if err := cache.Lock(lockKey, 30*time.Second, 3*time.Second); err != nil { return err } + defer func() { _ = cache.Unlock(lockKey) }() - allocatedSize := int64(len(data)) - - storage, err := db.GetOne[models.TenantStorage]( - db.Equal("tenant_id", tenantId), - db.Equal("plugin_id", pluginId), - ) - if err != nil { - if allocatedSize > c.maxStorageSize || allocatedSize > maxSize { - return fmt.Errorf("allocated size is greater than max storage size") + newSize := int64(len(data)) + var oldSize int64 = 0 + if exist, err := c.storage.Exists(tenantId, pluginId, key); err == nil && exist { + if s, err2 := c.storage.StateSize(tenantId, pluginId, key); err2 == nil { + oldSize = s } + } + delta := newSize - oldSize - if err == db.ErrDatabaseNotFound { - storage = models.TenantStorage{ - TenantID: tenantId, - PluginID: pluginId, - Size: allocatedSize, - } - if err := db.Create(&storage); err != nil { + txErr := db.WithTransaction(func(tx *gorm.DB) error { + record, err := db.GetOne[models.TenantStorage]( + db.WithTransactionContext(tx), + db.Equal("tenant_id", tenantId), + db.Equal("plugin_id", pluginId), + db.WLock(), + ) + + if err != nil { + if !errors.Is(err, db.ErrDatabaseNotFound) { return err } - } else { - return err + record = models.TenantStorage{TenantID: tenantId, PluginID: pluginId, Size: 0} + if cerr := db.Create(&record, tx); cerr != nil { + return cerr + } } - } else { - if allocatedSize+storage.Size > maxSize || allocatedSize+storage.Size > c.maxStorageSize { - return fmt.Errorf("allocated size is greater than max storage size") + + if delta > 0 { + if record.Size+delta > c.maxStorageSize || record.Size+delta > maxSize { + return fmt.Errorf("allocated size is greater than max storage size") + } } - err = db.Run( - db.Model(&models.TenantStorage{}), - db.Equal("tenant_id", tenantId), - db.Equal("plugin_id", pluginId), - db.Inc(map[string]int64{"size": allocatedSize}), - ) - if err != nil { + if err := c.storage.Save(tenantId, pluginId, key, data); err != nil { return err } + + if delta != 0 { + if err := db.Run( + db.WithTransactionContext(tx), + db.Model(&models.TenantStorage{}), + db.Equal("tenant_id", tenantId), + db.Equal("plugin_id", pluginId), + db.Inc(map[string]int64{"size": delta}), + ); err != nil { + return err + } + } + + return nil + }) + if txErr != nil { + return txErr } - // delete from cache - if _, err = cache.Del(c.getCacheKey(tenantId, pluginId, key)); err == cache.ErrNotFound { + if _, err := cache.Del(c.getCacheKey(tenantId, pluginId, key)); errors.Is(err, cache.ErrNotFound) { return nil + } else { + return err } - return err } // TODO: raises specific error to avoid confusion @@ -125,6 +145,12 @@ func (c *Persistence) Load(tenantId string, pluginId string, key string) ([]byte } func (c *Persistence) Delete(tenantId string, pluginId string, key string) (int64, error) { + lockKey := fmt.Sprintf("persistence:lock:%s:%s:%s", tenantId, pluginId, key) + if err := cache.Lock(lockKey, 30*time.Second, 3*time.Second); err != nil { + return 0, err + } + defer func() { _ = cache.Unlock(lockKey) }() + // delete from cache and storage deletedNum, err := cache.Del(c.getCacheKey(tenantId, pluginId, key)) if err != nil { @@ -137,19 +163,25 @@ func (c *Persistence) Delete(tenantId string, pluginId string, key string) (int6 return 0, err } - err = c.storage.Delete(tenantId, pluginId, key) - if err != nil { + if err = c.storage.Delete(tenantId, pluginId, key); err != nil { return 0, err } - // update storage size - err = db.Run( - db.Model(&models.TenantStorage{}), - db.Equal("tenant_id", tenantId), - db.Equal("plugin_id", pluginId), - db.Dec(map[string]int64{"size": size}), - ) - if err != nil { + if err := db.WithTransaction(func(tx *gorm.DB) error { + _, _ = db.GetOne[models.TenantStorage]( + db.WithTransactionContext(tx), + db.Equal("tenant_id", tenantId), + db.Equal("plugin_id", pluginId), + db.WLock(), + ) + return db.Run( + db.WithTransactionContext(tx), + db.Model(&models.TenantStorage{}), + db.Equal("tenant_id", tenantId), + db.Equal("plugin_id", pluginId), + db.Inc(map[string]int64{"size": -size}), + ) + }); err != nil { return 0, err } diff --git a/internal/core/persistence/persistence_test.go b/internal/core/persistence/persistence_test.go index c84791624f..7adfd66b6a 100644 --- a/internal/core/persistence/persistence_test.go +++ b/internal/core/persistence/persistence_test.go @@ -8,6 +8,7 @@ import ( "github.com/langgenius/dify-cloud-kit/oss/factory" "github.com/langgenius/dify-plugin-daemon/internal/db" "github.com/langgenius/dify-plugin-daemon/internal/types/app" + "github.com/langgenius/dify-plugin-daemon/internal/types/models" "github.com/langgenius/dify-plugin-daemon/pkg/utils/cache" "github.com/langgenius/dify-plugin-daemon/pkg/utils/strings" "github.com/stretchr/testify/assert" @@ -69,6 +70,104 @@ func TestPersistenceStoreAndLoad(t *testing.T) { assert.Equal(t, string(cacheDataBytes), "data") } +func TestPersistenceOverwriteAdjustsCounter(t *testing.T) { + // init deps + err := cache.InitRedisClient("localhost:6379", "", "difyai123456", false, 0, nil) + assert.Nil(t, err) + defer cache.Close() + + db.Init(&app.Config{ + DBType: app.DB_TYPE_POSTGRESQL, + DBUsername: "postgres", + DBPassword: "difyai123456", + DBHost: "localhost", + DBPort: 5432, + DBDatabase: "dify_plugin_daemon", + DBSslMode: "disable", + }) + defer db.Close() + + oss, err := factory.Load("local", cloudoss.OSSArgs{Local: &cloudoss.Local{Path: "./storage"}}) + assert.Nil(t, err) + + InitPersistence(oss, &app.Config{PersistenceStoragePath: "./persistence_storage", PersistenceStorageMaxSize: 1024 * 1024}) + + tenant := "tenant_" + strings.RandomString(6) + plugin := "plugin_" + strings.RandomString(6) + key := "k_" + strings.RandomString(6) + + // write 4 bytes + assert.Nil(t, persistence.Save(tenant, plugin, -1, key, []byte("abcd"))) + st, err := db.GetOne[models.TenantStorage](db.Equal("tenant_id", tenant), db.Equal("plugin_id", plugin)) + assert.Nil(t, err) + assert.Equal(t, int64(4), st.Size) + + // overwrite with 2 bytes -> size should be 2 + assert.Nil(t, persistence.Save(tenant, plugin, -1, key, []byte("bb"))) + st, err = db.GetOne[models.TenantStorage](db.Equal("tenant_id", tenant), db.Equal("plugin_id", plugin)) + assert.Nil(t, err) + assert.Equal(t, int64(2), st.Size) + + // overwrite with 3 bytes -> size should be 3 + assert.Nil(t, persistence.Save(tenant, plugin, -1, key, []byte("ccc"))) + st, err = db.GetOne[models.TenantStorage](db.Equal("tenant_id", tenant), db.Equal("plugin_id", plugin)) + assert.Nil(t, err) + assert.Equal(t, int64(3), st.Size) + + // and data should be latest + data, err := persistence.Load(tenant, plugin, key) + assert.Nil(t, err) + assert.Equal(t, "ccc", string(data)) +} + +func TestPersistenceOverwriteLimitEnforcedByDelta(t *testing.T) { + // init deps + err := cache.InitRedisClient("localhost:6379", "", "difyai123456", false, 0, nil) + assert.Nil(t, err) + defer cache.Close() + + db.Init(&app.Config{ + DBType: app.DB_TYPE_POSTGRESQL, + DBUsername: "postgres", + DBPassword: "difyai123456", + DBHost: "localhost", + DBPort: 5432, + DBDatabase: "dify_plugin_daemon", + DBSslMode: "disable", + }) + defer db.Close() + + oss, err := factory.Load("local", cloudoss.OSSArgs{Local: &cloudoss.Local{Path: "./storage"}}) + assert.Nil(t, err) + + // set a small global limit 5 bytes + InitPersistence(oss, &app.Config{PersistenceStoragePath: "./persistence_storage", PersistenceStorageMaxSize: 5}) + + tenant := "tenant_" + strings.RandomString(6) + plugin := "plugin_" + strings.RandomString(6) + key := "k_" + strings.RandomString(6) + + // write 4 bytes OK + assert.Nil(t, persistence.Save(tenant, plugin, -1, key, []byte("aaaa"))) + st, err := db.GetOne[models.TenantStorage](db.Equal("tenant_id", tenant), db.Equal("plugin_id", plugin)) + assert.Nil(t, err) + assert.Equal(t, int64(4), st.Size) + + // overwrite with 6 bytes -> delta = +2, 4+2=6 > 5 -> expect error, no change + if err := persistence.Save(tenant, plugin, -1, key, []byte("abcdef")); err == nil { + t.Fatalf("expected limit error, got nil") + } + + st, err = db.GetOne[models.TenantStorage](db.Equal("tenant_id", tenant), db.Equal("plugin_id", plugin)) + assert.Nil(t, err) + assert.Equal(t, int64(4), st.Size) + + // stored data should remain old value + data, err := persistence.Load(tenant, plugin, key) + assert.Nil(t, err) + assert.Equal(t, "aaaa", string(data)) +} + func TestPersistenceSaveAndLoadWithLongKey(t *testing.T) { err := cache.InitRedisClient("localhost:6379", "", "difyai123456", false, 0, nil) assert.Nil(t, err)