diff --git a/service/entityresolution/multi-strategy/providers/sql/sql_provider.go b/service/entityresolution/multi-strategy/providers/sql/sql_provider.go index 47f4328edc..d9c8c4c181 100644 --- a/service/entityresolution/multi-strategy/providers/sql/sql_provider.go +++ b/service/entityresolution/multi-strategy/providers/sql/sql_provider.go @@ -5,15 +5,57 @@ import ( "database/sql" "fmt" "strings" + "sync" - // Database drivers would be imported here: - // _ "github.com/lib/pq" // PostgreSQL driver - // _ "github.com/go-sql-driver/mysql" // MySQL driver - // _ "github.com/mattn/go-sqlite3" // SQLite driver - + "github.com/jackc/pgx/v5/stdlib" "github.com/opentdf/platform/service/entityresolution/multi-strategy/types" ) +var ( + // driverRegMu guards lazy driver registration to prevent duplicate-register panics. + driverRegMu sync.Mutex + registeredDrivers = make(map[string]struct{}) +) + +// ensureDriverRegistered lazily registers the named database/sql driver the first +// time a SQL provider for that driver is created. This avoids the need for +// consumers to add blank driver imports to their own binaries. +// +// Uses pgx/v5/stdlib for postgres (already a platform dependency). Other drivers +// (mysql, sqlite) are not currently auto-registered and must be imported by the +// consumer. Consumers that have already registered the driver themselves are +// handled gracefully via a sql.Drivers() pre-check. +func ensureDriverRegistered(driver string) { + // Normalize to lowercase so "Postgres", "POSTGRES", and "postgres" all resolve + // to the same registered driver name. sql.Register is case-sensitive. + driver = strings.ToLower(strings.TrimSpace(driver)) + + driverRegMu.Lock() + defer driverRegMu.Unlock() + + if _, ok := registeredDrivers[driver]; ok { + return + } + + // Check whether the driver was already registered externally (e.g. via a + // blank import in the consumer binary) before attempting to register it. + // Use strings.EqualFold so the pre-check is also case-insensitive. + for _, d := range sql.Drivers() { + if strings.EqualFold(d, driver) { + registeredDrivers[driver] = struct{}{} + return + } + } + + switch driver { + case "postgres": + sql.Register("postgres", stdlib.GetDefaultDriver()) + registeredDrivers[driver] = struct{}{} + } + // mysql and sqlite require imports not present in this module's dependencies. + // Add cases here when those drivers are added to go.mod. +} + // Provider implements the Provider interface for SQL databases type Provider struct { name string @@ -24,6 +66,15 @@ type Provider struct { // NewProvider creates a new SQL provider func NewProvider(ctx context.Context, name string, config Config) (*Provider, error) { + // Normalize the driver name so "Postgres", "POSTGRES", and "postgres" all + // resolve correctly through ensureDriverRegistered and sql.Open, both of + // which use case-sensitive driver name matching. + config.Driver = strings.ToLower(strings.TrimSpace(config.Driver)) + + // Register the database/sql driver for this provider's configured driver name + // if it has not already been registered. + ensureDriverRegistered(config.Driver) + provider := &Provider{ name: name, config: config,