Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 74 additions & 42 deletions internal/core/persistence/persistence.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package persistence

import (
"encoding/hex"
"errors"
"fmt"
"path"
"strings"
Expand All @@ -10,6 +11,7 @@ import (
"github.com/langgenius/dify-plugin-daemon/internal/db"
"github.com/langgenius/dify-plugin-daemon/internal/types/models"
"github.com/langgenius/dify-plugin-daemon/pkg/utils/cache"
"gorm.io/gorm"
)

type Persistence struct {
Expand Down Expand Up @@ -47,54 +49,72 @@ func (c *Persistence) Save(tenantId string, pluginId string, maxSize int64, key
maxSize = c.maxStorageSize
}

if err := c.storage.Save(tenantId, pluginId, key, data); err != nil {
lockKey := fmt.Sprintf("persistence:lock:%s:%s:%s", tenantId, pluginId, key)
if err := cache.Lock(lockKey, 30*time.Second, 3*time.Second); err != nil {
return err
}
defer func() { _ = cache.Unlock(lockKey) }()

allocatedSize := int64(len(data))

storage, err := db.GetOne[models.TenantStorage](
db.Equal("tenant_id", tenantId),
db.Equal("plugin_id", pluginId),
)
if err != nil {
if allocatedSize > c.maxStorageSize || allocatedSize > maxSize {
return fmt.Errorf("allocated size is greater than max storage size")
newSize := int64(len(data))
var oldSize int64 = 0
if exist, err := c.storage.Exists(tenantId, pluginId, key); err == nil && exist {
if s, err2 := c.storage.StateSize(tenantId, pluginId, key); err2 == nil {
oldSize = s
}
}
delta := newSize - oldSize

if err == db.ErrDatabaseNotFound {
storage = models.TenantStorage{
TenantID: tenantId,
PluginID: pluginId,
Size: allocatedSize,
}
if err := db.Create(&storage); err != nil {
txErr := db.WithTransaction(func(tx *gorm.DB) error {
record, err := db.GetOne[models.TenantStorage](
db.WithTransactionContext(tx),
db.Equal("tenant_id", tenantId),
db.Equal("plugin_id", pluginId),
db.WLock(),
)

if err != nil {
if !errors.Is(err, db.ErrDatabaseNotFound) {
return err
}
} else {
return err
record = models.TenantStorage{TenantID: tenantId, PluginID: pluginId, Size: 0}
if cerr := db.Create(&record, tx); cerr != nil {
return cerr
}
}
} else {
if allocatedSize+storage.Size > maxSize || allocatedSize+storage.Size > c.maxStorageSize {
return fmt.Errorf("allocated size is greater than max storage size")

if delta > 0 {
if record.Size+delta > c.maxStorageSize || record.Size+delta > maxSize {
return fmt.Errorf("allocated size is greater than max storage size")
}
}

err = db.Run(
db.Model(&models.TenantStorage{}),
db.Equal("tenant_id", tenantId),
db.Equal("plugin_id", pluginId),
db.Inc(map[string]int64{"size": allocatedSize}),
)
if err != nil {
if err := c.storage.Save(tenantId, pluginId, key, data); err != nil {
return err
}

if delta != 0 {
if err := db.Run(
db.WithTransactionContext(tx),
db.Model(&models.TenantStorage{}),
db.Equal("tenant_id", tenantId),
db.Equal("plugin_id", pluginId),
db.Inc(map[string]int64{"size": delta}),
); err != nil {
return err
}
}

return nil
})
if txErr != nil {
return txErr
}

// delete from cache
if _, err = cache.Del(c.getCacheKey(tenantId, pluginId, key)); err == cache.ErrNotFound {
if _, err := cache.Del(c.getCacheKey(tenantId, pluginId, key)); errors.Is(err, cache.ErrNotFound) {
return nil
} else {
return err
}
return err
}

// TODO: raises specific error to avoid confusion
Expand Down Expand Up @@ -125,6 +145,12 @@ func (c *Persistence) Load(tenantId string, pluginId string, key string) ([]byte
}

func (c *Persistence) Delete(tenantId string, pluginId string, key string) (int64, error) {
lockKey := fmt.Sprintf("persistence:lock:%s:%s:%s", tenantId, pluginId, key)
if err := cache.Lock(lockKey, 30*time.Second, 3*time.Second); err != nil {
return 0, err
}
defer func() { _ = cache.Unlock(lockKey) }()

// delete from cache and storage
deletedNum, err := cache.Del(c.getCacheKey(tenantId, pluginId, key))
if err != nil {
Expand All @@ -137,19 +163,25 @@ func (c *Persistence) Delete(tenantId string, pluginId string, key string) (int6
return 0, err
}

err = c.storage.Delete(tenantId, pluginId, key)
if err != nil {
if err = c.storage.Delete(tenantId, pluginId, key); err != nil {
return 0, err
}

// update storage size
err = db.Run(
db.Model(&models.TenantStorage{}),
db.Equal("tenant_id", tenantId),
db.Equal("plugin_id", pluginId),
db.Dec(map[string]int64{"size": size}),
)
if err != nil {
if err := db.WithTransaction(func(tx *gorm.DB) error {
_, _ = db.GetOne[models.TenantStorage](
db.WithTransactionContext(tx),
db.Equal("tenant_id", tenantId),
db.Equal("plugin_id", pluginId),
db.WLock(),
)
return db.Run(
db.WithTransactionContext(tx),
db.Model(&models.TenantStorage{}),
db.Equal("tenant_id", tenantId),
db.Equal("plugin_id", pluginId),
db.Inc(map[string]int64{"size": -size}),
)
}); err != nil {
return 0, err
}

Expand Down
99 changes: 99 additions & 0 deletions internal/core/persistence/persistence_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/langgenius/dify-cloud-kit/oss/factory"
"github.com/langgenius/dify-plugin-daemon/internal/db"
"github.com/langgenius/dify-plugin-daemon/internal/types/app"
"github.com/langgenius/dify-plugin-daemon/internal/types/models"
"github.com/langgenius/dify-plugin-daemon/pkg/utils/cache"
"github.com/langgenius/dify-plugin-daemon/pkg/utils/strings"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -69,6 +70,104 @@ func TestPersistenceStoreAndLoad(t *testing.T) {
assert.Equal(t, string(cacheDataBytes), "data")
}

func TestPersistenceOverwriteAdjustsCounter(t *testing.T) {
// init deps
err := cache.InitRedisClient("localhost:6379", "", "difyai123456", false, 0, nil)
assert.Nil(t, err)
defer cache.Close()

db.Init(&app.Config{
DBType: app.DB_TYPE_POSTGRESQL,
DBUsername: "postgres",
DBPassword: "difyai123456",
DBHost: "localhost",
DBPort: 5432,
DBDatabase: "dify_plugin_daemon",
DBSslMode: "disable",
})
defer db.Close()

oss, err := factory.Load("local", cloudoss.OSSArgs{Local: &cloudoss.Local{Path: "./storage"}})
assert.Nil(t, err)

InitPersistence(oss, &app.Config{PersistenceStoragePath: "./persistence_storage", PersistenceStorageMaxSize: 1024 * 1024})

tenant := "tenant_" + strings.RandomString(6)
plugin := "plugin_" + strings.RandomString(6)
key := "k_" + strings.RandomString(6)

// write 4 bytes
assert.Nil(t, persistence.Save(tenant, plugin, -1, key, []byte("abcd")))
st, err := db.GetOne[models.TenantStorage](db.Equal("tenant_id", tenant), db.Equal("plugin_id", plugin))
assert.Nil(t, err)
assert.Equal(t, int64(4), st.Size)

// overwrite with 2 bytes -> size should be 2
assert.Nil(t, persistence.Save(tenant, plugin, -1, key, []byte("bb")))
st, err = db.GetOne[models.TenantStorage](db.Equal("tenant_id", tenant), db.Equal("plugin_id", plugin))
assert.Nil(t, err)
assert.Equal(t, int64(2), st.Size)

// overwrite with 3 bytes -> size should be 3
assert.Nil(t, persistence.Save(tenant, plugin, -1, key, []byte("ccc")))
st, err = db.GetOne[models.TenantStorage](db.Equal("tenant_id", tenant), db.Equal("plugin_id", plugin))
assert.Nil(t, err)
assert.Equal(t, int64(3), st.Size)

// and data should be latest
data, err := persistence.Load(tenant, plugin, key)
assert.Nil(t, err)
assert.Equal(t, "ccc", string(data))
}

func TestPersistenceOverwriteLimitEnforcedByDelta(t *testing.T) {
// init deps
err := cache.InitRedisClient("localhost:6379", "", "difyai123456", false, 0, nil)
assert.Nil(t, err)
defer cache.Close()

db.Init(&app.Config{
DBType: app.DB_TYPE_POSTGRESQL,
DBUsername: "postgres",
DBPassword: "difyai123456",
DBHost: "localhost",
DBPort: 5432,
DBDatabase: "dify_plugin_daemon",
DBSslMode: "disable",
})
defer db.Close()

oss, err := factory.Load("local", cloudoss.OSSArgs{Local: &cloudoss.Local{Path: "./storage"}})
assert.Nil(t, err)

// set a small global limit 5 bytes
InitPersistence(oss, &app.Config{PersistenceStoragePath: "./persistence_storage", PersistenceStorageMaxSize: 5})

tenant := "tenant_" + strings.RandomString(6)
plugin := "plugin_" + strings.RandomString(6)
key := "k_" + strings.RandomString(6)

// write 4 bytes OK
assert.Nil(t, persistence.Save(tenant, plugin, -1, key, []byte("aaaa")))
st, err := db.GetOne[models.TenantStorage](db.Equal("tenant_id", tenant), db.Equal("plugin_id", plugin))
assert.Nil(t, err)
assert.Equal(t, int64(4), st.Size)

// overwrite with 6 bytes -> delta = +2, 4+2=6 > 5 -> expect error, no change
if err := persistence.Save(tenant, plugin, -1, key, []byte("abcdef")); err == nil {
t.Fatalf("expected limit error, got nil")
}

st, err = db.GetOne[models.TenantStorage](db.Equal("tenant_id", tenant), db.Equal("plugin_id", plugin))
assert.Nil(t, err)
assert.Equal(t, int64(4), st.Size)

// stored data should remain old value
data, err := persistence.Load(tenant, plugin, key)
assert.Nil(t, err)
assert.Equal(t, "aaaa", string(data))
}

func TestPersistenceSaveAndLoadWithLongKey(t *testing.T) {
err := cache.InitRedisClient("localhost:6379", "", "difyai123456", false, 0, nil)
assert.Nil(t, err)
Expand Down
Loading