diff --git a/astral/router.go b/astral/router.go index b83ba2f14..344215f6e 100644 --- a/astral/router.go +++ b/astral/router.go @@ -1,7 +1,6 @@ package astral import ( - "context" "io" ) @@ -9,4 +8,4 @@ type Router interface { RouteQuery(ctx *Context, q *Query, w io.WriteCloser) (io.WriteCloser, error) } -type RouteQueryFunc func(ctx context.Context, q *Query, w io.WriteCloser) (io.WriteCloser, error) +type RouteQueryFunc func(ctx *Context, q *Query, w io.WriteCloser) (io.WriteCloser, error) diff --git a/mod/gateway/README.md b/mod/gateway/README.md new file mode 100644 index 000000000..be77fd38a --- /dev/null +++ b/mod/gateway/README.md @@ -0,0 +1,16 @@ +# gateway + +## Configuration + +`gateway.yaml`: + +```yaml +gateway: + enabled: true + listen: + - tcp::6000 + +visibility: public +init_conns: 1 +max_conns: 8 +``` diff --git a/mod/gateway/client/bind.go b/mod/gateway/client/bind.go new file mode 100644 index 000000000..3a484aa74 --- /dev/null +++ b/mod/gateway/client/bind.go @@ -0,0 +1,24 @@ +package gateway + +import ( + "github.com/cryptopunkscc/astrald/astral" + "github.com/cryptopunkscc/astrald/astral/channel" + "github.com/cryptopunkscc/astrald/lib/query" + gw "github.com/cryptopunkscc/astrald/mod/gateway" +) + +func (c *Client) Bind(ctx *astral.Context, visibility gw.Visibility) (*gw.Socket, error) { + ch, err := c.queryCh(ctx, gw.MethodBind, query.Args{"visibility": string(visibility)}) + if err != nil { + return nil, err + } + defer ch.Close() + + var socket *gw.Socket + err = ch.Switch( + channel.Expect(&socket), + channel.PassErrors, + ) + + return socket, err +} diff --git a/mod/gateway/client/client.go b/mod/gateway/client/client.go new file mode 100644 index 000000000..756aa8820 --- /dev/null +++ b/mod/gateway/client/client.go @@ -0,0 +1,36 @@ +package gateway + +import ( + "github.com/cryptopunkscc/astrald/astral" + "github.com/cryptopunkscc/astrald/astral/channel" + "github.com/cryptopunkscc/astrald/lib/astrald" +) + +type Client struct { + astral *astrald.Client + targetID *astral.Identity +} + +var defaultClient *Client + +func New(targetID *astral.Identity, a *astrald.Client) *Client { + if a == nil { + a = astrald.Default() + } + return &Client{astral: a, targetID: targetID} +} + +func Default() *Client { + if defaultClient == nil { + defaultClient = New(nil, nil) + } + return defaultClient +} + +func SetDefault(client *Client) { + defaultClient = client +} + +func (c *Client) queryCh(ctx *astral.Context, method string, args any, cfg ...channel.ConfigFunc) (*channel.Channel, error) { + return c.astral.WithTarget(c.targetID).QueryChannel(ctx, method, args, cfg...) +} diff --git a/mod/gateway/client/connect.go b/mod/gateway/client/connect.go new file mode 100644 index 000000000..e190a30a7 --- /dev/null +++ b/mod/gateway/client/connect.go @@ -0,0 +1,24 @@ +package gateway + +import ( + "github.com/cryptopunkscc/astrald/astral" + "github.com/cryptopunkscc/astrald/astral/channel" + "github.com/cryptopunkscc/astrald/lib/query" + gw "github.com/cryptopunkscc/astrald/mod/gateway" +) + +func (c *Client) Connect(ctx *astral.Context, target *astral.Identity) (*gw.Socket, error) { + ch, err := c.queryCh(ctx, gw.MethodConnect, query.Args{"target": target.String()}) + if err != nil { + return nil, err + } + defer ch.Close() + + var socket *gw.Socket + err = ch.Switch( + channel.Expect(&socket), + channel.PassErrors, + ) + + return socket, err +} diff --git a/mod/gateway/client/list.go b/mod/gateway/client/list.go new file mode 100644 index 000000000..945948ae6 --- /dev/null +++ b/mod/gateway/client/list.go @@ -0,0 +1,25 @@ +package gateway + +import ( + "github.com/cryptopunkscc/astrald/astral" + "github.com/cryptopunkscc/astrald/astral/channel" + "github.com/cryptopunkscc/astrald/lib/query" + gw "github.com/cryptopunkscc/astrald/mod/gateway" +) + +func (c *Client) List(ctx *astral.Context) ([]*astral.Identity, error) { + ch, err := c.queryCh(ctx, gw.MethodList, query.Args{}) + if err != nil { + return nil, err + } + defer ch.Close() + + var list []*astral.Identity + err = ch.Switch( + channel.Collect(&list), + channel.StopOnEOS, + channel.PassErrors, + ) + + return list, err +} diff --git a/mod/gateway/errors.go b/mod/gateway/errors.go new file mode 100644 index 000000000..f944a8cae --- /dev/null +++ b/mod/gateway/errors.go @@ -0,0 +1,8 @@ +package gateway + +import "errors" + +var ErrUnauthorized = errors.New("unauthorized") +var ErrTargetNotReachable = errors.New("target not reachable") +var ErrInvalidGateway = errors.New("invalid gateway") +var ErrSocketUnreachable = errors.New("socket unreachable") diff --git a/mod/gateway/maintain_binding_task.go b/mod/gateway/maintain_binding_task.go new file mode 100644 index 000000000..7877bdab3 --- /dev/null +++ b/mod/gateway/maintain_binding_task.go @@ -0,0 +1,7 @@ +package gateway + +import "github.com/cryptopunkscc/astrald/mod/scheduler" + +type MaintainBindingTask interface { + scheduler.Task +} diff --git a/mod/gateway/module.go b/mod/gateway/module.go new file mode 100644 index 000000000..80bfa6787 --- /dev/null +++ b/mod/gateway/module.go @@ -0,0 +1,13 @@ +package gateway + +const ModuleName = "gateway" + +const ( + MethodBind = "gateway.node_bind" + MethodConnect = "gateway.node_connect" + MethodList = "gateway.node_list" + MethodRoute = "gateway.route" +) + +type Module interface { +} diff --git a/mod/gateway/socket.go b/mod/gateway/socket.go new file mode 100644 index 000000000..9a6afa371 --- /dev/null +++ b/mod/gateway/socket.go @@ -0,0 +1,25 @@ +package gateway + +import ( + "io" + + "github.com/cryptopunkscc/astrald/astral" + "github.com/cryptopunkscc/astrald/mod/exonet" +) + +// Socket describes a raw connection point at the gateway. The recipient opens +// a raw exonet connection to the Endpoint and sends Nonce as the first bytes +// to identify itself to the gateway. +type Socket struct { + Endpoint exonet.Endpoint + Nonce astral.Nonce +} + +func (Socket) ObjectType() string { return "mod.gateway.socket" } + +func (s Socket) WriteTo(w io.Writer) (int64, error) { return astral.Objectify(&s).WriteTo(w) } +func (s *Socket) ReadFrom(r io.Reader) (int64, error) { return astral.Objectify(s).ReadFrom(r) } + +func init() { + astral.Add(&Socket{}) +} diff --git a/mod/gateway/src/accept.go b/mod/gateway/src/accept.go new file mode 100644 index 000000000..74814fa6a --- /dev/null +++ b/mod/gateway/src/accept.go @@ -0,0 +1,60 @@ +package gateway + +import ( + "context" + "fmt" + "time" + + "github.com/cryptopunkscc/astrald/astral" + "github.com/cryptopunkscc/astrald/mod/exonet" +) + +const ( + socketDeadTimeout = 10 * time.Second + socketProbeTimeout = 5 * time.Second +) + +// acceptSocketConn dispatches an incoming socket connection to either the binder +// or connector path based on the nonce it presents. +func (mod *Module) acceptSocketConn(_ context.Context, conn exonet.Conn) (stopListener bool, err error) { + mod.log.Logv(2, "accepting socket connection from %v", conn.RemoteEndpoint()) + + var nonce astral.Nonce + if _, err := nonce.ReadFrom(conn); err != nil { + mod.log.Errorv(1, "read nonce from %v: %v", conn.RemoteEndpoint(), err) + conn.Close() + return stopListener, nil + } + + if b, ok := mod.binderByNonce(nonce); ok { + mod.log.Infov(2, "added idle conn to binder %v", b.Identity) + bc := b.addConn(conn) + go bc.keepalive(nil, nil) + return stopListener, nil + } + + c, ok := mod.connectorByNonce(nonce) + if !ok { + mod.log.Errorv(1, "unknown nonce %v from %v", nonce, conn.RemoteEndpoint()) + conn.Close() + return stopListener, nil + } + + mod.connectors.Remove(c) + + reserved := c.takeReserved() + if reserved == nil { + conn.Close() + return stopListener, fmt.Errorf("no reserved conn for %v", c.Target) + } + + if !reserved.signal() { + mod.log.Errorv(1, "reserved conn for %v is dead", c.Target) + conn.Close() + return stopListener, nil + } + + mod.log.Infov(2, "pipe from %v to %v created", c.Identity, c.Target) + go pipe(reserved, conn) + return stopListener, nil +} diff --git a/mod/gateway/src/bind.go b/mod/gateway/src/bind.go new file mode 100644 index 000000000..d4b006da2 --- /dev/null +++ b/mod/gateway/src/bind.go @@ -0,0 +1,43 @@ +package gateway + +import ( + "github.com/cryptopunkscc/astrald/astral" + "github.com/cryptopunkscc/astrald/mod/gateway" +) + +func (mod *Module) bind(ctx *astral.Context, identity *astral.Identity, visibility gateway.Visibility, network string) (gateway.Socket, error) { + if !mod.canGateway(identity) { + return gateway.Socket{}, gateway.ErrUnauthorized + } + + endpoint, err := mod.getGatewayEndpoint(ctx, network) + if err != nil { + return gateway.Socket{}, err + } + + newBinder := &binder{ + Identity: identity, + Nonce: astral.NewNonce(), + Visibility: visibility, + } + + oldBinder, ok := mod.binders.Replace(identity.String(), newBinder) + if ok { + if err = oldBinder.Close(); err != nil { + mod.log.Error("failed to close old binder: %v", err) + } + + targetID := oldBinder.Identity.String() + for _, c := range mod.connectors.Clone() { + if c.Target.String() == targetID { + mod.connectors.Remove(c) + c.Close() + } + } + } + + return gateway.Socket{ + Nonce: newBinder.Nonce, + Endpoint: endpoint, + }, nil +} diff --git a/mod/gateway/src/binder.go b/mod/gateway/src/binder.go new file mode 100644 index 000000000..3ddba2896 --- /dev/null +++ b/mod/gateway/src/binder.go @@ -0,0 +1,228 @@ +package gateway + +import ( + "fmt" + "io" + "sync/atomic" + "time" + + "github.com/cryptopunkscc/astrald/astral" + "github.com/cryptopunkscc/astrald/mod/exonet" + "github.com/cryptopunkscc/astrald/mod/gateway" + "github.com/cryptopunkscc/astrald/mod/gateway/src/frames" + "github.com/cryptopunkscc/astrald/sig" +) + +// binder represents a node registered as reachable through the gateway. +// Only one binder registration per identity is allowed. +type binder struct { + Identity *astral.Identity + Nonce astral.Nonce + Visibility gateway.Visibility + conns sig.Set[*bindingConn] +} + +func (b *binder) addConn(conn exonet.Conn) *bindingConn { + var bc *bindingConn + bc = newGatewayConn(conn, func() { b.conns.Remove(bc) }) + b.conns.Add(bc) + return bc +} + +// takeConn reserves an idle bindingConn for a connector via atomic CAS. +// CAS on active is sufficient — no mutex needed. +func (b *binder) takeConn() (*bindingConn, bool) { + for _, bc := range b.conns.Clone() { + if bc.active.CompareAndSwap(false, true) { + return bc, true + } + } + return nil, false +} + +func (b *binder) Close() error { + for _, bc := range b.conns.Clone() { + bc.Close() + } + return nil +} + +type connRole uint8 + +const ( + roleBinder connRole = iota + roleGateway +) + +// bindingConn is a unified idle socket connection for both binder and gateway sides +type bindingConn struct { + exonet.Conn + role connRole + + closed atomic.Bool + active atomic.Bool // set on idle→active; guards idle counters against double-decrement + + dead chan struct{} + signalCh chan chan error + onClose func() +} + +func newGatewayConn(conn exonet.Conn, onClose func()) *bindingConn { + return &bindingConn{ + Conn: conn, + role: roleGateway, + dead: make(chan struct{}), + signalCh: make(chan chan error, 1), + onClose: onClose, + } +} + +func newBinderConn(conn exonet.Conn) *bindingConn { + return &bindingConn{ + Conn: conn, + role: roleBinder, + dead: make(chan struct{}), + } +} + +func (bc *bindingConn) SetReadDeadline(t time.Time) error { + if dl, ok := bc.Conn.(deadliner); ok { + return dl.SetReadDeadline(t) + } + return nil +} + +func (bc *bindingConn) SetWriteDeadline(t time.Time) error { + if dl, ok := bc.Conn.(deadliner); ok { + return dl.SetWriteDeadline(t) + } + return nil +} + +// readFrame reads a single control byte within the gateway keepalive/control phase. +// Must not be called after activation, when the connection becomes a raw stream. +func (bc *bindingConn) readFrame(timeout time.Duration) (byte, error) { + bc.SetReadDeadline(time.Now().Add(timeout)) + var b [1]byte + _, err := io.ReadFull(bc.Conn, b[:]) + bc.SetReadDeadline(time.Time{}) + return b[0], err +} + +func (bc *bindingConn) Close() error { + err := bc.Conn.Close() + if !bc.closed.Swap(true) && bc.onClose != nil { + bc.onClose() + } + return err +} + +// keepalive runs the ping/pong loop until activation or connection loss. +// done stops the binder-side ping sleep on shutdown (pass ctx.Done()). +// onActivate is called after WriteSignalReady; returned error causes bc to be closed. +func (bc *bindingConn) keepalive(done <-chan struct{}, onActivate func() error) { + defer close(bc.dead) + activated := false + defer func() { + if !activated { + bc.Close() + } + }() + + for { + if bc.role == roleBinder { + if err := frames.WritePing(bc.Conn); err != nil { + return + } + } + + timeout := socketDeadTimeout + if bc.role == roleBinder { + timeout = socketPingTimeout + } + frame, err := bc.readFrame(timeout) + if err != nil { + return + } + + switch frame { + case frames.BytePing: // roleGateway only + select { + case respCh := <-bc.signalCh: + bc.sendSignalGo(respCh) + activated = true + return + default: + if err := frames.WritePong(bc.Conn); err != nil { + return + } + } + case frames.BytePong: // roleBinder only + select { + case <-time.After(socketPingInterval): + case <-done: + return + } + case frames.ByteSignalGo: // roleBinder only + bc.active.Store(true) + if err := frames.WriteSignalReady(bc.Conn); err != nil { + bc.Close() + activated = true // active is set; defer must not double-close + return + } + if onActivate != nil { + if err := onActivate(); err != nil { + bc.Close() + } + } + activated = true + return + default: + return + } + } +} + +// sendSignalGo sends ByteSignalGo and waits for ByteSignalReady, reporting the result on respCh. +func (bc *bindingConn) sendSignalGo(respCh chan error) { + defer bc.SetWriteDeadline(time.Time{}) + bc.SetWriteDeadline(time.Now().Add(socketProbeTimeout)) + + if err := frames.WriteSignalGo(bc.Conn); err != nil { + respCh <- err + return + } + + frame, err := bc.readFrame(socketProbeTimeout) + if err != nil { + respCh <- err + return + } + if frame != frames.ByteSignalReady { + respCh <- fmt.Errorf("expected signalReady, got 0x%02x", frame) + return + } + respCh <- nil +} + +func (bc *bindingConn) signal() bool { + respCh := make(chan error, 1) + select { + case bc.signalCh <- respCh: + case <-time.After(socketProbeTimeout): + bc.Close() + return false + } + select { + case err := <-respCh: + if err != nil { + bc.Close() + } + return err == nil + case <-bc.dead: + return false + case <-time.After(socketProbeTimeout + socketPingInterval): + bc.Close() + return false + } +} diff --git a/mod/gateway/src/config.go b/mod/gateway/src/config.go index 5e51fee44..9bf2170fe 100644 --- a/mod/gateway/src/config.go +++ b/mod/gateway/src/config.go @@ -1,13 +1,23 @@ package gateway +import ( + "github.com/cryptopunkscc/astrald/astral" + "github.com/cryptopunkscc/astrald/mod/gateway" +) + const defaultGateway = "node1f3AwbE1gJAgAqEx98FMipokcaE9ZapIphzDUkAceE7Pmw8ghmFV19QKCATeC7uyoLszQA" +type GatewayConfig struct { + Enabled bool `yaml:"enabled"` + Listen []string `yaml:"listen"` +} + type Config struct { - Subscribe []string `yaml:"subscribe"` + Gateway GatewayConfig `yaml:"gateway"` + Visibility gateway.Visibility `yaml:"visibility"` + Gateways []*astral.Identity `yaml:"gateways"` } var defaultConfig = Config{ - Subscribe: []string{ - defaultGateway, - }, + Visibility: gateway.VisibilityPublic, } diff --git a/mod/gateway/src/conn.go b/mod/gateway/src/conn.go index 8eda9e9f3..73159f729 100644 --- a/mod/gateway/src/conn.go +++ b/mod/gateway/src/conn.go @@ -1,42 +1,42 @@ package gateway import ( - "github.com/cryptopunkscc/astrald/astral" + "io" + "time" + "github.com/cryptopunkscc/astrald/mod/exonet" - "github.com/cryptopunkscc/astrald/mod/gateway" ) -var _ exonet.Conn = &Conn{} +var _ exonet.Conn = (*gwConn)(nil) -type Conn struct { - astral.Conn - localEndpoint *gateway.Endpoint - remoteEndpoint *gateway.Endpoint - outbound bool +// note: maybe can be part of exonet +type deadliner interface { + SetReadDeadline(time.Time) error + SetWriteDeadline(time.Time) error } -func newConn(conn astral.Conn, localEndpoint *gateway.Endpoint, remoteEndpoint *gateway.Endpoint, outbound bool) *Conn { - c := &Conn{ - Conn: conn, - localEndpoint: localEndpoint, - remoteEndpoint: remoteEndpoint, - outbound: outbound, - } - return c +// gwConn wraps any io.ReadWriteCloser with gateway endpoint metadata. +type gwConn struct { + io.ReadWriteCloser + local exonet.Endpoint + remote exonet.Endpoint + outbound bool } -func (conn Conn) LocalEndpoint() exonet.Endpoint { - return conn.localEndpoint -} +func (c *gwConn) LocalEndpoint() exonet.Endpoint { return c.local } +func (c *gwConn) RemoteEndpoint() exonet.Endpoint { return c.remote } +func (c *gwConn) Outbound() bool { return c.outbound } -func (conn Conn) RemoteEndpoint() exonet.Endpoint { - return conn.remoteEndpoint -} - -func (conn Conn) Outbound() bool { - return conn.outbound +func (c *gwConn) SetReadDeadline(t time.Time) error { + if dl, ok := c.ReadWriteCloser.(deadliner); ok { + return dl.SetReadDeadline(t) + } + return nil } -func (Conn) Network() string { - return NetworkName +func (c *gwConn) SetWriteDeadline(t time.Time) error { + if dl, ok := c.ReadWriteCloser.(deadliner); ok { + return dl.SetWriteDeadline(t) + } + return nil } diff --git a/mod/gateway/src/connect.go b/mod/gateway/src/connect.go new file mode 100644 index 000000000..bc802e871 --- /dev/null +++ b/mod/gateway/src/connect.go @@ -0,0 +1,61 @@ +package gateway + +import ( + "time" + + "github.com/cryptopunkscc/astrald/astral" + "github.com/cryptopunkscc/astrald/mod/gateway" +) + +const connectTimeout = 30 * time.Second + +func (mod *Module) connectTo(caller *astral.Identity, target *astral.Identity, network string) (gateway.Socket, error) { + if !mod.canGateway(caller) { + return gateway.Socket{}, gateway.ErrUnauthorized + } + + endpoint, err := mod.getGatewayEndpoint(mod.ctx, network) + if err != nil { + return gateway.Socket{}, err + } + + binder, ok := mod.binderByIdentity(target) + if !ok { + return gateway.Socket{}, gateway.ErrTargetNotReachable + } + + reserved, ok := binder.takeConn() + if !ok { + return gateway.Socket{}, gateway.ErrTargetNotReachable + } + + c := &connector{ + Identity: caller, + Nonce: astral.NewNonce(), + Target: target, + reserved: reserved, + } + + mod.connectors.Add(c) + + go func() { + t := time.NewTimer(connectTimeout) + defer t.Stop() + <-t.C + + bc := c.takeReserved() + if bc == nil { + return + } + + mod.connectors.Remove(c) + if err := bc.Close(); err != nil { + mod.log.Error("failed to close reserved conn: %v", err) + } + }() + + return gateway.Socket{ + Nonce: c.Nonce, + Endpoint: endpoint, + }, nil +} diff --git a/mod/gateway/src/connector.go b/mod/gateway/src/connector.go new file mode 100644 index 000000000..2b1789f34 --- /dev/null +++ b/mod/gateway/src/connector.go @@ -0,0 +1,40 @@ +package gateway + +import ( + "sync" + + "github.com/cryptopunkscc/astrald/astral" +) + +// connector represents a pending connection request from a node that wants +// to reach a binder through the gateway. Multiple connectors per identity +// are allowed. +type connector struct { + mu sync.Mutex + Identity *astral.Identity + Nonce astral.Nonce + Target *astral.Identity + reserved *bindingConn +} + +// takeReserved atomically takes the reserved bindingConn, returning nil if +// already taken (connection already established or timed out). +func (c *connector) takeReserved() *bindingConn { + c.mu.Lock() + defer c.mu.Unlock() + + bc := c.reserved + c.reserved = nil + return bc +} + +func (c *connector) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + + if c.reserved != nil { + return c.reserved.Close() + } + + return nil +} diff --git a/mod/gateway/src/deps.go b/mod/gateway/src/deps.go index 23cf6f795..c8023cd89 100644 --- a/mod/gateway/src/deps.go +++ b/mod/gateway/src/deps.go @@ -11,9 +11,11 @@ func (mod *Module) LoadDependencies(*astral.Context) (err error) { return } - mod.Exonet.SetDialer("gw", mod.dialer) + mod.Exonet.SetDialer("gw", mod) mod.Exonet.SetUnpacker("gw", mod) mod.Exonet.SetParser("gw", mod) + mod.ops.AddStructPrefix(mod, "Op") + mod.Services.AddDiscoverer(mod) return } diff --git a/mod/gateway/src/dial.go b/mod/gateway/src/dial.go new file mode 100644 index 000000000..be3f98f02 --- /dev/null +++ b/mod/gateway/src/dial.go @@ -0,0 +1,75 @@ +package gateway + +import ( + "github.com/cryptopunkscc/astrald/astral" + "github.com/cryptopunkscc/astrald/lib/astrald" + "github.com/cryptopunkscc/astrald/lib/query" + "github.com/cryptopunkscc/astrald/mod/exonet" + "github.com/cryptopunkscc/astrald/mod/gateway" + gatewayClient "github.com/cryptopunkscc/astrald/mod/gateway/client" +) + +var _ exonet.Dialer = &Module{} + +func (mod *Module) Dial(ctx *astral.Context, endpoint exonet.Endpoint) (exonet.Conn, error) { + if endpoint.Network() != NetworkName { + return nil, exonet.ErrUnsupportedNetwork + } + + gwEndpoint, ok := endpoint.(*gateway.Endpoint) + if !ok { + return nil, exonet.ErrUnsupportedNetwork + } + + if gwEndpoint.GatewayID.IsEqual(mod.node.Identity()) { + return nil, gateway.ErrInvalidGateway + } + + ctx = ctx.IncludeZone(astral.ZoneNetwork) + + client := gatewayClient.New(gwEndpoint.GatewayID, astrald.Default()) + socket, err := client.Connect(ctx, gwEndpoint.TargetID) + if err != nil { + return mod.route(ctx, gwEndpoint) + } + + conn, err := mod.Exonet.Dial(ctx, socket.Endpoint) + if err != nil { + return mod.route(ctx, gwEndpoint) + } + + if _, err := socket.Nonce.WriteTo(conn); err != nil { + conn.Close() + return mod.route(ctx, gwEndpoint) + } + + return &gwConn{ + ReadWriteCloser: conn, + local: conn.LocalEndpoint(), + remote: gwEndpoint, + outbound: conn.Outbound(), + }, nil +} + +func (mod *Module) route(ctx *astral.Context, gwEndpoint *gateway.Endpoint) (exonet.Conn, error) { + mod.log.Logv(1, "socket path unavailable, trying link path to %v via %v", gwEndpoint.TargetID, gwEndpoint.GatewayID) + + q := &astral.Query{ + Nonce: astral.NewNonce(), + Caller: mod.node.Identity(), + Target: gwEndpoint.GatewayID, + Query: gateway.MethodRoute + "." + gwEndpoint.TargetID.String(), + } + + conn, err := query.Route(ctx, mod.node, q) + if err != nil { + return nil, err + } + + return &gwConn{ + ReadWriteCloser: conn, + local: gateway.NewEndpoint(mod.node.Identity(), mod.node.Identity()), + remote: gwEndpoint, + outbound: true, + }, nil +} diff --git a/mod/gateway/src/dialer.go b/mod/gateway/src/dialer.go deleted file mode 100644 index 1d1dcac00..000000000 --- a/mod/gateway/src/dialer.go +++ /dev/null @@ -1,41 +0,0 @@ -package gateway - -import ( - "github.com/cryptopunkscc/astrald/astral" - "github.com/cryptopunkscc/astrald/lib/query" - "github.com/cryptopunkscc/astrald/mod/exonet" - "github.com/cryptopunkscc/astrald/mod/gateway" -) - -type Dialer struct { - node astral.Node -} - -func NewDialer(node astral.Node) *Dialer { - return &Dialer{node: node} -} - -func (dialer *Dialer) Dial(ctx *astral.Context, endpoint exonet.Endpoint) (exonet.Conn, error) { - e, err := Unpack(endpoint.Pack()) - if err != nil { - return nil, err - } - - if e.GatewayID.IsEqual(dialer.node.Identity()) { - return nil, ErrInvalidGateway - } - - var q = astral.NewQuery(dialer.node.Identity(), e.GatewayID, RouteServiceName+"."+e.TargetID.String()) - - conn, err := query.Route(ctx, dialer.node, q) - if err != nil { - return nil, err - } - - return newConn( - conn, - gateway.NewEndpoint(dialer.node.Identity(), dialer.node.Identity()), - e, - true, - ), err -} diff --git a/mod/gateway/src/endpoint_resolvers.go b/mod/gateway/src/endpoint_resolvers.go new file mode 100644 index 000000000..9ca841f93 --- /dev/null +++ b/mod/gateway/src/endpoint_resolvers.go @@ -0,0 +1,26 @@ +package gateway + +import ( + "time" + + "github.com/cryptopunkscc/astrald/astral" + "github.com/cryptopunkscc/astrald/mod/gateway" + "github.com/cryptopunkscc/astrald/mod/nodes" + "github.com/cryptopunkscc/astrald/sig" +) + +var _ nodes.EndpointResolver = &Module{} + +func (mod *Module) ResolveEndpoints(context *astral.Context, nodeID *astral.Identity) (<-chan *nodes.EndpointWithTTL, error) { + if !nodeID.IsEqual(mod.node.Identity()) { + // note: we might resolve endpoints if we act as their gateway + return sig.ArrayToChan([]*nodes.EndpointWithTTL{}), nil + } + + var endpoints []*nodes.EndpointWithTTL + for _, gw := range mod.gateways.Clone() { + endpoints = append(endpoints, nodes.NewEndpointWithTTL(gateway.NewEndpoint(gw, mod.node.Identity()), 7*30*24*time.Hour)) + } + + return sig.ArrayToChan(endpoints), nil +} diff --git a/mod/gateway/src/errors.go b/mod/gateway/src/errors.go deleted file mode 100644 index 5b0b9af38..000000000 --- a/mod/gateway/src/errors.go +++ /dev/null @@ -1,21 +0,0 @@ -package gateway - -import ( - "errors" - "fmt" -) - -var ErrInvalidGateway = errors.New("invalid gateway") -var ErrAlreadySubscribed = errors.New("already subscribed") -var ErrNotSubscribed = errors.New("subscription not found") - -type ErrParseError struct { - msg string -} - -func (e ErrParseError) Error() string { - if len(e.msg) == 0 { - return "parse error" - } - return fmt.Sprintf("parse error: %s", e.msg) -} diff --git a/mod/gateway/src/frames/ping_frame.go b/mod/gateway/src/frames/ping_frame.go new file mode 100644 index 000000000..35946d082 --- /dev/null +++ b/mod/gateway/src/frames/ping_frame.go @@ -0,0 +1,15 @@ +package frames + +import "io" + +const ( + BytePing = byte(0x00) + BytePong = byte(0x01) + ByteSignalGo = byte(0x02) + ByteSignalReady = byte(0x03) +) + +func WritePing(w io.Writer) error { _, err := w.Write([]byte{BytePing}); return err } +func WritePong(w io.Writer) error { _, err := w.Write([]byte{BytePong}); return err } +func WriteSignalGo(w io.Writer) error { _, err := w.Write([]byte{ByteSignalGo}); return err } +func WriteSignalReady(w io.Writer) error { _, err := w.Write([]byte{ByteSignalReady}); return err } diff --git a/mod/gateway/src/loader.go b/mod/gateway/src/loader.go index c367d565e..c7bea3e59 100644 --- a/mod/gateway/src/loader.go +++ b/mod/gateway/src/loader.go @@ -6,29 +6,26 @@ import ( "github.com/cryptopunkscc/astrald/core" "github.com/cryptopunkscc/astrald/core/assets" "github.com/cryptopunkscc/astrald/lib/routers" + "github.com/cryptopunkscc/astrald/mod/gateway" ) -const ModuleName = "gateway" - type Loader struct{} func (Loader) Load(node astral.Node, assets assets.Assets, log *log.Logger) (core.Module, error) { mod := &Module{ - node: node, - log: log, - PathRouter: routers.NewPathRouter(node.Identity(), false), - config: defaultConfig, - dialer: NewDialer(node), - subscribers: make(map[string]*Subscriber), + node: node, + log: log, + PathRouter: routers.NewPathRouter(node.Identity(), false), + config: defaultConfig, } - _ = assets.LoadYAML(ModuleName, &mod.config) + _ = assets.LoadYAML(gateway.ModuleName, &mod.config) return mod, nil } func init() { - if err := core.RegisterModule(ModuleName, Loader{}); err != nil { + if err := core.RegisterModule(gateway.ModuleName, Loader{}); err != nil { panic(err) } } diff --git a/mod/gateway/src/maintain_binding_task.go b/mod/gateway/src/maintain_binding_task.go new file mode 100644 index 000000000..fccbf79ee --- /dev/null +++ b/mod/gateway/src/maintain_binding_task.go @@ -0,0 +1,92 @@ +package gateway + +import ( + "time" + + "github.com/cryptopunkscc/astrald/astral" + "github.com/cryptopunkscc/astrald/lib/astrald" + "github.com/cryptopunkscc/astrald/mod/events" + "github.com/cryptopunkscc/astrald/mod/gateway" + gatewayClient "github.com/cryptopunkscc/astrald/mod/gateway/client" + "github.com/cryptopunkscc/astrald/mod/ip" + "github.com/cryptopunkscc/astrald/mod/scheduler" + "github.com/cryptopunkscc/astrald/sig" +) + +var _ scheduler.Task = &MaintainBindingTask{} +var _ scheduler.EventReceiver = &MaintainBindingTask{} + +type MaintainBindingTask struct { + mod *Module + GatewayID *astral.Identity + Visibility gateway.Visibility + retry *sig.Retry + triggerCh chan struct{} +} + +func (mod *Module) NewMaintainBindingTask(gatewayID *astral.Identity, visibility gateway.Visibility) *MaintainBindingTask { + retry, _ := sig.NewRetry(time.Second, 15*time.Minute, 2) + return &MaintainBindingTask{ + mod: mod, + GatewayID: gatewayID, + Visibility: visibility, + retry: retry, + triggerCh: make(chan struct{}, 1), + } +} + +func (task *MaintainBindingTask) String() string { + return "maintain_binding_task" +} + +func (task *MaintainBindingTask) Run(ctx *astral.Context) error { + task.mod.log.Log("starting to maintain binding to %v", task.GatewayID) + client := gatewayClient.New(task.GatewayID, astrald.Default()) + + count := -1 + for { + switch { + case count == 0: + task.mod.log.Log("binding to %v lost, rebinding", task.GatewayID) + case count > 0 && count%5 == 0: + task.mod.log.Log("still trying to bind to %v (attempt %v)", task.GatewayID, count) + } + + socket, err := client.Bind(ctx.IncludeZone(astral.ZoneNetwork), task.Visibility) + if err != nil { + select { + case <-ctx.Done(): + return ctx.Err() + case count = <-task.retry.Retry(): + case <-task.triggerCh: + } + continue + } + + task.retry.Reset() + if count > 0 { + task.mod.log.Log("rebound to %v after %v attempts", task.GatewayID, count) + } else if count < 0 { + task.mod.log.Log("bound to %v", task.GatewayID) + } + count = 0 + + err = task.mod.newSocketPool(ctx, task.GatewayID, *socket).Run() + if err != nil { + task.mod.log.Error("rebinding to %v due to: %v", task.GatewayID, err) + } + } +} + +func (task *MaintainBindingTask) ReceiveEvent(e *events.Event) { + switch typed := e.Data.(type) { + case *ip.EventNetworkAddressChanged: + if len(typed.Added) > 0 { + task.retry.Reset() + select { + case task.triggerCh <- struct{}{}: + default: + } + } + } +} diff --git a/mod/gateway/src/module.go b/mod/gateway/src/module.go index 0b52c3ffc..056bbaeea 100644 --- a/mod/gateway/src/module.go +++ b/mod/gateway/src/module.go @@ -1,156 +1,132 @@ package gateway import ( - "strings" - "sync" - "github.com/cryptopunkscc/astrald/astral" "github.com/cryptopunkscc/astrald/astral/log" + "github.com/cryptopunkscc/astrald/lib/ops" "github.com/cryptopunkscc/astrald/lib/routers" "github.com/cryptopunkscc/astrald/mod/dir" "github.com/cryptopunkscc/astrald/mod/exonet" - gateway2 "github.com/cryptopunkscc/astrald/mod/gateway" + "github.com/cryptopunkscc/astrald/mod/gateway" + ipmod "github.com/cryptopunkscc/astrald/mod/ip" "github.com/cryptopunkscc/astrald/mod/nodes" - "github.com/cryptopunkscc/astrald/tasks" + "github.com/cryptopunkscc/astrald/mod/scheduler" + "github.com/cryptopunkscc/astrald/mod/services" + tcpmod "github.com/cryptopunkscc/astrald/mod/tcp" + "github.com/cryptopunkscc/astrald/sig" ) const NetworkName = "gw" type Deps struct { - Dir dir.Module - Exonet exonet.Module - Nodes nodes.Module + Dir dir.Module + Exonet exonet.Module + Nodes nodes.Module + Scheduler scheduler.Module + Services services.Module + TCP tcpmod.Module + IP ipmod.Module } type Module struct { Deps *routers.PathRouter - config Config - node astral.Node - log *log.Logger - ctx *astral.Context - dialer *Dialer - subscribers map[string]*Subscriber - mu sync.Mutex -} -func (mod *Module) Run(ctx *astral.Context) error { - mod.ctx = ctx.IncludeZone(astral.ZoneNetwork) + ops ops.Set + config Config + node astral.Node + log *log.Logger + ctx *astral.Context - mod.subscribeToGateways() + gateways sig.Set[*astral.Identity] + binders sig.Map[string, *binder] + connectors sig.Set[*connector] - return tasks.Group( - &SubscribeService{Module: mod}, - &RouteService{Module: mod, router: mod.node}, - ).Run(ctx) + listenEndpoints sig.Map[string, exonet.Endpoint] } -func (mod *Module) subscribeToGateways() { - for _, gateName := range mod.config.Subscribe { - var gateID *astral.Identity - - if after, found := strings.CutPrefix(gateName, "node1"); found && len(after) > 32 { - var info nodes.NodeInfo - - err := info.UnmarshalText([]byte(after)) - if err != nil { - mod.log.Error("parse node info: %v", err) - continue - } - - // try to set alias - err = mod.Dir.SetAlias(info.Identity, string(info.Alias)) - if err != nil { - mod.log.Error("set alias: %v", err) - } - - // save endpoints - for _, ep := range info.Endpoints { - err = mod.Nodes.AddEndpoint(info.Identity, nodes.NewEndpointWithTTL(ep)) - if err != nil { - mod.log.Error("add endpoint: %v", err) - continue - } - } - - // subscribe - err = mod.Subscribe(info.Identity) - if err != nil { - mod.log.Error("subscribe: %v", err) - } - continue - } +var _ gateway.Module = &Module{} - gateID, err := mod.Dir.ResolveIdentity(gateName) - if err != nil { - mod.log.Error("resolve identity %v: %v", gateName, err) - continue - } - - err = mod.Subscribe(gateID) - if err != nil { - mod.log.Error("subscribe: %v", err) - } - } +func (mod *Module) GetOpSet() *ops.Set { + return &mod.ops } -func (mod *Module) Subscribe(gateway *astral.Identity) error { - mod.mu.Lock() - defer mod.mu.Unlock() +func (mod *Module) Run(ctx *astral.Context) error { + mod.ctx = ctx.IncludeZone(astral.ZoneNetwork) - switch { - case gateway.IsZero(): - return ErrInvalidGateway - case gateway.IsEqual(mod.node.Identity()): - return ErrInvalidGateway + err := mod.AddRoute(gateway.MethodRoute+".*", routers.Func(mod.routeQuery)) + if err != nil { + return err } - var hex = gateway.String() - - if _, found := mod.subscribers[hex]; found { - return ErrAlreadySubscribed + if mod.config.Gateway.Enabled { + mod.startServers(mod.ctx) } - var s = NewSubscriber(gateway, mod.node, mod.log) - mod.subscribers[hex] = s + <-mod.Scheduler.Ready() - go func() { - err := s.Run(mod.ctx) - if err != nil { - mod.log.Errorv(1, "gateway %v subscriber ended with error: %v", gateway, err) - } - mod.mu.Lock() - defer mod.mu.Unlock() + for _, gw := range mod.config.Gateways { + mod.addPersistentGateway(gw) + } - delete(mod.subscribers, hex) - }() + <-ctx.Done() + for _, b := range mod.binders.Values() { + b.Close() + } + for _, c := range mod.connectors.Clone() { + c.Close() + } return nil } -func (mod *Module) Unsubscribe(gateway *astral.Identity) error { - mod.mu.Lock() - defer mod.mu.Unlock() +func (mod *Module) Endpoints() []exonet.Endpoint { + var list = make([]exonet.Endpoint, 0) + + return list +} - s, found := mod.subscribers[gateway.String()] - if !found { - return ErrNotSubscribed +func (mod *Module) getGatewayEndpoint(ctx *astral.Context, network string) (endpoint exonet.Endpoint, err error) { + endpoint, ok := mod.listenEndpoints.Get(network) + if !ok { + // fixme: return public error (no gateway endpoint available) + return } - s.Cancel() - return nil + return endpoint, nil } -func (mod *Module) Endpoints() []exonet.Endpoint { - var list = make([]exonet.Endpoint, 0) +func (mod *Module) binderByIdentity(identity *astral.Identity) (*binder, bool) { + return mod.binders.Get(identity.String()) +} - for _, s := range mod.subscribers { - list = append(list, gateway2.NewEndpoint(s.Gateway(), mod.node.Identity())) +func (mod *Module) binderByNonce(nonce astral.Nonce) (*binder, bool) { + for _, b := range mod.binders.Values() { + if b.Nonce == nonce { + return b, true + } } + return nil, false +} - return list +func (mod *Module) connectorByNonce(nonce astral.Nonce) (*connector, bool) { + for _, c := range mod.connectors.Clone() { + if c.Nonce == nonce { + return c, true + } + } + return nil, false +} + +func (mod *Module) canGateway(identity *astral.Identity) bool { + return mod.config.Gateway.Enabled +} + +func (mod *Module) addPersistentGateway(gatewayID *astral.Identity) { + mod.gateways.Add(gatewayID) + mod.Scheduler.Schedule(mod.NewMaintainBindingTask(gatewayID, mod.config.Visibility)) } func (mod *Module) String() string { - return ModuleName + return gateway.ModuleName } diff --git a/mod/gateway/src/op_list.go b/mod/gateway/src/op_list.go new file mode 100644 index 000000000..2e3285801 --- /dev/null +++ b/mod/gateway/src/op_list.go @@ -0,0 +1,28 @@ +package gateway + +import ( + "github.com/cryptopunkscc/astrald/astral" + "github.com/cryptopunkscc/astrald/astral/channel" + "github.com/cryptopunkscc/astrald/lib/ops" + "github.com/cryptopunkscc/astrald/mod/gateway" +) + +type opListArgs struct { + Out string `query:"optional"` +} + +func (mod *Module) OpList(ctx *astral.Context, q *ops.Query, args opListArgs) error { + ch := q.AcceptChannel(channel.WithOutputFormat(args.Out)) + defer ch.Close() + + for _, client := range mod.binders.Values() { + if client.Visibility != gateway.VisibilityPublic { + continue + } + if err := ch.Send(client.Identity); err != nil { + return err + } + } + + return ch.Send(&astral.EOS{}) +} diff --git a/mod/gateway/src/op_node_bind.go b/mod/gateway/src/op_node_bind.go new file mode 100644 index 000000000..3dc0ecb4c --- /dev/null +++ b/mod/gateway/src/op_node_bind.go @@ -0,0 +1,30 @@ +package gateway + +import ( + "github.com/cryptopunkscc/astrald/astral" + "github.com/cryptopunkscc/astrald/astral/channel" + "github.com/cryptopunkscc/astrald/lib/ops" + "github.com/cryptopunkscc/astrald/mod/gateway" +) + +type opNodeBindArgs struct { + Visibility gateway.Visibility + In string `query:"optional"` + Out string `query:"optional"` +} + +func (mod *Module) OpNodeBind( + ctx *astral.Context, + q *ops.Query, + args opNodeBindArgs, +) (err error) { + ch := channel.New(q.Accept(), channel.WithFormats(args.In, args.Out)) + defer ch.Close() + + socket, err := mod.bind(ctx, q.Caller(), args.Visibility, "tcp") + if err != nil { + return ch.Send(astral.NewError(err.Error())) + } + + return ch.Send(&socket) +} diff --git a/mod/gateway/src/op_node_connect.go b/mod/gateway/src/op_node_connect.go new file mode 100644 index 000000000..93ed7b48f --- /dev/null +++ b/mod/gateway/src/op_node_connect.go @@ -0,0 +1,29 @@ +package gateway + +import ( + "github.com/cryptopunkscc/astrald/astral" + "github.com/cryptopunkscc/astrald/astral/channel" + "github.com/cryptopunkscc/astrald/lib/ops" +) + +type opNodeConnectArgs struct { + Target *astral.Identity + In string `query:"optional"` + Out string `query:"optional"` +} + +func (mod *Module) OpNodeConnect( + ctx *astral.Context, + q *ops.Query, + args opNodeConnectArgs, +) (err error) { + ch := channel.New(q.Accept(), channel.WithFormats(args.In, args.Out)) + defer ch.Close() + + socket, err := mod.connectTo(q.Caller(), args.Target, "tcp") + if err != nil { + return ch.Send(astral.NewError(err.Error())) + } + + return ch.Send(&socket) +} diff --git a/mod/gateway/src/parser.go b/mod/gateway/src/parser.go index 4db0c8b71..1d91cc889 100644 --- a/mod/gateway/src/parser.go +++ b/mod/gateway/src/parser.go @@ -2,10 +2,11 @@ package gateway import ( "errors" - "github.com/cryptopunkscc/astrald/astral" + "fmt" + "strings" + "github.com/cryptopunkscc/astrald/mod/exonet" "github.com/cryptopunkscc/astrald/mod/gateway" - "strings" ) var _ exonet.Parser = &Module{} @@ -17,7 +18,7 @@ func (mod *Module) Parse(network string, address string) (exonet.Endpoint, error var ids = strings.SplitN(address, ":", 2) if len(ids) != 2 { - return nil, ErrParseError{msg: "invalid address string"} + return nil, fmt.Errorf("invalid endpoint: %s", address) } var err error @@ -38,27 +39,3 @@ func (mod *Module) Parse(network string, address string) (exonet.Endpoint, error return &endpoint, nil } - -// Parse converts a text representation of a gateway address to an Endpoint struct -func Parse(str string) (endpoint *gateway.Endpoint, err error) { - if len(str) != (2*66)+1 { // two public key hex strings and a separator ":" - return endpoint, ErrParseError{msg: "invalid address length"} - } - var ids = strings.SplitN(str, ":", 2) - if len(ids) != 2 { - return nil, ErrParseError{msg: "invalid address string"} - } - endpoint.GatewayID, err = astral.ParseIdentity(ids[0]) - if err != nil { - return nil, err - } - endpoint.TargetID, err = astral.ParseIdentity(ids[1]) - if err != nil { - return nil, err - } - if endpoint.GatewayID.IsEqual(endpoint.TargetID) { - return nil, errors.New("invalid endpoint") - } - - return -} diff --git a/mod/gateway/src/pipe.go b/mod/gateway/src/pipe.go new file mode 100644 index 000000000..4d8504d9f --- /dev/null +++ b/mod/gateway/src/pipe.go @@ -0,0 +1,44 @@ +package gateway + +import ( + "io" + "time" +) + +func pipe(a, b io.ReadWriteCloser) { + const idle = 30 * time.Second + + done := make(chan struct{}, 2) + + forward := func(dst, src io.ReadWriteCloser) { + // note: sync.Pool could reduce per-connection allocations under high concurrency (pattern used by nginx, envoy, traefik) + buf := make([]byte, 32*1024) + srcD, srcOk := src.(deadliner) + dstD, dstOk := dst.(deadliner) + for { + if srcOk { + srcD.SetReadDeadline(time.Now().Add(idle)) + } + n, err := src.Read(buf) + if n > 0 { + if dstOk { + dstD.SetWriteDeadline(time.Now().Add(idle)) + } + if _, werr := dst.Write(buf[:n]); werr != nil { + break + } + } + if err != nil { + break + } + } + done <- struct{}{} + } + + go forward(a, b) + go forward(b, a) + + <-done + a.Close() + b.Close() +} diff --git a/mod/gateway/src/pool.go b/mod/gateway/src/pool.go new file mode 100644 index 000000000..404f6f041 --- /dev/null +++ b/mod/gateway/src/pool.go @@ -0,0 +1,113 @@ +package gateway + +import ( + "time" + + "github.com/cryptopunkscc/astrald/astral" + "github.com/cryptopunkscc/astrald/mod/exonet" + "github.com/cryptopunkscc/astrald/mod/gateway" + "github.com/cryptopunkscc/astrald/sig" +) + +const ( + socketPoolTargetIdle = 2 + socketPoolMaxFails = 3 + socketPingInterval = 2 * time.Second + socketPingTimeout = 3 * time.Second +) + +// SocketPool maintains socketPoolTargetIdle idle socket connections to a gateway. +type SocketPool struct { + *Module + ctx *astral.Context + socket gateway.Socket + gatewayID *astral.Identity + + conns sig.Set[*bindingConn] + wake chan struct{} +} + +func (mod *Module) newSocketPool(ctx *astral.Context, gatewayID *astral.Identity, socket gateway.Socket) *SocketPool { + return &SocketPool{ + ctx: ctx, + Module: mod, + socket: socket, + gatewayID: gatewayID, + wake: make(chan struct{}, 1), + } +} + +func (p *SocketPool) Run() error { + retry, _ := sig.NewRetry(time.Second, 30*time.Second, 2) + p.notify() + + for { + select { + case <-p.ctx.Done(): + return p.ctx.Err() + case <-p.wake: + for p.idleCount() < socketPoolTargetIdle { + conn, err := p.acquireConn() + if err != nil { + select { + case <-p.ctx.Done(): + return p.ctx.Err() + case count := <-retry.Retry(): + if count >= socketPoolMaxFails { + return gateway.ErrSocketUnreachable + } + } + continue + } + retry.Reset() + p.startIdleSocket(conn) + } + } + } +} + +func (p *SocketPool) acquireConn() (exonet.Conn, error) { + p.log.Logv(2, "acquiring socket connection to %v through %v", p.gatewayID, p.socket.Endpoint) + conn, err := p.Exonet.Dial(p.ctx, p.socket.Endpoint) + if err != nil { + return nil, err + } + + if _, err := p.socket.Nonce.WriteTo(conn); err != nil { + conn.Close() + return nil, err + } + + return conn, nil +} + +// idleCount returns the number of non-active conns in the pool. +func (p *SocketPool) idleCount() int { + return len(p.conns.Select(func(a *bindingConn) bool { + return !a.active.Load() + })) +} + +func (p *SocketPool) startIdleSocket(conn exonet.Conn) { + bc := newBinderConn(conn) + bc.onClose = func() { + p.conns.Remove(bc) + p.notify() + } + + p.conns.Add(bc) + go bc.keepalive(p.ctx.Done(), func() error { + return p.Nodes.EstablishInboundLink(p.ctx, &gwConn{ + ReadWriteCloser: bc, + local: gateway.NewEndpoint(p.node.Identity(), p.node.Identity()), + remote: gateway.NewEndpoint(p.gatewayID, p.node.Identity()), + }) + }) +} + +func (p *SocketPool) notify() { + select { + case p.wake <- struct{}{}: + default: + } +} diff --git a/mod/gateway/src/route.go b/mod/gateway/src/route.go new file mode 100644 index 000000000..8ca8db1ee --- /dev/null +++ b/mod/gateway/src/route.go @@ -0,0 +1,62 @@ +package gateway + +import ( + "context" + "io" + "strings" + "time" + + "github.com/cryptopunkscc/astrald/astral" + "github.com/cryptopunkscc/astrald/lib/query" + "github.com/cryptopunkscc/astrald/mod/gateway" +) + +const acceptTimeout = 30 * time.Second + +func (mod *Module) routeQuery(ctx *astral.Context, q *astral.Query, w io.WriteCloser) (io.WriteCloser, error) { + ctx = ctx.IncludeZone(astral.ZoneNetwork) + + var targetKey string + switch { + case strings.HasPrefix(q.Query, gateway.MethodRoute+"."): + targetKey, _ = strings.CutPrefix(q.Query, gateway.MethodRoute+".") + default: + return query.Reject() + } + + // target is us + if targetKey == mod.node.Identity().String() { + return query.Accept(q, w, func(conn astral.Conn) { + c := &gwConn{ + ReadWriteCloser: conn, + local: gateway.NewEndpoint(q.Target, q.Target), + remote: gateway.NewEndpoint(q.Caller, q.Target), + } + + // prevents slow gateway connections + actx, cancel := context.WithTimeout(context.Background(), acceptTimeout) + defer cancel() + + if err := mod.Nodes.EstablishInboundLink(actx, c); err != nil { + mod.log.Errorv(1, "inbound link from %v failed: %v", q.Caller, err) + } + }) + } + + // forward query (will automatically use existing link) + + targetIdentity, err := astral.ParseIdentity(targetKey) + if err != nil { + return query.Reject() + } + + nextQuery := &astral.Query{ + Nonce: astral.NewNonce(), + Caller: mod.node.Identity(), + Target: targetIdentity, + Query: q.Query, + } + + mod.log.Logv(2, "routing %v to %v via link", q.Caller, targetIdentity) + return mod.node.RouteQuery(ctx, nextQuery, w) +} diff --git a/mod/gateway/src/route_service.go b/mod/gateway/src/route_service.go deleted file mode 100644 index f2136d8f6..000000000 --- a/mod/gateway/src/route_service.go +++ /dev/null @@ -1,79 +0,0 @@ -package gateway - -import ( - "context" - "io" - "strings" - "time" - - "github.com/cryptopunkscc/astrald/astral" - "github.com/cryptopunkscc/astrald/lib/query" - "github.com/cryptopunkscc/astrald/mod/gateway" -) - -const RouteServiceName = ".gateway" -const acceptTimeout = 15 * time.Second - -type RouteService struct { - *Module - router astral.Router -} - -func (srv *RouteService) Run(ctx *astral.Context) error { - err := srv.AddRoute(RouteServiceName+".*", srv) - if err != nil { - return err - } - defer srv.RemoveRoute(RouteServiceName + ".*") - - <-ctx.Done() - return nil -} - -func (srv *RouteService) RouteQuery(ctx *astral.Context, q *astral.Query, w io.WriteCloser) (io.WriteCloser, error) { - var targetKey string - - switch { - case strings.HasPrefix(q.Query, RouteServiceName+"."): - targetKey, _ = strings.CutPrefix(q.Query, RouteServiceName+".") - - default: - return query.Reject() - } - - // check if the target is us - if targetKey == srv.node.Identity().String() { - return query.Accept(q, w, func(conn astral.Conn) { - gwConn := newConn( - conn, - gateway.NewEndpoint(q.Target, q.Target), - gateway.NewEndpoint(q.Caller, q.Target), - false, - ) - - actx, cancel := context.WithTimeout(context.Background(), acceptTimeout) - defer cancel() - - err := srv.Nodes.EstablishInboundLink(actx, gwConn) - if err != nil { - return - } - }) - } - - targetIdentity, err := astral.ParseIdentity(targetKey) - if err != nil { - return query.Reject() - } - - nextQuery := &astral.Query{ - Nonce: astral.NewNonce(), - Caller: srv.node.Identity(), - Target: targetIdentity, - Query: q.Query, - } - - srv.log.Logv(2, "forwarding %v to %v", q.Caller, targetIdentity) - - return srv.router.RouteQuery(ctx, nextQuery, w) -} diff --git a/mod/gateway/src/server.go b/mod/gateway/src/server.go new file mode 100644 index 000000000..9100d4e49 --- /dev/null +++ b/mod/gateway/src/server.go @@ -0,0 +1,43 @@ +package gateway + +import ( + "strings" + + "github.com/cryptopunkscc/astrald/astral" + "github.com/cryptopunkscc/astrald/mod/tcp" +) + +func (mod *Module) startServers(ctx *astral.Context) { + for _, addr := range mod.config.Gateway.Listen { + parts := strings.SplitN(addr, ":", 2) + if len(parts) != 2 { + mod.log.Error("invalid listen address: %v", addr) + continue + } + network, address := parts[0], parts[1] + endpoint, err := mod.Exonet.Parse(network, address) + if err != nil { + mod.log.Error("parse listen address %v: %v", addr, err) + continue + } + + switch network { + case "tcp": + tcpEndpoint, ok := endpoint.(*tcp.Endpoint) + if !ok { + mod.log.Error("invalid listen address: %v", addr) + continue + } + + mod.log.Logv(1, "start listening on %v", tcpEndpoint) + if err := mod.TCP.CreateEphemeralListener(ctx, tcpEndpoint.Port, mod.acceptSocketConn); err != nil { + mod.log.Error("create ephemeral listener on %v: %v", addr, err) + continue + } + + mod.listenEndpoints.Set("tcp", tcpEndpoint) + default: + mod.log.Error("unsupported gateway socket network: %v", network) + } + } +} diff --git a/mod/gateway/src/service_discoverer.go b/mod/gateway/src/service_discoverer.go new file mode 100644 index 000000000..8fef1c18c --- /dev/null +++ b/mod/gateway/src/service_discoverer.go @@ -0,0 +1,39 @@ +package gateway + +import ( + "github.com/cryptopunkscc/astrald/astral" + "github.com/cryptopunkscc/astrald/mod/gateway" + "github.com/cryptopunkscc/astrald/mod/services" +) + +var _ services.Discoverer = &Module{} + +func (mod *Module) DiscoverServices( + ctx *astral.Context, + caller *astral.Identity, + follow bool, +) (<-chan *services.Update, error) { + var ch = make(chan *services.Update, 2) + + if mod.config.Gateway.Enabled { + ch <- &services.Update{ + Available: true, + Name: gateway.ModuleName, + ProviderID: mod.node.Identity(), + } + } + + if !follow { + close(ch) + return ch, nil + } + + ch <- nil + + go func() { + <-ctx.Done() + close(ch) + }() + + return ch, nil +} diff --git a/mod/gateway/src/subscribe_service.go b/mod/gateway/src/subscribe_service.go deleted file mode 100644 index 2c44d7278..000000000 --- a/mod/gateway/src/subscribe_service.go +++ /dev/null @@ -1,46 +0,0 @@ -package gateway - -import ( - "encoding/json" - "github.com/cryptopunkscc/astrald/astral" - "github.com/cryptopunkscc/astrald/lib/query" - "io" - "time" -) - -const SubscribeServiceName = ".gateway.subscribe" -const SubscribeServiceType = "mod.gateway.subscribe" -const defaultSubscriptionDuration = 24 * time.Hour - -type SubscribeService struct { - *Module -} - -type Subscription struct { - Status string - ExpiresAt time.Time `json:"expires_at,omitempty"` -} - -func (srv *SubscribeService) Run(ctx *astral.Context) error { - var err = srv.AddRoute(SubscribeServiceName, srv) - if err != nil { - return err - } - defer srv.RemoveRoute(SubscribeServiceName) - - <-ctx.Done() - return nil -} - -func (srv *SubscribeService) RouteQuery(ctx *astral.Context, q *astral.Query, w io.WriteCloser) (io.WriteCloser, error) { - return query.Accept(q, w, func(conn astral.Conn) { - defer conn.Close() - - s := &Subscription{ - Status: "ok", - ExpiresAt: time.Now().Add(defaultSubscriptionDuration), - } - - json.NewEncoder(conn).Encode(s) - }) -} diff --git a/mod/gateway/src/subscriber.go b/mod/gateway/src/subscriber.go deleted file mode 100644 index 4f12950c1..000000000 --- a/mod/gateway/src/subscriber.go +++ /dev/null @@ -1,83 +0,0 @@ -package gateway - -import ( - "context" - "encoding/json" - "errors" - "github.com/cryptopunkscc/astrald/astral" - "github.com/cryptopunkscc/astrald/astral/log" - "github.com/cryptopunkscc/astrald/lib/query" - "time" -) - -const minimumSubscriptionDuration = 15 * time.Minute -const subscribeRetryInterval = 60 * time.Second - -type Subscriber struct { - node astral.Node - log *log.Logger - gateway *astral.Identity - cancel context.CancelFunc -} - -func (s *Subscriber) Gateway() *astral.Identity { - return s.gateway -} - -func NewSubscriber(gateway *astral.Identity, node astral.Node, log *log.Logger) *Subscriber { - return &Subscriber{node: node, log: log, gateway: gateway} -} - -func (s *Subscriber) Run(ctx *astral.Context) error { - ctx, s.cancel = ctx.WithCancel() - defer s.cancel() - - var expiresAt time.Time - for { - conn, err := query.Route(ctx, s.node, astral.NewQuery(s.node.Identity(), s.gateway, SubscribeServiceName)) - if err != nil { - select { - case <-ctx.Done(): - return nil - case <-time.After(subscribeRetryInterval): - } - continue - } - - var info Subscription - err = json.NewDecoder(conn).Decode(&info) - conn.Close() - - if err != nil { - select { - case <-ctx.Done(): - return nil - case <-time.After(subscribeRetryInterval): - } - continue - } - - if info.Status != "ok" { - return errors.New("subscription rejected") - } - - expiresAt = info.ExpiresAt - if time.Until(expiresAt) < minimumSubscriptionDuration { - return errors.New("subscription too short") - } - - s.log.Infov(2, "subscribed to %v until %v", s.gateway, expiresAt) - - select { - case <-ctx.Done(): - return nil - case <-time.After(time.Until(expiresAt) - time.Minute): - } - } -} - -func (s *Subscriber) Cancel() { - if s.cancel != nil { - s.cancel() - } -} diff --git a/mod/gateway/visibility.go b/mod/gateway/visibility.go new file mode 100644 index 000000000..8dbecf5dc --- /dev/null +++ b/mod/gateway/visibility.go @@ -0,0 +1,10 @@ +package gateway + +import "github.com/cryptopunkscc/astrald/astral" + +type Visibility = astral.String8 + +const ( + VisibilityPublic Visibility = "public" + VisibilityPrivate Visibility = "private" +) diff --git a/mod/nodes/src/peers.go b/mod/nodes/src/peers.go index fddfaa92a..c0c7d6aec 100644 --- a/mod/nodes/src/peers.go +++ b/mod/nodes/src/peers.go @@ -11,6 +11,7 @@ import ( "github.com/cryptopunkscc/astrald/astral" "github.com/cryptopunkscc/astrald/lib/query" "github.com/cryptopunkscc/astrald/mod/exonet" + "github.com/cryptopunkscc/astrald/mod/gateway" "github.com/cryptopunkscc/astrald/mod/nodes" "github.com/cryptopunkscc/astrald/mod/nodes/src/frames" "github.com/cryptopunkscc/astrald/mod/nodes/src/noise" @@ -369,6 +370,11 @@ func (mod *Peers) reflectStream(s *Stream) (err error) { return } + // note: rethink maybe switch (?) + if _, ok := s.RemoteEndpoint().(*gateway.Endpoint); ok { + // dont reflect gateway endpoints + return + } // reflect the endpoint err = mod.Objects.Push(mod.ctx, s.RemoteIdentity(), &nodes.ObservedEndpointMessage{ diff --git a/mod/tcp/client/client.go b/mod/tcp/client/client.go new file mode 100644 index 000000000..96a78e3d3 --- /dev/null +++ b/mod/tcp/client/client.go @@ -0,0 +1,27 @@ +package tcp + +import ( + "github.com/cryptopunkscc/astrald/astral" + "github.com/cryptopunkscc/astrald/astral/channel" + "github.com/cryptopunkscc/astrald/lib/astrald" +) + +type Client struct { + astral *astrald.Client + targetID *astral.Identity +} + +func New(targetID *astral.Identity, a *astrald.Client) *Client { + if a == nil { + a = astrald.Default() + } + return &Client{astral: a, targetID: targetID} +} + +func (client *Client) WithTarget(target *astral.Identity) *Client { + return &Client{astral: client.astral, targetID: target} +} + +func (client *Client) queryCh(ctx *astral.Context, method string, args any, cfg ...channel.ConfigFunc) (*channel.Channel, error) { + return client.astral.WithTarget(client.targetID).QueryChannel(ctx, method, args, cfg...) +} diff --git a/mod/tcp/client/close_ephemeral_listener.go b/mod/tcp/client/close_ephemeral_listener.go new file mode 100644 index 000000000..d6a5e036d --- /dev/null +++ b/mod/tcp/client/close_ephemeral_listener.go @@ -0,0 +1,25 @@ +package tcp + +import ( + "github.com/cryptopunkscc/astrald/astral" + "github.com/cryptopunkscc/astrald/astral/channel" + "github.com/cryptopunkscc/astrald/lib/query" + "github.com/cryptopunkscc/astrald/mod/tcp" +) + +func (client *Client) CloseEphemeralListener(ctx *astral.Context, port astral.Uint16) error { + ch, err := client.queryCh(ctx, tcp.MethodCloseEphemeralListener, query.Args{ + "port": port, + }) + if err != nil { + return err + } + defer ch.Close() + + return ch.Switch( + channel.ExpectAck, + func(msg *astral.ErrorMessage) error { + return msg + }, + ) +} diff --git a/mod/tcp/client/new_ephemeral_listener.go b/mod/tcp/client/new_ephemeral_listener.go new file mode 100644 index 000000000..108d684de --- /dev/null +++ b/mod/tcp/client/new_ephemeral_listener.go @@ -0,0 +1,25 @@ +package tcp + +import ( + "github.com/cryptopunkscc/astrald/astral" + "github.com/cryptopunkscc/astrald/astral/channel" + "github.com/cryptopunkscc/astrald/lib/query" + "github.com/cryptopunkscc/astrald/mod/tcp" +) + +func (client *Client) CreateEphemeralListener(ctx *astral.Context, port astral.Uint16) error { + ch, err := client.queryCh(ctx, tcp.MethodNewEphemeralListener, query.Args{ + "port": port, + }) + if err != nil { + return err + } + defer ch.Close() + + return ch.Switch( + channel.ExpectAck, + func(msg *astral.ErrorMessage) error { + return msg + }, + ) +} diff --git a/mod/tcp/errors.go b/mod/tcp/errors.go new file mode 100644 index 000000000..3ad3a4e80 --- /dev/null +++ b/mod/tcp/errors.go @@ -0,0 +1,6 @@ +package tcp + +import "errors" + +var ErrEphemeralListenerExists = errors.New("ephemeral listener already exists") +var ErrEphemeralListenerNotExist = errors.New("ephemeral listener not exists") diff --git a/mod/tcp/module.go b/mod/tcp/module.go index 315b2202d..5d649612c 100644 --- a/mod/tcp/module.go +++ b/mod/tcp/module.go @@ -1,14 +1,21 @@ package tcp import ( + "github.com/cryptopunkscc/astrald/astral" "github.com/cryptopunkscc/astrald/mod/exonet" ) const ModuleName = "tcp" +const ( + MethodNewEphemeralListener = "tcp.new_ephemeral_listener" + MethodCloseEphemeralListener = "tcp.close_ephemeral_listener" +) + type Module interface { exonet.Dialer exonet.Unpacker exonet.Parser ListenPort() int + CreateEphemeralListener(ctx *astral.Context, port astral.Uint16, handler exonet.EphemeralHandler) error } diff --git a/mod/tcp/src/dial.go b/mod/tcp/src/dial.go index 783bf7405..7682ae550 100644 --- a/mod/tcp/src/dial.go +++ b/mod/tcp/src/dial.go @@ -2,6 +2,7 @@ package tcp import ( _net "net" + "time" "github.com/cryptopunkscc/astrald/astral" "github.com/cryptopunkscc/astrald/mod/exonet" @@ -20,7 +21,7 @@ func (mod *Module) Dial(ctx *astral.Context, endpoint exonet.Endpoint) (exonet.C return nil, exonet.ErrDisabledNetwork } - var dialer = _net.Dialer{Timeout: mod.config.DialTimeout} + var dialer = _net.Dialer{Timeout: mod.config.DialTimeout, KeepAlive: 5 * time.Second} tcpConn, err := dialer.DialContext(ctx, "tcp", endpoint.Address()) if err != nil { diff --git a/mod/tcp/src/ephemeral_listener.go b/mod/tcp/src/ephemeral_listener.go new file mode 100644 index 000000000..ac90cd0b0 --- /dev/null +++ b/mod/tcp/src/ephemeral_listener.go @@ -0,0 +1,44 @@ +package tcp + +import ( + "fmt" + + "github.com/cryptopunkscc/astrald/astral" + "github.com/cryptopunkscc/astrald/mod/exonet" + "github.com/cryptopunkscc/astrald/mod/tcp" +) + +func (mod *Module) CreateEphemeralListener(ctx *astral.Context, port astral.Uint16, handler exonet.EphemeralHandler) error { + mod.mu.Lock() + defer mod.mu.Unlock() + + if _, ok := mod.ephemeralListeners.Get(port); ok { + return fmt.Errorf("%w: port %v", tcp.ErrEphemeralListenerExists, port) + } + + srv := NewServer(mod, port, handler) + mod.ephemeralListeners.Set(port, srv) + + go func() { + err := srv.Run(ctx) + if err != nil { + mod.log.Error("ephemeral listener error: %v", err) + } + + mod.ephemeralListeners.Delete(port) + }() + + return nil +} + +func (mod *Module) CloseEphemeralListener(port astral.Uint16) error { + listener, ok := mod.ephemeralListeners.Get(port) + if !ok { + return tcp.ErrEphemeralListenerNotExist + } + + listener.Close() + mod.ephemeralListeners.Delete(port) + + return nil +} diff --git a/mod/tcp/src/loader.go b/mod/tcp/src/loader.go index cc2368278..b1d0d59df 100644 --- a/mod/tcp/src/loader.go +++ b/mod/tcp/src/loader.go @@ -22,6 +22,8 @@ func (Loader) Load(node astral.Node, assets assets.Assets, l *log.Logger) (core. _ = assets.LoadYAML(tcp.ModuleName, &mod.config) + mod.ops.AddStructPrefix(mod, "Op") + for _, addr := range mod.config.Endpoints { addr, _ = strings.CutPrefix(addr, fmt.Sprintf("%s:", tcp.ModuleName)) diff --git a/mod/tcp/src/module.go b/mod/tcp/src/module.go index 130e2a7e3..b2a9a831f 100644 --- a/mod/tcp/src/module.go +++ b/mod/tcp/src/module.go @@ -1,10 +1,13 @@ package tcp import ( + "context" + "sync" "time" "github.com/cryptopunkscc/astrald/astral" "github.com/cryptopunkscc/astrald/astral/log" + "github.com/cryptopunkscc/astrald/lib/ops" "github.com/cryptopunkscc/astrald/mod/exonet" "github.com/cryptopunkscc/astrald/mod/nodes" "github.com/cryptopunkscc/astrald/mod/tcp" @@ -23,8 +26,12 @@ type Module struct { log *log.Logger ctx *astral.Context configEndpoints []exonet.Endpoint + ops ops.Set - server sig.Switch + mu sync.Mutex + + server sig.Switch + ephemeralListeners sig.Map[astral.Uint16, exonet.EphemeralListener] } type Settings struct { @@ -32,6 +39,23 @@ type Settings struct { Dial *tree.Value[*astral.Bool] `tree:"dial"` } +func (mod *Module) GetOpSet() *ops.Set { + return &mod.ops +} + +func (mod *Module) String() string { + return tcp.ModuleName +} + +func (mod *Module) acceptAll(ctx context.Context, conn exonet.Conn) (shouldStop bool, err error) { + err = mod.Nodes.EstablishInboundLink(ctx, conn) + if err != nil { + return false, err + } + + return false, nil +} + func (mod *Module) Run(ctx *astral.Context) (err error) { mod.ctx = ctx diff --git a/mod/tcp/src/op_close_ephemeral_listener.go b/mod/tcp/src/op_close_ephemeral_listener.go new file mode 100644 index 000000000..1ed7179e2 --- /dev/null +++ b/mod/tcp/src/op_close_ephemeral_listener.go @@ -0,0 +1,25 @@ +package tcp + +import ( + "github.com/cryptopunkscc/astrald/astral" + "github.com/cryptopunkscc/astrald/astral/channel" + "github.com/cryptopunkscc/astrald/lib/ops" +) + +type opCloseEphemeralListenerArgs struct { + Port astral.Uint16 + In string `query:"optional"` + Out string `query:"optional"` +} + +func (mod *Module) OpCloseEphemeralListener(ctx *astral.Context, q *ops.Query, args opCloseEphemeralListenerArgs) (err error) { + ch := channel.New(q.Accept(), channel.WithFormats(args.In, args.Out)) + defer ch.Close() + + err = mod.CloseEphemeralListener(args.Port) + if err != nil { + return ch.Send(astral.NewError(err.Error())) + } + + return ch.Send(&astral.Ack{}) +} diff --git a/mod/tcp/src/op_new_ephemeral_listener.go b/mod/tcp/src/op_new_ephemeral_listener.go new file mode 100644 index 000000000..bb649ffc4 --- /dev/null +++ b/mod/tcp/src/op_new_ephemeral_listener.go @@ -0,0 +1,25 @@ +package tcp + +import ( + "github.com/cryptopunkscc/astrald/astral" + "github.com/cryptopunkscc/astrald/astral/channel" + "github.com/cryptopunkscc/astrald/lib/ops" +) + +type opNewEphemeralListenerArgs struct { + Port astral.Uint16 + In string `query:"optional"` + Out string `query:"optional"` +} + +func (mod *Module) OpNewEphemeralListener(ctx *astral.Context, q *ops.Query, args opNewEphemeralListenerArgs) (err error) { + ch := channel.New(q.Accept(), channel.WithFormats(args.In, args.Out)) + defer ch.Close() + + err = mod.CreateEphemeralListener(ctx, args.Port, mod.acceptAll) + if err != nil { + return ch.Send(astral.NewError(err.Error())) + } + + return ch.Send(&astral.Ack{}) +} diff --git a/mod/tcp/src/server.go b/mod/tcp/src/server.go index 5e7af3142..379e11208 100644 --- a/mod/tcp/src/server.go +++ b/mod/tcp/src/server.go @@ -2,62 +2,109 @@ package tcp import ( "context" + "fmt" "net" - "strconv" + "sync/atomic" + "time" "github.com/cryptopunkscc/astrald/astral" + "github.com/cryptopunkscc/astrald/mod/exonet" "github.com/cryptopunkscc/astrald/mod/tcp" ) +var _ exonet.EphemeralListener = &Server{} + type Server struct { *Module + listenPort astral.Uint16 + listener net.Listener + onAccept exonet.EphemeralHandler + closed atomic.Bool + closedCh chan struct{} } -func NewServer(module *Module) *Server { - return &Server{Module: module} +func NewServer(module *Module, listenPort astral.Uint16, onAccept exonet.EphemeralHandler) *Server { + return &Server{ + Module: module, + listenPort: listenPort, + onAccept: onAccept, + closedCh: make(chan struct{}), + } } -func (srv *Server) Run(ctx context.Context) error { - // start the listener - var addrStr = ":" + strconv.Itoa(srv.config.ListenPort) +func (s *Server) Run(ctx *astral.Context) error { + addr := fmt.Sprintf(":%d", s.listenPort) - listener, err := net.Listen("tcp", addrStr) + listener, err := net.Listen("tcp", addr) if err != nil { - srv.log.Errorv(0, "failed to start server: %v", err) - return err + return fmt.Errorf("tcp server/run: failed to listen on %v: %w", addr, err) } - endpoint, _ := tcp.ParseEndpoint(listener.Addr().String()) + s.listener = listener - srv.log.Info("started server at %v", endpoint) - defer srv.log.Info("stopped server at %v", endpoint) + endpoint, _ := tcp.ParseEndpoint(listener.Addr().String()) + s.log.Info("started server at %v", endpoint) go func() { - <-ctx.Done() - listener.Close() + select { + case <-ctx.Done(): + s.Close() + case <-s.Done(): + } }() - // accept connections for { rawConn, err := listener.Accept() if err != nil { - return err + if s.closed.Load() || ctx.Err() != nil { + s.log.Info("stopped server at %v", endpoint) + return nil + } + + return fmt.Errorf("tcp server/run: accept failed: %w", err) } - var conn = tcp.WrapConn(rawConn, false) + if tc, ok := rawConn.(*net.TCPConn); ok { + tc.SetKeepAlive(true) + tc.SetKeepAlivePeriod(30 * time.Second) + } + conn := tcp.WrapConn(rawConn, false) go func() { - err := srv.Nodes.EstablishInboundLink(ctx, conn) + stopListener, err := s.onAccept(ctx, conn) if err != nil { - srv.log.Errorv(1, "handshake failed from %v: %v", conn.RemoteEndpoint(), err) + conn.Close() + s.log.Errorv(1, "tcp server/onAccept error from %v: %v", conn.RemoteEndpoint(), err) return } + + if stopListener { + s.Close() + } }() } } +func (s *Server) Done() <-chan struct{} { + return s.closedCh +} + +func (s *Server) Close() error { + if !s.closed.CompareAndSwap(false, true) { + return nil + } + + if s.listener != nil { + return s.listener.Close() + } + + close(s.closedCh) + return nil +} + func (mod *Module) startServer(ctx context.Context) { - srv := NewServer(mod) + listenPort := astral.Uint16(mod.config.ListenPort) + srv := NewServer(mod, listenPort, mod.acceptAll) if err := srv.Run(astral.NewContext(ctx)); err != nil { mod.log.Errorv(1, "server error: %v", err) }