From d870c46526aee3ce21fb0ab9c083d1aaa327e883 Mon Sep 17 00:00:00 2001 From: hbc Date: Fri, 2 Jan 2026 01:08:35 -0800 Subject: [PATCH 1/2] Add unix socket listening option --- README.md | 10 ++++- server.go | 64 ++++++++++++++++++++++++++---- server_socket_test.go | 90 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 156 insertions(+), 8 deletions(-) create mode 100644 server_socket_test.go diff --git a/README.md b/README.md index 2ef0ef4..b59c196 100644 --- a/README.md +++ b/README.md @@ -58,6 +58,14 @@ $ sqlite-rest serve --auth-token-file test.token --security-allow-table books -- ... ``` +### Start server with Unix domain socket + +``` +$ sqlite-rest serve --auth-token-file test.token --security-allow-table books --db-dsn ./bookstore.sqlite3 --http-socket /tmp/sqlite-rest.sock +{"level":"info","ts":1672528510.825417,"logger":"db-server","caller":"sqlite-rest/server.go:121","msg":"server started","socket":"/tmp/sqlite-rest.sock"} +... +``` + ### Generate authentication token **NOTE: the following steps create a sample token for testing only, please use a strong password in production.** @@ -178,4 +186,4 @@ $ sqlite-rest migrate --db-dsn ./bookstore.sqlite3 --direction down --step 1 ./e ## License -MIT \ No newline at end of file +MIT diff --git a/server.go b/server.go index 6be2857..e5cb9e8 100644 --- a/server.go +++ b/server.go @@ -5,9 +5,11 @@ import ( "encoding/json" "errors" "fmt" + "net" "net/http" "os" "os/signal" + "path/filepath" "syscall" "time" @@ -27,6 +29,7 @@ const ( type ServerOptions struct { Logger logr.Logger Addr string + SocketPath string AuthOptions ServerAuthOptions SecurityOptions ServerSecurityOptions Queryer sqlx.QueryerContext @@ -35,6 +38,7 @@ type ServerOptions struct { func (opts *ServerOptions) bindCLIFlags(fs *pflag.FlagSet) { fs.StringVar(&opts.Addr, "http-addr", ":8080", "server listen address") + fs.StringVar(&opts.SocketPath, "http-socket", "", "server listen unix socket path. If set, http-addr will be ignored") opts.AuthOptions.bindCLIFlags(fs) opts.SecurityOptions.bindCLIFlags(fs) @@ -52,7 +56,11 @@ func (opts *ServerOptions) defaults() error { opts.Logger = logr.Discard() } - if opts.Addr == "" { + if opts.SocketPath != "" { + opts.Addr = "" + } + + if opts.Addr == "" && opts.SocketPath == "" { opts.Addr = ":8080" } @@ -68,10 +76,12 @@ func (opts *ServerOptions) defaults() error { } type dbServer struct { - logger logr.Logger - server *http.Server - queryer sqlx.QueryerContext - execer sqlx.ExecerContext + logger logr.Logger + server *http.Server + listener net.Listener + socket string + queryer sqlx.QueryerContext + execer sqlx.ExecerContext } func NewServer(opts *ServerOptions) (*dbServer, error) { @@ -86,6 +96,7 @@ func NewServer(opts *ServerOptions) (*dbServer, error) { // TODO: make it configurable ReadHeaderTimeout: 5 * time.Second, }, + socket: opts.SocketPath, queryer: opts.Queryer, execer: opts.Execer, } @@ -128,15 +139,54 @@ func NewServer(opts *ServerOptions) (*dbServer, error) { } func (server *dbServer) Start(done <-chan struct{}) { - go server.server.ListenAndServe() + if server.socket != "" { + if err := os.MkdirAll(filepath.Dir(server.socket), 0755); err != nil { + server.logger.Error(err, "failed to ensure unix socket directory", "socket", server.socket) + return + } + + if err := os.RemoveAll(server.socket); err != nil { + server.logger.Error(err, "failed to remove stale unix socket", "socket", server.socket) + return + } + + l, err := net.Listen("unix", server.socket) + if err != nil { + server.logger.Error(err, "failed to listen on unix socket", "socket", server.socket) + return + } + server.listener = l + + go server.server.Serve(l) + server.logger.Info("server started", "socket", server.socket) + } else { + l, err := net.Listen("tcp", server.server.Addr) + if err != nil { + server.logger.Error(err, "failed to listen on tcp address", "addr", server.server.Addr) + return + } + server.listener = l + + go server.server.Serve(l) + server.logger.Info("server started", "addr", server.server.Addr) + } - server.logger.Info("server started", "addr", server.server.Addr) <-done server.logger.Info("shutting down server") shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() server.server.Shutdown(shutdownCtx) + + if server.listener != nil { + server.listener.Close() + } + + if server.socket != "" { + if err := os.Remove(server.socket); err != nil && !errors.Is(err, os.ErrNotExist) { + server.logger.Error(err, "failed to clean up unix socket", "socket", server.socket) + } + } } func (server *dbServer) responseHeader(w http.ResponseWriter, statusCode int) { diff --git a/server_socket_test.go b/server_socket_test.go new file mode 100644 index 0000000..2347763 --- /dev/null +++ b/server_socket_test.go @@ -0,0 +1,90 @@ +package main + +import ( + "context" + "encoding/json" + "net" + "net/http" + "os" + "path/filepath" + "testing" + "time" + + "github.com/jmoiron/sqlx" + "github.com/stretchr/testify/assert" +) + +func TestServerWithUnixSocket(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + socketPath := filepath.Join(dir, "sqlite-rest.sock") + + dbPath := filepath.Join(dir, "test.db") + db, err := sqlx.Open("sqlite3", dbPath) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + _, err = db.Exec("CREATE TABLE test (id int)") + if err != nil { + t.Fatal(err) + } + _, err = db.Exec(`INSERT INTO test (id) VALUES (1)`) + if err != nil { + t.Fatal(err) + } + + serverOpts := &ServerOptions{ + Logger: createTestLogger(t).WithName("test"), + Queryer: db, + Execer: db, + SocketPath: socketPath, + } + serverOpts.AuthOptions.disableAuth = true + serverOpts.SecurityOptions.EnabledTableOrViews = []string{"test"} + + server, err := NewServer(serverOpts) + if err != nil { + t.Fatal(err) + } + + done := make(chan struct{}) + go server.Start(done) + + assert.Eventually(t, func() bool { + _, err := os.Stat(socketPath) + return err == nil + }, 5*time.Second, 100*time.Millisecond) + + client := &http.Client{ + Transport: &http.Transport{ + DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) { + var d net.Dialer + return d.DialContext(ctx, "unix", socketPath) + }, + }, + } + + req, err := http.NewRequest(http.MethodGet, "http://unix/test", nil) + if err != nil { + t.Fatal(err) + } + + resp, err := client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + assert.Equal(t, http.StatusOK, resp.StatusCode) + + var rows []map[string]interface{} + err = json.NewDecoder(resp.Body).Decode(&rows) + assert.NoError(t, err) + assert.Len(t, rows, 1) + assert.EqualValues(t, 1, rows[0]["id"]) + + close(done) +} From 613ec72777f8498ea95ef69949214322c578698b Mon Sep 17 00:00:00 2001 From: hbc Date: Fri, 2 Jan 2026 01:29:20 -0800 Subject: [PATCH 2/2] Address review feedback for unix socket --- README.md | 4 ++-- server.go | 13 ++++++------- server_socket_test.go | 20 +++++++++++++++++++- 3 files changed, 27 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index b59c196..53279e5 100644 --- a/README.md +++ b/README.md @@ -54,7 +54,7 @@ $ sqlite3 bookstore.sqlite3 < examples/bookstore/data.sql ``` $ echo -n "topsecret" > test.token $ sqlite-rest serve --auth-token-file test.token --security-allow-table books --db-dsn ./bookstore.sqlite3 -{"level":"info","ts":1672528510.825417,"logger":"db-server","caller":"sqlite-rest/server.go:121","msg":"server started","addr":":8080"} +{"level":"info","ts":1672528510.825417,"logger":"db-server","msg":"server started","addr":":8080"} ... ``` @@ -62,7 +62,7 @@ $ sqlite-rest serve --auth-token-file test.token --security-allow-table books -- ``` $ sqlite-rest serve --auth-token-file test.token --security-allow-table books --db-dsn ./bookstore.sqlite3 --http-socket /tmp/sqlite-rest.sock -{"level":"info","ts":1672528510.825417,"logger":"db-server","caller":"sqlite-rest/server.go:121","msg":"server started","socket":"/tmp/sqlite-rest.sock"} +{"level":"info","ts":1672528510.825417,"logger":"db-server","msg":"server started","socket":"/tmp/sqlite-rest.sock"} ... ``` diff --git a/server.go b/server.go index e5cb9e8..c10458d 100644 --- a/server.go +++ b/server.go @@ -140,9 +140,12 @@ func NewServer(opts *ServerOptions) (*dbServer, error) { func (server *dbServer) Start(done <-chan struct{}) { if server.socket != "" { - if err := os.MkdirAll(filepath.Dir(server.socket), 0755); err != nil { - server.logger.Error(err, "failed to ensure unix socket directory", "socket", server.socket) - return + sockDir := filepath.Dir(server.socket) + if sockDir != "" && sockDir != "." { + if err := os.MkdirAll(sockDir, 0755); err != nil { + server.logger.Error(err, "failed to ensure unix socket directory", "socket", server.socket) + return + } } if err := os.RemoveAll(server.socket); err != nil { @@ -178,10 +181,6 @@ func (server *dbServer) Start(done <-chan struct{}) { defer cancel() server.server.Shutdown(shutdownCtx) - if server.listener != nil { - server.listener.Close() - } - if server.socket != "" { if err := os.Remove(server.socket); err != nil && !errors.Is(err, os.ErrNotExist) { server.logger.Error(err, "failed to clean up unix socket", "socket", server.socket) diff --git a/server_socket_test.go b/server_socket_test.go index 2347763..bbc5c9e 100644 --- a/server_socket_test.go +++ b/server_socket_test.go @@ -3,6 +3,7 @@ package main import ( "context" "encoding/json" + "errors" "net" "net/http" "os" @@ -51,7 +52,11 @@ func TestServerWithUnixSocket(t *testing.T) { } done := make(chan struct{}) - go server.Start(done) + serverDone := make(chan struct{}) + go func() { + server.Start(done) + close(serverDone) + }() assert.Eventually(t, func() bool { _, err := os.Stat(socketPath) @@ -87,4 +92,17 @@ func TestServerWithUnixSocket(t *testing.T) { assert.EqualValues(t, 1, rows[0]["id"]) close(done) + assert.Eventually(t, func() bool { + select { + case <-serverDone: + return true + default: + return false + } + }, 2*time.Second, 50*time.Millisecond) + + assert.Eventually(t, func() bool { + _, err := os.Stat(socketPath) + return errors.Is(err, os.ErrNotExist) + }, 5*time.Second, 100*time.Millisecond) }