From 35c486eace05e35081a43f135391f7e32868a33d Mon Sep 17 00:00:00 2001 From: Kyle Simukka Date: Thu, 6 Nov 2025 18:44:00 +0100 Subject: [PATCH 01/40] Updated testing environment. --- Makefile | 2 ++ docker-compose.test.yml | 27 ++++++++++++++++++--------- dockerfile.test | 6 ++++++ 3 files changed, 26 insertions(+), 9 deletions(-) create mode 100644 Makefile create mode 100644 dockerfile.test diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..1bfea7f --- /dev/null +++ b/Makefile @@ -0,0 +1,2 @@ +test: + docker compose --progress plain -f docker-compose.test.yml run test \ No newline at end of file diff --git a/docker-compose.test.yml b/docker-compose.test.yml index 94da3ec..56f4fc3 100644 --- a/docker-compose.test.yml +++ b/docker-compose.test.yml @@ -1,15 +1,24 @@ services: - # - # mssql - # mssql: image: mcr.microsoft.com/mssql/server:latest - - hostname: mssql - container_name: mssql - network_mode: bridge - ports: - - "1433:1433" + networks: + - mssql environment: ACCEPT_EULA: "Y" SA_PASSWORD: VippsPw1 + healthcheck: + test: ["CMD", "/opt/mssql-tools18/bin/sqlcmd", "-C", "-Usa", "-PVippsPw1", "-Q", "select 1"] + interval: 1s + retries: 20 + test: + build: + dockerfile: dockerfile.test + networks: + - mssql + environment: + SQLSERVER_DSN: sqlserver://mssql:1433?database=master&user id=sa&password=VippsPw1 + depends_on: + mssql: + condition: service_healthy +networks: + mssql: diff --git a/dockerfile.test b/dockerfile.test new file mode 100644 index 0000000..9313912 --- /dev/null +++ b/dockerfile.test @@ -0,0 +1,6 @@ +FROM golang:1.23 AS builder +WORKDIR /sqlcode +ENV GODEBUG="x509negativeserial=1" +COPY . . +RUN go mod tidy +CMD ["go", "test", "-v", "$(go list ./... | grep -v './example')"] \ No newline at end of file From 5a792ea893974b28ef55e3c87475eccd10ff0d34 Mon Sep 17 00:00:00 2001 From: Kyle Simukka Date: Thu, 6 Nov 2025 19:46:56 +0100 Subject: [PATCH 02/40] Working connection. --- Makefile | 5 +- ...mpose.test.yml => docker-compose.mssql.yml | 1 + docker-compose.pgsql.yml | 27 ++++++ dockerfile.test | 2 +- go.mod | 1 + go.sum | 2 + sqlcode.yaml | 6 +- sqltest/fixture.go | 93 ++++++++++++++----- 8 files changed, 108 insertions(+), 29 deletions(-) rename docker-compose.test.yml => docker-compose.mssql.yml (94%) create mode 100644 docker-compose.pgsql.yml diff --git a/Makefile b/Makefile index 1bfea7f..f6980da 100644 --- a/Makefile +++ b/Makefile @@ -1,2 +1,5 @@ test: - docker compose --progress plain -f docker-compose.test.yml run test \ No newline at end of file + docker compose --progress plain -f docker-compose.mssql.yml run test + +test_pgsql: + docker compose --progress plain -f docker-compose.pgsql.yml run test \ No newline at end of file diff --git a/docker-compose.test.yml b/docker-compose.mssql.yml similarity index 94% rename from docker-compose.test.yml rename to docker-compose.mssql.yml index 56f4fc3..84618fc 100644 --- a/docker-compose.test.yml +++ b/docker-compose.mssql.yml @@ -17,6 +17,7 @@ services: - mssql environment: SQLSERVER_DSN: sqlserver://mssql:1433?database=master&user id=sa&password=VippsPw1 + SQLSERVER_DRIVER: sqlserver depends_on: mssql: condition: service_healthy diff --git a/docker-compose.pgsql.yml b/docker-compose.pgsql.yml new file mode 100644 index 0000000..6391e0c --- /dev/null +++ b/docker-compose.pgsql.yml @@ -0,0 +1,27 @@ +services: + postgres: + image: postgres + networks: + - postgres + environment: + POSTGRES_PASSWORD: VippsPw1 + POSTGRES_USER: sa + POSTGRES_DB: master + healthcheck: + test: ["CMD-SHELL", "pg_isready", "-d", "db_prod"] + interval: 1s + retries: 20 + test: + build: + dockerfile: dockerfile.test + networks: + - postgres + environment: + SQLSERVER_DSN: postgresql://sa:VippsPw1@postgres:5432/master?sslmode=disable + SQLSERVER_DRIVER: postgres + GODEBUG: "x509negativeserial=1" + depends_on: + postgres: + condition: service_healthy +networks: + postgres: diff --git a/dockerfile.test b/dockerfile.test index 9313912..f4a199f 100644 --- a/dockerfile.test +++ b/dockerfile.test @@ -1,4 +1,4 @@ -FROM golang:1.23 AS builder +FROM golang:1.25.1 AS builder WORKDIR /sqlcode ENV GODEBUG="x509negativeserial=1" COPY . . diff --git a/go.mod b/go.mod index e497cd0..e80b068 100644 --- a/go.mod +++ b/go.mod @@ -22,6 +22,7 @@ require ( github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/golang-sql/sqlexp v0.1.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/lib/pq v1.10.9 // indirect github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/spf13/pflag v1.0.9 // indirect diff --git a/go.sum b/go.sum index 480c669..dab6450 100644 --- a/go.sum +++ b/go.sum @@ -22,6 +22,8 @@ github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EOqtpKwwwHI= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/modocache/gover v0.0.0-20171022184752-b58185e213c5/go.mod h1:caMODM3PzxT8aQXRPkAt8xlV/e7d7w8GM5g0fa5F0D8= github.com/pkg/browser v0.0.0-20180916011732-0a3d74bf9ce4/go.mod h1:4OwLy04Bl9Ef3GJJCoec+30X3LQs/0/m4HFRt/2LUSA= github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= diff --git a/sqlcode.yaml b/sqlcode.yaml index 549c23f..8adefb7 100644 --- a/sqlcode.yaml +++ b/sqlcode.yaml @@ -1,6 +1,8 @@ databases: - localtest: - connection: sqlserver://localhost:1433?database=foo&user id=foouser&password=FooPasswd1 + mssql: + connection: sqlserver://mssql:1433?database=foo&user id=foouser&password=FooPasswd1 + pgsql: + connection: postgresql://sa:VippsPw1@postgres:5432/master?sslmode=disable # One option is to list other paths to include ('dependencies') here. diff --git a/sqltest/fixture.go b/sqltest/fixture.go index 0059908..8c86651 100644 --- a/sqltest/fixture.go +++ b/sqltest/fixture.go @@ -3,14 +3,17 @@ package sqltest import ( "context" "database/sql" + "database/sql/driver" "fmt" - mssql "github.com/denisenkom/go-mssqldb" - "github.com/denisenkom/go-mssqldb/msdsn" - "github.com/gofrs/uuid" "io/ioutil" "os" "strings" "time" + + mssql "github.com/denisenkom/go-mssqldb" + "github.com/denisenkom/go-mssqldb/msdsn" + "github.com/gofrs/uuid" + pgsql "github.com/lib/pq" ) type StdoutLogger struct { @@ -30,6 +33,20 @@ type Fixture struct { DB *sql.DB DBName string adminDB *sql.DB + Driver driver.Driver +} + +func (f *Fixture) Quote(value string) string { + var ms mssql.Driver + var pg pgsql.Driver + + if f.Driver == &ms { + return fmt.Sprintf("[%s]", value) + } + if f.Driver == &pg { + return fmt.Sprintf(`"%s"`, value) + } + return value } func NewFixture() *Fixture { @@ -39,44 +56,70 @@ func NewFixture() *Fixture { defer cancel() dsn := os.Getenv("SQLSERVER_DSN") - if dsn == "" { + if len(dsn) == 0 { panic("Must set SQLSERVER_DSN to run tests") } - dsn = dsn + "&log=3" - mssql.SetLogger(StdoutLogger{}) + driver := os.Getenv("SQLSERVER_DRIVER") + if len(driver) == 0 { + panic("Must set SQLSERVER_DRIVER to run tests") + } + + switch driver { + case "sqlserver": + // set the logging level + dsn = dsn + "&log=3" + mssql.SetLogger(StdoutLogger{}) + case "postgres": + break + } var err error - fixture.adminDB, err = sql.Open("sqlserver", dsn) + fixture.adminDB, err = sql.Open(driver, dsn) if err != nil { panic(err) } - fixture.DBName = strings.ReplaceAll(uuid.Must(uuid.NewV4()).String(), "-", "") + // store a reference to the type of sql driver + fixture.Driver = fixture.adminDB.Driver() - _, err = fixture.adminDB.ExecContext(ctx, fmt.Sprintf(`create database [%s]`, fixture.DBName)) - if err != nil { - panic(err) - } - // These settings are just to get "worst-case" for our tests, since snapshot could interfer - _, err = fixture.adminDB.ExecContext(ctx, fmt.Sprintf(`alter database [%s] set allow_snapshot_isolation on`, fixture.DBName)) - if err != nil { - panic(err) - } - _, err = fixture.adminDB.ExecContext(ctx, fmt.Sprintf(`alter database [%s] set read_committed_snapshot on`, fixture.DBName)) + fixture.DBName = strings.ReplaceAll(uuid.Must(uuid.NewV4()).String(), "-", "") + dbname := fixture.Quote(fixture.DBName) + _, err = fixture.adminDB.ExecContext(ctx, fmt.Sprintf(`create database %s`, dbname)) if err != nil { + fmt.Printf("Failed to create the database: %s for the %s driver\n", dbname, driver) panic(err) } - pdsn, _, err := msdsn.Parse(dsn) - if err != nil { - panic(err) + if driver == "sqlserver" { + // These settings are just to get "worst-case" for our tests, since snapshot could interfer + _, err = fixture.adminDB.ExecContext(ctx, fmt.Sprintf(`alter database %s set allow_snapshot_isolation on`, dbname)) + if err != nil { + panic(err) + } + _, err = fixture.adminDB.ExecContext(ctx, fmt.Sprintf(`alter database %s set read_committed_snapshot on`, dbname)) + if err != nil { + panic(err) + } + + pdsn, _, err := msdsn.Parse(dsn) + if err != nil { + panic(err) + } + pdsn.Database = fixture.DBName + + fixture.DB, err = sql.Open(driver, pdsn.URL().String()) + if err != nil { + panic(err) + } } - pdsn.Database = fixture.DBName - fixture.DB, err = sql.Open("sqlserver", pdsn.URL().String()) - if err != nil { - panic(err) + if driver == "postgres" { + // TODO + fixture.DB, err = sql.Open(driver, strings.ReplaceAll(dsn, "/master", "/"+fixture.DBName)) + if err != nil { + panic(err) + } } return &fixture From f2e07a971e1df8ba75ef585f2951ee1d00ea779c Mon Sep 17 00:00:00 2001 From: Kyle Simukka Date: Wed, 3 Dec 2025 18:49:47 +0100 Subject: [PATCH 03/40] Updated fixture to use pgx. Simplified how we introspect the DSN to determine the driver type. Allow for future expansion/testing with other drivers. Preparring to write sqlcode migration for Postgres. --- Makefile | 5 +- deployable.go | 1 + go.mod | 11 +++-- go.sum | 16 +++++- migrations/0003.sqlcode.pgsql | 0 sqltest/fixture.go | 93 ++++++++++++++++++----------------- sqltest/sqlcode_test.go | 24 +++++++++ 7 files changed, 99 insertions(+), 51 deletions(-) create mode 100644 migrations/0003.sqlcode.pgsql diff --git a/Makefile b/Makefile index f6980da..fbcbdeb 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,7 @@ -test: +test: test_mssql test_pgsql + + +test_mssql: docker compose --progress plain -f docker-compose.mssql.yml run test test_pgsql: diff --git a/deployable.go b/deployable.go index 7e2b178..dcf2726 100644 --- a/deployable.go +++ b/deployable.go @@ -158,6 +158,7 @@ select @retcode; } defer func() { + // TODO: This returns an error if the lock is already released _, _ = dbc.ExecContext(ctx, `sp_releaseapplock`, sql.Named("Resource", lockResourceName), sql.Named("LockOwner", "Session"), diff --git a/go.mod b/go.mod index e80b068..f07fe09 100644 --- a/go.mod +++ b/go.mod @@ -22,11 +22,16 @@ require ( github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/golang-sql/sqlexp v0.1.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect + github.com/jackc/pgx/v5 v5.7.6 // indirect + github.com/jackc/puddle/v2 v2.2.2 // indirect github.com/lib/pq v1.10.9 // indirect github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/spf13/pflag v1.0.9 // indirect - golang.org/x/crypto v0.46.0 // indirect - golang.org/x/sys v0.39.0 // indirect - golang.org/x/text v0.32.0 // indirect + golang.org/x/crypto v0.43.0 // indirect + golang.org/x/sync v0.17.0 // indirect + golang.org/x/sys v0.37.0 // indirect + golang.org/x/text v0.30.0 // indirect ) diff --git a/go.sum b/go.sum index dab6450..1a7b35c 100644 --- a/go.sum +++ b/go.sum @@ -22,6 +22,14 @@ github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EOqtpKwwwHI= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.7.6 h1:rWQc5FwZSPX58r1OQmkuaNicxdmExaEz5A2DO2hUuTk= +github.com/jackc/pgx/v5 v5.7.6/go.mod h1:aruU7o91Tc2q2cFp5h4uP3f6ztExVpyVv88Xl/8Vl8M= +github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= +github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/modocache/gover v0.0.0-20171022184752-b58185e213c5/go.mod h1:caMODM3PzxT8aQXRPkAt8xlV/e7d7w8GM5g0fa5F0D8= @@ -40,6 +48,7 @@ github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiT github.com/spf13/pflag v1.0.9 h1:9exaQaMOCwffKiiiYk6/BndUBv+iRViNW+4lEMi0PvY= github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= @@ -52,8 +61,10 @@ golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20210610132358-84b48f89b13b/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU= -golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY= +golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4= +golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210= +golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= +golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -72,6 +83,7 @@ golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/migrations/0003.sqlcode.pgsql b/migrations/0003.sqlcode.pgsql new file mode 100644 index 0000000..e69de29 diff --git a/sqltest/fixture.go b/sqltest/fixture.go index 8c86651..b38ac06 100644 --- a/sqltest/fixture.go +++ b/sqltest/fixture.go @@ -3,9 +3,7 @@ package sqltest import ( "context" "database/sql" - "database/sql/driver" "fmt" - "io/ioutil" "os" "strings" "time" @@ -13,9 +11,22 @@ import ( mssql "github.com/denisenkom/go-mssqldb" "github.com/denisenkom/go-mssqldb/msdsn" "github.com/gofrs/uuid" - pgsql "github.com/lib/pq" + _ "github.com/jackc/pgx/v5" + _ "github.com/jackc/pgx/v5/stdlib" ) +type SqlDriverType int + +const ( + SqlDriverDenisen SqlDriverType = iota + SqlDriverPgx +) + +var sqlDrivers = map[SqlDriverType]string{ + SqlDriverDenisen: "sqlserver", + SqlDriverPgx: "pgx", +} + type StdoutLogger struct { } @@ -33,17 +44,23 @@ type Fixture struct { DB *sql.DB DBName string adminDB *sql.DB - Driver driver.Driver + Driver SqlDriverType } -func (f *Fixture) Quote(value string) string { - var ms mssql.Driver - var pg pgsql.Driver +func (f *Fixture) IsSqlServer() bool { + return f.Driver == SqlDriverDenisen +} - if f.Driver == &ms { +func (f *Fixture) IsPostgresql() bool { + return f.Driver == SqlDriverPgx +} + +// SQL specific quoting syntax +func (f *Fixture) Quote(value string) string { + if f.IsSqlServer() { return fmt.Sprintf("[%s]", value) } - if f.Driver == &pg { + if f.IsPostgresql() { return fmt.Sprintf(`"%s"`, value) } return value @@ -60,38 +77,39 @@ func NewFixture() *Fixture { panic("Must set SQLSERVER_DSN to run tests") } - driver := os.Getenv("SQLSERVER_DRIVER") - if len(driver) == 0 { - panic("Must set SQLSERVER_DRIVER to run tests") - } - - switch driver { - case "sqlserver": + if strings.Contains(dsn, "sqlserver") { // set the logging level - dsn = dsn + "&log=3" + // To enable specific logging levels, you sum the values of the desired flags + // 1: Log errors + // 2: Log messages + // 4: Log rows affected + // 8: Trace SQL statements + // 16: Log statement parameters + // 32: Log transaction begin/end + dsn = dsn + "&log=63" mssql.SetLogger(StdoutLogger{}) - case "postgres": - break + fixture.Driver = SqlDriverDenisen + } + if strings.Contains(dsn, "postgresql") { + fixture.Driver = SqlDriverPgx } var err error - - fixture.adminDB, err = sql.Open(driver, dsn) + fixture.adminDB, err = sql.Open(sqlDrivers[fixture.Driver], dsn) if err != nil { panic(err) } - // store a reference to the type of sql driver - fixture.Driver = fixture.adminDB.Driver() fixture.DBName = strings.ReplaceAll(uuid.Must(uuid.NewV4()).String(), "-", "") dbname := fixture.Quote(fixture.DBName) - _, err = fixture.adminDB.ExecContext(ctx, fmt.Sprintf(`create database %s`, dbname)) + qs := fmt.Sprintf(`create database %s`, dbname) + _, err = fixture.adminDB.ExecContext(ctx, qs) if err != nil { - fmt.Printf("Failed to create the database: %s for the %s driver\n", dbname, driver) + fmt.Printf("Failed to create the (%s) database: %s\n", sqlDrivers[fixture.Driver], dbname) panic(err) } - if driver == "sqlserver" { + if fixture.IsSqlServer() { // These settings are just to get "worst-case" for our tests, since snapshot could interfer _, err = fixture.adminDB.ExecContext(ctx, fmt.Sprintf(`alter database %s set allow_snapshot_isolation on`, dbname)) if err != nil { @@ -108,15 +126,15 @@ func NewFixture() *Fixture { } pdsn.Database = fixture.DBName - fixture.DB, err = sql.Open(driver, pdsn.URL().String()) + fixture.DB, err = sql.Open(sqlDrivers[fixture.Driver], pdsn.URL().String()) if err != nil { panic(err) } } - if driver == "postgres" { + if fixture.IsPostgresql() { // TODO - fixture.DB, err = sql.Open(driver, strings.ReplaceAll(dsn, "/master", "/"+fixture.DBName)) + fixture.DB, err = sql.Open(sqlDrivers[fixture.Driver], strings.ReplaceAll(dsn, "/master", "/"+fixture.DBName)) if err != nil { panic(err) } @@ -140,23 +158,8 @@ func (f *Fixture) Teardown() { f.adminDB = nil } -func (f *Fixture) RunMigrations() { - migrationSql, err := ioutil.ReadFile("migrations/from0001/0001.changefeed.sql") - if err != nil { - panic(err) - } - parts := strings.Split(string(migrationSql), "\ngo\n") - for _, p := range parts { - _, err = f.DB.Exec(p) - if err != nil { - fmt.Println(p) - panic(err) - } - } -} - func (f *Fixture) RunMigrationFile(filename string) { - migrationSql, err := ioutil.ReadFile(filename) + migrationSql, err := os.ReadFile(filename) if err != nil { panic(err) } diff --git a/sqltest/sqlcode_test.go b/sqltest/sqlcode_test.go index 4352eac..12e06f1 100644 --- a/sqltest/sqlcode_test.go +++ b/sqltest/sqlcode_test.go @@ -11,6 +11,7 @@ import ( func Test_RowsAffected(t *testing.T) { fixture := NewFixture() defer fixture.Teardown() + // if sql else pgsql fixture.RunMigrationFile("../migrations/0001.sqlcode.sql") ctx := context.Background() @@ -29,3 +30,26 @@ func Test_RowsAffected(t *testing.T) { require.Equal(t, 6, schemas[0].Objects) require.Equal(t, "5420c0269aaf", schemas[0].Suffix()) } + +func Test_EnsureUploaded(t *testing.T) { + fixture := NewFixture() + defer fixture.Teardown() + + t.Run("mssql", func(t *testing.T) { + if !fixture.IsSqlServer() { + t.Skip() + } + fixture.RunMigrationFile("../migrations/0001.sqlcode.sql") + + ctx := context.Background() + require.NoError(t, SQL.EnsureUploaded(ctx, fixture.DB)) + }) + + t.Run("pgsql", func(t *testing.T) { + if !fixture.IsPostgresql() { + t.Skip() + } + + fixture.RunMigrationFile("../migrations/0001.sqlcode.sql") + }) +} From f74f21c515b27e1bb0167bcfeb0a0db4685643bf Mon Sep 17 00:00:00 2001 From: Kyle Simukka Date: Wed, 3 Dec 2025 19:50:31 +0100 Subject: [PATCH 04/40] Wrote initial sqlcode migration for postgres. --- docker-compose.pgsql.yml | 1 - migrations/0003.sqlcode.pgsql | 166 ++++++++++++++++++++++++++++++++++ sqltest/fixture.go | 11 ++- sqltest/sqlcode_test.go | 37 +++++--- 4 files changed, 196 insertions(+), 19 deletions(-) diff --git a/docker-compose.pgsql.yml b/docker-compose.pgsql.yml index 6391e0c..bdb0f0d 100644 --- a/docker-compose.pgsql.yml +++ b/docker-compose.pgsql.yml @@ -18,7 +18,6 @@ services: - postgres environment: SQLSERVER_DSN: postgresql://sa:VippsPw1@postgres:5432/master?sslmode=disable - SQLSERVER_DRIVER: postgres GODEBUG: "x509negativeserial=1" depends_on: postgres: diff --git a/migrations/0003.sqlcode.pgsql b/migrations/0003.sqlcode.pgsql index e69de29..b26b8ed 100644 --- a/migrations/0003.sqlcode.pgsql +++ b/migrations/0003.sqlcode.pgsql @@ -0,0 +1,166 @@ +-- ====================================================================== +-- Create users and roles +-- ====================================================================== +do $$ +begin + -- This role will own the sqlcode schemas, so that created functions etc. + -- are owned by a role without permissions; this means functions/procedures + -- will not get more permissions than the caller already has (unless you use + -- SECURITY DEFINER somewhere). + if not exists ( + select 1 + from pg_roles + where rolname = 'sqlcode-user-with-no-permissions' + ) then + create role "sqlcode-user-with-no-permissions" nologin; + end if; + + -- This role will be granted execute (usage) permissions on all sqlcode schemas; + -- useful e.g. for humans logging in to debug. + if not exists ( + select 1 + from pg_roles + where rolname = 'sqlcode-execute-role' + ) then + create role "sqlcode-execute-role"; + end if; + + -- Role for calling CreateCodeSchema / DropCodeSchema; the role will also be granted + -- control over all schemas created this way. + if not exists ( + select 1 + from pg_roles + where rolname = 'sqlcode-deploy-role' + ) then + create role "sqlcode-deploy-role"; + end if; + + -- Make a role that *only* has this deploy role. During deploys we SET ROLE to this + -- so that we can more safely deploy code with restricted permissions. + if not exists ( + select 1 + from pg_roles + where rolname = 'sqlcode-deploy-sandbox-user' + ) then + create role "sqlcode-deploy-sandbox-user" nologin; + end if; +end; +$$; + +-- ====================================================================== +-- grant permissions +-- ====================================================================== + +do $$ +begin + -- grant "sqlcode-deploy-role" to "sqlcode-deploy-sandbox-user" + if not exists ( + select 1 + from pg_auth_members m + join pg_roles r_role on r_role.oid = m.roleid + join pg_roles r_member on r_member.oid = m.member + where r_role.rolname = 'sqlcode-deploy-role' + and r_member.rolname = 'sqlcode-deploy-sandbox-user' + ) then + grant "sqlcode-deploy-role" to "sqlcode-deploy-sandbox-user"; + end if; + +end; +$$; + +-- ====================================================================== +-- create schema +-- ====================================================================== + +-- Base schema to hold the procedures etc. +do $$ +begin + if not exists ( + select 1 from pg_namespace where nspname = 'sqlcode' + ) then + create schema sqlcode; + end if; +end; +$$; + +-- ====================================================================== +-- create procedures +-- ====================================================================== + +create or replace procedure sqlcode.createcodeschema(schemasuffix varchar) +language plpgsql +security definer +as $$ +declare + schemaname text := format('code@%s', schemasuffix); +begin + -- create the schema owned by "sqlcode-user-with-no-permissions" + execute format( + 'create schema %I authorization %I', + schemaname, + 'sqlcode-user-with-no-permissions' + ); + + -- grant schema privileges + execute format( + 'grant usage on schema %I to %I', + schemaname, + 'sqlcode-execute-role' + ); + + execute format( + 'grant usage, create on schema %I to %I', + schemaname, + 'sqlcode-deploy-role' + ); + +exception + when others then + raise; +end; +$$; + +-- ====================================================================== +-- procedure: sqlcode.dropcodeschema +-- ====================================================================== + +create or replace procedure sqlcode.dropcodeschema(schemasuffix varchar) +language plpgsql +security definer +as $$ +declare + schemaname text := format('code@%s', schemasuffix); + schema_exists boolean; +begin + -- check schema existence + select exists ( + select 1 + from pg_namespace + where nspname = schemaname + ) into schema_exists; + + if not schema_exists then + raise exception 'schema "%" not found', schemaname; + end if; + + -- drop the schema and all objects within it + execute format('drop schema %I cascade', schemaname); + +exception + when others then + raise; +end; +$$; + +-- ====================================================================== +-- privileges on the procedures and base schema +-- ====================================================================== + +grant execute on procedure sqlcode.createcodeschema(varchar) + to "sqlcode-deploy-role"; + +grant execute on procedure sqlcode.dropcodeschema(varchar) + to "sqlcode-deploy-role"; + +grant usage, create on schema sqlcode + to "sqlcode-deploy-role"; diff --git a/sqltest/fixture.go b/sqltest/fixture.go index b38ac06..82f395d 100644 --- a/sqltest/fixture.go +++ b/sqltest/fixture.go @@ -92,6 +92,8 @@ func NewFixture() *Fixture { } if strings.Contains(dsn, "postgresql") { fixture.Driver = SqlDriverPgx + // https://www.postgresql.org/docs/current/runtime-config-client.html#GUC-CLIENT-MIN-MESSAGES + dsn = dsn + "&options=-c%20client_min_messages%3DDEBUG5" } var err error @@ -105,7 +107,7 @@ func NewFixture() *Fixture { qs := fmt.Sprintf(`create database %s`, dbname) _, err = fixture.adminDB.ExecContext(ctx, qs) if err != nil { - fmt.Printf("Failed to create the (%s) database: %s\n", sqlDrivers[fixture.Driver], dbname) + fmt.Printf("Failed to create the (%s) database: %s: %e\n", sqlDrivers[fixture.Driver], dbname, err) panic(err) } @@ -133,7 +135,7 @@ func NewFixture() *Fixture { } if fixture.IsPostgresql() { - // TODO + // TODO use pgx config parser fixture.DB, err = sql.Open(sqlDrivers[fixture.Driver], strings.ReplaceAll(dsn, "/master", "/"+fixture.DBName)) if err != nil { panic(err) @@ -153,7 +155,10 @@ func (f *Fixture) Teardown() { _ = f.DB.Close() f.DB = nil - _, _ = f.adminDB.ExecContext(ctx, fmt.Sprintf(`drop database [%s]`, f.DBName)) + _, err := f.adminDB.ExecContext(ctx, fmt.Sprintf(`drop database %s`, f.Quote(f.DBName))) + if err != nil { + fmt.Printf("Failed to drop (%s) database %s: %e", sqlDrivers[f.Driver], f.DBName, err) + } _ = f.adminDB.Close() f.adminDB = nil } diff --git a/sqltest/sqlcode_test.go b/sqltest/sqlcode_test.go index 12e06f1..9de22a3 100644 --- a/sqltest/sqlcode_test.go +++ b/sqltest/sqlcode_test.go @@ -11,24 +11,31 @@ import ( func Test_RowsAffected(t *testing.T) { fixture := NewFixture() defer fixture.Teardown() - // if sql else pgsql - fixture.RunMigrationFile("../migrations/0001.sqlcode.sql") + t.Run("mssql", func(t *testing.T) { + if !fixture.IsSqlServer() { + t.Skip() + } + + fixture.RunMigrationFile("../migrations/0001.sqlcode.sql") - ctx := context.Background() + ctx := context.Background() + + require.NoError(t, SQL.EnsureUploaded(ctx, fixture.DB)) + patched := SQL.Patch(`[code].Test`) - require.NoError(t, SQL.EnsureUploaded(ctx, fixture.DB)) - patched := SQL.Patch(`[code].Test`) + res, err := fixture.DB.ExecContext(ctx, patched) + require.NoError(t, err) + rowsAffected, err := res.RowsAffected() + require.NoError(t, err) + assert.Equal(t, int64(1), rowsAffected) - res, err := fixture.DB.ExecContext(ctx, patched) - require.NoError(t, err) - rowsAffected, err := res.RowsAffected() - require.NoError(t, err) - assert.Equal(t, int64(1), rowsAffected) + schemas := SQL.ListUploaded(ctx, fixture.DB) + require.Len(t, schemas, 1) + require.Equal(t, 6, schemas[0].Objects) + require.Equal(t, "5420c0269aaf", schemas[0].Suffix()) + + }) - schemas := SQL.ListUploaded(ctx, fixture.DB) - require.Len(t, schemas, 1) - require.Equal(t, 6, schemas[0].Objects) - require.Equal(t, "5420c0269aaf", schemas[0].Suffix()) } func Test_EnsureUploaded(t *testing.T) { @@ -50,6 +57,6 @@ func Test_EnsureUploaded(t *testing.T) { t.Skip() } - fixture.RunMigrationFile("../migrations/0001.sqlcode.sql") + fixture.RunMigrationFile("../migrations/0003.sqlcode.pgsql") }) } From 396e42717e783b048c89d2baf3820c4b5cd03fd3 Mon Sep 17 00:00:00 2001 From: Kyle Simukka Date: Wed, 3 Dec 2025 20:06:03 +0100 Subject: [PATCH 05/40] Add security definer role. --- migrations/0003.sqlcode.pgsql | 62 +++++++++++++++++++++-------------- 1 file changed, 38 insertions(+), 24 deletions(-) diff --git a/migrations/0003.sqlcode.pgsql b/migrations/0003.sqlcode.pgsql index b26b8ed..a338627 100644 --- a/migrations/0003.sqlcode.pgsql +++ b/migrations/0003.sqlcode.pgsql @@ -1,12 +1,9 @@ -- ====================================================================== --- Create users and roles +-- create users and roles -- ====================================================================== do $$ begin - -- This role will own the sqlcode schemas, so that created functions etc. - -- are owned by a role without permissions; this means functions/procedures - -- will not get more permissions than the caller already has (unless you use - -- SECURITY DEFINER somewhere). + -- role that will own the sqlcode schemas (actual code schemas), with no login if not exists ( select 1 from pg_roles @@ -15,8 +12,16 @@ begin create role "sqlcode-user-with-no-permissions" nologin; end if; - -- This role will be granted execute (usage) permissions on all sqlcode schemas; - -- useful e.g. for humans logging in to debug. + -- role that owns the management schema/procedures (security definer) + if not exists ( + select 1 + from pg_roles + where rolname = 'sqlcode-definer-role' + ) then + create role "sqlcode-definer-role" nologin; + end if; + + -- role that gets execute / usage on code schemas (for humans debugging etc.) if not exists ( select 1 from pg_roles @@ -25,8 +30,8 @@ begin create role "sqlcode-execute-role"; end if; - -- Role for calling CreateCodeSchema / DropCodeSchema; the role will also be granted - -- control over all schemas created this way. + -- role for calling createcodeschema / dropcodeschema; + -- this role does not own the procedures, it only calls them. if not exists ( select 1 from pg_roles @@ -35,8 +40,7 @@ begin create role "sqlcode-deploy-role"; end if; - -- Make a role that *only* has this deploy role. During deploys we SET ROLE to this - -- so that we can more safely deploy code with restricted permissions. + -- sandbox role used during deploys, which only has sqlcode-deploy-role if not exists ( select 1 from pg_roles @@ -48,7 +52,7 @@ end; $$; -- ====================================================================== --- grant permissions +-- grant permissions / role memberships -- ====================================================================== do $$ @@ -64,27 +68,25 @@ begin ) then grant "sqlcode-deploy-role" to "sqlcode-deploy-sandbox-user"; end if; - end; $$; -- ====================================================================== --- create schema +-- create schema for management code (owner = definer role) -- ====================================================================== --- Base schema to hold the procedures etc. do $$ begin if not exists ( select 1 from pg_namespace where nspname = 'sqlcode' ) then - create schema sqlcode; + create schema sqlcode authorization "sqlcode-definer-role"; end if; end; $$; -- ====================================================================== --- create procedures +-- create procedures (security definer) -- ====================================================================== create or replace procedure sqlcode.createcodeschema(schemasuffix varchar) @@ -94,6 +96,9 @@ as $$ declare schemaname text := format('code@%s', schemasuffix); begin + -- harden search_path for security-definer (optional but recommended) + perform set_config('search_path', 'pg_catalog', true); + -- create the schema owned by "sqlcode-user-with-no-permissions" execute format( 'create schema %I authorization %I', @@ -120,18 +125,17 @@ exception end; $$; --- ====================================================================== --- procedure: sqlcode.dropcodeschema --- ====================================================================== - create or replace procedure sqlcode.dropcodeschema(schemasuffix varchar) language plpgsql security definer as $$ declare - schemaname text := format('code@%s', schemasuffix); - schema_exists boolean; + schemaname text := format('code@%s', schemasuffix); + schema_exists boolean; begin + -- harden search_path for security-definer (optional but recommended) + perform set_config('search_path', 'pg_catalog', true); + -- check schema existence select exists ( select 1 @@ -152,15 +156,25 @@ exception end; $$; +-- ensure procedures are owned by the definer role +alter procedure sqlcode.createcodeschema(varchar) + owner to "sqlcode-definer-role"; + +alter procedure sqlcode.dropcodeschema(varchar) + owner to "sqlcode-definer-role"; + -- ====================================================================== -- privileges on the procedures and base schema -- ====================================================================== +-- allow deploy role to call the management procedures grant execute on procedure sqlcode.createcodeschema(varchar) to "sqlcode-deploy-role"; grant execute on procedure sqlcode.dropcodeschema(varchar) to "sqlcode-deploy-role"; -grant usage, create on schema sqlcode +-- usually deploy role does not need create in the sqlcode management schema +-- (the procedures handle creation in separate "code@..." schemas) +grant usage on schema sqlcode to "sqlcode-deploy-role"; From 201a66ca275532bd19e23c8ec776bef10b3c4c5f Mon Sep 17 00:00:00 2001 From: Kyle Simukka Date: Wed, 3 Dec 2025 21:51:32 +0100 Subject: [PATCH 06/40] [wip] working through changes for EnsureUploaded to support postgresql. --- dbintf.go | 2 + dbops.go | 16 ++++- deployable.go | 106 ++++++++++++++++++++++++---------- docker-compose.pgsql.yml | 1 + migrations/0003.sqlcode.pgsql | 46 +++++++++++++++ sqltest/sqlcode_test.go | 21 +++++++ 6 files changed, 161 insertions(+), 31 deletions(-) diff --git a/dbintf.go b/dbintf.go index 8257e11..1495942 100644 --- a/dbintf.go +++ b/dbintf.go @@ -3,6 +3,7 @@ package sqlcode import ( "context" "database/sql" + "database/sql/driver" ) type DB interface { @@ -11,6 +12,7 @@ type DB interface { QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row Conn(ctx context.Context) (*sql.Conn, error) BeginTx(ctx context.Context, txOptions *sql.TxOptions) (*sql.Tx, error) + Driver() driver.Driver } var _ DB = &sql.DB{} diff --git a/dbops.go b/dbops.go index 05e5a88..0293a2a 100644 --- a/dbops.go +++ b/dbops.go @@ -3,11 +3,25 @@ package sqlcode import ( "context" "database/sql" + + mssql "github.com/denisenkom/go-mssqldb" + "github.com/jackc/pgx/v5/stdlib" ) func Exists(ctx context.Context, dbc DB, schemasuffix string) (bool, error) { var schemaID int - err := dbc.QueryRowContext(ctx, `select isnull(schema_id(@p1), 0)`, SchemaName(schemasuffix)).Scan(&schemaID) + + driver := dbc.Driver() + var qs string + + if _, ok := driver.(*mssql.Driver); ok { + qs = `select isnull(schema_id(@p1), 0)` + } + if _, ok := driver.(*stdlib.Driver); ok { + qs = `select coalesce((select oid from pg_namespace where nspname = $1),0)` + } + + err := dbc.QueryRowContext(ctx, qs, SchemaName(schemasuffix)).Scan(&schemaID) if err != nil { return false, err } diff --git a/deployable.go b/deployable.go index dcf2726..c1b38fe 100644 --- a/deployable.go +++ b/deployable.go @@ -11,6 +11,9 @@ import ( "time" mssql "github.com/denisenkom/go-mssqldb" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/stdlib" + pgxstdlib "github.com/jackc/pgx/v5/stdlib" "github.com/vippsas/sqlcode/sqlparser" ) @@ -77,21 +80,22 @@ func impersonate(ctx context.Context, dbc DB, username string, f func(conn *sql. // Upload will create and upload the schema; resulting in an error // if the schema already exists func (d *Deployable) Upload(ctx context.Context, dbc DB) error { - // First, impersonate a user with minimal privileges to get at least - // some level of sandboxing so that migration scripts can't do anything - // the caller didn't expect them to. - return impersonate(ctx, dbc, "sqlcode-deploy-sandbox-user", func(conn *sql.Conn) error { + driver := dbc.Driver() + qs := make(map[string][]interface{}, 1) + + var uploadFunc = func(conn *sql.Conn) error { tx, err := conn.BeginTx(ctx, nil) if err != nil { return err } - _, err = tx.ExecContext(ctx, `sqlcode.CreateCodeSchema`, - sql.Named("schemasuffix", d.SchemaSuffix), - ) - if err != nil { - _ = tx.Rollback() - return err + for q, args := range qs { + _, err = tx.ExecContext(ctx, q, args...) + + if err != nil { + _ = tx.Rollback() + return fmt.Errorf("failed to execute (%s) with arg(%s) in schema %s: %w", q, args, d.SchemaSuffix, err) + } } preprocessed, err := Preprocess(d.CodeBase, d.SchemaSuffix) @@ -123,8 +127,36 @@ func (d *Deployable) Upload(ctx context.Context, dbc DB) error { return nil - }) + } + + if _, ok := driver.(*mssql.Driver); ok { + // First, impersonate a user with minimal privileges to get at least + // some level of sandboxing so that migration scripts can't do anything + // the caller didn't expect them to. + qs["sqlcode.CreateCodeSchema"] = []interface { + }{ + sql.Named("schemasuffix", d.SchemaSuffix), + } + return impersonate(ctx, dbc, "sqlcode-deploy-sandbox-user", uploadFunc) + } + + if _, ok := driver.(*stdlib.Driver); ok { + qs[`set role "sqlcode-deploy-sandbox-user"`] = nil + qs[`call sqlcode.createcodeschema(@schemasuffix)`] = []interface{}{ + pgx.NamedArgs{"schemasuffix": d.SchemaSuffix}, + } + conn, err := dbc.Conn(ctx) + if err != nil { + return err + } + defer func() { + _ = conn.Close() + }() + return uploadFunc(conn) + } + + return fmt.Errorf("failed to determine sql driver to upload schema: %s", d.SchemaSuffix) } // EnsureUploaded checks that the schema with the suffix already exists, @@ -137,37 +169,51 @@ func (d *Deployable) EnsureUploaded(ctx context.Context, dbc DB) error { return nil } + driver := dbc.Driver() lockResourceName := "sqlcode.EnsureUploaded/" + d.SchemaSuffix + var lockRetCode int + var lockQs string + var unlockQs string + var err error + // When a lock is opened with the Transaction lock owner, // that lock is released when the transaction is committed or rolled back. - var lockRetCode int - err := dbc.QueryRowContext(ctx, ` -declare @retcode int; -exec @retcode = sp_getapplock @Resource = @resource, @LockMode = 'Shared', @LockOwner = 'Session', @LockTimeout = @timeoutMs; -select @retcode; -`, - sql.Named("resource", lockResourceName), - sql.Named("timeoutMs", 20000), - ).Scan(&lockRetCode) + if _, ok := driver.(*pgxstdlib.Driver); ok { + lockQs = `select sqlcode.get_applock(@resource, @timeout)` + unlockQs = `select sqlcode.release_applock(@resource)` + + err = dbc.QueryRowContext(ctx, lockQs, pgx.NamedArgs{ + "resource": lockResourceName, + "timeoutMs": 20000, + }).Scan(&lockRetCode) + + defer func() { + dbc.ExecContext(ctx, unlockQs, pgx.NamedArgs{"resource": lockResourceName}) + }() + } + + if _, ok := driver.(*mssql.Driver); ok { + // TODO + + defer func() { + // TODO: This returns an error if the lock is already released + _, _ = dbc.ExecContext(ctx, unlockQs, + sql.Named("Resource", lockResourceName), + sql.Named("LockOwner", "Session"), + ) + }() + } + if err != nil { return err } if lockRetCode < 0 { return errors.New("was not able to get lock before timeout") } - - defer func() { - // TODO: This returns an error if the lock is already released - _, _ = dbc.ExecContext(ctx, `sp_releaseapplock`, - sql.Named("Resource", lockResourceName), - sql.Named("LockOwner", "Session"), - ) - }() - exists, err := Exists(ctx, dbc, d.SchemaSuffix) if err != nil { - return err + return fmt.Errorf("unable to determine if schema %s exists: %w", d.SchemaSuffix, err) } if exists { diff --git a/docker-compose.pgsql.yml b/docker-compose.pgsql.yml index bdb0f0d..a6d7a0b 100644 --- a/docker-compose.pgsql.yml +++ b/docker-compose.pgsql.yml @@ -7,6 +7,7 @@ services: POSTGRES_PASSWORD: VippsPw1 POSTGRES_USER: sa POSTGRES_DB: master + PGOPTIONS: "-c log_error_verbosity=verbose -c log_statement=all" healthcheck: test: ["CMD-SHELL", "pg_isready", "-d", "db_prod"] interval: 1s diff --git a/migrations/0003.sqlcode.pgsql b/migrations/0003.sqlcode.pgsql index a338627..98f9228 100644 --- a/migrations/0003.sqlcode.pgsql +++ b/migrations/0003.sqlcode.pgsql @@ -156,6 +156,52 @@ exception end; $$; +-- similar behaviour as mssql getapplock +-- PostgreSQL advisory locks are session-based by default +create or replace function sqlcode.get_applock( + resource text, + timeout_ms integer default 0 +) +returns integer +language plpgsql +as $$ +declare + resource_key bigint; + acquired boolean; + waited_ms integer := 0; +begin + -- convert string to advisory-lock key + select hashtext(resource) into resource_key; + + -- attempt lock with timeout loop + loop + select pg_try_advisory_lock_shared(resource_key) + into acquired; + + if acquired then + return 1; -- lock acquired (success) + end if; + + if waited_ms >= timeout_ms then + return 0; -- timeout + end if; + + perform pg_sleep(0.01); -- sleep 10 ms + waited_ms := waited_ms + 10; + end loop; + + return null; -- safety fallback (should never hit) +end; +$$; + +create or replace function sqlcode.release_applock(resource text) +returns boolean +language sql +as $$ + select pg_advisory_unlock_shared(hashtext(resource)); +$$; + + -- ensure procedures are owned by the definer role alter procedure sqlcode.createcodeschema(varchar) owner to "sqlcode-definer-role"; diff --git a/sqltest/sqlcode_test.go b/sqltest/sqlcode_test.go index 9de22a3..f343c07 100644 --- a/sqltest/sqlcode_test.go +++ b/sqltest/sqlcode_test.go @@ -4,6 +4,7 @@ import ( "context" "testing" + "github.com/jackc/pgx/v5" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -58,5 +59,25 @@ func Test_EnsureUploaded(t *testing.T) { } fixture.RunMigrationFile("../migrations/0003.sqlcode.pgsql") + + ctx := context.Background() + + _, err := fixture.adminDB.Exec(`grant create on database @database to "sqlcode-definer-role"`, + pgx.NamedArgs{"database": fixture.DBName}) + require.NoError(t, err) + + require.NoError(t, SQL.EnsureUploaded(ctx, fixture.DB)) + patched := SQL.Patch(`[code].Test`) + + res, err := fixture.DB.ExecContext(ctx, patched) + require.NoError(t, err) + rowsAffected, err := res.RowsAffected() + require.NoError(t, err) + assert.Equal(t, int64(1), rowsAffected) + + schemas := SQL.ListUploaded(ctx, fixture.DB) + require.Len(t, schemas, 1) + require.Equal(t, 6, schemas[0].Objects) + require.Equal(t, "5420c0269aaf", schemas[0].Suffix()) }) } From 108f65114f68401ba77cc028ea98f224d44e0928 Mon Sep 17 00:00:00 2001 From: Kyle Simukka Date: Thu, 4 Dec 2025 19:19:38 +0100 Subject: [PATCH 07/40] Working EnsureUpload! --- cli/cmd/build.go | 4 +- deployable.go | 72 ++++++++++++++++++++++++++--------- migrations/0003.sqlcode.pgsql | 5 ++- error.go => mssql_error.go | 7 ++-- preprocess.go | 22 +++++++++-- sqlparser/dom.go | 7 +++- sqlparser/parser.go | 13 ++++++- sqlparser/parser_test.go | 18 +++++++++ sqltest/sql.go | 4 ++ sqltest/sqlcode_test.go | 21 ++++------ sqltest/test.pgsql | 8 ++++ 11 files changed, 136 insertions(+), 45 deletions(-) rename error.go => mssql_error.go (92%) create mode 100644 sqltest/test.pgsql diff --git a/cli/cmd/build.go b/cli/cmd/build.go index 1ffdde2..c0d7db3 100644 --- a/cli/cmd/build.go +++ b/cli/cmd/build.go @@ -3,6 +3,8 @@ package cmd import ( "errors" "fmt" + + mssql "github.com/denisenkom/go-mssqldb" "github.com/spf13/cobra" "github.com/vippsas/sqlcode" ) @@ -23,7 +25,7 @@ var ( return err } - preprocessed, err := sqlcode.Preprocess(d.CodeBase, schemasuffix) + preprocessed, err := sqlcode.Preprocess(d.CodeBase, schemasuffix, &mssql.Driver{}) if err != nil { return err } diff --git a/deployable.go b/deployable.go index c1b38fe..ead2709 100644 --- a/deployable.go +++ b/deployable.go @@ -98,7 +98,8 @@ func (d *Deployable) Upload(ctx context.Context, dbc DB) error { } } - preprocessed, err := Preprocess(d.CodeBase, d.SchemaSuffix) + preprocessed, err := Preprocess(d.CodeBase, d.SchemaSuffix, dbc.Driver()) + if err != nil { _ = tx.Rollback() return err @@ -107,15 +108,16 @@ func (d *Deployable) Upload(ctx context.Context, dbc DB) error { _, err := tx.ExecContext(ctx, b.Lines) if err != nil { _ = tx.Rollback() - sqlerr, ok := err.(mssql.Error) - if !ok { - return err - } else { - return SQLUserError{ + if sqlerr, ok := err.(mssql.Error); ok { + return MSSQLUserError{ Wrapped: sqlerr, Batch: b, } } + + // TODO(ks) PGSQLUserError + return fmt.Errorf("failed to upload deployable:%s in schema:%s:%w", d.CodeBase, d.SchemaSuffix, err) + } } err = tx.Commit() @@ -327,10 +329,28 @@ func (s *SchemaObject) Suffix() string { // Return a list of sqlcode schemas that have been uploaded to the database. // This includes all current and unused schemas. -func (d *Deployable) ListUploaded(ctx context.Context, dbc DB) []*SchemaObject { +func (d *Deployable) ListUploaded(ctx context.Context, dbc DB) ([]*SchemaObject, error) { objects := []*SchemaObject{} - impersonate(ctx, dbc, "sqlcode-deploy-sandbox-user", func(conn *sql.Conn) error { - rows, err := conn.QueryContext(ctx, ` + driver := dbc.Driver() + var qs string + + var list = func(conn *sql.Conn) error { + rows, err := conn.QueryContext(ctx, qs) + if err != nil { + return err + } + + for rows.Next() { + zero := &SchemaObject{} + rows.Scan(&zero.Name, &zero.Objects, &zero.SchemaId, &zero.CreateDate, &zero.ModifyDate) + objects = append(objects, zero) + } + + return nil + } + + if _, ok := driver.(*mssql.Driver); ok { + qs = ` select s.name , s.schema_id @@ -345,18 +365,32 @@ func (d *Deployable) ListUploaded(ctx context.Context, dbc DB) []*SchemaObject { from sys.objects o where o.schema_id = s.schema_id ) as o - where s.name like 'code@%'`) + where s.name like 'code@%'` + impersonate(ctx, dbc, "sqlcode-deploy-sandbox-user", list) + } + + // TODO(ks) the timestamps for schemas + if _, ok := driver.(*stdlib.Driver); ok { + qs = `select nspname as name + , oid as schema_id + , 0 as objects + , '' as create_date + , '' as modify_date + from pg_namespace + where nspname like 'code@%' + order by nspname` + conn, err := dbc.Conn(ctx) if err != nil { - return err + return nil, err } - - for rows.Next() { - zero := &SchemaObject{} - rows.Scan(&zero.Name, &zero.Objects, &zero.SchemaId, &zero.CreateDate, &zero.ModifyDate) - objects = append(objects, zero) + err = list(conn) + if err != nil { + return nil, err } + defer func() { + _ = conn.Close() + }() + } - return nil - }) - return objects + return objects, nil } diff --git a/migrations/0003.sqlcode.pgsql b/migrations/0003.sqlcode.pgsql index 98f9228..1395bc4 100644 --- a/migrations/0003.sqlcode.pgsql +++ b/migrations/0003.sqlcode.pgsql @@ -101,7 +101,7 @@ begin -- create the schema owned by "sqlcode-user-with-no-permissions" execute format( - 'create schema %I authorization %I', + 'create schema if not exists %I authorization %I', schemaname, 'sqlcode-user-with-no-permissions' ); @@ -220,6 +220,9 @@ grant execute on procedure sqlcode.createcodeschema(varchar) grant execute on procedure sqlcode.dropcodeschema(varchar) to "sqlcode-deploy-role"; +grant "sqlcode-user-with-no-permissions" + to "sqlcode-definer-role"; + -- usually deploy role does not need create in the sqlcode management schema -- (the procedures handle creation in separate "code@..." schemas) grant usage on schema sqlcode diff --git a/error.go b/mssql_error.go similarity index 92% rename from error.go rename to mssql_error.go index 6131fbf..22d4bde 100644 --- a/error.go +++ b/mssql_error.go @@ -3,17 +3,18 @@ package sqlcode import ( "bytes" "fmt" + "strings" + mssql "github.com/denisenkom/go-mssqldb" "github.com/vippsas/sqlcode/sqlparser" - "strings" ) -type SQLUserError struct { +type MSSQLUserError struct { Wrapped mssql.Error Batch Batch } -func (s SQLUserError) Error() string { +func (s MSSQLUserError) Error() string { var buf bytes.Buffer if _, fmterr := fmt.Fprintf(&buf, "\n"); fmterr != nil { diff --git a/preprocess.go b/preprocess.go index 5a8adab..2c6f647 100644 --- a/preprocess.go +++ b/preprocess.go @@ -2,12 +2,15 @@ package sqlcode import ( "crypto/sha256" + "database/sql/driver" "encoding/hex" "errors" "fmt" - "github.com/vippsas/sqlcode/sqlparser" "regexp" "strings" + + "github.com/jackc/pgx/v5/stdlib" + "github.com/vippsas/sqlcode/sqlparser" ) func SchemaSuffixFromHash(doc sqlparser.Document) string { @@ -138,7 +141,7 @@ func sqlcodeTransformCreate(declares map[string]string, c sqlparser.Create, quot return } -func Preprocess(doc sqlparser.Document, schemasuffix string) (PreprocessedFile, error) { +func Preprocess(doc sqlparser.Document, schemasuffix string, driver driver.Driver) (PreprocessedFile, error) { var result PreprocessedFile if strings.Contains(schemasuffix, "]") { @@ -154,10 +157,21 @@ func Preprocess(doc sqlparser.Document, schemasuffix string) (PreprocessedFile, if len(create.Body) == 0 { continue } - batch, err := sqlcodeTransformCreate(declares, create, "[code@"+schemasuffix+"]") + if create.Driver != driver { + // continue + } + // TODO(ks) this is not reached + target := "[code@" + schemasuffix + "]" + + if _, ok := create.Driver.(*stdlib.Driver); ok { + target = "code@" + schemasuffix + } + + batch, err := sqlcodeTransformCreate(declares, create, target) if err != nil { - return result, err + return result, fmt.Errorf("failed to transform create: %w", err) } + fmt.Print(batch) result.Batches = append(result.Batches, batch) } diff --git a/sqlparser/dom.go b/sqlparser/dom.go index cc661f4..14209ee 100644 --- a/sqlparser/dom.go +++ b/sqlparser/dom.go @@ -1,10 +1,12 @@ package sqlparser import ( + "database/sql/driver" "fmt" - "gopkg.in/yaml.v3" "io" "strings" + + "gopkg.in/yaml.v3" ) type Unparsed struct { @@ -64,7 +66,8 @@ type Create struct { QuotedName PosString // proc/func/type name, including [] Body []Unparsed DependsOn []PosString - Docstring []PosString // comment lines before the create statement. Note: this is also part of Body + Docstring []PosString // comment lines before the create statement. Note: this is also part of Body + Driver driver.Driver // the sql driver this document is intended for } func (c Create) DocstringAsString() string { diff --git a/sqlparser/parser.go b/sqlparser/parser.go index 40eebe9..b61f49e 100644 --- a/sqlparser/parser.go +++ b/sqlparser/parser.go @@ -12,6 +12,9 @@ import ( "regexp" "sort" "strings" + + mssql "github.com/denisenkom/go-mssqldb" + "github.com/jackc/pgx/v5/stdlib" ) var templateRoutineName string = "\ndeclare @RoutineName nvarchar(128)\nset @RoutineName = '%s'\n" @@ -276,6 +279,14 @@ func (doc *Document) parseBatch(s *Scanner, isFirst bool) (hasMore bool) { case "create": // should be start of create procedure or create function... c := doc.parseCreate(s, createCountInBatch) + + if strings.HasSuffix(string(s.file), ".sql") { + c.Driver = &mssql.Driver{} + } + if strings.HasSuffix(string(s.file), ".pgsql") { + c.Driver = &stdlib.Driver{} + } + // *prepend* what we saw before getting to the 'create' createCountInBatch++ c.Body = append(nodes, c.Body...) @@ -580,7 +591,7 @@ func ParseFilesystems(fslst []fs.FS, includeTags []string) (filenames []string, if strings.HasPrefix(path, ".") || strings.Contains(path, "/.") { return nil } - if !strings.HasSuffix(path, ".sql") { + if !strings.HasSuffix(path, ".sql") || !strings.HasSuffix(path, ".pgsql") { return nil } diff --git a/sqlparser/parser_test.go b/sqlparser/parser_test.go index 7bd20b8..3acc3c2 100644 --- a/sqlparser/parser_test.go +++ b/sqlparser/parser_test.go @@ -5,10 +5,27 @@ import ( "strings" "testing" + mssql "github.com/denisenkom/go-mssqldb" + "github.com/jackc/pgx/v5/stdlib" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +func TestPostgresqlCreate(t *testing.T) { + doc := ParseString("test.pgsql", ` +create procedure [code].test() +language plpgsql +as $$ +begin + perform 1; +end; +$$; + `) + + require.Len(t, doc.Creates, 1) + require.Equal(t, &stdlib.Driver{}, doc.Creates[0].Driver) +} + func TestParserSmokeTest(t *testing.T) { doc := ParseString("test.sql", ` /* test is a test @@ -43,6 +60,7 @@ end; require.Equal(t, 1, len(doc.Creates)) c := doc.Creates[0] + require.Equal(t, &mssql.Driver{}, c.Driver) assert.Equal(t, "[TestFunc]", c.QuotedName.Value) assert.Equal(t, []string{"[HelloFunc]", "[OtherFunc]"}, c.DependsOnStrings()) diff --git a/sqltest/sql.go b/sqltest/sql.go index 15d7995..4ad2754 100644 --- a/sqltest/sql.go +++ b/sqltest/sql.go @@ -9,7 +9,11 @@ import ( //go:embed *.sql var sqlfs embed.FS +//go:embed *.pgsql +var pgsqlfx embed.FS + var SQL = sqlcode.MustInclude( sqlcode.Options{}, sqlfs, + pgsqlfx, ) diff --git a/sqltest/sqlcode_test.go b/sqltest/sqlcode_test.go index f343c07..db1eb44 100644 --- a/sqltest/sqlcode_test.go +++ b/sqltest/sqlcode_test.go @@ -2,9 +2,9 @@ package sqltest import ( "context" + "fmt" "testing" - "github.com/jackc/pgx/v5" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -30,7 +30,8 @@ func Test_RowsAffected(t *testing.T) { require.NoError(t, err) assert.Equal(t, int64(1), rowsAffected) - schemas := SQL.ListUploaded(ctx, fixture.DB) + schemas, err := SQL.ListUploaded(ctx, fixture.DB) + require.NoError(t, err) require.Len(t, schemas, 1) require.Equal(t, 6, schemas[0].Objects) require.Equal(t, "5420c0269aaf", schemas[0].Suffix()) @@ -62,22 +63,14 @@ func Test_EnsureUploaded(t *testing.T) { ctx := context.Background() - _, err := fixture.adminDB.Exec(`grant create on database @database to "sqlcode-definer-role"`, - pgx.NamedArgs{"database": fixture.DBName}) + _, err := fixture.DB.Exec( + fmt.Sprintf(`grant create on database "%s" to "sqlcode-definer-role"`, fixture.DBName)) require.NoError(t, err) require.NoError(t, SQL.EnsureUploaded(ctx, fixture.DB)) - patched := SQL.Patch(`[code].Test`) - - res, err := fixture.DB.ExecContext(ctx, patched) - require.NoError(t, err) - rowsAffected, err := res.RowsAffected() + schemas, err := SQL.ListUploaded(ctx, fixture.DB) require.NoError(t, err) - assert.Equal(t, int64(1), rowsAffected) + require.Equal(t, "code@e3b0c44298fc", schemas[0].Name) - schemas := SQL.ListUploaded(ctx, fixture.DB) - require.Len(t, schemas, 1) - require.Equal(t, 6, schemas[0].Objects) - require.Equal(t, "5420c0269aaf", schemas[0].Suffix()) }) } diff --git a/sqltest/test.pgsql b/sqltest/test.pgsql new file mode 100644 index 0000000..e16e7dd --- /dev/null +++ b/sqltest/test.pgsql @@ -0,0 +1,8 @@ + +create procedure [code].test() +language plpgsql +as $$ +begin + perform 1; +end; +$$; \ No newline at end of file From b7b3b75b18ab73ab085f1748c7153001d3bf321f Mon Sep 17 00:00:00 2001 From: Kyle Simukka Date: Fri, 5 Dec 2025 15:04:57 +0100 Subject: [PATCH 08/40] [wip] update parser and scanner --- preprocess_test.go | 385 +++++++++++++++++++++++++++++++++++++-- sqlparser/parser.go | 13 +- sqlparser/parser_test.go | 245 +++++++++++++++++++++++++ sqltest/test.pgsql | 25 ++- 4 files changed, 652 insertions(+), 16 deletions(-) diff --git a/preprocess_test.go b/preprocess_test.go index bf976e8..20998ff 100644 --- a/preprocess_test.go +++ b/preprocess_test.go @@ -1,24 +1,16 @@ package sqlcode import ( + "strings" "testing" + mssql "github.com/denisenkom/go-mssqldb" + "github.com/jackc/pgx/v5/stdlib" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/vippsas/sqlcode/sqlparser" ) -func TestSchemaSuffixFromHash(t *testing.T) { - t.Run("returns a unique hash", func(t *testing.T) { - doc := sqlparser.Document{ - Declares: []sqlparser.Declare{}, - } - - value := SchemaSuffixFromHash(doc) - require.Equal(t, value, SchemaSuffixFromHash(doc)) - }) -} - func TestLineNumberInInput(t *testing.T) { // Scenario: @@ -63,3 +55,374 @@ func TestLineNumberInInput(t *testing.T) { } assert.Equal(t, expectedInputLineNumbers, inputlines[1:]) } + +func TestSchemaSuffixFromHash(t *testing.T) { + t.Run("returns a unique hash", func(t *testing.T) { + doc := sqlparser.Document{ + Declares: []sqlparser.Declare{}, + } + + value := SchemaSuffixFromHash(doc) + require.Equal(t, value, SchemaSuffixFromHash(doc)) + }) + + t.Run("returns consistent hash", func(t *testing.T) { + doc := sqlparser.ParseString("test.sql", ` +declare @EnumFoo int = 1; +go +create procedure [code].Test as begin end +`) + + suffix1 := SchemaSuffixFromHash(doc) + suffix2 := SchemaSuffixFromHash(doc) + + assert.Equal(t, suffix1, suffix2) + assert.Len(t, suffix1, 12) // 6 bytes = 12 hex chars + }) + + t.Run("different content yields different hash", func(t *testing.T) { + doc1 := sqlparser.ParseString("test.sql", ` +declare @EnumFoo int = 1; +go +create procedure [code].Test1 as begin end +`) + doc2 := sqlparser.ParseString("test.sql", ` +declare @EnumFoo int = 2; +go +create procedure [code].Test2 as begin end +`) + + suffix1 := SchemaSuffixFromHash(doc1) + suffix2 := SchemaSuffixFromHash(doc2) + + assert.NotEqual(t, suffix1, suffix2) + }) + + t.Run("empty document has hash", func(t *testing.T) { + doc := sqlparser.Document{} + suffix := SchemaSuffixFromHash(doc) + assert.Len(t, suffix, 12) + }) +} + +func TestSchemaName(t *testing.T) { + assert.Equal(t, "code@abc123", SchemaName("abc123")) + assert.Equal(t, "code@", SchemaName("")) +} + +func TestBatchLineNumberInInput(t *testing.T) { + t.Run("no corrections", func(t *testing.T) { + b := Batch{ + StartPos: sqlparser.Pos{Line: 10, Col: 1}, + Lines: "line1\nline2\nline3", + } + + assert.Equal(t, 10, b.LineNumberInInput(1)) + assert.Equal(t, 11, b.LineNumberInInput(2)) + assert.Equal(t, 12, b.LineNumberInInput(3)) + }) + + t.Run("with corrections", func(t *testing.T) { + b := Batch{ + StartPos: sqlparser.Pos{Line: 10, Col: 1}, + Lines: "line1\nline2\nextra1\nextra2\nline3", + lineNumberCorrections: []lineNumberCorrection{ + {inputLineNumber: 2, extraLinesInOutput: 2}, // line 2 became 3 lines + }, + } + + assert.Equal(t, 10, b.LineNumberInInput(1)) // line 1 -> input line 10 + assert.Equal(t, 11, b.LineNumberInInput(2)) // line 2 -> input line 11 + assert.Equal(t, 11, b.LineNumberInInput(3)) // extra line -> still input line 11 + assert.Equal(t, 11, b.LineNumberInInput(4)) // extra line -> still input line 11 + assert.Equal(t, 12, b.LineNumberInInput(5)) // line 3 -> input line 12 + }) +} + +func TestBatchRelativeLineNumberInInput(t *testing.T) { + t.Run("simple case with no corrections", func(t *testing.T) { + b := Batch{ + lineNumberCorrections: []lineNumberCorrection{}, + } + + assert.Equal(t, 1, b.RelativeLineNumberInInput(1)) + assert.Equal(t, 5, b.RelativeLineNumberInInput(5)) + }) + + t.Run("single correction", func(t *testing.T) { + b := Batch{ + lineNumberCorrections: []lineNumberCorrection{ + {inputLineNumber: 3, extraLinesInOutput: 2}, + }, + } + + assert.Equal(t, 1, b.RelativeLineNumberInInput(1)) + assert.Equal(t, 2, b.RelativeLineNumberInInput(2)) + assert.Equal(t, 3, b.RelativeLineNumberInInput(3)) + assert.Equal(t, 3, b.RelativeLineNumberInInput(4)) // extra line + assert.Equal(t, 3, b.RelativeLineNumberInInput(5)) // extra line + assert.Equal(t, 4, b.RelativeLineNumberInInput(6)) + }) + + t.Run("multiple corrections", func(t *testing.T) { + b := Batch{ + lineNumberCorrections: []lineNumberCorrection{ + {inputLineNumber: 2, extraLinesInOutput: 1}, + {inputLineNumber: 5, extraLinesInOutput: 3}, + }, + } + + assert.Equal(t, 1, b.RelativeLineNumberInInput(1)) + assert.Equal(t, 2, b.RelativeLineNumberInInput(2)) + assert.Equal(t, 2, b.RelativeLineNumberInInput(3)) // extra from line 2 + assert.Equal(t, 3, b.RelativeLineNumberInInput(4)) + assert.Equal(t, 4, b.RelativeLineNumberInInput(5)) + assert.Equal(t, 5, b.RelativeLineNumberInInput(6)) + assert.Equal(t, 5, b.RelativeLineNumberInInput(7)) // extra from line 5 + assert.Equal(t, 5, b.RelativeLineNumberInInput(8)) // extra from line 5 + assert.Equal(t, 5, b.RelativeLineNumberInInput(9)) // extra from line 5 + assert.Equal(t, 6, b.RelativeLineNumberInInput(10)) + }) +} + +func TestPreprocess(t *testing.T) { + t.Run("basic procedure with schema replacement", func(t *testing.T) { + doc := sqlparser.ParseString("test.sql", ` +create procedure [code].Test as +begin + select 1 +end +`) + doc.Creates[0].Driver = &mssql.Driver{} + + result, err := Preprocess(doc, "abc123", &mssql.Driver{}) + require.NoError(t, err) + require.Len(t, result.Batches, 1) + + assert.Contains(t, result.Batches[0].Lines, "[code@abc123]") + assert.NotContains(t, result.Batches[0].Lines, "[code]") + }) + + t.Run("postgres uses unquoted schema name", func(t *testing.T) { + doc := sqlparser.ParseString("test.pgsql", ` +create procedure [code].test() as $$ +begin + perform 1; +end; +$$ language plpgsql; +`) + doc.Creates[0].Driver = &stdlib.Driver{} + + result, err := Preprocess(doc, "abc123", &stdlib.Driver{}) + require.NoError(t, err) + require.Len(t, result.Batches, 1) + + assert.Contains(t, result.Batches[0].Lines, "code@abc123") + assert.NotContains(t, result.Batches[0].Lines, "[code@abc123]") + }) + + t.Run("replaces enum constants", func(t *testing.T) { + doc := sqlparser.ParseString("test.sql", ` +declare @EnumStatus int = 42; +go +create procedure [code].Test as +begin + select @EnumStatus +end +`) + doc.Creates[0].Driver = &mssql.Driver{} + + result, err := Preprocess(doc, "abc123", &mssql.Driver{}) + require.NoError(t, err) + require.Len(t, result.Batches, 1) + + batch := result.Batches[0].Lines + assert.Contains(t, batch, "42/*=@EnumStatus*/") + assert.NotContains(t, batch, "@EnumStatus\n") // shouldn't have bare reference + }) + + t.Run("handles multiline string constants", func(t *testing.T) { + doc := sqlparser.ParseString("test.sql", ` +declare @EnumMulti nvarchar(max) = N'line1 +line2 +line3'; +go +create procedure [code].Test as +begin + select @EnumMulti +end +`) + doc.Creates[0].Driver = &mssql.Driver{} + + result, err := Preprocess(doc, "abc123", &mssql.Driver{}) + require.NoError(t, err) + require.Len(t, result.Batches, 1) + + batch := result.Batches[0] + assert.Contains(t, batch.Lines, "N'line1\nline2\nline3'/*=@EnumMulti*/") + // Should have line number corrections for the 2 extra lines + assert.Len(t, batch.lineNumberCorrections, 1) + assert.Equal(t, 2, batch.lineNumberCorrections[0].extraLinesInOutput) + }) + + t.Run("error on undeclared constant", func(t *testing.T) { + doc := sqlparser.ParseString("test.sql", ` +create procedure [code].Test as +begin + select @EnumUndeclared +end +`) + doc.Creates[0].Driver = &mssql.Driver{} + + _, err := Preprocess(doc, "abc123", &mssql.Driver{}) + require.Error(t, err) + + var preprocErr PreprocessorError + require.ErrorAs(t, err, &preprocErr) + assert.Contains(t, preprocErr.Message, "@EnumUndeclared") + assert.Contains(t, preprocErr.Message, "not declared") + }) + + t.Run("error on schema suffix with bracket", func(t *testing.T) { + doc := sqlparser.ParseString("test.sql", ` +create procedure [code].Test as begin end +`) + doc.Creates[0].Driver = &mssql.Driver{} + + _, err := Preprocess(doc, "abc]123", &mssql.Driver{}) + require.Error(t, err) + assert.Contains(t, err.Error(), "schemasuffix cannot contain") + }) + + t.Run("skips creates with empty body", func(t *testing.T) { + doc := sqlparser.Document{ + Creates: []sqlparser.Create{ + {Body: []sqlparser.Unparsed{}}, + }, + } + + result, err := Preprocess(doc, "abc123", &mssql.Driver{}) + require.NoError(t, err) + assert.Empty(t, result.Batches) + }) + + t.Run("handles multiple creates", func(t *testing.T) { + doc := sqlparser.ParseString("test.sql", ` +create procedure [code].Proc1 as begin select 1 end +go +create procedure [code].Proc2 as begin select 2 end +`) + doc.Creates[0].Driver = &mssql.Driver{} + doc.Creates[1].Driver = &mssql.Driver{} + + result, err := Preprocess(doc, "abc123", &mssql.Driver{}) + require.NoError(t, err) + assert.Len(t, result.Batches, 2) + + assert.Contains(t, result.Batches[0].Lines, "Proc1") + assert.Contains(t, result.Batches[1].Lines, "Proc2") + }) + + t.Run("handles multiple constants in same procedure", func(t *testing.T) { + doc := sqlparser.ParseString("test.sql", ` +declare @EnumA int = 1, @EnumB int = 2; +go +create procedure [code].Test as +begin + select @EnumA, @EnumB +end +`) + doc.Creates[0].Driver = &mssql.Driver{} + + result, err := Preprocess(doc, "abc123", &mssql.Driver{}) + require.NoError(t, err) + require.Len(t, result.Batches, 1) + + batch := result.Batches[0].Lines + assert.Contains(t, batch, "1/*=@EnumA*/") + assert.Contains(t, batch, "2/*=@EnumB*/") + }) + + t.Run("preserves comments and formatting", func(t *testing.T) { + doc := sqlparser.ParseString("test.sql", ` +-- This is a test procedure +create procedure [code].Test as +begin + /* multi + line + comment */ + select 1 +end +`) + doc.Creates[0].Driver = &mssql.Driver{} + + result, err := Preprocess(doc, "abc123", &mssql.Driver{}) + require.NoError(t, err) + require.Len(t, result.Batches, 1) + + batch := result.Batches[0].Lines + assert.Contains(t, batch, "-- This is a test procedure") + assert.Contains(t, batch, "/* multi") + }) + + t.Run("handles const and global prefixes", func(t *testing.T) { + doc := sqlparser.ParseString("test.sql", ` +declare @ConstValue int = 100, @GlobalSetting nvarchar(50) = N'test'; +go +create procedure [code].Test as +begin + select @ConstValue, @GlobalSetting +end +`) + doc.Creates[0].Driver = &mssql.Driver{} + + result, err := Preprocess(doc, "abc123", &mssql.Driver{}) + require.NoError(t, err) + require.Len(t, result.Batches, 1) + + batch := result.Batches[0].Lines + assert.Contains(t, batch, "100/*=@ConstValue*/") + assert.Contains(t, batch, "N'test'/*=@GlobalSetting*/") + }) +} + +func TestPreprocessString(t *testing.T) { + t.Run("replaces code schema", func(t *testing.T) { + result := preprocessString("abc123", "select * from [code].Table") + assert.Equal(t, "select * from [code@abc123].Table", result) + }) + + t.Run("case insensitive replacement", func(t *testing.T) { + result := preprocessString("abc123", "select * from [CODE].Table and [Code].Other") + assert.Contains(t, result, "[code@abc123].Table") + assert.Contains(t, result, "[code@abc123].Other") + }) + + t.Run("multiple occurrences", func(t *testing.T) { + sql := ` + select * from [code].A + join [code].B on A.id = B.id + where exists (select 1 from [code].C) + ` + result := preprocessString("abc123", sql) + assert.Equal(t, 3, strings.Count(result, "[code@abc123]")) + assert.NotContains(t, result, "[code].") + }) + + t.Run("no replacement needed", func(t *testing.T) { + sql := "select * from dbo.Table" + result := preprocessString("abc123", sql) + assert.Equal(t, sql, result) + }) +} + +func TestPreprocessorError(t *testing.T) { + t.Run("formats error message", func(t *testing.T) { + err := PreprocessorError{ + Pos: sqlparser.Pos{File: "test.sql", Line: 10, Col: 5}, + Message: "something went wrong", + } + + assert.Equal(t, "test.sql:10:5: something went wrong", err.Error()) + }) +} diff --git a/sqlparser/parser.go b/sqlparser/parser.go index b61f49e..b88e4ea 100644 --- a/sqlparser/parser.go +++ b/sqlparser/parser.go @@ -9,7 +9,9 @@ import ( "errors" "fmt" "io/fs" + "path/filepath" "regexp" + "slices" "sort" "strings" @@ -19,6 +21,8 @@ import ( var templateRoutineName string = "\ndeclare @RoutineName nvarchar(128)\nset @RoutineName = '%s'\n" +var supportedSqlExtensions []string = []string{".sql", ".pgsql"} + func CopyToken(s *Scanner, target *[]Unparsed) { *target = append(*target, CreateUnparsed(s)) } @@ -564,9 +568,8 @@ func ParseString(filename FileRef, input string) (result Document) { return } -// ParseFileystems iterates through a list of filesystems and parses all files -// matching `*.sql`, determines which one are sqlcode files from the contents, -// and returns the combination of all of them. +// ParseFileystems iterates through a list of filesystems and parses all supported +// SQL files and returns the combination of all of them. // // err will only return errors related to filesystems/reading. Errors // related to parsing/sorting will be in result.Errors. @@ -591,7 +594,9 @@ func ParseFilesystems(fslst []fs.FS, includeTags []string) (filenames []string, if strings.HasPrefix(path, ".") || strings.Contains(path, "/.") { return nil } - if !strings.HasSuffix(path, ".sql") || !strings.HasSuffix(path, ".pgsql") { + + extension := filepath.Ext(path) + if !slices.Contains(supportedSqlExtensions, extension) { return nil } diff --git a/sqlparser/parser_test.go b/sqlparser/parser_test.go index 3acc3c2..fc75460 100644 --- a/sqlparser/parser_test.go +++ b/sqlparser/parser_test.go @@ -2,8 +2,10 @@ package sqlparser import ( "fmt" + "io/fs" "strings" "testing" + "testing/fstest" mssql "github.com/denisenkom/go-mssqldb" "github.com/jackc/pgx/v5/stdlib" @@ -410,3 +412,246 @@ create procedure [code].Foo as begin end err.Error()) } + +func TestParseFilesystems(t *testing.T) { + t.Run("basic parsing of sql files", func(t *testing.T) { + fsys := fstest.MapFS{ + "test1.sql": &fstest.MapFile{ + Data: []byte(` +declare @EnumFoo int = 1; +go +create procedure [code].Proc1 as begin end +`), + }, + "test2.sql": &fstest.MapFile{ + Data: []byte(` +create function [code].Func1() returns int as begin return 1 end +`), + }, + } + + filenames, doc, err := ParseFilesystems([]fs.FS{fsys}, nil) + require.NoError(t, err) + assert.Len(t, filenames, 2) + assert.Len(t, doc.Creates, 2) + assert.Len(t, doc.Declares, 1) + }) + + t.Run("filters by include tags", func(t *testing.T) { + fsys := fstest.MapFS{ + "included.sql": &fstest.MapFile{ + Data: []byte(`--sqlcode:include-if foo,bar +create procedure [code].Included as begin end +`), + }, + "excluded.sql": &fstest.MapFile{ + Data: []byte(`--sqlcode:include-if baz +create procedure [code].Excluded as begin end +`), + }, + } + + filenames, doc, err := ParseFilesystems([]fs.FS{fsys}, []string{"foo", "bar"}) + require.NoError(t, err) + assert.Len(t, filenames, 1) + assert.Contains(t, filenames[0], "included.sql") + assert.Len(t, doc.Creates, 1) + assert.Equal(t, "[Included]", doc.Creates[0].QuotedName.Value) + }) + + t.Run("detects duplicate files with same hash", func(t *testing.T) { + contents := []byte(`create procedure [code].Test as begin end`) + + fs1 := fstest.MapFS{ + "test.sql": &fstest.MapFile{Data: contents}, + } + fs2 := fstest.MapFS{ + "test.sql": &fstest.MapFile{Data: contents}, + } + + _, _, err := ParseFilesystems([]fs.FS{fs1, fs2}, nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "exact same contents") + }) + + t.Run("skips non-sqlcode files", func(t *testing.T) { + fsys := fstest.MapFS{ + "regular.sql": &fstest.MapFile{ + Data: []byte(`select * from table1`), + }, + "sqlcode.sql": &fstest.MapFile{ + Data: []byte(`create procedure [code].Test as begin end`), + }, + } + + filenames, doc, err := ParseFilesystems([]fs.FS{fsys}, nil) + require.NoError(t, err) + assert.Len(t, filenames, 1) + assert.Contains(t, filenames[0], "sqlcode.sql") + assert.Len(t, doc.Creates, 1) + }) + + t.Run("skips hidden directories", func(t *testing.T) { + fsys := fstest.MapFS{ + "visible.sql": &fstest.MapFile{ + Data: []byte(`create procedure [code].Visible as begin end`), + }, + ".hidden/test.sql": &fstest.MapFile{ + Data: []byte(`create procedure [code].Hidden as begin end`), + }, + "dir/.git/test.sql": &fstest.MapFile{ + Data: []byte(`create procedure [code].Git as begin end`), + }, + } + + filenames, doc, err := ParseFilesystems([]fs.FS{fsys}, nil) + require.NoError(t, err) + assert.Len(t, filenames, 1) + assert.Contains(t, filenames[0], "visible.sql") + assert.Len(t, doc.Creates, 1) + }) + + t.Run("handles dependencies and topological sort", func(t *testing.T) { + fsys := fstest.MapFS{ + "proc1.sql": &fstest.MapFile{ + Data: []byte(`create procedure [code].Proc1 as begin exec [code].Proc2 end`), + }, + "proc2.sql": &fstest.MapFile{ + Data: []byte(`create procedure [code].Proc2 as begin select 1 end`), + }, + } + + filenames, doc, err := ParseFilesystems([]fs.FS{fsys}, nil) + require.NoError(t, err) + assert.Len(t, filenames, 2) + assert.Len(t, doc.Creates, 2) + // Proc2 should come before Proc1 due to dependency + assert.Equal(t, "[Proc2]", doc.Creates[0].QuotedName.Value) + assert.Equal(t, "[Proc1]", doc.Creates[1].QuotedName.Value) + }) + + t.Run("reports topological sort errors", func(t *testing.T) { + fsys := fstest.MapFS{ + "circular1.sql": &fstest.MapFile{ + Data: []byte(`create procedure [code].A as begin exec [code].B end`), + }, + "circular2.sql": &fstest.MapFile{ + Data: []byte(`create procedure [code].B as begin exec [code].A end`), + }, + } + + _, doc, err := ParseFilesystems([]fs.FS{fsys}, nil) + require.NoError(t, err) // filesystem error should be nil + assert.NotEmpty(t, doc.Errors) // but parsing errors should exist + assert.Contains(t, doc.Errors[0].Message, "Detected a dependency cycle") + }) + + t.Run("handles multiple filesystems", func(t *testing.T) { + fs1 := fstest.MapFS{ + "test1.sql": &fstest.MapFile{ + Data: []byte(`create procedure [code].Proc1 as begin end`), + }, + } + fs2 := fstest.MapFS{ + "test2.sql": &fstest.MapFile{ + Data: []byte(`create procedure [code].Proc2 as begin end`), + }, + } + + filenames, doc, err := ParseFilesystems([]fs.FS{fs1, fs2}, nil) + require.NoError(t, err) + assert.Len(t, filenames, 2) + assert.Contains(t, filenames[0], "fs[0]:") + assert.Contains(t, filenames[1], "fs[1]:") + assert.Len(t, doc.Creates, 2) + }) + + t.Run("detects sqlcode files by pragma header", func(t *testing.T) { + fsys := fstest.MapFS{ + "test.sql": &fstest.MapFile{ + Data: []byte(`--sqlcode:include-if foo +create procedure NotInCodeSchema.Test as begin end`), + }, + } + + filenames, doc, err := ParseFilesystems([]fs.FS{fsys}, []string{"foo"}) + require.NoError(t, err) + assert.Len(t, filenames, 1) + // Should still parse even though it will have errors (not in [code] schema) + assert.NotEmpty(t, doc.Errors) + }) + + t.Run("handles pgsql extension", func(t *testing.T) { + fsys := fstest.MapFS{ + "test.pgsql": &fstest.MapFile{ + Data: []byte(` +create procedure [code].test() +language plpgsql +as $$ +begin + perform 1; +end; +$$; +`), + }, + } + + filenames, doc, err := ParseFilesystems([]fs.FS{fsys}, nil) + require.NoError(t, err) + assert.Len(t, filenames, 1) + assert.Len(t, doc.Creates, 1) + assert.Equal(t, &stdlib.Driver{}, doc.Creates[0].Driver) + }) + + t.Run("empty filesystem returns empty results", func(t *testing.T) { + fsys := fstest.MapFS{} + + filenames, doc, err := ParseFilesystems([]fs.FS{fsys}, nil) + require.NoError(t, err) + assert.Empty(t, filenames) + assert.Empty(t, doc.Creates) + assert.Empty(t, doc.Declares) + }) +} + +func TestMatchesIncludeTags(t *testing.T) { + t.Run("empty requirements matches anything", func(t *testing.T) { + assert.True(t, matchesIncludeTags([]string{}, []string{})) + assert.True(t, matchesIncludeTags([]string{}, []string{"foo"})) + }) + + t.Run("all requirements must be met", func(t *testing.T) { + assert.True(t, matchesIncludeTags([]string{"foo", "bar"}, []string{"foo", "bar", "baz"})) + assert.False(t, matchesIncludeTags([]string{"foo", "bar"}, []string{"foo"})) + assert.False(t, matchesIncludeTags([]string{"foo", "bar"}, []string{"bar"})) + }) + + t.Run("exact match", func(t *testing.T) { + assert.True(t, matchesIncludeTags([]string{"foo"}, []string{"foo"})) + assert.False(t, matchesIncludeTags([]string{"foo"}, []string{"bar"})) + }) +} + +func TestIsSqlcodeConstVariable(t *testing.T) { + testCases := []struct { + name string + varname string + expected bool + }{ + {"@Enum prefix", "@EnumFoo", true}, + {"@ENUM_ prefix", "@ENUM_FOO", true}, + {"@enum_ prefix", "@enum_foo", true}, + {"@Const prefix", "@ConstFoo", true}, + {"@CONST_ prefix", "@CONST_FOO", true}, + {"@const_ prefix", "@const_foo", true}, + {"regular variable", "@MyVariable", false}, + {"@Global prefix", "@GlobalVar", false}, + {"no @ prefix", "EnumFoo", false}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.expected, IsSqlcodeConstVariable(tc.varname)) + }) + } +} diff --git a/sqltest/test.pgsql b/sqltest/test.pgsql index e16e7dd..49992da 100644 --- a/sqltest/test.pgsql +++ b/sqltest/test.pgsql @@ -1,5 +1,28 @@ +-- consts can be -create procedure [code].test() +-- sqlcode +-- we define schemas per deployment +-- uploading all stored functions/procedures/types/consts to a schema +-- pods are restarted/deployed per deployment + +-- aaa bbb + 3 1 + +(iof increase in errors, stop deployment) +-- aaa bbb + 3 0 +-- aaa bbb + 1 2 +-- aaa bbb + 0 3 +-- + +-- ++ both mssql and pgsql have the same architecture with schemas and stored functions/procedures + +-- Q: constants? +-- we have the same constants defined in both mssql and pggsql + +create procedure [code].test() -- expands to code@aaa.test language plpgsql as $$ begin From dd9b48e7138766d011ee05a3ae516ec2c349402f Mon Sep 17 00:00:00 2001 From: Kyle Simukka Date: Tue, 9 Dec 2025 21:27:56 +0100 Subject: [PATCH 09/40] Fixed issue with Preprocess. Passing pgsql tests. --- deployable.go | 21 +++++- deployable_test.go | 6 ++ dockerfile.test | 1 + ...{0003.sqlcode.pgsql => 0001.sqlcode.pgsql} | 0 preprocess.go | 26 ++++--- preprocess_test.go | 13 ++-- sqltest/fixture.go | 31 ++++++++ sqltest/sqlcode_test.go | 75 ++++++++++--------- sqltest/test.pgsql | 26 +------ 9 files changed, 118 insertions(+), 81 deletions(-) rename migrations/{0003.sqlcode.pgsql => 0001.sqlcode.pgsql} (100%) diff --git a/deployable.go b/deployable.go index ead2709..ec784d7 100644 --- a/deployable.go +++ b/deployable.go @@ -99,7 +99,6 @@ func (d *Deployable) Upload(ctx context.Context, dbc DB) error { } preprocessed, err := Preprocess(d.CodeBase, d.SchemaSuffix, dbc.Driver()) - if err != nil { _ = tx.Rollback() return err @@ -244,11 +243,28 @@ func (d Deployable) DropAndUpload(ctx context.Context, dbc DB) error { } // Patch will preprocess the sql passed in so that it will call SQL code -// deployed by the receiver Deployable +// deployed by the receiver Deployable for SQL Server. +// NOTE: This will be deprecated and eventually replaced with CodePatch. func (d Deployable) Patch(sql string) string { return preprocessString(d.SchemaSuffix, sql) } +// CodePatch will preprocess the sql passed in to call +// the correct SQL code deployed to the provided database. +// Q: Nameing? DBPatch, PatchV2, ??? +func (d Deployable) CodePatch(dbc *sql.DB, sql string) string { + driver := dbc.Driver() + if _, ok := driver.(*mssql.Driver); ok { + return codeSchemaRegexp.ReplaceAllString(sql, fmt.Sprintf(`[code@%s]`, d.SchemaSuffix)) + } + + if _, ok := driver.(*stdlib.Driver); ok { + return codeSchemaRegexp.ReplaceAllString(sql, fmt.Sprintf(`"code@%s"`, d.SchemaSuffix)) + } + + panic("unhandled sql driver") +} + func (d *Deployable) markAsUploaded(dbc DB) { d.uploaded[dbc] = struct{}{} } @@ -260,7 +276,6 @@ func (d *Deployable) IsUploadedFromCache(dbc DB) bool { // TODO: StringConst. This requires parsing a SQL literal, a bit too complex // to code up just-in-case - func (d Deployable) IntConst(s string) (int, error) { for _, declare := range d.CodeBase.Declares { if declare.VariableName == s { diff --git a/deployable_test.go b/deployable_test.go index 1e9dac5..7b87b57 100644 --- a/deployable_test.go +++ b/deployable_test.go @@ -21,5 +21,11 @@ declare @EnumInt int = 1, @EnumString varchar(max) = 'hello'; n, err := d.IntConst("@EnumInt") require.NoError(t, err) assert.Equal(t, 1, n) +} +func TestPatch(t *testing.T) { + t.Run("mssql schemasuffix", func(t *testing.T) { + d := Deployable{} + require.Equal(t, "[code@].Foo", d.Patch("[code].Foo")) + }) } diff --git a/dockerfile.test b/dockerfile.test index f4a199f..d8cdf13 100644 --- a/dockerfile.test +++ b/dockerfile.test @@ -3,4 +3,5 @@ WORKDIR /sqlcode ENV GODEBUG="x509negativeserial=1" COPY . . RUN go mod tidy +# Skip the example folder because it has examples of what-not-to-do CMD ["go", "test", "-v", "$(go list ./... | grep -v './example')"] \ No newline at end of file diff --git a/migrations/0003.sqlcode.pgsql b/migrations/0001.sqlcode.pgsql similarity index 100% rename from migrations/0003.sqlcode.pgsql rename to migrations/0001.sqlcode.pgsql diff --git a/preprocess.go b/preprocess.go index 2c6f647..c4a776b 100644 --- a/preprocess.go +++ b/preprocess.go @@ -6,6 +6,7 @@ import ( "encoding/hex" "errors" "fmt" + "reflect" "regexp" "strings" @@ -131,7 +132,6 @@ func sqlcodeTransformCreate(declares map[string]string, c sqlparser.Create, quot result.lineNumberCorrections = append(result.lineNumberCorrections, lineNumberCorrection{relativeLine, newlineCount}) } } - if _, err = w.WriteString(token); err != nil { return } @@ -153,25 +153,29 @@ func Preprocess(doc sqlparser.Document, schemasuffix string, driver driver.Drive declares[dec.VariableName] = dec.Literal.RawValue } + // The current sql driver that we are preparring for + currentDriver := reflect.TypeOf(driver) + + // the default target for mssql + target := fmt.Sprintf(`[code@%s]`, schemasuffix) + + // pgsql target + if _, ok := driver.(*stdlib.Driver); ok { + target = fmt.Sprintf(`"code@%s"`, schemasuffix) + } + for _, create := range doc.Creates { if len(create.Body) == 0 { continue } - if create.Driver != driver { - // continue - } - // TODO(ks) this is not reached - target := "[code@" + schemasuffix + "]" - - if _, ok := create.Driver.(*stdlib.Driver); ok { - target = "code@" + schemasuffix + if !currentDriver.AssignableTo(reflect.TypeOf(create.Driver)) { + // this batch is for a different sql driver + continue } - batch, err := sqlcodeTransformCreate(declares, create, target) if err != nil { return result, fmt.Errorf("failed to transform create: %w", err) } - fmt.Print(batch) result.Batches = append(result.Batches, batch) } diff --git a/preprocess_test.go b/preprocess_test.go index 20998ff..940b8f6 100644 --- a/preprocess_test.go +++ b/preprocess_test.go @@ -199,8 +199,8 @@ end require.NoError(t, err) require.Len(t, result.Batches, 1) - assert.Contains(t, result.Batches[0].Lines, "[code@abc123]") - assert.NotContains(t, result.Batches[0].Lines, "[code]") + assert.Contains(t, result.Batches[0].Lines, "[code@abc123].") + assert.NotContains(t, result.Batches[0].Lines, "[code].") }) t.Run("postgres uses unquoted schema name", func(t *testing.T) { @@ -217,8 +217,8 @@ $$ language plpgsql; require.NoError(t, err) require.Len(t, result.Batches, 1) - assert.Contains(t, result.Batches[0].Lines, "code@abc123") - assert.NotContains(t, result.Batches[0].Lines, "[code@abc123]") + assert.Contains(t, result.Batches[0].Lines, "code@abc123.") + assert.NotContains(t, result.Batches[0].Lines, "[code@abc123].") }) t.Run("replaces enum constants", func(t *testing.T) { @@ -367,7 +367,8 @@ end t.Run("handles const and global prefixes", func(t *testing.T) { doc := sqlparser.ParseString("test.sql", ` -declare @ConstValue int = 100, @GlobalSetting nvarchar(50) = N'test'; +declare @ConstValue int = 100; +declare @GlobalSetting nvarchar(50) = N'test'; go create procedure [code].Test as begin @@ -382,7 +383,7 @@ end batch := result.Batches[0].Lines assert.Contains(t, batch, "100/*=@ConstValue*/") - assert.Contains(t, batch, "N'test'/*=@GlobalSetting*/") + assert.NotContains(t, batch, "N'test'/*=@GlobalSetting*/") }) } diff --git a/sqltest/fixture.go b/sqltest/fixture.go index 82f395d..e05c418 100644 --- a/sqltest/fixture.go +++ b/sqltest/fixture.go @@ -6,6 +6,7 @@ import ( "fmt" "os" "strings" + "testing" "time" mssql "github.com/denisenkom/go-mssqldb" @@ -140,6 +141,24 @@ func NewFixture() *Fixture { if err != nil { panic(err) } + + var user string + err = fixture.DB.QueryRow(`select current_user`).Scan(&user) + if err != nil { + panic(err) + } + _, err = fixture.DB.Exec(fmt.Sprintf(`GRANT ALL ON DATABASE "%s" TO sa;`, fixture.DBName)) + if err != nil { + panic(err) + } + _, err = fixture.DB.Exec(fmt.Sprintf(`GRANT ALL PRIVILEGES ON SCHEMA public TO %s;`, user)) + if err != nil { + panic(err) + } + _, err = fixture.DB.Exec(fmt.Sprintf(`ALTER DATABASE "%s" OWNER TO sa;`, fixture.DBName)) + if err != nil { + panic(err) + } } return &fixture @@ -177,3 +196,15 @@ func (f *Fixture) RunMigrationFile(filename string) { } } } + +func (f *Fixture) RunIfPostgres(t *testing.T, name string, fn func(t *testing.T)) { + if f.IsPostgresql() { + t.Run(name, fn) + } +} + +func (f *Fixture) RunIfMssql(t *testing.T, name string, fn func(t *testing.T)) { + if f.IsSqlServer() { + t.Run(name, fn) + } +} diff --git a/sqltest/sqlcode_test.go b/sqltest/sqlcode_test.go index db1eb44..910064b 100644 --- a/sqltest/sqlcode_test.go +++ b/sqltest/sqlcode_test.go @@ -9,68 +9,71 @@ import ( "github.com/stretchr/testify/require" ) -func Test_RowsAffected(t *testing.T) { +func Test_Patch(t *testing.T) { fixture := NewFixture() + ctx := context.Background() defer fixture.Teardown() - t.Run("mssql", func(t *testing.T) { - if !fixture.IsSqlServer() { - t.Skip() - } + if fixture.IsSqlServer() { fixture.RunMigrationFile("../migrations/0001.sqlcode.sql") + } - ctx := context.Background() + if fixture.IsPostgresql() { + fixture.RunMigrationFile("../migrations/0001.sqlcode.pgsql") + _, err := fixture.DB.Exec( + fmt.Sprintf(`grant create on database "%s" to "sqlcode-definer-role"`, fixture.DBName)) + require.NoError(t, err) + } - require.NoError(t, SQL.EnsureUploaded(ctx, fixture.DB)) - patched := SQL.Patch(`[code].Test`) + require.NoError(t, SQL.EnsureUploaded(ctx, fixture.DB)) + fixture.RunIfMssql(t, "mssql", func(t *testing.T) { + patched := SQL.CodePatch(fixture.DB, `[code].Test`) res, err := fixture.DB.ExecContext(ctx, patched) require.NoError(t, err) + rowsAffected, err := res.RowsAffected() require.NoError(t, err) assert.Equal(t, int64(1), rowsAffected) + }) - schemas, err := SQL.ListUploaded(ctx, fixture.DB) + fixture.RunIfPostgres(t, "pgsql", func(t *testing.T) { + patched := SQL.CodePatch(fixture.DB, `call [code].Test()`) + res, err := fixture.DB.ExecContext(ctx, patched) require.NoError(t, err) - require.Len(t, schemas, 1) - require.Equal(t, 6, schemas[0].Objects) - require.Equal(t, "5420c0269aaf", schemas[0].Suffix()) + // postgresql perform does not result with affected rows + rowsAffected, err := res.RowsAffected() + require.NoError(t, err) + assert.Equal(t, int64(0), rowsAffected) }) } func Test_EnsureUploaded(t *testing.T) { - fixture := NewFixture() - defer fixture.Teardown() - - t.Run("mssql", func(t *testing.T) { - if !fixture.IsSqlServer() { - t.Skip() - } - fixture.RunMigrationFile("../migrations/0001.sqlcode.sql") + f := NewFixture() + defer f.Teardown() + ctx := context.Background() + + f.RunIfMssql(t, "mssql", func(t *testing.T) { + f.RunMigrationFile("../migrations/0001.sqlcode.sql") + require.NoError(t, SQL.EnsureUploaded(ctx, f.DB)) + schemas, err := SQL.ListUploaded(ctx, f.DB) + require.NoError(t, err) + require.Len(t, schemas, 1) - ctx := context.Background() - require.NoError(t, SQL.EnsureUploaded(ctx, fixture.DB)) }) - t.Run("pgsql", func(t *testing.T) { - if !fixture.IsPostgresql() { - t.Skip() - } - - fixture.RunMigrationFile("../migrations/0003.sqlcode.pgsql") + f.RunIfPostgres(t, "pgsql", func(t *testing.T) { + f.RunMigrationFile("../migrations/0001.sqlcode.pgsql") - ctx := context.Background() - - _, err := fixture.DB.Exec( - fmt.Sprintf(`grant create on database "%s" to "sqlcode-definer-role"`, fixture.DBName)) + _, err := f.DB.Exec( + fmt.Sprintf(`grant create on database "%s" to "sqlcode-definer-role"`, f.DBName)) require.NoError(t, err) - require.NoError(t, SQL.EnsureUploaded(ctx, fixture.DB)) - schemas, err := SQL.ListUploaded(ctx, fixture.DB) + require.NoError(t, SQL.EnsureUploaded(ctx, f.DB)) + schemas, err := SQL.ListUploaded(ctx, f.DB) require.NoError(t, err) - require.Equal(t, "code@e3b0c44298fc", schemas[0].Name) - + require.Len(t, schemas, 1) }) } diff --git a/sqltest/test.pgsql b/sqltest/test.pgsql index 49992da..bff51a5 100644 --- a/sqltest/test.pgsql +++ b/sqltest/test.pgsql @@ -1,28 +1,4 @@ --- consts can be - --- sqlcode --- we define schemas per deployment --- uploading all stored functions/procedures/types/consts to a schema --- pods are restarted/deployed per deployment - --- aaa bbb - 3 1 - -(iof increase in errors, stop deployment) --- aaa bbb - 3 0 --- aaa bbb - 1 2 --- aaa bbb - 0 3 --- - --- ++ both mssql and pgsql have the same architecture with schemas and stored functions/procedures - --- Q: constants? --- we have the same constants defined in both mssql and pggsql - -create procedure [code].test() -- expands to code@aaa.test +create procedure [code].test() language plpgsql as $$ begin From 9cbbe00b94f8e65e2f7c958b3b1ca47bd6d8a7a3 Mon Sep 17 00:00:00 2001 From: Kyle Simukka Date: Tue, 9 Dec 2025 21:28:06 +0100 Subject: [PATCH 10/40] Updated GO workflow to test both drivers. --- .github/workflows/go.yml | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 926f7a6..e5272b4 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -1,7 +1,7 @@ # This workflow will build a golang project # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-go -name: go-querysql-test +name: sqlcode on: pull_request: @@ -11,19 +11,16 @@ jobs: build: runs-on: ubuntu-latest - env: - SQLSERVER_DSN: "sqlserver://127.0.0.1:1433?database=master&user id=sa&password=VippsPw1" + strategy: + matrix: + drivers: ['mssql','pgsql'] steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - name: Set up Go uses: actions/setup-go@v5 with: - go-version: '1.23' - - - name: Start db - run: docker compose -f docker-compose.test.yml up -d + go-version: '1.25' - name: Test - # Skip the example folder because it has examples of what-not-to-do - run: go test -v $(go list ./... | grep -v './example') + run: docker compose -f docker-compose.${{ matrix.driver }}.yml run test \ No newline at end of file From f7fe0068fac597ee2c9dc3f7bfa53ea21d76dc8c Mon Sep 17 00:00:00 2001 From: Kyle Simukka Date: Tue, 9 Dec 2025 21:29:34 +0100 Subject: [PATCH 11/40] Fixed typo in GH workflow. --- .github/workflows/go.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index e5272b4..64b63c3 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -13,7 +13,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - drivers: ['mssql','pgsql'] + driver: ['mssql','pgsql'] steps: - uses: actions/checkout@v5 From 365cffa599b98474785fe74001acffc388994323 Mon Sep 17 00:00:00 2001 From: Kyle Simukka Date: Tue, 9 Dec 2025 21:41:36 +0100 Subject: [PATCH 12/40] Updated Dockerfile --- dockerfile.test | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dockerfile.test b/dockerfile.test index d8cdf13..286268c 100644 --- a/dockerfile.test +++ b/dockerfile.test @@ -1,7 +1,7 @@ -FROM golang:1.25.1 AS builder +FROM golang:1.25 AS builder WORKDIR /sqlcode ENV GODEBUG="x509negativeserial=1" COPY . . RUN go mod tidy # Skip the example folder because it has examples of what-not-to-do -CMD ["go", "test", "-v", "$(go list ./... | grep -v './example')"] \ No newline at end of file +CMD ["go", "test", "-v", "./..."] \ No newline at end of file From b3b3d137d519dd548719d2890dec4244e8963e6b Mon Sep 17 00:00:00 2001 From: Kyle Simukka Date: Tue, 9 Dec 2025 21:44:21 +0100 Subject: [PATCH 13/40] Use build tags to exclude examples from bulid & test --- example/basic/example.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/example/basic/example.go b/example/basic/example.go index abe1194..9406915 100644 --- a/example/basic/example.go +++ b/example/basic/example.go @@ -1,3 +1,6 @@ +//go:build examples +// +build examples + package example import ( From 6fb9ea5edd58db3e317f8a6a9f02d75851744862 Mon Sep 17 00:00:00 2001 From: Kyle Simukka Date: Tue, 9 Dec 2025 21:46:06 +0100 Subject: [PATCH 14/40] Exclude example test --- example/basic/example_test.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/example/basic/example_test.go b/example/basic/example_test.go index 079bd91..0c78c63 100644 --- a/example/basic/example_test.go +++ b/example/basic/example_test.go @@ -1,13 +1,17 @@ +//go:build examples +// +build examples + package example import ( "context" "fmt" + "testing" + "time" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/vippsas/sqlcode/sqltest" - "testing" - "time" ) func TestPreprocess(t *testing.T) { From c587c0fc9577a1b7fa33551477e71809798b2eb6 Mon Sep 17 00:00:00 2001 From: Kyle Simukka Date: Tue, 9 Dec 2025 21:47:16 +0100 Subject: [PATCH 15/40] Fixed failing test. --- preprocess_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/preprocess_test.go b/preprocess_test.go index 940b8f6..9d0d2ee 100644 --- a/preprocess_test.go +++ b/preprocess_test.go @@ -217,7 +217,7 @@ $$ language plpgsql; require.NoError(t, err) require.Len(t, result.Batches, 1) - assert.Contains(t, result.Batches[0].Lines, "code@abc123.") + assert.Contains(t, result.Batches[0].Lines, `"code@abc123".`) assert.NotContains(t, result.Batches[0].Lines, "[code@abc123].") }) From 1ebaa6f72c862e3e6a28a1c31ceabb564322082d Mon Sep 17 00:00:00 2001 From: Kyle Simukka Date: Thu, 11 Dec 2025 18:40:04 +0100 Subject: [PATCH 16/40] Moved Document structs to a separate file for better organization. --- sqlparser/document.go | 505 ++++++++++++++++++++++++++++++++++++++++++ sqlparser/parser.go | 492 ---------------------------------------- 2 files changed, 505 insertions(+), 492 deletions(-) create mode 100644 sqlparser/document.go diff --git a/sqlparser/document.go b/sqlparser/document.go new file mode 100644 index 0000000..c907aa8 --- /dev/null +++ b/sqlparser/document.go @@ -0,0 +1,505 @@ +package sqlparser + +import ( + "fmt" + "sort" + "strings" + + "github.com/jackc/pgx/v5/stdlib" + mssql "github.com/microsoft/go-mssqldb" +) + +type Document struct { + PragmaIncludeIf []string + Creates []Create + Declares []Declare + Errors []Error +} + +func (d *Document) addError(s *Scanner, msg string) { + d.Errors = append(d.Errors, Error{ + Pos: s.Start(), + Message: msg, + }) +} + +func (d *Document) unexpectedTokenError(s *Scanner) { + d.addError(s, "Unexpected: "+s.Token()) +} + +func (doc *Document) parseTypeExpression(s *Scanner) (t Type) { + parseArgs := func() { + // parses *after* the initial (; consumes trailing ) + for { + switch { + case s.TokenType() == NumberToken: + t.Args = append(t.Args, s.Token()) + case s.TokenType() == UnquotedIdentifierToken && s.TokenLower() == "max": + t.Args = append(t.Args, "max") + default: + doc.unexpectedTokenError(s) + doc.recoverToNextStatement(s) + return + } + s.NextNonWhitespaceCommentToken() + switch { + case s.TokenType() == CommaToken: + s.NextNonWhitespaceCommentToken() + continue + case s.TokenType() == RightParenToken: + s.NextNonWhitespaceCommentToken() + return + default: + doc.unexpectedTokenError(s) + doc.recoverToNextStatement(s) + return + } + } + } + + if s.TokenType() != UnquotedIdentifierToken { + panic("assertion failed, bug in caller") + } + t.BaseType = s.Token() + s.NextNonWhitespaceCommentToken() + if s.TokenType() == LeftParenToken { + s.NextNonWhitespaceCommentToken() + parseArgs() + } + return +} + +func (doc *Document) parseDeclare(s *Scanner) (result []Declare) { + declareStart := s.Start() + // parse what is *after* the `declare` reserved keyword +loop: + for { + if s.TokenType() != VariableIdentifierToken { + doc.unexpectedTokenError(s) + doc.recoverToNextStatement(s) + return + } + + variableName := s.Token() + if !strings.HasPrefix(strings.ToLower(variableName), "@enum") && + !strings.HasPrefix(strings.ToLower(variableName), "@global") && + !strings.HasPrefix(strings.ToLower(variableName), "@const") { + doc.addError(s, "sqlcode constants needs to have names starting with @Enum, @Global or @Const: "+variableName) + } + s.NextNonWhitespaceCommentToken() + var variableType Type + switch s.TokenType() { + case EqualToken: + doc.addError(s, "sqlcode constants needs a type declared explicitly") + s.NextNonWhitespaceCommentToken() + case UnquotedIdentifierToken: + variableType = doc.parseTypeExpression(s) + } + + if s.TokenType() != EqualToken { + doc.addError(s, "sqlcode constants needs to be assigned at once using =") + doc.recoverToNextStatement(s) + } + + switch s.NextNonWhitespaceCommentToken() { + case NumberToken, NVarcharLiteralToken, VarcharLiteralToken: + result = append(result, Declare{ + Start: declareStart, + Stop: s.Stop(), + VariableName: variableName, + Datatype: variableType, + Literal: CreateUnparsed(s), + }) + default: + doc.unexpectedTokenError(s) + doc.recoverToNextStatement(s) + return + } + + switch s.NextNonWhitespaceCommentToken() { + case CommaToken: + s.NextNonWhitespaceCommentToken() + continue + case SemicolonToken: + s.NextNonWhitespaceCommentToken() + break loop + default: + break loop + } + } + if len(result) == 0 { + doc.addError(s, "incorrect syntax; no variables successfully declared") + } + return +} + +func (doc *Document) parseBatchSeparator(s *Scanner) { + // just saw a 'go'; just make sure there's nothing bad trailing it + // (if there is, convert to errors and move on until the line is consumed + errorEmitted := false + for { + switch s.NextToken() { + case WhitespaceToken: + continue + case MalformedBatchSeparatorToken: + if !errorEmitted { + doc.addError(s, "`go` should be alone on a line without any comments") + errorEmitted = true + } + continue + default: + return + } + } +} + +func (doc *Document) parseDeclareBatch(s *Scanner) (hasMore bool) { + if s.ReservedWord() != "declare" { + panic("assertion failed, incorrect use in caller") + } + for { + tt := s.TokenType() + switch { + case tt == EOFToken: + return false + case tt == ReservedWordToken && s.ReservedWord() == "declare": + s.NextNonWhitespaceCommentToken() + d := doc.parseDeclare(s) + doc.Declares = append(doc.Declares, d...) + case tt == ReservedWordToken && s.ReservedWord() != "declare": + doc.addError(s, "Only 'declare' allowed in this batch") + doc.recoverToNextStatement(s) + case tt == BatchSeparatorToken: + doc.parseBatchSeparator(s) + return true + default: + doc.unexpectedTokenError(s) + doc.recoverToNextStatement(s) + } + } +} + +func (doc *Document) parseBatch(s *Scanner, isFirst bool) (hasMore bool) { + var nodes []Unparsed + var docstring []PosString + newLineEncounteredInDocstring := false + + var createCountInBatch int + + for { + tt := s.TokenType() + switch tt { + case EOFToken: + return false + case WhitespaceToken, MultilineCommentToken: + nodes = append(nodes, CreateUnparsed(s)) + // do not reset token for a single trailing newline + t := s.Token() + if !newLineEncounteredInDocstring && (t == "\n" || t == "\r\n") { + newLineEncounteredInDocstring = true + } else { + docstring = nil + } + s.NextToken() + case SinglelineCommentToken: + // We build up a list of single line comments for the "docstring"; + // it is reset whenever we encounter something else + docstring = append(docstring, PosString{s.Start(), s.Token()}) + nodes = append(nodes, CreateUnparsed(s)) + newLineEncounteredInDocstring = false + s.NextToken() + case ReservedWordToken: + switch s.ReservedWord() { + case "declare": + // First declare-statement; enter a mode where we assume all contents + // of batch are declare statements + if !isFirst { + doc.addError(s, "'declare' statement only allowed in first batch") + } + // regardless of errors, go on and parse as far as we get... + return doc.parseDeclareBatch(s) + case "create": + // should be start of create procedure or create function... + c := doc.parseCreate(s, createCountInBatch) + + if strings.HasSuffix(string(s.file), ".sql") { + c.Driver = &mssql.Driver{} + } + if strings.HasSuffix(string(s.file), ".pgsql") { + c.Driver = &stdlib.Driver{} + } + + // *prepend* what we saw before getting to the 'create' + createCountInBatch++ + c.Body = append(nodes, c.Body...) + c.Docstring = docstring + doc.Creates = append(doc.Creates, c) + default: + doc.addError(s, "Expected 'declare' or 'create', got: "+s.ReservedWord()) + s.NextToken() + } + case BatchSeparatorToken: + doc.parseBatchSeparator(s) + return true + default: + doc.unexpectedTokenError(s) + s.NextToken() + docstring = nil + } + } +} + +func (d *Document) recoverToNextStatementCopying(s *Scanner, target *[]Unparsed) { + // We hit an unexpected token ... as an heuristic for continuing parsing, + // skip parsing until we hit a reserved word that starts a statement + // we recognize + for { + NextTokenCopyingWhitespace(s, target) + switch s.TokenType() { + case ReservedWordToken: + switch s.ReservedWord() { + case "declare", "create", "go": + return + } + case EOFToken: + return + default: + CopyToken(s, target) + } + } +} + +func (d *Document) recoverToNextStatement(s *Scanner) { + // We hit an unexpected token ... as an heuristic for continuing parsing, + // skip parsing until we hit a reserved word that starts a statement + // we recognize + for { + s.NextNonWhitespaceCommentToken() + switch s.TokenType() { + case ReservedWordToken: + switch s.ReservedWord() { + case "declare", "create", "go": + return + } + case EOFToken: + return + } + } +} + +// parseCodeschemaName parses `[code] . something`, and returns `something` +// in quoted form (`[something]`). Also copy to `target`. Empty string on error. +// Note: To follow conventions, consume one extra token at the end even if we know +// it fill not be consumed by this function... +func (d *Document) parseCodeschemaName(s *Scanner, target *[]Unparsed) PosString { + CopyToken(s, target) + NextTokenCopyingWhitespace(s, target) + if s.TokenType() != DotToken { + d.addError(s, fmt.Sprintf("[code] must be followed by '.'")) + d.recoverToNextStatementCopying(s, target) + return PosString{Value: ""} + } + CopyToken(s, target) + + NextTokenCopyingWhitespace(s, target) + switch s.TokenType() { + case UnquotedIdentifierToken: + // To get something uniform for comparison, quote all names + CopyToken(s, target) + result := PosString{Pos: s.Start(), Value: "[" + s.Token() + "]"} + NextTokenCopyingWhitespace(s, target) + return result + case QuotedIdentifierToken: + CopyToken(s, target) + result := PosString{Pos: s.Start(), Value: s.Token()} + NextTokenCopyingWhitespace(s, target) + return result + default: + d.addError(s, fmt.Sprintf("[code]. must be followed an identifier")) + d.recoverToNextStatementCopying(s, target) + return PosString{Value: ""} + } +} + +// parseCreate parses anything that starts with "create". Position is +// *on* the create token. +// At this stage in sqlcode parser development we're only interested +// in procedures/functions/types as opaque blocks of SQL code where +// we only track dependencies between them and their declared name; +// so we treat them with the same code. We consume until the end of +// the batch; only one declaration allowed per batch. Everything +// parsed here will also be added to `batch`. On any error, copying +// to batch stops / becomes erratic.. +func (d *Document) parseCreate(s *Scanner, createCountInBatch int) (result Create) { + if s.ReservedWord() != "create" { + panic("illegal use by caller") + } + CopyToken(s, &result.Body) + + NextTokenCopyingWhitespace(s, &result.Body) + + createType := strings.ToLower(s.Token()) + if !(createType == "procedure" || createType == "function" || createType == "type") { + d.addError(s, fmt.Sprintf("sqlcode only supports creating procedures, functions or types; not `%s`", createType)) + d.recoverToNextStatementCopying(s, &result.Body) + return + } + if (createType == "procedure" || createType == "function") && createCountInBatch > 0 { + d.addError(s, "a procedure/function must be alone in a batch; use 'go' to split batches") + d.recoverToNextStatementCopying(s, &result.Body) + return + } + + result.CreateType = createType + CopyToken(s, &result.Body) + + NextTokenCopyingWhitespace(s, &result.Body) + + // Insist on [code]. + if s.TokenType() != QuotedIdentifierToken || s.Token() != "[code]" { + d.addError(s, fmt.Sprintf("create %s must be followed by [code].", result.CreateType)) + d.recoverToNextStatementCopying(s, &result.Body) + return + } + result.QuotedName = d.parseCodeschemaName(s, &result.Body) + if result.QuotedName.String() == "" { + return + } + + // We have matched "create [code]."; at this + // point we copy the rest until the batch ends; *but* track dependencies + // + some other details mentioned below + + //firstAs := true // See comment below on rowcount + +tailloop: + for { + tt := s.TokenType() + switch { + case tt == ReservedWordToken && s.ReservedWord() == "create": + // So, we're currently parsing 'create ...' and we see another 'create'. + // We split in two cases depending on the context we are currently in + // (createType is referring to how we entered this function, *NOT* the + // `create` statement we are looking at now + switch createType { // note: this is the *outer* create type, not the one of current scanner position + case "function", "procedure": + // Within a function/procedure we can allow 'create index', 'create table' and nothing + // else. (Well, only procedures can have them, but we'll leave it to T-SQL to complain + // about that aspect, not relevant for batch / dependency parsing) + // + // What is important is a function/procedure/type isn't started on without a 'go' + // in between; so we block those 3 from appearing in the same batch + CopyToken(s, &result.Body) + NextTokenCopyingWhitespace(s, &result.Body) + tt2 := s.TokenType() + + if (tt2 == ReservedWordToken && (s.ReservedWord() == "function" || s.ReservedWord() == "procedure")) || + (tt2 == UnquotedIdentifierToken && s.TokenLower() == "type") { + d.recoverToNextStatementCopying(s, &result.Body) + d.addError(s, "a procedure/function must be alone in a batch; use 'go' to split batches") + return + } + case "type": + // We allow more than one type creation in a batch; and 'create' can never appear + // scoped within 'create type'. So at a new create we are done with the previous + // one, and return it -- the caller can then re-enter this function from the top + break tailloop + default: + panic("assertion failed") + } + + case tt == EOFToken || tt == BatchSeparatorToken: + break tailloop + case tt == QuotedIdentifierToken && s.Token() == "[code]": + // Parse a dependency + dep := d.parseCodeschemaName(s, &result.Body) + found := false + for _, existing := range result.DependsOn { + if existing.Value == dep.Value { + found = true + break + } + } + if !found { + result.DependsOn = append(result.DependsOn, dep) + } + case tt == ReservedWordToken && s.Token() == "as": + CopyToken(s, &result.Body) + NextTokenCopyingWhitespace(s, &result.Body) + /* + TODO: Fix and re-enable + This code add RoutineName for convenience. So: + + create procedure [code@5420c0269aaf].Test as + begin + select 1 + end + go + + becomes: + + create procedure [code@5420c0269aaf].Test as + declare @RoutineName nvarchar(128) + set @RoutineName = 'Test' + begin + select 1 + end + go + + However, for some very strange reason, @@rowcount is 1 with the first version, + and it is 2 with the second version. + if firstAs { + // Add the `RoutineName` token as a convenience, so that we can refer to the procedure/function name + // from inside the procedure (for example, when logging) + if result.CreateType == "procedure" { + procNameToken := Unparsed{ + Type: OtherToken, + RawValue: fmt.Sprintf(templateRoutineName, strings.Trim(result.QuotedName.Value, "[]")), + } + result.Body = append(result.Body, procNameToken) + } + firstAs = false + } + */ + + default: + CopyToken(s, &result.Body) + NextTokenCopyingWhitespace(s, &result.Body) + } + } + + sort.Slice(result.DependsOn, func(i, j int) bool { + return result.DependsOn[i].Value < result.DependsOn[j].Value + }) + return +} + +// NextTokenCopyingWhitespace is like s.NextToken(), but if whitespace is encountered +// it is simply copied into `target`. Upon return, the scanner is located at a non-whitespace +// token, and target is either unmodified or filled with some whitespace nodes. +func NextTokenCopyingWhitespace(s *Scanner, target *[]Unparsed) { + for { + tt := s.NextToken() + switch tt { + case EOFToken, BatchSeparatorToken: + // do not copy + return + case WhitespaceToken, MultilineCommentToken, SinglelineCommentToken: + // copy, and loop around + CopyToken(s, target) + continue + default: + return + } + } + +} + +func CreateUnparsed(s *Scanner) Unparsed { + return Unparsed{ + Type: s.TokenType(), + Start: s.Start(), + Stop: s.Stop(), + RawValue: s.Token(), + } +} diff --git a/sqlparser/parser.go b/sqlparser/parser.go index b88e4ea..f2b51d8 100644 --- a/sqlparser/parser.go +++ b/sqlparser/parser.go @@ -12,11 +12,7 @@ import ( "path/filepath" "regexp" "slices" - "sort" "strings" - - mssql "github.com/denisenkom/go-mssqldb" - "github.com/jackc/pgx/v5/stdlib" ) var templateRoutineName string = "\ndeclare @RoutineName nvarchar(128)\nset @RoutineName = '%s'\n" @@ -27,27 +23,6 @@ func CopyToken(s *Scanner, target *[]Unparsed) { *target = append(*target, CreateUnparsed(s)) } -// NextTokenCopyingWhitespace is like s.NextToken(), but if whitespace is encountered -// it is simply copied into `target`. Upon return, the scanner is located at a non-whitespace -// token, and target is either unmodified or filled with some whitespace nodes. -func NextTokenCopyingWhitespace(s *Scanner, target *[]Unparsed) { - for { - tt := s.NextToken() - switch tt { - case EOFToken, BatchSeparatorToken: - // do not copy - return - case WhitespaceToken, MultilineCommentToken, SinglelineCommentToken: - // copy, and loop around - CopyToken(s, target) - continue - default: - return - } - } - -} - // AdvanceAndCopy is like NextToken; advance to next token that is not whitespace and return // Note: The 'go' and EOF tokens are *not* copied func AdvanceAndCopy(s *Scanner, target *[]Unparsed) { @@ -69,473 +44,6 @@ func AdvanceAndCopy(s *Scanner, target *[]Unparsed) { } } -func CreateUnparsed(s *Scanner) Unparsed { - return Unparsed{ - Type: s.TokenType(), - Start: s.Start(), - Stop: s.Stop(), - RawValue: s.Token(), - } -} - -func (d *Document) addError(s *Scanner, msg string) { - d.Errors = append(d.Errors, Error{ - Pos: s.Start(), - Message: msg, - }) -} - -func (d *Document) unexpectedTokenError(s *Scanner) { - d.addError(s, "Unexpected: "+s.Token()) -} - -func (doc *Document) parseTypeExpression(s *Scanner) (t Type) { - parseArgs := func() { - // parses *after* the initial (; consumes trailing ) - for { - switch { - case s.TokenType() == NumberToken: - t.Args = append(t.Args, s.Token()) - case s.TokenType() == UnquotedIdentifierToken && s.TokenLower() == "max": - t.Args = append(t.Args, "max") - default: - doc.unexpectedTokenError(s) - doc.recoverToNextStatement(s) - return - } - s.NextNonWhitespaceCommentToken() - switch { - case s.TokenType() == CommaToken: - s.NextNonWhitespaceCommentToken() - continue - case s.TokenType() == RightParenToken: - s.NextNonWhitespaceCommentToken() - return - default: - doc.unexpectedTokenError(s) - doc.recoverToNextStatement(s) - return - } - } - } - - if s.TokenType() != UnquotedIdentifierToken { - panic("assertion failed, bug in caller") - } - t.BaseType = s.Token() - s.NextNonWhitespaceCommentToken() - if s.TokenType() == LeftParenToken { - s.NextNonWhitespaceCommentToken() - parseArgs() - } - return -} - -func (doc *Document) parseDeclare(s *Scanner) (result []Declare) { - declareStart := s.Start() - // parse what is *after* the `declare` reserved keyword -loop: - for { - if s.TokenType() != VariableIdentifierToken { - doc.unexpectedTokenError(s) - doc.recoverToNextStatement(s) - return - } - - variableName := s.Token() - if !strings.HasPrefix(strings.ToLower(variableName), "@enum") && - !strings.HasPrefix(strings.ToLower(variableName), "@global") && - !strings.HasPrefix(strings.ToLower(variableName), "@const") { - doc.addError(s, "sqlcode constants needs to have names starting with @Enum, @Global or @Const: "+variableName) - } - s.NextNonWhitespaceCommentToken() - var variableType Type - switch s.TokenType() { - case EqualToken: - doc.addError(s, "sqlcode constants needs a type declared explicitly") - s.NextNonWhitespaceCommentToken() - case UnquotedIdentifierToken: - variableType = doc.parseTypeExpression(s) - } - - if s.TokenType() != EqualToken { - doc.addError(s, "sqlcode constants needs to be assigned at once using =") - doc.recoverToNextStatement(s) - } - - switch s.NextNonWhitespaceCommentToken() { - case NumberToken, NVarcharLiteralToken, VarcharLiteralToken: - result = append(result, Declare{ - Start: declareStart, - Stop: s.Stop(), - VariableName: variableName, - Datatype: variableType, - Literal: CreateUnparsed(s), - }) - default: - doc.unexpectedTokenError(s) - doc.recoverToNextStatement(s) - return - } - - switch s.NextNonWhitespaceCommentToken() { - case CommaToken: - s.NextNonWhitespaceCommentToken() - continue - case SemicolonToken: - s.NextNonWhitespaceCommentToken() - break loop - default: - break loop - } - } - if len(result) == 0 { - doc.addError(s, "incorrect syntax; no variables successfully declared") - } - return -} - -func (doc *Document) parseBatchSeparator(s *Scanner) { - // just saw a 'go'; just make sure there's nothing bad trailing it - // (if there is, convert to errors and move on until the line is consumed - errorEmitted := false - for { - switch s.NextToken() { - case WhitespaceToken: - continue - case MalformedBatchSeparatorToken: - if !errorEmitted { - doc.addError(s, "`go` should be alone on a line without any comments") - errorEmitted = true - } - continue - default: - return - } - } -} - -func (doc *Document) parseDeclareBatch(s *Scanner) (hasMore bool) { - if s.ReservedWord() != "declare" { - panic("assertion failed, incorrect use in caller") - } - for { - tt := s.TokenType() - switch { - case tt == EOFToken: - return false - case tt == ReservedWordToken && s.ReservedWord() == "declare": - s.NextNonWhitespaceCommentToken() - d := doc.parseDeclare(s) - doc.Declares = append(doc.Declares, d...) - case tt == ReservedWordToken && s.ReservedWord() != "declare": - doc.addError(s, "Only 'declare' allowed in this batch") - doc.recoverToNextStatement(s) - case tt == BatchSeparatorToken: - doc.parseBatchSeparator(s) - return true - default: - doc.unexpectedTokenError(s) - doc.recoverToNextStatement(s) - } - } -} - -func (doc *Document) parseBatch(s *Scanner, isFirst bool) (hasMore bool) { - var nodes []Unparsed - var docstring []PosString - newLineEncounteredInDocstring := false - - var createCountInBatch int - - for { - tt := s.TokenType() - switch tt { - case EOFToken: - return false - case WhitespaceToken, MultilineCommentToken: - nodes = append(nodes, CreateUnparsed(s)) - // do not reset token for a single trailing newline - t := s.Token() - if !newLineEncounteredInDocstring && (t == "\n" || t == "\r\n") { - newLineEncounteredInDocstring = true - } else { - docstring = nil - } - s.NextToken() - case SinglelineCommentToken: - // We build up a list of single line comments for the "docstring"; - // it is reset whenever we encounter something else - docstring = append(docstring, PosString{s.Start(), s.Token()}) - nodes = append(nodes, CreateUnparsed(s)) - newLineEncounteredInDocstring = false - s.NextToken() - case ReservedWordToken: - switch s.ReservedWord() { - case "declare": - // First declare-statement; enter a mode where we assume all contents - // of batch are declare statements - if !isFirst { - doc.addError(s, "'declare' statement only allowed in first batch") - } - // regardless of errors, go on and parse as far as we get... - return doc.parseDeclareBatch(s) - case "create": - // should be start of create procedure or create function... - c := doc.parseCreate(s, createCountInBatch) - - if strings.HasSuffix(string(s.file), ".sql") { - c.Driver = &mssql.Driver{} - } - if strings.HasSuffix(string(s.file), ".pgsql") { - c.Driver = &stdlib.Driver{} - } - - // *prepend* what we saw before getting to the 'create' - createCountInBatch++ - c.Body = append(nodes, c.Body...) - c.Docstring = docstring - doc.Creates = append(doc.Creates, c) - default: - doc.addError(s, "Expected 'declare' or 'create', got: "+s.ReservedWord()) - s.NextToken() - } - case BatchSeparatorToken: - doc.parseBatchSeparator(s) - return true - default: - doc.unexpectedTokenError(s) - s.NextToken() - docstring = nil - } - } -} - -func (d *Document) recoverToNextStatementCopying(s *Scanner, target *[]Unparsed) { - // We hit an unexpected token ... as an heuristic for continuing parsing, - // skip parsing until we hit a reserved word that starts a statement - // we recognize - for { - NextTokenCopyingWhitespace(s, target) - switch s.TokenType() { - case ReservedWordToken: - switch s.ReservedWord() { - case "declare", "create", "go": - return - } - case EOFToken: - return - default: - CopyToken(s, target) - } - } -} - -func (d *Document) recoverToNextStatement(s *Scanner) { - // We hit an unexpected token ... as an heuristic for continuing parsing, - // skip parsing until we hit a reserved word that starts a statement - // we recognize - for { - s.NextNonWhitespaceCommentToken() - switch s.TokenType() { - case ReservedWordToken: - switch s.ReservedWord() { - case "declare", "create", "go": - return - } - case EOFToken: - return - } - } -} - -// parseCodeschemaName parses `[code] . something`, and returns `something` -// in quoted form (`[something]`). Also copy to `target`. Empty string on error. -// Note: To follow conventions, consume one extra token at the end even if we know -// it fill not be consumed by this function... -func (d *Document) parseCodeschemaName(s *Scanner, target *[]Unparsed) PosString { - CopyToken(s, target) - NextTokenCopyingWhitespace(s, target) - if s.TokenType() != DotToken { - d.addError(s, fmt.Sprintf("[code] must be followed by '.'")) - d.recoverToNextStatementCopying(s, target) - return PosString{Value: ""} - } - CopyToken(s, target) - - NextTokenCopyingWhitespace(s, target) - switch s.TokenType() { - case UnquotedIdentifierToken: - // To get something uniform for comparison, quote all names - CopyToken(s, target) - result := PosString{Pos: s.Start(), Value: "[" + s.Token() + "]"} - NextTokenCopyingWhitespace(s, target) - return result - case QuotedIdentifierToken: - CopyToken(s, target) - result := PosString{Pos: s.Start(), Value: s.Token()} - NextTokenCopyingWhitespace(s, target) - return result - default: - d.addError(s, fmt.Sprintf("[code]. must be followed an identifier")) - d.recoverToNextStatementCopying(s, target) - return PosString{Value: ""} - } -} - -// parseCreate parses anything that starts with "create". Position is -// *on* the create token. -// At this stage in sqlcode parser development we're only interested -// in procedures/functions/types as opaque blocks of SQL code where -// we only track dependencies between them and their declared name; -// so we treat them with the same code. We consume until the end of -// the batch; only one declaration allowed per batch. Everything -// parsed here will also be added to `batch`. On any error, copying -// to batch stops / becomes erratic.. -func (d *Document) parseCreate(s *Scanner, createCountInBatch int) (result Create) { - if s.ReservedWord() != "create" { - panic("illegal use by caller") - } - CopyToken(s, &result.Body) - - NextTokenCopyingWhitespace(s, &result.Body) - - createType := strings.ToLower(s.Token()) - if !(createType == "procedure" || createType == "function" || createType == "type") { - d.addError(s, fmt.Sprintf("sqlcode only supports creating procedures, functions or types; not `%s`", createType)) - d.recoverToNextStatementCopying(s, &result.Body) - return - } - if (createType == "procedure" || createType == "function") && createCountInBatch > 0 { - d.addError(s, "a procedure/function must be alone in a batch; use 'go' to split batches") - d.recoverToNextStatementCopying(s, &result.Body) - return - } - - result.CreateType = createType - CopyToken(s, &result.Body) - - NextTokenCopyingWhitespace(s, &result.Body) - - // Insist on [code]. - if s.TokenType() != QuotedIdentifierToken || s.Token() != "[code]" { - d.addError(s, fmt.Sprintf("create %s must be followed by [code].", result.CreateType)) - d.recoverToNextStatementCopying(s, &result.Body) - return - } - result.QuotedName = d.parseCodeschemaName(s, &result.Body) - if result.QuotedName.String() == "" { - return - } - - // We have matched "create [code]."; at this - // point we copy the rest until the batch ends; *but* track dependencies - // + some other details mentioned below - - //firstAs := true // See comment below on rowcount - -tailloop: - for { - tt := s.TokenType() - switch { - case tt == ReservedWordToken && s.ReservedWord() == "create": - // So, we're currently parsing 'create ...' and we see another 'create'. - // We split in two cases depending on the context we are currently in - // (createType is referring to how we entered this function, *NOT* the - // `create` statement we are looking at now - switch createType { // note: this is the *outer* create type, not the one of current scanner position - case "function", "procedure": - // Within a function/procedure we can allow 'create index', 'create table' and nothing - // else. (Well, only procedures can have them, but we'll leave it to T-SQL to complain - // about that aspect, not relevant for batch / dependency parsing) - // - // What is important is a function/procedure/type isn't started on without a 'go' - // in between; so we block those 3 from appearing in the same batch - CopyToken(s, &result.Body) - NextTokenCopyingWhitespace(s, &result.Body) - tt2 := s.TokenType() - - if (tt2 == ReservedWordToken && (s.ReservedWord() == "function" || s.ReservedWord() == "procedure")) || - (tt2 == UnquotedIdentifierToken && s.TokenLower() == "type") { - d.recoverToNextStatementCopying(s, &result.Body) - d.addError(s, "a procedure/function must be alone in a batch; use 'go' to split batches") - return - } - case "type": - // We allow more than one type creation in a batch; and 'create' can never appear - // scoped within 'create type'. So at a new create we are done with the previous - // one, and return it -- the caller can then re-enter this function from the top - break tailloop - default: - panic("assertion failed") - } - - case tt == EOFToken || tt == BatchSeparatorToken: - break tailloop - case tt == QuotedIdentifierToken && s.Token() == "[code]": - // Parse a dependency - dep := d.parseCodeschemaName(s, &result.Body) - found := false - for _, existing := range result.DependsOn { - if existing.Value == dep.Value { - found = true - break - } - } - if !found { - result.DependsOn = append(result.DependsOn, dep) - } - case tt == ReservedWordToken && s.Token() == "as": - CopyToken(s, &result.Body) - NextTokenCopyingWhitespace(s, &result.Body) - /* - TODO: Fix and re-enable - This code add RoutineName for convenience. So: - - create procedure [code@5420c0269aaf].Test as - begin - select 1 - end - go - - becomes: - - create procedure [code@5420c0269aaf].Test as - declare @RoutineName nvarchar(128) - set @RoutineName = 'Test' - begin - select 1 - end - go - - However, for some very strange reason, @@rowcount is 1 with the first version, - and it is 2 with the second version. - if firstAs { - // Add the `RoutineName` token as a convenience, so that we can refer to the procedure/function name - // from inside the procedure (for example, when logging) - if result.CreateType == "procedure" { - procNameToken := Unparsed{ - Type: OtherToken, - RawValue: fmt.Sprintf(templateRoutineName, strings.Trim(result.QuotedName.Value, "[]")), - } - result.Body = append(result.Body, procNameToken) - } - firstAs = false - } - */ - - default: - CopyToken(s, &result.Body) - NextTokenCopyingWhitespace(s, &result.Body) - } - } - - sort.Slice(result.DependsOn, func(i, j int) bool { - return result.DependsOn[i].Value < result.DependsOn[j].Value - }) - return -} - func Parse(s *Scanner, result *Document) { // Top-level parse; this focuses on splitting into "batches" separated // by 'go'. From b3419f21fcbe59fa45cdbc4ae5f5b953bbdc2b4c Mon Sep 17 00:00:00 2001 From: Kyle Simukka Date: Thu, 11 Dec 2025 18:41:23 +0100 Subject: [PATCH 17/40] Updated go-mssql depedency to use microsoft fork. DropAndUpload now support postgresql. --- cli/cmd/build.go | 2 +- cli/cmd/config.go | 11 ++--- dbops.go | 23 +++++++++-- deployable.go | 2 +- docker-compose.mssql.yml | 2 +- dockerfile.test | 2 - go.mod | 16 +++++--- go.sum | 86 ++++++++++++++++++++++------------------ mssql_error.go | 2 +- preprocess_test.go | 2 +- sqlparser/parser_test.go | 2 +- sqltest/fixture.go | 36 ++++++++++------- sqltest/sqlcode_test.go | 33 ++++++++++++--- 13 files changed, 139 insertions(+), 80 deletions(-) diff --git a/cli/cmd/build.go b/cli/cmd/build.go index c0d7db3..9fd9d9a 100644 --- a/cli/cmd/build.go +++ b/cli/cmd/build.go @@ -4,7 +4,7 @@ import ( "errors" "fmt" - mssql "github.com/denisenkom/go-mssqldb" + mssql "github.com/microsoft/go-mssqldb" "github.com/spf13/cobra" "github.com/vippsas/sqlcode" ) diff --git a/cli/cmd/config.go b/cli/cmd/config.go index 6bebbf1..6968802 100644 --- a/cli/cmd/config.go +++ b/cli/cmd/config.go @@ -5,16 +5,17 @@ import ( "database/sql" "errors" "fmt" - mssql "github.com/denisenkom/go-mssqldb" - "github.com/denisenkom/go-mssqldb/azuread" - "golang.org/x/net/proxy" "io/ioutil" "os" "path" "strings" - _ "github.com/denisenkom/go-mssqldb/azuread" - "github.com/denisenkom/go-mssqldb/msdsn" + mssql "github.com/microsoft/go-mssqldb" + "github.com/microsoft/go-mssqldb/azuread" + "golang.org/x/net/proxy" + + _ "github.com/microsoft/go-mssqldb/azuread" + "github.com/microsoft/go-mssqldb/msdsn" "github.com/sirupsen/logrus" "gopkg.in/yaml.v3" ) diff --git a/dbops.go b/dbops.go index 0293a2a..d63278f 100644 --- a/dbops.go +++ b/dbops.go @@ -4,8 +4,9 @@ import ( "context" "database/sql" - mssql "github.com/denisenkom/go-mssqldb" + "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/stdlib" + mssql "github.com/microsoft/go-mssqldb" ) func Exists(ctx context.Context, dbc DB, schemasuffix string) (bool, error) { @@ -33,8 +34,24 @@ func Drop(ctx context.Context, dbc DB, schemasuffix string) error { if err != nil { return err } - _, err = tx.ExecContext(ctx, `sqlcode.DropCodeSchema`, - sql.Named("schemasuffix", schemasuffix)) + + var qs string + var arg = []interface{}{} + driver := dbc.Driver() + + if _, ok := driver.(*mssql.Driver); ok { + qs = `sqlcode.DropCodeSchema` + arg = []interface{}{sql.Named("schemasuffix", schemasuffix)} + } + + if _, ok := dbc.Driver().(*stdlib.Driver); ok { + qs = `call sqlcode.dropcodeschema(@schemasuffix)` + arg = []interface{}{ + pgx.NamedArgs{"schemasuffix": schemasuffix}, + } + } + + _, err = tx.ExecContext(ctx, qs, arg...) if err != nil { _ = tx.Rollback() return err diff --git a/deployable.go b/deployable.go index ec784d7..03cc0a8 100644 --- a/deployable.go +++ b/deployable.go @@ -10,10 +10,10 @@ import ( "strings" "time" - mssql "github.com/denisenkom/go-mssqldb" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/stdlib" pgxstdlib "github.com/jackc/pgx/v5/stdlib" + mssql "github.com/microsoft/go-mssqldb" "github.com/vippsas/sqlcode/sqlparser" ) diff --git a/docker-compose.mssql.yml b/docker-compose.mssql.yml index 84618fc..f2e4471 100644 --- a/docker-compose.mssql.yml +++ b/docker-compose.mssql.yml @@ -17,7 +17,7 @@ services: - mssql environment: SQLSERVER_DSN: sqlserver://mssql:1433?database=master&user id=sa&password=VippsPw1 - SQLSERVER_DRIVER: sqlserver + GODEBUG: "x509negativeserial=1" depends_on: mssql: condition: service_healthy diff --git a/dockerfile.test b/dockerfile.test index 286268c..dbd2061 100644 --- a/dockerfile.test +++ b/dockerfile.test @@ -1,7 +1,5 @@ FROM golang:1.25 AS builder WORKDIR /sqlcode -ENV GODEBUG="x509negativeserial=1" COPY . . RUN go mod tidy -# Skip the example folder because it has examples of what-not-to-do CMD ["go", "test", "-v", "./..."] \ No newline at end of file diff --git a/go.mod b/go.mod index f07fe09..3f7df71 100644 --- a/go.mod +++ b/go.mod @@ -4,8 +4,9 @@ go 1.24.3 require ( github.com/alecthomas/repr v0.5.2 - github.com/denisenkom/go-mssqldb v0.12.3 github.com/gofrs/uuid v4.4.0+incompatible + github.com/jackc/pgx/v5 v5.7.6 + github.com/microsoft/go-mssqldb v1.9.5 github.com/sirupsen/logrus v1.9.3 github.com/smasher164/xid v0.1.2 github.com/spf13/cobra v1.10.2 @@ -15,20 +16,23 @@ require ( ) require ( - github.com/Azure/azure-sdk-for-go/sdk/azcore v0.19.0 // indirect - github.com/Azure/azure-sdk-for-go/sdk/azidentity v0.11.0 // indirect - github.com/Azure/azure-sdk-for-go/sdk/internal v0.7.0 // indirect + github.com/Azure/azure-sdk-for-go/sdk/azcore v1.18.0 // indirect + github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.10.1 // indirect + github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.1 // indirect + github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/golang-jwt/jwt/v5 v5.2.2 // indirect github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/golang-sql/sqlexp v0.1.0 // indirect + github.com/google/uuid v1.6.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect - github.com/jackc/pgx/v5 v5.7.6 // indirect github.com/jackc/puddle/v2 v2.2.2 // indirect - github.com/lib/pq v1.10.9 // indirect + github.com/kylelemons/godebug v1.1.0 // indirect github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/shopspring/decimal v1.4.0 // indirect github.com/spf13/pflag v1.0.9 // indirect golang.org/x/crypto v0.43.0 // indirect golang.org/x/sync v0.17.0 // indirect diff --git a/go.sum b/go.sum index 1a7b35c..0d93c80 100644 --- a/go.sum +++ b/go.sum @@ -1,25 +1,39 @@ -github.com/Azure/azure-sdk-for-go/sdk/azcore v0.19.0 h1:lhSJz9RMbJcTgxifR1hUNJnn6CNYtbgEDtQV22/9RBA= -github.com/Azure/azure-sdk-for-go/sdk/azcore v0.19.0/go.mod h1:h6H6c8enJmmocHUbLiiGY6sx7f9i+X3m1CHdd5c6Rdw= -github.com/Azure/azure-sdk-for-go/sdk/azidentity v0.11.0 h1:OYa9vmRX2XC5GXRAzeggG12sF/z5D9Ahtdm9EJ00WN4= -github.com/Azure/azure-sdk-for-go/sdk/azidentity v0.11.0/go.mod h1:HcM1YX14R7CJcghJGOYCgdezslRSVzqwLf/q+4Y2r/0= -github.com/Azure/azure-sdk-for-go/sdk/internal v0.7.0 h1:v9p9TfTbf7AwNb5NYQt7hI41IfPoLFiFkLtb+bmGjT0= -github.com/Azure/azure-sdk-for-go/sdk/internal v0.7.0/go.mod h1:yqy467j36fJxcRV2TzfVZ1pCb5vxm4BtZPUdYWe/Xo8= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.18.0 h1:Gt0j3wceWMwPmiazCa8MzMA0MfhmPIz0Qp0FJ6qcM0U= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.18.0/go.mod h1:Ot/6aikWnKWi4l9QB7qVSwa8iMphQNqkWALMoNT3rzM= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.10.1 h1:B+blDbyVIG3WaikNxPnhPiJ1MThR03b3vKGtER95TP4= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.10.1/go.mod h1:JdM5psgjfBf5fo2uWOZhflPWyDBZ/O/CNAH9CtsuZE4= +github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache v0.3.2 h1:yz1bePFlP5Vws5+8ez6T3HWXPmwOK7Yvq8QxDBD3SKY= +github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache v0.3.2/go.mod h1:Pa9ZNPuoNu/GztvBSKk9J1cDJW6vk/n0zLtV4mgd8N8= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.1 h1:FPKJS1T+clwv+OLGt13a8UjqeRuh0O4SJ3lUriThc+4= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.1/go.mod h1:j2chePtV91HrC22tGoRX3sGY42uF13WzmmV80/OdVAA= +github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.3.1 h1:Wgf5rZba3YZqeTNJPtvqZoBu1sBN/L4sry+u2U3Y75w= +github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.3.1/go.mod h1:xxCBG/f/4Vbmh2XQJBsOmNdxWUY5j/s27jujKPbQf14= +github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.1.1 h1:bFWuoEKg+gImo7pvkiQEFAc8ocibADgXeiLAxWhWmkI= +github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.1.1/go.mod h1:Vih/3yc6yac2JzU4hzpaDupBJP0Flaia9rXXrU8xyww= +github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1 h1:WJTmL004Abzc5wDB5VtZG2PJk5ndYDgVacGqfirKxjM= +github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1/go.mod h1:tCcJZ0uHAmvjsVYzEFivsRTN00oz5BEsRgQHu5JZ9WE= +github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2 h1:oygO0locgZJe7PpYPXT5A29ZkwJaPqcva7BVeemZOZs= +github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI= github.com/alecthomas/repr v0.5.2 h1:SU73FTI9D1P5UNtvseffFSGmdNci/O6RsqzeXJtP0Qs= github.com/alecthomas/repr v0.5.2/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/denisenkom/go-mssqldb v0.12.3 h1:pBSGx9Tq67pBOTLmxNuirNTeB8Vjmf886Kx+8Y+8shw= -github.com/denisenkom/go-mssqldb v0.12.3/go.mod h1:k0mtMFOnU+AihqFxPMiF05rtiDrorD1Vrm1KEz5hxDo= -github.com/dnaeon/go-vcr v1.2.0/go.mod h1:R4UdLID7HZT3taECzJs4YgbbH6PIGXB6W/sc5OLb6RQ= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/gofrs/uuid v4.4.0+incompatible h1:3qXRTX8/NbyulANqlc0lchS1gqAVxRgsuW1YrTJupqA= github.com/gofrs/uuid v4.4.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= -github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= +github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= +github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0ktxqI+Sida1w446QrXBRJ0nee3SNZlA= github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A= github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EOqtpKwwwHI= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= @@ -30,15 +44,27 @@ github.com/jackc/pgx/v5 v5.7.6 h1:rWQc5FwZSPX58r1OQmkuaNicxdmExaEz5A2DO2hUuTk= github.com/jackc/pgx/v5 v5.7.6/go.mod h1:aruU7o91Tc2q2cFp5h4uP3f6ztExVpyVv88Xl/8Vl8M= github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= -github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= -github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= -github.com/modocache/gover v0.0.0-20171022184752-b58185e213c5/go.mod h1:caMODM3PzxT8aQXRPkAt8xlV/e7d7w8GM5g0fa5F0D8= -github.com/pkg/browser v0.0.0-20180916011732-0a3d74bf9ce4/go.mod h1:4OwLy04Bl9Ef3GJJCoec+30X3LQs/0/m4HFRt/2LUSA= +github.com/keybase/go-keychain v0.0.1 h1:way+bWYa6lDppZoZcgMbYsvC7GxljxrskdNInRtuthU= +github.com/keybase/go-keychain v0.0.1/go.mod h1:PdEILRW3i9D8JcdM+FmY6RwkHGnhHxXwkPPMeUgOK1k= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= +github.com/microsoft/go-mssqldb v1.9.5 h1:orwya0X/5bsL1o+KasupTkk2eNTNFkTQG0BEe/HxCn0= +github.com/microsoft/go-mssqldb v1.9.5/go.mod h1:VCP2a0KEZZtGLRHd1PsLavLFYy/3xX2yJUPycv3Sr2Q= github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/redis/go-redis/v9 v9.8.0 h1:q3nRvjrlge/6UD7eTu/DSg2uYiU2mCL0G/uzBWqhicI= +github.com/redis/go-redis/v9 v9.8.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw= +github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= +github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k= +github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/smasher164/xid v0.1.2 h1:erplXSdBRIIw+MrwjJ/m8sLN2XY16UGzpTA0E2Ru6HA= @@ -52,41 +78,23 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= -go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20201016220609-9e8e0b390897/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU= -golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20210610132358-84b48f89b13b/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04= +golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0= golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4= golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210= golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= -golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ= +golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU= -golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= +golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k= +golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= -gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/mssql_error.go b/mssql_error.go index 22d4bde..d6f531e 100644 --- a/mssql_error.go +++ b/mssql_error.go @@ -5,7 +5,7 @@ import ( "fmt" "strings" - mssql "github.com/denisenkom/go-mssqldb" + mssql "github.com/microsoft/go-mssqldb" "github.com/vippsas/sqlcode/sqlparser" ) diff --git a/preprocess_test.go b/preprocess_test.go index 9d0d2ee..e68d8bf 100644 --- a/preprocess_test.go +++ b/preprocess_test.go @@ -4,8 +4,8 @@ import ( "strings" "testing" - mssql "github.com/denisenkom/go-mssqldb" "github.com/jackc/pgx/v5/stdlib" + mssql "github.com/microsoft/go-mssqldb" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/vippsas/sqlcode/sqlparser" diff --git a/sqlparser/parser_test.go b/sqlparser/parser_test.go index fc75460..0e36b40 100644 --- a/sqlparser/parser_test.go +++ b/sqlparser/parser_test.go @@ -7,8 +7,8 @@ import ( "testing" "testing/fstest" - mssql "github.com/denisenkom/go-mssqldb" "github.com/jackc/pgx/v5/stdlib" + mssql "github.com/microsoft/go-mssqldb" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/sqltest/fixture.go b/sqltest/fixture.go index e05c418..b384be2 100644 --- a/sqltest/fixture.go +++ b/sqltest/fixture.go @@ -9,23 +9,23 @@ import ( "testing" "time" - mssql "github.com/denisenkom/go-mssqldb" - "github.com/denisenkom/go-mssqldb/msdsn" "github.com/gofrs/uuid" _ "github.com/jackc/pgx/v5" _ "github.com/jackc/pgx/v5/stdlib" + mssql "github.com/microsoft/go-mssqldb" + "github.com/microsoft/go-mssqldb/msdsn" ) type SqlDriverType int const ( - SqlDriverDenisen SqlDriverType = iota + SqlDriverMssql SqlDriverType = iota SqlDriverPgx ) var sqlDrivers = map[SqlDriverType]string{ - SqlDriverDenisen: "sqlserver", - SqlDriverPgx: "pgx", + SqlDriverMssql: "sqlserver", + SqlDriverPgx: "pgx", } type StdoutLogger struct { @@ -49,7 +49,7 @@ type Fixture struct { } func (f *Fixture) IsSqlServer() bool { - return f.Driver == SqlDriverDenisen + return f.Driver == SqlDriverMssql } func (f *Fixture) IsPostgresql() bool { @@ -89,7 +89,7 @@ func NewFixture() *Fixture { // 32: Log transaction begin/end dsn = dsn + "&log=63" mssql.SetLogger(StdoutLogger{}) - fixture.Driver = SqlDriverDenisen + fixture.Driver = SqlDriverMssql } if strings.Contains(dsn, "postgresql") { fixture.Driver = SqlDriverPgx @@ -123,7 +123,7 @@ func NewFixture() *Fixture { panic(err) } - pdsn, _, err := msdsn.Parse(dsn) + pdsn, err := msdsn.Parse(dsn) if err != nil { panic(err) } @@ -198,13 +198,21 @@ func (f *Fixture) RunMigrationFile(filename string) { } func (f *Fixture) RunIfPostgres(t *testing.T, name string, fn func(t *testing.T)) { - if f.IsPostgresql() { - t.Run(name, fn) - } + t.Run("pgsql", func(t *testing.T) { + if f.IsPostgresql() { + t.Run(name, fn) + } else { + t.Skip() + } + }) } func (f *Fixture) RunIfMssql(t *testing.T, name string, fn func(t *testing.T)) { - if f.IsSqlServer() { - t.Run(name, fn) - } + t.Run("mssql", func(t *testing.T) { + if f.IsSqlServer() { + t.Run(name, fn) + } else { + t.Skip() + } + }) } diff --git a/sqltest/sqlcode_test.go b/sqltest/sqlcode_test.go index 910064b..b92cc7f 100644 --- a/sqltest/sqlcode_test.go +++ b/sqltest/sqlcode_test.go @@ -27,7 +27,7 @@ func Test_Patch(t *testing.T) { require.NoError(t, SQL.EnsureUploaded(ctx, fixture.DB)) - fixture.RunIfMssql(t, "mssql", func(t *testing.T) { + fixture.RunIfMssql(t, "returns 1 affected row", func(t *testing.T) { patched := SQL.CodePatch(fixture.DB, `[code].Test`) res, err := fixture.DB.ExecContext(ctx, patched) require.NoError(t, err) @@ -37,7 +37,8 @@ func Test_Patch(t *testing.T) { assert.Equal(t, int64(1), rowsAffected) }) - fixture.RunIfPostgres(t, "pgsql", func(t *testing.T) { + // TODO: instrument a test table to perform an update operation + fixture.RunIfPostgres(t, "returns 0 affected rows", func(t *testing.T) { patched := SQL.CodePatch(fixture.DB, `call [code].Test()`) res, err := fixture.DB.ExecContext(ctx, patched) require.NoError(t, err) @@ -47,7 +48,6 @@ func Test_Patch(t *testing.T) { require.NoError(t, err) assert.Equal(t, int64(0), rowsAffected) }) - } func Test_EnsureUploaded(t *testing.T) { @@ -55,7 +55,7 @@ func Test_EnsureUploaded(t *testing.T) { defer f.Teardown() ctx := context.Background() - f.RunIfMssql(t, "mssql", func(t *testing.T) { + f.RunIfMssql(t, "uploads schema", func(t *testing.T) { f.RunMigrationFile("../migrations/0001.sqlcode.sql") require.NoError(t, SQL.EnsureUploaded(ctx, f.DB)) schemas, err := SQL.ListUploaded(ctx, f.DB) @@ -64,7 +64,7 @@ func Test_EnsureUploaded(t *testing.T) { }) - f.RunIfPostgres(t, "pgsql", func(t *testing.T) { + f.RunIfPostgres(t, "uploads schema", func(t *testing.T) { f.RunMigrationFile("../migrations/0001.sqlcode.pgsql") _, err := f.DB.Exec( @@ -77,3 +77,26 @@ func Test_EnsureUploaded(t *testing.T) { require.Len(t, schemas, 1) }) } + +func Test_DropAndUpload(t *testing.T) { + f := NewFixture() + defer f.Teardown() + ctx := context.Background() + + f.RunIfMssql(t, "drop and upload", func(t *testing.T) { + f.RunMigrationFile("../migrations/0001.sqlcode.sql") + require.NoError(t, SQL.EnsureUploaded(ctx, f.DB)) + require.NoError(t, SQL.DropAndUpload(ctx, f.DB)) + }) + + f.RunIfPostgres(t, "drop and upload", func(t *testing.T) { + f.RunMigrationFile("../migrations/0001.sqlcode.pgsql") + + _, err := f.DB.Exec( + fmt.Sprintf(`grant create on database "%s" to "sqlcode-definer-role"`, f.DBName)) + require.NoError(t, err) + + require.NoError(t, SQL.EnsureUploaded(ctx, f.DB)) + require.NoError(t, SQL.DropAndUpload(ctx, f.DB)) + }) +} From e51b27760b12f0a25639c98b484843b0e13d4f09 Mon Sep 17 00:00:00 2001 From: Kyle Simukka Date: Thu, 11 Dec 2025 18:48:02 +0100 Subject: [PATCH 18/40] Initial unit tests for T-SQL syntax parsing. --- sqlparser/document_test.go | 519 +++++++++++++++++++++++++++++++++++++ sqlparser/dom.go | 7 - sqlparser/scanner.go | 9 +- 3 files changed, 526 insertions(+), 9 deletions(-) create mode 100644 sqlparser/document_test.go diff --git a/sqlparser/document_test.go b/sqlparser/document_test.go new file mode 100644 index 0000000..5ddddb6 --- /dev/null +++ b/sqlparser/document_test.go @@ -0,0 +1,519 @@ +package sqlparser + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDocument_addError(t *testing.T) { + t.Run("adds error with position", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "select") + s.NextToken() + + doc.addError(s, "test error message") + + require.Len(t, doc.Errors, 1) + assert.Equal(t, "test error message", doc.Errors[0].Message) + assert.Equal(t, Pos{File: "test.sql", Line: 1, Col: 1}, doc.Errors[0].Pos) + }) + + t.Run("accumulates multiple errors", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "abc def") + s.NextToken() + doc.addError(s, "error 1") + s.NextToken() + doc.addError(s, "error 2") + + require.Len(t, doc.Errors, 2) + assert.Equal(t, "error 1", doc.Errors[0].Message) + assert.Equal(t, "error 2", doc.Errors[1].Message) + }) + + t.Run("creates error with token text", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "unexpected_token") + s.NextToken() + + doc.unexpectedTokenError(s) + + require.Len(t, doc.Errors, 1) + assert.Equal(t, "Unexpected: unexpected_token", doc.Errors[0].Message) + }) +} + +func TestDocument_parseTypeExpression(t *testing.T) { + t.Run("parses simple type without args", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "int") + s.NextToken() + + typ := doc.parseTypeExpression(s) + + assert.Equal(t, "int", typ.BaseType) + assert.Empty(t, typ.Args) + }) + + t.Run("parses type with single arg", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "varchar(50)") + s.NextToken() + + typ := doc.parseTypeExpression(s) + + assert.Equal(t, "varchar", typ.BaseType) + assert.Equal(t, []string{"50"}, typ.Args) + }) + + t.Run("parses type with multiple args", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "decimal(10, 2)") + s.NextToken() + + typ := doc.parseTypeExpression(s) + + assert.Equal(t, "decimal", typ.BaseType) + assert.Equal(t, []string{"10", "2"}, typ.Args) + }) + + t.Run("parses type with max", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "nvarchar(max)") + s.NextToken() + + typ := doc.parseTypeExpression(s) + + assert.Equal(t, "nvarchar", typ.BaseType) + assert.Equal(t, []string{"max"}, typ.Args) + }) + + t.Run("handles invalid arg", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "varchar(invalid)") + s.NextToken() + + typ := doc.parseTypeExpression(s) + + assert.Equal(t, "varchar", typ.BaseType) + assert.NotEmpty(t, doc.Errors) + }) + + t.Run("panics if not on identifier", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "123") + s.NextToken() + + assert.Panics(t, func() { + doc.parseTypeExpression(s) + }) + }) +} + +func TestDocument_parseDeclare(t *testing.T) { + t.Run("parses single enum declaration", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "@EnumStatus int = 42") + s.NextToken() + + declares := doc.parseDeclare(s) + + require.Len(t, declares, 1) + assert.Equal(t, "@EnumStatus", declares[0].VariableName) + assert.Equal(t, "int", declares[0].Datatype.BaseType) + assert.Equal(t, "42", declares[0].Literal.RawValue) + }) + + t.Run("parses multiple declarations with comma", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "@EnumA int = 1, @EnumB int = 2;") + s.NextToken() + + declares := doc.parseDeclare(s) + + require.Len(t, declares, 2) + assert.Equal(t, "@EnumA", declares[0].VariableName) + assert.Equal(t, "@EnumB", declares[1].VariableName) + }) + + t.Run("parses string literal", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "@EnumName nvarchar(50) = N'test'") + s.NextToken() + + declares := doc.parseDeclare(s) + + require.Len(t, declares, 1) + assert.Equal(t, "N'test'", declares[0].Literal.RawValue) + }) + + t.Run("errors on invalid variable name", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "@InvalidName int = 1") + s.NextToken() + + declares := doc.parseDeclare(s) + + require.Len(t, declares, 1) + assert.NotEmpty(t, doc.Errors) + assert.Contains(t, doc.Errors[0].Message, "@InvalidName") + }) + + t.Run("errors on missing type", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "@EnumTest = 42") + s.NextToken() + + declares := doc.parseDeclare(s) + + require.Len(t, declares, 1) + assert.NotEmpty(t, doc.Errors) + assert.Contains(t, doc.Errors[0].Message, "type declared explicitly") + }) + + t.Run("errors on missing assignment", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "@EnumTest int") + s.NextToken() + + doc.parseDeclare(s) + + assert.NotEmpty(t, doc.Errors) + assert.Contains(t, doc.Errors[0].Message, "needs to be assigned") + }) + + t.Run("accepts @Global prefix", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "@GlobalSetting int = 100") + s.NextToken() + + declares := doc.parseDeclare(s) + + require.Len(t, declares, 1) + assert.Equal(t, "@GlobalSetting", declares[0].VariableName) + assert.Empty(t, doc.Errors) + }) + + t.Run("accepts @Const prefix", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "@ConstValue int = 200") + s.NextToken() + + declares := doc.parseDeclare(s) + + require.Len(t, declares, 1) + assert.Equal(t, "@ConstValue", declares[0].VariableName) + assert.Empty(t, doc.Errors) + }) +} + +func TestDocument_parseBatchSeparator(t *testing.T) { + t.Run("parses valid go separator", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "go\n") + s.NextToken() + + doc.parseBatchSeparator(s) + + assert.Empty(t, doc.Errors) + }) + + t.Run("errors on malformed separator", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "go -- comment") + s.NextToken() + + doc.parseBatchSeparator(s) + + assert.NotEmpty(t, doc.Errors) + assert.Contains(t, doc.Errors[0].Message, "should be alone") + }) +} + +func TestDocument_parseCodeschemaName(t *testing.T) { + t.Run("parses unquoted identifier", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "[code].TestProc") + s.NextToken() + var target []Unparsed + + result := doc.parseCodeschemaName(s, &target) + + assert.Equal(t, "[TestProc]", result.Value) + assert.NotEmpty(t, target) + }) + + t.Run("parses quoted identifier", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "[code].[Test Proc]") + s.NextToken() + var target []Unparsed + + result := doc.parseCodeschemaName(s, &target) + + assert.Equal(t, "[Test Proc]", result.Value) + }) + + t.Run("errors on missing dot", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "[code] TestProc") + s.NextToken() + var target []Unparsed + + result := doc.parseCodeschemaName(s, &target) + + assert.Equal(t, "", result.Value) + assert.NotEmpty(t, doc.Errors) + assert.Contains(t, doc.Errors[0].Message, "must be followed by '.'") + }) + + t.Run("errors on missing identifier", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "[code].123") + s.NextToken() + var target []Unparsed + + result := doc.parseCodeschemaName(s, &target) + + assert.Equal(t, "", result.Value) + assert.NotEmpty(t, doc.Errors) + assert.Contains(t, doc.Errors[0].Message, "must be followed an identifier") + }) +} + +func TestDocument_parseCreate(t *testing.T) { + t.Run("parses simple procedure", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "create procedure [code].TestProc as begin end") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "procedure", create.CreateType) + assert.Equal(t, "[TestProc]", create.QuotedName.Value) + assert.NotEmpty(t, create.Body) + }) + + t.Run("parses function", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "create function [code].TestFunc() returns int as begin return 1 end") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "function", create.CreateType) + assert.Equal(t, "[TestFunc]", create.QuotedName.Value) + }) + + t.Run("parses type", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "create type [code].TestType as table (id int)") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "type", create.CreateType) + assert.Equal(t, "[TestType]", create.QuotedName.Value) + }) + + t.Run("tracks dependencies", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "create procedure [code].Proc1 as begin select * from [code].Table1 join [code].Table2 end") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + require.Len(t, create.DependsOn, 2) + assert.Equal(t, "[Table1]", create.DependsOn[0].Value) + assert.Equal(t, "[Table2]", create.DependsOn[1].Value) + }) + + t.Run("deduplicates dependencies", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "create procedure [code].Proc1 as begin select * from [code].Table1; select * from [code].Table1 end") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + require.Len(t, create.DependsOn, 1) + assert.Equal(t, "[Table1]", create.DependsOn[0].Value) + }) + + t.Run("errors on unsupported create type", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "create table [code].TestTable (id int)") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + doc.parseCreate(s, 0) + + assert.NotEmpty(t, doc.Errors) + assert.Contains(t, doc.Errors[0].Message, "only supports creating procedures") + }) + + t.Run("errors on multiple procedures in batch", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "create procedure [code].Proc2 as begin end") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + doc.parseCreate(s, 1) + + assert.NotEmpty(t, doc.Errors) + assert.Contains(t, doc.Errors[0].Message, "must be alone in a batch") + }) + + t.Run("errors on missing code schema", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "create procedure dbo.TestProc as begin end") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + doc.parseCreate(s, 0) + + assert.NotEmpty(t, doc.Errors) + assert.Contains(t, doc.Errors[0].Message, "must be followed by [code]") + }) + + t.Run("allows create index inside procedure", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "create procedure [code].Proc as begin create index IX_Test on #temp(id) end") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "procedure", create.CreateType) + assert.Empty(t, doc.Errors) + }) + + t.Run("stops at batch separator", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "create procedure [code].Proc as begin end\ngo") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "[Proc]", create.QuotedName.Value) + assert.Equal(t, BatchSeparatorToken, s.TokenType()) + }) + + t.Run("panics if not on create token", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "procedure") + s.NextToken() + + assert.Panics(t, func() { + doc.parseCreate(s, 0) + }) + }) +} + +func TestNextTokenCopyingWhitespace(t *testing.T) { + t.Run("copies whitespace tokens", func(t *testing.T) { + s := NewScanner("test.sql", " \n\t token") + var target []Unparsed + + NextTokenCopyingWhitespace(s, &target) + + assert.NotEmpty(t, target) + assert.Equal(t, UnquotedIdentifierToken, s.TokenType()) + }) + + t.Run("copies comments", func(t *testing.T) { + s := NewScanner("test.sql", "/* comment */ -- line\ntoken") + var target []Unparsed + + NextTokenCopyingWhitespace(s, &target) + + assert.True(t, len(target) >= 2) + assert.Equal(t, UnquotedIdentifierToken, s.TokenType()) + }) + + t.Run("stops at EOF", func(t *testing.T) { + s := NewScanner("test.sql", " ") + var target []Unparsed + + NextTokenCopyingWhitespace(s, &target) + + assert.Equal(t, EOFToken, s.TokenType()) + }) +} + +func TestCreateUnparsed(t *testing.T) { + t.Run("creates unparsed from scanner", func(t *testing.T) { + s := NewScanner("test.sql", "select") + s.NextToken() + + unparsed := CreateUnparsed(s) + + assert.Equal(t, ReservedWordToken, unparsed.Type) + assert.Equal(t, "select", unparsed.RawValue) + assert.Equal(t, Pos{File: "test.sql", Line: 1, Col: 1}, unparsed.Start) + }) +} + +func TestDocument_recoverToNextStatement(t *testing.T) { + t.Run("recovers to declare", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "invalid tokens here declare @x int = 1") + s.NextToken() + + doc.recoverToNextStatement(s) + + assert.Equal(t, ReservedWordToken, s.TokenType()) + assert.Equal(t, "declare", s.ReservedWord()) + }) + + t.Run("recovers to create", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "bad stuff create procedure") + s.NextToken() + + doc.recoverToNextStatement(s) + + assert.Equal(t, "create", s.ReservedWord()) + }) + + t.Run("recovers to go", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "error error go") + s.NextToken() + + doc.recoverToNextStatement(s) + + assert.Equal(t, "go", s.ReservedWord()) + }) + + t.Run("stops at EOF", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "no keywords") + s.NextToken() + + doc.recoverToNextStatement(s) + + assert.Equal(t, EOFToken, s.TokenType()) + }) +} + +func TestDocument_recoverToNextStatementCopying(t *testing.T) { + t.Run("copies tokens while recovering", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "bad token declare") + s.NextToken() + var target []Unparsed + + doc.recoverToNextStatementCopying(s, &target) + + assert.NotEmpty(t, target) + assert.Equal(t, "declare", s.ReservedWord()) + }) +} diff --git a/sqlparser/dom.go b/sqlparser/dom.go index 14209ee..22afdaa 100644 --- a/sqlparser/dom.go +++ b/sqlparser/dom.go @@ -129,13 +129,6 @@ func (e Error) WithoutPos() Error { return Error{Message: e.Message} } -type Document struct { - PragmaIncludeIf []string - Creates []Create - Declares []Declare - Errors []Error -} - func (c Create) Serialize(w io.StringWriter) error { for _, l := range c.Body { if _, err := w.WriteString(l.RawValue); err != nil { diff --git a/sqlparser/scanner.go b/sqlparser/scanner.go index 3103894..a5fb75a 100644 --- a/sqlparser/scanner.go +++ b/sqlparser/scanner.go @@ -1,11 +1,12 @@ package sqlparser import ( - "github.com/smasher164/xid" "regexp" "strings" "unicode" "unicode/utf8" + + "github.com/smasher164/xid" ) // dedicated type for reference to file, in case we need to refactor this later.. @@ -40,6 +41,10 @@ type Scanner struct { reservedWord string // in the event that the token is a ReservedWordToken, this contains the lower-case version } +func NewScanner(path FileRef, input string) *Scanner { + return &Scanner{input: input, file: path} +} + type TokenType int func (s *Scanner) TokenType() TokenType { @@ -316,7 +321,7 @@ func (s *Scanner) scanIdentifier() { s.curIndex = len(s.input) } -// DRY helper to handle both '' and ]] escapes +// DRY helper to handle both ” and ]] escapes func (s *Scanner) scanUntilSingleDoubleEscapes(endmarker rune, tokenType TokenType, unterminatedTokenType TokenType) TokenType { skipnext := false for i, r := range s.input[s.curIndex:] { From e14a78360cfe04414738c4089cc8cf6a434c0d4c Mon Sep 17 00:00:00 2001 From: Kyle Simukka Date: Thu, 11 Dec 2025 19:57:25 +0100 Subject: [PATCH 19/40] Refactored to use a Document interface. --- cli/cmd/constants.go | 8 +- cli/cmd/dep.go | 8 +- deployable.go | 6 +- preprocess.go | 8 +- sqlparser/create.go | 98 +++++++ sqlparser/document.go | 526 +++---------------------------------- sqlparser/document_test.go | 521 ++---------------------------------- sqlparser/dom.go | 154 +---------- sqlparser/parser.go | 33 +-- sqlparser/parser_test.go | 116 ++++---- 10 files changed, 249 insertions(+), 1229 deletions(-) create mode 100644 sqlparser/create.go diff --git a/cli/cmd/constants.go b/cli/cmd/constants.go index a71b364..90d071a 100644 --- a/cli/cmd/constants.go +++ b/cli/cmd/constants.go @@ -20,18 +20,18 @@ var ( if err != nil { return err } - if len(d.CodeBase.Creates) == 0 && len(d.CodeBase.Declares) == 0 { + if d.CodeBase.Empty() { fmt.Println("No SQL code found in given paths") } - if len(d.CodeBase.Errors) > 0 { + if d.CodeBase.HasErrors() { fmt.Println("Errors:") - for _, e := range d.CodeBase.Errors { + for _, e := range d.CodeBase.Errors() { fmt.Printf("%s:%d:%d: %s\n", e.Pos.File, e.Pos.Line, e.Pos.Line, e.Message) } return nil } fmt.Println("declare") - for i, c := range d.CodeBase.Declares { + for i, c := range d.CodeBase.Declares() { var prefix string if i == 0 { prefix = " " diff --git a/cli/cmd/dep.go b/cli/cmd/dep.go index c0d110c..528b5b5 100644 --- a/cli/cmd/dep.go +++ b/cli/cmd/dep.go @@ -36,16 +36,16 @@ var ( fmt.Println() err = nil } - if len(d.CodeBase.Creates) == 0 && len(d.CodeBase.Declares) == 0 { + if d.CodeBase.Empty() { fmt.Println("No SQL code found in given paths") } - if len(d.CodeBase.Errors) > 0 { + if d.CodeBase.HasErrors() { fmt.Println("Errors:") - for _, e := range d.CodeBase.Errors { + for _, e := range d.CodeBase.Errors() { fmt.Printf("%s:%d:%d: %s\n", e.Pos.File, e.Pos.Line, e.Pos.Line, e.Message) } } - for _, c := range d.CodeBase.Creates { + for _, c := range d.CodeBase.Creates() { fmt.Println(c.QuotedName.String() + ":") if len(c.DependsOn) > 0 { fmt.Println(" Uses:") diff --git a/deployable.go b/deployable.go index 03cc0a8..135fd26 100644 --- a/deployable.go +++ b/deployable.go @@ -277,7 +277,7 @@ func (d *Deployable) IsUploadedFromCache(dbc DB) bool { // TODO: StringConst. This requires parsing a SQL literal, a bit too complex // to code up just-in-case func (d Deployable) IntConst(s string) (int, error) { - for _, declare := range d.CodeBase.Declares { + for _, declare := range d.CodeBase.Declares() { if declare.VariableName == s { // TODO: more robust integer SQL parsing than this; only works // in most common cases @@ -311,8 +311,8 @@ type Options struct { func Include(opts Options, fsys ...fs.FS) (result Deployable, err error) { parsedFiles, doc, err := sqlparser.ParseFilesystems(fsys, opts.IncludeTags) - if len(doc.Errors) > 0 && !opts.PartialParseResults { - return Deployable{}, SQLCodeParseErrors{Errors: doc.Errors} + if doc.HasErrors() && !opts.PartialParseResults { + return Deployable{}, SQLCodeParseErrors{Errors: doc.Errors()} } result.CodeBase = doc diff --git a/preprocess.go b/preprocess.go index c4a776b..9478c3f 100644 --- a/preprocess.go +++ b/preprocess.go @@ -16,10 +16,10 @@ import ( func SchemaSuffixFromHash(doc sqlparser.Document) string { hasher := sha256.New() - for _, dec := range doc.Declares { + for _, dec := range doc.Declares() { hasher.Write([]byte(dec.String() + "\n")) } - for _, c := range doc.Creates { + for _, c := range doc.Creates() { if err := c.SerializeBytes(hasher); err != nil { panic(err) // asserting that sha256 will never return a write error... } @@ -149,7 +149,7 @@ func Preprocess(doc sqlparser.Document, schemasuffix string, driver driver.Drive } declares := make(map[string]string) - for _, dec := range doc.Declares { + for _, dec := range doc.Declares() { declares[dec.VariableName] = dec.Literal.RawValue } @@ -164,7 +164,7 @@ func Preprocess(doc sqlparser.Document, schemasuffix string, driver driver.Drive target = fmt.Sprintf(`"code@%s"`, schemasuffix) } - for _, create := range doc.Creates { + for _, create := range doc.Creates() { if len(create.Body) == 0 { continue } diff --git a/sqlparser/create.go b/sqlparser/create.go new file mode 100644 index 0000000..0bdbbf2 --- /dev/null +++ b/sqlparser/create.go @@ -0,0 +1,98 @@ +package sqlparser + +import ( + "database/sql/driver" + "io" + "strings" + + "gopkg.in/yaml.v3" +) + +type Create struct { + CreateType string // "procedure", "function" or "type" + QuotedName PosString // proc/func/type name, including [] + Body []Unparsed + DependsOn []PosString + Docstring []PosString // comment lines before the create statement. Note: this is also part of Body + Driver driver.Driver // the sql driver this document is intended for +} + +func (c Create) DocstringAsString() string { + var result []string + for _, line := range c.Docstring { + result = append(result, line.Value) + } + return strings.Join(result, "\n") +} + +func (c Create) DocstringYamldoc() (string, error) { + var yamldoc []string + parsing := false + for _, line := range c.Docstring { + if strings.HasPrefix(line.Value, "--!") { + parsing = true + if !strings.HasPrefix(line.Value, "--! ") { + return "", Error{line.Pos, "YAML document in docstring; missing space after `--!`"} + } + yamldoc = append(yamldoc, line.Value[4:]) + } else if parsing { + return "", Error{line.Pos, "once embedded yaml document is started (lines prefixed with `--!`), it must continue until create statement"} + } + } + return strings.Join(yamldoc, "\n"), nil +} + +func (c Create) ParseYamlInDocstring(out any) error { + yamldoc, err := c.DocstringYamldoc() + if err != nil { + return err + } + return yaml.Unmarshal([]byte(yamldoc), out) +} + +func (c Create) Serialize(w io.StringWriter) error { + for _, l := range c.Body { + if _, err := w.WriteString(l.RawValue); err != nil { + return err + } + } + return nil +} + +func (c Create) SerializeBytes(w io.Writer) error { + for _, l := range c.Body { + if _, err := w.Write([]byte(l.RawValue)); err != nil { + return err + } + } + return nil +} + +func (c Create) String() string { + var buf strings.Builder + err := c.Serialize(&buf) + if err != nil { + panic(err) + } + return buf.String() +} + +func (c Create) WithoutPos() Create { + var body []Unparsed + for _, x := range c.Body { + body = append(body, x.WithoutPos()) + } + return Create{ + CreateType: c.CreateType, + QuotedName: c.QuotedName, + DependsOn: c.DependsOn, + Body: body, + } +} + +func (c Create) DependsOnStrings() (result []string) { + for _, x := range c.DependsOn { + result = append(result, x.Value) + } + return +} diff --git a/sqlparser/document.go b/sqlparser/document.go index c907aa8..21839ec 100644 --- a/sqlparser/document.go +++ b/sqlparser/document.go @@ -1,505 +1,45 @@ package sqlparser import ( - "fmt" - "sort" + "path/filepath" "strings" - - "github.com/jackc/pgx/v5/stdlib" - mssql "github.com/microsoft/go-mssqldb" ) -type Document struct { - PragmaIncludeIf []string - Creates []Create - Declares []Declare - Errors []Error -} - -func (d *Document) addError(s *Scanner, msg string) { - d.Errors = append(d.Errors, Error{ - Pos: s.Start(), - Message: msg, - }) -} - -func (d *Document) unexpectedTokenError(s *Scanner) { - d.addError(s, "Unexpected: "+s.Token()) -} - -func (doc *Document) parseTypeExpression(s *Scanner) (t Type) { - parseArgs := func() { - // parses *after* the initial (; consumes trailing ) - for { - switch { - case s.TokenType() == NumberToken: - t.Args = append(t.Args, s.Token()) - case s.TokenType() == UnquotedIdentifierToken && s.TokenLower() == "max": - t.Args = append(t.Args, "max") - default: - doc.unexpectedTokenError(s) - doc.recoverToNextStatement(s) - return - } - s.NextNonWhitespaceCommentToken() - switch { - case s.TokenType() == CommaToken: - s.NextNonWhitespaceCommentToken() - continue - case s.TokenType() == RightParenToken: - s.NextNonWhitespaceCommentToken() - return - default: - doc.unexpectedTokenError(s) - doc.recoverToNextStatement(s) - return - } - } - } - - if s.TokenType() != UnquotedIdentifierToken { - panic("assertion failed, bug in caller") - } - t.BaseType = s.Token() - s.NextNonWhitespaceCommentToken() - if s.TokenType() == LeftParenToken { - s.NextNonWhitespaceCommentToken() - parseArgs() - } +// Document represents a parsed SQL document, containing +// declarations, create statements, pragmas, and errors. +// It provides methods to access and manipulate these components +// for T-SQL and PostgreSQL +type Document interface { + Empty() bool + HasErrors() bool + + Creates() []Create + Declares() []Declare + Errors() []Error + PragmaIncludeIf() []string + Include(other Document) + Sort() + ParsePragmas(s *Scanner) + ParseBatch(s *Scanner, isFirst bool) (hasMore bool) + + WithoutPos() Document +} + +// Helper function to parse a SQL document from a string input +func ParseString(filename FileRef, input string) (result Document) { + result = NewDocumentFromExtension(filepath.Ext(strings.ToLower(string(filename)))) + Parse(&Scanner{input: input, file: filename}, result) return } -func (doc *Document) parseDeclare(s *Scanner) (result []Declare) { - declareStart := s.Start() - // parse what is *after* the `declare` reserved keyword -loop: - for { - if s.TokenType() != VariableIdentifierToken { - doc.unexpectedTokenError(s) - doc.recoverToNextStatement(s) - return - } - - variableName := s.Token() - if !strings.HasPrefix(strings.ToLower(variableName), "@enum") && - !strings.HasPrefix(strings.ToLower(variableName), "@global") && - !strings.HasPrefix(strings.ToLower(variableName), "@const") { - doc.addError(s, "sqlcode constants needs to have names starting with @Enum, @Global or @Const: "+variableName) - } - s.NextNonWhitespaceCommentToken() - var variableType Type - switch s.TokenType() { - case EqualToken: - doc.addError(s, "sqlcode constants needs a type declared explicitly") - s.NextNonWhitespaceCommentToken() - case UnquotedIdentifierToken: - variableType = doc.parseTypeExpression(s) - } - - if s.TokenType() != EqualToken { - doc.addError(s, "sqlcode constants needs to be assigned at once using =") - doc.recoverToNextStatement(s) - } - - switch s.NextNonWhitespaceCommentToken() { - case NumberToken, NVarcharLiteralToken, VarcharLiteralToken: - result = append(result, Declare{ - Start: declareStart, - Stop: s.Stop(), - VariableName: variableName, - Datatype: variableType, - Literal: CreateUnparsed(s), - }) - default: - doc.unexpectedTokenError(s) - doc.recoverToNextStatement(s) - return - } - - switch s.NextNonWhitespaceCommentToken() { - case CommaToken: - s.NextNonWhitespaceCommentToken() - continue - case SemicolonToken: - s.NextNonWhitespaceCommentToken() - break loop - default: - break loop - } - } - if len(result) == 0 { - doc.addError(s, "incorrect syntax; no variables successfully declared") - } - return -} - -func (doc *Document) parseBatchSeparator(s *Scanner) { - // just saw a 'go'; just make sure there's nothing bad trailing it - // (if there is, convert to errors and move on until the line is consumed - errorEmitted := false - for { - switch s.NextToken() { - case WhitespaceToken: - continue - case MalformedBatchSeparatorToken: - if !errorEmitted { - doc.addError(s, "`go` should be alone on a line without any comments") - errorEmitted = true - } - continue - default: - return - } - } -} - -func (doc *Document) parseDeclareBatch(s *Scanner) (hasMore bool) { - if s.ReservedWord() != "declare" { - panic("assertion failed, incorrect use in caller") - } - for { - tt := s.TokenType() - switch { - case tt == EOFToken: - return false - case tt == ReservedWordToken && s.ReservedWord() == "declare": - s.NextNonWhitespaceCommentToken() - d := doc.parseDeclare(s) - doc.Declares = append(doc.Declares, d...) - case tt == ReservedWordToken && s.ReservedWord() != "declare": - doc.addError(s, "Only 'declare' allowed in this batch") - doc.recoverToNextStatement(s) - case tt == BatchSeparatorToken: - doc.parseBatchSeparator(s) - return true - default: - doc.unexpectedTokenError(s) - doc.recoverToNextStatement(s) - } - } -} - -func (doc *Document) parseBatch(s *Scanner, isFirst bool) (hasMore bool) { - var nodes []Unparsed - var docstring []PosString - newLineEncounteredInDocstring := false - - var createCountInBatch int - - for { - tt := s.TokenType() - switch tt { - case EOFToken: - return false - case WhitespaceToken, MultilineCommentToken: - nodes = append(nodes, CreateUnparsed(s)) - // do not reset token for a single trailing newline - t := s.Token() - if !newLineEncounteredInDocstring && (t == "\n" || t == "\r\n") { - newLineEncounteredInDocstring = true - } else { - docstring = nil - } - s.NextToken() - case SinglelineCommentToken: - // We build up a list of single line comments for the "docstring"; - // it is reset whenever we encounter something else - docstring = append(docstring, PosString{s.Start(), s.Token()}) - nodes = append(nodes, CreateUnparsed(s)) - newLineEncounteredInDocstring = false - s.NextToken() - case ReservedWordToken: - switch s.ReservedWord() { - case "declare": - // First declare-statement; enter a mode where we assume all contents - // of batch are declare statements - if !isFirst { - doc.addError(s, "'declare' statement only allowed in first batch") - } - // regardless of errors, go on and parse as far as we get... - return doc.parseDeclareBatch(s) - case "create": - // should be start of create procedure or create function... - c := doc.parseCreate(s, createCountInBatch) - - if strings.HasSuffix(string(s.file), ".sql") { - c.Driver = &mssql.Driver{} - } - if strings.HasSuffix(string(s.file), ".pgsql") { - c.Driver = &stdlib.Driver{} - } - - // *prepend* what we saw before getting to the 'create' - createCountInBatch++ - c.Body = append(nodes, c.Body...) - c.Docstring = docstring - doc.Creates = append(doc.Creates, c) - default: - doc.addError(s, "Expected 'declare' or 'create', got: "+s.ReservedWord()) - s.NextToken() - } - case BatchSeparatorToken: - doc.parseBatchSeparator(s) - return true - default: - doc.unexpectedTokenError(s) - s.NextToken() - docstring = nil - } - } -} - -func (d *Document) recoverToNextStatementCopying(s *Scanner, target *[]Unparsed) { - // We hit an unexpected token ... as an heuristic for continuing parsing, - // skip parsing until we hit a reserved word that starts a statement - // we recognize - for { - NextTokenCopyingWhitespace(s, target) - switch s.TokenType() { - case ReservedWordToken: - switch s.ReservedWord() { - case "declare", "create", "go": - return - } - case EOFToken: - return - default: - CopyToken(s, target) - } - } -} - -func (d *Document) recoverToNextStatement(s *Scanner) { - // We hit an unexpected token ... as an heuristic for continuing parsing, - // skip parsing until we hit a reserved word that starts a statement - // we recognize - for { - s.NextNonWhitespaceCommentToken() - switch s.TokenType() { - case ReservedWordToken: - switch s.ReservedWord() { - case "declare", "create", "go": - return - } - case EOFToken: - return - } - } -} - -// parseCodeschemaName parses `[code] . something`, and returns `something` -// in quoted form (`[something]`). Also copy to `target`. Empty string on error. -// Note: To follow conventions, consume one extra token at the end even if we know -// it fill not be consumed by this function... -func (d *Document) parseCodeschemaName(s *Scanner, target *[]Unparsed) PosString { - CopyToken(s, target) - NextTokenCopyingWhitespace(s, target) - if s.TokenType() != DotToken { - d.addError(s, fmt.Sprintf("[code] must be followed by '.'")) - d.recoverToNextStatementCopying(s, target) - return PosString{Value: ""} - } - CopyToken(s, target) - - NextTokenCopyingWhitespace(s, target) - switch s.TokenType() { - case UnquotedIdentifierToken: - // To get something uniform for comparison, quote all names - CopyToken(s, target) - result := PosString{Pos: s.Start(), Value: "[" + s.Token() + "]"} - NextTokenCopyingWhitespace(s, target) - return result - case QuotedIdentifierToken: - CopyToken(s, target) - result := PosString{Pos: s.Start(), Value: s.Token()} - NextTokenCopyingWhitespace(s, target) - return result +// Based on the input file extension, create the appropriate Document type +func NewDocumentFromExtension(extension string) Document { + switch extension { + case ".sql": + return &TSqlDocument{} + case ".pgsql": + return &PGSqlDocument{} default: - d.addError(s, fmt.Sprintf("[code]. must be followed an identifier")) - d.recoverToNextStatementCopying(s, target) - return PosString{Value: ""} - } -} - -// parseCreate parses anything that starts with "create". Position is -// *on* the create token. -// At this stage in sqlcode parser development we're only interested -// in procedures/functions/types as opaque blocks of SQL code where -// we only track dependencies between them and their declared name; -// so we treat them with the same code. We consume until the end of -// the batch; only one declaration allowed per batch. Everything -// parsed here will also be added to `batch`. On any error, copying -// to batch stops / becomes erratic.. -func (d *Document) parseCreate(s *Scanner, createCountInBatch int) (result Create) { - if s.ReservedWord() != "create" { - panic("illegal use by caller") - } - CopyToken(s, &result.Body) - - NextTokenCopyingWhitespace(s, &result.Body) - - createType := strings.ToLower(s.Token()) - if !(createType == "procedure" || createType == "function" || createType == "type") { - d.addError(s, fmt.Sprintf("sqlcode only supports creating procedures, functions or types; not `%s`", createType)) - d.recoverToNextStatementCopying(s, &result.Body) - return - } - if (createType == "procedure" || createType == "function") && createCountInBatch > 0 { - d.addError(s, "a procedure/function must be alone in a batch; use 'go' to split batches") - d.recoverToNextStatementCopying(s, &result.Body) - return - } - - result.CreateType = createType - CopyToken(s, &result.Body) - - NextTokenCopyingWhitespace(s, &result.Body) - - // Insist on [code]. - if s.TokenType() != QuotedIdentifierToken || s.Token() != "[code]" { - d.addError(s, fmt.Sprintf("create %s must be followed by [code].", result.CreateType)) - d.recoverToNextStatementCopying(s, &result.Body) - return - } - result.QuotedName = d.parseCodeschemaName(s, &result.Body) - if result.QuotedName.String() == "" { - return - } - - // We have matched "create [code]."; at this - // point we copy the rest until the batch ends; *but* track dependencies - // + some other details mentioned below - - //firstAs := true // See comment below on rowcount - -tailloop: - for { - tt := s.TokenType() - switch { - case tt == ReservedWordToken && s.ReservedWord() == "create": - // So, we're currently parsing 'create ...' and we see another 'create'. - // We split in two cases depending on the context we are currently in - // (createType is referring to how we entered this function, *NOT* the - // `create` statement we are looking at now - switch createType { // note: this is the *outer* create type, not the one of current scanner position - case "function", "procedure": - // Within a function/procedure we can allow 'create index', 'create table' and nothing - // else. (Well, only procedures can have them, but we'll leave it to T-SQL to complain - // about that aspect, not relevant for batch / dependency parsing) - // - // What is important is a function/procedure/type isn't started on without a 'go' - // in between; so we block those 3 from appearing in the same batch - CopyToken(s, &result.Body) - NextTokenCopyingWhitespace(s, &result.Body) - tt2 := s.TokenType() - - if (tt2 == ReservedWordToken && (s.ReservedWord() == "function" || s.ReservedWord() == "procedure")) || - (tt2 == UnquotedIdentifierToken && s.TokenLower() == "type") { - d.recoverToNextStatementCopying(s, &result.Body) - d.addError(s, "a procedure/function must be alone in a batch; use 'go' to split batches") - return - } - case "type": - // We allow more than one type creation in a batch; and 'create' can never appear - // scoped within 'create type'. So at a new create we are done with the previous - // one, and return it -- the caller can then re-enter this function from the top - break tailloop - default: - panic("assertion failed") - } - - case tt == EOFToken || tt == BatchSeparatorToken: - break tailloop - case tt == QuotedIdentifierToken && s.Token() == "[code]": - // Parse a dependency - dep := d.parseCodeschemaName(s, &result.Body) - found := false - for _, existing := range result.DependsOn { - if existing.Value == dep.Value { - found = true - break - } - } - if !found { - result.DependsOn = append(result.DependsOn, dep) - } - case tt == ReservedWordToken && s.Token() == "as": - CopyToken(s, &result.Body) - NextTokenCopyingWhitespace(s, &result.Body) - /* - TODO: Fix and re-enable - This code add RoutineName for convenience. So: - - create procedure [code@5420c0269aaf].Test as - begin - select 1 - end - go - - becomes: - - create procedure [code@5420c0269aaf].Test as - declare @RoutineName nvarchar(128) - set @RoutineName = 'Test' - begin - select 1 - end - go - - However, for some very strange reason, @@rowcount is 1 with the first version, - and it is 2 with the second version. - if firstAs { - // Add the `RoutineName` token as a convenience, so that we can refer to the procedure/function name - // from inside the procedure (for example, when logging) - if result.CreateType == "procedure" { - procNameToken := Unparsed{ - Type: OtherToken, - RawValue: fmt.Sprintf(templateRoutineName, strings.Trim(result.QuotedName.Value, "[]")), - } - result.Body = append(result.Body, procNameToken) - } - firstAs = false - } - */ - - default: - CopyToken(s, &result.Body) - NextTokenCopyingWhitespace(s, &result.Body) - } - } - - sort.Slice(result.DependsOn, func(i, j int) bool { - return result.DependsOn[i].Value < result.DependsOn[j].Value - }) - return -} - -// NextTokenCopyingWhitespace is like s.NextToken(), but if whitespace is encountered -// it is simply copied into `target`. Upon return, the scanner is located at a non-whitespace -// token, and target is either unmodified or filled with some whitespace nodes. -func NextTokenCopyingWhitespace(s *Scanner, target *[]Unparsed) { - for { - tt := s.NextToken() - switch tt { - case EOFToken, BatchSeparatorToken: - // do not copy - return - case WhitespaceToken, MultilineCommentToken, SinglelineCommentToken: - // copy, and loop around - CopyToken(s, target) - continue - default: - return - } - } - -} - -func CreateUnparsed(s *Scanner) Unparsed { - return Unparsed{ - Type: s.TokenType(), - Start: s.Start(), - Stop: s.Stop(), - RawValue: s.Token(), + panic("unhandled document type: " + extension) } } diff --git a/sqlparser/document_test.go b/sqlparser/document_test.go index 5ddddb6..8e497e6 100644 --- a/sqlparser/document_test.go +++ b/sqlparser/document_test.go @@ -7,513 +7,50 @@ import ( "github.com/stretchr/testify/require" ) -func TestDocument_addError(t *testing.T) { - t.Run("adds error with position", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "select") - s.NextToken() +func TestNewDocumentFromExtension(t *testing.T) { + t.Run("returns TSqlDocument for .sql extension", func(t *testing.T) { + doc := NewDocumentFromExtension(".sql") - doc.addError(s, "test error message") - - require.Len(t, doc.Errors, 1) - assert.Equal(t, "test error message", doc.Errors[0].Message) - assert.Equal(t, Pos{File: "test.sql", Line: 1, Col: 1}, doc.Errors[0].Pos) - }) - - t.Run("accumulates multiple errors", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "abc def") - s.NextToken() - doc.addError(s, "error 1") - s.NextToken() - doc.addError(s, "error 2") - - require.Len(t, doc.Errors, 2) - assert.Equal(t, "error 1", doc.Errors[0].Message) - assert.Equal(t, "error 2", doc.Errors[1].Message) + _, ok := doc.(*TSqlDocument) + assert.True(t, ok, "Expected TSqlDocument type") + assert.NotNil(t, doc) }) - t.Run("creates error with token text", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "unexpected_token") - s.NextToken() - - doc.unexpectedTokenError(s) + t.Run("returns PGSqlDocument for .pgsql extension", func(t *testing.T) { + doc := NewDocumentFromExtension(".pgsql") - require.Len(t, doc.Errors, 1) - assert.Equal(t, "Unexpected: unexpected_token", doc.Errors[0].Message) + _, ok := doc.(*PGSqlDocument) + assert.True(t, ok, "Expected PGSqlDocument type") + assert.NotNil(t, doc) }) -} - -func TestDocument_parseTypeExpression(t *testing.T) { - t.Run("parses simple type without args", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "int") - s.NextToken() - - typ := doc.parseTypeExpression(s) - - assert.Equal(t, "int", typ.BaseType) - assert.Empty(t, typ.Args) - }) - - t.Run("parses type with single arg", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "varchar(50)") - s.NextToken() - - typ := doc.parseTypeExpression(s) - - assert.Equal(t, "varchar", typ.BaseType) - assert.Equal(t, []string{"50"}, typ.Args) - }) - - t.Run("parses type with multiple args", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "decimal(10, 2)") - s.NextToken() - - typ := doc.parseTypeExpression(s) - - assert.Equal(t, "decimal", typ.BaseType) - assert.Equal(t, []string{"10", "2"}, typ.Args) - }) - - t.Run("parses type with max", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "nvarchar(max)") - s.NextToken() - - typ := doc.parseTypeExpression(s) - - assert.Equal(t, "nvarchar", typ.BaseType) - assert.Equal(t, []string{"max"}, typ.Args) - }) - - t.Run("handles invalid arg", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "varchar(invalid)") - s.NextToken() - - typ := doc.parseTypeExpression(s) - - assert.Equal(t, "varchar", typ.BaseType) - assert.NotEmpty(t, doc.Errors) - }) - - t.Run("panics if not on identifier", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "123") - s.NextToken() + t.Run("panics for unsupported extension", func(t *testing.T) { assert.Panics(t, func() { - doc.parseTypeExpression(s) - }) - }) -} - -func TestDocument_parseDeclare(t *testing.T) { - t.Run("parses single enum declaration", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "@EnumStatus int = 42") - s.NextToken() - - declares := doc.parseDeclare(s) - - require.Len(t, declares, 1) - assert.Equal(t, "@EnumStatus", declares[0].VariableName) - assert.Equal(t, "int", declares[0].Datatype.BaseType) - assert.Equal(t, "42", declares[0].Literal.RawValue) - }) - - t.Run("parses multiple declarations with comma", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "@EnumA int = 1, @EnumB int = 2;") - s.NextToken() - - declares := doc.parseDeclare(s) - - require.Len(t, declares, 2) - assert.Equal(t, "@EnumA", declares[0].VariableName) - assert.Equal(t, "@EnumB", declares[1].VariableName) - }) - - t.Run("parses string literal", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "@EnumName nvarchar(50) = N'test'") - s.NextToken() - - declares := doc.parseDeclare(s) - - require.Len(t, declares, 1) - assert.Equal(t, "N'test'", declares[0].Literal.RawValue) - }) - - t.Run("errors on invalid variable name", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "@InvalidName int = 1") - s.NextToken() - - declares := doc.parseDeclare(s) - - require.Len(t, declares, 1) - assert.NotEmpty(t, doc.Errors) - assert.Contains(t, doc.Errors[0].Message, "@InvalidName") - }) - - t.Run("errors on missing type", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "@EnumTest = 42") - s.NextToken() - - declares := doc.parseDeclare(s) - - require.Len(t, declares, 1) - assert.NotEmpty(t, doc.Errors) - assert.Contains(t, doc.Errors[0].Message, "type declared explicitly") - }) - - t.Run("errors on missing assignment", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "@EnumTest int") - s.NextToken() - - doc.parseDeclare(s) - - assert.NotEmpty(t, doc.Errors) - assert.Contains(t, doc.Errors[0].Message, "needs to be assigned") - }) - - t.Run("accepts @Global prefix", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "@GlobalSetting int = 100") - s.NextToken() - - declares := doc.parseDeclare(s) - - require.Len(t, declares, 1) - assert.Equal(t, "@GlobalSetting", declares[0].VariableName) - assert.Empty(t, doc.Errors) - }) - - t.Run("accepts @Const prefix", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "@ConstValue int = 200") - s.NextToken() - - declares := doc.parseDeclare(s) - - require.Len(t, declares, 1) - assert.Equal(t, "@ConstValue", declares[0].VariableName) - assert.Empty(t, doc.Errors) - }) -} - -func TestDocument_parseBatchSeparator(t *testing.T) { - t.Run("parses valid go separator", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "go\n") - s.NextToken() - - doc.parseBatchSeparator(s) - - assert.Empty(t, doc.Errors) - }) - - t.Run("errors on malformed separator", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "go -- comment") - s.NextToken() - - doc.parseBatchSeparator(s) - - assert.NotEmpty(t, doc.Errors) - assert.Contains(t, doc.Errors[0].Message, "should be alone") - }) -} - -func TestDocument_parseCodeschemaName(t *testing.T) { - t.Run("parses unquoted identifier", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "[code].TestProc") - s.NextToken() - var target []Unparsed - - result := doc.parseCodeschemaName(s, &target) - - assert.Equal(t, "[TestProc]", result.Value) - assert.NotEmpty(t, target) - }) - - t.Run("parses quoted identifier", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "[code].[Test Proc]") - s.NextToken() - var target []Unparsed - - result := doc.parseCodeschemaName(s, &target) - - assert.Equal(t, "[Test Proc]", result.Value) - }) - - t.Run("errors on missing dot", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "[code] TestProc") - s.NextToken() - var target []Unparsed - - result := doc.parseCodeschemaName(s, &target) - - assert.Equal(t, "", result.Value) - assert.NotEmpty(t, doc.Errors) - assert.Contains(t, doc.Errors[0].Message, "must be followed by '.'") + NewDocumentFromExtension(".txt") + }, "Expected panic for unsupported extension") }) - t.Run("errors on missing identifier", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "[code].123") - s.NextToken() - var target []Unparsed - - result := doc.parseCodeschemaName(s, &target) - - assert.Equal(t, "", result.Value) - assert.NotEmpty(t, doc.Errors) - assert.Contains(t, doc.Errors[0].Message, "must be followed an identifier") - }) -} - -func TestDocument_parseCreate(t *testing.T) { - t.Run("parses simple procedure", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "create procedure [code].TestProc as begin end") - s.NextToken() - s.NextNonWhitespaceCommentToken() - - create := doc.parseCreate(s, 0) - - assert.Equal(t, "procedure", create.CreateType) - assert.Equal(t, "[TestProc]", create.QuotedName.Value) - assert.NotEmpty(t, create.Body) - }) - - t.Run("parses function", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "create function [code].TestFunc() returns int as begin return 1 end") - s.NextToken() - s.NextNonWhitespaceCommentToken() - - create := doc.parseCreate(s, 0) - - assert.Equal(t, "function", create.CreateType) - assert.Equal(t, "[TestFunc]", create.QuotedName.Value) - }) - - t.Run("parses type", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "create type [code].TestType as table (id int)") - s.NextToken() - s.NextNonWhitespaceCommentToken() - - create := doc.parseCreate(s, 0) - - assert.Equal(t, "type", create.CreateType) - assert.Equal(t, "[TestType]", create.QuotedName.Value) - }) - - t.Run("tracks dependencies", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "create procedure [code].Proc1 as begin select * from [code].Table1 join [code].Table2 end") - s.NextToken() - s.NextNonWhitespaceCommentToken() - - create := doc.parseCreate(s, 0) - - require.Len(t, create.DependsOn, 2) - assert.Equal(t, "[Table1]", create.DependsOn[0].Value) - assert.Equal(t, "[Table2]", create.DependsOn[1].Value) - }) - - t.Run("deduplicates dependencies", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "create procedure [code].Proc1 as begin select * from [code].Table1; select * from [code].Table1 end") - s.NextToken() - s.NextNonWhitespaceCommentToken() - - create := doc.parseCreate(s, 0) - - require.Len(t, create.DependsOn, 1) - assert.Equal(t, "[Table1]", create.DependsOn[0].Value) - }) - - t.Run("errors on unsupported create type", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "create table [code].TestTable (id int)") - s.NextToken() - s.NextNonWhitespaceCommentToken() - - doc.parseCreate(s, 0) - - assert.NotEmpty(t, doc.Errors) - assert.Contains(t, doc.Errors[0].Message, "only supports creating procedures") - }) - - t.Run("errors on multiple procedures in batch", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "create procedure [code].Proc2 as begin end") - s.NextToken() - s.NextNonWhitespaceCommentToken() - - doc.parseCreate(s, 1) - - assert.NotEmpty(t, doc.Errors) - assert.Contains(t, doc.Errors[0].Message, "must be alone in a batch") - }) - - t.Run("errors on missing code schema", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "create procedure dbo.TestProc as begin end") - s.NextToken() - s.NextNonWhitespaceCommentToken() - - doc.parseCreate(s, 0) - - assert.NotEmpty(t, doc.Errors) - assert.Contains(t, doc.Errors[0].Message, "must be followed by [code]") - }) - - t.Run("allows create index inside procedure", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "create procedure [code].Proc as begin create index IX_Test on #temp(id) end") - s.NextToken() - s.NextNonWhitespaceCommentToken() - - create := doc.parseCreate(s, 0) - - assert.Equal(t, "procedure", create.CreateType) - assert.Empty(t, doc.Errors) - }) - - t.Run("stops at batch separator", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "create procedure [code].Proc as begin end\ngo") - s.NextToken() - s.NextNonWhitespaceCommentToken() - - create := doc.parseCreate(s, 0) - - assert.Equal(t, "[Proc]", create.QuotedName.Value) - assert.Equal(t, BatchSeparatorToken, s.TokenType()) - }) - - t.Run("panics if not on create token", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "procedure") - s.NextToken() - + t.Run("panics for empty extension", func(t *testing.T) { assert.Panics(t, func() { - doc.parseCreate(s, 0) - }) + NewDocumentFromExtension("") + }, "Expected panic for empty extension") }) -} - -func TestNextTokenCopyingWhitespace(t *testing.T) { - t.Run("copies whitespace tokens", func(t *testing.T) { - s := NewScanner("test.sql", " \n\t token") - var target []Unparsed - - NextTokenCopyingWhitespace(s, &target) - - assert.NotEmpty(t, target) - assert.Equal(t, UnquotedIdentifierToken, s.TokenType()) - }) - - t.Run("copies comments", func(t *testing.T) { - s := NewScanner("test.sql", "/* comment */ -- line\ntoken") - var target []Unparsed - - NextTokenCopyingWhitespace(s, &target) - - assert.True(t, len(target) >= 2) - assert.Equal(t, UnquotedIdentifierToken, s.TokenType()) - }) - - t.Run("stops at EOF", func(t *testing.T) { - s := NewScanner("test.sql", " ") - var target []Unparsed - - NextTokenCopyingWhitespace(s, &target) - - assert.Equal(t, EOFToken, s.TokenType()) - }) -} - -func TestCreateUnparsed(t *testing.T) { - t.Run("creates unparsed from scanner", func(t *testing.T) { - s := NewScanner("test.sql", "select") - s.NextToken() - - unparsed := CreateUnparsed(s) - assert.Equal(t, ReservedWordToken, unparsed.Type) - assert.Equal(t, "select", unparsed.RawValue) - assert.Equal(t, Pos{File: "test.sql", Line: 1, Col: 1}, unparsed.Start) - }) -} - -func TestDocument_recoverToNextStatement(t *testing.T) { - t.Run("recovers to declare", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "invalid tokens here declare @x int = 1") - s.NextToken() - - doc.recoverToNextStatement(s) - - assert.Equal(t, ReservedWordToken, s.TokenType()) - assert.Equal(t, "declare", s.ReservedWord()) - }) - - t.Run("recovers to create", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "bad stuff create procedure") - s.NextToken() - - doc.recoverToNextStatement(s) - - assert.Equal(t, "create", s.ReservedWord()) - }) - - t.Run("recovers to go", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "error error go") - s.NextToken() - - doc.recoverToNextStatement(s) - - assert.Equal(t, "go", s.ReservedWord()) + t.Run("panics for unknown SQL extension", func(t *testing.T) { + assert.Panics(t, func() { + NewDocumentFromExtension(".mysql") + }, "Expected panic for .mysql extension") }) - t.Run("stops at EOF", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "no keywords") - s.NextToken() - - doc.recoverToNextStatement(s) - - assert.Equal(t, EOFToken, s.TokenType()) + t.Run("extension matching is case insensitive", func(t *testing.T) { + assert.Panics(t, func() { + NewDocumentFromExtension(".SQL") + }, "Expected panic for uppercase .SQL") }) -} - -func TestDocument_recoverToNextStatementCopying(t *testing.T) { - t.Run("copies tokens while recovering", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "bad token declare") - s.NextToken() - var target []Unparsed - - doc.recoverToNextStatementCopying(s, &target) - assert.NotEmpty(t, target) - assert.Equal(t, "declare", s.ReservedWord()) + t.Run("returned documents implement Document interface", func(t *testing.T) { + sqlDoc := NewDocumentFromExtension(".sql") + pgsqlDoc := NewDocumentFromExtension(".pgsql") + require.NotEqual(t, sqlDoc, pgsqlDoc) }) } diff --git a/sqlparser/dom.go b/sqlparser/dom.go index 22afdaa..0c72587 100644 --- a/sqlparser/dom.go +++ b/sqlparser/dom.go @@ -1,12 +1,8 @@ package sqlparser import ( - "database/sql/driver" "fmt" - "io" "strings" - - "gopkg.in/yaml.v3" ) type Unparsed struct { @@ -61,48 +57,6 @@ func (p PosString) String() string { return p.Value } -type Create struct { - CreateType string // "procedure", "function" or "type" - QuotedName PosString // proc/func/type name, including [] - Body []Unparsed - DependsOn []PosString - Docstring []PosString // comment lines before the create statement. Note: this is also part of Body - Driver driver.Driver // the sql driver this document is intended for -} - -func (c Create) DocstringAsString() string { - var result []string - for _, line := range c.Docstring { - result = append(result, line.Value) - } - return strings.Join(result, "\n") -} - -func (c Create) DocstringYamldoc() (string, error) { - var yamldoc []string - parsing := false - for _, line := range c.Docstring { - if strings.HasPrefix(line.Value, "--!") { - parsing = true - if !strings.HasPrefix(line.Value, "--! ") { - return "", Error{line.Pos, "YAML document in docstring; missing space after `--!`"} - } - yamldoc = append(yamldoc, line.Value[4:]) - } else if parsing { - return "", Error{line.Pos, "once embedded yaml document is started (lines prefixed with `--!`), it must continue until create statement"} - } - } - return strings.Join(yamldoc, "\n"), nil -} - -func (c Create) ParseYamlInDocstring(out any) error { - yamldoc, err := c.DocstringYamldoc() - if err != nil { - return err - } - return yaml.Unmarshal([]byte(yamldoc), out) -} - type Type struct { BaseType string Args []string @@ -129,107 +83,11 @@ func (e Error) WithoutPos() Error { return Error{Message: e.Message} } -func (c Create) Serialize(w io.StringWriter) error { - for _, l := range c.Body { - if _, err := w.WriteString(l.RawValue); err != nil { - return err - } - } - return nil -} - -func (c Create) SerializeBytes(w io.Writer) error { - for _, l := range c.Body { - if _, err := w.Write([]byte(l.RawValue)); err != nil { - return err - } - } - return nil -} - -func (c Create) String() string { - var buf strings.Builder - err := c.Serialize(&buf) - if err != nil { - panic(err) - } - return buf.String() -} - -func (c Create) WithoutPos() Create { - var body []Unparsed - for _, x := range c.Body { - body = append(body, x.WithoutPos()) - } - return Create{ - CreateType: c.CreateType, - QuotedName: c.QuotedName, - DependsOn: c.DependsOn, - Body: body, - } -} - -func (c Create) DependsOnStrings() (result []string) { - for _, x := range c.DependsOn { - result = append(result, x.Value) - } - return -} - -// Transform a Document to remove all Position information; this is used -// to 'unclutter' a DOM to more easily write assertions on it. -func (d Document) WithoutPos() Document { - var cs []Create - for _, x := range d.Creates { - cs = append(cs, x.WithoutPos()) - } - var ds []Declare - for _, x := range d.Declares { - ds = append(ds, x.WithoutPos()) - } - var es []Error - for _, x := range d.Errors { - es = append(es, x.WithoutPos()) - } - return Document{ - Creates: cs, - Declares: ds, - Errors: es, - } -} - -func (d *Document) Include(other Document) { - // Do not copy PragmaIncludeIf, since that is local to a single file. - // Its contents is also present in each Create. - d.Declares = append(d.Declares, other.Declares...) - d.Creates = append(d.Creates, other.Creates...) - d.Errors = append(d.Errors, other.Errors...) -} - -func (d *Document) parseSinglePragma(s *Scanner) { - pragma := strings.TrimSpace(strings.TrimPrefix(s.Token(), "--sqlcode:")) - if pragma == "" { - return - } - parts := strings.Split(pragma, " ") - if len(parts) != 2 { - d.addError(s, "Illegal pragma: "+s.Token()) - return - } - if parts[0] != "include-if" { - d.addError(s, "Illegal pragma: "+s.Token()) - return - } - d.PragmaIncludeIf = append(d.PragmaIncludeIf, strings.Split(parts[1], ",")...) -} - -func (d *Document) parsePragmas(s *Scanner) { - for s.TokenType() == PragmaToken { - d.parseSinglePragma(s) - s.NextNonWhitespaceToken() +func CreateUnparsed(s *Scanner) Unparsed { + return Unparsed{ + Type: s.TokenType(), + Start: s.Start(), + Stop: s.Stop(), + RawValue: s.Token(), } } - -func (d Document) Empty() bool { - return len(d.Creates) > 0 || len(d.Declares) > 0 -} diff --git a/sqlparser/parser.go b/sqlparser/parser.go index f2b51d8..1a49a08 100644 --- a/sqlparser/parser.go +++ b/sqlparser/parser.go @@ -44,7 +44,7 @@ func AdvanceAndCopy(s *Scanner, target *[]Unparsed) { } } -func Parse(s *Scanner, result *Document) { +func Parse(s *Scanner, result Document) { // Top-level parse; this focuses on splitting into "batches" separated // by 'go'. @@ -62,20 +62,17 @@ func Parse(s *Scanner, result *Document) { // `s` will typically never be positioned on whitespace except in // whitespace-preserving parsing + filepath.Ext(s.input) + s.NextNonWhitespaceToken() - result.parsePragmas(s) - hasMore := result.parseBatch(s, true) + result.ParsePragmas(s) + hasMore := result.ParseBatch(s, true) for hasMore { - hasMore = result.parseBatch(s, false) + hasMore = result.ParseBatch(s, false) } return } -func ParseString(filename FileRef, input string) (result Document) { - Parse(&Scanner{input: input, file: filename}, &result) - return -} - // ParseFileystems iterates through a list of filesystems and parses all supported // SQL files and returns the combination of all of them. // @@ -129,10 +126,10 @@ func ParseFilesystems(fslst []fs.FS, includeTags []string) (filenames []string, } hashes[hash] = pathDesc - var fdoc Document - Parse(&Scanner{input: string(buf), file: FileRef(path)}, &fdoc) + fdoc := NewDocumentFromExtension(extension) + Parse(&Scanner{input: string(buf), file: FileRef(path)}, fdoc) - if matchesIncludeTags(fdoc.PragmaIncludeIf, includeTags) { + if matchesIncludeTags(fdoc.PragmaIncludeIf(), includeTags) { filenames = append(filenames, pathDesc) result.Include(fdoc) } @@ -144,17 +141,7 @@ func ParseFilesystems(fslst []fs.FS, includeTags []string) (filenames []string, } } - // Do the topological sort; and include any error with it as part - // of `result`, *not* return it as err - sortedCreates, errpos, sortErr := TopologicalSort(result.Creates) - if sortErr != nil { - result.Errors = append(result.Errors, Error{ - Pos: errpos, - Message: sortErr.Error(), - }) - } else { - result.Creates = sortedCreates - } + result.Sort() return } diff --git a/sqlparser/parser_test.go b/sqlparser/parser_test.go index 0e36b40..39f1e58 100644 --- a/sqlparser/parser_test.go +++ b/sqlparser/parser_test.go @@ -24,8 +24,8 @@ end; $$; `) - require.Len(t, doc.Creates, 1) - require.Equal(t, &stdlib.Driver{}, doc.Creates[0].Driver) + require.Len(t, doc.Creates(), 1) + require.Equal(t, &stdlib.Driver{}, doc.Creates()[0].Driver) } func TestParserSmokeTest(t *testing.T) { @@ -60,8 +60,8 @@ end; `) docNoPos := doc.WithoutPos() - require.Equal(t, 1, len(doc.Creates)) - c := doc.Creates[0] + require.Equal(t, 1, len(doc.Creates())) + c := doc.Creates()[0] require.Equal(t, &mssql.Driver{}, c.Driver) assert.Equal(t, "[TestFunc]", c.QuotedName.Value) @@ -82,7 +82,7 @@ end; { Message: "'declare' statement only allowed in first batch", }, - }, docNoPos.Errors) + }, docNoPos.Errors()) assert.Equal(t, []Declare{ @@ -130,7 +130,7 @@ end; }, }, }, - docNoPos.Declares, + docNoPos.Declares(), ) // repr.Println(doc) } @@ -150,21 +150,21 @@ create function [code].Two(); { Message: "a procedure/function must be alone in a batch; use 'go' to split batches", }, - }, doc.Errors) + }, doc.Errors()) } func TestBuggyDeclare(t *testing.T) { // this caused parses to infinitely loop; regression test... doc := ParseString("test.sql", `declare @EnumA int = 4 @EnumB tinyint = 5 @ENUM_C bigint = 435;`) - assert.Equal(t, 1, len(doc.Errors)) - assert.Equal(t, "Unexpected: @EnumB", doc.Errors[0].Message) + assert.Equal(t, 1, len(doc.Errors())) + assert.Equal(t, "Unexpected: @EnumB", doc.Errors()[0].Message) } func TestCreateType(t *testing.T) { doc := ParseString("test.sql", `create type [code].MyType as table (x int not null primary key);`) - assert.Equal(t, 1, len(doc.Creates)) - assert.Equal(t, "type", doc.Creates[0].CreateType) - assert.Equal(t, "[MyType]", doc.Creates[0].QuotedName.Value) + assert.Equal(t, 1, len(doc.Creates())) + assert.Equal(t, "type", doc.Creates()[0].CreateType) + assert.Equal(t, "[MyType]", doc.Creates()[0].QuotedName.Value) } func TestPragma(t *testing.T) { @@ -179,7 +179,7 @@ create procedure [code].ProcedureShouldAlsoHavePragmasAnnotated() func TestInfiniteLoopRegression(t *testing.T) { // success if we terminate!... doc := ParseString("test.sql", `@declare`) - assert.Equal(t, 1, len(doc.Errors)) + assert.Equal(t, 1, len(doc.Errors())) } func TestDeclareSeparation(t *testing.T) { @@ -190,7 +190,7 @@ func TestDeclareSeparation(t *testing.T) { doc := ParseString("test.sql", ` declare @EnumFirst int = 3, @EnumSecond varchar(max) = 'hello'declare @EnumThird int=3 declare @EnumFourth int=4;declare @EnumFifth int =5 `) - //repr.Println(doc.Declares) + //repr.Println(doc.Declares()) require.Equal(t, []Declare{ { VariableName: "@EnumFirst", @@ -217,7 +217,7 @@ declare @EnumFirst int = 3, @EnumSecond varchar(max) = 'hello'declare @EnumThird Datatype: Type{BaseType: "int"}, Literal: Unparsed{Type: NumberToken, RawValue: "5"}, }, - }, doc.WithoutPos().Declares) + }, doc.WithoutPos().Declares()) } func TestBatchDivisionsAndCreateStatements(t *testing.T) { @@ -232,7 +232,7 @@ go create type [code].Batch3 as table (x int); `) commentCount := 0 - for _, c := range doc.Creates { + for _, c := range doc.Creates() { for _, b := range c.Body { if strings.Contains(b.RawValue, "2nd") { commentCount++ @@ -251,13 +251,13 @@ create type [code].Type1 as table (x int); create type [code].Type2 as table (x int); create type [code].Type3 as table (x int); `) - require.Equal(t, 3, len(doc.Creates)) - assert.Equal(t, "[Type1]", doc.Creates[0].QuotedName.Value) - assert.Equal(t, "[Type3]", doc.Creates[2].QuotedName.Value) + require.Equal(t, 3, len(doc.Creates())) + assert.Equal(t, "[Type1]", doc.Creates()[0].QuotedName.Value) + assert.Equal(t, "[Type3]", doc.Creates()[2].QuotedName.Value) // There was a bug that the last item in the body would be the 'create' // of the next statement; regression test.. - assert.Equal(t, "\n", doc.Creates[0].Body[len(doc.Creates[0].Body)-1].RawValue) - assert.Equal(t, "create", doc.Creates[1].Body[0].RawValue) + assert.Equal(t, "\n", doc.Creates()[0].Body[len(doc.Creates()[0].Body)-1].RawValue) + assert.Equal(t, "create", doc.Creates()[1].Body[0].RawValue) } func TestCreateProcs(t *testing.T) { @@ -270,10 +270,10 @@ create type [code].MyType () create procedure [code].MyProcedure () `) // First function and last procedure triggers errors. - require.Equal(t, 2, len(doc.Errors)) + require.Equal(t, 2, len(doc.Errors())) emsg := "a procedure/function must be alone in a batch; use 'go' to split batches" - assert.Equal(t, emsg, doc.Errors[0].Message) - assert.Equal(t, emsg, doc.Errors[1].Message) + assert.Equal(t, emsg, doc.Errors()[0].Message) + assert.Equal(t, emsg, doc.Errors()[1].Message) } @@ -283,14 +283,14 @@ func TestCreateProcs2(t *testing.T) { create type [code].MyType () create procedure [code].FirstProc as table (x int) `) - //repr.Println(doc.Errors) + //repr.Println(doc.Errors()) // Code above was mainly to be able to step through parser in a given way. // First function triggers an error. Then create type is parsed which is // fine sharing a batch with others. - require.Equal(t, 1, len(doc.Errors)) + require.Equal(t, 1, len(doc.Errors())) emsg := "a procedure/function must be alone in a batch; use 'go' to split batches" - assert.Equal(t, emsg, doc.Errors[0].Message) + assert.Equal(t, emsg, doc.Errors()[0].Message) } func TestCreateProcsAndCheckForRoutineName(t *testing.T) { @@ -322,12 +322,12 @@ create procedure [code].[transform:safeguarding.Calculation/HEAD](@now datetime2 }, } for _, tc := range testcases { - require.Equal(t, 0, len(tc.doc.Errors)) - assert.Len(t, tc.doc.Creates, 1) - assert.Greater(t, len(tc.doc.Creates[0].Body), tc.expectedIndex) + require.Equal(t, 0, len(tc.doc.Errors())) + assert.Len(t, tc.doc.Creates(), 1) + assert.Greater(t, len(tc.doc.Creates()[0].Body), tc.expectedIndex) assert.Equal(t, fmt.Sprintf(templateRoutineName, tc.expectedProcName), - tc.doc.Creates[0].Body[tc.expectedIndex].RawValue, + tc.doc.Creates()[0].Body[tc.expectedIndex].RawValue, ) } } @@ -342,9 +342,9 @@ end // Code above was mainly to be able to step through parser in a given way. // First function triggers an error. Then create type is parsed which is // fine sharing a batch with others. - require.Equal(t, 2, len(doc.Errors)) - assert.Equal(t, "`go` should be alone on a line without any comments", doc.Errors[0].Message) - assert.Equal(t, "Expected 'declare' or 'create', got: end", doc.Errors[1].Message) + require.Equal(t, 2, len(doc.Errors())) + assert.Equal(t, "`go` should be alone on a line without any comments", doc.Errors()[0].Message) + assert.Equal(t, "Expected 'declare' or 'create', got: end", doc.Errors()[1].Message) } func TestCreateAnnotationHappyDay(t *testing.T) { @@ -362,8 +362,8 @@ create procedure [code].Foo as begin end `) assert.Equal(t, "-- This is part of annotation\n--! key1: a\n--! key2: b\n--! key3: [1,2,3]", - doc.Creates[0].DocstringAsString()) - s, err := doc.Creates[0].DocstringYamldoc() + doc.Creates()[0].DocstringAsString()) + s, err := doc.Creates()[0].DocstringYamldoc() assert.NoError(t, err) assert.Equal(t, "key1: a\nkey2: b\nkey3: [1,2,3]", @@ -372,7 +372,7 @@ create procedure [code].Foo as begin end var x struct { Key1 string `yaml:"key1"` } - require.NoError(t, doc.Creates[0].ParseYamlInDocstring(&x)) + require.NoError(t, doc.Creates()[0].ParseYamlInDocstring(&x)) assert.Equal(t, "a", x.Key1) } @@ -387,7 +387,7 @@ create procedure [code].Foo as begin end `) assert.Equal(t, "-- docstring here", - doc.Creates[0].DocstringAsString()) + doc.Creates()[0].DocstringAsString()) } func TestCreateAnnotationErrors(t *testing.T) { @@ -397,7 +397,7 @@ func TestCreateAnnotationErrors(t *testing.T) { -- This comment after yamldoc is illegal; this also prevents multiple embedded YAML documents create procedure [code].Foo as begin end `) - _, err := doc.Creates[0].DocstringYamldoc() + _, err := doc.Creates()[0].DocstringYamldoc() assert.Equal(t, "test.sql:3:1 once embedded yaml document is started (lines prefixed with `--!`), it must continue until create statement", err.Error()) @@ -407,7 +407,7 @@ create procedure [code].Foo as begin end --!key4: 1 create procedure [code].Foo as begin end `) - _, err = doc.Creates[0].DocstringYamldoc() + _, err = doc.Creates()[0].DocstringYamldoc() assert.Equal(t, "test.sql:3:1 YAML document in docstring; missing space after `--!`", err.Error()) @@ -433,8 +433,8 @@ create function [code].Func1() returns int as begin return 1 end filenames, doc, err := ParseFilesystems([]fs.FS{fsys}, nil) require.NoError(t, err) assert.Len(t, filenames, 2) - assert.Len(t, doc.Creates, 2) - assert.Len(t, doc.Declares, 1) + assert.Len(t, doc.Creates(), 2) + assert.Len(t, doc.Declares(), 1) }) t.Run("filters by include tags", func(t *testing.T) { @@ -455,8 +455,8 @@ create procedure [code].Excluded as begin end require.NoError(t, err) assert.Len(t, filenames, 1) assert.Contains(t, filenames[0], "included.sql") - assert.Len(t, doc.Creates, 1) - assert.Equal(t, "[Included]", doc.Creates[0].QuotedName.Value) + assert.Len(t, doc.Creates(), 1) + assert.Equal(t, "[Included]", doc.Creates()[0].QuotedName.Value) }) t.Run("detects duplicate files with same hash", func(t *testing.T) { @@ -488,7 +488,7 @@ create procedure [code].Excluded as begin end require.NoError(t, err) assert.Len(t, filenames, 1) assert.Contains(t, filenames[0], "sqlcode.sql") - assert.Len(t, doc.Creates, 1) + assert.Len(t, doc.Creates(), 1) }) t.Run("skips hidden directories", func(t *testing.T) { @@ -508,7 +508,7 @@ create procedure [code].Excluded as begin end require.NoError(t, err) assert.Len(t, filenames, 1) assert.Contains(t, filenames[0], "visible.sql") - assert.Len(t, doc.Creates, 1) + assert.Len(t, doc.Creates(), 1) }) t.Run("handles dependencies and topological sort", func(t *testing.T) { @@ -524,10 +524,10 @@ create procedure [code].Excluded as begin end filenames, doc, err := ParseFilesystems([]fs.FS{fsys}, nil) require.NoError(t, err) assert.Len(t, filenames, 2) - assert.Len(t, doc.Creates, 2) + assert.Len(t, doc.Creates(), 2) // Proc2 should come before Proc1 due to dependency - assert.Equal(t, "[Proc2]", doc.Creates[0].QuotedName.Value) - assert.Equal(t, "[Proc1]", doc.Creates[1].QuotedName.Value) + assert.Equal(t, "[Proc2]", doc.Creates()[0].QuotedName.Value) + assert.Equal(t, "[Proc1]", doc.Creates()[1].QuotedName.Value) }) t.Run("reports topological sort errors", func(t *testing.T) { @@ -541,9 +541,9 @@ create procedure [code].Excluded as begin end } _, doc, err := ParseFilesystems([]fs.FS{fsys}, nil) - require.NoError(t, err) // filesystem error should be nil - assert.NotEmpty(t, doc.Errors) // but parsing errors should exist - assert.Contains(t, doc.Errors[0].Message, "Detected a dependency cycle") + require.NoError(t, err) // filesystem error should be nil + assert.NotEmpty(t, doc.Errors()) // but parsing errors should exist + assert.Contains(t, doc.Errors()[0].Message, "Detected a dependency cycle") }) t.Run("handles multiple filesystems", func(t *testing.T) { @@ -563,7 +563,7 @@ create procedure [code].Excluded as begin end assert.Len(t, filenames, 2) assert.Contains(t, filenames[0], "fs[0]:") assert.Contains(t, filenames[1], "fs[1]:") - assert.Len(t, doc.Creates, 2) + assert.Len(t, doc.Creates(), 2) }) t.Run("detects sqlcode files by pragma header", func(t *testing.T) { @@ -578,7 +578,7 @@ create procedure NotInCodeSchema.Test as begin end`), require.NoError(t, err) assert.Len(t, filenames, 1) // Should still parse even though it will have errors (not in [code] schema) - assert.NotEmpty(t, doc.Errors) + assert.NotEmpty(t, doc.Errors()) }) t.Run("handles pgsql extension", func(t *testing.T) { @@ -599,8 +599,8 @@ $$; filenames, doc, err := ParseFilesystems([]fs.FS{fsys}, nil) require.NoError(t, err) assert.Len(t, filenames, 1) - assert.Len(t, doc.Creates, 1) - assert.Equal(t, &stdlib.Driver{}, doc.Creates[0].Driver) + assert.Len(t, doc.Creates(), 1) + assert.Equal(t, &stdlib.Driver{}, doc.Creates()[0].Driver) }) t.Run("empty filesystem returns empty results", func(t *testing.T) { @@ -609,8 +609,8 @@ $$; filenames, doc, err := ParseFilesystems([]fs.FS{fsys}, nil) require.NoError(t, err) assert.Empty(t, filenames) - assert.Empty(t, doc.Creates) - assert.Empty(t, doc.Declares) + assert.Empty(t, doc.Creates()) + assert.Empty(t, doc.Declares()) }) } From cd7f93ffa59b4055d1dbb9155fda8ec0754bc523 Mon Sep 17 00:00:00 2001 From: Kyle Simukka Date: Thu, 11 Dec 2025 19:58:05 +0100 Subject: [PATCH 20/40] Renamed the existing Document struct to be specific for T-SQL. --- sqlparser/tsql_document.go | 588 ++++++++++++++++++++++++ sqlparser/tsql_document_test.go | 764 ++++++++++++++++++++++++++++++++ 2 files changed, 1352 insertions(+) create mode 100644 sqlparser/tsql_document.go create mode 100644 sqlparser/tsql_document_test.go diff --git a/sqlparser/tsql_document.go b/sqlparser/tsql_document.go new file mode 100644 index 0000000..343ef7c --- /dev/null +++ b/sqlparser/tsql_document.go @@ -0,0 +1,588 @@ +package sqlparser + +import ( + "fmt" + "sort" + "strings" + + "github.com/jackc/pgx/v5/stdlib" + mssql "github.com/microsoft/go-mssqldb" +) + +type TSqlDocument struct { + pragmaIncludeIf []string + creates []Create + declares []Declare + errors []Error +} + +func (d TSqlDocument) HasErrors() bool { + return len(d.errors) > 0 +} + +func (d TSqlDocument) Creates() []Create { + return d.creates +} + +func (d TSqlDocument) Declares() []Declare { + return d.declares +} + +func (d TSqlDocument) Errors() []Error { + return d.errors +} +func (d TSqlDocument) PragmaIncludeIf() []string { + return d.pragmaIncludeIf +} + +func (d *TSqlDocument) Sort() { + // Do the topological sort; and include any error with it as part + // of `result`, *not* return it as err + sortedCreates, errpos, sortErr := TopologicalSort(d.creates) + + if sortErr != nil { + d.errors = append(d.errors, Error{ + Pos: errpos, + Message: sortErr.Error(), + }) + } else { + d.creates = sortedCreates + } +} + +// Transform a TSqlDocument to remove all Position information; this is used +// to 'unclutter' a DOM to more easily write assertions on it. +func (d TSqlDocument) WithoutPos() Document { + var cs []Create + for _, x := range d.creates { + cs = append(cs, x.WithoutPos()) + } + var ds []Declare + for _, x := range d.declares { + ds = append(ds, x.WithoutPos()) + } + var es []Error + for _, x := range d.errors { + es = append(es, x.WithoutPos()) + } + return &TSqlDocument{ + creates: cs, + declares: ds, + errors: es, + } +} + +func (d *TSqlDocument) Include(other Document) { + // Do not copy pragmaIncludeIf, since that is local to a single file. + // Its contents is also present in each Create. + d.declares = append(d.declares, other.Declares()...) + d.creates = append(d.creates, other.Creates()...) + d.errors = append(d.errors, other.Errors()...) +} + +func (d *TSqlDocument) parseSinglePragma(s *Scanner) { + pragma := strings.TrimSpace(strings.TrimPrefix(s.Token(), "--sqlcode:")) + if pragma == "" { + return + } + parts := strings.Split(pragma, " ") + if len(parts) != 2 { + d.addError(s, "Illegal pragma: "+s.Token()) + return + } + if parts[0] != "include-if" { + d.addError(s, "Illegal pragma: "+s.Token()) + return + } + d.pragmaIncludeIf = append(d.pragmaIncludeIf, strings.Split(parts[1], ",")...) +} + +func (d *TSqlDocument) ParsePragmas(s *Scanner) { + for s.TokenType() == PragmaToken { + d.parseSinglePragma(s) + s.NextNonWhitespaceToken() + } +} + +func (d TSqlDocument) Empty() bool { + return len(d.creates) == 0 || len(d.declares) == 0 +} + +func (d *TSqlDocument) addError(s *Scanner, msg string) { + d.errors = append(d.errors, Error{ + Pos: s.Start(), + Message: msg, + }) +} + +func (d *TSqlDocument) unexpectedTokenError(s *Scanner) { + d.addError(s, "Unexpected: "+s.Token()) +} + +func (doc *TSqlDocument) parseTypeExpression(s *Scanner) (t Type) { + parseArgs := func() { + // parses *after* the initial (; consumes trailing ) + for { + switch { + case s.TokenType() == NumberToken: + t.Args = append(t.Args, s.Token()) + case s.TokenType() == UnquotedIdentifierToken && s.TokenLower() == "max": + t.Args = append(t.Args, "max") + default: + doc.unexpectedTokenError(s) + doc.recoverToNextStatement(s) + return + } + s.NextNonWhitespaceCommentToken() + switch { + case s.TokenType() == CommaToken: + s.NextNonWhitespaceCommentToken() + continue + case s.TokenType() == RightParenToken: + s.NextNonWhitespaceCommentToken() + return + default: + doc.unexpectedTokenError(s) + doc.recoverToNextStatement(s) + return + } + } + } + + if s.TokenType() != UnquotedIdentifierToken { + panic("assertion failed, bug in caller") + } + t.BaseType = s.Token() + s.NextNonWhitespaceCommentToken() + if s.TokenType() == LeftParenToken { + s.NextNonWhitespaceCommentToken() + parseArgs() + } + return +} + +func (doc *TSqlDocument) parseDeclare(s *Scanner) (result []Declare) { + declareStart := s.Start() + // parse what is *after* the `declare` reserved keyword +loop: + for { + if s.TokenType() != VariableIdentifierToken { + doc.unexpectedTokenError(s) + doc.recoverToNextStatement(s) + return + } + + variableName := s.Token() + if !strings.HasPrefix(strings.ToLower(variableName), "@enum") && + !strings.HasPrefix(strings.ToLower(variableName), "@global") && + !strings.HasPrefix(strings.ToLower(variableName), "@const") { + doc.addError(s, "sqlcode constants needs to have names starting with @Enum, @Global or @Const: "+variableName) + } + s.NextNonWhitespaceCommentToken() + var variableType Type + switch s.TokenType() { + case EqualToken: + doc.addError(s, "sqlcode constants needs a type declared explicitly") + s.NextNonWhitespaceCommentToken() + case UnquotedIdentifierToken: + variableType = doc.parseTypeExpression(s) + } + + if s.TokenType() != EqualToken { + doc.addError(s, "sqlcode constants needs to be assigned at once using =") + doc.recoverToNextStatement(s) + } + + switch s.NextNonWhitespaceCommentToken() { + case NumberToken, NVarcharLiteralToken, VarcharLiteralToken: + result = append(result, Declare{ + Start: declareStart, + Stop: s.Stop(), + VariableName: variableName, + Datatype: variableType, + Literal: CreateUnparsed(s), + }) + default: + doc.unexpectedTokenError(s) + doc.recoverToNextStatement(s) + return + } + + switch s.NextNonWhitespaceCommentToken() { + case CommaToken: + s.NextNonWhitespaceCommentToken() + continue + case SemicolonToken: + s.NextNonWhitespaceCommentToken() + break loop + default: + break loop + } + } + if len(result) == 0 { + doc.addError(s, "incorrect syntax; no variables successfully declared") + } + return +} + +func (doc *TSqlDocument) parseBatchSeparator(s *Scanner) { + // just saw a 'go'; just make sure there's nothing bad trailing it + // (if there is, convert to errors and move on until the line is consumed + errorEmitted := false + for { + switch s.NextToken() { + case WhitespaceToken: + continue + case MalformedBatchSeparatorToken: + if !errorEmitted { + doc.addError(s, "`go` should be alone on a line without any comments") + errorEmitted = true + } + continue + default: + return + } + } +} + +func (doc *TSqlDocument) parseDeclareBatch(s *Scanner) (hasMore bool) { + if s.ReservedWord() != "declare" { + panic("assertion failed, incorrect use in caller") + } + for { + tt := s.TokenType() + switch { + case tt == EOFToken: + return false + case tt == ReservedWordToken && s.ReservedWord() == "declare": + s.NextNonWhitespaceCommentToken() + d := doc.parseDeclare(s) + doc.declares = append(doc.declares, d...) + case tt == ReservedWordToken && s.ReservedWord() != "declare": + doc.addError(s, "Only 'declare' allowed in this batch") + doc.recoverToNextStatement(s) + case tt == BatchSeparatorToken: + doc.parseBatchSeparator(s) + return true + default: + doc.unexpectedTokenError(s) + doc.recoverToNextStatement(s) + } + } +} + +func (doc *TSqlDocument) ParseBatch(s *Scanner, isFirst bool) (hasMore bool) { + var nodes []Unparsed + var docstring []PosString + newLineEncounteredInDocstring := false + + var createCountInBatch int + + for { + tt := s.TokenType() + switch tt { + case EOFToken: + return false + case WhitespaceToken, MultilineCommentToken: + nodes = append(nodes, CreateUnparsed(s)) + // do not reset token for a single trailing newline + t := s.Token() + if !newLineEncounteredInDocstring && (t == "\n" || t == "\r\n") { + newLineEncounteredInDocstring = true + } else { + docstring = nil + } + s.NextToken() + case SinglelineCommentToken: + // We build up a list of single line comments for the "docstring"; + // it is reset whenever we encounter something else + docstring = append(docstring, PosString{s.Start(), s.Token()}) + nodes = append(nodes, CreateUnparsed(s)) + newLineEncounteredInDocstring = false + s.NextToken() + case ReservedWordToken: + switch s.ReservedWord() { + case "declare": + // First declare-statement; enter a mode where we assume all contents + // of batch are declare statements + if !isFirst { + doc.addError(s, "'declare' statement only allowed in first batch") + } + // regardless of errors, go on and parse as far as we get... + return doc.parseDeclareBatch(s) + case "create": + // should be start of create procedure or create function... + c := doc.parseCreate(s, createCountInBatch) + + if strings.HasSuffix(string(s.file), ".sql") { + c.Driver = &mssql.Driver{} + } + if strings.HasSuffix(string(s.file), ".pgsql") { + c.Driver = &stdlib.Driver{} + } + + // *prepend* what we saw before getting to the 'create' + createCountInBatch++ + c.Body = append(nodes, c.Body...) + c.Docstring = docstring + doc.creates = append(doc.creates, c) + default: + doc.addError(s, "Expected 'declare' or 'create', got: "+s.ReservedWord()) + s.NextToken() + } + case BatchSeparatorToken: + doc.parseBatchSeparator(s) + return true + default: + doc.unexpectedTokenError(s) + s.NextToken() + docstring = nil + } + } +} + +func (d *TSqlDocument) recoverToNextStatementCopying(s *Scanner, target *[]Unparsed) { + // We hit an unexpected token ... as an heuristic for continuing parsing, + // skip parsing until we hit a reserved word that starts a statement + // we recognize + for { + NextTokenCopyingWhitespace(s, target) + switch s.TokenType() { + case ReservedWordToken: + switch s.ReservedWord() { + case "declare", "create", "go": + return + } + case EOFToken: + return + default: + CopyToken(s, target) + } + } +} + +func (d *TSqlDocument) recoverToNextStatement(s *Scanner) { + // We hit an unexpected token ... as an heuristic for continuing parsing, + // skip parsing until we hit a reserved word that starts a statement + // we recognize + for { + s.NextNonWhitespaceCommentToken() + switch s.TokenType() { + case ReservedWordToken: + switch s.ReservedWord() { + case "declare", "create", "go": + return + } + case EOFToken: + return + } + } +} + +// parseCodeschemaName parses `[code] . something`, and returns `something` +// in quoted form (`[something]`). Also copy to `target`. Empty string on error. +// Note: To follow conventions, consume one extra token at the end even if we know +// it fill not be consumed by this function... +func (d *TSqlDocument) parseCodeschemaName(s *Scanner, target *[]Unparsed) PosString { + CopyToken(s, target) + NextTokenCopyingWhitespace(s, target) + if s.TokenType() != DotToken { + d.addError(s, fmt.Sprintf("[code] must be followed by '.'")) + d.recoverToNextStatementCopying(s, target) + return PosString{Value: ""} + } + CopyToken(s, target) + + NextTokenCopyingWhitespace(s, target) + switch s.TokenType() { + case UnquotedIdentifierToken: + // To get something uniform for comparison, quote all names + CopyToken(s, target) + result := PosString{Pos: s.Start(), Value: "[" + s.Token() + "]"} + NextTokenCopyingWhitespace(s, target) + return result + case QuotedIdentifierToken: + CopyToken(s, target) + result := PosString{Pos: s.Start(), Value: s.Token()} + NextTokenCopyingWhitespace(s, target) + return result + default: + d.addError(s, fmt.Sprintf("[code]. must be followed an identifier")) + d.recoverToNextStatementCopying(s, target) + return PosString{Value: ""} + } +} + +// parseCreate parses anything that starts with "create". Position is +// *on* the create token. +// At this stage in sqlcode parser development we're only interested +// in procedures/functions/types as opaque blocks of SQL code where +// we only track dependencies between them and their declared name; +// so we treat them with the same code. We consume until the end of +// the batch; only one declaration allowed per batch. Everything +// parsed here will also be added to `batch`. On any error, copying +// to batch stops / becomes erratic.. +func (d *TSqlDocument) parseCreate(s *Scanner, createCountInBatch int) (result Create) { + if s.ReservedWord() != "create" { + panic("illegal use by caller") + } + CopyToken(s, &result.Body) + + NextTokenCopyingWhitespace(s, &result.Body) + + createType := strings.ToLower(s.Token()) + if !(createType == "procedure" || createType == "function" || createType == "type") { + d.addError(s, fmt.Sprintf("sqlcode only supports creating procedures, functions or types; not `%s`", createType)) + d.recoverToNextStatementCopying(s, &result.Body) + return + } + if (createType == "procedure" || createType == "function") && createCountInBatch > 0 { + d.addError(s, "a procedure/function must be alone in a batch; use 'go' to split batches") + d.recoverToNextStatementCopying(s, &result.Body) + return + } + + result.CreateType = createType + CopyToken(s, &result.Body) + + NextTokenCopyingWhitespace(s, &result.Body) + + // Insist on [code]. + if s.TokenType() != QuotedIdentifierToken || s.Token() != "[code]" { + d.addError(s, fmt.Sprintf("create %s must be followed by [code].", result.CreateType)) + d.recoverToNextStatementCopying(s, &result.Body) + return + } + result.QuotedName = d.parseCodeschemaName(s, &result.Body) + if result.QuotedName.String() == "" { + return + } + + // We have matched "create [code]."; at this + // point we copy the rest until the batch ends; *but* track dependencies + // + some other details mentioned below + + //firstAs := true // See comment below on rowcount + +tailloop: + for { + tt := s.TokenType() + switch { + case tt == ReservedWordToken && s.ReservedWord() == "create": + // So, we're currently parsing 'create ...' and we see another 'create'. + // We split in two cases depending on the context we are currently in + // (createType is referring to how we entered this function, *NOT* the + // `create` statement we are looking at now + switch createType { // note: this is the *outer* create type, not the one of current scanner position + case "function", "procedure": + // Within a function/procedure we can allow 'create index', 'create table' and nothing + // else. (Well, only procedures can have them, but we'll leave it to T-SQL to complain + // about that aspect, not relevant for batch / dependency parsing) + // + // What is important is a function/procedure/type isn't started on without a 'go' + // in between; so we block those 3 from appearing in the same batch + CopyToken(s, &result.Body) + NextTokenCopyingWhitespace(s, &result.Body) + tt2 := s.TokenType() + + if (tt2 == ReservedWordToken && (s.ReservedWord() == "function" || s.ReservedWord() == "procedure")) || + (tt2 == UnquotedIdentifierToken && s.TokenLower() == "type") { + d.recoverToNextStatementCopying(s, &result.Body) + d.addError(s, "a procedure/function must be alone in a batch; use 'go' to split batches") + return + } + case "type": + // We allow more than one type creation in a batch; and 'create' can never appear + // scoped within 'create type'. So at a new create we are done with the previous + // one, and return it -- the caller can then re-enter this function from the top + break tailloop + default: + panic("assertion failed") + } + + case tt == EOFToken || tt == BatchSeparatorToken: + break tailloop + case tt == QuotedIdentifierToken && s.Token() == "[code]": + // Parse a dependency + dep := d.parseCodeschemaName(s, &result.Body) + found := false + for _, existing := range result.DependsOn { + if existing.Value == dep.Value { + found = true + break + } + } + if !found { + result.DependsOn = append(result.DependsOn, dep) + } + case tt == ReservedWordToken && s.Token() == "as": + CopyToken(s, &result.Body) + NextTokenCopyingWhitespace(s, &result.Body) + /* + TODO: Fix and re-enable + This code add RoutineName for convenience. So: + + create procedure [code@5420c0269aaf].Test as + begin + select 1 + end + go + + becomes: + + create procedure [code@5420c0269aaf].Test as + declare @RoutineName nvarchar(128) + set @RoutineName = 'Test' + begin + select 1 + end + go + + However, for some very strange reason, @@rowcount is 1 with the first version, + and it is 2 with the second version. + if firstAs { + // Add the `RoutineName` token as a convenience, so that we can refer to the procedure/function name + // from inside the procedure (for example, when logging) + if result.CreateType == "procedure" { + procNameToken := Unparsed{ + Type: OtherToken, + RawValue: fmt.Sprintf(templateRoutineName, strings.Trim(result.QuotedName.Value, "[]")), + } + result.Body = append(result.Body, procNameToken) + } + firstAs = false + } + */ + + default: + CopyToken(s, &result.Body) + NextTokenCopyingWhitespace(s, &result.Body) + } + } + + sort.Slice(result.DependsOn, func(i, j int) bool { + return result.DependsOn[i].Value < result.DependsOn[j].Value + }) + return +} + +// NextTokenCopyingWhitespace is like s.NextToken(), but if whitespace is encountered +// it is simply copied into `target`. Upon return, the scanner is located at a non-whitespace +// token, and target is either unmodified or filled with some whitespace nodes. +func NextTokenCopyingWhitespace(s *Scanner, target *[]Unparsed) { + for { + tt := s.NextToken() + switch tt { + case EOFToken, BatchSeparatorToken: + // do not copy + return + case WhitespaceToken, MultilineCommentToken, SinglelineCommentToken: + // copy, and loop around + CopyToken(s, target) + continue + default: + return + } + } + +} diff --git a/sqlparser/tsql_document_test.go b/sqlparser/tsql_document_test.go new file mode 100644 index 0000000..b4ef303 --- /dev/null +++ b/sqlparser/tsql_document_test.go @@ -0,0 +1,764 @@ +package sqlparser + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDocument_addError(t *testing.T) { + t.Run("adds error with position", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "select") + s.NextToken() + + doc.addError(s, "test error message") + require.True(t, doc.HasErrors()) + assert.Equal(t, "test error message", doc.errors[0].Message) + assert.Equal(t, Pos{File: "test.sql", Line: 1, Col: 1}, doc.errors[0].Pos) + }) + + t.Run("accumulates multiple errors", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "abc def") + s.NextToken() + doc.addError(s, "error 1") + s.NextToken() + doc.addError(s, "error 2") + + require.Len(t, doc.errors, 2) + assert.Equal(t, "error 1", doc.errors[0].Message) + assert.Equal(t, "error 2", doc.errors[1].Message) + }) + + t.Run("creates error with token text", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "unexpected_token") + s.NextToken() + + doc.unexpectedTokenError(s) + + require.Len(t, doc.errors, 1) + assert.Equal(t, "Unexpected: unexpected_token", doc.errors[0].Message) + }) +} + +func TestDocument_parseTypeExpression(t *testing.T) { + t.Run("parses simple type without args", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "int") + s.NextToken() + + typ := doc.parseTypeExpression(s) + + assert.Equal(t, "int", typ.BaseType) + assert.Empty(t, typ.Args) + }) + + t.Run("parses type with single arg", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "varchar(50)") + s.NextToken() + + typ := doc.parseTypeExpression(s) + + assert.Equal(t, "varchar", typ.BaseType) + assert.Equal(t, []string{"50"}, typ.Args) + }) + + t.Run("parses type with multiple args", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "decimal(10, 2)") + s.NextToken() + + typ := doc.parseTypeExpression(s) + + assert.Equal(t, "decimal", typ.BaseType) + assert.Equal(t, []string{"10", "2"}, typ.Args) + }) + + t.Run("parses type with max", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "nvarchar(max)") + s.NextToken() + + typ := doc.parseTypeExpression(s) + + assert.Equal(t, "nvarchar", typ.BaseType) + assert.Equal(t, []string{"max"}, typ.Args) + }) + + t.Run("handles invalid arg", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "varchar(invalid)") + s.NextToken() + + typ := doc.parseTypeExpression(s) + + assert.Equal(t, "varchar", typ.BaseType) + assert.NotEmpty(t, doc.errors) + }) + + t.Run("panics if not on identifier", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "123") + s.NextToken() + + assert.Panics(t, func() { + doc.parseTypeExpression(s) + }) + }) +} + +func TestDocument_parseDeclare(t *testing.T) { + t.Run("parses single enum declaration", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "@EnumStatus int = 42") + s.NextToken() + + declares := doc.parseDeclare(s) + + require.Len(t, declares, 1) + assert.Equal(t, "@EnumStatus", declares[0].VariableName) + assert.Equal(t, "int", declares[0].Datatype.BaseType) + assert.Equal(t, "42", declares[0].Literal.RawValue) + }) + + t.Run("parses multiple declarations with comma", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "@EnumA int = 1, @EnumB int = 2;") + s.NextToken() + + declares := doc.parseDeclare(s) + + require.Len(t, declares, 2) + assert.Equal(t, "@EnumA", declares[0].VariableName) + assert.Equal(t, "@EnumB", declares[1].VariableName) + }) + + t.Run("parses string literal", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "@EnumName nvarchar(50) = N'test'") + s.NextToken() + + declares := doc.parseDeclare(s) + + require.Len(t, declares, 1) + assert.Equal(t, "N'test'", declares[0].Literal.RawValue) + }) + + t.Run("errors on invalid variable name", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "@InvalidName int = 1") + s.NextToken() + + declares := doc.parseDeclare(s) + + require.Len(t, declares, 1) + assert.NotEmpty(t, doc.errors) + assert.Contains(t, doc.errors[0].Message, "@InvalidName") + }) + + t.Run("errors on missing type", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "@EnumTest = 42") + s.NextToken() + + declares := doc.parseDeclare(s) + + require.Len(t, declares, 1) + assert.NotEmpty(t, doc.errors) + assert.Contains(t, doc.errors[0].Message, "type declared explicitly") + }) + + t.Run("errors on missing assignment", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "@EnumTest int") + s.NextToken() + + doc.parseDeclare(s) + + assert.NotEmpty(t, doc.errors) + assert.Contains(t, doc.errors[0].Message, "needs to be assigned") + }) + + t.Run("accepts @Global prefix", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "@GlobalSetting int = 100") + s.NextToken() + + declares := doc.parseDeclare(s) + + require.Len(t, declares, 1) + assert.Equal(t, "@GlobalSetting", declares[0].VariableName) + assert.Empty(t, doc.errors) + }) + + t.Run("accepts @Const prefix", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "@ConstValue int = 200") + s.NextToken() + + declares := doc.parseDeclare(s) + + require.Len(t, declares, 1) + assert.Equal(t, "@ConstValue", declares[0].VariableName) + assert.Empty(t, doc.errors) + }) +} + +func TestDocument_parseBatchSeparator(t *testing.T) { + t.Run("parses valid go separator", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "go\n") + s.NextToken() + + doc.parseBatchSeparator(s) + + assert.Empty(t, doc.errors) + }) + + t.Run("errors on malformed separator", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "go -- comment") + s.NextToken() + + doc.parseBatchSeparator(s) + + assert.NotEmpty(t, doc.errors) + assert.Contains(t, doc.errors[0].Message, "should be alone") + }) +} + +func TestDocument_parseCodeschemaName(t *testing.T) { + t.Run("parses unquoted identifier", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "[code].TestProc") + s.NextToken() + var target []Unparsed + + result := doc.parseCodeschemaName(s, &target) + + assert.Equal(t, "[TestProc]", result.Value) + assert.NotEmpty(t, target) + }) + + t.Run("parses quoted identifier", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "[code].[Test Proc]") + s.NextToken() + var target []Unparsed + + result := doc.parseCodeschemaName(s, &target) + + assert.Equal(t, "[Test Proc]", result.Value) + }) + + t.Run("errors on missing dot", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "[code] TestProc") + s.NextToken() + var target []Unparsed + + result := doc.parseCodeschemaName(s, &target) + + assert.Equal(t, "", result.Value) + assert.NotEmpty(t, doc.errors) + assert.Contains(t, doc.errors[0].Message, "must be followed by '.'") + }) + + t.Run("errors on missing identifier", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "[code].123") + s.NextToken() + var target []Unparsed + + result := doc.parseCodeschemaName(s, &target) + + assert.Equal(t, "", result.Value) + assert.NotEmpty(t, doc.errors) + assert.Contains(t, doc.errors[0].Message, "must be followed an identifier") + }) +} + +func TestDocument_parseCreate(t *testing.T) { + t.Run("parses simple procedure", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "create procedure [code].TestProc as begin end") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "procedure", create.CreateType) + assert.Equal(t, "[TestProc]", create.QuotedName.Value) + assert.NotEmpty(t, create.Body) + }) + + t.Run("parses function", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "create function [code].TestFunc() returns int as begin return 1 end") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "function", create.CreateType) + assert.Equal(t, "[TestFunc]", create.QuotedName.Value) + }) + + t.Run("parses type", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "create type [code].TestType as table (id int)") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "type", create.CreateType) + assert.Equal(t, "[TestType]", create.QuotedName.Value) + }) + + t.Run("tracks dependencies", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "create procedure [code].Proc1 as begin select * from [code].Table1 join [code].Table2 end") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + require.Len(t, create.DependsOn, 2) + assert.Equal(t, "[Table1]", create.DependsOn[0].Value) + assert.Equal(t, "[Table2]", create.DependsOn[1].Value) + }) + + t.Run("deduplicates dependencies", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "create procedure [code].Proc1 as begin select * from [code].Table1; select * from [code].Table1 end") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + require.Len(t, create.DependsOn, 1) + assert.Equal(t, "[Table1]", create.DependsOn[0].Value) + }) + + t.Run("errors on unsupported create type", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "create table [code].TestTable (id int)") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + doc.parseCreate(s, 0) + + assert.NotEmpty(t, doc.errors) + assert.Contains(t, doc.errors[0].Message, "only supports creating procedures") + }) + + t.Run("errors on multiple procedures in batch", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "create procedure [code].Proc2 as begin end") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + doc.parseCreate(s, 1) + + assert.NotEmpty(t, doc.errors) + assert.Contains(t, doc.errors[0].Message, "must be alone in a batch") + }) + + t.Run("errors on missing code schema", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "create procedure dbo.TestProc as begin end") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + doc.parseCreate(s, 0) + + assert.NotEmpty(t, doc.errors) + assert.Contains(t, doc.errors[0].Message, "must be followed by [code]") + }) + + t.Run("allows create index inside procedure", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "create procedure [code].Proc as begin create index IX_Test on #temp(id) end") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "procedure", create.CreateType) + assert.Empty(t, doc.errors) + }) + + t.Run("stops at batch separator", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "create procedure [code].Proc as begin end\ngo") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "[Proc]", create.QuotedName.Value) + assert.Equal(t, BatchSeparatorToken, s.TokenType()) + }) + + t.Run("panics if not on create token", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "procedure") + s.NextToken() + + assert.Panics(t, func() { + doc.parseCreate(s, 0) + }) + }) +} + +func TestNextTokenCopyingWhitespace(t *testing.T) { + t.Run("copies whitespace tokens", func(t *testing.T) { + s := NewScanner("test.sql", " \n\t token") + var target []Unparsed + + NextTokenCopyingWhitespace(s, &target) + + assert.NotEmpty(t, target) + assert.Equal(t, UnquotedIdentifierToken, s.TokenType()) + }) + + t.Run("copies comments", func(t *testing.T) { + s := NewScanner("test.sql", "/* comment */ -- line\ntoken") + var target []Unparsed + + NextTokenCopyingWhitespace(s, &target) + + assert.True(t, len(target) >= 2) + assert.Equal(t, UnquotedIdentifierToken, s.TokenType()) + }) + + t.Run("stops at EOF", func(t *testing.T) { + s := NewScanner("test.sql", " ") + var target []Unparsed + + NextTokenCopyingWhitespace(s, &target) + + assert.Equal(t, EOFToken, s.TokenType()) + }) +} + +func TestCreateUnparsed(t *testing.T) { + t.Run("creates unparsed from scanner", func(t *testing.T) { + s := NewScanner("test.sql", "select") + s.NextToken() + + unparsed := CreateUnparsed(s) + + assert.Equal(t, ReservedWordToken, unparsed.Type) + assert.Equal(t, "select", unparsed.RawValue) + assert.Equal(t, Pos{File: "test.sql", Line: 1, Col: 1}, unparsed.Start) + }) +} + +func TestDocument_recoverToNextStatement(t *testing.T) { + t.Run("recovers to declare", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "invalid tokens here declare @x int = 1") + s.NextToken() + + doc.recoverToNextStatement(s) + + assert.Equal(t, ReservedWordToken, s.TokenType()) + assert.Equal(t, "declare", s.ReservedWord()) + }) + + t.Run("recovers to create", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "bad stuff create procedure") + s.NextToken() + + doc.recoverToNextStatement(s) + + assert.Equal(t, "create", s.ReservedWord()) + }) + + t.Run("recovers to go", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "error error go") + s.NextToken() + + doc.recoverToNextStatement(s) + + assert.Equal(t, "go", s.ReservedWord()) + }) + + t.Run("stops at EOF", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "no keywords") + s.NextToken() + + doc.recoverToNextStatement(s) + + assert.Equal(t, EOFToken, s.TokenType()) + }) +} + +func TestDocument_recoverToNextStatementCopying(t *testing.T) { + t.Run("copies tokens while recovering", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "bad token declare") + s.NextToken() + var target []Unparsed + + doc.recoverToNextStatementCopying(s, &target) + + assert.NotEmpty(t, target) + assert.Equal(t, "declare", s.ReservedWord()) + }) +} + +func TestDocument_PostgreSQL17_parseCreate(t *testing.T) { + t.Run("parses PostgreSQL function with dollar quoting", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "create function test_func() returns int as $$ begin return 1; end; $$ language plpgsql") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "function", create.CreateType) + assert.Equal(t, "test_func", create.QuotedName.Value) + }) + + t.Run("parses PostgreSQL procedure", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "create procedure insert_data(a integer, b integer) language sql as $$ insert into tbl values (a, b); $$") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "procedure", create.CreateType) + assert.Equal(t, "insert_data", create.QuotedName.Value) + }) + + t.Run("parses CREATE OR REPLACE", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "create or replace function test_func() returns int as $$ begin return 1; end; $$ language plpgsql") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "function", create.CreateType) + }) + + t.Run("parses schema-qualified name", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "create function public.test_func() returns int as $$ begin return 1; end; $$ language plpgsql") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Contains(t, create.QuotedName.Value, "test_func") + }) + + t.Run("parses RETURNS TABLE", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "create function get_users() returns table(id int, name text) as $$ select id, name from users; $$ language sql") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "function", create.CreateType) + }) + + t.Run("tracks dependencies with schema prefix", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "create function test() returns int as $$ select * from public.table1 join public.table2 on table1.id = table2.id; $$ language sql") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + require.Len(t, create.DependsOn, 2) + }) + + t.Run("parses volatility categories", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "create function test_func() returns int immutable as $$ begin return 1; end; $$ language plpgsql") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "function", create.CreateType) + }) + + t.Run("parses PARALLEL SAFE", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "create function test_func() returns int parallel safe as $$ begin return 1; end; $$ language plpgsql") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "function", create.CreateType) + }) +} + +// func TestDocument_PostgreSQL17_Types(t *testing.T) { +// t.Run("parses composite type", func(t *testing.T) { +// doc := &TSqlDocument{} +// s := NewScanner("test.pgsql", "create type address_type as (street text, city text, zip varchar(10))") +// s.NextToken() +// s.NextNonWhitespaceCommentToken() + +// create := doc.parseCreate(s, 0) + +// assert.Equal(t, "type", create.CreateType) +// }) + +// t.Run("parses enum type", func(t *testing.T) { +// doc := &TSqlDocument{} +// s := NewScanner("test.pgsql", "create type mood as enum ('sad', 'ok', 'happy')") +// s.NextToken() +// s.NextNonWhitespaceCommentToken() + +// create := doc.parseCreate(s, 0) + +// assert.Equal(t, "type", create.CreateType) +// }) + +// t.Run("parses range type", func(t *testing.T) { +// doc := &TSqlDocument{} +// s := NewScanner("test.pgsql", "create type float_range as range (subtype = float8, subtype_diff = float8mi)") +// s.NextToken() +// s.NextNonWhitespaceCommentToken() + +// create := doc.parseCreate(s, 0) + +// assert.Equal(t, "type", create.CreateType) +// }) +// } + +// func TestDocument_PostgreSQL17_Extensions(t *testing.T) { +// t.Run("parses JSON functions PostgreSQL 17", func(t *testing.T) { +// doc := &TSqlDocument{} +// s := NewScanner("test.pgsql", "create function test() returns jsonb as $$ select json_serialize(data) from table1; $$ language sql") +// s.NextToken() +// s.NextNonWhitespaceCommentToken() + +// create := doc.parseCreate(s, 0) + +// assert.Equal(t, "function", create.CreateType) +// }) + +// t.Run("parses MERGE statement (PostgreSQL 15+)", func(t *testing.T) { +// doc := &TSqlDocument{} +// s := NewScanner("test.pgsql", "create function do_merge() returns void as $$ merge into target using source on target.id = source.id when matched then update set value = source.value; $$ language sql") +// s.NextToken() +// s.NextNonWhitespaceCommentToken() + +// create := doc.parseCreate(s, 0) + +// assert.Equal(t, "function", create.CreateType) +// }) +// } + +// func TestDocument_PostgreSQL17_Identifiers(t *testing.T) { +// t.Run("parses double-quoted identifiers", func(t *testing.T) { +// doc := &TSqlDocument{} +// s := NewScanner("test.pgsql", `create function "Test Func"() returns int as $$ begin return 1; end; $$ language plpgsql`) +// s.NextToken() +// s.NextNonWhitespaceCommentToken() + +// create := doc.parseCreate(s, 0) + +// assert.Contains(t, create.QuotedName.Value, "Test Func") +// }) + +// t.Run("parses case-sensitive identifiers", func(t *testing.T) { +// doc := &TSqlDocument{} +// s := NewScanner("test.pgsql", `create function "TestFunc"() returns int as $$ begin return 1; end; $$ language plpgsql`) +// s.NextToken() +// s.NextNonWhitespaceCommentToken() + +// create := doc.parseCreate(s, 0) + +// assert.Contains(t, create.QuotedName.Value, "TestFunc") +// }) +// } + +// func TestDocument_PostgreSQL17_Datatypes(t *testing.T) { +// t.Run("parses array types", func(t *testing.T) { +// doc := &TSqlDocument{} +// s := NewScanner("test.pgsql", "integer[]") +// s.NextToken() + +// typ := doc.parseTypeExpression(s) + +// assert.Equal(t, "integer[]", typ.BaseType) +// }) + +// t.Run("parses serial types", func(t *testing.T) { +// doc := &TSqlDocument{} +// s := NewScanner("test.pgsql", "serial") +// s.NextToken() + +// typ := doc.parseTypeExpression(s) + +// assert.Equal(t, "serial", typ.BaseType) +// }) + +// t.Run("parses text type", func(t *testing.T) { +// doc := &TSqlDocument{} +// s := NewScanner("test.pgsql", "text") +// s.NextToken() + +// typ := doc.parseTypeExpression(s) + +// assert.Equal(t, "text", typ.BaseType) +// }) + +// t.Run("parses jsonb type", func(t *testing.T) { +// doc := &TSqlDocument{} +// s := NewScanner("test.pgsql", "jsonb") +// s.NextToken() + +// typ := doc.parseTypeExpression(s) + +// assert.Equal(t, "jsonb", typ.BaseType) +// }) + +// t.Run("parses uuid type", func(t *testing.T) { +// doc := &TSqlDocument{} +// s := NewScanner("test.pgsql", "uuid") +// s.NextToken() + +// typ := doc.parseTypeExpression(s) + +// assert.Equal(t, "uuid", typ.BaseType) +// }) +// } + +// func TestDocument_PostgreSQL17_BatchSeparator(t *testing.T) { +// t.Run("PostgreSQL uses semicolon not GO", func(t *testing.T) { +// doc := &TSqlDocument{} +// s := NewScanner("test.pgsql", "create function test1() returns int as $$ begin return 1; end; $$ language plpgsql; create function test2() returns int as $$ begin return 2; end; $$ language plpgsql;") +// s.NextToken() +// s.NextNonWhitespaceCommentToken() + +// create1 := doc.parseCreate(s, 0) +// assert.Equal(t, "test1", create1.QuotedName.Value) + +// // Move to next statement +// s.NextNonWhitespaceCommentToken() +// s.NextNonWhitespaceCommentToken() + +// create2 := doc.parseCreate(s, 1) +// assert.Equal(t, "test2", create2.QuotedName.Value) +// }) +// } From 37da069b9480ae5b0513d57eb5ba054059ea8d2b Mon Sep 17 00:00:00 2001 From: Kyle Simukka Date: Thu, 11 Dec 2025 19:58:27 +0100 Subject: [PATCH 21/40] Created initial PGSqlDocument for PostgreSQL. --- sqlparser/pgsql_document.go | 50 +++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) create mode 100644 sqlparser/pgsql_document.go diff --git a/sqlparser/pgsql_document.go b/sqlparser/pgsql_document.go new file mode 100644 index 0000000..9aacf98 --- /dev/null +++ b/sqlparser/pgsql_document.go @@ -0,0 +1,50 @@ +package sqlparser + +type PGSqlDocument struct { + pragmaIncludeIf []string + creates []Create + errors []Error +} + +func (d PGSqlDocument) HasErrors() bool { + return len(d.errors) > 0 +} + +func (d PGSqlDocument) Creates() []Create { + return d.creates +} + +func (d PGSqlDocument) Declares() []Declare { + return nil +} + +func (d PGSqlDocument) Errors() []Error { + return d.errors +} +func (d PGSqlDocument) PragmaIncludeIf() []string { + return d.pragmaIncludeIf +} + +func (d PGSqlDocument) Empty() bool { + return len(d.creates) == 0 +} + +func (d PGSqlDocument) Sort() { + +} + +func (d PGSqlDocument) Include(other Document) { + +} + +func (d PGSqlDocument) ParsePragmas(s *Scanner) { + +} + +func (d PGSqlDocument) WithoutPos() Document { + return &PGSqlDocument{} +} + +func (d PGSqlDocument) ParseBatch(s *Scanner, isFirst bool) bool { + return false +} From 0ceb98b0a2abd1f3b77c9a317300fb5f28903355 Mon Sep 17 00:00:00 2001 From: Kyle Simukka Date: Thu, 11 Dec 2025 20:13:29 +0100 Subject: [PATCH 22/40] Updated unit test. --- sqlparser/tsql_document.go | 6 +- sqlparser/tsql_document_test.go | 320 +++++++++----------------------- 2 files changed, 90 insertions(+), 236 deletions(-) diff --git a/sqlparser/tsql_document.go b/sqlparser/tsql_document.go index 343ef7c..23d43a4 100644 --- a/sqlparser/tsql_document.go +++ b/sqlparser/tsql_document.go @@ -178,6 +178,7 @@ loop: !strings.HasPrefix(strings.ToLower(variableName), "@const") { doc.addError(s, "sqlcode constants needs to have names starting with @Enum, @Global or @Const: "+variableName) } + s.NextNonWhitespaceCommentToken() var variableType Type switch s.TokenType() { @@ -195,13 +196,14 @@ loop: switch s.NextNonWhitespaceCommentToken() { case NumberToken, NVarcharLiteralToken, VarcharLiteralToken: - result = append(result, Declare{ + declare := Declare{ Start: declareStart, Stop: s.Stop(), VariableName: variableName, Datatype: variableType, Literal: CreateUnparsed(s), - }) + } + result = append(result, declare) default: doc.unexpectedTokenError(s) doc.recoverToNextStatement(s) diff --git a/sqlparser/tsql_document_test.go b/sqlparser/tsql_document_test.go index b4ef303..0d5256c 100644 --- a/sqlparser/tsql_document_test.go +++ b/sqlparser/tsql_document_test.go @@ -7,106 +7,108 @@ import ( "github.com/stretchr/testify/require" ) -func TestDocument_addError(t *testing.T) { - t.Run("adds error with position", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.sql", "select") - s.NextToken() - - doc.addError(s, "test error message") - require.True(t, doc.HasErrors()) - assert.Equal(t, "test error message", doc.errors[0].Message) - assert.Equal(t, Pos{File: "test.sql", Line: 1, Col: 1}, doc.errors[0].Pos) - }) - - t.Run("accumulates multiple errors", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.sql", "abc def") - s.NextToken() - doc.addError(s, "error 1") - s.NextToken() - doc.addError(s, "error 2") +func TestTSqlDocument(t *testing.T) { + t.Run("addError", func(t *testing.T) { + t.Run("adds error with position", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "select") + s.NextToken() + + doc.addError(s, "test error message") + require.True(t, doc.HasErrors()) + assert.Equal(t, "test error message", doc.errors[0].Message) + assert.Equal(t, Pos{File: "test.sql", Line: 1, Col: 1}, doc.errors[0].Pos) + }) - require.Len(t, doc.errors, 2) - assert.Equal(t, "error 1", doc.errors[0].Message) - assert.Equal(t, "error 2", doc.errors[1].Message) - }) + t.Run("accumulates multiple errors", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "abc def") + s.NextToken() + doc.addError(s, "error 1") + s.NextToken() + doc.addError(s, "error 2") + + require.Len(t, doc.errors, 2) + assert.Equal(t, "error 1", doc.errors[0].Message) + assert.Equal(t, "error 2", doc.errors[1].Message) + }) - t.Run("creates error with token text", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.sql", "unexpected_token") - s.NextToken() + t.Run("creates error with token text", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "unexpected_token") + s.NextToken() - doc.unexpectedTokenError(s) + doc.unexpectedTokenError(s) - require.Len(t, doc.errors, 1) - assert.Equal(t, "Unexpected: unexpected_token", doc.errors[0].Message) + require.Len(t, doc.errors, 1) + assert.Equal(t, "Unexpected: unexpected_token", doc.errors[0].Message) + }) }) -} -func TestDocument_parseTypeExpression(t *testing.T) { - t.Run("parses simple type without args", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.sql", "int") - s.NextToken() + t.Run("parseTypeExpression", func(t *testing.T) { + t.Run("parses simple type without args", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "int") + s.NextToken() - typ := doc.parseTypeExpression(s) + typ := doc.parseTypeExpression(s) - assert.Equal(t, "int", typ.BaseType) - assert.Empty(t, typ.Args) - }) + assert.Equal(t, "int", typ.BaseType) + assert.Empty(t, typ.Args) + }) - t.Run("parses type with single arg", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.sql", "varchar(50)") - s.NextToken() + t.Run("parses type with single arg", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "varchar(50)") + s.NextToken() - typ := doc.parseTypeExpression(s) + typ := doc.parseTypeExpression(s) - assert.Equal(t, "varchar", typ.BaseType) - assert.Equal(t, []string{"50"}, typ.Args) - }) + assert.Equal(t, "varchar", typ.BaseType) + assert.Equal(t, []string{"50"}, typ.Args) + }) - t.Run("parses type with multiple args", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.sql", "decimal(10, 2)") - s.NextToken() + t.Run("parses type with multiple args", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "decimal(10, 2)") + s.NextToken() - typ := doc.parseTypeExpression(s) + typ := doc.parseTypeExpression(s) - assert.Equal(t, "decimal", typ.BaseType) - assert.Equal(t, []string{"10", "2"}, typ.Args) - }) + assert.Equal(t, "decimal", typ.BaseType) + assert.Equal(t, []string{"10", "2"}, typ.Args) + }) - t.Run("parses type with max", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.sql", "nvarchar(max)") - s.NextToken() + t.Run("parses type with max", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "nvarchar(max)") + s.NextToken() - typ := doc.parseTypeExpression(s) + typ := doc.parseTypeExpression(s) - assert.Equal(t, "nvarchar", typ.BaseType) - assert.Equal(t, []string{"max"}, typ.Args) - }) + assert.Equal(t, "nvarchar", typ.BaseType) + assert.Equal(t, []string{"max"}, typ.Args) + }) - t.Run("handles invalid arg", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.sql", "varchar(invalid)") - s.NextToken() + t.Run("handles invalid arg", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "varchar(invalid)") + s.NextToken() - typ := doc.parseTypeExpression(s) + typ := doc.parseTypeExpression(s) - assert.Equal(t, "varchar", typ.BaseType) - assert.NotEmpty(t, doc.errors) - }) + assert.Equal(t, "varchar", typ.BaseType) + assert.NotEmpty(t, doc.errors) + }) - t.Run("panics if not on identifier", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.sql", "123") - s.NextToken() + t.Run("panics if not on identifier", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "123") + s.NextToken() - assert.Panics(t, func() { - doc.parseTypeExpression(s) + assert.Panics(t, func() { + doc.parseTypeExpression(s) + }) }) }) } @@ -155,6 +157,9 @@ func TestDocument_parseDeclare(t *testing.T) { declares := doc.parseDeclare(s) + // in this case when we detect the missing prefix, + // we add an error and continue parsing the declaration. + // this results with it being added require.Len(t, declares, 1) assert.NotEmpty(t, doc.errors) assert.Contains(t, doc.errors[0].Message, "@InvalidName") @@ -167,7 +172,7 @@ func TestDocument_parseDeclare(t *testing.T) { declares := doc.parseDeclare(s) - require.Len(t, declares, 1) + require.Len(t, declares, 0) assert.NotEmpty(t, doc.errors) assert.Contains(t, doc.errors[0].Message, "type declared explicitly") }) @@ -177,8 +182,9 @@ func TestDocument_parseDeclare(t *testing.T) { s := NewScanner("test.sql", "@EnumTest int") s.NextToken() - doc.parseDeclare(s) + declares := doc.parseDeclare(s) + require.Len(t, declares, 0) assert.NotEmpty(t, doc.errors) assert.Contains(t, doc.errors[0].Message, "needs to be assigned") }) @@ -608,157 +614,3 @@ func TestDocument_PostgreSQL17_parseCreate(t *testing.T) { assert.Equal(t, "function", create.CreateType) }) } - -// func TestDocument_PostgreSQL17_Types(t *testing.T) { -// t.Run("parses composite type", func(t *testing.T) { -// doc := &TSqlDocument{} -// s := NewScanner("test.pgsql", "create type address_type as (street text, city text, zip varchar(10))") -// s.NextToken() -// s.NextNonWhitespaceCommentToken() - -// create := doc.parseCreate(s, 0) - -// assert.Equal(t, "type", create.CreateType) -// }) - -// t.Run("parses enum type", func(t *testing.T) { -// doc := &TSqlDocument{} -// s := NewScanner("test.pgsql", "create type mood as enum ('sad', 'ok', 'happy')") -// s.NextToken() -// s.NextNonWhitespaceCommentToken() - -// create := doc.parseCreate(s, 0) - -// assert.Equal(t, "type", create.CreateType) -// }) - -// t.Run("parses range type", func(t *testing.T) { -// doc := &TSqlDocument{} -// s := NewScanner("test.pgsql", "create type float_range as range (subtype = float8, subtype_diff = float8mi)") -// s.NextToken() -// s.NextNonWhitespaceCommentToken() - -// create := doc.parseCreate(s, 0) - -// assert.Equal(t, "type", create.CreateType) -// }) -// } - -// func TestDocument_PostgreSQL17_Extensions(t *testing.T) { -// t.Run("parses JSON functions PostgreSQL 17", func(t *testing.T) { -// doc := &TSqlDocument{} -// s := NewScanner("test.pgsql", "create function test() returns jsonb as $$ select json_serialize(data) from table1; $$ language sql") -// s.NextToken() -// s.NextNonWhitespaceCommentToken() - -// create := doc.parseCreate(s, 0) - -// assert.Equal(t, "function", create.CreateType) -// }) - -// t.Run("parses MERGE statement (PostgreSQL 15+)", func(t *testing.T) { -// doc := &TSqlDocument{} -// s := NewScanner("test.pgsql", "create function do_merge() returns void as $$ merge into target using source on target.id = source.id when matched then update set value = source.value; $$ language sql") -// s.NextToken() -// s.NextNonWhitespaceCommentToken() - -// create := doc.parseCreate(s, 0) - -// assert.Equal(t, "function", create.CreateType) -// }) -// } - -// func TestDocument_PostgreSQL17_Identifiers(t *testing.T) { -// t.Run("parses double-quoted identifiers", func(t *testing.T) { -// doc := &TSqlDocument{} -// s := NewScanner("test.pgsql", `create function "Test Func"() returns int as $$ begin return 1; end; $$ language plpgsql`) -// s.NextToken() -// s.NextNonWhitespaceCommentToken() - -// create := doc.parseCreate(s, 0) - -// assert.Contains(t, create.QuotedName.Value, "Test Func") -// }) - -// t.Run("parses case-sensitive identifiers", func(t *testing.T) { -// doc := &TSqlDocument{} -// s := NewScanner("test.pgsql", `create function "TestFunc"() returns int as $$ begin return 1; end; $$ language plpgsql`) -// s.NextToken() -// s.NextNonWhitespaceCommentToken() - -// create := doc.parseCreate(s, 0) - -// assert.Contains(t, create.QuotedName.Value, "TestFunc") -// }) -// } - -// func TestDocument_PostgreSQL17_Datatypes(t *testing.T) { -// t.Run("parses array types", func(t *testing.T) { -// doc := &TSqlDocument{} -// s := NewScanner("test.pgsql", "integer[]") -// s.NextToken() - -// typ := doc.parseTypeExpression(s) - -// assert.Equal(t, "integer[]", typ.BaseType) -// }) - -// t.Run("parses serial types", func(t *testing.T) { -// doc := &TSqlDocument{} -// s := NewScanner("test.pgsql", "serial") -// s.NextToken() - -// typ := doc.parseTypeExpression(s) - -// assert.Equal(t, "serial", typ.BaseType) -// }) - -// t.Run("parses text type", func(t *testing.T) { -// doc := &TSqlDocument{} -// s := NewScanner("test.pgsql", "text") -// s.NextToken() - -// typ := doc.parseTypeExpression(s) - -// assert.Equal(t, "text", typ.BaseType) -// }) - -// t.Run("parses jsonb type", func(t *testing.T) { -// doc := &TSqlDocument{} -// s := NewScanner("test.pgsql", "jsonb") -// s.NextToken() - -// typ := doc.parseTypeExpression(s) - -// assert.Equal(t, "jsonb", typ.BaseType) -// }) - -// t.Run("parses uuid type", func(t *testing.T) { -// doc := &TSqlDocument{} -// s := NewScanner("test.pgsql", "uuid") -// s.NextToken() - -// typ := doc.parseTypeExpression(s) - -// assert.Equal(t, "uuid", typ.BaseType) -// }) -// } - -// func TestDocument_PostgreSQL17_BatchSeparator(t *testing.T) { -// t.Run("PostgreSQL uses semicolon not GO", func(t *testing.T) { -// doc := &TSqlDocument{} -// s := NewScanner("test.pgsql", "create function test1() returns int as $$ begin return 1; end; $$ language plpgsql; create function test2() returns int as $$ begin return 2; end; $$ language plpgsql;") -// s.NextToken() -// s.NextNonWhitespaceCommentToken() - -// create1 := doc.parseCreate(s, 0) -// assert.Equal(t, "test1", create1.QuotedName.Value) - -// // Move to next statement -// s.NextNonWhitespaceCommentToken() -// s.NextNonWhitespaceCommentToken() - -// create2 := doc.parseCreate(s, 1) -// assert.Equal(t, "test2", create2.QuotedName.Value) -// }) -// } From 286f97625453407e2b64f63e045931a92a5677f4 Mon Sep 17 00:00:00 2001 From: Kyle Simukka Date: Thu, 11 Dec 2025 22:34:10 +0100 Subject: [PATCH 23/40] Updated tests. --- preprocess_test.go | 39 +---- sqlparser/pgsql_document_test.go | 254 +++++++++++++++++++++++++++++++ sqlparser/tsql_document.go | 9 +- sqlparser/tsql_document_test.go | 92 ----------- 4 files changed, 257 insertions(+), 137 deletions(-) create mode 100644 sqlparser/pgsql_document_test.go diff --git a/preprocess_test.go b/preprocess_test.go index e68d8bf..85132e6 100644 --- a/preprocess_test.go +++ b/preprocess_test.go @@ -58,10 +58,7 @@ func TestLineNumberInInput(t *testing.T) { func TestSchemaSuffixFromHash(t *testing.T) { t.Run("returns a unique hash", func(t *testing.T) { - doc := sqlparser.Document{ - Declares: []sqlparser.Declare{}, - } - + doc := sqlparser.NewDocumentFromExtension(".sql") value := SchemaSuffixFromHash(doc) require.Equal(t, value, SchemaSuffixFromHash(doc)) }) @@ -99,7 +96,7 @@ create procedure [code].Test2 as begin end }) t.Run("empty document has hash", func(t *testing.T) { - doc := sqlparser.Document{} + doc := sqlparser.NewDocumentFromExtension(".pgsql") suffix := SchemaSuffixFromHash(doc) assert.Len(t, suffix, 12) }) @@ -193,7 +190,6 @@ begin select 1 end `) - doc.Creates[0].Driver = &mssql.Driver{} result, err := Preprocess(doc, "abc123", &mssql.Driver{}) require.NoError(t, err) @@ -211,8 +207,6 @@ begin end; $$ language plpgsql; `) - doc.Creates[0].Driver = &stdlib.Driver{} - result, err := Preprocess(doc, "abc123", &stdlib.Driver{}) require.NoError(t, err) require.Len(t, result.Batches, 1) @@ -230,8 +224,6 @@ begin select @EnumStatus end `) - doc.Creates[0].Driver = &mssql.Driver{} - result, err := Preprocess(doc, "abc123", &mssql.Driver{}) require.NoError(t, err) require.Len(t, result.Batches, 1) @@ -252,8 +244,6 @@ begin select @EnumMulti end `) - doc.Creates[0].Driver = &mssql.Driver{} - result, err := Preprocess(doc, "abc123", &mssql.Driver{}) require.NoError(t, err) require.Len(t, result.Batches, 1) @@ -272,8 +262,6 @@ begin select @EnumUndeclared end `) - doc.Creates[0].Driver = &mssql.Driver{} - _, err := Preprocess(doc, "abc123", &mssql.Driver{}) require.Error(t, err) @@ -287,34 +275,17 @@ end doc := sqlparser.ParseString("test.sql", ` create procedure [code].Test as begin end `) - doc.Creates[0].Driver = &mssql.Driver{} - _, err := Preprocess(doc, "abc]123", &mssql.Driver{}) require.Error(t, err) assert.Contains(t, err.Error(), "schemasuffix cannot contain") }) - t.Run("skips creates with empty body", func(t *testing.T) { - doc := sqlparser.Document{ - Creates: []sqlparser.Create{ - {Body: []sqlparser.Unparsed{}}, - }, - } - - result, err := Preprocess(doc, "abc123", &mssql.Driver{}) - require.NoError(t, err) - assert.Empty(t, result.Batches) - }) - t.Run("handles multiple creates", func(t *testing.T) { doc := sqlparser.ParseString("test.sql", ` create procedure [code].Proc1 as begin select 1 end go create procedure [code].Proc2 as begin select 2 end `) - doc.Creates[0].Driver = &mssql.Driver{} - doc.Creates[1].Driver = &mssql.Driver{} - result, err := Preprocess(doc, "abc123", &mssql.Driver{}) require.NoError(t, err) assert.Len(t, result.Batches, 2) @@ -332,8 +303,6 @@ begin select @EnumA, @EnumB end `) - doc.Creates[0].Driver = &mssql.Driver{} - result, err := Preprocess(doc, "abc123", &mssql.Driver{}) require.NoError(t, err) require.Len(t, result.Batches, 1) @@ -354,8 +323,6 @@ begin select 1 end `) - doc.Creates[0].Driver = &mssql.Driver{} - result, err := Preprocess(doc, "abc123", &mssql.Driver{}) require.NoError(t, err) require.Len(t, result.Batches, 1) @@ -375,8 +342,6 @@ begin select @ConstValue, @GlobalSetting end `) - doc.Creates[0].Driver = &mssql.Driver{} - result, err := Preprocess(doc, "abc123", &mssql.Driver{}) require.NoError(t, err) require.Len(t, result.Batches, 1) diff --git a/sqlparser/pgsql_document_test.go b/sqlparser/pgsql_document_test.go new file mode 100644 index 0000000..2d3c8af --- /dev/null +++ b/sqlparser/pgsql_document_test.go @@ -0,0 +1,254 @@ +package sqlparser + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDocument_PostgreSQL17_parseCreate(t *testing.T) { + t.Run("parses PostgreSQL function with dollar quoting", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "create function test_func() returns int as $$ begin return 1; end; $$ language plpgsql") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "function", create.CreateType) + assert.Equal(t, "test_func", create.QuotedName.Value) + }) + + t.Run("parses PostgreSQL procedure", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "create procedure insert_data(a integer, b integer) language sql as $$ insert into tbl values (a, b); $$") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "procedure", create.CreateType) + assert.Equal(t, "insert_data", create.QuotedName.Value) + }) + + t.Run("parses CREATE OR REPLACE", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "create or replace function test_func() returns int as $$ begin return 1; end; $$ language plpgsql") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "function", create.CreateType) + }) + + t.Run("parses schema-qualified name", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "create function public.test_func() returns int as $$ begin return 1; end; $$ language plpgsql") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Contains(t, create.QuotedName.Value, "test_func") + }) + + t.Run("parses RETURNS TABLE", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "create function get_users() returns table(id int, name text) as $$ select id, name from users; $$ language sql") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "function", create.CreateType) + }) + + t.Run("tracks dependencies with schema prefix", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "create function test() returns int as $$ select * from public.table1 join public.table2 on table1.id = table2.id; $$ language sql") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + require.Len(t, create.DependsOn, 2) + }) + + t.Run("parses volatility categories", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "create function test_func() returns int immutable as $$ begin return 1; end; $$ language plpgsql") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "function", create.CreateType) + }) + + t.Run("parses PARALLEL SAFE", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "create function test_func() returns int parallel safe as $$ begin return 1; end; $$ language plpgsql") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "function", create.CreateType) + }) +} + +func TestDocument_PostgreSQL17_Types(t *testing.T) { + t.Run("parses composite type", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "create type address_type as (street text, city text, zip varchar(10))") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "type", create.CreateType) + }) + + t.Run("parses enum type", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "create type mood as enum ('sad', 'ok', 'happy')") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "type", create.CreateType) + }) + + t.Run("parses range type", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "create type float_range as range (subtype = float8, subtype_diff = float8mi)") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "type", create.CreateType) + }) +} + +func TestDocument_PostgreSQL17_Extensions(t *testing.T) { + t.Run("parses JSON functions PostgreSQL 17", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "create function test() returns jsonb as $$ select json_serialize(data) from table1; $$ language sql") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "function", create.CreateType) + }) + + t.Run("parses MERGE statement (PostgreSQL 15+)", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "create function do_merge() returns void as $$ merge into target using source on target.id = source.id when matched then update set value = source.value; $$ language sql") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "function", create.CreateType) + }) +} + +func TestDocument_PostgreSQL17_Identifiers(t *testing.T) { + t.Run("parses double-quoted identifiers", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", `create function "Test Func"() returns int as $$ begin return 1; end; $$ language plpgsql`) + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Contains(t, create.QuotedName.Value, "Test Func") + }) + + t.Run("parses case-sensitive identifiers", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", `create function "TestFunc"() returns int as $$ begin return 1; end; $$ language plpgsql`) + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Contains(t, create.QuotedName.Value, "TestFunc") + }) +} + +func TestDocument_PostgreSQL17_Datatypes(t *testing.T) { + t.Run("parses array types", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "integer[]") + s.NextToken() + + typ := doc.parseTypeExpression(s) + + assert.Equal(t, "integer[]", typ.BaseType) + }) + + t.Run("parses serial types", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "serial") + s.NextToken() + + typ := doc.parseTypeExpression(s) + + assert.Equal(t, "serial", typ.BaseType) + }) + + t.Run("parses text type", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "text") + s.NextToken() + + typ := doc.parseTypeExpression(s) + + assert.Equal(t, "text", typ.BaseType) + }) + + t.Run("parses jsonb type", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "jsonb") + s.NextToken() + + typ := doc.parseTypeExpression(s) + + assert.Equal(t, "jsonb", typ.BaseType) + }) + + t.Run("parses uuid type", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "uuid") + s.NextToken() + + typ := doc.parseTypeExpression(s) + + assert.Equal(t, "uuid", typ.BaseType) + }) +} + +func TestDocument_PostgreSQL17_BatchSeparator(t *testing.T) { + t.Run("PostgreSQL uses semicolon not GO", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "create function test1() returns int as $$ begin return 1; end; $$ language plpgsql; create function test2() returns int as $$ begin return 2; end; $$ language plpgsql;") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create1 := doc.parseCreate(s, 0) + assert.Equal(t, "test1", create1.QuotedName.Value) + + // Move to next statement + s.NextNonWhitespaceCommentToken() + s.NextNonWhitespaceCommentToken() + + create2 := doc.parseCreate(s, 1) + assert.Equal(t, "test2", create2.QuotedName.Value) + }) +} diff --git a/sqlparser/tsql_document.go b/sqlparser/tsql_document.go index 23d43a4..2aa35a9 100644 --- a/sqlparser/tsql_document.go +++ b/sqlparser/tsql_document.go @@ -5,7 +5,6 @@ import ( "sort" "strings" - "github.com/jackc/pgx/v5/stdlib" mssql "github.com/microsoft/go-mssqldb" ) @@ -315,13 +314,7 @@ func (doc *TSqlDocument) ParseBatch(s *Scanner, isFirst bool) (hasMore bool) { case "create": // should be start of create procedure or create function... c := doc.parseCreate(s, createCountInBatch) - - if strings.HasSuffix(string(s.file), ".sql") { - c.Driver = &mssql.Driver{} - } - if strings.HasSuffix(string(s.file), ".pgsql") { - c.Driver = &stdlib.Driver{} - } + c.Driver = &mssql.Driver{} // *prepend* what we saw before getting to the 'create' createCountInBatch++ diff --git a/sqlparser/tsql_document_test.go b/sqlparser/tsql_document_test.go index 0d5256c..b44d575 100644 --- a/sqlparser/tsql_document_test.go +++ b/sqlparser/tsql_document_test.go @@ -522,95 +522,3 @@ func TestDocument_recoverToNextStatementCopying(t *testing.T) { assert.Equal(t, "declare", s.ReservedWord()) }) } - -func TestDocument_PostgreSQL17_parseCreate(t *testing.T) { - t.Run("parses PostgreSQL function with dollar quoting", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.pgsql", "create function test_func() returns int as $$ begin return 1; end; $$ language plpgsql") - s.NextToken() - s.NextNonWhitespaceCommentToken() - - create := doc.parseCreate(s, 0) - - assert.Equal(t, "function", create.CreateType) - assert.Equal(t, "test_func", create.QuotedName.Value) - }) - - t.Run("parses PostgreSQL procedure", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.pgsql", "create procedure insert_data(a integer, b integer) language sql as $$ insert into tbl values (a, b); $$") - s.NextToken() - s.NextNonWhitespaceCommentToken() - - create := doc.parseCreate(s, 0) - - assert.Equal(t, "procedure", create.CreateType) - assert.Equal(t, "insert_data", create.QuotedName.Value) - }) - - t.Run("parses CREATE OR REPLACE", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.pgsql", "create or replace function test_func() returns int as $$ begin return 1; end; $$ language plpgsql") - s.NextToken() - s.NextNonWhitespaceCommentToken() - - create := doc.parseCreate(s, 0) - - assert.Equal(t, "function", create.CreateType) - }) - - t.Run("parses schema-qualified name", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.pgsql", "create function public.test_func() returns int as $$ begin return 1; end; $$ language plpgsql") - s.NextToken() - s.NextNonWhitespaceCommentToken() - - create := doc.parseCreate(s, 0) - - assert.Contains(t, create.QuotedName.Value, "test_func") - }) - - t.Run("parses RETURNS TABLE", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.pgsql", "create function get_users() returns table(id int, name text) as $$ select id, name from users; $$ language sql") - s.NextToken() - s.NextNonWhitespaceCommentToken() - - create := doc.parseCreate(s, 0) - - assert.Equal(t, "function", create.CreateType) - }) - - t.Run("tracks dependencies with schema prefix", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.pgsql", "create function test() returns int as $$ select * from public.table1 join public.table2 on table1.id = table2.id; $$ language sql") - s.NextToken() - s.NextNonWhitespaceCommentToken() - - create := doc.parseCreate(s, 0) - - require.Len(t, create.DependsOn, 2) - }) - - t.Run("parses volatility categories", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.pgsql", "create function test_func() returns int immutable as $$ begin return 1; end; $$ language plpgsql") - s.NextToken() - s.NextNonWhitespaceCommentToken() - - create := doc.parseCreate(s, 0) - - assert.Equal(t, "function", create.CreateType) - }) - - t.Run("parses PARALLEL SAFE", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.pgsql", "create function test_func() returns int parallel safe as $$ begin return 1; end; $$ language plpgsql") - s.NextToken() - s.NextNonWhitespaceCommentToken() - - create := doc.parseCreate(s, 0) - - assert.Equal(t, "function", create.CreateType) - }) -} From 910279d725a1e72e5f5fd818d4f6fdbc958ae90f Mon Sep 17 00:00:00 2001 From: Kyle Simukka Date: Thu, 11 Dec 2025 22:49:56 +0100 Subject: [PATCH 24/40] Simplified Document interface. Created Pragma struct. --- sqlparser/document.go | 4 +--- sqlparser/parser.go | 10 +++------ sqlparser/pgsql_document.go | 4 ++++ sqlparser/pragma.go | 41 +++++++++++++++++++++++++++++++++++ sqlparser/tsql_document.go | 43 ++++++++++++++----------------------- 5 files changed, 65 insertions(+), 37 deletions(-) create mode 100644 sqlparser/pragma.go diff --git a/sqlparser/document.go b/sqlparser/document.go index 21839ec..94b9020 100644 --- a/sqlparser/document.go +++ b/sqlparser/document.go @@ -19,9 +19,7 @@ type Document interface { PragmaIncludeIf() []string Include(other Document) Sort() - ParsePragmas(s *Scanner) - ParseBatch(s *Scanner, isFirst bool) (hasMore bool) - + Parse(s *Scanner) error WithoutPos() Document } diff --git a/sqlparser/parser.go b/sqlparser/parser.go index 1a49a08..c15e25f 100644 --- a/sqlparser/parser.go +++ b/sqlparser/parser.go @@ -61,14 +61,10 @@ func Parse(s *Scanner, result Document) { // // `s` will typically never be positioned on whitespace except in // whitespace-preserving parsing - - filepath.Ext(s.input) - s.NextNonWhitespaceToken() - result.ParsePragmas(s) - hasMore := result.ParseBatch(s, true) - for hasMore { - hasMore = result.ParseBatch(s, false) + err := result.Parse(s) + if err != nil { + panic(fmt.Sprintf("failed to parse document: %s: %e", s.file, err)) } return } diff --git a/sqlparser/pgsql_document.go b/sqlparser/pgsql_document.go index 9aacf98..24ff627 100644 --- a/sqlparser/pgsql_document.go +++ b/sqlparser/pgsql_document.go @@ -10,6 +10,10 @@ func (d PGSqlDocument) HasErrors() bool { return len(d.errors) > 0 } +func (d *PGSqlDocument) Parse(s *Scanner) error { + return nil +} + func (d PGSqlDocument) Creates() []Create { return d.creates } diff --git a/sqlparser/pragma.go b/sqlparser/pragma.go new file mode 100644 index 0000000..f0ee990 --- /dev/null +++ b/sqlparser/pragma.go @@ -0,0 +1,41 @@ +package sqlparser + +import ( + "fmt" + "strings" +) + +type Pragma struct { + pragmas []string +} + +func (d Pragma) PragmaIncludeIf() []string { + return d.pragmas +} + +func (d *Pragma) parseSinglePragma(s *Scanner) error { + pragma := strings.TrimSpace(strings.TrimPrefix(s.Token(), "--sqlcode:")) + if pragma == "" { + return nil + } + parts := strings.Split(pragma, " ") + + if len(parts) != 2 || parts[0] != "include-if" { + return fmt.Errorf("Illegal pragma: %s", s.Token()) + } + + d.pragmas = append(d.pragmas, strings.Split(parts[1], ",")...) + return nil +} + +func (d *Pragma) ParsePragmas(s *Scanner) error { + for s.TokenType() == PragmaToken { + err := d.parseSinglePragma(s) + if err != nil { + return err + } + s.NextNonWhitespaceToken() + } + + return nil +} diff --git a/sqlparser/tsql_document.go b/sqlparser/tsql_document.go index 2aa35a9..1038de7 100644 --- a/sqlparser/tsql_document.go +++ b/sqlparser/tsql_document.go @@ -13,6 +13,8 @@ type TSqlDocument struct { creates []Create declares []Declare errors []Error + + Pragma } func (d TSqlDocument) HasErrors() bool { @@ -30,8 +32,19 @@ func (d TSqlDocument) Declares() []Declare { func (d TSqlDocument) Errors() []Error { return d.errors } -func (d TSqlDocument) PragmaIncludeIf() []string { - return d.pragmaIncludeIf + +func (d *TSqlDocument) Parse(s *Scanner) error { + err := d.ParsePragmas(s) + if err != nil { + d.addError(s, err.Error()) + } + + hasMore := d.parseBatch(s, true) + for hasMore { + hasMore = d.parseBatch(s, false) + } + + return nil } func (d *TSqlDocument) Sort() { @@ -79,30 +92,6 @@ func (d *TSqlDocument) Include(other Document) { d.errors = append(d.errors, other.Errors()...) } -func (d *TSqlDocument) parseSinglePragma(s *Scanner) { - pragma := strings.TrimSpace(strings.TrimPrefix(s.Token(), "--sqlcode:")) - if pragma == "" { - return - } - parts := strings.Split(pragma, " ") - if len(parts) != 2 { - d.addError(s, "Illegal pragma: "+s.Token()) - return - } - if parts[0] != "include-if" { - d.addError(s, "Illegal pragma: "+s.Token()) - return - } - d.pragmaIncludeIf = append(d.pragmaIncludeIf, strings.Split(parts[1], ",")...) -} - -func (d *TSqlDocument) ParsePragmas(s *Scanner) { - for s.TokenType() == PragmaToken { - d.parseSinglePragma(s) - s.NextNonWhitespaceToken() - } -} - func (d TSqlDocument) Empty() bool { return len(d.creates) == 0 || len(d.declares) == 0 } @@ -272,7 +261,7 @@ func (doc *TSqlDocument) parseDeclareBatch(s *Scanner) (hasMore bool) { } } -func (doc *TSqlDocument) ParseBatch(s *Scanner, isFirst bool) (hasMore bool) { +func (doc *TSqlDocument) parseBatch(s *Scanner, isFirst bool) (hasMore bool) { var nodes []Unparsed var docstring []PosString newLineEncounteredInDocstring := false From fb8fffb061a8dc353c0fb697e04a0a7ea3ea3312 Mon Sep 17 00:00:00 2001 From: Kyle Simukka Date: Thu, 11 Dec 2025 23:21:53 +0100 Subject: [PATCH 25/40] [wip] pgsql document parsing --- sqlparser/pgsql_document.go | 533 +++++++++++++++++++++++++++++++++++- sqlparser/scanner.go | 2 + 2 files changed, 525 insertions(+), 10 deletions(-) diff --git a/sqlparser/pgsql_document.go b/sqlparser/pgsql_document.go index 24ff627..88ae487 100644 --- a/sqlparser/pgsql_document.go +++ b/sqlparser/pgsql_document.go @@ -1,9 +1,16 @@ package sqlparser +import ( + "fmt" + + "github.com/jackc/pgx/v5/stdlib" +) + type PGSqlDocument struct { - pragmaIncludeIf []string - creates []Create - errors []Error + creates []Create + errors []Error + + Pragma } func (d PGSqlDocument) HasErrors() bool { @@ -11,6 +18,11 @@ func (d PGSqlDocument) HasErrors() bool { } func (d *PGSqlDocument) Parse(s *Scanner) error { + err := d.ParsePragmas(s) + if err != nil { + d.errors = append(d.errors, Error{s.Start(), err.Error()}) + } + return nil } @@ -18,6 +30,7 @@ func (d PGSqlDocument) Creates() []Create { return d.creates } +// Not yet implemented func (d PGSqlDocument) Declares() []Declare { return nil } @@ -25,9 +38,6 @@ func (d PGSqlDocument) Declares() []Declare { func (d PGSqlDocument) Errors() []Error { return d.errors } -func (d PGSqlDocument) PragmaIncludeIf() []string { - return d.pragmaIncludeIf -} func (d PGSqlDocument) Empty() bool { return len(d.creates) == 0 @@ -41,14 +51,517 @@ func (d PGSqlDocument) Include(other Document) { } -func (d PGSqlDocument) ParsePragmas(s *Scanner) { +func (d PGSqlDocument) WithoutPos() Document { + return &PGSqlDocument{} +} + +// No GO batch separator: +// +// PostgreSQL uses semicolons (;) to separate statements, not GO. +// Multiple CREATE statements can exist in the same file. +// +// No top-level DECLARE: +// +// In PostgreSQL, DECLARE is only used inside function/procedure bodies within BEGIN...END blocks, not as top-level batch statements. +// +// Multiple CREATEs per batch: +// +// Unlike T-SQL which requires procedures/functions to be alone in a batch, PostgreSQL allows multiple CREATE statements separated by semicolons. +// +// Semicolon handling: +// +// The semicolon is a statement terminator, not a batch separator, so parsing continues after encountering one. +// +// Dollar quoting: +// +// PostgreSQL uses $$ or $tag$ for quoting function bodies instead of BEGIN...END (this would be handled in parseCreate). +// +// CREATE OR REPLACE: +// +// PostgreSQL commonly uses CREATE OR REPLACE which would need special handling in parseCreate. +// +// Schema qualification: +// +// PostgreSQL uses schema.object notation rather than [schema].[object]. +func (doc *PGSqlDocument) parseBatch(s *Scanner, isFirst bool) (hasMore bool) { + var nodes []Unparsed + var docstring []PosString + newLineEncounteredInDocstring := false + for { + tt := s.TokenType() + switch tt { + case EOFToken: + return false + case WhitespaceToken, MultilineCommentToken: + nodes = append(nodes, CreateUnparsed(s)) + // do not reset docstring for a single trailing newline + t := s.Token() + if !newLineEncounteredInDocstring && (t == "\n" || t == "\r\n") { + newLineEncounteredInDocstring = true + } else { + docstring = nil + } + s.NextToken() + case SinglelineCommentToken: + // Build up a list of single line comments for the "docstring"; + // it is reset whenever we encounter something else + docstring = append(docstring, PosString{s.Start(), s.Token()}) + nodes = append(nodes, CreateUnparsed(s)) + newLineEncounteredInDocstring = false + s.NextToken() + case ReservedWordToken: + switch s.ReservedWord() { + case "declare": + // PostgreSQL doesn't have top-level DECLARE batches like T-SQL + // DECLARE is only used inside function/procedure bodies + if isFirst { + doc.addError(s, "PostgreSQL 'declare' is used inside function bodies, not as top-level batch statements") + } + nodes = append(nodes, CreateUnparsed(s)) + s.NextToken() + docstring = nil + case "create": + // Parse CREATE FUNCTION, CREATE PROCEDURE, CREATE TYPE, etc. + createStart := len(doc.creates) + c := doc.parseCreate(s, createStart) + c.Driver = &stdlib.Driver{} + + // Prepend any leading comments/whitespace + c.Body = append(nodes, c.Body...) + c.Docstring = docstring + doc.creates = append(doc.creates, c) + + // Reset for next statement + nodes = nil + docstring = nil + newLineEncounteredInDocstring = false + default: + doc.addError(s, "Expected 'create', got: "+s.ReservedWord()) + s.NextToken() + docstring = nil + } + case SemicolonToken: + // PostgreSQL uses semicolons as statement terminators + // Multiple CREATE statements can exist in same file + nodes = append(nodes, CreateUnparsed(s)) + s.NextToken() + // Continue parsing - don't return like T-SQL does with GO + case BatchSeparatorToken: + // PostgreSQL doesn't use GO batch separators + // Q: Do we want to use GO batch separators as a feature of sqlcode? + doc.addError(s, "PostgreSQL does not use 'GO' batch separators; use semicolons instead") + s.NextToken() + docstring = nil + default: + doc.addError(s, fmt.Sprintf("Unexpected token in PostgreSQL document: %s", s.Token())) + s.NextToken() + docstring = nil + } + } } -func (d PGSqlDocument) WithoutPos() Document { - return &PGSqlDocument{} +// parseCreate parses PostgreSQL CREATE statements (FUNCTION, PROCEDURE, TYPE, etc.) +// Position is *on* the CREATE token. +// +// PostgreSQL CREATE syntax differences from T-SQL: +// - Supports CREATE OR REPLACE for functions/procedures +// - Uses dollar quoting ($$...$$) or $tag$...$tag$ for function bodies +// - Schema qualification uses dot notation: schema.function_name +// - Double-quoted identifiers preserve case: "MyFunction" +// - Function parameters use different syntax: func(param1 type1, param2 type2) +// - RETURNS clause specifies return type +// - LANGUAGE clause (plpgsql, sql, etc.) is required +// - Function characteristics: IMMUTABLE, STABLE, VOLATILE, PARALLEL SAFE, etc. +// +// We parse until we hit a semicolon or EOF, tracking dependencies on other objects. +func (doc *PGSqlDocument) parseCreate(s *Scanner, createCountInBatch int) (result Create) { + var body []Unparsed + pos := s.Start() + + // Copy the CREATE token + CopyToken(s, &body) + s.NextNonWhitespaceCommentToken() + + // Check for OR REPLACE + if s.TokenType() == ReservedWordToken && s.ReservedWord() == "or" { + CopyToken(s, &body) + s.NextNonWhitespaceCommentToken() + + if s.TokenType() == ReservedWordToken && s.ReservedWord() == "replace" { + CopyToken(s, &body) + s.NextNonWhitespaceCommentToken() + } else { + doc.addError(s, "Expected 'REPLACE' after 'OR'") + doc.recoverToNextStatementCopying(s, &body) + result.Body = body + return + } + } + + // Parse the object type (FUNCTION, PROCEDURE, TYPE, etc.) + if s.TokenType() != ReservedWordToken { + doc.addError(s, "Expected object type after CREATE (e.g., FUNCTION, PROCEDURE, TYPE)") + doc.recoverToNextStatementCopying(s, &body) + result.Body = body + return + } + + createType := s.ReservedWord() + result.CreateType = createType + CopyToken(s, &body) + s.NextNonWhitespaceCommentToken() + + // Validate supported CREATE types + switch createType { + case "function", "procedure", "type": + // Supported types + default: + doc.addError(s, fmt.Sprintf("Unsupported CREATE type for PostgreSQL: %s", createType)) + doc.recoverToNextStatementCopying(s, &body) + result.Body = body + return + } + + // Parse the object name (with optional schema qualification) + // objectName := doc.parseQualifiedName(s, &body) + // if objectName == "" { + // doc.addError(s, "Expected object name after CREATE "+createType) + // doc.recoverToNextStatementCopying(s, &body) + // result.Body = body + // return + // } + + // result.QuotedName = PosString{pos, objectName} + + // Insist on [code]. + if s.TokenType() != QuotedIdentifierToken || s.Token() != "[code]" { + doc.addError(s, fmt.Sprintf("create %s must be followed by [code].", result.CreateType)) + doc.recoverToNextStatementCopying(s, &result.Body) + return + } + result.QuotedName = doc.parseCodeschemaName(s, &result.Body) + if result.QuotedName.String() == "" { + return + } + + // Parse function/procedure signature or type definition + switch createType { + case "function", "procedure": + doc.parseFunctionSignature(s, &body, &result) + case "type": + doc.parseTypeDefinition(s, &body, &result) + } + + // Parse the rest of the CREATE statement body until semicolon or EOF + doc.parseCreateBody(s, &body, &result) + + result.Body = body + return +} + +// parseQualifiedName parses schema-qualified or simple object names +// Supports: simple_name, schema.name, "Quoted Name", schema."Quoted Name" +func (doc *PGSqlDocument) parseQualifiedName(s *Scanner, body *[]Unparsed) string { + var nameParts []string + + for { + switch s.TokenType() { + case UnquotedIdentifierToken: + nameParts = append(nameParts, s.Token()) + CopyToken(s, body) + s.NextNonWhitespaceCommentToken() + case QuotedIdentifierToken: + // PostgreSQL uses double quotes for case-sensitive identifiers + nameParts = append(nameParts, s.Token()) + CopyToken(s, body) + s.NextNonWhitespaceCommentToken() + default: + if len(nameParts) == 0 { + return "" + } + // Return the last part as the object name (without schema) + return nameParts[len(nameParts)-1] + } + + // Check for dot separator (schema.object) + if s.TokenType() == DotToken { + CopyToken(s, body) + s.NextNonWhitespaceCommentToken() + continue + } + + break + } + + if len(nameParts) == 0 { + return "" + } + return nameParts[len(nameParts)-1] +} + +// parseFunctionSignature parses function/procedure parameters and RETURNS clause +func (doc *PGSqlDocument) parseFunctionSignature(s *Scanner, body *[]Unparsed, result *Create) { + // Expect opening parenthesis for parameters + if s.TokenType() != LeftParenToken { + doc.addError(s, "Expected '(' for function parameters") + return + } + + CopyToken(s, body) + s.NextNonWhitespaceCommentToken() + + // Parse parameters until closing parenthesis + parenDepth := 1 + for parenDepth > 0 { + switch s.TokenType() { + case EOFToken: + doc.addError(s, "Unexpected EOF in function parameters") + return + case LeftParenToken: + parenDepth++ + CopyToken(s, body) + s.NextToken() + case RightParenToken: + parenDepth-- + CopyToken(s, body) + s.NextToken() + case SemicolonToken: + doc.addError(s, "Unexpected semicolon in function parameters") + return + default: + CopyToken(s, body) + s.NextToken() + } + } + + s.SkipWhitespaceComments() + + // Parse RETURNS clause (for functions, not procedures) + if result.CreateType == "function" { + if s.TokenType() == ReservedWordToken && s.ReservedWord() == "returns" { + CopyToken(s, body) + s.NextNonWhitespaceCommentToken() + + // Handle RETURNS TABLE(...) + if s.TokenType() == ReservedWordToken && s.ReservedWord() == "table" { + CopyToken(s, body) + s.NextNonWhitespaceCommentToken() + + if s.TokenType() == LeftParenToken { + doc.parseReturnTable(s, body) + } + } else { + // Parse simple return type + doc.parseTypeExpression(s, body) + } + } + } +} + +// parseReturnTable parses RETURNS TABLE(...) syntax +func (doc *PGSqlDocument) parseReturnTable(s *Scanner, body *[]Unparsed) { + parenDepth := 0 + for { + switch s.TokenType() { + case EOFToken, SemicolonToken: + return + case LeftParenToken: + parenDepth++ + case RightParenToken: + parenDepth-- + CopyToken(s, body) + s.NextToken() + if parenDepth == 0 { + return + } + continue + } + CopyToken(s, body) + s.NextToken() + } +} + +// parseTypeExpression parses PostgreSQL type expressions +// Supports: int, integer, text, varchar(n), numeric(p,s), arrays (int[]), etc. +func (doc *PGSqlDocument) parseTypeExpression(s *Scanner, body *[]Unparsed) { + // Parse base type + if s.TokenType() != UnquotedIdentifierToken && s.TokenType() != ReservedWordToken { + return + } + + CopyToken(s, body) + s.NextNonWhitespaceCommentToken() + + // Handle array notation: type[] + // if s.TokenType() == LeftBracketToken { + // CopyToken(s, body) + // s.NextNonWhitespaceCommentToken() + + // if s.TokenType() == RightBracketToken { + // CopyToken(s, body) + // s.NextNonWhitespaceCommentToken() + // } + // } + + // Handle type parameters: varchar(100), numeric(10,2) + if s.TokenType() == LeftParenToken { + parenDepth := 1 + CopyToken(s, body) + s.NextToken() + + for parenDepth > 0 { + switch s.TokenType() { + case EOFToken, SemicolonToken: + return + case LeftParenToken: + parenDepth++ + case RightParenToken: + parenDepth-- + } + CopyToken(s, body) + s.NextToken() + } + } +} + +// parseTypeDefinition parses CREATE TYPE syntax +// Supports: ENUM, composite types, range types +func (doc *PGSqlDocument) parseTypeDefinition(s *Scanner, body *[]Unparsed, result *Create) { + // TYPE definitions use AS keyword + if s.TokenType() == ReservedWordToken && s.ReservedWord() == "as" { + CopyToken(s, body) + s.NextNonWhitespaceCommentToken() + + // Check for ENUM, RANGE, or composite type + if s.TokenType() == ReservedWordToken { + typeKind := s.ReservedWord() + switch typeKind { + case "enum", "range": + CopyToken(s, body) + s.NextNonWhitespaceCommentToken() + } + } + } } -func (d PGSqlDocument) ParseBatch(s *Scanner, isFirst bool) bool { +// parseCreateBody parses the body of a CREATE statement +// Handles dollar-quoted strings, tracks dependencies, continues until semicolon/EOF +func (doc *PGSqlDocument) parseCreateBody(s *Scanner, body *[]Unparsed, result *Create) { + dollarQuoteDepth := 0 + var currentDollarTag string + + for { + switch s.TokenType() { + case EOFToken: + return + case SemicolonToken: + // Statement terminator - we're done + CopyToken(s, body) + s.NextToken() + return + case DollarQuotedStringStartToken: + // PostgreSQL dollar quoting: $$...$$ or $tag$...$tag$ + currentDollarTag = s.Token() + dollarQuoteDepth++ + CopyToken(s, body) + s.NextToken() + case DollarQuotedStringEndToken: + if s.Token() == currentDollarTag { + dollarQuoteDepth-- + } + CopyToken(s, body) + s.NextToken() + if dollarQuoteDepth == 0 { + currentDollarTag = "" + } + case UnquotedIdentifierToken, QuotedIdentifierToken: + // Track dependencies on tables/views/functions + // In PostgreSQL, identifiers can be qualified: schema.object + identifier := s.Token() + + // Check if this might be a dependency (after FROM, JOIN, etc.) + if doc.mightBeDependency(s) { + // Extract just the object name (without schema prefix) + objectName := doc.extractObjectName(identifier) + result.DependsOn = append(result.DependsOn, PosString{s.Start(), objectName}) + } + + CopyToken(s, body) + s.NextToken() + default: + CopyToken(s, body) + s.NextToken() + } + } +} + +// mightBeDependency checks if current context suggests a table/view/function reference +func (doc *PGSqlDocument) mightBeDependency(s *Scanner) bool { + // Simple heuristic: look back for FROM, JOIN, INTO, etc. + // This would need to track parse context for accurate dependency detection + return false // Placeholder - implement context-aware dependency tracking +} + +// extractObjectName extracts object name from schema-qualified identifier +func (doc *PGSqlDocument) extractObjectName(identifier string) string { + // Handle schema.object notation + // For now, return as-is; proper implementation would split on dot + return identifier +} + +// recoverToNextStatementCopying recovers from parse errors by skipping to next statement +func (doc *PGSqlDocument) recoverToNextStatementCopying(s *Scanner, target *[]Unparsed) { + for { + NextTokenCopyingWhitespace(s, target) + switch s.TokenType() { + case ReservedWordToken: + switch s.ReservedWord() { + case "create", "drop", "alter": + return + } + case EOFToken, SemicolonToken: + return + default: + CopyToken(s, target) + } + } +} + +func (doc *PGSqlDocument) addError(s *Scanner, err string) { + doc.errors = append(doc.errors, Error{ + s.Start(), err, + }) +} + +func (doc *PGSqlDocument) parseDeclareBatch(s *Scanner) (hasMore bool) { + // PostgreSQL doesn't have top-level DECLARE batches like T-SQL + // DECLARE is only used inside function/procedure bodies (in BEGIN...END blocks) + doc.addError(s, "PostgreSQL does not support top-level DECLARE statements outside of function bodies") + doc.recoverToNextStatement(s) return false } + +func (doc *PGSqlDocument) parseBatchSeparator(s *Scanner) { + // PostgreSQL doesn't use GO batch separators + doc.addError(s, "PostgreSQL does not use 'GO' batch separators; use semicolons") + s.NextToken() +} + +func (doc *PGSqlDocument) recoverToNextStatement(s *Scanner) { + // We hit an unexpected token ... as an heuristic for continuing parsing, + // skip parsing until we hit a reserved word that starts a statement + // we recognize + for { + s.NextNonWhitespaceCommentToken() + switch s.TokenType() { + case ReservedWordToken: + switch s.ReservedWord() { + case "create", "drop", "alter": + return + } + case EOFToken, SemicolonToken: + return + } + } +} diff --git a/sqlparser/scanner.go b/sqlparser/scanner.go index a5fb75a..9e6c27f 100644 --- a/sqlparser/scanner.go +++ b/sqlparser/scanner.go @@ -565,6 +565,8 @@ var reservedWords = map[string]struct{}{ "writetext": struct{}{}, "exit": struct{}{}, "proc": struct{}{}, + // pgsql + "replace": struct{}{}, } // apparently 'within group' is also reserved but dropping that.. From 59967551899aa49e1b3b3f4d88f39d38246c6829 Mon Sep 17 00:00:00 2001 From: Kyle Simukka Date: Tue, 16 Dec 2025 19:23:28 +0100 Subject: [PATCH 26/40] Simplify the interfaces for parsing a SQL document. --- sqlparser/document.go | 92 +++++++++++++++++ sqlparser/dom.go | 24 ----- sqlparser/nodes.go | 91 +++++++++++++++++ sqlparser/pgsql_document.go | 95 ++++++++---------- sqlparser/tsql_document.go | 191 ++++++++---------------------------- sqlparser/unparsed.go | 25 +++++ 6 files changed, 287 insertions(+), 231 deletions(-) create mode 100644 sqlparser/nodes.go create mode 100644 sqlparser/unparsed.go diff --git a/sqlparser/document.go b/sqlparser/document.go index 94b9020..1721875 100644 --- a/sqlparser/document.go +++ b/sqlparser/document.go @@ -1,7 +1,9 @@ package sqlparser import ( + "fmt" "path/filepath" + "slices" "strings" ) @@ -41,3 +43,93 @@ func NewDocumentFromExtension(extension string) Document { panic("unhandled document type: " + extension) } } + +// parseCodeschemaName parses `[code] . something`, and returns `something` +// in quoted form (`[something]`). Also copy to `target`. Empty string on error. +// Note: To follow conventions, consume one extra token at the end even if we know +// it fill not be consumed by this function... +func ParseCodeschemaName(s *Scanner, target *[]Unparsed, statementTokens []string) (PosString, error) { + CopyToken(s, target) + NextTokenCopyingWhitespace(s, target) + if s.TokenType() != DotToken { + RecoverToNextStatementCopying(s, target, statementTokens) + return PosString{Value: ""}, fmt.Errorf("[code] must be followed by '.'") + } + + CopyToken(s, target) + + NextTokenCopyingWhitespace(s, target) + switch s.TokenType() { + case UnquotedIdentifierToken: + // To get something uniform for comparison, quote all names + CopyToken(s, target) + result := PosString{Pos: s.Start(), Value: "[" + s.Token() + "]"} + NextTokenCopyingWhitespace(s, target) + return result, nil + case QuotedIdentifierToken: + CopyToken(s, target) + result := PosString{Pos: s.Start(), Value: s.Token()} + NextTokenCopyingWhitespace(s, target) + return result, nil + default: + RecoverToNextStatementCopying(s, target, statementTokens) + return PosString{Value: ""}, fmt.Errorf("[code]. must be followed an identifier") + } +} + +// NextTokenCopyingWhitespace is like s.NextToken(), but if whitespace is encountered +// it is simply copied into `target`. Upon return, the scanner is located at a non-whitespace +// token, and target is either unmodified or filled with some whitespace nodes. +func NextTokenCopyingWhitespace(s *Scanner, target *[]Unparsed) { + for { + tt := s.NextToken() + switch tt { + case EOFToken, BatchSeparatorToken: + // do not copy + return + case WhitespaceToken, MultilineCommentToken, SinglelineCommentToken: + // copy, and loop around + CopyToken(s, target) + continue + default: + return + } + } + +} + +func RecoverToNextStatementCopying(s *Scanner, target *[]Unparsed, StatementTokens []string) { + // We hit an unexpected token ... as an heuristic for continuing parsing, + // skip parsing until we hit a reserved word that starts a statement + // we recognize + for { + NextTokenCopyingWhitespace(s, target) + switch s.TokenType() { + case ReservedWordToken: + if slices.Contains(StatementTokens, s.ReservedWord()) { + return + } + case EOFToken: + return + default: + CopyToken(s, target) + } + } +} + +func RecoverToNextStatement(s *Scanner, StatementTokens []string) { + // We hit an unexpected token ... as an heuristic for continuing parsing, + // skip parsing until we hit a reserved word that starts a statement + // we recognize + for { + s.NextNonWhitespaceCommentToken() + switch s.TokenType() { + case ReservedWordToken: + if slices.Contains(StatementTokens, s.ReservedWord()) { + return + } + case EOFToken: + return + } + } +} diff --git a/sqlparser/dom.go b/sqlparser/dom.go index 0c72587..a75db38 100644 --- a/sqlparser/dom.go +++ b/sqlparser/dom.go @@ -5,21 +5,6 @@ import ( "strings" ) -type Unparsed struct { - Type TokenType - Start, Stop Pos - RawValue string -} - -func (u Unparsed) WithoutPos() Unparsed { - return Unparsed{ - Type: u.Type, - Start: Pos{}, - Stop: Pos{}, - RawValue: u.RawValue, - } -} - type Declare struct { Start Pos Stop Pos @@ -82,12 +67,3 @@ func (e Error) Error() string { func (e Error) WithoutPos() Error { return Error{Message: e.Message} } - -func CreateUnparsed(s *Scanner) Unparsed { - return Unparsed{ - Type: s.TokenType(), - Start: s.Start(), - Stop: s.Stop(), - RawValue: s.Token(), - } -} diff --git a/sqlparser/nodes.go b/sqlparser/nodes.go new file mode 100644 index 0000000..7b2033e --- /dev/null +++ b/sqlparser/nodes.go @@ -0,0 +1,91 @@ +package sqlparser + +import ( + "fmt" +) + +type Nodes struct { + Nodes []Unparsed + DocString []PosString + CreateStatements int + TokenHandlers map[string]func(*Scanner, *Nodes) bool + Errors []Error + BatchSeparatorToken TokenType +} + +func (n *Nodes) Create(s *Scanner) { + n.Nodes = append(n.Nodes, CreateUnparsed(s)) +} + +func (n *Nodes) HasErrors() bool { + return len(n.Errors) > 0 +} + +// Agnostic parser that handles comments, whitespace, and reserved words +func (n *Nodes) Parse(s *Scanner) bool { + newLineEncounteredInDocstring := false + + for { + tt := s.TokenType() + switch tt { + case EOFToken: + return false + case WhitespaceToken, MultilineCommentToken: + n.Create(s) + // do not reset token for a single trailing newline + t := s.Token() + if !newLineEncounteredInDocstring && (t == "\n" || t == "\r\n") { + newLineEncounteredInDocstring = true + } else { + n.DocString = nil + } + s.NextToken() + case SinglelineCommentToken: + // We build up a list of single line comments for the "docstring"; + // it is reset whenever we encounter something else + n.DocString = append(n.DocString, PosString{s.Start(), s.Token()}) + n.Create(s) + newLineEncounteredInDocstring = false + s.NextToken() + case ReservedWordToken: + token := s.ReservedWord() + handler, exists := n.TokenHandlers[token] + if !exists { + n.Errors = append(n.Errors, Error{ + s.Start(), fmt.Sprintf("Expected , got: %s", token), + }) + s.NextToken() + } else { + if handler(s, n) { + // regardless of errors, go on and parse as far as we get... + return true + } + } + case BatchSeparatorToken: + // TODO + errorEmitted := false + for { + switch s.NextToken() { + case WhitespaceToken: + continue + case MalformedBatchSeparatorToken: + if !errorEmitted { + n.Errors = append(n.Errors, Error{ + s.Start(), "`go` should be alone on a line without any comments", + }) + errorEmitted = true + } + continue + default: + return true + } + } + default: + n.Errors = append(n.Errors, Error{ + s.Start(), fmt.Sprintf("Unexpected token: %s", s.Token()), + }) + s.NextToken() + n.DocString = nil + } + } +} diff --git a/sqlparser/pgsql_document.go b/sqlparser/pgsql_document.go index 88ae487..181e4a0 100644 --- a/sqlparser/pgsql_document.go +++ b/sqlparser/pgsql_document.go @@ -6,6 +6,8 @@ import ( "github.com/jackc/pgx/v5/stdlib" ) +var PGSQLStatementTokens = []string{"create"} + type PGSqlDocument struct { creates []Create errors []Error @@ -84,6 +86,30 @@ func (d PGSqlDocument) WithoutPos() Document { // // PostgreSQL uses schema.object notation rather than [schema].[object]. func (doc *PGSqlDocument) parseBatch(s *Scanner, isFirst bool) (hasMore bool) { + nodes := &Nodes{ + TokenHandlers: map[string]func(*Scanner, *Nodes) bool{ + "create": func(s *Scanner, n *Nodes) bool { + // Parse CREATE FUNCTION, CREATE PROCEDURE, CREATE TYPE, etc. + c := doc.parseCreate(s, n.CreateStatements) + c.Driver = &stdlib.Driver{} + + // Prepend any leading comments/whitespace + c.Body = append(n.Nodes, c.Body...) + c.Docstring = n.DocString + doc.creates = append(doc.creates, c) + + return false + }, + }, + } + + hasMore = nodes.Parse(s) + if nodes.HasErrors() { + doc.errors = append(doc.errors, nodes.Errors...) + } + + return hasMore + var nodes []Unparsed var docstring []PosString newLineEncounteredInDocstring := false @@ -177,13 +203,14 @@ func (doc *PGSqlDocument) parseBatch(s *Scanner, isFirst bool) (hasMore bool) { // We parse until we hit a semicolon or EOF, tracking dependencies on other objects. func (doc *PGSqlDocument) parseCreate(s *Scanner, createCountInBatch int) (result Create) { var body []Unparsed - pos := s.Start() // Copy the CREATE token CopyToken(s, &body) s.NextNonWhitespaceCommentToken() // Check for OR REPLACE + // NOTE: "or replace" doesn't make sense within sqlcode as this will be created within a new + // schema. if s.TokenType() == ReservedWordToken && s.ReservedWord() == "or" { CopyToken(s, &body) s.NextNonWhitespaceCommentToken() @@ -193,7 +220,7 @@ func (doc *PGSqlDocument) parseCreate(s *Scanner, createCountInBatch int) (resul s.NextNonWhitespaceCommentToken() } else { doc.addError(s, "Expected 'REPLACE' after 'OR'") - doc.recoverToNextStatementCopying(s, &body) + RecoverToNextStatementCopying(s, &body, PGSQLStatementTokens) result.Body = body return } @@ -202,7 +229,7 @@ func (doc *PGSqlDocument) parseCreate(s *Scanner, createCountInBatch int) (resul // Parse the object type (FUNCTION, PROCEDURE, TYPE, etc.) if s.TokenType() != ReservedWordToken { doc.addError(s, "Expected object type after CREATE (e.g., FUNCTION, PROCEDURE, TYPE)") - doc.recoverToNextStatementCopying(s, &body) + RecoverToNextStatementCopying(s, &body, PGSQLStatementTokens) result.Body = body return } @@ -218,29 +245,23 @@ func (doc *PGSqlDocument) parseCreate(s *Scanner, createCountInBatch int) (resul // Supported types default: doc.addError(s, fmt.Sprintf("Unsupported CREATE type for PostgreSQL: %s", createType)) - doc.recoverToNextStatementCopying(s, &body) + RecoverToNextStatementCopying(s, &body, PGSQLStatementTokens) result.Body = body return } - // Parse the object name (with optional schema qualification) - // objectName := doc.parseQualifiedName(s, &body) - // if objectName == "" { - // doc.addError(s, "Expected object name after CREATE "+createType) - // doc.recoverToNextStatementCopying(s, &body) - // result.Body = body - // return - // } - - // result.QuotedName = PosString{pos, objectName} - - // Insist on [code]. + // Insist on [code] to provide the ability for sqlcode to patch function bodies + // with references to other sqlcode objects. if s.TokenType() != QuotedIdentifierToken || s.Token() != "[code]" { doc.addError(s, fmt.Sprintf("create %s must be followed by [code].", result.CreateType)) - doc.recoverToNextStatementCopying(s, &result.Body) + RecoverToNextStatementCopying(s, &result.Body, PGSQLStatementTokens) return } - result.QuotedName = doc.parseCodeschemaName(s, &result.Body) + var err error + result.QuotedName, err = ParseCodeschemaName(s, &result.Body, PGSQLStatementTokens) + if err != nil { + doc.addError(s, err.Error()) + } if result.QuotedName.String() == "" { return } @@ -510,24 +531,6 @@ func (doc *PGSqlDocument) extractObjectName(identifier string) string { return identifier } -// recoverToNextStatementCopying recovers from parse errors by skipping to next statement -func (doc *PGSqlDocument) recoverToNextStatementCopying(s *Scanner, target *[]Unparsed) { - for { - NextTokenCopyingWhitespace(s, target) - switch s.TokenType() { - case ReservedWordToken: - switch s.ReservedWord() { - case "create", "drop", "alter": - return - } - case EOFToken, SemicolonToken: - return - default: - CopyToken(s, target) - } - } -} - func (doc *PGSqlDocument) addError(s *Scanner, err string) { doc.errors = append(doc.errors, Error{ s.Start(), err, @@ -538,7 +541,7 @@ func (doc *PGSqlDocument) parseDeclareBatch(s *Scanner) (hasMore bool) { // PostgreSQL doesn't have top-level DECLARE batches like T-SQL // DECLARE is only used inside function/procedure bodies (in BEGIN...END blocks) doc.addError(s, "PostgreSQL does not support top-level DECLARE statements outside of function bodies") - doc.recoverToNextStatement(s) + RecoverToNextStatement(s, PGSQLStatementTokens) return false } @@ -547,21 +550,3 @@ func (doc *PGSqlDocument) parseBatchSeparator(s *Scanner) { doc.addError(s, "PostgreSQL does not use 'GO' batch separators; use semicolons") s.NextToken() } - -func (doc *PGSqlDocument) recoverToNextStatement(s *Scanner) { - // We hit an unexpected token ... as an heuristic for continuing parsing, - // skip parsing until we hit a reserved word that starts a statement - // we recognize - for { - s.NextNonWhitespaceCommentToken() - switch s.TokenType() { - case ReservedWordToken: - switch s.ReservedWord() { - case "create", "drop", "alter": - return - } - case EOFToken, SemicolonToken: - return - } - } -} diff --git a/sqlparser/tsql_document.go b/sqlparser/tsql_document.go index 1038de7..8c17499 100644 --- a/sqlparser/tsql_document.go +++ b/sqlparser/tsql_document.go @@ -8,6 +8,8 @@ import ( mssql "github.com/microsoft/go-mssqldb" ) +var TSQLStatementTokens = []string{"create", "declare", "go"} + type TSqlDocument struct { pragmaIncludeIf []string creates []Create @@ -118,7 +120,7 @@ func (doc *TSqlDocument) parseTypeExpression(s *Scanner) (t Type) { t.Args = append(t.Args, "max") default: doc.unexpectedTokenError(s) - doc.recoverToNextStatement(s) + RecoverToNextStatement(s, TSQLStatementTokens) return } s.NextNonWhitespaceCommentToken() @@ -131,7 +133,7 @@ func (doc *TSqlDocument) parseTypeExpression(s *Scanner) (t Type) { return default: doc.unexpectedTokenError(s) - doc.recoverToNextStatement(s) + RecoverToNextStatement(s, TSQLStatementTokens) return } } @@ -156,7 +158,7 @@ loop: for { if s.TokenType() != VariableIdentifierToken { doc.unexpectedTokenError(s) - doc.recoverToNextStatement(s) + RecoverToNextStatement(s, TSQLStatementTokens) return } @@ -179,7 +181,7 @@ loop: if s.TokenType() != EqualToken { doc.addError(s, "sqlcode constants needs to be assigned at once using =") - doc.recoverToNextStatement(s) + RecoverToNextStatement(s, TSQLStatementTokens) } switch s.NextNonWhitespaceCommentToken() { @@ -194,7 +196,7 @@ loop: result = append(result, declare) default: doc.unexpectedTokenError(s) - doc.recoverToNextStatement(s) + RecoverToNextStatement(s, TSQLStatementTokens) return } @@ -250,151 +252,50 @@ func (doc *TSqlDocument) parseDeclareBatch(s *Scanner) (hasMore bool) { doc.declares = append(doc.declares, d...) case tt == ReservedWordToken && s.ReservedWord() != "declare": doc.addError(s, "Only 'declare' allowed in this batch") - doc.recoverToNextStatement(s) + RecoverToNextStatement(s, TSQLStatementTokens) case tt == BatchSeparatorToken: doc.parseBatchSeparator(s) return true default: doc.unexpectedTokenError(s) - doc.recoverToNextStatement(s) + RecoverToNextStatement(s, TSQLStatementTokens) } } } func (doc *TSqlDocument) parseBatch(s *Scanner, isFirst bool) (hasMore bool) { - var nodes []Unparsed - var docstring []PosString - newLineEncounteredInDocstring := false - - var createCountInBatch int - - for { - tt := s.TokenType() - switch tt { - case EOFToken: - return false - case WhitespaceToken, MultilineCommentToken: - nodes = append(nodes, CreateUnparsed(s)) - // do not reset token for a single trailing newline - t := s.Token() - if !newLineEncounteredInDocstring && (t == "\n" || t == "\r\n") { - newLineEncounteredInDocstring = true - } else { - docstring = nil - } - s.NextToken() - case SinglelineCommentToken: - // We build up a list of single line comments for the "docstring"; - // it is reset whenever we encounter something else - docstring = append(docstring, PosString{s.Start(), s.Token()}) - nodes = append(nodes, CreateUnparsed(s)) - newLineEncounteredInDocstring = false - s.NextToken() - case ReservedWordToken: - switch s.ReservedWord() { - case "declare": + nodes := &Nodes{ + TokenHandlers: map[string]func(*Scanner, *Nodes) bool{ + "declare": func(s *Scanner, n *Nodes) bool { // First declare-statement; enter a mode where we assume all contents // of batch are declare statements if !isFirst { doc.addError(s, "'declare' statement only allowed in first batch") } + // regardless of errors, go on and parse as far as we get... return doc.parseDeclareBatch(s) - case "create": + }, + "create": func(s *Scanner, n *Nodes) bool { // should be start of create procedure or create function... - c := doc.parseCreate(s, createCountInBatch) + c := doc.parseCreate(s, n.CreateStatements) c.Driver = &mssql.Driver{} // *prepend* what we saw before getting to the 'create' - createCountInBatch++ - c.Body = append(nodes, c.Body...) - c.Docstring = docstring + n.CreateStatements++ + c.Body = append(n.Nodes, c.Body...) + c.Docstring = n.DocString doc.creates = append(doc.creates, c) - default: - doc.addError(s, "Expected 'declare' or 'create', got: "+s.ReservedWord()) - s.NextToken() - } - case BatchSeparatorToken: - doc.parseBatchSeparator(s) - return true - default: - doc.unexpectedTokenError(s) - s.NextToken() - docstring = nil - } + return false + }, + }, } -} - -func (d *TSqlDocument) recoverToNextStatementCopying(s *Scanner, target *[]Unparsed) { - // We hit an unexpected token ... as an heuristic for continuing parsing, - // skip parsing until we hit a reserved word that starts a statement - // we recognize - for { - NextTokenCopyingWhitespace(s, target) - switch s.TokenType() { - case ReservedWordToken: - switch s.ReservedWord() { - case "declare", "create", "go": - return - } - case EOFToken: - return - default: - CopyToken(s, target) - } + hasMore = nodes.Parse(s) + if nodes.HasErrors() { + doc.errors = append(doc.errors, nodes.Errors...) } -} -func (d *TSqlDocument) recoverToNextStatement(s *Scanner) { - // We hit an unexpected token ... as an heuristic for continuing parsing, - // skip parsing until we hit a reserved word that starts a statement - // we recognize - for { - s.NextNonWhitespaceCommentToken() - switch s.TokenType() { - case ReservedWordToken: - switch s.ReservedWord() { - case "declare", "create", "go": - return - } - case EOFToken: - return - } - } -} - -// parseCodeschemaName parses `[code] . something`, and returns `something` -// in quoted form (`[something]`). Also copy to `target`. Empty string on error. -// Note: To follow conventions, consume one extra token at the end even if we know -// it fill not be consumed by this function... -func (d *TSqlDocument) parseCodeschemaName(s *Scanner, target *[]Unparsed) PosString { - CopyToken(s, target) - NextTokenCopyingWhitespace(s, target) - if s.TokenType() != DotToken { - d.addError(s, fmt.Sprintf("[code] must be followed by '.'")) - d.recoverToNextStatementCopying(s, target) - return PosString{Value: ""} - } - CopyToken(s, target) - - NextTokenCopyingWhitespace(s, target) - switch s.TokenType() { - case UnquotedIdentifierToken: - // To get something uniform for comparison, quote all names - CopyToken(s, target) - result := PosString{Pos: s.Start(), Value: "[" + s.Token() + "]"} - NextTokenCopyingWhitespace(s, target) - return result - case QuotedIdentifierToken: - CopyToken(s, target) - result := PosString{Pos: s.Start(), Value: s.Token()} - NextTokenCopyingWhitespace(s, target) - return result - default: - d.addError(s, fmt.Sprintf("[code]. must be followed an identifier")) - d.recoverToNextStatementCopying(s, target) - return PosString{Value: ""} - } + return hasMore } // parseCreate parses anything that starts with "create". Position is @@ -417,12 +318,12 @@ func (d *TSqlDocument) parseCreate(s *Scanner, createCountInBatch int) (result C createType := strings.ToLower(s.Token()) if !(createType == "procedure" || createType == "function" || createType == "type") { d.addError(s, fmt.Sprintf("sqlcode only supports creating procedures, functions or types; not `%s`", createType)) - d.recoverToNextStatementCopying(s, &result.Body) + RecoverToNextStatementCopying(s, &result.Body, TSQLStatementTokens) return } if (createType == "procedure" || createType == "function") && createCountInBatch > 0 { d.addError(s, "a procedure/function must be alone in a batch; use 'go' to split batches") - d.recoverToNextStatementCopying(s, &result.Body) + RecoverToNextStatementCopying(s, &result.Body, TSQLStatementTokens) return } @@ -434,10 +335,14 @@ func (d *TSqlDocument) parseCreate(s *Scanner, createCountInBatch int) (result C // Insist on [code]. if s.TokenType() != QuotedIdentifierToken || s.Token() != "[code]" { d.addError(s, fmt.Sprintf("create %s must be followed by [code].", result.CreateType)) - d.recoverToNextStatementCopying(s, &result.Body) + RecoverToNextStatementCopying(s, &result.Body, TSQLStatementTokens) return } - result.QuotedName = d.parseCodeschemaName(s, &result.Body) + var err error + result.QuotedName, err = ParseCodeschemaName(s, &result.Body, TSQLStatementTokens) + if err != nil { + d.addError(s, err.Error()) + } if result.QuotedName.String() == "" { return } @@ -471,7 +376,7 @@ tailloop: if (tt2 == ReservedWordToken && (s.ReservedWord() == "function" || s.ReservedWord() == "procedure")) || (tt2 == UnquotedIdentifierToken && s.TokenLower() == "type") { - d.recoverToNextStatementCopying(s, &result.Body) + RecoverToNextStatementCopying(s, &result.Body, TSQLStatementTokens) d.addError(s, "a procedure/function must be alone in a batch; use 'go' to split batches") return } @@ -488,7 +393,10 @@ tailloop: break tailloop case tt == QuotedIdentifierToken && s.Token() == "[code]": // Parse a dependency - dep := d.parseCodeschemaName(s, &result.Body) + dep, err := ParseCodeschemaName(s, &result.Body, TSQLStatementTokens) + if err != nil { + d.addError(s, err.Error()) + } found := false for _, existing := range result.DependsOn { if existing.Value == dep.Value { @@ -549,24 +457,3 @@ tailloop: }) return } - -// NextTokenCopyingWhitespace is like s.NextToken(), but if whitespace is encountered -// it is simply copied into `target`. Upon return, the scanner is located at a non-whitespace -// token, and target is either unmodified or filled with some whitespace nodes. -func NextTokenCopyingWhitespace(s *Scanner, target *[]Unparsed) { - for { - tt := s.NextToken() - switch tt { - case EOFToken, BatchSeparatorToken: - // do not copy - return - case WhitespaceToken, MultilineCommentToken, SinglelineCommentToken: - // copy, and loop around - CopyToken(s, target) - continue - default: - return - } - } - -} diff --git a/sqlparser/unparsed.go b/sqlparser/unparsed.go new file mode 100644 index 0000000..45a85bf --- /dev/null +++ b/sqlparser/unparsed.go @@ -0,0 +1,25 @@ +package sqlparser + +type Unparsed struct { + Type TokenType + Start, Stop Pos + RawValue string +} + +func CreateUnparsed(s *Scanner) Unparsed { + return Unparsed{ + Type: s.TokenType(), + Start: s.Start(), + Stop: s.Stop(), + RawValue: s.Token(), + } +} + +func (u Unparsed) WithoutPos() Unparsed { + return Unparsed{ + Type: u.Type, + Start: Pos{}, + Stop: Pos{}, + RawValue: u.RawValue, + } +} From fb55399b19aacbfb148353096e6f99cff016faf5 Mon Sep 17 00:00:00 2001 From: Kyle Simukka Date: Tue, 16 Dec 2025 19:24:21 +0100 Subject: [PATCH 27/40] Refactored pgsql document to use node parser. --- sqlparser/pgsql_document.go | 150 ++++++++++++++++++------------------ 1 file changed, 75 insertions(+), 75 deletions(-) diff --git a/sqlparser/pgsql_document.go b/sqlparser/pgsql_document.go index 181e4a0..f6e2e8f 100644 --- a/sqlparser/pgsql_document.go +++ b/sqlparser/pgsql_document.go @@ -110,81 +110,81 @@ func (doc *PGSqlDocument) parseBatch(s *Scanner, isFirst bool) (hasMore bool) { return hasMore - var nodes []Unparsed - var docstring []PosString - newLineEncounteredInDocstring := false - - for { - tt := s.TokenType() - switch tt { - case EOFToken: - return false - case WhitespaceToken, MultilineCommentToken: - nodes = append(nodes, CreateUnparsed(s)) - // do not reset docstring for a single trailing newline - t := s.Token() - if !newLineEncounteredInDocstring && (t == "\n" || t == "\r\n") { - newLineEncounteredInDocstring = true - } else { - docstring = nil - } - s.NextToken() - case SinglelineCommentToken: - // Build up a list of single line comments for the "docstring"; - // it is reset whenever we encounter something else - docstring = append(docstring, PosString{s.Start(), s.Token()}) - nodes = append(nodes, CreateUnparsed(s)) - newLineEncounteredInDocstring = false - s.NextToken() - case ReservedWordToken: - switch s.ReservedWord() { - case "declare": - // PostgreSQL doesn't have top-level DECLARE batches like T-SQL - // DECLARE is only used inside function/procedure bodies - if isFirst { - doc.addError(s, "PostgreSQL 'declare' is used inside function bodies, not as top-level batch statements") - } - nodes = append(nodes, CreateUnparsed(s)) - s.NextToken() - docstring = nil - case "create": - // Parse CREATE FUNCTION, CREATE PROCEDURE, CREATE TYPE, etc. - createStart := len(doc.creates) - c := doc.parseCreate(s, createStart) - c.Driver = &stdlib.Driver{} - - // Prepend any leading comments/whitespace - c.Body = append(nodes, c.Body...) - c.Docstring = docstring - doc.creates = append(doc.creates, c) - - // Reset for next statement - nodes = nil - docstring = nil - newLineEncounteredInDocstring = false - default: - doc.addError(s, "Expected 'create', got: "+s.ReservedWord()) - s.NextToken() - docstring = nil - } - case SemicolonToken: - // PostgreSQL uses semicolons as statement terminators - // Multiple CREATE statements can exist in same file - nodes = append(nodes, CreateUnparsed(s)) - s.NextToken() - // Continue parsing - don't return like T-SQL does with GO - case BatchSeparatorToken: - // PostgreSQL doesn't use GO batch separators - // Q: Do we want to use GO batch separators as a feature of sqlcode? - doc.addError(s, "PostgreSQL does not use 'GO' batch separators; use semicolons instead") - s.NextToken() - docstring = nil - default: - doc.addError(s, fmt.Sprintf("Unexpected token in PostgreSQL document: %s", s.Token())) - s.NextToken() - docstring = nil - } - } + // var nodes []Unparsed + // var docstring []PosString + // newLineEncounteredInDocstring := false + + // for { + // tt := s.TokenType() + // switch tt { + // case EOFToken: + // return false + // case WhitespaceToken, MultilineCommentToken: + // nodes = append(nodes, CreateUnparsed(s)) + // // do not reset docstring for a single trailing newline + // t := s.Token() + // if !newLineEncounteredInDocstring && (t == "\n" || t == "\r\n") { + // newLineEncounteredInDocstring = true + // } else { + // docstring = nil + // } + // s.NextToken() + // case SinglelineCommentToken: + // // Build up a list of single line comments for the "docstring"; + // // it is reset whenever we encounter something else + // docstring = append(docstring, PosString{s.Start(), s.Token()}) + // nodes = append(nodes, CreateUnparsed(s)) + // newLineEncounteredInDocstring = false + // s.NextToken() + // case ReservedWordToken: + // switch s.ReservedWord() { + // case "declare": + // // PostgreSQL doesn't have top-level DECLARE batches like T-SQL + // // DECLARE is only used inside function/procedure bodies + // if isFirst { + // doc.addError(s, "PostgreSQL 'declare' is used inside function bodies, not as top-level batch statements") + // } + // nodes = append(nodes, CreateUnparsed(s)) + // s.NextToken() + // docstring = nil + // case "create": + // // Parse CREATE FUNCTION, CREATE PROCEDURE, CREATE TYPE, etc. + // createStart := len(doc.creates) + // c := doc.parseCreate(s, createStart) + // c.Driver = &stdlib.Driver{} + + // // Prepend any leading comments/whitespace + // c.Body = append(nodes, c.Body...) + // c.Docstring = docstring + // doc.creates = append(doc.creates, c) + + // // Reset for next statement + // nodes = nil + // docstring = nil + // newLineEncounteredInDocstring = false + // default: + // doc.addError(s, "Expected 'create', got: "+s.ReservedWord()) + // s.NextToken() + // docstring = nil + // } + // case SemicolonToken: + // // PostgreSQL uses semicolons as statement terminators + // // Multiple CREATE statements can exist in same file + // nodes = append(nodes, CreateUnparsed(s)) + // s.NextToken() + // // Continue parsing - don't return like T-SQL does with GO + // case BatchSeparatorToken: + // // PostgreSQL doesn't use GO batch separators + // // Q: Do we want to use GO batch separators as a feature of sqlcode? + // doc.addError(s, "PostgreSQL does not use 'GO' batch separators; use semicolons instead") + // s.NextToken() + // docstring = nil + // default: + // doc.addError(s, fmt.Sprintf("Unexpected token in PostgreSQL document: %s", s.Token())) + // s.NextToken() + // docstring = nil + // } + // } } // parseCreate parses PostgreSQL CREATE statements (FUNCTION, PROCEDURE, TYPE, etc.) From 9eacf7e726bae2a67b8581256f1196cc0f14f682 Mon Sep 17 00:00:00 2001 From: Kyle Simukka Date: Thu, 18 Dec 2025 20:21:29 +0100 Subject: [PATCH 28/40] [wip] --- sqlparser/{nodes.go => batch.go} | 10 +-- sqlparser/document_test.go | 95 ++++++++++++++++++++++++++ sqlparser/node_test.go | 1 + sqlparser/pgsql_document.go | 12 ++-- sqlparser/scanner.go | 13 ++-- sqlparser/tokentype.go | 7 ++ sqlparser/tsql_document.go | 15 ++-- sqlparser/tsql_document_test.go | 114 ++----------------------------- 8 files changed, 135 insertions(+), 132 deletions(-) rename sqlparser/{nodes.go => batch.go} (91%) create mode 100644 sqlparser/node_test.go diff --git a/sqlparser/nodes.go b/sqlparser/batch.go similarity index 91% rename from sqlparser/nodes.go rename to sqlparser/batch.go index 7b2033e..25504f5 100644 --- a/sqlparser/nodes.go +++ b/sqlparser/batch.go @@ -4,25 +4,25 @@ import ( "fmt" ) -type Nodes struct { +type Batch struct { Nodes []Unparsed DocString []PosString CreateStatements int - TokenHandlers map[string]func(*Scanner, *Nodes) bool + TokenHandlers map[string]func(*Scanner, *Batch) bool Errors []Error BatchSeparatorToken TokenType } -func (n *Nodes) Create(s *Scanner) { +func (n *Batch) Create(s *Scanner) { n.Nodes = append(n.Nodes, CreateUnparsed(s)) } -func (n *Nodes) HasErrors() bool { +func (n *Batch) HasErrors() bool { return len(n.Errors) > 0 } // Agnostic parser that handles comments, whitespace, and reserved words -func (n *Nodes) Parse(s *Scanner) bool { +func (n *Batch) Parse(s *Scanner) bool { newLineEncounteredInDocstring := false for { diff --git a/sqlparser/document_test.go b/sqlparser/document_test.go index 8e497e6..ec011c5 100644 --- a/sqlparser/document_test.go +++ b/sqlparser/document_test.go @@ -1,6 +1,7 @@ package sqlparser import ( + "fmt" "testing" "github.com/stretchr/testify/assert" @@ -54,3 +55,97 @@ func TestNewDocumentFromExtension(t *testing.T) { require.NotEqual(t, sqlDoc, pgsqlDoc) }) } + +func TestDocument_parseCodeschemaName(t *testing.T) { + t.Run("parses unquoted identifier", func(t *testing.T) { + s := NewScanner("test.sql", "[code].TestProc") + s.NextToken() + var target []Unparsed + + result, err := ParseCodeschemaName(s, &target, nil) + assert.NoError(t, err) + assert.Equal(t, "[TestProc]", result.Value) + assert.NotEmpty(t, target) + }) + + t.Run("parses quoted identifier", func(t *testing.T) { + s := NewScanner("test.sql", "[code].[Test Proc]") + s.NextToken() + var target []Unparsed + + result, err := ParseCodeschemaName(s, &target, nil) + assert.NoError(t, err) + + assert.Equal(t, "[Test Proc]", result.Value) + }) + + t.Run("errors on missing dot", func(t *testing.T) { + s := NewScanner("test.sql", "[code] TestProc") + s.NextToken() + var target []Unparsed + + result, err := ParseCodeschemaName(s, &target, nil) + assert.Error(t, err) + + assert.Equal(t, "", result.Value) + assert.ErrorContains(t, err, "must be followed by '.'") + }) + + t.Run("errors on missing identifier", func(t *testing.T) { + s := NewScanner("test.sql", "[code].123") + s.NextToken() + var target []Unparsed + + result, err := ParseCodeschemaName(s, &target, nil) + + assert.Error(t, err) + assert.Equal(t, "", result.Value) + assert.ErrorContains(t, err, "must be followed an identifier") + }) +} + +func TestDocument_recoverToNextStatement(t *testing.T) { + t.Run("recovers to declare", func(t *testing.T) { + s := NewScanner("test.sql", "invalid tokens here declare @x int = 1") + s.NextToken() + + RecoverToNextStatement(s, []string{"declare"}) + + fmt.Printf("%#v\n", s) + + assert.Equal(t, ReservedWordToken, s.TokenType()) + assert.Equal(t, "declare", s.ReservedWord()) + }) + + t.Run("recovers to create", func(t *testing.T) { + s := NewScanner("test.sql", "bad stuff create procedure") + s.NextToken() + + RecoverToNextStatement(s, []string{"create"}) + + assert.Equal(t, ReservedWordToken, s.TokenType()) + assert.Equal(t, "create", s.ReservedWord()) + }) + + t.Run("stops at EOF", func(t *testing.T) { + s := NewScanner("test.sql", "no keywords") + s.NextToken() + + RecoverToNextStatement(s, []string{}) + + assert.Equal(t, EOFToken, s.TokenType()) + }) +} + +func TestDocument_recoverToNextStatementCopying(t *testing.T) { + t.Run("copies tokens while recovering", func(t *testing.T) { + s := NewScanner("test.sql", "bad token declare") + s.NextToken() + var target []Unparsed + + RecoverToNextStatementCopying(s, &target, []string{"declare"}) + + assert.NotEmpty(t, target) + assert.Equal(t, "declare", s.ReservedWord()) + }) +} diff --git a/sqlparser/node_test.go b/sqlparser/node_test.go new file mode 100644 index 0000000..04173c6 --- /dev/null +++ b/sqlparser/node_test.go @@ -0,0 +1 @@ +package sqlparser diff --git a/sqlparser/pgsql_document.go b/sqlparser/pgsql_document.go index f6e2e8f..e97ec32 100644 --- a/sqlparser/pgsql_document.go +++ b/sqlparser/pgsql_document.go @@ -86,9 +86,9 @@ func (d PGSqlDocument) WithoutPos() Document { // // PostgreSQL uses schema.object notation rather than [schema].[object]. func (doc *PGSqlDocument) parseBatch(s *Scanner, isFirst bool) (hasMore bool) { - nodes := &Nodes{ - TokenHandlers: map[string]func(*Scanner, *Nodes) bool{ - "create": func(s *Scanner, n *Nodes) bool { + batch := &Batch{ + TokenHandlers: map[string]func(*Scanner, *Batch) bool{ + "create": func(s *Scanner, n *Batch) bool { // Parse CREATE FUNCTION, CREATE PROCEDURE, CREATE TYPE, etc. c := doc.parseCreate(s, n.CreateStatements) c.Driver = &stdlib.Driver{} @@ -103,9 +103,9 @@ func (doc *PGSqlDocument) parseBatch(s *Scanner, isFirst bool) (hasMore bool) { }, } - hasMore = nodes.Parse(s) - if nodes.HasErrors() { - doc.errors = append(doc.errors, nodes.Errors...) + hasMore = batch.Parse(s) + if batch.HasErrors() { + doc.errors = append(doc.errors, batch.Errors...) } return hasMore diff --git a/sqlparser/scanner.go b/sqlparser/scanner.go index 9e6c27f..18b2b77 100644 --- a/sqlparser/scanner.go +++ b/sqlparser/scanner.go @@ -1,6 +1,7 @@ package sqlparser import ( + "fmt" "regexp" "strings" "unicode" @@ -221,7 +222,9 @@ func (s *Scanner) nextToken() TokenType { return VariableIdentifierToken } else { rw := strings.ToLower(s.Token()) - if _, ok := reservedWords[rw]; ok { + _, ok := reservedWords[rw] + fmt.Printf("%#v %t\n", rw, ok) + if ok { s.reservedWord = rw return ReservedWordToken } else { @@ -243,7 +246,10 @@ func (s *Scanner) nextToken() TokenType { // no, it is instead an identifier starting with N... s.scanIdentifier() rw := strings.ToLower(s.Token()) - if _, ok := reservedWords[rw]; ok { + _, ok := reservedWords[rw] + fmt.Printf("%#v %t\n", rw, ok) + + if ok { s.reservedWord = rw return ReservedWordToken } else { @@ -380,6 +386,7 @@ func (s *Scanner) scanWhitespace() TokenType { return WhitespaceToken } +// tsql (mssql) reservered words var reservedWords = map[string]struct{}{ "add": struct{}{}, "external": struct{}{}, @@ -565,8 +572,6 @@ var reservedWords = map[string]struct{}{ "writetext": struct{}{}, "exit": struct{}{}, "proc": struct{}{}, - // pgsql - "replace": struct{}{}, } // apparently 'within group' is also reserved but dropping that.. diff --git a/sqlparser/tokentype.go b/sqlparser/tokentype.go index 835e605..644d8da 100644 --- a/sqlparser/tokentype.go +++ b/sqlparser/tokentype.go @@ -38,6 +38,10 @@ const ( UnexpectedCharacterToken NonUTF8ErrorToken + // PGSQL specific + DollarQuotedStringStartToken + DollarQuotedStringEndToken + BatchSeparatorToken MalformedBatchSeparatorToken EOFToken @@ -90,6 +94,9 @@ var tokenToDescription = map[TokenType]string{ UnexpectedCharacterToken: "UnexpectedCharacterToken", NonUTF8ErrorToken: "NonUTF8ErrorToken", + DollarQuotedStringStartToken: "DollarQuotedStringEndToken", + DollarQuotedStringEndToken: "DollarQuotedStringEndToken", + // After a lot of back and forth we added the batch separater to the scanner. // We implement sqlcmd's use of the go // do separate batches. sqlcmd will only support GO at the start of diff --git a/sqlparser/tsql_document.go b/sqlparser/tsql_document.go index 8c17499..81f85bb 100644 --- a/sqlparser/tsql_document.go +++ b/sqlparser/tsql_document.go @@ -221,6 +221,7 @@ func (doc *TSqlDocument) parseBatchSeparator(s *Scanner) { // just saw a 'go'; just make sure there's nothing bad trailing it // (if there is, convert to errors and move on until the line is consumed errorEmitted := false + // continuously process tokens until a non-whitespace, non-malformed token is encountered. for { switch s.NextToken() { case WhitespaceToken: @@ -264,9 +265,9 @@ func (doc *TSqlDocument) parseDeclareBatch(s *Scanner) (hasMore bool) { } func (doc *TSqlDocument) parseBatch(s *Scanner, isFirst bool) (hasMore bool) { - nodes := &Nodes{ - TokenHandlers: map[string]func(*Scanner, *Nodes) bool{ - "declare": func(s *Scanner, n *Nodes) bool { + batch := &Batch{ + TokenHandlers: map[string]func(*Scanner, *Batch) bool{ + "declare": func(s *Scanner, n *Batch) bool { // First declare-statement; enter a mode where we assume all contents // of batch are declare statements if !isFirst { @@ -276,7 +277,7 @@ func (doc *TSqlDocument) parseBatch(s *Scanner, isFirst bool) (hasMore bool) { // regardless of errors, go on and parse as far as we get... return doc.parseDeclareBatch(s) }, - "create": func(s *Scanner, n *Nodes) bool { + "create": func(s *Scanner, n *Batch) bool { // should be start of create procedure or create function... c := doc.parseCreate(s, n.CreateStatements) c.Driver = &mssql.Driver{} @@ -290,9 +291,9 @@ func (doc *TSqlDocument) parseBatch(s *Scanner, isFirst bool) (hasMore bool) { }, }, } - hasMore = nodes.Parse(s) - if nodes.HasErrors() { - doc.errors = append(doc.errors, nodes.Errors...) + hasMore = batch.Parse(s) + if batch.HasErrors() { + doc.errors = append(doc.errors, batch.Errors...) } return hasMore diff --git a/sqlparser/tsql_document_test.go b/sqlparser/tsql_document_test.go index b44d575..3d07ad3 100644 --- a/sqlparser/tsql_document_test.go +++ b/sqlparser/tsql_document_test.go @@ -1,6 +1,7 @@ package sqlparser import ( + "fmt" "testing" "github.com/stretchr/testify/assert" @@ -228,66 +229,16 @@ func TestDocument_parseBatchSeparator(t *testing.T) { t.Run("errors on malformed separator", func(t *testing.T) { doc := &TSqlDocument{} s := NewScanner("test.sql", "go -- comment") - s.NextToken() - + tt := s.NextToken() + fmt.Printf("%#v %#v\n", s, tt) doc.parseBatchSeparator(s) + fmt.Printf("%#v\n", s) assert.NotEmpty(t, doc.errors) assert.Contains(t, doc.errors[0].Message, "should be alone") }) } -func TestDocument_parseCodeschemaName(t *testing.T) { - t.Run("parses unquoted identifier", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.sql", "[code].TestProc") - s.NextToken() - var target []Unparsed - - result := doc.parseCodeschemaName(s, &target) - - assert.Equal(t, "[TestProc]", result.Value) - assert.NotEmpty(t, target) - }) - - t.Run("parses quoted identifier", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.sql", "[code].[Test Proc]") - s.NextToken() - var target []Unparsed - - result := doc.parseCodeschemaName(s, &target) - - assert.Equal(t, "[Test Proc]", result.Value) - }) - - t.Run("errors on missing dot", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.sql", "[code] TestProc") - s.NextToken() - var target []Unparsed - - result := doc.parseCodeschemaName(s, &target) - - assert.Equal(t, "", result.Value) - assert.NotEmpty(t, doc.errors) - assert.Contains(t, doc.errors[0].Message, "must be followed by '.'") - }) - - t.Run("errors on missing identifier", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.sql", "[code].123") - s.NextToken() - var target []Unparsed - - result := doc.parseCodeschemaName(s, &target) - - assert.Equal(t, "", result.Value) - assert.NotEmpty(t, doc.errors) - assert.Contains(t, doc.errors[0].Message, "must be followed an identifier") - }) -} - func TestDocument_parseCreate(t *testing.T) { t.Run("parses simple procedure", func(t *testing.T) { doc := &TSqlDocument{} @@ -465,60 +416,3 @@ func TestCreateUnparsed(t *testing.T) { assert.Equal(t, Pos{File: "test.sql", Line: 1, Col: 1}, unparsed.Start) }) } - -func TestDocument_recoverToNextStatement(t *testing.T) { - t.Run("recovers to declare", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.sql", "invalid tokens here declare @x int = 1") - s.NextToken() - - doc.recoverToNextStatement(s) - - assert.Equal(t, ReservedWordToken, s.TokenType()) - assert.Equal(t, "declare", s.ReservedWord()) - }) - - t.Run("recovers to create", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.sql", "bad stuff create procedure") - s.NextToken() - - doc.recoverToNextStatement(s) - - assert.Equal(t, "create", s.ReservedWord()) - }) - - t.Run("recovers to go", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.sql", "error error go") - s.NextToken() - - doc.recoverToNextStatement(s) - - assert.Equal(t, "go", s.ReservedWord()) - }) - - t.Run("stops at EOF", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.sql", "no keywords") - s.NextToken() - - doc.recoverToNextStatement(s) - - assert.Equal(t, EOFToken, s.TokenType()) - }) -} - -func TestDocument_recoverToNextStatementCopying(t *testing.T) { - t.Run("copies tokens while recovering", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.sql", "bad token declare") - s.NextToken() - var target []Unparsed - - doc.recoverToNextStatementCopying(s, &target) - - assert.NotEmpty(t, target) - assert.Equal(t, "declare", s.ReservedWord()) - }) -} From 606925cad38c9ad0fc57517da0951a9ab28373e4 Mon Sep 17 00:00:00 2001 From: Kyle Simukka Date: Sat, 3 Jan 2026 15:51:08 +0100 Subject: [PATCH 29/40] cleaning up --- deployable.go | 3 +- mssql_error.go | 4 +- preprocess.go | 15 +- sqlparser/document_test.go | 151 -- sqlparser/mssql/README.md | 48 + .../{tsql_document.go => mssql/document.go} | 268 ++-- sqlparser/mssql/document_test.go | 1273 +++++++++++++++++ sqlparser/{ => mssql}/scanner.go | 213 +-- sqlparser/mssql/scanner_test.go | 600 ++++++++ sqlparser/mssql/tokens.go | 57 + sqlparser/node_test.go | 1 - sqlparser/parser.go | 79 +- sqlparser/parser_test.go | 15 - sqlparser/pgsql_document.go | 552 ------- sqlparser/pgsql_document_test.go | 254 ---- sqlparser/scanner_test.go | 292 ---- sqlparser/{ => sqldocument}/batch.go | 19 +- sqlparser/sqldocument/batch_test.go | 300 ++++ sqlparser/{ => sqldocument}/create.go | 15 +- sqlparser/{ => sqldocument}/document.go | 50 +- sqlparser/sqldocument/document_test.go | 102 ++ sqlparser/{ => sqldocument}/dom.go | 19 +- sqlparser/{ => sqldocument}/pragma.go | 6 +- sqlparser/sqldocument/scanner.go | 56 + sqlparser/sqldocument/tokens.go | 102 ++ .../{ => sqldocument}/topological_sort.go | 2 +- .../topological_sort_test.go | 7 +- sqlparser/sqldocument/unparsed.go | 16 + sqlparser/tokentype.go | 121 -- sqlparser/tsql_document_test.go | 418 ------ sqlparser/unparsed.go | 25 - 31 files changed, 2916 insertions(+), 2167 deletions(-) delete mode 100644 sqlparser/document_test.go create mode 100644 sqlparser/mssql/README.md rename sqlparser/{tsql_document.go => mssql/document.go} (53%) create mode 100644 sqlparser/mssql/document_test.go rename sqlparser/{ => mssql}/scanner.go (70%) create mode 100644 sqlparser/mssql/scanner_test.go create mode 100644 sqlparser/mssql/tokens.go delete mode 100644 sqlparser/node_test.go delete mode 100644 sqlparser/pgsql_document.go delete mode 100644 sqlparser/pgsql_document_test.go delete mode 100644 sqlparser/scanner_test.go rename sqlparser/{ => sqldocument}/batch.go (76%) create mode 100644 sqlparser/sqldocument/batch_test.go rename sqlparser/{ => sqldocument}/create.go (88%) rename sqlparser/{ => sqldocument}/document.go (68%) create mode 100644 sqlparser/sqldocument/document_test.go rename sqlparser/{ => sqldocument}/dom.go (73%) rename sqlparser/{ => sqldocument}/pragma.go (83%) create mode 100644 sqlparser/sqldocument/scanner.go create mode 100644 sqlparser/sqldocument/tokens.go rename sqlparser/{ => sqldocument}/topological_sort.go (98%) rename sqlparser/{ => sqldocument}/topological_sort_test.go (98%) create mode 100644 sqlparser/sqldocument/unparsed.go delete mode 100644 sqlparser/tokentype.go delete mode 100644 sqlparser/tsql_document_test.go delete mode 100644 sqlparser/unparsed.go diff --git a/deployable.go b/deployable.go index 135fd26..e46a31d 100644 --- a/deployable.go +++ b/deployable.go @@ -15,12 +15,13 @@ import ( pgxstdlib "github.com/jackc/pgx/v5/stdlib" mssql "github.com/microsoft/go-mssqldb" "github.com/vippsas/sqlcode/sqlparser" + "github.com/vippsas/sqlcode/sqlparser/sqldocument" ) type Deployable struct { SchemaSuffix string ParsedFiles []string // mainly for use in error messages etc - CodeBase sqlparser.Document + CodeBase sqldocument.Document // cache over whether it has been uploaded to a given DB // (the same physical DB can be in this map multiple times under diff --git a/mssql_error.go b/mssql_error.go index d6f531e..26a3774 100644 --- a/mssql_error.go +++ b/mssql_error.go @@ -6,7 +6,7 @@ import ( "strings" mssql "github.com/microsoft/go-mssqldb" - "github.com/vippsas/sqlcode/sqlparser" + "github.com/vippsas/sqlcode/sqlparser/sqldocument" ) type MSSQLUserError struct { @@ -33,7 +33,7 @@ func (s MSSQLUserError) Error() string { } type SQLCodeParseErrors struct { - Errors []sqlparser.Error + Errors []sqldocument.Error } func (e SQLCodeParseErrors) Error() string { diff --git a/preprocess.go b/preprocess.go index 9478c3f..4b3b072 100644 --- a/preprocess.go +++ b/preprocess.go @@ -12,9 +12,10 @@ import ( "github.com/jackc/pgx/v5/stdlib" "github.com/vippsas/sqlcode/sqlparser" + "github.com/vippsas/sqlcode/sqlparser/sqldocument" ) -func SchemaSuffixFromHash(doc sqlparser.Document) string { +func SchemaSuffixFromHash(doc sqldocument.Document) string { hasher := sha256.New() for _, dec := range doc.Declares() { hasher.Write([]byte(dec.String() + "\n")) @@ -45,7 +46,7 @@ type lineNumberCorrection struct { } type Batch struct { - StartPos sqlparser.Pos + StartPos sqldocument.Pos Lines string // lineNumberCorrections contains data that helps us map from errors in the `Lines` @@ -90,7 +91,7 @@ type PreprocessedFile struct { } type PreprocessorError struct { - Pos sqlparser.Pos + Pos sqldocument.Pos Message string } @@ -100,7 +101,7 @@ func (p PreprocessorError) Error() string { var codeSchemaRegexp = regexp.MustCompile(`(?i)\[code\]`) -func sqlcodeTransformCreate(declares map[string]string, c sqlparser.Create, quotedTargetSchema string) (result Batch, err error) { +func sqlcodeTransformCreate(declares map[string]string, c sqldocument.Create, quotedTargetSchema string) (result Batch, err error) { var w strings.Builder if len(c.Body) > 0 { @@ -117,9 +118,9 @@ func sqlcodeTransformCreate(declares map[string]string, c sqlparser.Create, quot for _, u := range c.Body { token := u.RawValue switch { - case u.Type == sqlparser.QuotedIdentifierToken && u.RawValue == "[code]": + case u.Type == sqldocument.QuotedIdentifierToken && u.RawValue == "[code]": token = quotedTargetSchema - case u.Type == sqlparser.VariableIdentifierToken && sqlparser.IsSqlcodeConstVariable(u.RawValue): + case u.Type == sqldocument.VariableIdentifierToken && sqlparser.IsSqlcodeConstVariable(u.RawValue): constLiteral, ok := declares[u.RawValue] if !ok { err = PreprocessorError{u.Start, fmt.Sprintf("sqlcode constant `%s` not declared", u.RawValue)} @@ -141,7 +142,7 @@ func sqlcodeTransformCreate(declares map[string]string, c sqlparser.Create, quot return } -func Preprocess(doc sqlparser.Document, schemasuffix string, driver driver.Driver) (PreprocessedFile, error) { +func Preprocess(doc sqldocument.Document, schemasuffix string, driver driver.Driver) (PreprocessedFile, error) { var result PreprocessedFile if strings.Contains(schemasuffix, "]") { diff --git a/sqlparser/document_test.go b/sqlparser/document_test.go deleted file mode 100644 index ec011c5..0000000 --- a/sqlparser/document_test.go +++ /dev/null @@ -1,151 +0,0 @@ -package sqlparser - -import ( - "fmt" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestNewDocumentFromExtension(t *testing.T) { - t.Run("returns TSqlDocument for .sql extension", func(t *testing.T) { - doc := NewDocumentFromExtension(".sql") - - _, ok := doc.(*TSqlDocument) - assert.True(t, ok, "Expected TSqlDocument type") - assert.NotNil(t, doc) - }) - - t.Run("returns PGSqlDocument for .pgsql extension", func(t *testing.T) { - doc := NewDocumentFromExtension(".pgsql") - - _, ok := doc.(*PGSqlDocument) - assert.True(t, ok, "Expected PGSqlDocument type") - assert.NotNil(t, doc) - }) - - t.Run("panics for unsupported extension", func(t *testing.T) { - assert.Panics(t, func() { - NewDocumentFromExtension(".txt") - }, "Expected panic for unsupported extension") - }) - - t.Run("panics for empty extension", func(t *testing.T) { - assert.Panics(t, func() { - NewDocumentFromExtension("") - }, "Expected panic for empty extension") - }) - - t.Run("panics for unknown SQL extension", func(t *testing.T) { - assert.Panics(t, func() { - NewDocumentFromExtension(".mysql") - }, "Expected panic for .mysql extension") - }) - - t.Run("extension matching is case insensitive", func(t *testing.T) { - assert.Panics(t, func() { - NewDocumentFromExtension(".SQL") - }, "Expected panic for uppercase .SQL") - }) - - t.Run("returned documents implement Document interface", func(t *testing.T) { - sqlDoc := NewDocumentFromExtension(".sql") - pgsqlDoc := NewDocumentFromExtension(".pgsql") - require.NotEqual(t, sqlDoc, pgsqlDoc) - }) -} - -func TestDocument_parseCodeschemaName(t *testing.T) { - t.Run("parses unquoted identifier", func(t *testing.T) { - s := NewScanner("test.sql", "[code].TestProc") - s.NextToken() - var target []Unparsed - - result, err := ParseCodeschemaName(s, &target, nil) - assert.NoError(t, err) - assert.Equal(t, "[TestProc]", result.Value) - assert.NotEmpty(t, target) - }) - - t.Run("parses quoted identifier", func(t *testing.T) { - s := NewScanner("test.sql", "[code].[Test Proc]") - s.NextToken() - var target []Unparsed - - result, err := ParseCodeschemaName(s, &target, nil) - assert.NoError(t, err) - - assert.Equal(t, "[Test Proc]", result.Value) - }) - - t.Run("errors on missing dot", func(t *testing.T) { - s := NewScanner("test.sql", "[code] TestProc") - s.NextToken() - var target []Unparsed - - result, err := ParseCodeschemaName(s, &target, nil) - assert.Error(t, err) - - assert.Equal(t, "", result.Value) - assert.ErrorContains(t, err, "must be followed by '.'") - }) - - t.Run("errors on missing identifier", func(t *testing.T) { - s := NewScanner("test.sql", "[code].123") - s.NextToken() - var target []Unparsed - - result, err := ParseCodeschemaName(s, &target, nil) - - assert.Error(t, err) - assert.Equal(t, "", result.Value) - assert.ErrorContains(t, err, "must be followed an identifier") - }) -} - -func TestDocument_recoverToNextStatement(t *testing.T) { - t.Run("recovers to declare", func(t *testing.T) { - s := NewScanner("test.sql", "invalid tokens here declare @x int = 1") - s.NextToken() - - RecoverToNextStatement(s, []string{"declare"}) - - fmt.Printf("%#v\n", s) - - assert.Equal(t, ReservedWordToken, s.TokenType()) - assert.Equal(t, "declare", s.ReservedWord()) - }) - - t.Run("recovers to create", func(t *testing.T) { - s := NewScanner("test.sql", "bad stuff create procedure") - s.NextToken() - - RecoverToNextStatement(s, []string{"create"}) - - assert.Equal(t, ReservedWordToken, s.TokenType()) - assert.Equal(t, "create", s.ReservedWord()) - }) - - t.Run("stops at EOF", func(t *testing.T) { - s := NewScanner("test.sql", "no keywords") - s.NextToken() - - RecoverToNextStatement(s, []string{}) - - assert.Equal(t, EOFToken, s.TokenType()) - }) -} - -func TestDocument_recoverToNextStatementCopying(t *testing.T) { - t.Run("copies tokens while recovering", func(t *testing.T) { - s := NewScanner("test.sql", "bad token declare") - s.NextToken() - var target []Unparsed - - RecoverToNextStatementCopying(s, &target, []string{"declare"}) - - assert.NotEmpty(t, target) - assert.Equal(t, "declare", s.ReservedWord()) - }) -} diff --git a/sqlparser/mssql/README.md b/sqlparser/mssql/README.md new file mode 100644 index 0000000..6f470bb --- /dev/null +++ b/sqlparser/mssql/README.md @@ -0,0 +1,48 @@ + +Package mssql provides a T-SQL (Microsoft SQL Server) parser for the sqlcode library. + +# Overview +This package implements a lexical scanner and document parser specifically designed +for T-SQL syntax. It is part of the sqlcode toolchain that manages SQL database +objects (procedures, functions, types) with dependency tracking and code generation. + +# Architecture +The parser follows a two-layer architecture: + 1. Scanner (scanner.go): A lexical tokenizer that breaks T-SQL source into tokens. + It handles T-SQL-specific constructs like N'unicode strings', [bracketed identifiers], + and the GO batch separator. + 2. Document (document.go): A higher-level parser that processes token streams to + extract CREATE statements, DECLARE constants, and dependency information. + +# Token System +T-SQL tokens are divided into two categories: + - Common tokens (defined in sqldocument): Shared across SQL dialects (e.g., parentheses, + whitespace, identifiers). These use token type values 0-999. + - T-SQL-specific tokens (defined in tokens.go): Dialect-specific tokens like + VarcharLiteralToken ('...') and NVarcharLiteralToken (N'...'). These use values 1000-1999. + +# Batch Separator Handling +T-SQL uses GO as a batch separator with special rules: + - GO must appear at the start of a line (only whitespace/comments before it) + - Nothing except whitespace may follow GO on the same line + - GO is not a reserved word; it's a client tool command +The scanner tracks line position state to correctly identify GO as a BatchSeparatorToken +rather than an identifier. Malformed separators (GO followed by non-whitespace) are +reported as MalformedBatchSeparatorToken. + +# Document Structure +The parser recognizes: + - CREATE PROCEDURE/FUNCTION/TYPE statements in the [code] schema + - DECLARE statements for constants (variables starting with @Enum, @Global, or @Const) + - Dependencies between objects via [code].ObjectName references + - Pragma comments (--sqlcode:...) for build-time directives + +# Dependency Tracking +When parsing CREATE statements, the parser scans for [code].ObjectName patterns +to build a dependency graph. This enables topological sorting of objects so they +are created in the correct order during deployment. + +# Error Recovery +The parser uses a recovery strategy that skips to the next statement-starting +keyword (CREATE, DECLARE, GO) when encountering syntax errors. This allows +partial parsing of files with errors while collecting all error messages. diff --git a/sqlparser/tsql_document.go b/sqlparser/mssql/document.go similarity index 53% rename from sqlparser/tsql_document.go rename to sqlparser/mssql/document.go index 81f85bb..16fe254 100644 --- a/sqlparser/tsql_document.go +++ b/sqlparser/mssql/document.go @@ -1,4 +1,4 @@ -package sqlparser +package mssql import ( "fmt" @@ -6,36 +6,50 @@ import ( "strings" mssql "github.com/microsoft/go-mssqldb" + "github.com/vippsas/sqlcode/sqlparser/sqldocument" ) +// TSQLStatementTokens defines the keywords that start new statements. +// Used by error recovery to find a safe point to resume parsing. var TSQLStatementTokens = []string{"create", "declare", "go"} +// TSqlDocument represents a T-SQL source file. +// +// The document contains: +// - creates: CREATE PROCEDURE/FUNCTION/TYPE statements with dependency info +// - declares: DECLARE statements for sqlcode constants (@Enum*, @Global*, @Const*) +// - errors: Syntax and semantic errors encountered during parsing +// - pragmaIncludeIf: Conditional compilation directives from --sqlcode:include-if +// +// Parsing follows T-SQL batch semantics where batches are separated by GO. +// The first batch may contain DECLARE statements for constants. +// Subsequent batches contain CREATE statements for database objects. type TSqlDocument struct { pragmaIncludeIf []string - creates []Create - declares []Declare - errors []Error + creates []sqldocument.Create + declares []sqldocument.Declare + errors []sqldocument.Error - Pragma + sqldocument.Pragma } -func (d TSqlDocument) HasErrors() bool { - return len(d.errors) > 0 -} - -func (d TSqlDocument) Creates() []Create { - return d.creates -} - -func (d TSqlDocument) Declares() []Declare { - return d.declares -} - -func (d TSqlDocument) Errors() []Error { - return d.errors -} +// Parse processes a T-SQL source file from the given input. +// +// Parsing proceeds in phases: +// 1. Parse pragma comments at the file start (--sqlcode:...) +// 2. Parse batches sequentially, separated by GO +// +// The first batch has special rules: it may contain DECLARE statements +// for sqlcode constants. CREATE statements may appear in any batch, +// but procedures/functions must be alone in their batch (T-SQL requirement). +// +// Errors are accumulated in the document rather than stopping parsing, +// allowing partial results even with syntax errors. +func (d *TSqlDocument) Parse(input []byte, file sqldocument.FileRef) error { + s := &Scanner{} + s.SetInput(input) + s.SetFile(file) -func (d *TSqlDocument) Parse(s *Scanner) error { err := d.ParsePragmas(s) if err != nil { d.addError(s, err.Error()) @@ -49,13 +63,29 @@ func (d *TSqlDocument) Parse(s *Scanner) error { return nil } +func (d TSqlDocument) HasErrors() bool { + return len(d.errors) > 0 +} + +func (d TSqlDocument) Creates() []sqldocument.Create { + return d.creates +} + +func (d TSqlDocument) Declares() []sqldocument.Declare { + return d.declares +} + +func (d TSqlDocument) Errors() []sqldocument.Error { + return d.errors +} + func (d *TSqlDocument) Sort() { // Do the topological sort; and include any error with it as part // of `result`, *not* return it as err - sortedCreates, errpos, sortErr := TopologicalSort(d.creates) + sortedCreates, errpos, sortErr := sqldocument.TopologicalSort(d.creates) if sortErr != nil { - d.errors = append(d.errors, Error{ + d.errors = append(d.errors, sqldocument.Error{ Pos: errpos, Message: sortErr.Error(), }) @@ -64,29 +94,7 @@ func (d *TSqlDocument) Sort() { } } -// Transform a TSqlDocument to remove all Position information; this is used -// to 'unclutter' a DOM to more easily write assertions on it. -func (d TSqlDocument) WithoutPos() Document { - var cs []Create - for _, x := range d.creates { - cs = append(cs, x.WithoutPos()) - } - var ds []Declare - for _, x := range d.declares { - ds = append(ds, x.WithoutPos()) - } - var es []Error - for _, x := range d.errors { - es = append(es, x.WithoutPos()) - } - return &TSqlDocument{ - creates: cs, - declares: ds, - errors: es, - } -} - -func (d *TSqlDocument) Include(other Document) { +func (d *TSqlDocument) Include(other sqldocument.Document) { // Do not copy pragmaIncludeIf, since that is local to a single file. // Its contents is also present in each Create. d.declares = append(d.declares, other.Declares()...) @@ -98,67 +106,67 @@ func (d TSqlDocument) Empty() bool { return len(d.creates) == 0 || len(d.declares) == 0 } -func (d *TSqlDocument) addError(s *Scanner, msg string) { - d.errors = append(d.errors, Error{ +func (d *TSqlDocument) addError(s sqldocument.Scanner, msg string) { + d.errors = append(d.errors, sqldocument.Error{ Pos: s.Start(), Message: msg, }) } -func (d *TSqlDocument) unexpectedTokenError(s *Scanner) { +func (d *TSqlDocument) unexpectedTokenError(s sqldocument.Scanner) { d.addError(s, "Unexpected: "+s.Token()) } -func (doc *TSqlDocument) parseTypeExpression(s *Scanner) (t Type) { +func (doc *TSqlDocument) parseTypeExpression(s sqldocument.Scanner) (t sqldocument.Type) { parseArgs := func() { // parses *after* the initial (; consumes trailing ) for { switch { - case s.TokenType() == NumberToken: + case s.TokenType() == sqldocument.NumberToken: t.Args = append(t.Args, s.Token()) - case s.TokenType() == UnquotedIdentifierToken && s.TokenLower() == "max": + case s.TokenType() == sqldocument.UnquotedIdentifierToken && s.TokenLower() == "max": t.Args = append(t.Args, "max") default: doc.unexpectedTokenError(s) - RecoverToNextStatement(s, TSQLStatementTokens) + sqldocument.RecoverToNextStatement(s, TSQLStatementTokens) return } s.NextNonWhitespaceCommentToken() switch { - case s.TokenType() == CommaToken: + case s.TokenType() == sqldocument.CommaToken: s.NextNonWhitespaceCommentToken() continue - case s.TokenType() == RightParenToken: + case s.TokenType() == sqldocument.RightParenToken: s.NextNonWhitespaceCommentToken() return default: doc.unexpectedTokenError(s) - RecoverToNextStatement(s, TSQLStatementTokens) + sqldocument.RecoverToNextStatement(s, TSQLStatementTokens) return } } } - if s.TokenType() != UnquotedIdentifierToken { + if s.TokenType() != sqldocument.UnquotedIdentifierToken { panic("assertion failed, bug in caller") } t.BaseType = s.Token() s.NextNonWhitespaceCommentToken() - if s.TokenType() == LeftParenToken { + if s.TokenType() == sqldocument.LeftParenToken { s.NextNonWhitespaceCommentToken() parseArgs() } return } -func (doc *TSqlDocument) parseDeclare(s *Scanner) (result []Declare) { +func (doc *TSqlDocument) parseDeclare(s sqldocument.Scanner) (result []sqldocument.Declare) { declareStart := s.Start() // parse what is *after* the `declare` reserved keyword loop: for { - if s.TokenType() != VariableIdentifierToken { + if s.TokenType() != sqldocument.VariableIdentifierToken { doc.unexpectedTokenError(s) - RecoverToNextStatement(s, TSQLStatementTokens) + sqldocument.RecoverToNextStatement(s, TSQLStatementTokens) return } @@ -170,41 +178,41 @@ loop: } s.NextNonWhitespaceCommentToken() - var variableType Type + var variableType sqldocument.Type switch s.TokenType() { - case EqualToken: + case sqldocument.EqualToken: doc.addError(s, "sqlcode constants needs a type declared explicitly") s.NextNonWhitespaceCommentToken() - case UnquotedIdentifierToken: + case sqldocument.UnquotedIdentifierToken: variableType = doc.parseTypeExpression(s) } - if s.TokenType() != EqualToken { + if s.TokenType() != sqldocument.EqualToken { doc.addError(s, "sqlcode constants needs to be assigned at once using =") - RecoverToNextStatement(s, TSQLStatementTokens) + sqldocument.RecoverToNextStatement(s, TSQLStatementTokens) } switch s.NextNonWhitespaceCommentToken() { - case NumberToken, NVarcharLiteralToken, VarcharLiteralToken: - declare := Declare{ + case sqldocument.NumberToken, NVarcharLiteralToken, VarcharLiteralToken: + declare := sqldocument.Declare{ Start: declareStart, Stop: s.Stop(), VariableName: variableName, Datatype: variableType, - Literal: CreateUnparsed(s), + Literal: sqldocument.CreateUnparsed(s), } result = append(result, declare) default: doc.unexpectedTokenError(s) - RecoverToNextStatement(s, TSQLStatementTokens) + sqldocument.RecoverToNextStatement(s, TSQLStatementTokens) return } switch s.NextNonWhitespaceCommentToken() { - case CommaToken: + case sqldocument.CommaToken: s.NextNonWhitespaceCommentToken() continue - case SemicolonToken: + case sqldocument.SemicolonToken: s.NextNonWhitespaceCommentToken() break loop default: @@ -217,16 +225,16 @@ loop: return } -func (doc *TSqlDocument) parseBatchSeparator(s *Scanner) { +func (doc *TSqlDocument) parseBatchSeparator(s sqldocument.Scanner) { // just saw a 'go'; just make sure there's nothing bad trailing it // (if there is, convert to errors and move on until the line is consumed errorEmitted := false // continuously process tokens until a non-whitespace, non-malformed token is encountered. for { switch s.NextToken() { - case WhitespaceToken: + case sqldocument.WhitespaceToken: continue - case MalformedBatchSeparatorToken: + case sqldocument.MalformedBatchSeparatorToken: if !errorEmitted { doc.addError(s, "`go` should be alone on a line without any comments") errorEmitted = true @@ -238,36 +246,45 @@ func (doc *TSqlDocument) parseBatchSeparator(s *Scanner) { } } -func (doc *TSqlDocument) parseDeclareBatch(s *Scanner) (hasMore bool) { +func (doc *TSqlDocument) parseDeclareBatch(s sqldocument.Scanner) (hasMore bool) { if s.ReservedWord() != "declare" { panic("assertion failed, incorrect use in caller") } for { tt := s.TokenType() switch { - case tt == EOFToken: + case tt == sqldocument.EOFToken: return false - case tt == ReservedWordToken && s.ReservedWord() == "declare": + case tt == sqldocument.ReservedWordToken && s.ReservedWord() == "declare": s.NextNonWhitespaceCommentToken() d := doc.parseDeclare(s) doc.declares = append(doc.declares, d...) - case tt == ReservedWordToken && s.ReservedWord() != "declare": + case tt == sqldocument.ReservedWordToken && s.ReservedWord() != "declare": doc.addError(s, "Only 'declare' allowed in this batch") - RecoverToNextStatement(s, TSQLStatementTokens) - case tt == BatchSeparatorToken: + sqldocument.RecoverToNextStatement(s, TSQLStatementTokens) + case tt == sqldocument.BatchSeparatorToken: doc.parseBatchSeparator(s) return true default: doc.unexpectedTokenError(s) - RecoverToNextStatement(s, TSQLStatementTokens) + sqldocument.RecoverToNextStatement(s, TSQLStatementTokens) } } } -func (doc *TSqlDocument) parseBatch(s *Scanner, isFirst bool) (hasMore bool) { - batch := &Batch{ - TokenHandlers: map[string]func(*Scanner, *Batch) bool{ - "declare": func(s *Scanner, n *Batch) bool { +// parseBatch processes a single T-SQL batch (content between GO separators). +// +// Batch processing strategy: +// - Track tokens before the first significant statement for docstrings +// - Dispatch to specialized parsers based on statement type (CREATE, DECLARE) +// - Handle batch separator (GO) to signal batch boundary +// +// The isFirst parameter indicates whether this is the first batch in the file, +// which affects whether DECLARE statements are allowed. +func (doc *TSqlDocument) parseBatch(s sqldocument.Scanner, isFirst bool) (hasMore bool) { + batch := &sqldocument.Batch{ + TokenHandlers: map[string]func(sqldocument.Scanner, *sqldocument.Batch) bool{ + "declare": func(s sqldocument.Scanner, n *sqldocument.Batch) bool { // First declare-statement; enter a mode where we assume all contents // of batch are declare statements if !isFirst { @@ -277,7 +294,7 @@ func (doc *TSqlDocument) parseBatch(s *Scanner, isFirst bool) (hasMore bool) { // regardless of errors, go on and parse as far as we get... return doc.parseDeclareBatch(s) }, - "create": func(s *Scanner, n *Batch) bool { + "create": func(s sqldocument.Scanner, n *sqldocument.Batch) bool { // should be start of create procedure or create function... c := doc.parseCreate(s, n.CreateStatements) c.Driver = &mssql.Driver{} @@ -287,6 +304,10 @@ func (doc *TSqlDocument) parseBatch(s *Scanner, isFirst bool) (hasMore bool) { c.Body = append(n.Nodes, c.Body...) c.Docstring = n.DocString doc.creates = append(doc.creates, c) + + // fmt.Printf("%#v\n", s) + // fmt.Printf("%#v\n", n) + // fmt.Printf("%#v\n", doc) return false }, }, @@ -299,48 +320,55 @@ func (doc *TSqlDocument) parseBatch(s *Scanner, isFirst bool) (hasMore bool) { return hasMore } -// parseCreate parses anything that starts with "create". Position is -// *on* the create token. -// At this stage in sqlcode parser development we're only interested -// in procedures/functions/types as opaque blocks of SQL code where -// we only track dependencies between them and their declared name; -// so we treat them with the same code. We consume until the end of -// the batch; only one declaration allowed per batch. Everything -// parsed here will also be added to `batch`. On any error, copying -// to batch stops / becomes erratic.. -func (d *TSqlDocument) parseCreate(s *Scanner, createCountInBatch int) (result Create) { +// parseCreate parses CREATE PROCEDURE/FUNCTION/TYPE statements. +// +// This is the core of the sqlcode parser. It: +// 1. Validates the CREATE type is one we support (procedure/function/type) +// 2. Extracts the object name from [code].ObjectName syntax +// 3. Copies the entire statement body for later emission +// 4. Tracks dependencies by finding [code].OtherObject references +// +// The parser is intentionally permissive about T-SQL syntax details, +// delegating full validation to SQL Server. It focuses on extracting +// the structural information needed for dependency ordering and code generation. +// +// Parameters: +// - s: Scanner positioned on the CREATE keyword +// - createCountInBatch: Number of CREATE statements already seen in this batch +// (used to enforce "one procedure/function per batch" rule) +func (d *TSqlDocument) parseCreate(s sqldocument.Scanner, createCountInBatch int) (result sqldocument.Create) { if s.ReservedWord() != "create" { panic("illegal use by caller") } - CopyToken(s, &result.Body) + sqldocument.CopyToken(s, &result.Body) - NextTokenCopyingWhitespace(s, &result.Body) + sqldocument.NextTokenCopyingWhitespace(s, &result.Body) createType := strings.ToLower(s.Token()) if !(createType == "procedure" || createType == "function" || createType == "type") { d.addError(s, fmt.Sprintf("sqlcode only supports creating procedures, functions or types; not `%s`", createType)) - RecoverToNextStatementCopying(s, &result.Body, TSQLStatementTokens) + sqldocument.RecoverToNextStatementCopying(s, &result.Body, TSQLStatementTokens) return } if (createType == "procedure" || createType == "function") && createCountInBatch > 0 { d.addError(s, "a procedure/function must be alone in a batch; use 'go' to split batches") - RecoverToNextStatementCopying(s, &result.Body, TSQLStatementTokens) + sqldocument.RecoverToNextStatementCopying(s, &result.Body, TSQLStatementTokens) return } result.CreateType = createType - CopyToken(s, &result.Body) + sqldocument.CopyToken(s, &result.Body) - NextTokenCopyingWhitespace(s, &result.Body) + sqldocument.NextTokenCopyingWhitespace(s, &result.Body) // Insist on [code]. - if s.TokenType() != QuotedIdentifierToken || s.Token() != "[code]" { + if s.TokenType() != sqldocument.QuotedIdentifierToken || s.Token() != "[code]" { d.addError(s, fmt.Sprintf("create %s must be followed by [code].", result.CreateType)) - RecoverToNextStatementCopying(s, &result.Body, TSQLStatementTokens) + sqldocument.RecoverToNextStatementCopying(s, &result.Body, TSQLStatementTokens) return } var err error - result.QuotedName, err = ParseCodeschemaName(s, &result.Body, TSQLStatementTokens) + result.QuotedName, err = sqldocument.ParseCodeschemaName(s, &result.Body, TSQLStatementTokens) if err != nil { d.addError(s, err.Error()) } @@ -358,7 +386,7 @@ tailloop: for { tt := s.TokenType() switch { - case tt == ReservedWordToken && s.ReservedWord() == "create": + case tt == sqldocument.ReservedWordToken && s.ReservedWord() == "create": // So, we're currently parsing 'create ...' and we see another 'create'. // We split in two cases depending on the context we are currently in // (createType is referring to how we entered this function, *NOT* the @@ -371,13 +399,13 @@ tailloop: // // What is important is a function/procedure/type isn't started on without a 'go' // in between; so we block those 3 from appearing in the same batch - CopyToken(s, &result.Body) - NextTokenCopyingWhitespace(s, &result.Body) + sqldocument.CopyToken(s, &result.Body) + sqldocument.NextTokenCopyingWhitespace(s, &result.Body) tt2 := s.TokenType() - if (tt2 == ReservedWordToken && (s.ReservedWord() == "function" || s.ReservedWord() == "procedure")) || - (tt2 == UnquotedIdentifierToken && s.TokenLower() == "type") { - RecoverToNextStatementCopying(s, &result.Body, TSQLStatementTokens) + if (tt2 == sqldocument.ReservedWordToken && (s.ReservedWord() == "function" || s.ReservedWord() == "procedure")) || + (tt2 == sqldocument.UnquotedIdentifierToken && s.TokenLower() == "type") { + sqldocument.RecoverToNextStatementCopying(s, &result.Body, TSQLStatementTokens) d.addError(s, "a procedure/function must be alone in a batch; use 'go' to split batches") return } @@ -390,11 +418,11 @@ tailloop: panic("assertion failed") } - case tt == EOFToken || tt == BatchSeparatorToken: + case tt == sqldocument.EOFToken || tt == sqldocument.BatchSeparatorToken: break tailloop - case tt == QuotedIdentifierToken && s.Token() == "[code]": + case tt == sqldocument.QuotedIdentifierToken && s.Token() == "[code]": // Parse a dependency - dep, err := ParseCodeschemaName(s, &result.Body, TSQLStatementTokens) + dep, err := sqldocument.ParseCodeschemaName(s, &result.Body, TSQLStatementTokens) if err != nil { d.addError(s, err.Error()) } @@ -408,9 +436,9 @@ tailloop: if !found { result.DependsOn = append(result.DependsOn, dep) } - case tt == ReservedWordToken && s.Token() == "as": - CopyToken(s, &result.Body) - NextTokenCopyingWhitespace(s, &result.Body) + case tt == sqldocument.ReservedWordToken && s.Token() == "as": + sqldocument.CopyToken(s, &result.Body) + sqldocument.NextTokenCopyingWhitespace(s, &result.Body) /* TODO: Fix and re-enable This code add RoutineName for convenience. So: @@ -448,8 +476,8 @@ tailloop: */ default: - CopyToken(s, &result.Body) - NextTokenCopyingWhitespace(s, &result.Body) + sqldocument.CopyToken(s, &result.Body) + sqldocument.NextTokenCopyingWhitespace(s, &result.Body) } } diff --git a/sqlparser/mssql/document_test.go b/sqlparser/mssql/document_test.go new file mode 100644 index 0000000..51122fa --- /dev/null +++ b/sqlparser/mssql/document_test.go @@ -0,0 +1,1273 @@ +package mssql + +import ( + "testing" + + mssql "github.com/microsoft/go-mssqldb" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/vippsas/sqlcode/sqlparser/sqldocument" +) + +func ParseString(t *testing.T, file, input string) *TSqlDocument { + d := &TSqlDocument{} + assert.NoError(t, d.Parse([]byte(input), sqldocument.FileRef(file))) + return d +} + +func TestParserSmokeTest(t *testing.T) { + doc := ParseString(t, "test.sql", ` +/* test is a test + +declare @EnumFoo int = 2; + +*/ + +declare/*comment*/@EnumBar1 varchar (max) = N'declare @EnumThisIsInString'; +declare + + + @EnumBar2 int = 20, + @EnumBar3 int=21; + +GO + +declare @EnumNextBatch int = 3; + +go +-- preceding comment 1 +/* preceding comment 2 + +asdfasdf */create procedure [code].TestFunc as begin + refers to [code].OtherFunc [code].HelloFunc; + create table x ( int x not null ); -- should be ok +end; + +/* trailing comment */ +`) + require.Equal(t, 1, len(doc.Creates())) + c := doc.Creates()[0] + require.Equal(t, &mssql.Driver{}, c.Driver) + + assert.Equal(t, "[TestFunc]", c.QuotedName.Value) + assert.Equal(t, []string{"[HelloFunc]", "[OtherFunc]"}, c.DependsOnStrings()) + assert.Equal(t, `-- preceding comment 1 +/* preceding comment 2 + +asdfasdf */create procedure [code].TestFunc as begin + refers to [code].OtherFunc [code].HelloFunc; + create table x ( int x not null ); -- should be ok +end; + +/* trailing comment */ +`, c.String()) + + assert.Equal(t, + []sqldocument.Error{ + { + Message: "'declare' statement only allowed in first batch", + }, + }, doc.Errors()) + + assert.Equal(t, + []sqldocument.Declare{ + { + VariableName: "@EnumBar1", + Datatype: sqldocument.Type{ + BaseType: "varchar", + Args: []string{ + "max", + }, + }, + Literal: sqldocument.Unparsed{ + Type: NVarcharLiteralToken, + RawValue: "N'declare @EnumThisIsInString'", + }, + }, + { + VariableName: "@EnumBar2", + Datatype: sqldocument.Type{ + BaseType: "int", + }, + Literal: sqldocument.Unparsed{ + Type: sqldocument.NumberToken, + RawValue: "20", + }, + }, + { + VariableName: "@EnumBar3", + Datatype: sqldocument.Type{ + BaseType: "int", + }, + Literal: sqldocument.Unparsed{ + Type: sqldocument.NumberToken, + RawValue: "21", + }, + }, + { + VariableName: "@EnumNextBatch", + Datatype: sqldocument.Type{ + BaseType: "int", + }, + Literal: sqldocument.Unparsed{ + Type: sqldocument.NumberToken, + RawValue: "3", + }, + }, + }, + doc.Declares(), + ) + // repr.Println(doc) +} + +// import ( +// "fmt" +// "strings" +// "testing" + +// "github.com/stretchr/testify/assert" +// "github.com/stretchr/testify/require" +// "github.com/vippsas/sqlcode/sqlparser/sqldocument" +// ) + +// // Helper to parse a document from input string +// func parseDocument(input string) *TSqlDocument { +// doc := &TSqlDocument{} +// s := NewScanner("test.sql", input) +// s.NextToken() +// doc.Parse(s) +// return doc +// } + +// func TestTSqlDocument(t *testing.T) { +// t.Run("addError", func(t *testing.T) { +// t.Run("adds error with position", func(t *testing.T) { +// doc := &TSqlDocument{} +// s := NewScanner("test.sql", "select") +// s.NextToken() + +// doc.addError(s, "test error message") +// require.True(t, doc.HasErrors()) +// assert.Equal(t, "test error message", doc.errors[0].Message) +// assert.Equal(t, Pos{File: "test.sql", Line: 1, Col: 1}, doc.errors[0].Pos) +// }) + +// t.Run("accumulates multiple errors", func(t *testing.T) { +// doc := &TSqlDocument{} +// s := NewScanner("test.sql", "abc def") +// s.NextToken() +// doc.addError(s, "error 1") +// s.NextToken() +// doc.addError(s, "error 2") + +// require.Len(t, doc.errors, 2) +// assert.Equal(t, "error 1", doc.errors[0].Message) +// assert.Equal(t, "error 2", doc.errors[1].Message) +// }) + +// t.Run("creates error with token text", func(t *testing.T) { +// doc := &TSqlDocument{} +// s := NewScanner("test.sql", "unexpected_token") +// s.NextToken() + +// doc.unexpectedTokenError(s) + +// require.Len(t, doc.errors, 1) +// assert.Equal(t, "Unexpected: unexpected_token", doc.errors[0].Message) +// }) +// }) + +// t.Run("parseTypeExpression", func(t *testing.T) { +// t.Run("parses simple type without args", func(t *testing.T) { +// doc := &TSqlDocument{} +// s := NewScanner("test.sql", "int") +// s.NextToken() + +// typ := doc.parseTypeExpression(s) + +// assert.Equal(t, "int", typ.BaseType) +// assert.Empty(t, typ.Args) +// }) + +// t.Run("parses type with single arg", func(t *testing.T) { +// doc := &TSqlDocument{} +// s := NewScanner("test.sql", "varchar(50)") +// s.NextToken() + +// typ := doc.parseTypeExpression(s) + +// assert.Equal(t, "varchar", typ.BaseType) +// assert.Equal(t, []string{"50"}, typ.Args) +// }) + +// t.Run("parses type with multiple args", func(t *testing.T) { +// doc := &TSqlDocument{} +// s := NewScanner("test.sql", "decimal(10, 2)") +// s.NextToken() + +// typ := doc.parseTypeExpression(s) + +// assert.Equal(t, "decimal", typ.BaseType) +// assert.Equal(t, []string{"10", "2"}, typ.Args) +// }) + +// t.Run("parses type with max", func(t *testing.T) { +// doc := &TSqlDocument{} +// s := NewScanner("test.sql", "nvarchar(max)") +// s.NextToken() + +// typ := doc.parseTypeExpression(s) + +// assert.Equal(t, "nvarchar", typ.BaseType) +// assert.Equal(t, []string{"max"}, typ.Args) +// }) + +// t.Run("handles invalid arg", func(t *testing.T) { +// doc := &TSqlDocument{} +// s := NewScanner("test.sql", "varchar(invalid)") +// s.NextToken() + +// typ := doc.parseTypeExpression(s) + +// assert.Equal(t, "varchar", typ.BaseType) +// assert.NotEmpty(t, doc.errors) +// }) + +// t.Run("panics if not on identifier", func(t *testing.T) { +// doc := &TSqlDocument{} +// s := NewScanner("test.sql", "123") +// s.NextToken() + +// assert.Panics(t, func() { +// doc.parseTypeExpression(s) +// }) +// }) +// }) +// } + +// func TestDocument_parseDeclare(t *testing.T) { +// t.Run("parses single enum declaration", func(t *testing.T) { +// doc := &TSqlDocument{} +// s := NewScanner("test.sql", "@EnumStatus int = 42") +// s.NextToken() + +// declares := doc.parseDeclare(s) + +// require.Len(t, declares, 1) +// assert.Equal(t, "@EnumStatus", declares[0].VariableName) +// assert.Equal(t, "int", declares[0].Datatype.BaseType) +// assert.Equal(t, "42", declares[0].Literal.RawValue) +// }) + +// t.Run("parses multiple declarations with comma", func(t *testing.T) { +// doc := &TSqlDocument{} +// s := NewScanner("test.sql", "@EnumA int = 1, @EnumB int = 2;") +// s.NextToken() + +// declares := doc.parseDeclare(s) + +// require.Len(t, declares, 2) +// assert.Equal(t, "@EnumA", declares[0].VariableName) +// assert.Equal(t, "@EnumB", declares[1].VariableName) +// }) + +// t.Run("parses string literal", func(t *testing.T) { +// doc := &TSqlDocument{} +// s := NewScanner("test.sql", "@EnumName nvarchar(50) = N'test'") +// s.NextToken() + +// declares := doc.parseDeclare(s) + +// require.Len(t, declares, 1) +// assert.Equal(t, "N'test'", declares[0].Literal.RawValue) +// }) + +// t.Run("errors on invalid variable name", func(t *testing.T) { +// doc := &TSqlDocument{} +// s := NewScanner("test.sql", "@InvalidName int = 1") +// s.NextToken() + +// declares := doc.parseDeclare(s) + +// // in this case when we detect the missing prefix, +// // we add an error and continue parsing the declaration. +// // this results with it being added +// require.Len(t, declares, 1) +// assert.NotEmpty(t, doc.errors) +// assert.Contains(t, doc.errors[0].Message, "@InvalidName") +// }) + +// t.Run("errors on missing type", func(t *testing.T) { +// doc := &TSqlDocument{} +// s := NewScanner("test.sql", "@EnumTest = 42") +// s.NextToken() + +// declares := doc.parseDeclare(s) + +// require.Len(t, declares, 0) +// assert.NotEmpty(t, doc.errors) +// assert.Contains(t, doc.errors[0].Message, "type declared explicitly") +// }) + +// t.Run("errors on missing assignment", func(t *testing.T) { +// doc := &TSqlDocument{} +// s := NewScanner("test.sql", "@EnumTest int") +// s.NextToken() + +// declares := doc.parseDeclare(s) + +// require.Len(t, declares, 0) +// assert.NotEmpty(t, doc.errors) +// assert.Contains(t, doc.errors[0].Message, "needs to be assigned") +// }) + +// t.Run("accepts @Global prefix", func(t *testing.T) { +// doc := &TSqlDocument{} +// s := NewScanner("test.sql", "@GlobalSetting int = 100") +// s.NextToken() + +// declares := doc.parseDeclare(s) + +// require.Len(t, declares, 1) +// assert.Equal(t, "@GlobalSetting", declares[0].VariableName) +// assert.Empty(t, doc.errors) +// }) + +// t.Run("accepts @Const prefix", func(t *testing.T) { +// doc := &TSqlDocument{} +// s := NewScanner("test.sql", "@ConstValue int = 200") +// s.NextToken() + +// declares := doc.parseDeclare(s) + +// require.Len(t, declares, 1) +// assert.Equal(t, "@ConstValue", declares[0].VariableName) +// assert.Empty(t, doc.errors) +// }) +// } + +// func TestDocument_parseBatchSeparator(t *testing.T) { +// t.Run("parses valid go separator", func(t *testing.T) { +// doc := &TSqlDocument{} +// s := NewScanner("test.sql", "go\n") +// s.NextToken() + +// doc.parseBatchSeparator(s) + +// assert.Empty(t, doc.errors) +// }) + +// t.Run("errors on malformed separator", func(t *testing.T) { +// doc := &TSqlDocument{} +// s := NewScanner("test.sql", "go -- comment") +// tt := s.NextToken() +// fmt.Printf("%#v %#v\n", s, tt) +// doc.parseBatchSeparator(s) +// fmt.Printf("%#v\n", s) + +// assert.NotEmpty(t, doc.errors) +// assert.Contains(t, doc.errors[0].Message, "should be alone") +// }) +// } + +// func TestDocument_parseCreate(t *testing.T) { +// t.Run("parses simple procedure", func(t *testing.T) { +// doc := &TSqlDocument{} +// s := NewScanner("test.sql", "create procedure [code].TestProc as begin end") +// s.NextToken() +// s.NextNonWhitespaceCommentToken() + +// create := doc.parseCreate(s, 0) + +// assert.Equal(t, "procedure", create.CreateType) +// assert.Equal(t, "[TestProc]", create.QuotedName.Value) +// assert.NotEmpty(t, create.Body) +// }) + +// t.Run("parses function", func(t *testing.T) { +// doc := &TSqlDocument{} +// s := NewScanner("test.sql", "create function [code].TestFunc() returns int as begin return 1 end") +// s.NextToken() +// s.NextNonWhitespaceCommentToken() + +// create := doc.parseCreate(s, 0) + +// assert.Equal(t, "function", create.CreateType) +// assert.Equal(t, "[TestFunc]", create.QuotedName.Value) +// }) + +// t.Run("parses type", func(t *testing.T) { +// doc := &TSqlDocument{} +// s := NewScanner("test.sql", "create type [code].TestType as table (id int)") +// s.NextToken() +// s.NextNonWhitespaceCommentToken() + +// create := doc.parseCreate(s, 0) + +// assert.Equal(t, "type", create.CreateType) +// assert.Equal(t, "[TestType]", create.QuotedName.Value) +// }) + +// t.Run("tracks dependencies", func(t *testing.T) { +// doc := &TSqlDocument{} +// s := NewScanner("test.sql", "create procedure [code].Proc1 as begin select * from [code].Table1 join [code].Table2 end") +// s.NextToken() +// s.NextNonWhitespaceCommentToken() + +// create := doc.parseCreate(s, 0) + +// require.Len(t, create.DependsOn, 2) +// assert.Equal(t, "[Table1]", create.DependsOn[0].Value) +// assert.Equal(t, "[Table2]", create.DependsOn[1].Value) +// }) + +// t.Run("deduplicates dependencies", func(t *testing.T) { +// doc := &TSqlDocument{} +// s := NewScanner("test.sql", "create procedure [code].Proc1 as begin select * from [code].Table1; select * from [code].Table1 end") +// s.NextToken() +// s.NextNonWhitespaceCommentToken() + +// create := doc.parseCreate(s, 0) + +// require.Len(t, create.DependsOn, 1) +// assert.Equal(t, "[Table1]", create.DependsOn[0].Value) +// }) + +// t.Run("errors on unsupported create type", func(t *testing.T) { +// doc := &TSqlDocument{} +// s := NewScanner("test.sql", "create table [code].TestTable (id int)") +// s.NextToken() +// s.NextNonWhitespaceCommentToken() + +// doc.parseCreate(s, 0) + +// assert.NotEmpty(t, doc.errors) +// assert.Contains(t, doc.errors[0].Message, "only supports creating procedures") +// }) + +// t.Run("errors on multiple procedures in batch", func(t *testing.T) { +// doc := &TSqlDocument{} +// s := NewScanner("test.sql", "create procedure [code].Proc2 as begin end") +// s.NextToken() +// s.NextNonWhitespaceCommentToken() + +// doc.parseCreate(s, 1) + +// assert.NotEmpty(t, doc.errors) +// assert.Contains(t, doc.errors[0].Message, "must be alone in a batch") +// }) + +// t.Run("errors on missing code schema", func(t *testing.T) { +// doc := &TSqlDocument{} +// s := NewScanner("test.sql", "create procedure dbo.TestProc as begin end") +// s.NextToken() +// s.NextNonWhitespaceCommentToken() + +// doc.parseCreate(s, 0) + +// assert.NotEmpty(t, doc.errors) +// assert.Contains(t, doc.errors[0].Message, "must be followed by [code]") +// }) + +// t.Run("allows create index inside procedure", func(t *testing.T) { +// doc := &TSqlDocument{} +// s := NewScanner("test.sql", "create procedure [code].Proc as begin create index IX_Test on #temp(id) end") +// s.NextToken() +// s.NextNonWhitespaceCommentToken() + +// create := doc.parseCreate(s, 0) + +// assert.Equal(t, "procedure", create.CreateType) +// assert.Empty(t, doc.errors) +// }) + +// t.Run("stops at batch separator", func(t *testing.T) { +// doc := &TSqlDocument{} +// s := NewScanner("test.sql", "create procedure [code].Proc as begin end\ngo") +// s.NextToken() +// s.NextNonWhitespaceCommentToken() + +// create := doc.parseCreate(s, 0) + +// assert.Equal(t, "[Proc]", create.QuotedName.Value) +// assert.Equal(t, BatchSeparatorToken, s.TokenType()) +// }) + +// t.Run("panics if not on create token", func(t *testing.T) { +// doc := &TSqlDocument{} +// s := NewScanner("test.sql", "procedure") +// s.NextToken() + +// assert.Panics(t, func() { +// doc.parseCreate(s, 0) +// }) +// }) +// } + +// func TestNextTokenCopyingWhitespace(t *testing.T) { +// t.Run("copies whitespace tokens", func(t *testing.T) { +// s := NewScanner("test.sql", " \n\t token") +// var target []Unparsed + +// NextTokenCopyingWhitespace(s, &target) + +// assert.NotEmpty(t, target) +// assert.Equal(t, UnquotedIdentifierToken, s.TokenType()) +// }) + +// t.Run("copies comments", func(t *testing.T) { +// s := NewScanner("test.sql", "/* comment */ -- line\ntoken") +// var target []Unparsed + +// NextTokenCopyingWhitespace(s, &target) + +// assert.True(t, len(target) >= 2) +// assert.Equal(t, UnquotedIdentifierToken, s.TokenType()) +// }) + +// t.Run("stops at EOF", func(t *testing.T) { +// s := NewScanner("test.sql", " ") +// var target []Unparsed + +// NextTokenCopyingWhitespace(s, &target) + +// assert.Equal(t, EOFToken, s.TokenType()) +// }) +// } + +// func TestCreateUnparsed(t *testing.T) { +// t.Run("creates unparsed from scanner", func(t *testing.T) { +// s := NewScanner("test.sql", "select") +// s.NextToken() + +// unparsed := CreateUnparsed(s) + +// assert.Equal(t, ReservedWordToken, unparsed.Type) +// assert.Equal(t, "select", unparsed.RawValue) +// assert.Equal(t, Pos{File: "test.sql", Line: 1, Col: 1}, unparsed.Start) +// }) +// } + +// func TestDocument_CreateProcedure(t *testing.T) { +// input := ` +// CREATE PROCEDURE [code].[MyProc] +// @Param1 nvarchar(100) +// AS +// BEGIN +// SELECT @Param1 +// END +// go +// ` +// doc := parseDocument(input) + +// if doc.HasErrors() { +// t.Fatalf("unexpected errors: %+v", doc.Errors()) +// } + +// creates := doc.Creates() +// if len(creates) != 1 { +// t.Fatalf("expected 1 create, got %d", len(creates)) +// } + +// c := creates[0] +// if c.CreateType != "procedure" { +// t.Errorf("expected createType 'procedure', got %q", c.CreateType) +// } +// if c.QuotedName.Value != "[MyProc]" { +// t.Errorf("expected name '[MyProc]', got %q", c.QuotedName.Value) +// } +// } + +// func TestDocument_CreateFunction(t *testing.T) { +// input := ` +// CREATE FUNCTION [code].[GetValue] +// ( +// @Input int +// ) +// RETURNS int +// AS +// BEGIN +// RETURN @Input * 2 +// END +// go +// ` +// doc := parseDocument(input) + +// if doc.HasErrors() { +// t.Fatalf("unexpected errors: %+v", doc.Errors()) +// } + +// creates := doc.Creates() +// if len(creates) != 1 { +// t.Fatalf("expected 1 create, got %d", len(creates)) +// } + +// c := creates[0] +// if c.CreateType != "function" { +// t.Errorf("expected createType 'function', got %q", c.CreateType) +// } +// if c.QuotedName.Value != "[GetValue]" { +// t.Errorf("expected name '[GetValue]', got %q", c.QuotedName.Value) +// } +// } + +// func TestDocument_CreateType(t *testing.T) { +// input := ` +// CREATE TYPE [code].[MyTableType] AS TABLE +// ( +// Id int, +// Name nvarchar(100) +// ) +// go +// ` +// doc := parseDocument(input) + +// if doc.HasErrors() { +// t.Fatalf("unexpected errors: %+v", doc.Errors()) +// } + +// creates := doc.Creates() +// if len(creates) != 1 { +// t.Fatalf("expected 1 create, got %d", len(creates)) +// } + +// c := creates[0] +// if c.CreateType != "type" { +// t.Errorf("expected createType 'type', got %q", c.CreateType) +// } +// } + +// func TestDocument_DeclareConstants(t *testing.T) { +// input := ` +// DECLARE @EnumStatus int = 1, +// @GlobalTimeout int = 30, +// @ConstName nvarchar(50) = N'TestValue'; +// go + +// CREATE PROCEDURE [code].[MyProc] +// AS +// BEGIN +// SELECT 1 +// END +// go +// ` +// doc := parseDocument(input) + +// if doc.HasErrors() { +// t.Fatalf("unexpected errors: %+v", doc.Errors()) +// } + +// declares := doc.Declares() +// if len(declares) != 3 { +// t.Fatalf("expected 3 declares, got %d", len(declares)) +// } + +// // Check first declare +// if declares[0].VariableName != "@EnumStatus" { +// t.Errorf("expected '@EnumStatus', got %q", declares[0].VariableName) +// } +// if declares[0].Datatype.BaseType != "int" { +// t.Errorf("expected type 'int', got %q", declares[0].Datatype.BaseType) +// } + +// // Check third declare with nvarchar type +// if declares[2].VariableName != "@ConstName" { +// t.Errorf("expected '@ConstName', got %q", declares[2].VariableName) +// } +// if declares[2].Datatype.BaseType != "nvarchar" { +// t.Errorf("expected type 'nvarchar', got %q", declares[2].Datatype.BaseType) +// } +// if len(declares[2].Datatype.Args) != 1 || declares[2].Datatype.Args[0] != "50" { +// t.Errorf("expected type args ['50'], got %v", declares[2].Datatype.Args) +// } +// } + +// func TestDocument_Dependencies(t *testing.T) { +// input := ` +// CREATE PROCEDURE [code].[ProcA] +// AS +// BEGIN +// EXEC [code].[ProcB] +// SELECT * FROM [code].[TableFunc]() +// EXEC [code].[ProcB] -- duplicate reference +// END +// go +// ` +// doc := parseDocument(input) + +// if doc.HasErrors() { +// t.Fatalf("unexpected errors: %+v", doc.Errors()) +// } + +// creates := doc.Creates() +// if len(creates) != 1 { +// t.Fatalf("expected 1 create, got %d", len(creates)) +// } + +// deps := creates[0].DependsOn +// if len(deps) != 2 { +// t.Fatalf("expected 2 unique dependencies, got %d: %+v", len(deps), deps) +// } + +// // Dependencies should be sorted +// if deps[0].Value != "[ProcB]" { +// t.Errorf("expected first dep '[ProcB]', got %q", deps[0].Value) +// } +// if deps[1].Value != "[TableFunc]" { +// t.Errorf("expected second dep '[TableFunc]', got %q", deps[1].Value) +// } +// } + +// func TestDocument_MultipleBatches(t *testing.T) { +// input := ` +// CREATE PROCEDURE [code].[ProcA] +// AS +// BEGIN +// SELECT 1 +// END +// go + +// CREATE PROCEDURE [code].[ProcB] +// AS +// BEGIN +// EXEC [code].[ProcA] +// END +// go + +// CREATE FUNCTION [code].[FuncC]() +// RETURNS int +// AS +// BEGIN +// RETURN 42 +// END +// go +// ` +// doc := parseDocument(input) + +// if doc.HasErrors() { +// t.Fatalf("unexpected errors: %+v", doc.Errors()) +// } + +// creates := doc.Creates() +// if len(creates) != 3 { +// t.Fatalf("expected 3 creates, got %d", len(creates)) +// } + +// // Verify order preserved +// if creates[0].QuotedName.Value != "[ProcA]" { +// t.Errorf("expected first create '[ProcA]', got %q", creates[0].QuotedName.Value) +// } +// if creates[1].QuotedName.Value != "[ProcB]" { +// t.Errorf("expected second create '[ProcB]', got %q", creates[1].QuotedName.Value) +// } +// if creates[2].QuotedName.Value != "[FuncC]" { +// t.Errorf("expected third create '[FuncC]', got %q", creates[2].QuotedName.Value) +// } + +// // ProcB should depend on ProcA +// if len(creates[1].DependsOn) != 1 || creates[1].DependsOn[0].Value != "[ProcA]" { +// t.Errorf("ProcB should depend on ProcA, got %+v", creates[1].DependsOn) +// } +// } + +// func TestDocument_MultipleTypesInBatch(t *testing.T) { +// input := ` +// CREATE TYPE [code].[Type1] AS TABLE (Id int) +// CREATE TYPE [code].[Type2] AS TABLE (Name nvarchar(50)) +// CREATE TYPE [code].[Type3] AS TABLE (Value decimal(10,2)) +// go +// ` +// doc := parseDocument(input) + +// if doc.HasErrors() { +// t.Fatalf("unexpected errors: %+v", doc.Errors()) +// } + +// creates := doc.Creates() +// if len(creates) != 3 { +// t.Fatalf("expected 3 type creates, got %d", len(creates)) +// } + +// for i, c := range creates { +// if c.CreateType != "type" { +// t.Errorf("create %d: expected type 'type', got %q", i, c.CreateType) +// } +// } +// } + +// func TestDocument_ErrorInvalidDeclarePrefix(t *testing.T) { +// input := ` +// DECLARE @InvalidName int = 1; +// go +// ` +// doc := parseDocument(input) + +// if !doc.HasErrors() { +// t.Fatal("expected error for invalid variable name prefix") +// } + +// errors := doc.Errors() +// found := false +// for _, err := range errors { +// if strings.Contains(err.Message, "@Enum") || strings.Contains(err.Message, "@Global") || strings.Contains(err.Message, "@Const") { +// found = true +// break +// } +// } +// if !found { +// t.Errorf("expected error about variable name prefix, got: %+v", errors) +// } +// } + +// func TestDocument_ErrorMissingCodeSchema(t *testing.T) { +// input := ` +// CREATE PROCEDURE [dbo].[MyProc] +// AS +// BEGIN +// SELECT 1 +// END +// go +// ` +// doc := parseDocument(input) + +// if !doc.HasErrors() { +// t.Fatal("expected error for missing [code] schema") +// } + +// errors := doc.Errors() +// found := false +// for _, err := range errors { +// if strings.Contains(err.Message, "[code]") { +// found = true +// break +// } +// } +// if !found { +// t.Errorf("expected error about [code] schema, got: %+v", errors) +// } +// } + +// func TestDocument_ErrorProcedureNotAloneInBatch(t *testing.T) { +// input := ` +// CREATE PROCEDURE [code].[ProcA] +// AS +// BEGIN +// SELECT 1 +// END + +// CREATE PROCEDURE [code].[ProcB] +// AS +// BEGIN +// SELECT 2 +// END +// go +// ` +// doc := parseDocument(input) + +// if !doc.HasErrors() { +// t.Fatal("expected error for multiple procedures in one batch") +// } + +// errors := doc.Errors() +// found := false +// for _, err := range errors { +// if strings.Contains(err.Message, "alone in a batch") { +// found = true +// break +// } +// } +// if !found { +// t.Errorf("expected error about procedure alone in batch, got: %+v", errors) +// } +// } + +// func TestDocument_ErrorDeclareInSecondBatch(t *testing.T) { +// input := ` +// CREATE PROCEDURE [code].[MyProc] +// AS +// BEGIN +// SELECT 1 +// END +// go + +// DECLARE @EnumValue int = 1; +// go +// ` +// doc := parseDocument(input) + +// if !doc.HasErrors() { +// t.Fatal("expected error for declare in second batch") +// } + +// errors := doc.Errors() +// found := false +// for _, err := range errors { +// if strings.Contains(err.Message, "first batch") { +// found = true +// break +// } +// } +// if !found { +// t.Errorf("expected error about first batch, got: %+v", errors) +// } +// } + +// func TestDocument_Pragma(t *testing.T) { +// input := `--sqlcode:include-if feature-flag +// CREATE PROCEDURE [code].[MyProc] +// AS +// BEGIN +// SELECT 1 +// END +// go +// ` +// doc := parseDocument(input) + +// if doc.HasErrors() { +// t.Fatalf("unexpected errors: %+v", doc.Errors()) +// } + +// creates := doc.Creates() +// if len(creates) != 1 { +// t.Fatalf("expected 1 create, got %d", len(creates)) +// } +// } + +// func TestDocument_ComplexProcedure(t *testing.T) { +// input := ` +// -- This procedure demonstrates complex T-SQL features +// CREATE PROCEDURE [code].[ComplexProc] +// @TableInput [code].[MyTableType] READONLY, +// @Status int OUTPUT, +// @Message nvarchar(max) OUTPUT +// AS +// BEGIN +// SET NOCOUNT ON; + +// BEGIN TRY +// BEGIN TRANSACTION; + +// -- Use CTE +// WITH CTE AS ( +// SELECT Id, Name, ROW_NUMBER() OVER (ORDER BY Id) AS RowNum +// FROM @TableInput +// ) +// INSERT INTO SomeTable (Id, Name) +// SELECT Id, Name FROM CTE WHERE RowNum <= 100; + +// -- Call another procedure +// EXEC [code].[HelperProc] @Status OUTPUT; + +// -- Use table-valued function +// SELECT * FROM [code].[GetItems](@Status); + +// COMMIT TRANSACTION; +// SET @Message = N'Success'; +// END TRY +// BEGIN CATCH +// IF @@TRANCOUNT > 0 +// ROLLBACK TRANSACTION; + +// SET @Status = ERROR_NUMBER(); +// SET @Message = ERROR_MESSAGE(); +// END CATCH +// END +// go +// ` +// doc := parseDocument(input) + +// if doc.HasErrors() { +// t.Fatalf("unexpected errors: %+v", doc.Errors()) +// } + +// creates := doc.Creates() +// if len(creates) != 1 { +// t.Fatalf("expected 1 create, got %d", len(creates)) +// } + +// c := creates[0] +// if c.CreateType != "procedure" { +// t.Errorf("expected 'procedure', got %q", c.CreateType) +// } + +// // Should have dependencies on HelperProc, GetItems, and MyTableType +// if len(c.DependsOn) != 3 { +// t.Errorf("expected 3 dependencies, got %d: %+v", c.DependsOn) +// } + +// depNames := make(map[string]bool) +// for _, dep := range c.DependsOn { +// depNames[dep.Value] = true +// } + +// expectedDeps := []string{"[GetItems]", "[HelperProc]", "[MyTableType]"} +// for _, exp := range expectedDeps { +// if !depNames[exp] { +// t.Errorf("missing expected dependency %s", exp) +// } +// } +// } + +// func TestDocument_WithoutPos(t *testing.T) { +// input := ` +// DECLARE @EnumValue int = 1; +// go +// CREATE PROCEDURE [code].[MyProc] +// AS +// SELECT 1 +// go +// ` +// doc := parseDocument(input) +// docWithoutPos := doc.WithoutPos().(*TSqlDocument) + +// // Verify positions are zeroed +// for _, c := range docWithoutPos.Creates() { +// if c.QuotedName.Pos.Line != 0 { +// t.Error("expected zero position in WithoutPos result") +// } +// } +// for _, d := range docWithoutPos.Declares() { +// if d.Start.Line != 0 { +// t.Error("expected zero position in WithoutPos result") +// } +// } +// } + +// func TestDocument_Include(t *testing.T) { +// input1 := ` +// CREATE PROCEDURE [code].[ProcA] +// AS +// SELECT 1 +// go +// ` +// input2 := ` +// CREATE PROCEDURE [code].[ProcB] +// AS +// SELECT 2 +// go +// ` +// doc1 := parseDocument(input1) +// doc2 := parseDocument(input2) + +// doc1.Include(doc2) + +// creates := doc1.Creates() +// if len(creates) != 2 { +// t.Fatalf("expected 2 creates after Include, got %d", len(creates)) +// } +// } + +// func TestDocument_Sort(t *testing.T) { +// // ProcB depends on ProcA, so ProcA should come first after sort +// input := ` +// CREATE PROCEDURE [code].[ProcB] +// AS +// BEGIN +// EXEC [code].[ProcA] +// END +// go + +// CREATE PROCEDURE [code].[ProcA] +// AS +// SELECT 1 +// go +// ` +// doc := parseDocument(input) +// doc.Sort() + +// if doc.HasErrors() { +// t.Fatalf("unexpected errors after sort: %+v", doc.Errors()) +// } + +// creates := doc.Creates() +// if len(creates) != 2 { +// t.Fatalf("expected 2 creates, got %d", len(creates)) +// } + +// // After topological sort, ProcA should come before ProcB +// if creates[0].QuotedName.Value != "[ProcA]" { +// t.Errorf("expected ProcA first after sort, got %q", creates[0].QuotedName.Value) +// } +// if creates[1].QuotedName.Value != "[ProcB]" { +// t.Errorf("expected ProcB second after sort, got %q", creates[1].QuotedName.Value) +// } +// } + +// func TestDocument_SortCyclicDependency(t *testing.T) { +// // Create a cycle: A -> B -> C -> A +// input := ` +// CREATE PROCEDURE [code].[ProcA] +// AS +// EXEC [code].[ProcC] +// go + +// CREATE PROCEDURE [code].[ProcB] +// AS +// EXEC [code].[ProcA] +// go + +// CREATE PROCEDURE [code].[ProcC] +// AS +// EXEC [code].[ProcB] +// go +// ` +// doc := parseDocument(input) +// doc.Sort() + +// // Should have an error about cyclic dependency +// if !doc.HasErrors() { +// t.Fatal("expected error for cyclic dependency") +// } + +// found := false +// for _, err := range doc.Errors() { +// if strings.Contains(strings.ToLower(err.Message), "cycle") || +// strings.Contains(strings.ToLower(err.Message), "circular") { +// found = true +// break +// } +// } +// if !found { +// t.Errorf("expected cycle-related error, got: %+v", doc.Errors()) +// } +// } + +// func TestDocument_NewScanner(t *testing.T) { +// doc := &TSqlDocument{} +// input := "SELECT 1" +// file := sqldocument.FileRef("test.sql") + +// scanner := doc.NewScanner(input, file) + +// if scanner == nil { +// t.Fatal("NewScanner returned nil") +// } + +// scanner.NextToken() +// if scanner.TokenType() != sqldocument.ReservedWordToken { +// t.Errorf("expected ReservedWordToken for SELECT, got %v", scanner.TokenType()) +// } +// } + +// func TestDocument_Empty(t *testing.T) { +// emptyDoc := &TSqlDocument{} +// if !emptyDoc.Empty() { +// t.Error("empty document should report Empty() = true") +// } + +// input := ` +// CREATE PROCEDURE [code].[MyProc] +// AS +// SELECT 1 +// go +// ` +// doc := parseDocument(input) +// // Note: Empty() returns true if EITHER creates or declares is empty (uses ||) +// // This might be a bug in the original code - should probably be && +// // Testing current behavior +// if doc.Empty() { +// // Has creates but no declares, so with || it would be true +// // Let's check the actual logic +// t.Log("Empty() behavior may need review - currently uses || instead of &&") +// } +// } + +// func TestDocument_UnicodeIdentifiers(t *testing.T) { +// input := ` +// CREATE PROCEDURE [code].[日本語プロシージャ] +// @パラメータ nvarchar(100) +// AS +// BEGIN +// SELECT @パラメータ +// END +// go +// ` +// doc := parseDocument(input) + +// if doc.HasErrors() { +// t.Fatalf("unexpected errors: %+v", doc.Errors()) +// } + +// creates := doc.Creates() +// if len(creates) != 1 { +// t.Fatalf("expected 1 create, got %d", len(creates)) +// } + +// if creates[0].QuotedName.Value != "[日本語プロシージャ]" { +// t.Errorf("expected Unicode name, got %q", creates[0].QuotedName.Value) +// } +// } + +// func TestDocument_NestedCreateStatements(t *testing.T) { +// // Procedure containing CREATE TABLE (should be allowed) +// input := ` +// CREATE PROCEDURE [code].[MyProc] +// AS +// BEGIN +// CREATE TABLE #TempTable (Id int) +// INSERT INTO #TempTable SELECT 1 +// SELECT * FROM #TempTable +// DROP TABLE #TempTable +// END +// go +// ` +// doc := parseDocument(input) + +// if doc.HasErrors() { +// t.Fatalf("unexpected errors: %+v", doc.Errors()) +// } + +// creates := doc.Creates() +// if len(creates) != 1 { +// t.Fatalf("expected 1 create (procedure only), got %d", len(creates)) +// } +// } + +// func TestDocument_TypeWithMaxArg(t *testing.T) { +// input := ` +// DECLARE @ConstValue nvarchar(max) = N'test'; +// go +// ` +// doc := parseDocument(input) + +// if doc.HasErrors() { +// t.Fatalf("unexpected errors: %+v", doc.Errors()) +// } + +// declares := doc.Declares() +// if len(declares) != 1 { +// t.Fatalf("expected 1 declare, got %d", len(declares)) +// } + +// if declares[0].Datatype.BaseType != "nvarchar" { +// t.Errorf("expected 'nvarchar', got %q", declares[0].Datatype.BaseType) +// } +// if len(declares[0].Datatype.Args) != 1 || declares[0].Datatype.Args[0] != "max" { +// t.Errorf("expected ['max'], got %v", declares[0].Datatype.Args) +// } +// } + +// func TestDocument_TypeWithMultipleArgs(t *testing.T) { +// input := ` +// DECLARE @ConstValue decimal(18,4) = 123.4567; +// go +// ` +// doc := parseDocument(input) + +// if doc.HasErrors() { +// t.Fatalf("unexpected errors: %+v", doc.Errors()) +// } + +// declares := doc.Declares() +// if len(declares) != 1 { +// t.Fatalf("expected 1 declare, got %d", len(declares)) +// } + +// dt := declares[0].Datatype +// if dt.BaseType != "decimal" { +// t.Errorf("expected 'decimal', got %q", dt.BaseType) +// } +// if len(dt.Args) != 2 || dt.Args[0] != "18" || dt.Args[1] != "4" { +// t.Errorf("expected ['18', '4'], got %v", dt.Args) +// } +// } diff --git a/sqlparser/scanner.go b/sqlparser/mssql/scanner.go similarity index 70% rename from sqlparser/scanner.go rename to sqlparser/mssql/scanner.go index 18b2b77..32104df 100644 --- a/sqlparser/scanner.go +++ b/sqlparser/mssql/scanner.go @@ -1,101 +1,133 @@ -package sqlparser +package mssql import ( - "fmt" "regexp" "strings" "unicode" "unicode/utf8" "github.com/smasher164/xid" + "github.com/vippsas/sqlcode/sqlparser/sqldocument" ) -// dedicated type for reference to file, in case we need to refactor this later.. -type FileRef string - -type Pos struct { - File FileRef - Line, Col int -} - -// We don't do the lexer/parser split / token stream, but simply use the -// Scanner directly from the recursive descent parser; it is simply a cursor -// in the buffer with associated utility methods +// Scanner is a lexical scanner for T-SQL source code. +// +// Unlike traditional lexer/parser architectures with a token stream, Scanner +// is used directly by the recursive descent parser as a cursor into the input +// buffer. It provides utility methods for tokenization and position tracking. +// +// The scanner handles T-SQL specific constructs including: +// - String literals ('...' and N'...') +// - Quoted identifiers ([...]) +// - Single-line (--) and multi-line (/* */) comments +// - Batch separators (GO) +// - Reserved words +// - Variables (@identifier) type Scanner struct { - input string - file FileRef + input string // The complete source code being scanned + file sqldocument.FileRef // Reference to the source file for error reporting - startIndex int // start of this item - curIndex int // current position of the Scanner - tokenType TokenType + startIndex int // Byte index where current token starts + curIndex int // Current byte position in Input + tokenType sqldocument.TokenType // Type of the current token - // NextToken() has a small state machine to implement the rules of batch seperators - // using these two states - startOfLine bool // have we seen anything non-whitespace, non-comment since start of line? Only used for BatchSeparatorToken - afterBatchSeparator bool // raise an error if we see anything except whitespace and comments after 'go' + // Batch separator state machine fields. + // The GO batch separator has special rules: it must appear at the start + // of a line and nothing except whitespace can follow it on the same line. + startOfLine bool // True if no non-whitespace/comment seen since start of line + afterBatchSeparator bool // True if we just saw GO; used to detect malformed separators - startLine int - stopLine int - indexAtStartLine int // value of `curIndex` after newline char - indexAtStopLine int // value of `curIndex` after newline char + startLine int // Line number (0-indexed) where current token starts + stopLine int // Line number (0-indexed) where current token ends + indexAtStartLine int // Byte index at the start of startLine (after newline) + indexAtStopLine int // Byte index at the start of stopLine (after newline) - reservedWord string // in the event that the token is a ReservedWordToken, this contains the lower-case version + reservedWord string // Lowercase version of token if it's a reserved word, empty otherwise } -func NewScanner(path FileRef, input string) *Scanner { +// NewScanner creates a new Scanner for the given source file and input string. +// The scanner is positioned before the first token; call NextToken() to advance. +func NewScanner(path sqldocument.FileRef, input string) *Scanner { return &Scanner{input: input, file: path} } -type TokenType int - -func (s *Scanner) TokenType() TokenType { +// TokenType returns the type of the current token. +func (s *Scanner) TokenType() sqldocument.TokenType { return s.tokenType } -// Returns a clone of the scanner; this is used to do look-ahead parsing +func (s *Scanner) SetInput(input []byte) { + s.input = string(input) +} + +func (s *Scanner) SetFile(file sqldocument.FileRef) { + s.file = file +} + +func (s *Scanner) File() sqldocument.FileRef { + return s.file +} + +// Clone returns a copy of the scanner at its current position. +// This is used for look-ahead parsing where we need to tentatively +// scan tokens without committing to consuming them. func (s Scanner) Clone() *Scanner { result := new(Scanner) *result = s return result } +// Token returns the text of the current token as a substring of Input. func (s *Scanner) Token() string { return s.input[s.startIndex:s.curIndex] } +// TokenLower returns the current token text converted to lowercase. +// Useful for case-insensitive keyword matching. func (s *Scanner) TokenLower() string { return strings.ToLower(s.Token()) } +// ReservedWord returns the lowercase reserved word if the current token +// is a ReservedWordToken, or an empty string otherwise. func (s *Scanner) ReservedWord() string { return s.reservedWord } -func (s *Scanner) Start() Pos { - return Pos{ +// Start returns the position where the current token begins. +// Line and column are 1-indexed. +func (s *Scanner) Start() sqldocument.Pos { + return sqldocument.Pos{ Line: s.startLine + 1, Col: s.startIndex - s.indexAtStartLine + 1, File: s.file, } } -func (s *Scanner) Stop() Pos { - return Pos{ +// Stop returns the position where the current token ends. +// Line and column are 1-indexed. +func (s *Scanner) Stop() sqldocument.Pos { + return sqldocument.Pos{ Line: s.stopLine + 1, Col: s.curIndex - s.indexAtStopLine + 1, File: s.file, } } +// bumpLine increments the line counter and records the byte position +// after the newline character. The offset parameter is the position +// of the newline within the current scan operation. func (s *Scanner) bumpLine(offset int) { s.stopLine++ s.indexAtStopLine = s.curIndex + offset + 1 } +// SkipWhitespaceComments advances past any whitespace and comment tokens. +// Stops when a non-whitespace, non-comment token is encountered. func (s *Scanner) SkipWhitespaceComments() { for { switch s.TokenType() { - case WhitespaceToken, MultilineCommentToken, SinglelineCommentToken: + case sqldocument.WhitespaceToken, sqldocument.MultilineCommentToken, sqldocument.SinglelineCommentToken: default: return } @@ -103,10 +135,13 @@ func (s *Scanner) SkipWhitespaceComments() { } } +// SkipWhitespace advances past any whitespace tokens. +// Stops when a non-whitespace token is encountered. +// Unlike SkipWhitespaceComments, this preserves comments. func (s *Scanner) SkipWhitespace() { for { switch s.TokenType() { - case WhitespaceToken: + case sqldocument.WhitespaceToken: default: return } @@ -114,21 +149,35 @@ func (s *Scanner) SkipWhitespace() { } } -func (s *Scanner) NextNonWhitespaceToken() TokenType { +// NextNonWhitespaceToken advances to the next token and then skips +// any whitespace, returning the type of the first non-whitespace token. +func (s *Scanner) NextNonWhitespaceToken() sqldocument.TokenType { s.NextToken() s.SkipWhitespace() return s.TokenType() } -func (s *Scanner) NextNonWhitespaceCommentToken() TokenType { +// NextNonWhitespaceCommentToken advances to the next token and then skips +// any whitespace and comments, returning the type of the first significant token. +func (s *Scanner) NextNonWhitespaceCommentToken() sqldocument.TokenType { s.NextToken() s.SkipWhitespaceComments() return s.TokenType() } -// NextToken scans the NextToken token and advances the Scanner's position to -// after the token -func (s *Scanner) NextToken() TokenType { +// NextToken scans the next token and advances the scanner's position. +// +// This method wraps the raw tokenization with batch separator handling. +// The GO batch separator has special rules in T-SQL: +// - It must appear at the start of a line (only whitespace/comments before it) +// - Nothing except whitespace may follow it on the same line +// - It is not processed inside [names], 'strings', or /*comments*/ +// +// If GO is followed by non-whitespace on the same line, subsequent tokens +// are returned as MalformedBatchSeparatorToken until end of line. +// +// Returns the TokenType of the scanned token. +func (s *Scanner) NextToken() sqldocument.TokenType { // handle startOfLine flag here; this is used to parse the 'go' batch separator s.tokenType = s.nextToken() @@ -145,12 +194,12 @@ func (s *Scanner) NextToken() TokenType { // on the same line as 'go'. And doing so will not turn it into a literal, // but instead return MalformedBatchSeparatorToken - if s.startOfLine && s.tokenType == UnquotedIdentifierToken && s.TokenLower() == "go" { - s.tokenType = BatchSeparatorToken + if s.startOfLine && s.tokenType == sqldocument.UnquotedIdentifierToken && s.TokenLower() == "go" { + s.tokenType = sqldocument.BatchSeparatorToken s.afterBatchSeparator = true - } else if s.afterBatchSeparator && s.tokenType != WhitespaceToken && s.tokenType != EOFToken { - s.tokenType = MalformedBatchSeparatorToken - } else if s.tokenType == WhitespaceToken { + } else if s.afterBatchSeparator && s.tokenType != sqldocument.WhitespaceToken && s.tokenType != sqldocument.EOFToken { + s.tokenType = sqldocument.MalformedBatchSeparatorToken + } else if s.tokenType == sqldocument.WhitespaceToken { // If we just saw the whitespace token that bumped the linecount, // we are at the "start of line", even if this contains some space after the \n: if s.stopLine > s.startLine { @@ -161,10 +210,11 @@ func (s *Scanner) NextToken() TokenType { } else { s.startOfLine = false } + return s.tokenType } -func (s *Scanner) nextToken() TokenType { +func (s *Scanner) nextToken() sqldocument.TokenType { s.startIndex = s.curIndex s.reservedWord = "" s.startLine = s.stopLine @@ -174,29 +224,29 @@ func (s *Scanner) nextToken() TokenType { // First, decisions that can be made after one character: switch { case r == utf8.RuneError && w == 0: - return EOFToken + return sqldocument.EOFToken case r == utf8.RuneError && w == -1: // not UTF-8, we can't really proceed so not advancing Scanner, // caller should take care to always exit.. - return NonUTF8ErrorToken + return sqldocument.NonUTF8ErrorToken case r == '(': s.curIndex += w - return LeftParenToken + return sqldocument.LeftParenToken case r == ')': s.curIndex += w - return RightParenToken + return sqldocument.RightParenToken case r == ';': s.curIndex += w - return SemicolonToken + return sqldocument.SemicolonToken case r == '=': s.curIndex += w - return EqualToken + return sqldocument.EqualToken case r == ',': s.curIndex += w - return CommaToken + return sqldocument.CommaToken case r == '.': s.curIndex += w - return DotToken + return sqldocument.DotToken case r == '\'': s.curIndex += w return s.scanStringLiteral(VarcharLiteralToken) @@ -219,16 +269,15 @@ func (s *Scanner) nextToken() TokenType { s.curIndex += w s.scanIdentifier() if r == '@' { - return VariableIdentifierToken + return sqldocument.VariableIdentifierToken } else { rw := strings.ToLower(s.Token()) _, ok := reservedWords[rw] - fmt.Printf("%#v %t\n", rw, ok) if ok { s.reservedWord = rw - return ReservedWordToken + return sqldocument.ReservedWordToken } else { - return UnquotedIdentifierToken + return sqldocument.UnquotedIdentifierToken } } } @@ -247,13 +296,11 @@ func (s *Scanner) nextToken() TokenType { s.scanIdentifier() rw := strings.ToLower(s.Token()) _, ok := reservedWords[rw] - fmt.Printf("%#v %t\n", rw, ok) - if ok { s.reservedWord = rw - return ReservedWordToken + return sqldocument.ReservedWordToken } else { - return UnquotedIdentifierToken + return sqldocument.UnquotedIdentifierToken } case r == '/' && r2 == '*': s.curIndex += w + w2 @@ -266,28 +313,28 @@ func (s *Scanner) nextToken() TokenType { } s.curIndex += w - return OtherToken + return sqldocument.OtherToken } // scanMultilineComment assumes one has advanced over '/*' -func (s *Scanner) scanMultilineComment() TokenType { +func (s *Scanner) scanMultilineComment() sqldocument.TokenType { prevWasStar := false for i, r := range s.input[s.curIndex:] { if r == '*' { prevWasStar = true } else if prevWasStar && r == '/' { s.curIndex += i + 1 - return MultilineCommentToken + return sqldocument.MultilineCommentToken } else if r == '\n' { s.bumpLine(i) } } s.curIndex = len(s.input) - return MultilineCommentToken + return sqldocument.MultilineCommentToken } // scanSinglelineComment assumes one has advanced over -- -func (s *Scanner) scanSinglelineComment() TokenType { +func (s *Scanner) scanSinglelineComment() sqldocument.TokenType { isPragma := strings.HasPrefix(s.input[s.curIndex:], "sqlcode:") end := strings.Index(s.input[s.curIndex:], "\n") if end == -1 { @@ -299,20 +346,20 @@ func (s *Scanner) scanSinglelineComment() TokenType { s.curIndex += end } if isPragma { - return PragmaToken + return sqldocument.PragmaToken } else { - return SinglelineCommentToken + return sqldocument.SinglelineCommentToken } } // scanStringLiteral assumes one has scanned ' or N' (depending on param); // then scans until the end of the string -func (s *Scanner) scanStringLiteral(tokenType TokenType) TokenType { +func (s *Scanner) scanStringLiteral(tokenType sqldocument.TokenType) sqldocument.TokenType { return s.scanUntilSingleDoubleEscapes('\'', tokenType, UnterminatedVarcharLiteralErrorToken) } -func (s *Scanner) scanQuotedIdentifier() TokenType { - return s.scanUntilSingleDoubleEscapes(']', QuotedIdentifierToken, UnterminatedQuotedIdentifierErrorToken) +func (s *Scanner) scanQuotedIdentifier() sqldocument.TokenType { + return s.scanUntilSingleDoubleEscapes(']', sqldocument.QuotedIdentifierToken, UnterminatedQuotedIdentifierErrorToken) } // scanIdentifier assumes first character of an identifier has been identified, @@ -328,7 +375,11 @@ func (s *Scanner) scanIdentifier() { } // DRY helper to handle both ” and ]] escapes -func (s *Scanner) scanUntilSingleDoubleEscapes(endmarker rune, tokenType TokenType, unterminatedTokenType TokenType) TokenType { +func (s *Scanner) scanUntilSingleDoubleEscapes( + endmarker rune, + tokenType sqldocument.TokenType, + unterminatedTokenType sqldocument.TokenType, +) sqldocument.TokenType { skipnext := false for i, r := range s.input[s.curIndex:] { if skipnext { @@ -356,7 +407,7 @@ func (s *Scanner) scanUntilSingleDoubleEscapes(endmarker rune, tokenType TokenTy var numberRegexp = regexp.MustCompile(`^[+-]?\d+\.?\d*([eE][+-]?\d*)?`) -func (s *Scanner) scanNumber() TokenType { +func (s *Scanner) scanNumber() sqldocument.TokenType { // T-SQL seems to scan a number until the // end and then allowing a literal to start without whitespace or other things // in between... @@ -368,22 +419,22 @@ func (s *Scanner) scanNumber() TokenType { panic("should always have a match according to regex and conditions in caller") } s.curIndex += loc[1] - return NumberToken + return sqldocument.NumberToken } -func (s *Scanner) scanWhitespace() TokenType { +func (s *Scanner) scanWhitespace() sqldocument.TokenType { for i, r := range s.input[s.curIndex:] { if r == '\n' { s.bumpLine(i) } if !unicode.IsSpace(r) { s.curIndex += i - return WhitespaceToken + return sqldocument.WhitespaceToken } } // eof s.curIndex = len(s.input) - return WhitespaceToken + return sqldocument.WhitespaceToken } // tsql (mssql) reservered words diff --git a/sqlparser/mssql/scanner_test.go b/sqlparser/mssql/scanner_test.go new file mode 100644 index 0000000..19a561f --- /dev/null +++ b/sqlparser/mssql/scanner_test.go @@ -0,0 +1,600 @@ +package mssql + +import ( + "testing" + + "github.com/vippsas/sqlcode/sqlparser/sqldocument" +) + +// Helper to collect all tokens from input +func collectTokens(input string) []struct { + Type sqldocument.TokenType + Value string +} { + s := NewScanner("test.sql", input) + var tokens []struct { + Type sqldocument.TokenType + Value string + } + for { + tt := s.NextToken() + tokens = append(tokens, struct { + Type sqldocument.TokenType + Value string + }{tt, s.Token()}) + if tt == sqldocument.EOFToken { + break + } + } + return tokens +} + +func TestScanner_SimpleTokens(t *testing.T) { + tests := []struct { + name string + input string + expected []sqldocument.TokenType + }{ + { + name: "parentheses and punctuation", + input: "( ) ; = , .", + expected: []sqldocument.TokenType{ + sqldocument.LeftParenToken, + sqldocument.WhitespaceToken, + sqldocument.RightParenToken, + sqldocument.WhitespaceToken, + sqldocument.SemicolonToken, + sqldocument.WhitespaceToken, + sqldocument.EqualToken, + sqldocument.WhitespaceToken, + sqldocument.CommaToken, + sqldocument.WhitespaceToken, + sqldocument.DotToken, + sqldocument.EOFToken, + }, + }, + { + name: "empty input", + input: "", + expected: []sqldocument.TokenType{ + sqldocument.EOFToken, + }, + }, + { + name: "whitespace only", + input: " \t\n ", + expected: []sqldocument.TokenType{ + sqldocument.WhitespaceToken, + sqldocument.EOFToken, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tokens := collectTokens(tt.input) + if len(tokens) != len(tt.expected) { + t.Fatalf("expected %d tokens, got %d", len(tt.expected), len(tokens)) + } + for i, exp := range tt.expected { + if tokens[i].Type != exp { + t.Errorf("token %d: expected type %v, got %v (value: %q)", + i, exp, tokens[i].Type, tokens[i].Value) + } + } + }) + } +} + +func TestScanner_StringLiterals(t *testing.T) { + tests := []struct { + name string + input string + expectedType sqldocument.TokenType + expectedValue string + }{ + { + name: "simple varchar", + input: "'hello world'", + expectedType: VarcharLiteralToken, + expectedValue: "'hello world'", + }, + { + name: "varchar with escaped quote", + input: "'it''s working'", + expectedType: VarcharLiteralToken, + expectedValue: "'it''s working'", + }, + { + name: "empty varchar", + input: "''", + expectedType: VarcharLiteralToken, + expectedValue: "''", + }, + { + name: "simple nvarchar", + input: "N'unicode string'", + expectedType: NVarcharLiteralToken, + expectedValue: "N'unicode string'", + }, + { + name: "nvarchar with unicode", + input: "N'こんにちは'", + expectedType: NVarcharLiteralToken, + expectedValue: "N'こんにちは'", + }, + { + name: "nvarchar with escaped quote", + input: "N'say ''hello'''", + expectedType: NVarcharLiteralToken, + expectedValue: "N'say ''hello'''", + }, + { + name: "multiline varchar", + input: "'line1\nline2\nline3'", + expectedType: VarcharLiteralToken, + expectedValue: "'line1\nline2\nline3'", + }, + { + name: "unterminated varchar", + input: "'unterminated", + expectedType: UnterminatedVarcharLiteralErrorToken, + expectedValue: "'unterminated", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := NewScanner("test.sql", tt.input) + s.NextToken() + if s.TokenType() != tt.expectedType { + t.Errorf("expected type %v, got %v", tt.expectedType, s.TokenType()) + } + if s.Token() != tt.expectedValue { + t.Errorf("expected value %q, got %q", tt.expectedValue, s.Token()) + } + }) + } +} + +func TestScanner_QuotedIdentifiers(t *testing.T) { + tests := []struct { + name string + input string + expectedType sqldocument.TokenType + expectedValue string + }{ + { + name: "simple bracket identifier", + input: "[MyTable]", + expectedType: sqldocument.QuotedIdentifierToken, + expectedValue: "[MyTable]", + }, + { + name: "bracket identifier with space", + input: "[My Table Name]", + expectedType: sqldocument.QuotedIdentifierToken, + expectedValue: "[My Table Name]", + }, + { + name: "bracket identifier with escaped bracket", + input: "[My]]Table]", + expectedType: sqldocument.QuotedIdentifierToken, + expectedValue: "[My]]Table]", + }, + { + name: "code schema identifier", + input: "[code]", + expectedType: sqldocument.QuotedIdentifierToken, + expectedValue: "[code]", + }, + { + name: "unterminated bracket identifier", + input: "[unterminated", + expectedType: UnterminatedQuotedIdentifierErrorToken, + expectedValue: "[unterminated", + }, + { + name: "double quote error", + input: "\"identifier\"", + expectedType: DoubleQuoteErrorToken, + expectedValue: "\"", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := NewScanner("test.sql", tt.input) + s.NextToken() + if s.TokenType() != tt.expectedType { + t.Errorf("expected type %v, got %v", tt.expectedType, s.TokenType()) + } + if s.Token() != tt.expectedValue { + t.Errorf("expected value %q, got %q", tt.expectedValue, s.Token()) + } + }) + } +} + +func TestScanner_Identifiers(t *testing.T) { + tests := []struct { + name string + input string + expectedType sqldocument.TokenType + expectedWord string // for reserved words + }{ + { + name: "simple identifier", + input: "MyProc", + expectedType: sqldocument.UnquotedIdentifierToken, + }, + { + name: "identifier with underscore", + input: "my_procedure_name", + expectedType: sqldocument.UnquotedIdentifierToken, + }, + { + name: "identifier with numbers", + input: "Proc123", + expectedType: sqldocument.UnquotedIdentifierToken, + }, + { + name: "identifier starting with underscore", + input: "_private", + expectedType: sqldocument.UnquotedIdentifierToken, + }, + { + name: "identifier with hash (temp table)", + input: "#TempTable", + expectedType: sqldocument.UnquotedIdentifierToken, + }, + { + name: "global temp table", + input: "##GlobalTemp", + expectedType: sqldocument.UnquotedIdentifierToken, + }, + { + name: "variable identifier", + input: "@myVariable", + expectedType: sqldocument.VariableIdentifierToken, + }, + { + name: "system variable", + input: "@@ROWCOUNT", + expectedType: sqldocument.VariableIdentifierToken, + }, + { + name: "reserved word CREATE", + input: "CREATE", + expectedType: sqldocument.ReservedWordToken, + expectedWord: "create", + }, + { + name: "reserved word lowercase", + input: "select", + expectedType: sqldocument.ReservedWordToken, + expectedWord: "select", + }, + { + name: "reserved word mixed case", + input: "DeClaRe", + expectedType: sqldocument.ReservedWordToken, + expectedWord: "declare", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := NewScanner("test.sql", tt.input) + s.NextToken() + if s.TokenType() != tt.expectedType { + t.Errorf("expected type %v, got %v", tt.expectedType, s.TokenType()) + } + if tt.expectedWord != "" && s.ReservedWord() != tt.expectedWord { + t.Errorf("expected reserved word %q, got %q", tt.expectedWord, s.ReservedWord()) + } + }) + } +} + +func TestScanner_Numbers(t *testing.T) { + tests := []struct { + name string + input string + expectedValue string + }{ + {"integer", "42", "42"}, + {"negative integer", "-42", "-42"}, + {"positive integer", "+42", "+42"}, + {"decimal", "3.14159", "3.14159"}, + {"negative decimal", "-3.14", "-3.14"}, + {"scientific notation", "1.5e10", "1.5e10"}, + {"scientific negative exponent", "1.5e-10", "1.5e-10"}, + {"scientific positive exponent", "1.5e+10", "1.5e+10"}, + {"integer scientific", "1e5", "1e5"}, + {"leading decimal", "123.", "123."}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := NewScanner("test.sql", tt.input) + s.NextToken() + if s.TokenType() != sqldocument.NumberToken { + t.Errorf("expected NumberToken, got %v", s.TokenType()) + } + if s.Token() != tt.expectedValue { + t.Errorf("expected %q, got %q", tt.expectedValue, s.Token()) + } + }) + } +} + +func TestScanner_Comments(t *testing.T) { + tests := []struct { + name string + input string + expectedType sqldocument.TokenType + expectedValue string + }{ + { + name: "single line comment", + input: "-- this is a comment", + expectedType: sqldocument.SinglelineCommentToken, + expectedValue: "-- this is a comment", + }, + { + name: "single line comment before newline", + input: "-- comment\ncode", + expectedType: sqldocument.SinglelineCommentToken, + expectedValue: "-- comment", + }, + { + name: "multiline comment", + input: "/* this is\na multiline\ncomment */", + expectedType: sqldocument.MultilineCommentToken, + expectedValue: "/* this is\na multiline\ncomment */", + }, + { + name: "multiline comment with asterisks", + input: "/* * * * */", + expectedType: sqldocument.MultilineCommentToken, + expectedValue: "/* * * * */", + }, + { + name: "pragma comment", + input: "--sqlcode:include-if foo", + expectedType: sqldocument.PragmaToken, + expectedValue: "--sqlcode:include-if foo", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := NewScanner("test.sql", tt.input) + s.NextToken() + if s.TokenType() != tt.expectedType { + t.Errorf("expected type %v, got %v", tt.expectedType, s.TokenType()) + } + if s.Token() != tt.expectedValue { + t.Errorf("expected value %q, got %q", tt.expectedValue, s.Token()) + } + }) + } +} + +func TestScanner_BatchSeparator(t *testing.T) { + tests := []struct { + name string + input string + expected []sqldocument.TokenType + }{ + { + name: "go at start of file", + input: "go", + expected: []sqldocument.TokenType{ + sqldocument.BatchSeparatorToken, + sqldocument.EOFToken, + }, + }, + { + name: "GO uppercase at start", + input: "GO", + expected: []sqldocument.TokenType{ + sqldocument.BatchSeparatorToken, + sqldocument.EOFToken, + }, + }, + { + name: "go after newline", + input: "SELECT 1\ngo", + expected: []sqldocument.TokenType{ + sqldocument.ReservedWordToken, // SELECT + sqldocument.WhitespaceToken, + sqldocument.NumberToken, // 1 + sqldocument.WhitespaceToken, + sqldocument.BatchSeparatorToken, // go + sqldocument.EOFToken, + }, + }, + { + name: "go with trailing whitespace", + input: "go \nSELECT", + expected: []sqldocument.TokenType{ + sqldocument.BatchSeparatorToken, + sqldocument.WhitespaceToken, + sqldocument.ReservedWordToken, + sqldocument.EOFToken, + }, + }, + { + name: "go mid-line is identifier", + input: "SELECT go FROM", + expected: []sqldocument.TokenType{ + sqldocument.ReservedWordToken, // SELECT + sqldocument.WhitespaceToken, + sqldocument.UnquotedIdentifierToken, // go (not batch separator) + sqldocument.WhitespaceToken, + sqldocument.ReservedWordToken, // FROM + sqldocument.EOFToken, + }, + }, + { + name: "go with comment after is malformed", + input: "go -- comment", + expected: []sqldocument.TokenType{ + sqldocument.BatchSeparatorToken, + sqldocument.WhitespaceToken, + sqldocument.MalformedBatchSeparatorToken, // -- comment + sqldocument.EOFToken, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tokens := collectTokens(tt.input) + if len(tokens) != len(tt.expected) { + t.Fatalf("expected %d tokens, got %d: %+v", len(tt.expected), len(tokens), tokens) + } + for i, exp := range tt.expected { + if tokens[i].Type != exp { + t.Errorf("token %d: expected type %v, got %v (value: %q)", + i, exp, tokens[i].Type, tokens[i].Value) + } + } + }) + } +} + +func TestScanner_Position(t *testing.T) { + input := "SELECT\n @var\n FROM" + s := NewScanner("test.sql", input) + + // SELECT + s.NextToken() + start := s.Start() + if start.Line != 1 || start.Col != 1 { + t.Errorf("SELECT start: expected (1,1), got (%d,%d)", start.Line, start.Col) + } + + // whitespace (includes newline) + s.NextToken() + + // @var + s.NextToken() + start = s.Start() + if start.Line != 2 || start.Col != 3 { + t.Errorf("@var start: expected (2,3), got (%d,%d)", start.Line, start.Col) + } + + // whitespace + s.NextToken() + + // FROM + s.NextToken() + start = s.Start() + if start.Line != 3 || start.Col != 3 { + t.Errorf("FROM start: expected (3,3), got (%d,%d)", start.Line, start.Col) + } +} + +func TestScanner_ComplexStatement(t *testing.T) { + input := `CREATE PROCEDURE [code].[MyProc] + @Param1 nvarchar(100), + @Param2 int = 42 +AS +BEGIN + SELECT @Param1, @Param2 +END` + + s := NewScanner("test.sql", input) + + // Verify we can tokenize the entire statement without errors + tokenCount := 0 + for { + tt := s.NextToken() + tokenCount++ + if tt == sqldocument.EOFToken { + break + } + if tt == sqldocument.NonUTF8ErrorToken { + t.Fatalf("unexpected non-UTF8 error at token %d", tokenCount) + } + } + + if tokenCount < 30 { + t.Errorf("expected at least 30 tokens, got %d", tokenCount) + } +} + +func TestScanner_ToCommonToken(t *testing.T) { + tests := []struct { + tsqlToken sqldocument.TokenType + commonToken sqldocument.TokenType + }{ + {VarcharLiteralToken, sqldocument.StringLiteralToken}, + {NVarcharLiteralToken, sqldocument.StringLiteralToken}, + {BracketQuotedIdentifierToken, sqldocument.QuotedIdentifierToken}, + {UnterminatedVarcharLiteralErrorToken, sqldocument.UnterminatedStringErrorToken}, + {UnterminatedQuotedIdentifierErrorToken, sqldocument.UnterminatedStringErrorToken}, + {sqldocument.NumberToken, sqldocument.NumberToken}, // passthrough + {sqldocument.WhitespaceToken, sqldocument.WhitespaceToken}, // passthrough + } + + for _, tt := range tests { + result := ToCommonToken(tt.tsqlToken) + if result != tt.commonToken { + t.Errorf("ToCommonToken(%v): expected %v, got %v", + tt.tsqlToken, tt.commonToken, result) + } + } +} + +func TestScanner_Clone(t *testing.T) { + input := "SELECT FROM WHERE" + s := NewScanner("test.sql", input) + + s.NextToken() // SELECT + s.NextToken() // whitespace + + clone := s.Clone() + + // Advance original + s.NextToken() // FROM + + // Clone should still be at whitespace position + if clone.Token() != " " { + t.Errorf("clone should still be at whitespace, got %q", clone.Token()) + } + + // Advance clone independently + clone.NextToken() + if clone.Token() != "FROM" { + t.Errorf("clone should now be at FROM, got %q", clone.Token()) + } +} + +func TestScanner_SkipMethods(t *testing.T) { + input := "SELECT /* comment */ @var" + s := NewScanner("test.sql", input) + + s.NextToken() // SELECT + s.NextToken() // whitespace + + // SkipWhitespace should stop at comment + s.SkipWhitespace() + if s.TokenType() != sqldocument.MultilineCommentToken { + t.Errorf("SkipWhitespace should stop at comment, got %v", s.TokenType()) + } + + // Reset and test SkipWhitespaceComments + s = NewScanner("test.sql", input) + s.NextToken() // SELECT + tt := s.NextNonWhitespaceCommentToken() + if tt != sqldocument.VariableIdentifierToken { + t.Errorf("NextNonWhitespaceCommentToken should return @var token type, got %v", tt) + } + if s.Token() != "@var" { + t.Errorf("should be at @var, got %q", s.Token()) + } +} diff --git a/sqlparser/mssql/tokens.go b/sqlparser/mssql/tokens.go new file mode 100644 index 0000000..db9fd89 --- /dev/null +++ b/sqlparser/mssql/tokens.go @@ -0,0 +1,57 @@ +package mssql + +import "github.com/vippsas/sqlcode/sqlparser/sqldocument" + +// T-SQL specific tokens (range 1000-1999) +// +// Token values are partitioned by dialect to avoid collisions: +// - 0-999: Common tokens shared across dialects (sqldocument package) +// - 1000-1999: T-SQL specific tokens (this package) +// - 2000-2999: Reserved for other dialects (e.g., PostgreSQL) +// +// This design allows dialect-specific code to use concrete token types +// while common code can use ToCommonToken() for abstraction. +const ( + // T-SQL specific string literals + // + // T-SQL distinguishes between varchar ('...') and nvarchar (N'...') + // string literals. Both use single quotes with '' as the escape sequence. + VarcharLiteralToken sqldocument.TokenType = iota + sqldocument.TSQLTokenStart + NVarcharLiteralToken + + // T-SQL specific identifier styles + // + // T-SQL uses square brackets for quoted identifiers: [My Table] + // Brackets are escaped by doubling: [My]]Table] represents "My]Table" + BracketQuotedIdentifierToken // [identifier] + + // T-SQL specific errors + // + // Unlike standard SQL, T-SQL does not support double-quoted strings. + // Double quotes are reserved for QUOTED_IDENTIFIER mode identifiers, + // but sqlcode requires bracket notation for consistency. + DoubleQuoteErrorToken // T-SQL doesn't support double-quoted strings + UnterminatedVarcharLiteralErrorToken + UnterminatedQuotedIdentifierErrorToken +) + +// ToCommonToken maps T-SQL specific tokens to their common equivalents +// for dialect-agnostic processing. +// +// This abstraction layer allows higher-level code to work with logical +// token categories (e.g., "string literal") without knowing the specific +// dialect syntax (varchar vs nvarchar, brackets vs double quotes). +// +// Tokens that are already common tokens pass through unchanged. +func ToCommonToken(tt sqldocument.TokenType) sqldocument.TokenType { + switch tt { + case VarcharLiteralToken, NVarcharLiteralToken: + return sqldocument.StringLiteralToken + case BracketQuotedIdentifierToken: + return sqldocument.QuotedIdentifierToken + case UnterminatedVarcharLiteralErrorToken, UnterminatedQuotedIdentifierErrorToken: + return sqldocument.UnterminatedStringErrorToken + default: + return tt + } +} diff --git a/sqlparser/node_test.go b/sqlparser/node_test.go deleted file mode 100644 index 04173c6..0000000 --- a/sqlparser/node_test.go +++ /dev/null @@ -1 +0,0 @@ -package sqlparser diff --git a/sqlparser/parser.go b/sqlparser/parser.go index c15e25f..25ca035 100644 --- a/sqlparser/parser.go +++ b/sqlparser/parser.go @@ -13,60 +13,27 @@ import ( "regexp" "slices" "strings" -) - -var templateRoutineName string = "\ndeclare @RoutineName nvarchar(128)\nset @RoutineName = '%s'\n" -var supportedSqlExtensions []string = []string{".sql", ".pgsql"} - -func CopyToken(s *Scanner, target *[]Unparsed) { - *target = append(*target, CreateUnparsed(s)) -} + "github.com/vippsas/sqlcode/sqlparser/mssql" + "github.com/vippsas/sqlcode/sqlparser/sqldocument" +) -// AdvanceAndCopy is like NextToken; advance to next token that is not whitespace and return -// Note: The 'go' and EOF tokens are *not* copied -func AdvanceAndCopy(s *Scanner, target *[]Unparsed) { - for { - tt := s.NextToken() - switch tt { - case EOFToken, BatchSeparatorToken: - // do not copy - return - case WhitespaceToken, MultilineCommentToken, SinglelineCommentToken: - // copy, and loop around - CopyToken(s, target) - continue - default: - // copy, and return - CopyToken(s, target) - return - } - } -} +var ( + templateRoutineName string = "\ndeclare @RoutineName nvarchar(128)\nset @RoutineName = '%s'\n" + supportedSqlExtensions []string = []string{".sql", ".pgsql"} + // consider something a "sqlcode source file" if it contains [code] + // or a --sqlcode: header + isSqlCodeRegex = regexp.MustCompile(`^--sqlcode:|\[code\]`) +) -func Parse(s *Scanner, result Document) { - // Top-level parse; this focuses on splitting into "batches" separated - // by 'go'. - - // CONVENTION: - // All functions should expect `s` positioned on what they are documented - // to consume/parse. - // - // Functions typically consume *after* the keyword that triggered their - // invoication; e.g. parseCreate parses from first non-whitespace-token - // *after* `create`. - // - // On return, `s` is positioned at the token that starts the next statement/ - // sub-expression. In particular trailing ';' and whitespace has been consumed. - // - // `s` will typically never be positioned on whitespace except in - // whitespace-preserving parsing - s.NextNonWhitespaceToken() - err := result.Parse(s) - if err != nil { - panic(fmt.Sprintf("failed to parse document: %s: %e", s.file, err)) +// Based on the input file extension, create the appropriate Document type +func NewDocumentFromExtension(extension string) sqldocument.Document { + switch extension { + case ".sql": + return &mssql.TSqlDocument{} + default: + panic("unhandled document type: " + extension) } - return } // ParseFileystems iterates through a list of filesystems and parses all supported @@ -76,7 +43,7 @@ func Parse(s *Scanner, result Document) { // related to parsing/sorting will be in result.Errors. // // ParseFilesystems will also sort create statements topologically. -func ParseFilesystems(fslst []fs.FS, includeTags []string) (filenames []string, result Document, err error) { +func ParseFilesystems(fslst []fs.FS, includeTags []string) (filenames []string, result sqldocument.Document, err error) { // We are being passed several *filesystems* here. It may be easy to pass in the same // directory twice but that should not be encouraged, so if we get the same hash from // two files, return an error. Only files containing [code] in some way will be @@ -123,8 +90,12 @@ func ParseFilesystems(fslst []fs.FS, includeTags []string) (filenames []string, hashes[hash] = pathDesc fdoc := NewDocumentFromExtension(extension) - Parse(&Scanner{input: string(buf), file: FileRef(path)}, fdoc) + err = fdoc.Parse(buf, sqldocument.FileRef(path)) + if err != nil { + return fmt.Errorf("error parsing file %s: %w", pathDesc, err) + } + // only include if include tags match if matchesIncludeTags(fdoc.PragmaIncludeIf(), includeTags) { filenames = append(filenames, pathDesc) result.Include(fdoc) @@ -166,7 +137,3 @@ func IsSqlcodeConstVariable(varname string) bool { strings.HasPrefix(varname, "@CONST_") || strings.HasPrefix(varname, "@const_") } - -// consider something a "sqlcode source file" if it contains [code] -// or a --sqlcode: header -var isSqlCodeRegex = regexp.MustCompile(`^--sqlcode:|\[code\]`) diff --git a/sqlparser/parser_test.go b/sqlparser/parser_test.go index 39f1e58..d9e7ff6 100644 --- a/sqlparser/parser_test.go +++ b/sqlparser/parser_test.go @@ -13,21 +13,6 @@ import ( "github.com/stretchr/testify/require" ) -func TestPostgresqlCreate(t *testing.T) { - doc := ParseString("test.pgsql", ` -create procedure [code].test() -language plpgsql -as $$ -begin - perform 1; -end; -$$; - `) - - require.Len(t, doc.Creates(), 1) - require.Equal(t, &stdlib.Driver{}, doc.Creates()[0].Driver) -} - func TestParserSmokeTest(t *testing.T) { doc := ParseString("test.sql", ` /* test is a test diff --git a/sqlparser/pgsql_document.go b/sqlparser/pgsql_document.go deleted file mode 100644 index e97ec32..0000000 --- a/sqlparser/pgsql_document.go +++ /dev/null @@ -1,552 +0,0 @@ -package sqlparser - -import ( - "fmt" - - "github.com/jackc/pgx/v5/stdlib" -) - -var PGSQLStatementTokens = []string{"create"} - -type PGSqlDocument struct { - creates []Create - errors []Error - - Pragma -} - -func (d PGSqlDocument) HasErrors() bool { - return len(d.errors) > 0 -} - -func (d *PGSqlDocument) Parse(s *Scanner) error { - err := d.ParsePragmas(s) - if err != nil { - d.errors = append(d.errors, Error{s.Start(), err.Error()}) - } - - return nil -} - -func (d PGSqlDocument) Creates() []Create { - return d.creates -} - -// Not yet implemented -func (d PGSqlDocument) Declares() []Declare { - return nil -} - -func (d PGSqlDocument) Errors() []Error { - return d.errors -} - -func (d PGSqlDocument) Empty() bool { - return len(d.creates) == 0 -} - -func (d PGSqlDocument) Sort() { - -} - -func (d PGSqlDocument) Include(other Document) { - -} - -func (d PGSqlDocument) WithoutPos() Document { - return &PGSqlDocument{} -} - -// No GO batch separator: -// -// PostgreSQL uses semicolons (;) to separate statements, not GO. -// Multiple CREATE statements can exist in the same file. -// -// No top-level DECLARE: -// -// In PostgreSQL, DECLARE is only used inside function/procedure bodies within BEGIN...END blocks, not as top-level batch statements. -// -// Multiple CREATEs per batch: -// -// Unlike T-SQL which requires procedures/functions to be alone in a batch, PostgreSQL allows multiple CREATE statements separated by semicolons. -// -// Semicolon handling: -// -// The semicolon is a statement terminator, not a batch separator, so parsing continues after encountering one. -// -// Dollar quoting: -// -// PostgreSQL uses $$ or $tag$ for quoting function bodies instead of BEGIN...END (this would be handled in parseCreate). -// -// CREATE OR REPLACE: -// -// PostgreSQL commonly uses CREATE OR REPLACE which would need special handling in parseCreate. -// -// Schema qualification: -// -// PostgreSQL uses schema.object notation rather than [schema].[object]. -func (doc *PGSqlDocument) parseBatch(s *Scanner, isFirst bool) (hasMore bool) { - batch := &Batch{ - TokenHandlers: map[string]func(*Scanner, *Batch) bool{ - "create": func(s *Scanner, n *Batch) bool { - // Parse CREATE FUNCTION, CREATE PROCEDURE, CREATE TYPE, etc. - c := doc.parseCreate(s, n.CreateStatements) - c.Driver = &stdlib.Driver{} - - // Prepend any leading comments/whitespace - c.Body = append(n.Nodes, c.Body...) - c.Docstring = n.DocString - doc.creates = append(doc.creates, c) - - return false - }, - }, - } - - hasMore = batch.Parse(s) - if batch.HasErrors() { - doc.errors = append(doc.errors, batch.Errors...) - } - - return hasMore - - // var nodes []Unparsed - // var docstring []PosString - // newLineEncounteredInDocstring := false - - // for { - // tt := s.TokenType() - // switch tt { - // case EOFToken: - // return false - // case WhitespaceToken, MultilineCommentToken: - // nodes = append(nodes, CreateUnparsed(s)) - // // do not reset docstring for a single trailing newline - // t := s.Token() - // if !newLineEncounteredInDocstring && (t == "\n" || t == "\r\n") { - // newLineEncounteredInDocstring = true - // } else { - // docstring = nil - // } - // s.NextToken() - // case SinglelineCommentToken: - // // Build up a list of single line comments for the "docstring"; - // // it is reset whenever we encounter something else - // docstring = append(docstring, PosString{s.Start(), s.Token()}) - // nodes = append(nodes, CreateUnparsed(s)) - // newLineEncounteredInDocstring = false - // s.NextToken() - // case ReservedWordToken: - // switch s.ReservedWord() { - // case "declare": - // // PostgreSQL doesn't have top-level DECLARE batches like T-SQL - // // DECLARE is only used inside function/procedure bodies - // if isFirst { - // doc.addError(s, "PostgreSQL 'declare' is used inside function bodies, not as top-level batch statements") - // } - // nodes = append(nodes, CreateUnparsed(s)) - // s.NextToken() - // docstring = nil - // case "create": - // // Parse CREATE FUNCTION, CREATE PROCEDURE, CREATE TYPE, etc. - // createStart := len(doc.creates) - // c := doc.parseCreate(s, createStart) - // c.Driver = &stdlib.Driver{} - - // // Prepend any leading comments/whitespace - // c.Body = append(nodes, c.Body...) - // c.Docstring = docstring - // doc.creates = append(doc.creates, c) - - // // Reset for next statement - // nodes = nil - // docstring = nil - // newLineEncounteredInDocstring = false - // default: - // doc.addError(s, "Expected 'create', got: "+s.ReservedWord()) - // s.NextToken() - // docstring = nil - // } - // case SemicolonToken: - // // PostgreSQL uses semicolons as statement terminators - // // Multiple CREATE statements can exist in same file - // nodes = append(nodes, CreateUnparsed(s)) - // s.NextToken() - // // Continue parsing - don't return like T-SQL does with GO - // case BatchSeparatorToken: - // // PostgreSQL doesn't use GO batch separators - // // Q: Do we want to use GO batch separators as a feature of sqlcode? - // doc.addError(s, "PostgreSQL does not use 'GO' batch separators; use semicolons instead") - // s.NextToken() - // docstring = nil - // default: - // doc.addError(s, fmt.Sprintf("Unexpected token in PostgreSQL document: %s", s.Token())) - // s.NextToken() - // docstring = nil - // } - // } -} - -// parseCreate parses PostgreSQL CREATE statements (FUNCTION, PROCEDURE, TYPE, etc.) -// Position is *on* the CREATE token. -// -// PostgreSQL CREATE syntax differences from T-SQL: -// - Supports CREATE OR REPLACE for functions/procedures -// - Uses dollar quoting ($$...$$) or $tag$...$tag$ for function bodies -// - Schema qualification uses dot notation: schema.function_name -// - Double-quoted identifiers preserve case: "MyFunction" -// - Function parameters use different syntax: func(param1 type1, param2 type2) -// - RETURNS clause specifies return type -// - LANGUAGE clause (plpgsql, sql, etc.) is required -// - Function characteristics: IMMUTABLE, STABLE, VOLATILE, PARALLEL SAFE, etc. -// -// We parse until we hit a semicolon or EOF, tracking dependencies on other objects. -func (doc *PGSqlDocument) parseCreate(s *Scanner, createCountInBatch int) (result Create) { - var body []Unparsed - - // Copy the CREATE token - CopyToken(s, &body) - s.NextNonWhitespaceCommentToken() - - // Check for OR REPLACE - // NOTE: "or replace" doesn't make sense within sqlcode as this will be created within a new - // schema. - if s.TokenType() == ReservedWordToken && s.ReservedWord() == "or" { - CopyToken(s, &body) - s.NextNonWhitespaceCommentToken() - - if s.TokenType() == ReservedWordToken && s.ReservedWord() == "replace" { - CopyToken(s, &body) - s.NextNonWhitespaceCommentToken() - } else { - doc.addError(s, "Expected 'REPLACE' after 'OR'") - RecoverToNextStatementCopying(s, &body, PGSQLStatementTokens) - result.Body = body - return - } - } - - // Parse the object type (FUNCTION, PROCEDURE, TYPE, etc.) - if s.TokenType() != ReservedWordToken { - doc.addError(s, "Expected object type after CREATE (e.g., FUNCTION, PROCEDURE, TYPE)") - RecoverToNextStatementCopying(s, &body, PGSQLStatementTokens) - result.Body = body - return - } - - createType := s.ReservedWord() - result.CreateType = createType - CopyToken(s, &body) - s.NextNonWhitespaceCommentToken() - - // Validate supported CREATE types - switch createType { - case "function", "procedure", "type": - // Supported types - default: - doc.addError(s, fmt.Sprintf("Unsupported CREATE type for PostgreSQL: %s", createType)) - RecoverToNextStatementCopying(s, &body, PGSQLStatementTokens) - result.Body = body - return - } - - // Insist on [code] to provide the ability for sqlcode to patch function bodies - // with references to other sqlcode objects. - if s.TokenType() != QuotedIdentifierToken || s.Token() != "[code]" { - doc.addError(s, fmt.Sprintf("create %s must be followed by [code].", result.CreateType)) - RecoverToNextStatementCopying(s, &result.Body, PGSQLStatementTokens) - return - } - var err error - result.QuotedName, err = ParseCodeschemaName(s, &result.Body, PGSQLStatementTokens) - if err != nil { - doc.addError(s, err.Error()) - } - if result.QuotedName.String() == "" { - return - } - - // Parse function/procedure signature or type definition - switch createType { - case "function", "procedure": - doc.parseFunctionSignature(s, &body, &result) - case "type": - doc.parseTypeDefinition(s, &body, &result) - } - - // Parse the rest of the CREATE statement body until semicolon or EOF - doc.parseCreateBody(s, &body, &result) - - result.Body = body - return -} - -// parseQualifiedName parses schema-qualified or simple object names -// Supports: simple_name, schema.name, "Quoted Name", schema."Quoted Name" -func (doc *PGSqlDocument) parseQualifiedName(s *Scanner, body *[]Unparsed) string { - var nameParts []string - - for { - switch s.TokenType() { - case UnquotedIdentifierToken: - nameParts = append(nameParts, s.Token()) - CopyToken(s, body) - s.NextNonWhitespaceCommentToken() - case QuotedIdentifierToken: - // PostgreSQL uses double quotes for case-sensitive identifiers - nameParts = append(nameParts, s.Token()) - CopyToken(s, body) - s.NextNonWhitespaceCommentToken() - default: - if len(nameParts) == 0 { - return "" - } - // Return the last part as the object name (without schema) - return nameParts[len(nameParts)-1] - } - - // Check for dot separator (schema.object) - if s.TokenType() == DotToken { - CopyToken(s, body) - s.NextNonWhitespaceCommentToken() - continue - } - - break - } - - if len(nameParts) == 0 { - return "" - } - return nameParts[len(nameParts)-1] -} - -// parseFunctionSignature parses function/procedure parameters and RETURNS clause -func (doc *PGSqlDocument) parseFunctionSignature(s *Scanner, body *[]Unparsed, result *Create) { - // Expect opening parenthesis for parameters - if s.TokenType() != LeftParenToken { - doc.addError(s, "Expected '(' for function parameters") - return - } - - CopyToken(s, body) - s.NextNonWhitespaceCommentToken() - - // Parse parameters until closing parenthesis - parenDepth := 1 - for parenDepth > 0 { - switch s.TokenType() { - case EOFToken: - doc.addError(s, "Unexpected EOF in function parameters") - return - case LeftParenToken: - parenDepth++ - CopyToken(s, body) - s.NextToken() - case RightParenToken: - parenDepth-- - CopyToken(s, body) - s.NextToken() - case SemicolonToken: - doc.addError(s, "Unexpected semicolon in function parameters") - return - default: - CopyToken(s, body) - s.NextToken() - } - } - - s.SkipWhitespaceComments() - - // Parse RETURNS clause (for functions, not procedures) - if result.CreateType == "function" { - if s.TokenType() == ReservedWordToken && s.ReservedWord() == "returns" { - CopyToken(s, body) - s.NextNonWhitespaceCommentToken() - - // Handle RETURNS TABLE(...) - if s.TokenType() == ReservedWordToken && s.ReservedWord() == "table" { - CopyToken(s, body) - s.NextNonWhitespaceCommentToken() - - if s.TokenType() == LeftParenToken { - doc.parseReturnTable(s, body) - } - } else { - // Parse simple return type - doc.parseTypeExpression(s, body) - } - } - } -} - -// parseReturnTable parses RETURNS TABLE(...) syntax -func (doc *PGSqlDocument) parseReturnTable(s *Scanner, body *[]Unparsed) { - parenDepth := 0 - for { - switch s.TokenType() { - case EOFToken, SemicolonToken: - return - case LeftParenToken: - parenDepth++ - case RightParenToken: - parenDepth-- - CopyToken(s, body) - s.NextToken() - if parenDepth == 0 { - return - } - continue - } - CopyToken(s, body) - s.NextToken() - } -} - -// parseTypeExpression parses PostgreSQL type expressions -// Supports: int, integer, text, varchar(n), numeric(p,s), arrays (int[]), etc. -func (doc *PGSqlDocument) parseTypeExpression(s *Scanner, body *[]Unparsed) { - // Parse base type - if s.TokenType() != UnquotedIdentifierToken && s.TokenType() != ReservedWordToken { - return - } - - CopyToken(s, body) - s.NextNonWhitespaceCommentToken() - - // Handle array notation: type[] - // if s.TokenType() == LeftBracketToken { - // CopyToken(s, body) - // s.NextNonWhitespaceCommentToken() - - // if s.TokenType() == RightBracketToken { - // CopyToken(s, body) - // s.NextNonWhitespaceCommentToken() - // } - // } - - // Handle type parameters: varchar(100), numeric(10,2) - if s.TokenType() == LeftParenToken { - parenDepth := 1 - CopyToken(s, body) - s.NextToken() - - for parenDepth > 0 { - switch s.TokenType() { - case EOFToken, SemicolonToken: - return - case LeftParenToken: - parenDepth++ - case RightParenToken: - parenDepth-- - } - CopyToken(s, body) - s.NextToken() - } - } -} - -// parseTypeDefinition parses CREATE TYPE syntax -// Supports: ENUM, composite types, range types -func (doc *PGSqlDocument) parseTypeDefinition(s *Scanner, body *[]Unparsed, result *Create) { - // TYPE definitions use AS keyword - if s.TokenType() == ReservedWordToken && s.ReservedWord() == "as" { - CopyToken(s, body) - s.NextNonWhitespaceCommentToken() - - // Check for ENUM, RANGE, or composite type - if s.TokenType() == ReservedWordToken { - typeKind := s.ReservedWord() - switch typeKind { - case "enum", "range": - CopyToken(s, body) - s.NextNonWhitespaceCommentToken() - } - } - } -} - -// parseCreateBody parses the body of a CREATE statement -// Handles dollar-quoted strings, tracks dependencies, continues until semicolon/EOF -func (doc *PGSqlDocument) parseCreateBody(s *Scanner, body *[]Unparsed, result *Create) { - dollarQuoteDepth := 0 - var currentDollarTag string - - for { - switch s.TokenType() { - case EOFToken: - return - case SemicolonToken: - // Statement terminator - we're done - CopyToken(s, body) - s.NextToken() - return - case DollarQuotedStringStartToken: - // PostgreSQL dollar quoting: $$...$$ or $tag$...$tag$ - currentDollarTag = s.Token() - dollarQuoteDepth++ - CopyToken(s, body) - s.NextToken() - case DollarQuotedStringEndToken: - if s.Token() == currentDollarTag { - dollarQuoteDepth-- - } - CopyToken(s, body) - s.NextToken() - if dollarQuoteDepth == 0 { - currentDollarTag = "" - } - case UnquotedIdentifierToken, QuotedIdentifierToken: - // Track dependencies on tables/views/functions - // In PostgreSQL, identifiers can be qualified: schema.object - identifier := s.Token() - - // Check if this might be a dependency (after FROM, JOIN, etc.) - if doc.mightBeDependency(s) { - // Extract just the object name (without schema prefix) - objectName := doc.extractObjectName(identifier) - result.DependsOn = append(result.DependsOn, PosString{s.Start(), objectName}) - } - - CopyToken(s, body) - s.NextToken() - default: - CopyToken(s, body) - s.NextToken() - } - } -} - -// mightBeDependency checks if current context suggests a table/view/function reference -func (doc *PGSqlDocument) mightBeDependency(s *Scanner) bool { - // Simple heuristic: look back for FROM, JOIN, INTO, etc. - // This would need to track parse context for accurate dependency detection - return false // Placeholder - implement context-aware dependency tracking -} - -// extractObjectName extracts object name from schema-qualified identifier -func (doc *PGSqlDocument) extractObjectName(identifier string) string { - // Handle schema.object notation - // For now, return as-is; proper implementation would split on dot - return identifier -} - -func (doc *PGSqlDocument) addError(s *Scanner, err string) { - doc.errors = append(doc.errors, Error{ - s.Start(), err, - }) -} - -func (doc *PGSqlDocument) parseDeclareBatch(s *Scanner) (hasMore bool) { - // PostgreSQL doesn't have top-level DECLARE batches like T-SQL - // DECLARE is only used inside function/procedure bodies (in BEGIN...END blocks) - doc.addError(s, "PostgreSQL does not support top-level DECLARE statements outside of function bodies") - RecoverToNextStatement(s, PGSQLStatementTokens) - return false -} - -func (doc *PGSqlDocument) parseBatchSeparator(s *Scanner) { - // PostgreSQL doesn't use GO batch separators - doc.addError(s, "PostgreSQL does not use 'GO' batch separators; use semicolons") - s.NextToken() -} diff --git a/sqlparser/pgsql_document_test.go b/sqlparser/pgsql_document_test.go deleted file mode 100644 index 2d3c8af..0000000 --- a/sqlparser/pgsql_document_test.go +++ /dev/null @@ -1,254 +0,0 @@ -package sqlparser - -import ( - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestDocument_PostgreSQL17_parseCreate(t *testing.T) { - t.Run("parses PostgreSQL function with dollar quoting", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.pgsql", "create function test_func() returns int as $$ begin return 1; end; $$ language plpgsql") - s.NextToken() - s.NextNonWhitespaceCommentToken() - - create := doc.parseCreate(s, 0) - - assert.Equal(t, "function", create.CreateType) - assert.Equal(t, "test_func", create.QuotedName.Value) - }) - - t.Run("parses PostgreSQL procedure", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.pgsql", "create procedure insert_data(a integer, b integer) language sql as $$ insert into tbl values (a, b); $$") - s.NextToken() - s.NextNonWhitespaceCommentToken() - - create := doc.parseCreate(s, 0) - - assert.Equal(t, "procedure", create.CreateType) - assert.Equal(t, "insert_data", create.QuotedName.Value) - }) - - t.Run("parses CREATE OR REPLACE", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.pgsql", "create or replace function test_func() returns int as $$ begin return 1; end; $$ language plpgsql") - s.NextToken() - s.NextNonWhitespaceCommentToken() - - create := doc.parseCreate(s, 0) - - assert.Equal(t, "function", create.CreateType) - }) - - t.Run("parses schema-qualified name", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.pgsql", "create function public.test_func() returns int as $$ begin return 1; end; $$ language plpgsql") - s.NextToken() - s.NextNonWhitespaceCommentToken() - - create := doc.parseCreate(s, 0) - - assert.Contains(t, create.QuotedName.Value, "test_func") - }) - - t.Run("parses RETURNS TABLE", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.pgsql", "create function get_users() returns table(id int, name text) as $$ select id, name from users; $$ language sql") - s.NextToken() - s.NextNonWhitespaceCommentToken() - - create := doc.parseCreate(s, 0) - - assert.Equal(t, "function", create.CreateType) - }) - - t.Run("tracks dependencies with schema prefix", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.pgsql", "create function test() returns int as $$ select * from public.table1 join public.table2 on table1.id = table2.id; $$ language sql") - s.NextToken() - s.NextNonWhitespaceCommentToken() - - create := doc.parseCreate(s, 0) - - require.Len(t, create.DependsOn, 2) - }) - - t.Run("parses volatility categories", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.pgsql", "create function test_func() returns int immutable as $$ begin return 1; end; $$ language plpgsql") - s.NextToken() - s.NextNonWhitespaceCommentToken() - - create := doc.parseCreate(s, 0) - - assert.Equal(t, "function", create.CreateType) - }) - - t.Run("parses PARALLEL SAFE", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.pgsql", "create function test_func() returns int parallel safe as $$ begin return 1; end; $$ language plpgsql") - s.NextToken() - s.NextNonWhitespaceCommentToken() - - create := doc.parseCreate(s, 0) - - assert.Equal(t, "function", create.CreateType) - }) -} - -func TestDocument_PostgreSQL17_Types(t *testing.T) { - t.Run("parses composite type", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.pgsql", "create type address_type as (street text, city text, zip varchar(10))") - s.NextToken() - s.NextNonWhitespaceCommentToken() - - create := doc.parseCreate(s, 0) - - assert.Equal(t, "type", create.CreateType) - }) - - t.Run("parses enum type", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.pgsql", "create type mood as enum ('sad', 'ok', 'happy')") - s.NextToken() - s.NextNonWhitespaceCommentToken() - - create := doc.parseCreate(s, 0) - - assert.Equal(t, "type", create.CreateType) - }) - - t.Run("parses range type", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.pgsql", "create type float_range as range (subtype = float8, subtype_diff = float8mi)") - s.NextToken() - s.NextNonWhitespaceCommentToken() - - create := doc.parseCreate(s, 0) - - assert.Equal(t, "type", create.CreateType) - }) -} - -func TestDocument_PostgreSQL17_Extensions(t *testing.T) { - t.Run("parses JSON functions PostgreSQL 17", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.pgsql", "create function test() returns jsonb as $$ select json_serialize(data) from table1; $$ language sql") - s.NextToken() - s.NextNonWhitespaceCommentToken() - - create := doc.parseCreate(s, 0) - - assert.Equal(t, "function", create.CreateType) - }) - - t.Run("parses MERGE statement (PostgreSQL 15+)", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.pgsql", "create function do_merge() returns void as $$ merge into target using source on target.id = source.id when matched then update set value = source.value; $$ language sql") - s.NextToken() - s.NextNonWhitespaceCommentToken() - - create := doc.parseCreate(s, 0) - - assert.Equal(t, "function", create.CreateType) - }) -} - -func TestDocument_PostgreSQL17_Identifiers(t *testing.T) { - t.Run("parses double-quoted identifiers", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.pgsql", `create function "Test Func"() returns int as $$ begin return 1; end; $$ language plpgsql`) - s.NextToken() - s.NextNonWhitespaceCommentToken() - - create := doc.parseCreate(s, 0) - - assert.Contains(t, create.QuotedName.Value, "Test Func") - }) - - t.Run("parses case-sensitive identifiers", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.pgsql", `create function "TestFunc"() returns int as $$ begin return 1; end; $$ language plpgsql`) - s.NextToken() - s.NextNonWhitespaceCommentToken() - - create := doc.parseCreate(s, 0) - - assert.Contains(t, create.QuotedName.Value, "TestFunc") - }) -} - -func TestDocument_PostgreSQL17_Datatypes(t *testing.T) { - t.Run("parses array types", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.pgsql", "integer[]") - s.NextToken() - - typ := doc.parseTypeExpression(s) - - assert.Equal(t, "integer[]", typ.BaseType) - }) - - t.Run("parses serial types", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.pgsql", "serial") - s.NextToken() - - typ := doc.parseTypeExpression(s) - - assert.Equal(t, "serial", typ.BaseType) - }) - - t.Run("parses text type", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.pgsql", "text") - s.NextToken() - - typ := doc.parseTypeExpression(s) - - assert.Equal(t, "text", typ.BaseType) - }) - - t.Run("parses jsonb type", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.pgsql", "jsonb") - s.NextToken() - - typ := doc.parseTypeExpression(s) - - assert.Equal(t, "jsonb", typ.BaseType) - }) - - t.Run("parses uuid type", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.pgsql", "uuid") - s.NextToken() - - typ := doc.parseTypeExpression(s) - - assert.Equal(t, "uuid", typ.BaseType) - }) -} - -func TestDocument_PostgreSQL17_BatchSeparator(t *testing.T) { - t.Run("PostgreSQL uses semicolon not GO", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.pgsql", "create function test1() returns int as $$ begin return 1; end; $$ language plpgsql; create function test2() returns int as $$ begin return 2; end; $$ language plpgsql;") - s.NextToken() - s.NextNonWhitespaceCommentToken() - - create1 := doc.parseCreate(s, 0) - assert.Equal(t, "test1", create1.QuotedName.Value) - - // Move to next statement - s.NextNonWhitespaceCommentToken() - s.NextNonWhitespaceCommentToken() - - create2 := doc.parseCreate(s, 1) - assert.Equal(t, "test2", create2.QuotedName.Value) - }) -} diff --git a/sqlparser/scanner_test.go b/sqlparser/scanner_test.go deleted file mode 100644 index cff40fc..0000000 --- a/sqlparser/scanner_test.go +++ /dev/null @@ -1,292 +0,0 @@ -package sqlparser - -import ( - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "strings" - "testing" -) - -func TestNextToken(t *testing.T) { - // just check that regexp should return nil if we didn't start to match... - assert.Equal(t, []int(nil), numberRegexp.FindStringIndex("a123")) - - testExt := func(startOfLine bool, prefix, input string, expectedTokenType TokenType, expected string, extraAssertion ...func(s *Scanner)) func(*testing.T) { - return func(t *testing.T) { - s := &Scanner{input: prefix + input, curIndex: len(prefix), startOfLine: startOfLine} - tt := s.NextToken() - assert.Equal(t, expectedTokenType, tt) - assert.Equal(t, expected, s.Token()) - for _, a := range extraAssertion { - a(s) - } - } - } - - test := func(input string, expectedTokenType TokenType, expected string, extraAssertion ...func(s *Scanner)) func(*testing.T) { - return testExt(false, "abcd", input, expectedTokenType, expected, extraAssertion...) - } - - t.Run("", test(" ", WhitespaceToken, " ")) - t.Run("", test(" a ", WhitespaceToken, " ")) - t.Run("", test(" \t\t\n\n \t \nasdf", WhitespaceToken, " \t\t\n\n \t \n")) - - t.Run("", test("123", NumberToken, "123")) - t.Run("", test("123;\n", NumberToken, "123")) - t.Run("", test("123\n", NumberToken, "123")) - t.Run("", test("123 ", NumberToken, "123")) - t.Run("", test("+123.e-3_asdf", NumberToken, "+123.e-3")) - t.Run("", test("-123.e+3+a", NumberToken, "-123.e+3")) - t.Run("", test("-123.12e3+a", NumberToken, "-123.12e3")) - t.Run("", test("-123.12e-35+a", NumberToken, "-123.12e-35")) - t.Run("", test("-123.12ea", NumberToken, "-123.12e")) - t.Run("", test("-123.12;\n", NumberToken, "-123.12")) - - t.Run("", test("'hello world'", VarcharLiteralToken, "'hello world'")) - t.Run("", test("'hello world'after", VarcharLiteralToken, "'hello world'")) - t.Run("", test("'hello '' world'after", VarcharLiteralToken, "'hello '' world'")) - t.Run("", test("''''", VarcharLiteralToken, "''''")) - t.Run("", test("''", VarcharLiteralToken, "''")) - - t.Run("", test("N'hello world'after", NVarcharLiteralToken, "N'hello world'")) - t.Run("", test("N''", NVarcharLiteralToken, "N''")) - - t.Run("", test("'''hello", UnterminatedVarcharLiteralErrorToken, "'''hello")) - t.Run("", test("N'''hello", UnterminatedVarcharLiteralErrorToken, "N'''hello")) - - t.Run("", test("[ quote \n quote]] hi]asdf", QuotedIdentifierToken, "[ quote \n quote]] hi]")) - t.Run("", test("[][]", QuotedIdentifierToken, "[]")) - t.Run("", test("[]]]", QuotedIdentifierToken, "[]]]")) - t.Run("", test("[]]test", UnterminatedQuotedIdentifierErrorToken, "[]]test")) - - t.Run("", test("/* comment\n\n */asdf", MultilineCommentToken, "/* comment\n\n */")) - t.Run("", test("/* comment\n\n ****/asdf", MultilineCommentToken, "/* comment\n\n ****/")) - // unterminated multiline comment is treated like a comment - t.Run("", test("/* comment\n\n asdf", MultilineCommentToken, "/* comment\n\n asdf")) - - // single stopLine comment .. trailing \n is not considered part of token - t.Run("", test("-- test\nhello", SinglelineCommentToken, "-- test")) - t.Run("", test("-- test", SinglelineCommentToken, "-- test")) - - t.Run("", test(`"asdf`, DoubleQuoteErrorToken, `"`)) - - t.Run("", test(``, EOFToken, ``)) - - t.Run("", test("abc", UnquotedIdentifierToken, "abc")) - t.Run("", test("@a#$$__bc a", VariableIdentifierToken, "@a#$$__bc")) - // identifier starting with N is special branch - t.Run("", test("N@a#$$__bc a", UnquotedIdentifierToken, "N@a#$$__bc")) - - t.Run("", test("