@@ -3,6 +3,12 @@ package internal
33// gormCasbinAdapter is a minimal casbin persist.Adapter backed by gorm.
44// It replaces casbin/gorm-adapter/v3 to avoid the duplicate sqlite driver
55// registration conflict between glebarez/go-sqlite and modernc.org/sqlite.
6+ //
7+ // Option A – tenant filter: construct the adapter with a filterField / filterValue
8+ // to apply a WHERE clause on every LoadPolicy call, isolating one tenant's rows.
9+ //
10+ // Option B – per-tenant table: pass a resolved table name (e.g. "casbin_rule_acme")
11+ // so each tenant's policies live in a dedicated table.
612
713import (
814 "fmt"
@@ -11,44 +17,235 @@ import (
1117 "github.com/casbin/casbin/v2/model"
1218 "github.com/casbin/casbin/v2/persist"
1319 "gorm.io/gorm"
20+ "gorm.io/gorm/clause"
1421)
1522
23+ // identWriter wraps strings.Builder to satisfy clause.Writer (which requires
24+ // both WriteString and WriteByte). It is used to invoke dialector.QuoteTo so
25+ // that identifier quoting is database-specific (backticks for MySQL/SQLite,
26+ // double-quotes for PostgreSQL).
27+ type identWriter struct { strings.Builder }
28+
29+ func (w * identWriter ) WriteByte (b byte ) error { return w .Builder .WriteByte (b ) }
30+
31+ // quoteIdent returns name quoted as an identifier using the dialector's own
32+ // quoting rules, making SQL safe to use on MySQL, PostgreSQL and SQLite.
33+ func quoteIdent (db * gorm.DB , name string ) string {
34+ var w identWriter
35+ db .Dialector .QuoteTo (& w , name )
36+ return w .String ()
37+ }
38+
39+ // validFilterFields is the set of column names allowed in a filter to prevent
40+ // SQL injection when building dynamic WHERE clauses.
41+ var validFilterFields = map [string ]bool {
42+ "v0" : true , "v1" : true , "v2" : true ,
43+ "v3" : true , "v4" : true , "v5" : true ,
44+ }
45+
46+ // GORMFilter specifies a WHERE clause for tenant-scoped policy loading.
47+ // It is the concrete filter type accepted by gormAdapter.LoadFilteredPolicy.
48+ type GORMFilter struct {
49+ // Field is the column name to filter on (one of "v0" through "v5").
50+ Field string
51+ // Value is the value the column must equal.
52+ Value string
53+ }
54+
1655// casbinRule mirrors the table schema used by the upstream gorm-adapter.
56+ // The composite uniqueness constraint is NOT declared in the struct tags because
57+ // GORM uses the literal tag name as the index name (e.g. "unique_index"), which
58+ // conflicts in SQLite's global index namespace when multiple tenant tables are
59+ // created from the same struct. The index is instead created by migrateTable
60+ // with a per-table name ("uidx_<tableName>").
1761type casbinRule struct {
1862 ID uint `gorm:"primarykey;autoIncrement"`
19- Ptype string `gorm:"size:512;uniqueIndex:unique_index "`
20- V0 string `gorm:"size:512;uniqueIndex:unique_index "`
21- V1 string `gorm:"size:512;uniqueIndex:unique_index "`
22- V2 string `gorm:"size:512;uniqueIndex:unique_index "`
23- V3 string `gorm:"size:512;uniqueIndex:unique_index "`
24- V4 string `gorm:"size:512;uniqueIndex:unique_index "`
25- V5 string `gorm:"size:512;uniqueIndex:unique_index "`
63+ Ptype string `gorm:"size:512"`
64+ V0 string `gorm:"size:512"`
65+ V1 string `gorm:"size:512"`
66+ V2 string `gorm:"size:512"`
67+ V3 string `gorm:"size:512"`
68+ V4 string `gorm:"size:512"`
69+ V5 string `gorm:"size:512"`
2670}
2771
28- // TableName returns the gorm table name.
72+ // TableName returns the default gorm table name.
73+ // When a custom table name is needed the caller uses db.Table(name) explicitly.
2974func (casbinRule ) TableName () string { return "casbin_rule" }
3075
31- // gormAdapter implements persist.Adapter using a gorm.DB.
76+ // gormAdapter implements persist.FilteredAdapter using a gorm.DB.
77+ // When filterField/filterValue are non-empty all load/save operations are
78+ // scoped to rows where filterField = filterValue (Option A).
79+ // The tableName field enables per-tenant tables (Option B).
80+ //
81+ // FilteredAdapter contract (Casbin v2):
82+ // - IsFiltered() must return false before the first LoadPolicy call so that
83+ // NewEnforcer will call LoadPolicy during initialisation.
84+ // - IsFiltered() returns true once a filtered load has been performed, which
85+ // tells the Casbin enforcer not to allow a bulk SavePolicy.
86+ // - The filter is applied automatically inside LoadPolicy when filterField /
87+ // filterValue are configured, so callers need not use LoadFilteredPolicy.
3288type gormAdapter struct {
33- db * gorm.DB
34- tableName string
89+ db * gorm.DB
90+ tableName string
91+ filterField string // Option A: column name (e.g. "v0"); empty = no filter
92+ filterValue string // Option A: value to match
93+ filtered bool // true after the first filtered LoadPolicy; starts false
3594}
3695
37- // newGORMAdapter auto-migrates the casbin_rule table and returns an adapter.
38- func newGORMAdapter (db * gorm.DB , tableName string ) (* gormAdapter , error ) {
96+ // validTableNameRune returns true when ch is allowed in a table name used as a
97+ // raw SQL identifier. Only ASCII letters, digits, underscores and hyphens are
98+ // accepted; this prevents characters that could break %q SQL quoting.
99+ func validTableNameRune (ch rune ) bool {
100+ return (ch >= 'a' && ch <= 'z' ) ||
101+ (ch >= 'A' && ch <= 'Z' ) ||
102+ (ch >= '0' && ch <= '9' ) ||
103+ ch == '_' || ch == '-'
104+ }
105+
106+ // table returns a *gorm.DB scoped to the adapter's table name.
107+ func (a * gormAdapter ) table () * gorm.DB {
108+ return a .db .Table (a .tableName )
109+ }
110+
111+ // newGORMAdapter auto-migrates the table and returns an adapter.
112+ // filterField and filterValue implement Option A (tenant-scoped WHERE clause).
113+ // Pass empty strings to disable filtering.
114+ func newGORMAdapter (db * gorm.DB , tableName , filterField , filterValue string ) (* gormAdapter , error ) {
39115 if tableName == "" {
40116 tableName = "casbin_rule"
41117 }
42- if err := db .AutoMigrate (& casbinRule {}); err != nil {
118+ // Validate tableName to prevent SQL injection via the raw CREATE INDEX SQL
119+ // constructed in migrateTable. Only allow characters that are safe to use
120+ // inside a double-quoted SQL identifier without additional escaping.
121+ for _ , ch := range tableName {
122+ if ! validTableNameRune (ch ) {
123+ return nil , fmt .Errorf ("gorm casbin adapter: invalid character %q in table name %q" , ch , tableName )
124+ }
125+ }
126+ if filterField != "" && ! validFilterFields [filterField ] {
127+ return nil , fmt .Errorf ("gorm casbin adapter: invalid filter_field %q (must be v0-v5)" , filterField )
128+ }
129+ // filterField and filterValue must be set together; partial configuration is
130+ // ambiguous and would silently skip filtering while looking like it is enabled.
131+ if (filterField == "" ) != (filterValue == "" ) {
132+ return nil , fmt .Errorf ("gorm casbin adapter: filter_field and filter_value must both be set or both be empty" )
133+ }
134+ if err := migrateTable (db , tableName ); err != nil {
43135 return nil , fmt .Errorf ("gorm casbin adapter: migrate: %w" , err )
44136 }
45- return & gormAdapter {db : db , tableName : tableName }, nil
137+ return & gormAdapter {
138+ db : db ,
139+ tableName : tableName ,
140+ filterField : filterField ,
141+ filterValue : filterValue ,
142+ }, nil
143+ }
144+
145+ // migrateTable creates the casbin rule table for tableName.
146+ //
147+ // The composite unique constraint is created as a named index "uidx_<tableName>"
148+ // rather than via struct tags so that each tenant table gets its own index name.
149+ // This avoids SQLite's flat index namespace where two tables migrated from the
150+ // same struct with the same named index tag would conflict.
151+ //
152+ // On SQLite, AutoMigrate is skipped when the table already exists because the
153+ // generic GORM migrator's HasColumn falls back to information_schema (which does
154+ // not exist in SQLite), causing unintended ALTER TABLE attempts. On PostgreSQL
155+ // and MySQL, AutoMigrate always runs so that schema drift is corrected.
156+ func migrateTable (db * gorm.DB , tableName string ) error {
157+ isSQLite := db .Dialector .Name () == "sqlite"
158+ tableExists := db .Migrator ().HasTable (tableName )
159+
160+ if ! isSQLite || ! tableExists {
161+ if err := db .Table (tableName ).AutoMigrate (& casbinRule {}); err != nil {
162+ return err
163+ }
164+ }
165+ // Create a composite unique index with a per-table name.
166+ // We check for existence first so the operation is idempotent on all
167+ // supported databases (SQLite, PostgreSQL, MySQL).
168+ idxName := "uidx_" + strings .ReplaceAll (tableName , "-" , "_" )
169+ if db .Migrator ().HasIndex (tableName , idxName ) {
170+ return nil
171+ }
172+ // Use the dialector's identifier quoting so the SQL is correct for every
173+ // supported database: backticks on MySQL/SQLite, double-quotes on PostgreSQL.
174+ cols := []string {"ptype" , "v0" , "v1" , "v2" , "v3" , "v4" , "v5" }
175+ quotedCols := make ([]string , len (cols ))
176+ for i , c := range cols {
177+ quotedCols [i ] = quoteIdent (db , c )
178+ }
179+ return db .Exec (fmt .Sprintf (
180+ "CREATE UNIQUE INDEX %s ON %s (%s)" ,
181+ quoteIdent (db , idxName ),
182+ quoteIdent (db , tableName ),
183+ strings .Join (quotedCols , ", " ),
184+ )).Error
185+ }
186+
187+ // IsFiltered returns true once a filtered LoadPolicy has been performed.
188+ // Before the first load it returns false so that casbin.NewEnforcer will call
189+ // LoadPolicy during initialisation (Casbin skips LoadPolicy when IsFiltered is
190+ // true at construction time).
191+ func (a * gormAdapter ) IsFiltered () bool {
192+ return a .filtered
46193}
47194
48- // LoadPolicy loads all policies from the database into the model.
195+ // LoadPolicy loads all (or filtered) policies from the database into the model.
196+ // When filterField/filterValue are configured the WHERE clause is applied and
197+ // filtered is set to true so that subsequent SavePolicy calls are scoped.
49198func (a * gormAdapter ) LoadPolicy (mdl model.Model ) error {
199+ if err := a .loadWithFilter (mdl , a .filterField , a .filterValue ); err != nil {
200+ return err
201+ }
202+ if a .filterField != "" && a .filterValue != "" {
203+ a .filtered = true
204+ }
205+ return nil
206+ }
207+
208+ // LoadFilteredPolicy loads only policies matching the supplied filter.
209+ // filter must be a GORMFilter value. On success the active filter is stored on
210+ // the adapter so that subsequent SavePolicy calls are correctly scoped, and
211+ // IsFiltered is set to true.
212+ func (a * gormAdapter ) LoadFilteredPolicy (mdl model.Model , filter interface {}) error {
213+ f , ok := filter .(GORMFilter )
214+ if ! ok {
215+ return fmt .Errorf ("gorm casbin adapter: LoadFilteredPolicy expects GORMFilter, got %T" , filter )
216+ }
217+ if f .Field != "" && ! validFilterFields [f .Field ] {
218+ return fmt .Errorf ("gorm casbin adapter: invalid filter field %q (must be v0-v5)" , f .Field )
219+ }
220+ if err := a .loadWithFilter (mdl , f .Field , f .Value ); err != nil {
221+ return err
222+ }
223+ // Persist the active filter so that subsequent SavePolicy calls can correctly
224+ // scope their operations. When no effective filter was provided, clear any
225+ // previously stored filter to avoid inconsistent SavePolicy behavior.
226+ if f .Field != "" && f .Value != "" {
227+ a .filterField = f .Field
228+ a .filterValue = f .Value
229+ a .filtered = true
230+ } else {
231+ a .filterField = ""
232+ a .filterValue = ""
233+ a .filtered = false
234+ }
235+ return nil
236+ }
237+
238+ // loadWithFilter is the shared implementation used by LoadPolicy and
239+ // LoadFilteredPolicy. field must be a pre-validated column name (v0-v5) or
240+ // empty. clause.Eq is used for the WHERE clause so that identifier quoting is
241+ // handled by GORM's dialector, making it correct on MySQL, PostgreSQL and SQLite.
242+ func (a * gormAdapter ) loadWithFilter (mdl model.Model , field , value string ) error {
243+ q := a .table ()
244+ if field != "" && value != "" {
245+ q = q .Where (clause.Eq {Column : clause.Column {Name : field }, Value : value })
246+ }
50247 var rules []casbinRule
51- if err := a . db .Find (& rules ).Error ; err != nil {
248+ if err := q .Find (& rules ).Error ; err != nil {
52249 return err
53250 }
54251 for _ , rule := range rules {
@@ -58,10 +255,9 @@ func (a *gormAdapter) LoadPolicy(mdl model.Model) error {
58255}
59256
60257// SavePolicy saves all policies from the model into the database.
258+ // When a tenant filter is active only the matching rows are replaced, so that
259+ // other tenants' data is not affected.
61260func (a * gormAdapter ) SavePolicy (mdl model.Model ) error {
62- if err := a .db .Session (& gorm.Session {AllowGlobalUpdate : true }).Delete (& casbinRule {}).Error ; err != nil {
63- return err
64- }
65261 var rules []casbinRule
66262 for ptype , assertions := range mdl ["p" ] {
67263 for _ , assertion := range assertions .Policy {
@@ -73,27 +269,69 @@ func (a *gormAdapter) SavePolicy(mdl model.Model) error {
73269 rules = append (rules , lineToRule (ptype , assertion ))
74270 }
75271 }
272+
273+ if a .IsFiltered () {
274+ // Delete only rows belonging to this tenant, then re-insert.
275+ if err := a .table ().Where (clause.Eq {Column : clause.Column {Name : a .filterField }, Value : a .filterValue }).Delete (& casbinRule {}).Error ; err != nil {
276+ return err
277+ }
278+ } else {
279+ // Delete all rows in the table.
280+ if err := a .table ().Session (& gorm.Session {AllowGlobalUpdate : true }).Delete (& casbinRule {}).Error ; err != nil {
281+ return err
282+ }
283+ }
284+
76285 if len (rules ) > 0 {
77- return a .db .CreateInBatches (rules , 100 ).Error
286+ return a .table ().CreateInBatches (rules , 100 ).Error
287+ }
288+ return nil
289+ }
290+
291+ // checkTenantScope returns an error when the adapter is in filtered mode and
292+ // the rule's field at the filter index does not match filterValue. This
293+ // prevents accidental cross-tenant writes via AddPolicy / RemovePolicy.
294+ func (a * gormAdapter ) checkTenantScope (rule []string ) error {
295+ if a .filterField == "" {
296+ return nil
297+ }
298+ // filterField is validated to be "v0"–"v5"; extract the numeric index.
299+ fieldIdx := int (a .filterField [1 ] - '0' )
300+ if fieldIdx >= len (rule ) {
301+ return fmt .Errorf ("gorm casbin adapter: rule has %d field(s), filter requires %s" , len (rule ), a .filterField )
302+ }
303+ if rule [fieldIdx ] != a .filterValue {
304+ return fmt .Errorf ("gorm casbin adapter: rule %s=%q does not match tenant filter %s=%q; cross-tenant writes are not allowed" ,
305+ a .filterField , rule [fieldIdx ], a .filterField , a .filterValue )
78306 }
79307 return nil
80308}
81309
82310// AddPolicy adds a policy rule to the database.
311+ // When a tenant filter is active the rule is validated to ensure it belongs to
312+ // the configured tenant before being written.
83313func (a * gormAdapter ) AddPolicy (sec , ptype string , rule []string ) error {
314+ if err := a .checkTenantScope (rule ); err != nil {
315+ return err
316+ }
84317 r := lineToRule (ptype , rule )
85- return a .db .Create (& r ).Error
318+ return a .table () .Create (& r ).Error
86319}
87320
88321// RemovePolicy removes a policy rule from the database.
322+ // When a tenant filter is active the rule is validated to ensure it belongs to
323+ // the configured tenant before deletion is attempted.
89324func (a * gormAdapter ) RemovePolicy (sec , ptype string , rule []string ) error {
325+ if err := a .checkTenantScope (rule ); err != nil {
326+ return err
327+ }
90328 r := lineToRule (ptype , rule )
91- return a .db .Where (& r ).Delete (& casbinRule {}).Error
329+ return a .table () .Where (& r ).Delete (& casbinRule {}).Error
92330}
93331
94332// RemoveFilteredPolicy removes policy rules matching the given filter.
95333func (a * gormAdapter ) RemoveFilteredPolicy (sec , ptype string , fieldIndex int , fieldValues ... string ) error {
96- query := a .db .Where ("ptype = ?" , ptype )
334+ query := a .table () .Where ("ptype = ?" , ptype )
97335 fields := []string {"v0" , "v1" , "v2" , "v3" , "v4" , "v5" }
98336 for i , v := range fieldValues {
99337 if v != "" {
@@ -129,5 +367,6 @@ func lineToRule(ptype string, rule []string) casbinRule {
129367 return r
130368}
131369
132- // Compile-time interface check.
133- var _ persist.Adapter = (* gormAdapter )(nil )
370+ // Compile-time interface check – gormAdapter must satisfy FilteredAdapter
371+ // (which is a superset of Adapter).
372+ var _ persist.FilteredAdapter = (* gormAdapter )(nil )
0 commit comments