Skip to content

Commit afcbce3

Browse files
Copilotintel352
andauthored
feat: tenant-scoped policy loading for GORM adapter (filter + per-tenant table) (#14)
* Initial plan * feat: tenant-scoped policy loading for GORM adapter (Option A + Option B) Co-authored-by: intel352 <77607+intel352@users.noreply.github.com> * fix: address code review comments - identifier quoting, filter validation, tenant scope Co-authored-by: intel352 <77607+intel352@users.noreply.github.com> * style: fix gofmt issues in module_casbin_storage_test.go Co-authored-by: intel352 <77607+intel352@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: intel352 <77607+intel352@users.noreply.github.com>
1 parent 026af27 commit afcbce3

4 files changed

Lines changed: 749 additions & 35 deletions

File tree

internal/gorm_adapter.go

Lines changed: 265 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,12 @@ package internal
33
// gormCasbinAdapter is a minimal casbin persist.Adapter backed by gorm.
44
// It replaces casbin/gorm-adapter/v3 to avoid the duplicate sqlite driver
55
// registration conflict between glebarez/go-sqlite and modernc.org/sqlite.
6+
//
7+
// Option A – tenant filter: construct the adapter with a filterField / filterValue
8+
// to apply a WHERE clause on every LoadPolicy call, isolating one tenant's rows.
9+
//
10+
// Option B – per-tenant table: pass a resolved table name (e.g. "casbin_rule_acme")
11+
// so each tenant's policies live in a dedicated table.
612

713
import (
814
"fmt"
@@ -11,44 +17,235 @@ import (
1117
"github.com/casbin/casbin/v2/model"
1218
"github.com/casbin/casbin/v2/persist"
1319
"gorm.io/gorm"
20+
"gorm.io/gorm/clause"
1421
)
1522

23+
// identWriter wraps strings.Builder to satisfy clause.Writer (which requires
24+
// both WriteString and WriteByte). It is used to invoke dialector.QuoteTo so
25+
// that identifier quoting is database-specific (backticks for MySQL/SQLite,
26+
// double-quotes for PostgreSQL).
27+
type identWriter struct{ strings.Builder }
28+
29+
func (w *identWriter) WriteByte(b byte) error { return w.Builder.WriteByte(b) }
30+
31+
// quoteIdent returns name quoted as an identifier using the dialector's own
32+
// quoting rules, making SQL safe to use on MySQL, PostgreSQL and SQLite.
33+
func quoteIdent(db *gorm.DB, name string) string {
34+
var w identWriter
35+
db.Dialector.QuoteTo(&w, name)
36+
return w.String()
37+
}
38+
39+
// validFilterFields is the set of column names allowed in a filter to prevent
40+
// SQL injection when building dynamic WHERE clauses.
41+
var validFilterFields = map[string]bool{
42+
"v0": true, "v1": true, "v2": true,
43+
"v3": true, "v4": true, "v5": true,
44+
}
45+
46+
// GORMFilter specifies a WHERE clause for tenant-scoped policy loading.
47+
// It is the concrete filter type accepted by gormAdapter.LoadFilteredPolicy.
48+
type GORMFilter struct {
49+
// Field is the column name to filter on (one of "v0" through "v5").
50+
Field string
51+
// Value is the value the column must equal.
52+
Value string
53+
}
54+
1655
// casbinRule mirrors the table schema used by the upstream gorm-adapter.
56+
// The composite uniqueness constraint is NOT declared in the struct tags because
57+
// GORM uses the literal tag name as the index name (e.g. "unique_index"), which
58+
// conflicts in SQLite's global index namespace when multiple tenant tables are
59+
// created from the same struct. The index is instead created by migrateTable
60+
// with a per-table name ("uidx_<tableName>").
1761
type casbinRule struct {
1862
ID uint `gorm:"primarykey;autoIncrement"`
19-
Ptype string `gorm:"size:512;uniqueIndex:unique_index"`
20-
V0 string `gorm:"size:512;uniqueIndex:unique_index"`
21-
V1 string `gorm:"size:512;uniqueIndex:unique_index"`
22-
V2 string `gorm:"size:512;uniqueIndex:unique_index"`
23-
V3 string `gorm:"size:512;uniqueIndex:unique_index"`
24-
V4 string `gorm:"size:512;uniqueIndex:unique_index"`
25-
V5 string `gorm:"size:512;uniqueIndex:unique_index"`
63+
Ptype string `gorm:"size:512"`
64+
V0 string `gorm:"size:512"`
65+
V1 string `gorm:"size:512"`
66+
V2 string `gorm:"size:512"`
67+
V3 string `gorm:"size:512"`
68+
V4 string `gorm:"size:512"`
69+
V5 string `gorm:"size:512"`
2670
}
2771

28-
// TableName returns the gorm table name.
72+
// TableName returns the default gorm table name.
73+
// When a custom table name is needed the caller uses db.Table(name) explicitly.
2974
func (casbinRule) TableName() string { return "casbin_rule" }
3075

31-
// gormAdapter implements persist.Adapter using a gorm.DB.
76+
// gormAdapter implements persist.FilteredAdapter using a gorm.DB.
77+
// When filterField/filterValue are non-empty all load/save operations are
78+
// scoped to rows where filterField = filterValue (Option A).
79+
// The tableName field enables per-tenant tables (Option B).
80+
//
81+
// FilteredAdapter contract (Casbin v2):
82+
// - IsFiltered() must return false before the first LoadPolicy call so that
83+
// NewEnforcer will call LoadPolicy during initialisation.
84+
// - IsFiltered() returns true once a filtered load has been performed, which
85+
// tells the Casbin enforcer not to allow a bulk SavePolicy.
86+
// - The filter is applied automatically inside LoadPolicy when filterField /
87+
// filterValue are configured, so callers need not use LoadFilteredPolicy.
3288
type gormAdapter struct {
33-
db *gorm.DB
34-
tableName string
89+
db *gorm.DB
90+
tableName string
91+
filterField string // Option A: column name (e.g. "v0"); empty = no filter
92+
filterValue string // Option A: value to match
93+
filtered bool // true after the first filtered LoadPolicy; starts false
3594
}
3695

37-
// newGORMAdapter auto-migrates the casbin_rule table and returns an adapter.
38-
func newGORMAdapter(db *gorm.DB, tableName string) (*gormAdapter, error) {
96+
// validTableNameRune returns true when ch is allowed in a table name used as a
97+
// raw SQL identifier. Only ASCII letters, digits, underscores and hyphens are
98+
// accepted; this prevents characters that could break %q SQL quoting.
99+
func validTableNameRune(ch rune) bool {
100+
return (ch >= 'a' && ch <= 'z') ||
101+
(ch >= 'A' && ch <= 'Z') ||
102+
(ch >= '0' && ch <= '9') ||
103+
ch == '_' || ch == '-'
104+
}
105+
106+
// table returns a *gorm.DB scoped to the adapter's table name.
107+
func (a *gormAdapter) table() *gorm.DB {
108+
return a.db.Table(a.tableName)
109+
}
110+
111+
// newGORMAdapter auto-migrates the table and returns an adapter.
112+
// filterField and filterValue implement Option A (tenant-scoped WHERE clause).
113+
// Pass empty strings to disable filtering.
114+
func newGORMAdapter(db *gorm.DB, tableName, filterField, filterValue string) (*gormAdapter, error) {
39115
if tableName == "" {
40116
tableName = "casbin_rule"
41117
}
42-
if err := db.AutoMigrate(&casbinRule{}); err != nil {
118+
// Validate tableName to prevent SQL injection via the raw CREATE INDEX SQL
119+
// constructed in migrateTable. Only allow characters that are safe to use
120+
// inside a double-quoted SQL identifier without additional escaping.
121+
for _, ch := range tableName {
122+
if !validTableNameRune(ch) {
123+
return nil, fmt.Errorf("gorm casbin adapter: invalid character %q in table name %q", ch, tableName)
124+
}
125+
}
126+
if filterField != "" && !validFilterFields[filterField] {
127+
return nil, fmt.Errorf("gorm casbin adapter: invalid filter_field %q (must be v0-v5)", filterField)
128+
}
129+
// filterField and filterValue must be set together; partial configuration is
130+
// ambiguous and would silently skip filtering while looking like it is enabled.
131+
if (filterField == "") != (filterValue == "") {
132+
return nil, fmt.Errorf("gorm casbin adapter: filter_field and filter_value must both be set or both be empty")
133+
}
134+
if err := migrateTable(db, tableName); err != nil {
43135
return nil, fmt.Errorf("gorm casbin adapter: migrate: %w", err)
44136
}
45-
return &gormAdapter{db: db, tableName: tableName}, nil
137+
return &gormAdapter{
138+
db: db,
139+
tableName: tableName,
140+
filterField: filterField,
141+
filterValue: filterValue,
142+
}, nil
143+
}
144+
145+
// migrateTable creates the casbin rule table for tableName.
146+
//
147+
// The composite unique constraint is created as a named index "uidx_<tableName>"
148+
// rather than via struct tags so that each tenant table gets its own index name.
149+
// This avoids SQLite's flat index namespace where two tables migrated from the
150+
// same struct with the same named index tag would conflict.
151+
//
152+
// On SQLite, AutoMigrate is skipped when the table already exists because the
153+
// generic GORM migrator's HasColumn falls back to information_schema (which does
154+
// not exist in SQLite), causing unintended ALTER TABLE attempts. On PostgreSQL
155+
// and MySQL, AutoMigrate always runs so that schema drift is corrected.
156+
func migrateTable(db *gorm.DB, tableName string) error {
157+
isSQLite := db.Dialector.Name() == "sqlite"
158+
tableExists := db.Migrator().HasTable(tableName)
159+
160+
if !isSQLite || !tableExists {
161+
if err := db.Table(tableName).AutoMigrate(&casbinRule{}); err != nil {
162+
return err
163+
}
164+
}
165+
// Create a composite unique index with a per-table name.
166+
// We check for existence first so the operation is idempotent on all
167+
// supported databases (SQLite, PostgreSQL, MySQL).
168+
idxName := "uidx_" + strings.ReplaceAll(tableName, "-", "_")
169+
if db.Migrator().HasIndex(tableName, idxName) {
170+
return nil
171+
}
172+
// Use the dialector's identifier quoting so the SQL is correct for every
173+
// supported database: backticks on MySQL/SQLite, double-quotes on PostgreSQL.
174+
cols := []string{"ptype", "v0", "v1", "v2", "v3", "v4", "v5"}
175+
quotedCols := make([]string, len(cols))
176+
for i, c := range cols {
177+
quotedCols[i] = quoteIdent(db, c)
178+
}
179+
return db.Exec(fmt.Sprintf(
180+
"CREATE UNIQUE INDEX %s ON %s (%s)",
181+
quoteIdent(db, idxName),
182+
quoteIdent(db, tableName),
183+
strings.Join(quotedCols, ", "),
184+
)).Error
185+
}
186+
187+
// IsFiltered returns true once a filtered LoadPolicy has been performed.
188+
// Before the first load it returns false so that casbin.NewEnforcer will call
189+
// LoadPolicy during initialisation (Casbin skips LoadPolicy when IsFiltered is
190+
// true at construction time).
191+
func (a *gormAdapter) IsFiltered() bool {
192+
return a.filtered
46193
}
47194

48-
// LoadPolicy loads all policies from the database into the model.
195+
// LoadPolicy loads all (or filtered) policies from the database into the model.
196+
// When filterField/filterValue are configured the WHERE clause is applied and
197+
// filtered is set to true so that subsequent SavePolicy calls are scoped.
49198
func (a *gormAdapter) LoadPolicy(mdl model.Model) error {
199+
if err := a.loadWithFilter(mdl, a.filterField, a.filterValue); err != nil {
200+
return err
201+
}
202+
if a.filterField != "" && a.filterValue != "" {
203+
a.filtered = true
204+
}
205+
return nil
206+
}
207+
208+
// LoadFilteredPolicy loads only policies matching the supplied filter.
209+
// filter must be a GORMFilter value. On success the active filter is stored on
210+
// the adapter so that subsequent SavePolicy calls are correctly scoped, and
211+
// IsFiltered is set to true.
212+
func (a *gormAdapter) LoadFilteredPolicy(mdl model.Model, filter interface{}) error {
213+
f, ok := filter.(GORMFilter)
214+
if !ok {
215+
return fmt.Errorf("gorm casbin adapter: LoadFilteredPolicy expects GORMFilter, got %T", filter)
216+
}
217+
if f.Field != "" && !validFilterFields[f.Field] {
218+
return fmt.Errorf("gorm casbin adapter: invalid filter field %q (must be v0-v5)", f.Field)
219+
}
220+
if err := a.loadWithFilter(mdl, f.Field, f.Value); err != nil {
221+
return err
222+
}
223+
// Persist the active filter so that subsequent SavePolicy calls can correctly
224+
// scope their operations. When no effective filter was provided, clear any
225+
// previously stored filter to avoid inconsistent SavePolicy behavior.
226+
if f.Field != "" && f.Value != "" {
227+
a.filterField = f.Field
228+
a.filterValue = f.Value
229+
a.filtered = true
230+
} else {
231+
a.filterField = ""
232+
a.filterValue = ""
233+
a.filtered = false
234+
}
235+
return nil
236+
}
237+
238+
// loadWithFilter is the shared implementation used by LoadPolicy and
239+
// LoadFilteredPolicy. field must be a pre-validated column name (v0-v5) or
240+
// empty. clause.Eq is used for the WHERE clause so that identifier quoting is
241+
// handled by GORM's dialector, making it correct on MySQL, PostgreSQL and SQLite.
242+
func (a *gormAdapter) loadWithFilter(mdl model.Model, field, value string) error {
243+
q := a.table()
244+
if field != "" && value != "" {
245+
q = q.Where(clause.Eq{Column: clause.Column{Name: field}, Value: value})
246+
}
50247
var rules []casbinRule
51-
if err := a.db.Find(&rules).Error; err != nil {
248+
if err := q.Find(&rules).Error; err != nil {
52249
return err
53250
}
54251
for _, rule := range rules {
@@ -58,10 +255,9 @@ func (a *gormAdapter) LoadPolicy(mdl model.Model) error {
58255
}
59256

60257
// SavePolicy saves all policies from the model into the database.
258+
// When a tenant filter is active only the matching rows are replaced, so that
259+
// other tenants' data is not affected.
61260
func (a *gormAdapter) SavePolicy(mdl model.Model) error {
62-
if err := a.db.Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&casbinRule{}).Error; err != nil {
63-
return err
64-
}
65261
var rules []casbinRule
66262
for ptype, assertions := range mdl["p"] {
67263
for _, assertion := range assertions.Policy {
@@ -73,27 +269,69 @@ func (a *gormAdapter) SavePolicy(mdl model.Model) error {
73269
rules = append(rules, lineToRule(ptype, assertion))
74270
}
75271
}
272+
273+
if a.IsFiltered() {
274+
// Delete only rows belonging to this tenant, then re-insert.
275+
if err := a.table().Where(clause.Eq{Column: clause.Column{Name: a.filterField}, Value: a.filterValue}).Delete(&casbinRule{}).Error; err != nil {
276+
return err
277+
}
278+
} else {
279+
// Delete all rows in the table.
280+
if err := a.table().Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&casbinRule{}).Error; err != nil {
281+
return err
282+
}
283+
}
284+
76285
if len(rules) > 0 {
77-
return a.db.CreateInBatches(rules, 100).Error
286+
return a.table().CreateInBatches(rules, 100).Error
287+
}
288+
return nil
289+
}
290+
291+
// checkTenantScope returns an error when the adapter is in filtered mode and
292+
// the rule's field at the filter index does not match filterValue. This
293+
// prevents accidental cross-tenant writes via AddPolicy / RemovePolicy.
294+
func (a *gormAdapter) checkTenantScope(rule []string) error {
295+
if a.filterField == "" {
296+
return nil
297+
}
298+
// filterField is validated to be "v0"–"v5"; extract the numeric index.
299+
fieldIdx := int(a.filterField[1] - '0')
300+
if fieldIdx >= len(rule) {
301+
return fmt.Errorf("gorm casbin adapter: rule has %d field(s), filter requires %s", len(rule), a.filterField)
302+
}
303+
if rule[fieldIdx] != a.filterValue {
304+
return fmt.Errorf("gorm casbin adapter: rule %s=%q does not match tenant filter %s=%q; cross-tenant writes are not allowed",
305+
a.filterField, rule[fieldIdx], a.filterField, a.filterValue)
78306
}
79307
return nil
80308
}
81309

82310
// AddPolicy adds a policy rule to the database.
311+
// When a tenant filter is active the rule is validated to ensure it belongs to
312+
// the configured tenant before being written.
83313
func (a *gormAdapter) AddPolicy(sec, ptype string, rule []string) error {
314+
if err := a.checkTenantScope(rule); err != nil {
315+
return err
316+
}
84317
r := lineToRule(ptype, rule)
85-
return a.db.Create(&r).Error
318+
return a.table().Create(&r).Error
86319
}
87320

88321
// RemovePolicy removes a policy rule from the database.
322+
// When a tenant filter is active the rule is validated to ensure it belongs to
323+
// the configured tenant before deletion is attempted.
89324
func (a *gormAdapter) RemovePolicy(sec, ptype string, rule []string) error {
325+
if err := a.checkTenantScope(rule); err != nil {
326+
return err
327+
}
90328
r := lineToRule(ptype, rule)
91-
return a.db.Where(&r).Delete(&casbinRule{}).Error
329+
return a.table().Where(&r).Delete(&casbinRule{}).Error
92330
}
93331

94332
// RemoveFilteredPolicy removes policy rules matching the given filter.
95333
func (a *gormAdapter) RemoveFilteredPolicy(sec, ptype string, fieldIndex int, fieldValues ...string) error {
96-
query := a.db.Where("ptype = ?", ptype)
334+
query := a.table().Where("ptype = ?", ptype)
97335
fields := []string{"v0", "v1", "v2", "v3", "v4", "v5"}
98336
for i, v := range fieldValues {
99337
if v != "" {
@@ -129,5 +367,6 @@ func lineToRule(ptype string, rule []string) casbinRule {
129367
return r
130368
}
131369

132-
// Compile-time interface check.
133-
var _ persist.Adapter = (*gormAdapter)(nil)
370+
// Compile-time interface check – gormAdapter must satisfy FilteredAdapter
371+
// (which is a superset of Adapter).
372+
var _ persist.FilteredAdapter = (*gormAdapter)(nil)

0 commit comments

Comments
 (0)