Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
291 changes: 265 additions & 26 deletions internal/gorm_adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@ package internal
// gormCasbinAdapter is a minimal casbin persist.Adapter backed by gorm.
// It replaces casbin/gorm-adapter/v3 to avoid the duplicate sqlite driver
// registration conflict between glebarez/go-sqlite and modernc.org/sqlite.
//
// Option A – tenant filter: construct the adapter with a filterField / filterValue
// to apply a WHERE clause on every LoadPolicy call, isolating one tenant's rows.
//
// Option B – per-tenant table: pass a resolved table name (e.g. "casbin_rule_acme")
// so each tenant's policies live in a dedicated table.

import (
"fmt"
Expand All @@ -11,44 +17,235 @@ import (
"github.com/casbin/casbin/v2/model"
"github.com/casbin/casbin/v2/persist"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)

// identWriter wraps strings.Builder to satisfy clause.Writer (which requires
// both WriteString and WriteByte). It is used to invoke dialector.QuoteTo so
// that identifier quoting is database-specific (backticks for MySQL/SQLite,
// double-quotes for PostgreSQL).
type identWriter struct{ strings.Builder }

func (w *identWriter) WriteByte(b byte) error { return w.Builder.WriteByte(b) }

// quoteIdent returns name quoted as an identifier using the dialector's own
// quoting rules, making SQL safe to use on MySQL, PostgreSQL and SQLite.
func quoteIdent(db *gorm.DB, name string) string {
var w identWriter
db.Dialector.QuoteTo(&w, name)
return w.String()
}

// validFilterFields is the set of column names allowed in a filter to prevent
// SQL injection when building dynamic WHERE clauses.
var validFilterFields = map[string]bool{
"v0": true, "v1": true, "v2": true,
"v3": true, "v4": true, "v5": true,
}

// GORMFilter specifies a WHERE clause for tenant-scoped policy loading.
// It is the concrete filter type accepted by gormAdapter.LoadFilteredPolicy.
type GORMFilter struct {
// Field is the column name to filter on (one of "v0" through "v5").
Field string
// Value is the value the column must equal.
Value string
}

// casbinRule mirrors the table schema used by the upstream gorm-adapter.
// The composite uniqueness constraint is NOT declared in the struct tags because
// GORM uses the literal tag name as the index name (e.g. "unique_index"), which
// conflicts in SQLite's global index namespace when multiple tenant tables are
// created from the same struct. The index is instead created by migrateTable
// with a per-table name ("uidx_<tableName>").
type casbinRule struct {
ID uint `gorm:"primarykey;autoIncrement"`
Ptype string `gorm:"size:512;uniqueIndex:unique_index"`
V0 string `gorm:"size:512;uniqueIndex:unique_index"`
V1 string `gorm:"size:512;uniqueIndex:unique_index"`
V2 string `gorm:"size:512;uniqueIndex:unique_index"`
V3 string `gorm:"size:512;uniqueIndex:unique_index"`
V4 string `gorm:"size:512;uniqueIndex:unique_index"`
V5 string `gorm:"size:512;uniqueIndex:unique_index"`
Ptype string `gorm:"size:512"`
V0 string `gorm:"size:512"`
V1 string `gorm:"size:512"`
V2 string `gorm:"size:512"`
V3 string `gorm:"size:512"`
V4 string `gorm:"size:512"`
V5 string `gorm:"size:512"`
}

// TableName returns the gorm table name.
// TableName returns the default gorm table name.
// When a custom table name is needed the caller uses db.Table(name) explicitly.
func (casbinRule) TableName() string { return "casbin_rule" }

// gormAdapter implements persist.Adapter using a gorm.DB.
// gormAdapter implements persist.FilteredAdapter using a gorm.DB.
// When filterField/filterValue are non-empty all load/save operations are
// scoped to rows where filterField = filterValue (Option A).
// The tableName field enables per-tenant tables (Option B).
//
// FilteredAdapter contract (Casbin v2):
// - IsFiltered() must return false before the first LoadPolicy call so that
// NewEnforcer will call LoadPolicy during initialisation.
// - IsFiltered() returns true once a filtered load has been performed, which
// tells the Casbin enforcer not to allow a bulk SavePolicy.
// - The filter is applied automatically inside LoadPolicy when filterField /
// filterValue are configured, so callers need not use LoadFilteredPolicy.
type gormAdapter struct {
db *gorm.DB
tableName string
db *gorm.DB
tableName string
filterField string // Option A: column name (e.g. "v0"); empty = no filter
filterValue string // Option A: value to match
filtered bool // true after the first filtered LoadPolicy; starts false
}

// newGORMAdapter auto-migrates the casbin_rule table and returns an adapter.
func newGORMAdapter(db *gorm.DB, tableName string) (*gormAdapter, error) {
// validTableNameRune returns true when ch is allowed in a table name used as a
// raw SQL identifier. Only ASCII letters, digits, underscores and hyphens are
// accepted; this prevents characters that could break %q SQL quoting.
func validTableNameRune(ch rune) bool {
return (ch >= 'a' && ch <= 'z') ||
(ch >= 'A' && ch <= 'Z') ||
(ch >= '0' && ch <= '9') ||
ch == '_' || ch == '-'
}

// table returns a *gorm.DB scoped to the adapter's table name.
func (a *gormAdapter) table() *gorm.DB {
return a.db.Table(a.tableName)
}

// newGORMAdapter auto-migrates the table and returns an adapter.
// filterField and filterValue implement Option A (tenant-scoped WHERE clause).
// Pass empty strings to disable filtering.
func newGORMAdapter(db *gorm.DB, tableName, filterField, filterValue string) (*gormAdapter, error) {
if tableName == "" {
tableName = "casbin_rule"
}
if err := db.AutoMigrate(&casbinRule{}); err != nil {
// Validate tableName to prevent SQL injection via the raw CREATE INDEX SQL
// constructed in migrateTable. Only allow characters that are safe to use
// inside a double-quoted SQL identifier without additional escaping.
for _, ch := range tableName {
if !validTableNameRune(ch) {
return nil, fmt.Errorf("gorm casbin adapter: invalid character %q in table name %q", ch, tableName)
}
}
if filterField != "" && !validFilterFields[filterField] {
return nil, fmt.Errorf("gorm casbin adapter: invalid filter_field %q (must be v0-v5)", filterField)
}
// filterField and filterValue must be set together; partial configuration is
// ambiguous and would silently skip filtering while looking like it is enabled.
if (filterField == "") != (filterValue == "") {
return nil, fmt.Errorf("gorm casbin adapter: filter_field and filter_value must both be set or both be empty")
}
if err := migrateTable(db, tableName); err != nil {
Comment on lines +126 to +134
Copy link

Copilot AI Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

newGORMAdapter accepts filterField/filterValue independently. If only one is set (e.g., filter_field provided but filter_value empty), LoadPolicy won’t apply a filter and IsFiltered will never flip, but the config looks like tenant isolation is enabled. Consider validating that filterField and filterValue are either both set or both empty, and fail adapter creation on partial configuration.

Copilot uses AI. Check for mistakes.
return nil, fmt.Errorf("gorm casbin adapter: migrate: %w", err)
}
return &gormAdapter{db: db, tableName: tableName}, nil
return &gormAdapter{
db: db,
tableName: tableName,
filterField: filterField,
filterValue: filterValue,
}, nil
}

// migrateTable creates the casbin rule table for tableName.
//
// The composite unique constraint is created as a named index "uidx_<tableName>"
// rather than via struct tags so that each tenant table gets its own index name.
// This avoids SQLite's flat index namespace where two tables migrated from the
// same struct with the same named index tag would conflict.
//
// On SQLite, AutoMigrate is skipped when the table already exists because the
// generic GORM migrator's HasColumn falls back to information_schema (which does
// not exist in SQLite), causing unintended ALTER TABLE attempts. On PostgreSQL
// and MySQL, AutoMigrate always runs so that schema drift is corrected.
func migrateTable(db *gorm.DB, tableName string) error {
isSQLite := db.Dialector.Name() == "sqlite"
tableExists := db.Migrator().HasTable(tableName)

if !isSQLite || !tableExists {
if err := db.Table(tableName).AutoMigrate(&casbinRule{}); err != nil {
return err
}
}
Comment on lines +156 to +164
Copy link

Copilot AI Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

migrateTable skips AutoMigrate whenever the table already exists, regardless of dialect. On postgres/mysql this means existing deployments won’t get schema drift corrected (missing columns, type changes, etc.), which AutoMigrate previously handled. If the intent is SQLite-only safety, consider restricting the “only migrate new tables” behavior to the SQLite dialector, or documenting that existing tables will not be auto-migrated on any DB.

Copilot uses AI. Check for mistakes.
// Create a composite unique index with a per-table name.
// We check for existence first so the operation is idempotent on all
// supported databases (SQLite, PostgreSQL, MySQL).
idxName := "uidx_" + strings.ReplaceAll(tableName, "-", "_")
if db.Migrator().HasIndex(tableName, idxName) {
return nil
}
// Use the dialector's identifier quoting so the SQL is correct for every
// supported database: backticks on MySQL/SQLite, double-quotes on PostgreSQL.
cols := []string{"ptype", "v0", "v1", "v2", "v3", "v4", "v5"}
quotedCols := make([]string, len(cols))
for i, c := range cols {
quotedCols[i] = quoteIdent(db, c)
}
return db.Exec(fmt.Sprintf(
"CREATE UNIQUE INDEX %s ON %s (%s)",
quoteIdent(db, idxName),
quoteIdent(db, tableName),
strings.Join(quotedCols, ", "),
)).Error
Comment on lines +179 to +184
Copy link

Copilot AI Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

migrateTable uses a raw CREATE UNIQUE INDEX %q ON %q ("ptype",...) statement. %q produces a Go string literal (double-quoted with backslash escapes) and the double-quote identifier quoting is not portable (notably breaks MySQL unless ANSI_QUOTES is enabled). Consider using GORM/dialector identifier quoting (e.g., via the dialector’s QuoteTo / clause builders) or a migrator-based index creation approach so the SQL is correct across sqlite/postgres/mysql.

Copilot uses AI. Check for mistakes.
}

// IsFiltered returns true once a filtered LoadPolicy has been performed.
// Before the first load it returns false so that casbin.NewEnforcer will call
// LoadPolicy during initialisation (Casbin skips LoadPolicy when IsFiltered is
// true at construction time).
func (a *gormAdapter) IsFiltered() bool {
return a.filtered
}

// LoadPolicy loads all policies from the database into the model.
// LoadPolicy loads all (or filtered) policies from the database into the model.
// When filterField/filterValue are configured the WHERE clause is applied and
// filtered is set to true so that subsequent SavePolicy calls are scoped.
func (a *gormAdapter) LoadPolicy(mdl model.Model) error {
if err := a.loadWithFilter(mdl, a.filterField, a.filterValue); err != nil {
return err
}
if a.filterField != "" && a.filterValue != "" {
a.filtered = true
}
return nil
}

// LoadFilteredPolicy loads only policies matching the supplied filter.
// filter must be a GORMFilter value. On success the active filter is stored on
// the adapter so that subsequent SavePolicy calls are correctly scoped, and
// IsFiltered is set to true.
func (a *gormAdapter) LoadFilteredPolicy(mdl model.Model, filter interface{}) error {
f, ok := filter.(GORMFilter)
if !ok {
return fmt.Errorf("gorm casbin adapter: LoadFilteredPolicy expects GORMFilter, got %T", filter)
}
if f.Field != "" && !validFilterFields[f.Field] {
return fmt.Errorf("gorm casbin adapter: invalid filter field %q (must be v0-v5)", f.Field)
}
if err := a.loadWithFilter(mdl, f.Field, f.Value); err != nil {
return err
}
// Persist the active filter so that subsequent SavePolicy calls can correctly
// scope their operations. When no effective filter was provided, clear any
// previously stored filter to avoid inconsistent SavePolicy behavior.
if f.Field != "" && f.Value != "" {
a.filterField = f.Field
a.filterValue = f.Value
a.filtered = true
} else {
a.filterField = ""
a.filterValue = ""
a.filtered = false
}
return nil
}

// loadWithFilter is the shared implementation used by LoadPolicy and
// LoadFilteredPolicy. field must be a pre-validated column name (v0-v5) or
// empty. clause.Eq is used for the WHERE clause so that identifier quoting is
// handled by GORM's dialector, making it correct on MySQL, PostgreSQL and SQLite.
func (a *gormAdapter) loadWithFilter(mdl model.Model, field, value string) error {
q := a.table()
if field != "" && value != "" {
q = q.Where(clause.Eq{Column: clause.Column{Name: field}, Value: value})
}
Comment on lines +244 to +246
Copy link

Copilot AI Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

loadWithFilter builds the WHERE clause using backtick-quoted identifiers (""+field+" = ?"). Backticks are not valid identifier quoting on PostgreSQL, so tenant filtering will fail on the postgres driver. Use GORM’s clause API / dialector quoting to build an identifier-safe equality expression without hard-coding quote characters.

Copilot uses AI. Check for mistakes.
var rules []casbinRule
if err := a.db.Find(&rules).Error; err != nil {
if err := q.Find(&rules).Error; err != nil {
return err
}
for _, rule := range rules {
Expand All @@ -58,10 +255,9 @@ func (a *gormAdapter) LoadPolicy(mdl model.Model) error {
}

// SavePolicy saves all policies from the model into the database.
// When a tenant filter is active only the matching rows are replaced, so that
// other tenants' data is not affected.
func (a *gormAdapter) SavePolicy(mdl model.Model) error {
if err := a.db.Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&casbinRule{}).Error; err != nil {
return err
}
var rules []casbinRule
for ptype, assertions := range mdl["p"] {
for _, assertion := range assertions.Policy {
Expand All @@ -73,27 +269,69 @@ func (a *gormAdapter) SavePolicy(mdl model.Model) error {
rules = append(rules, lineToRule(ptype, assertion))
}
}

if a.IsFiltered() {
// Delete only rows belonging to this tenant, then re-insert.
if err := a.table().Where(clause.Eq{Column: clause.Column{Name: a.filterField}, Value: a.filterValue}).Delete(&casbinRule{}).Error; err != nil {
return err
}
} else {
// Delete all rows in the table.
if err := a.table().Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&casbinRule{}).Error; err != nil {
return err
}
}

if len(rules) > 0 {
return a.db.CreateInBatches(rules, 100).Error
return a.table().CreateInBatches(rules, 100).Error
}
return nil
}

// checkTenantScope returns an error when the adapter is in filtered mode and
// the rule's field at the filter index does not match filterValue. This
// prevents accidental cross-tenant writes via AddPolicy / RemovePolicy.
func (a *gormAdapter) checkTenantScope(rule []string) error {
if a.filterField == "" {
return nil
}
// filterField is validated to be "v0"–"v5"; extract the numeric index.
fieldIdx := int(a.filterField[1] - '0')
if fieldIdx >= len(rule) {
return fmt.Errorf("gorm casbin adapter: rule has %d field(s), filter requires %s", len(rule), a.filterField)
}
if rule[fieldIdx] != a.filterValue {
return fmt.Errorf("gorm casbin adapter: rule %s=%q does not match tenant filter %s=%q; cross-tenant writes are not allowed",
a.filterField, rule[fieldIdx], a.filterField, a.filterValue)
}
return nil
}

// AddPolicy adds a policy rule to the database.
// When a tenant filter is active the rule is validated to ensure it belongs to
// the configured tenant before being written.
func (a *gormAdapter) AddPolicy(sec, ptype string, rule []string) error {
if err := a.checkTenantScope(rule); err != nil {
return err
}
r := lineToRule(ptype, rule)
return a.db.Create(&r).Error
return a.table().Create(&r).Error
}

// RemovePolicy removes a policy rule from the database.
// When a tenant filter is active the rule is validated to ensure it belongs to
// the configured tenant before deletion is attempted.
func (a *gormAdapter) RemovePolicy(sec, ptype string, rule []string) error {
if err := a.checkTenantScope(rule); err != nil {
return err
}
r := lineToRule(ptype, rule)
return a.db.Where(&r).Delete(&casbinRule{}).Error
return a.table().Where(&r).Delete(&casbinRule{}).Error
}
Comment on lines 310 to 330
Copy link

Copilot AI Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When tenant filtering is enabled, AddPolicy/RemovePolicy currently write/delete rows solely based on the provided rule values, without enforcing the adapter’s filter scope. A caller can accidentally insert/delete another tenant’s rows by passing a rule with a different vN value than the configured filter. To make tenant isolation reliable, consider validating that the rule’s filtered field matches filterValue (or overriding it) and returning an error when it doesn’t.

Copilot uses AI. Check for mistakes.

// RemoveFilteredPolicy removes policy rules matching the given filter.
func (a *gormAdapter) RemoveFilteredPolicy(sec, ptype string, fieldIndex int, fieldValues ...string) error {
query := a.db.Where("ptype = ?", ptype)
query := a.table().Where("ptype = ?", ptype)
fields := []string{"v0", "v1", "v2", "v3", "v4", "v5"}
for i, v := range fieldValues {
if v != "" {
Expand Down Expand Up @@ -129,5 +367,6 @@ func lineToRule(ptype string, rule []string) casbinRule {
return r
}

// Compile-time interface check.
var _ persist.Adapter = (*gormAdapter)(nil)
// Compile-time interface check – gormAdapter must satisfy FilteredAdapter
// (which is a superset of Adapter).
var _ persist.FilteredAdapter = (*gormAdapter)(nil)
Loading
Loading