From 1281dd4c2fd2d393600ea6415ebe42c329413133 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 12 Mar 2026 11:57:39 +0000 Subject: [PATCH 1/4] Initial plan From b6de4cf48a704cbf8a67e24146f916ebec035252 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 12 Mar 2026 12:39:02 +0000 Subject: [PATCH 2/4] feat: tenant-scoped policy loading for GORM adapter (Option A + Option B) Co-authored-by: intel352 <77607+intel352@users.noreply.github.com> --- internal/gorm_adapter.go | 217 +++++++++++++++-- internal/module_casbin.go | 59 ++++- internal/module_casbin_storage_test.go | 325 ++++++++++++++++++++++++- internal/sqlite_dialector.go | 49 +++- 4 files changed, 615 insertions(+), 35 deletions(-) diff --git a/internal/gorm_adapter.go b/internal/gorm_adapter.go index 1558582..c19e365 100644 --- a/internal/gorm_adapter.go +++ b/internal/gorm_adapter.go @@ -3,6 +3,12 @@ package internal // gormCasbinAdapter is a minimal casbin persist.Adapter backed by gorm. // It replaces casbin/gorm-adapter/v3 to avoid the duplicate sqlite driver // registration conflict between glebarez/go-sqlite and modernc.org/sqlite. +// +// Option A – tenant filter: construct the adapter with a filterField / filterValue +// to apply a WHERE clause on every LoadPolicy call, isolating one tenant's rows. +// +// Option B – per-tenant table: pass a resolved table name (e.g. "casbin_rule_acme") +// so each tenant's policies live in a dedicated table. import ( "fmt" @@ -13,42 +19,188 @@ import ( "gorm.io/gorm" ) +// validFilterFields is the set of column names allowed in a filter to prevent +// SQL injection when building dynamic WHERE clauses. +var validFilterFields = map[string]bool{ + "v0": true, "v1": true, "v2": true, + "v3": true, "v4": true, "v5": true, +} + +// GORMFilter specifies a WHERE clause for tenant-scoped policy loading. +// It is the concrete filter type accepted by gormAdapter.LoadFilteredPolicy. +type GORMFilter struct { + // Field is the column name to filter on (one of "v0" through "v5"). + Field string + // Value is the value the column must equal. + Value string +} + // casbinRule mirrors the table schema used by the upstream gorm-adapter. +// The composite uniqueness constraint is NOT declared in the struct tags because +// GORM uses the literal tag name as the index name (e.g. "unique_index"), which +// conflicts in SQLite's global index namespace when multiple tenant tables are +// created from the same struct. The index is instead created by migrateTable +// with a per-table name ("uidx_"). type casbinRule struct { ID uint `gorm:"primarykey;autoIncrement"` - Ptype string `gorm:"size:512;uniqueIndex:unique_index"` - V0 string `gorm:"size:512;uniqueIndex:unique_index"` - V1 string `gorm:"size:512;uniqueIndex:unique_index"` - V2 string `gorm:"size:512;uniqueIndex:unique_index"` - V3 string `gorm:"size:512;uniqueIndex:unique_index"` - V4 string `gorm:"size:512;uniqueIndex:unique_index"` - V5 string `gorm:"size:512;uniqueIndex:unique_index"` + Ptype string `gorm:"size:512"` + V0 string `gorm:"size:512"` + V1 string `gorm:"size:512"` + V2 string `gorm:"size:512"` + V3 string `gorm:"size:512"` + V4 string `gorm:"size:512"` + V5 string `gorm:"size:512"` } -// TableName returns the gorm table name. +// TableName returns the default gorm table name. +// When a custom table name is needed the caller uses db.Table(name) explicitly. func (casbinRule) TableName() string { return "casbin_rule" } -// gormAdapter implements persist.Adapter using a gorm.DB. +// gormAdapter implements persist.FilteredAdapter using a gorm.DB. +// When filterField/filterValue are non-empty all load/save operations are +// scoped to rows where filterField = filterValue (Option A). +// The tableName field enables per-tenant tables (Option B). +// +// FilteredAdapter contract (Casbin v2): +// - IsFiltered() must return false before the first LoadPolicy call so that +// NewEnforcer will call LoadPolicy during initialisation. +// - IsFiltered() returns true once a filtered load has been performed, which +// tells the Casbin enforcer not to allow a bulk SavePolicy. +// - The filter is applied automatically inside LoadPolicy when filterField / +// filterValue are configured, so callers need not use LoadFilteredPolicy. type gormAdapter struct { - db *gorm.DB - tableName string + db *gorm.DB + tableName string + filterField string // Option A: column name (e.g. "v0"); empty = no filter + filterValue string // Option A: value to match + filtered bool // true after the first filtered LoadPolicy; starts false +} + +// validTableNameRune returns true when ch is allowed in a table name used as a +// raw SQL identifier. Only ASCII letters, digits, underscores and hyphens are +// accepted; this prevents characters that could break %q SQL quoting. +func validTableNameRune(ch rune) bool { + return (ch >= 'a' && ch <= 'z') || + (ch >= 'A' && ch <= 'Z') || + (ch >= '0' && ch <= '9') || + ch == '_' || ch == '-' +} + +// table returns a *gorm.DB scoped to the adapter's table name. +func (a *gormAdapter) table() *gorm.DB { + return a.db.Table(a.tableName) } -// newGORMAdapter auto-migrates the casbin_rule table and returns an adapter. -func newGORMAdapter(db *gorm.DB, tableName string) (*gormAdapter, error) { +// newGORMAdapter auto-migrates the table and returns an adapter. +// filterField and filterValue implement Option A (tenant-scoped WHERE clause). +// Pass empty strings to disable filtering. +func newGORMAdapter(db *gorm.DB, tableName, filterField, filterValue string) (*gormAdapter, error) { if tableName == "" { tableName = "casbin_rule" } - if err := db.AutoMigrate(&casbinRule{}); err != nil { + // Validate tableName to prevent SQL injection via the raw CREATE INDEX SQL + // constructed in migrateTable. Only allow characters that are safe to use + // inside a double-quoted SQL identifier without additional escaping. + for _, ch := range tableName { + if !validTableNameRune(ch) { + return nil, fmt.Errorf("gorm casbin adapter: invalid character %q in table name %q", ch, tableName) + } + } + if filterField != "" && !validFilterFields[filterField] { + return nil, fmt.Errorf("gorm casbin adapter: invalid filter_field %q (must be v0-v5)", filterField) + } + if err := migrateTable(db, tableName); err != nil { return nil, fmt.Errorf("gorm casbin adapter: migrate: %w", err) } - return &gormAdapter{db: db, tableName: tableName}, nil + return &gormAdapter{ + db: db, + tableName: tableName, + filterField: filterField, + filterValue: filterValue, + }, nil +} + +// migrateTable creates the casbin rule table for tableName. +// +// The composite unique constraint is created as a named index "uidx_" +// rather than via struct tags so that each tenant table gets its own index name. +// This avoids SQLite's flat index namespace where two tables migrated from the +// same struct with the same named index tag would conflict. +// +// AutoMigrate is only invoked when the table does not yet exist. Skipping it +// for existing tables avoids SQLite-specific issues in the generic GORM migrator +// (e.g. HasColumn falling back to an information_schema query that does not +// exist in SQLite, which can cause unintended ALTER TABLE attempts). +func migrateTable(db *gorm.DB, tableName string) error { + if !db.Migrator().HasTable(tableName) { + if err := db.Table(tableName).AutoMigrate(&casbinRule{}); err != nil { + return err + } + } + // Create a composite unique index with a per-table name. + // We check for existence first so the operation is idempotent on all + // supported databases (SQLite, PostgreSQL, MySQL). + idxName := "uidx_" + strings.ReplaceAll(tableName, "-", "_") + if db.Migrator().HasIndex(tableName, idxName) { + return nil + } + return db.Exec(fmt.Sprintf( + `CREATE UNIQUE INDEX %q ON %q ("ptype","v0","v1","v2","v3","v4","v5")`, + idxName, tableName, + )).Error +} + +// IsFiltered returns true once a filtered LoadPolicy has been performed. +// Before the first load it returns false so that casbin.NewEnforcer will call +// LoadPolicy during initialisation (Casbin skips LoadPolicy when IsFiltered is +// true at construction time). +func (a *gormAdapter) IsFiltered() bool { + return a.filtered } -// LoadPolicy loads all policies from the database into the model. +// LoadPolicy loads all (or filtered) policies from the database into the model. +// When filterField/filterValue are configured the WHERE clause is applied and +// filtered is set to true so that subsequent SavePolicy calls are scoped. func (a *gormAdapter) LoadPolicy(mdl model.Model) error { + if err := a.loadWithFilter(mdl, a.filterField, a.filterValue); err != nil { + return err + } + if a.filterField != "" && a.filterValue != "" { + a.filtered = true + } + return nil +} + +// LoadFilteredPolicy loads only policies matching the supplied filter. +// filter must be a GORMFilter value. Sets IsFiltered to true on success. +func (a *gormAdapter) LoadFilteredPolicy(mdl model.Model, filter interface{}) error { + f, ok := filter.(GORMFilter) + if !ok { + return fmt.Errorf("gorm casbin adapter: LoadFilteredPolicy expects GORMFilter, got %T", filter) + } + if f.Field != "" && !validFilterFields[f.Field] { + return fmt.Errorf("gorm casbin adapter: invalid filter field %q (must be v0-v5)", f.Field) + } + if err := a.loadWithFilter(mdl, f.Field, f.Value); err != nil { + return err + } + a.filtered = true + return nil +} + +// loadWithFilter is the shared implementation used by LoadPolicy and +// LoadFilteredPolicy. field must be a pre-validated column name (v0-v5) or +// empty; the backtick quoting is an additional defence-in-depth measure. +func (a *gormAdapter) loadWithFilter(mdl model.Model, field, value string) error { + q := a.table() + if field != "" && value != "" { + // field is already validated against validFilterFields (v0-v5), so + // backtick-quoting is safe and provides defence-in-depth against any + // future code path that might supply an unvalidated column name. + q = q.Where("`"+field+"` = ?", value) + } var rules []casbinRule - if err := a.db.Find(&rules).Error; err != nil { + if err := q.Find(&rules).Error; err != nil { return err } for _, rule := range rules { @@ -58,10 +210,9 @@ func (a *gormAdapter) LoadPolicy(mdl model.Model) error { } // SavePolicy saves all policies from the model into the database. +// When a tenant filter is active only the matching rows are replaced, so that +// other tenants' data is not affected. func (a *gormAdapter) SavePolicy(mdl model.Model) error { - if err := a.db.Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&casbinRule{}).Error; err != nil { - return err - } var rules []casbinRule for ptype, assertions := range mdl["p"] { for _, assertion := range assertions.Policy { @@ -73,8 +224,21 @@ func (a *gormAdapter) SavePolicy(mdl model.Model) error { rules = append(rules, lineToRule(ptype, assertion)) } } + + if a.IsFiltered() { + // Delete only rows belonging to this tenant, then re-insert. + if err := a.table().Where(a.filterField+" = ?", a.filterValue).Delete(&casbinRule{}).Error; err != nil { + return err + } + } else { + // Delete all rows in the table. + if err := a.table().Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&casbinRule{}).Error; err != nil { + return err + } + } + if len(rules) > 0 { - return a.db.CreateInBatches(rules, 100).Error + return a.table().CreateInBatches(rules, 100).Error } return nil } @@ -82,18 +246,18 @@ func (a *gormAdapter) SavePolicy(mdl model.Model) error { // AddPolicy adds a policy rule to the database. func (a *gormAdapter) AddPolicy(sec, ptype string, rule []string) error { r := lineToRule(ptype, rule) - return a.db.Create(&r).Error + return a.table().Create(&r).Error } // RemovePolicy removes a policy rule from the database. func (a *gormAdapter) RemovePolicy(sec, ptype string, rule []string) error { r := lineToRule(ptype, rule) - return a.db.Where(&r).Delete(&casbinRule{}).Error + return a.table().Where(&r).Delete(&casbinRule{}).Error } // RemoveFilteredPolicy removes policy rules matching the given filter. func (a *gormAdapter) RemoveFilteredPolicy(sec, ptype string, fieldIndex int, fieldValues ...string) error { - query := a.db.Where("ptype = ?", ptype) + query := a.table().Where("ptype = ?", ptype) fields := []string{"v0", "v1", "v2", "v3", "v4", "v5"} for i, v := range fieldValues { if v != "" { @@ -129,5 +293,6 @@ func lineToRule(ptype string, rule []string) casbinRule { return r } -// Compile-time interface check. -var _ persist.Adapter = (*gormAdapter)(nil) +// Compile-time interface check – gormAdapter must satisfy FilteredAdapter +// (which is a superset of Adapter). +var _ persist.FilteredAdapter = (*gormAdapter)(nil) diff --git a/internal/module_casbin.go b/internal/module_casbin.go index 5ede417..b1156f1 100644 --- a/internal/module_casbin.go +++ b/internal/module_casbin.go @@ -1,10 +1,12 @@ package internal import ( + "bytes" "context" "fmt" "strings" "sync" + "text/template" "time" "github.com/casbin/casbin/v2" @@ -42,7 +44,19 @@ type adapterConfig struct { // GORM adapter fields. Driver string `yaml:"driver"` // "postgres", "mysql", or "sqlite3" DSN string `yaml:"dsn"` - TableName string `yaml:"table_name"` // optional; defaults to "casbin_rule" + TableName string `yaml:"table_name"` // optional; defaults to "casbin_rule"; supports Go templates + + // Option B: Tenant identifier used as {{.Tenant}} in a table_name template. + // Example: table_name "casbin_rule_{{.Tenant}}" with Tenant "acme" resolves + // to table "casbin_rule_acme", giving each tenant its own policy table. + Tenant string `yaml:"tenant"` + + // Option A: Tenant-scoped policy loading via a WHERE clause on LoadPolicy. + // FilterField is a column name (v0–v5) and FilterValue is the value that + // column must equal. When both are set the adapter implements + // persist.FilteredAdapter and only loads/saves matching rows. + FilterField string `yaml:"filter_field"` + FilterValue string `yaml:"filter_value"` } // watcherConfig describes the optional polling reload behaviour. @@ -136,6 +150,9 @@ func parseAdapterConfig(raw map[string]any) adapterConfig { a.Driver, _ = raw["driver"].(string) a.DSN, _ = raw["dsn"].(string) a.TableName, _ = raw["table_name"].(string) + a.Tenant, _ = raw["tenant"].(string) + a.FilterField, _ = raw["filter_field"].(string) + a.FilterValue, _ = raw["filter_value"].(string) return a } @@ -217,13 +234,38 @@ func (m *CasbinModule) buildGORMAdapter() (persist.Adapter, error) { return nil, fmt.Errorf("authz.casbin %q: open gorm db: %w", m.name, err) } - a, err := newGORMAdapter(db, m.config.Adapter.TableName) + // Option B: resolve table_name template (e.g. "casbin_rule_{{.Tenant}}"). + tableName, err := resolveTableNameTemplate(m.config.Adapter.TableName, m.config.Adapter) + if err != nil { + return nil, fmt.Errorf("authz.casbin %q: %w", m.name, err) + } + + // Option A: pass filter fields for tenant-scoped policy loading. + a, err := newGORMAdapter(db, tableName, m.config.Adapter.FilterField, m.config.Adapter.FilterValue) if err != nil { return nil, fmt.Errorf("authz.casbin %q: create gorm adapter: %w", m.name, err) } return a, nil } +// resolveTableNameTemplate resolves a Go template in tableName. +// The template data is the adapterConfig, so {{.Tenant}} expands to cfg.Tenant. +// If tableName contains no template markers it is returned unchanged. +func resolveTableNameTemplate(tableName string, cfg adapterConfig) (string, error) { + if !strings.Contains(tableName, "{{") { + return tableName, nil + } + tmpl, err := template.New("table_name").Parse(tableName) + if err != nil { + return "", fmt.Errorf("parse table_name template: %w", err) + } + var buf bytes.Buffer + if err := tmpl.Execute(&buf, cfg); err != nil { + return "", fmt.Errorf("execute table_name template: %w", err) + } + return buf.String(), nil +} + // Init builds the Casbin enforcer from the configured adapter. func (m *CasbinModule) Init() error { m.mu.Lock() @@ -319,6 +361,8 @@ func (m *CasbinModule) Enforce(sub, obj, act string) (bool, error) { } // AddPolicy adds a policy rule and saves it to the adapter. +// When the enforcer uses a FilteredAdapter, SavePolicy is skipped because +// the incremental adapter.AddPolicy already persisted the row. func (m *CasbinModule) AddPolicy(rule []string) (bool, error) { m.mu.Lock() defer m.mu.Unlock() @@ -329,7 +373,7 @@ func (m *CasbinModule) AddPolicy(rule []string) (bool, error) { if err != nil { return false, err } - if ok { + if ok && !m.enforcer.IsFiltered() { if err := m.enforcer.SavePolicy(); err != nil { return false, err } @@ -338,6 +382,7 @@ func (m *CasbinModule) AddPolicy(rule []string) (bool, error) { } // RemovePolicy removes a policy rule and saves the adapter. +// When the enforcer uses a FilteredAdapter, SavePolicy is skipped. func (m *CasbinModule) RemovePolicy(rule []string) (bool, error) { m.mu.Lock() defer m.mu.Unlock() @@ -348,7 +393,7 @@ func (m *CasbinModule) RemovePolicy(rule []string) (bool, error) { if err != nil { return false, err } - if ok { + if ok && !m.enforcer.IsFiltered() { if err := m.enforcer.SavePolicy(); err != nil { return false, err } @@ -357,6 +402,7 @@ func (m *CasbinModule) RemovePolicy(rule []string) (bool, error) { } // AddGroupingPolicy adds a role mapping and saves the adapter. +// When the enforcer uses a FilteredAdapter, SavePolicy is skipped. func (m *CasbinModule) AddGroupingPolicy(rule []string) (bool, error) { m.mu.Lock() defer m.mu.Unlock() @@ -367,7 +413,7 @@ func (m *CasbinModule) AddGroupingPolicy(rule []string) (bool, error) { if err != nil { return false, err } - if ok { + if ok && !m.enforcer.IsFiltered() { if err := m.enforcer.SavePolicy(); err != nil { return false, err } @@ -376,6 +422,7 @@ func (m *CasbinModule) AddGroupingPolicy(rule []string) (bool, error) { } // RemoveGroupingPolicy removes a role mapping and saves the adapter. +// When the enforcer uses a FilteredAdapter, SavePolicy is skipped. func (m *CasbinModule) RemoveGroupingPolicy(rule []string) (bool, error) { m.mu.Lock() defer m.mu.Unlock() @@ -386,7 +433,7 @@ func (m *CasbinModule) RemoveGroupingPolicy(rule []string) (bool, error) { if err != nil { return false, err } - if ok { + if ok && !m.enforcer.IsFiltered() { if err := m.enforcer.SavePolicy(); err != nil { return false, err } diff --git a/internal/module_casbin_storage_test.go b/internal/module_casbin_storage_test.go index 3681e9c..5e1b738 100644 --- a/internal/module_casbin_storage_test.go +++ b/internal/module_casbin_storage_test.go @@ -219,7 +219,330 @@ func TestGORMAdapter_MissingDSN(t *testing.T) { } } -// --- inMemoryAdapter mutation tests --- +// --- Option A: tenant-filter GORM tests --- + +// TestGORMAdapter_FilterField_InvalidField checks that an invalid filter_field +// is rejected at adapter creation time. +func TestGORMAdapter_FilterField_InvalidField(t *testing.T) { + m, err := newCasbinModule("authz", map[string]any{ + "model": testModel, + "adapter": map[string]any{ + "type": "gorm", + "driver": "sqlite3", + "dsn": ":memory:", + "filter_field": "not_a_column", // invalid + "filter_value": "tenant_a", + }, + }) + if err != nil { + t.Fatalf("newCasbinModule: %v", err) + } + if err := m.Init(); err == nil { + t.Error("expected Init to fail for invalid filter_field") + } +} + +// TestGORMAdapter_InvalidTableName checks that a table name with characters +// unsafe for use as a raw SQL identifier is rejected at Init time. +func TestGORMAdapter_InvalidTableName(t *testing.T) { + m, err := newCasbinModule("authz", map[string]any{ + "model": testModel, + "adapter": map[string]any{ + "type": "gorm", + "driver": "sqlite3", + "dsn": ":memory:", + "table_name": `casbin_rule"; DROP TABLE casbin_rule; --`, // injection attempt + }, + }) + if err != nil { + t.Fatalf("newCasbinModule: %v", err) + } + if err := m.Init(); err == nil { + t.Error("expected Init to fail for unsafe table name") + } +} + +// TestGORMAdapter_TenantFilter demonstrates Option A: two modules share the +// same SQLite file-based database but each only loads/manages its own rows, +// identified by the value stored in v0 (the "tenant" column). +func TestGORMAdapter_TenantFilter(t *testing.T) { + dir := t.TempDir() + dsn := "file:" + dir + "/authz.db" + + // Build a shared multi-tenant model where v0 holds the tenant. + // Policy line: p, , , , + // Matcher: g(r.sub, p.sub, p.v0) && p.v0 == r.dom && r.obj == p.obj ... + // For simplicity we reuse testModel (sub, obj, act) and put tenant in v0. + // We seed the DB directly via a plain (no-filter) adapter first. + + // --- seed phase: populate both tenants' rows --- + seed, err := newCasbinModule("seed", map[string]any{ + "model": testModel, + "adapter": map[string]any{ + "type": "gorm", + "driver": "sqlite3", + "dsn": dsn, + }, + }) + if err != nil { + t.Fatalf("seed newCasbinModule: %v", err) + } + if err := seed.Init(); err != nil { + t.Fatalf("seed Init: %v", err) + } + // tenant_a policy: alice → admin → GET /api + if _, err := seed.AddPolicy([]string{"tenant_a", "/api", "GET"}); err != nil { + t.Fatalf("seed AddPolicy tenant_a: %v", err) + } + // tenant_b policy: bob → admin → POST /data + if _, err := seed.AddPolicy([]string{"tenant_b", "/data", "POST"}); err != nil { + t.Fatalf("seed AddPolicy tenant_b: %v", err) + } + + // --- tenant_a module: filter on v0 = "tenant_a" --- + modA, err := newCasbinModule("authz_a", map[string]any{ + "model": testModel, + "adapter": map[string]any{ + "type": "gorm", + "driver": "sqlite3", + "dsn": dsn, + "filter_field": "v0", + "filter_value": "tenant_a", + }, + }) + if err != nil { + t.Fatalf("modA newCasbinModule: %v", err) + } + if err := modA.Init(); err != nil { + t.Fatalf("modA Init: %v", err) + } + + // tenant_a can access /api GET + if ok, err := modA.Enforce("tenant_a", "/api", "GET"); err != nil || !ok { + t.Errorf("tenant_a should be allowed GET /api: ok=%v err=%v", ok, err) + } + // tenant_b's policy is NOT loaded into modA + if ok, err := modA.Enforce("tenant_b", "/data", "POST"); err != nil || ok { + t.Errorf("tenant_b policy must not be visible in tenant_a module: ok=%v err=%v", ok, err) + } + + // --- tenant_b module: filter on v0 = "tenant_b" --- + modB, err := newCasbinModule("authz_b", map[string]any{ + "model": testModel, + "adapter": map[string]any{ + "type": "gorm", + "driver": "sqlite3", + "dsn": dsn, + "filter_field": "v0", + "filter_value": "tenant_b", + }, + }) + if err != nil { + t.Fatalf("modB newCasbinModule: %v", err) + } + if err := modB.Init(); err != nil { + t.Fatalf("modB Init: %v", err) + } + + // tenant_b can access /data POST + if ok, err := modB.Enforce("tenant_b", "/data", "POST"); err != nil || !ok { + t.Errorf("tenant_b should be allowed POST /data: ok=%v err=%v", ok, err) + } + // tenant_a's policy is NOT loaded into modB + if ok, err := modB.Enforce("tenant_a", "/api", "GET"); err != nil || ok { + t.Errorf("tenant_a policy must not be visible in tenant_b module: ok=%v err=%v", ok, err) + } +} + +// TestGORMAdapter_TenantFilter_MutationIsolation verifies that AddPolicy and +// RemovePolicy on a filtered module only affect that tenant's rows and do not +// corrupt other tenants' data. +func TestGORMAdapter_TenantFilter_MutationIsolation(t *testing.T) { + dir := t.TempDir() + dsn := "file:" + dir + "/authz.db" + + // Seed shared DB. + seed, err := newCasbinModule("seed", map[string]any{ + "model": testModel, + "adapter": map[string]any{ + "type": "gorm", + "driver": "sqlite3", + "dsn": dsn, + }, + }) + if err != nil { + t.Fatalf("seed newCasbinModule: %v", err) + } + if err := seed.Init(); err != nil { + t.Fatalf("seed Init: %v", err) + } + if _, err := seed.AddPolicy([]string{"tenant_b", "/reports", "GET"}); err != nil { + t.Fatalf("seed AddPolicy: %v", err) + } + + // Filtered module for tenant_a. + modA, err := newCasbinModule("authz_a", map[string]any{ + "model": testModel, + "adapter": map[string]any{ + "type": "gorm", + "driver": "sqlite3", + "dsn": dsn, + "filter_field": "v0", + "filter_value": "tenant_a", + }, + }) + if err != nil { + t.Fatalf("modA newCasbinModule: %v", err) + } + if err := modA.Init(); err != nil { + t.Fatalf("modA Init: %v", err) + } + + // Add a policy for tenant_a. + if _, err := modA.AddPolicy([]string{"tenant_a", "/metrics", "GET"}); err != nil { + t.Fatalf("modA AddPolicy: %v", err) + } + if ok, err := modA.Enforce("tenant_a", "/metrics", "GET"); err != nil || !ok { + t.Errorf("tenant_a should be allowed GET /metrics: ok=%v err=%v", ok, err) + } + + // Remove tenant_a's policy. + if _, err := modA.RemovePolicy([]string{"tenant_a", "/metrics", "GET"}); err != nil { + t.Fatalf("modA RemovePolicy: %v", err) + } + if ok, err := modA.Enforce("tenant_a", "/metrics", "GET"); err != nil || ok { + t.Errorf("tenant_a /metrics should be denied after removal: ok=%v err=%v", ok, err) + } + + // Verify tenant_b's row is still intact in the shared DB. + modB, err := newCasbinModule("authz_b", map[string]any{ + "model": testModel, + "adapter": map[string]any{ + "type": "gorm", + "driver": "sqlite3", + "dsn": dsn, + "filter_field": "v0", + "filter_value": "tenant_b", + }, + }) + if err != nil { + t.Fatalf("modB newCasbinModule: %v", err) + } + if err := modB.Init(); err != nil { + t.Fatalf("modB Init: %v", err) + } + if ok, err := modB.Enforce("tenant_b", "/reports", "GET"); err != nil || !ok { + t.Errorf("tenant_b /reports should still be allowed: ok=%v err=%v", ok, err) + } +} + +// --- Option B: per-tenant table GORM tests --- + +// TestGORMAdapter_PerTenantTable verifies that two modules configured with +// different table names have completely independent policy stores. +func TestGORMAdapter_PerTenantTable(t *testing.T) { + dir := t.TempDir() + dsn := "file:" + dir + "/authz.db" + + makeModule := func(name, table string) *CasbinModule { + t.Helper() + m, err := newCasbinModule(name, map[string]any{ + "model": testModel, + "adapter": map[string]any{ + "type": "gorm", + "driver": "sqlite3", + "dsn": dsn, + "table_name": table, + }, + }) + if err != nil { + t.Fatalf("newCasbinModule %s: %v", name, err) + } + if err := m.Init(); err != nil { + t.Fatalf("Init %s: %v", name, err) + } + return m + } + + modA := makeModule("authz_a", "casbin_rule_tenant_a") + modB := makeModule("authz_b", "casbin_rule_tenant_b") + + // Add distinct policies to each table. + if _, err := modA.AddPolicy([]string{"alice", "/a", "GET"}); err != nil { + t.Fatalf("modA AddPolicy: %v", err) + } + if _, err := modB.AddPolicy([]string{"bob", "/b", "POST"}); err != nil { + t.Fatalf("modB AddPolicy: %v", err) + } + + // Each module only sees its own policies. + if ok, err := modA.Enforce("alice", "/a", "GET"); err != nil || !ok { + t.Errorf("alice should be allowed GET /a in tenant_a: ok=%v err=%v", ok, err) + } + if ok, err := modA.Enforce("bob", "/b", "POST"); err != nil || ok { + t.Errorf("bob's policy must not exist in tenant_a table: ok=%v err=%v", ok, err) + } + if ok, err := modB.Enforce("bob", "/b", "POST"); err != nil || !ok { + t.Errorf("bob should be allowed POST /b in tenant_b: ok=%v err=%v", ok, err) + } + if ok, err := modB.Enforce("alice", "/a", "GET"); err != nil || ok { + t.Errorf("alice's policy must not exist in tenant_b table: ok=%v err=%v", ok, err) + } +} + +// TestGORMAdapter_TableNameTemplate verifies that a Go template in table_name +// is resolved using the Tenant config field (Option B). +func TestGORMAdapter_TableNameTemplate(t *testing.T) { + dir := t.TempDir() + dsn := "file:" + dir + "/authz.db" + + m, err := newCasbinModule("authz", map[string]any{ + "model": testModel, + "adapter": map[string]any{ + "type": "gorm", + "driver": "sqlite3", + "dsn": dsn, + "table_name": "casbin_rule_{{.Tenant}}", + "tenant": "acme_corp", + }, + }) + if err != nil { + t.Fatalf("newCasbinModule: %v", err) + } + if err := m.Init(); err != nil { + t.Fatalf("Init: %v", err) + } + + // Add and enforce a policy – proves the table was created and is usable. + if _, err := m.AddPolicy([]string{"alice", "/dashboard", "GET"}); err != nil { + t.Fatalf("AddPolicy: %v", err) + } + if ok, err := m.Enforce("alice", "/dashboard", "GET"); err != nil || !ok { + t.Errorf("alice should be allowed GET /dashboard: ok=%v err=%v", ok, err) + } +} + +// TestGORMAdapter_TableNameTemplate_Invalid checks that an invalid template +// expression is rejected at Init time. +func TestGORMAdapter_TableNameTemplate_Invalid(t *testing.T) { + m, err := newCasbinModule("authz", map[string]any{ + "model": testModel, + "adapter": map[string]any{ + "type": "gorm", + "driver": "sqlite3", + "dsn": ":memory:", + "table_name": "casbin_rule_{{.Unclosed", // broken template + }, + }) + if err != nil { + t.Fatalf("newCasbinModule: %v", err) + } + if err := m.Init(); err == nil { + t.Error("expected Init to fail for invalid table_name template") + } +} + + func TestInMemoryAdapter_AddRemovePolicy(t *testing.T) { m := buildModule(t, diff --git a/internal/sqlite_dialector.go b/internal/sqlite_dialector.go index f7dd19c..864e152 100644 --- a/internal/sqlite_dialector.go +++ b/internal/sqlite_dialector.go @@ -65,11 +65,56 @@ func (d *sqliteDialector) DefaultValueOf(field *schema.Field) clause.Expression } func (d *sqliteDialector) Migrator(db *gorm.DB) gorm.Migrator { - return migrator.Migrator{Config: migrator.Config{ + return &sqliteMigrator{Migrator: migrator.Migrator{Config: migrator.Config{ DB: db, Dialector: d, CreateIndexAfterCreateTable: true, - }} + }}} +} + +// sqliteMigrator extends the generic GORM migrator with SQLite-specific +// implementations for HasTable and HasIndex. The generic migrator queries +// information_schema which does not exist in SQLite. +type sqliteMigrator struct { + migrator.Migrator +} + +// HasTable reports whether the named table exists in the SQLite database. +func (m sqliteMigrator) HasTable(value interface{}) bool { + var count int + _ = m.RunWithValue(value, func(stmt *gorm.Statement) error { + return m.DB.Raw( + "SELECT count(*) FROM sqlite_master WHERE type = 'table' AND name = ?", + stmt.Table, + ).Row().Scan(&count) + }) + return count > 0 +} + +// HasIndex reports whether an index with the given name exists in the +// SQLite database (index names are global in SQLite). +func (m sqliteMigrator) HasIndex(value interface{}, name string) bool { + var count int + _ = m.RunWithValue(value, func(stmt *gorm.Statement) error { + return m.DB.Raw( + "SELECT count(*) FROM sqlite_master WHERE type = 'index' AND name = ?", + name, + ).Row().Scan(&count) + }) + return count > 0 +} + +// HasColumn reports whether the given column exists in the table using SQLite's +// PRAGMA table_info, which works correctly for all SQLite table names. +func (m sqliteMigrator) HasColumn(value interface{}, name string) bool { + var count int + _ = m.RunWithValue(value, func(stmt *gorm.Statement) error { + return m.DB.Raw( + "SELECT count(*) FROM pragma_table_info(?) WHERE name = ?", + stmt.Table, name, + ).Row().Scan(&count) + }) + return count > 0 } func (d *sqliteDialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) { From 4bbfad3f80ed3e259118f5f7939fe16aa0303e43 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 12 Mar 2026 15:59:14 +0000 Subject: [PATCH 3/4] fix: address code review comments - identifier quoting, filter validation, tenant scope Co-authored-by: intel352 <77607+intel352@users.noreply.github.com> --- internal/gorm_adapter.go | 104 +++++++++++++++++++++---- internal/module_casbin_storage_test.go | 64 ++++++++++++++- internal/sqlite_dialector.go | 4 +- 3 files changed, 154 insertions(+), 18 deletions(-) diff --git a/internal/gorm_adapter.go b/internal/gorm_adapter.go index c19e365..b4ff9c4 100644 --- a/internal/gorm_adapter.go +++ b/internal/gorm_adapter.go @@ -17,8 +17,25 @@ import ( "github.com/casbin/casbin/v2/model" "github.com/casbin/casbin/v2/persist" "gorm.io/gorm" + "gorm.io/gorm/clause" ) +// identWriter wraps strings.Builder to satisfy clause.Writer (which requires +// both WriteString and WriteByte). It is used to invoke dialector.QuoteTo so +// that identifier quoting is database-specific (backticks for MySQL/SQLite, +// double-quotes for PostgreSQL). +type identWriter struct{ strings.Builder } + +func (w *identWriter) WriteByte(b byte) error { return w.Builder.WriteByte(b) } + +// quoteIdent returns name quoted as an identifier using the dialector's own +// quoting rules, making SQL safe to use on MySQL, PostgreSQL and SQLite. +func quoteIdent(db *gorm.DB, name string) string { + var w identWriter + db.Dialector.QuoteTo(&w, name) + return w.String() +} + // validFilterFields is the set of column names allowed in a filter to prevent // SQL injection when building dynamic WHERE clauses. var validFilterFields = map[string]bool{ @@ -109,6 +126,11 @@ func newGORMAdapter(db *gorm.DB, tableName, filterField, filterValue string) (*g if filterField != "" && !validFilterFields[filterField] { return nil, fmt.Errorf("gorm casbin adapter: invalid filter_field %q (must be v0-v5)", filterField) } + // filterField and filterValue must be set together; partial configuration is + // ambiguous and would silently skip filtering while looking like it is enabled. + if (filterField == "") != (filterValue == "") { + return nil, fmt.Errorf("gorm casbin adapter: filter_field and filter_value must both be set or both be empty") + } if err := migrateTable(db, tableName); err != nil { return nil, fmt.Errorf("gorm casbin adapter: migrate: %w", err) } @@ -127,12 +149,15 @@ func newGORMAdapter(db *gorm.DB, tableName, filterField, filterValue string) (*g // This avoids SQLite's flat index namespace where two tables migrated from the // same struct with the same named index tag would conflict. // -// AutoMigrate is only invoked when the table does not yet exist. Skipping it -// for existing tables avoids SQLite-specific issues in the generic GORM migrator -// (e.g. HasColumn falling back to an information_schema query that does not -// exist in SQLite, which can cause unintended ALTER TABLE attempts). +// On SQLite, AutoMigrate is skipped when the table already exists because the +// generic GORM migrator's HasColumn falls back to information_schema (which does +// not exist in SQLite), causing unintended ALTER TABLE attempts. On PostgreSQL +// and MySQL, AutoMigrate always runs so that schema drift is corrected. func migrateTable(db *gorm.DB, tableName string) error { - if !db.Migrator().HasTable(tableName) { + isSQLite := db.Dialector.Name() == "sqlite" + tableExists := db.Migrator().HasTable(tableName) + + if !isSQLite || !tableExists { if err := db.Table(tableName).AutoMigrate(&casbinRule{}); err != nil { return err } @@ -144,9 +169,18 @@ func migrateTable(db *gorm.DB, tableName string) error { if db.Migrator().HasIndex(tableName, idxName) { return nil } + // Use the dialector's identifier quoting so the SQL is correct for every + // supported database: backticks on MySQL/SQLite, double-quotes on PostgreSQL. + cols := []string{"ptype", "v0", "v1", "v2", "v3", "v4", "v5"} + quotedCols := make([]string, len(cols)) + for i, c := range cols { + quotedCols[i] = quoteIdent(db, c) + } return db.Exec(fmt.Sprintf( - `CREATE UNIQUE INDEX %q ON %q ("ptype","v0","v1","v2","v3","v4","v5")`, - idxName, tableName, + "CREATE UNIQUE INDEX %s ON %s (%s)", + quoteIdent(db, idxName), + quoteIdent(db, tableName), + strings.Join(quotedCols, ", "), )).Error } @@ -172,7 +206,9 @@ func (a *gormAdapter) LoadPolicy(mdl model.Model) error { } // LoadFilteredPolicy loads only policies matching the supplied filter. -// filter must be a GORMFilter value. Sets IsFiltered to true on success. +// filter must be a GORMFilter value. On success the active filter is stored on +// the adapter so that subsequent SavePolicy calls are correctly scoped, and +// IsFiltered is set to true. func (a *gormAdapter) LoadFilteredPolicy(mdl model.Model, filter interface{}) error { f, ok := filter.(GORMFilter) if !ok { @@ -184,20 +220,29 @@ func (a *gormAdapter) LoadFilteredPolicy(mdl model.Model, filter interface{}) er if err := a.loadWithFilter(mdl, f.Field, f.Value); err != nil { return err } - a.filtered = true + // Persist the active filter so that subsequent SavePolicy calls can correctly + // scope their operations. When no effective filter was provided, clear any + // previously stored filter to avoid inconsistent SavePolicy behavior. + if f.Field != "" && f.Value != "" { + a.filterField = f.Field + a.filterValue = f.Value + a.filtered = true + } else { + a.filterField = "" + a.filterValue = "" + a.filtered = false + } return nil } // loadWithFilter is the shared implementation used by LoadPolicy and // LoadFilteredPolicy. field must be a pre-validated column name (v0-v5) or -// empty; the backtick quoting is an additional defence-in-depth measure. +// empty. clause.Eq is used for the WHERE clause so that identifier quoting is +// handled by GORM's dialector, making it correct on MySQL, PostgreSQL and SQLite. func (a *gormAdapter) loadWithFilter(mdl model.Model, field, value string) error { q := a.table() if field != "" && value != "" { - // field is already validated against validFilterFields (v0-v5), so - // backtick-quoting is safe and provides defence-in-depth against any - // future code path that might supply an unvalidated column name. - q = q.Where("`"+field+"` = ?", value) + q = q.Where(clause.Eq{Column: clause.Column{Name: field}, Value: value}) } var rules []casbinRule if err := q.Find(&rules).Error; err != nil { @@ -227,7 +272,7 @@ func (a *gormAdapter) SavePolicy(mdl model.Model) error { if a.IsFiltered() { // Delete only rows belonging to this tenant, then re-insert. - if err := a.table().Where(a.filterField+" = ?", a.filterValue).Delete(&casbinRule{}).Error; err != nil { + if err := a.table().Where(clause.Eq{Column: clause.Column{Name: a.filterField}, Value: a.filterValue}).Delete(&casbinRule{}).Error; err != nil { return err } } else { @@ -243,14 +288,43 @@ func (a *gormAdapter) SavePolicy(mdl model.Model) error { return nil } +// checkTenantScope returns an error when the adapter is in filtered mode and +// the rule's field at the filter index does not match filterValue. This +// prevents accidental cross-tenant writes via AddPolicy / RemovePolicy. +func (a *gormAdapter) checkTenantScope(rule []string) error { + if a.filterField == "" { + return nil + } + // filterField is validated to be "v0"–"v5"; extract the numeric index. + fieldIdx := int(a.filterField[1] - '0') + if fieldIdx >= len(rule) { + return fmt.Errorf("gorm casbin adapter: rule has %d field(s), filter requires %s", len(rule), a.filterField) + } + if rule[fieldIdx] != a.filterValue { + return fmt.Errorf("gorm casbin adapter: rule %s=%q does not match tenant filter %s=%q; cross-tenant writes are not allowed", + a.filterField, rule[fieldIdx], a.filterField, a.filterValue) + } + return nil +} + // AddPolicy adds a policy rule to the database. +// When a tenant filter is active the rule is validated to ensure it belongs to +// the configured tenant before being written. func (a *gormAdapter) AddPolicy(sec, ptype string, rule []string) error { + if err := a.checkTenantScope(rule); err != nil { + return err + } r := lineToRule(ptype, rule) return a.table().Create(&r).Error } // RemovePolicy removes a policy rule from the database. +// When a tenant filter is active the rule is validated to ensure it belongs to +// the configured tenant before deletion is attempted. func (a *gormAdapter) RemovePolicy(sec, ptype string, rule []string) error { + if err := a.checkTenantScope(rule); err != nil { + return err + } r := lineToRule(ptype, rule) return a.table().Where(&r).Delete(&casbinRule{}).Error } diff --git a/internal/module_casbin_storage_test.go b/internal/module_casbin_storage_test.go index 5e1b738..48d39ad 100644 --- a/internal/module_casbin_storage_test.go +++ b/internal/module_casbin_storage_test.go @@ -242,6 +242,38 @@ func TestGORMAdapter_FilterField_InvalidField(t *testing.T) { } } +// TestGORMAdapter_FilterField_PartialConfig checks that specifying only one of +// filter_field / filter_value is rejected; both must be set or neither. +func TestGORMAdapter_FilterField_PartialConfig(t *testing.T) { + for _, tc := range []struct { + name string + field string + value string + }{ + {"field only", "v0", ""}, + {"value only", "", "tenant_a"}, + } { + t.Run(tc.name, func(t *testing.T) { + m, err := newCasbinModule("authz", map[string]any{ + "model": testModel, + "adapter": map[string]any{ + "type": "gorm", + "driver": "sqlite3", + "dsn": ":memory:", + "filter_field": tc.field, + "filter_value": tc.value, + }, + }) + if err != nil { + t.Fatalf("newCasbinModule: %v", err) + } + if err := m.Init(); err == nil { + t.Error("expected Init to fail for partial filter config") + } + }) + } +} + // TestGORMAdapter_InvalidTableName checks that a table name with characters // unsafe for use as a raw SQL identifier is rejected at Init time. func TestGORMAdapter_InvalidTableName(t *testing.T) { @@ -436,7 +468,37 @@ func TestGORMAdapter_TenantFilter_MutationIsolation(t *testing.T) { } } -// --- Option B: per-tenant table GORM tests --- +// TestGORMAdapter_TenantFilter_CrossTenantWriteRejected verifies that AddPolicy +// and RemovePolicy reject rules whose tenant field does not match the adapter's +// filter value, preventing accidental cross-tenant writes. +func TestGORMAdapter_TenantFilter_CrossTenantWriteRejected(t *testing.T) { + m, err := newCasbinModule("authz_a", map[string]any{ + "model": testModel, + "adapter": map[string]any{ + "type": "gorm", + "driver": "sqlite3", + "dsn": ":memory:", + "filter_field": "v0", + "filter_value": "tenant_a", + }, + }) + if err != nil { + t.Fatalf("newCasbinModule: %v", err) + } + if err := m.Init(); err != nil { + t.Fatalf("Init: %v", err) + } + + // Attempting to add a rule for a different tenant must be rejected. + if _, err := m.AddPolicy([]string{"tenant_b", "/api", "GET"}); err == nil { + t.Error("expected AddPolicy to fail for cross-tenant rule (tenant_b via tenant_a adapter)") + } + + // Attempting to remove a rule for a different tenant must also be rejected. + if _, err := m.RemovePolicy([]string{"tenant_b", "/api", "GET"}); err == nil { + t.Error("expected RemovePolicy to fail for cross-tenant rule") + } +} // TestGORMAdapter_PerTenantTable verifies that two modules configured with // different table names have completely independent policy stores. diff --git a/internal/sqlite_dialector.go b/internal/sqlite_dialector.go index 864e152..d196459 100644 --- a/internal/sqlite_dialector.go +++ b/internal/sqlite_dialector.go @@ -73,8 +73,8 @@ func (d *sqliteDialector) Migrator(db *gorm.DB) gorm.Migrator { } // sqliteMigrator extends the generic GORM migrator with SQLite-specific -// implementations for HasTable and HasIndex. The generic migrator queries -// information_schema which does not exist in SQLite. +// implementations for HasTable, HasIndex, and HasColumn. The generic migrator +// queries information_schema which does not exist in SQLite. type sqliteMigrator struct { migrator.Migrator } From c4437664b66fa65fd381ee54e5725193de831d6c Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 12 Mar 2026 16:12:11 +0000 Subject: [PATCH 4/4] style: fix gofmt issues in module_casbin_storage_test.go Co-authored-by: intel352 <77607+intel352@users.noreply.github.com> --- internal/module_casbin_storage_test.go | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/internal/module_casbin_storage_test.go b/internal/module_casbin_storage_test.go index 48d39ad..8b143c3 100644 --- a/internal/module_casbin_storage_test.go +++ b/internal/module_casbin_storage_test.go @@ -246,9 +246,9 @@ func TestGORMAdapter_FilterField_InvalidField(t *testing.T) { // filter_field / filter_value is rejected; both must be set or neither. func TestGORMAdapter_FilterField_PartialConfig(t *testing.T) { for _, tc := range []struct { - name string - field string - value string + name string + field string + value string }{ {"field only", "v0", ""}, {"value only", "", "tenant_a"}, @@ -604,8 +604,6 @@ func TestGORMAdapter_TableNameTemplate_Invalid(t *testing.T) { } } - - func TestInMemoryAdapter_AddRemovePolicy(t *testing.T) { m := buildModule(t, [][]string{{"admin", "/api", "GET"}},