diff --git a/cmd/centralserver/main.go b/cmd/centralserver/main.go index cc37e50..bb198e4 100644 --- a/cmd/centralserver/main.go +++ b/cmd/centralserver/main.go @@ -17,7 +17,20 @@ import ( "github.com/ParsaKSH/SlipStream-Plus/internal/tunnel" ) -//cc + +// sourceConn wraps a tunnel connection with a write mutex so that +// concurrent WriteFrame calls from different ConnIDs don't interleave bytes. +type sourceConn struct { + conn net.Conn + writeMu sync.Mutex +} + +func (sc *sourceConn) WriteFrame(f *tunnel.Frame) error { + sc.writeMu.Lock() + defer sc.writeMu.Unlock() + return tunnel.WriteFrame(sc.conn, f) +} + // connState tracks a single reassembled connection. type connState struct { mu sync.Mutex @@ -29,7 +42,7 @@ type connState struct { // Sources: all tunnel connections that can carry reverse data. // We round-robin responses across them (not broadcast). - sources []io.Writer + sources []*sourceConn sourceIdx int } @@ -39,6 +52,10 @@ type centralServer struct { mu sync.RWMutex conns map[uint32]*connState // ConnID → state + + // sourceMu protects the sources map (net.Conn → *sourceConn). + sourceMu sync.Mutex + sourceMap map[net.Conn]*sourceConn } func main() { @@ -54,6 +71,7 @@ func main() { cs := ¢ralServer{ socksUpstream: *socksUpstream, conns: make(map[uint32]*connState), + sourceMap: make(map[net.Conn]*sourceConn), } sigCh := make(chan os.Signal, 1) @@ -147,12 +165,15 @@ func (cs *centralServer) handleSOCKS5Passthrough(clientConn net.Conn, firstByte func (cs *centralServer) handleFrameConn(conn net.Conn, firstByte byte, remoteAddr string) { log.Printf("[central] frame connection from %s", remoteAddr) + sc := cs.getSourceConn(conn) + // Track which ConnIDs this source served servedIDs := make(map[uint32]bool) defer func() { - // Source TCP died — clean up connStates that only had this source - cs.cleanupSource(conn, servedIDs, remoteAddr) + // Source TCP died — clean up sourceConn and connStates + cs.removeSourceConn(conn) + cs.cleanupSource(sc, servedIDs, remoteAddr) }() // Read remaining header bytes (we already read 1) @@ -168,7 +189,7 @@ func (cs *centralServer) handleFrameConn(conn net.Conn, firstByte byte, remoteAd firstFrame := cs.parseHeader(fullHdr, conn, remoteAddr) if firstFrame != nil { servedIDs[firstFrame.ConnID] = true - cs.dispatchFrame(firstFrame, conn) + cs.dispatchFrame(firstFrame, sc) } for { @@ -180,13 +201,13 @@ func (cs *centralServer) handleFrameConn(conn net.Conn, firstByte byte, remoteAd return } servedIDs[frame.ConnID] = true - cs.dispatchFrame(frame, conn) + cs.dispatchFrame(frame, sc) } } // cleanupSource removes a dead source connection from all connStates. // If a connState has no remaining sources, it is fully cleaned up. -func (cs *centralServer) cleanupSource(deadSource net.Conn, servedIDs map[uint32]bool, remoteAddr string) { +func (cs *centralServer) cleanupSource(deadSource *sourceConn, servedIDs map[uint32]bool, remoteAddr string) { cs.mu.Lock() defer cs.mu.Unlock() @@ -227,6 +248,26 @@ func (cs *centralServer) cleanupSource(deadSource net.Conn, servedIDs map[uint32 } } +// getSourceConn returns the sourceConn wrapper for a raw net.Conn, +// creating one if it doesn't exist yet. +func (cs *centralServer) getSourceConn(conn net.Conn) *sourceConn { + cs.sourceMu.Lock() + defer cs.sourceMu.Unlock() + sc, ok := cs.sourceMap[conn] + if !ok { + sc = &sourceConn{conn: conn} + cs.sourceMap[conn] = sc + } + return sc +} + +// removeSourceConn removes the sourceConn wrapper when the raw conn dies. +func (cs *centralServer) removeSourceConn(conn net.Conn) { + cs.sourceMu.Lock() + delete(cs.sourceMap, conn) + cs.sourceMu.Unlock() +} + func isClosedConnErr(err error) bool { if err == nil { return false @@ -257,7 +298,7 @@ func (cs *centralServer) parseHeader(hdr [tunnel.HeaderSize]byte, conn net.Conn, } } -func (cs *centralServer) dispatchFrame(frame *tunnel.Frame, source net.Conn) { +func (cs *centralServer) dispatchFrame(frame *tunnel.Frame, source *sourceConn) { if frame.IsSYN() { cs.handleSYN(frame, source) return @@ -273,7 +314,7 @@ func (cs *centralServer) dispatchFrame(frame *tunnel.Frame, source net.Conn) { cs.handleData(frame, source) } -func (cs *centralServer) handleSYN(frame *tunnel.Frame, source net.Conn) { +func (cs *centralServer) handleSYN(frame *tunnel.Frame, source *sourceConn) { connID := frame.ConnID cs.mu.Lock() @@ -307,7 +348,7 @@ func (cs *centralServer) handleSYN(frame *tunnel.Frame, source net.Conn) { ctx, cancel := context.WithCancel(context.Background()) state := &connState{ reorderer: tunnel.NewReordererAt(frame.SeqNum + 1), // skip SYN's SeqNum - sources: []io.Writer{source}, + sources: []*sourceConn{source}, cancel: cancel, created: time.Now(), } @@ -315,11 +356,11 @@ func (cs *centralServer) handleSYN(frame *tunnel.Frame, source net.Conn) { cs.mu.Unlock() log.Printf("[central] conn=%d: SYN → target=%s", connID, targetAddr) - go cs.connectUpstream(ctx, connID, state, atyp, addr, port, targetAddr, source) + go cs.connectUpstream(ctx, connID, state, atyp, addr, port, targetAddr) } func (cs *centralServer) connectUpstream(ctx context.Context, connID uint32, state *connState, - atyp byte, addr, port []byte, targetAddr string, source net.Conn) { + atyp byte, addr, port []byte, targetAddr string) { upConn, err := net.DialTimeout("tcp", cs.socksUpstream, 10*time.Second) if err != nil { @@ -349,6 +390,7 @@ func (cs *centralServer) connectUpstream(ctx context.Context, connID uint32, sta return } + // Read greeting response (2 bytes) + CONNECT response header (4 bytes) resp := make([]byte, 6) if _, err := io.ReadFull(upConn, resp); err != nil { log.Printf("[central] conn=%d: upstream response read failed: %v", connID, err) @@ -358,7 +400,18 @@ func (cs *centralServer) connectUpstream(ctx context.Context, connID uint32, sta return } - // Drain bind address + // Check CONNECT result BEFORE draining bind address. + // resp[3] = REP field (0x00 = success). If non-zero, upstream may close + // without sending bind address, so don't try to drain it. + if resp[3] != 0x00 { + log.Printf("[central] conn=%d: upstream CONNECT rejected: 0x%02x", connID, resp[3]) + upConn.Close() + cs.sendFrame(connID, &tunnel.Frame{ConnID: connID, Flags: tunnel.FlagRST | tunnel.FlagReverse}) + cs.removeConn(connID) + return + } + + // Drain bind address (only on success) switch resp[5] { case 0x01: io.ReadFull(upConn, make([]byte, 6)) @@ -372,14 +425,6 @@ func (cs *centralServer) connectUpstream(ctx context.Context, connID uint32, sta io.ReadFull(upConn, make([]byte, 6)) } - if resp[3] != 0x00 { - log.Printf("[central] conn=%d: upstream CONNECT rejected: 0x%02x", connID, resp[3]) - upConn.Close() - cs.sendFrame(connID, &tunnel.Frame{ConnID: connID, Flags: tunnel.FlagRST | tunnel.FlagReverse}) - cs.removeConn(connID) - return - } - state.mu.Lock() state.target = upConn @@ -453,7 +498,8 @@ func (cs *centralServer) relayUpstreamToTunnel(ctx context.Context, connID uint3 } // sendFrame picks ONE source via round-robin and writes the frame. -// If that source fails, tries the next one. Much better than broadcasting. +// If that source fails, tries the next one. Uses sourceConn.WriteFrame +// which is mutex-protected per TCP connection, preventing interleaved writes. func (cs *centralServer) sendFrame(connID uint32, frame *tunnel.Frame) { cs.mu.RLock() state, ok := cs.conns[connID] @@ -473,9 +519,9 @@ func (cs *centralServer) sendFrame(connID uint32, frame *tunnel.Frame) { for tries := 0; tries < len(state.sources); tries++ { idx := state.sourceIdx % len(state.sources) state.sourceIdx++ - w := state.sources[idx] + sc := state.sources[idx] - if err := tunnel.WriteFrame(w, frame); err != nil { + if err := sc.WriteFrame(frame); err != nil { // Remove dead source state.sources = append(state.sources[:idx], state.sources[idx+1:]...) if state.sourceIdx > 0 { @@ -488,7 +534,7 @@ func (cs *centralServer) sendFrame(connID uint32, frame *tunnel.Frame) { log.Printf("[central] conn=%d: all sources failed", connID) } -func (cs *centralServer) handleData(frame *tunnel.Frame, source net.Conn) { +func (cs *centralServer) handleData(frame *tunnel.Frame, source *sourceConn) { cs.mu.RLock() state, ok := cs.conns[frame.ConnID] cs.mu.RUnlock() @@ -548,7 +594,10 @@ func (cs *centralServer) handleFIN(frame *tunnel.Frame) { state.target.Close() } state.mu.Unlock() - log.Printf("[central] conn=%d: FIN received", frame.ConnID) + + // Remove from map and cancel context so relayUpstreamToTunnel exits cleanly + cs.removeConn(frame.ConnID) + log.Printf("[central] conn=%d: FIN received, cleaned up", frame.ConnID) } func (cs *centralServer) handleRST(frame *tunnel.Frame) { @@ -601,10 +650,8 @@ func (cs *centralServer) cleanupLoop() { if len(state.sources) == 0 && now.Sub(state.created) > 30*time.Second { shouldClean = true } - // Connection too old (5 min max lifetime) - if now.Sub(state.created) > 5*time.Minute { - shouldClean = true - } + // No max lifetime — long-lived connections (downloads, streams) + // are valid. Cleanup only based on actual broken state above. state.mu.Unlock() if shouldClean { diff --git a/internal/gui/dashboard.go b/internal/gui/dashboard.go index a200d64..cde306b 100644 --- a/internal/gui/dashboard.go +++ b/internal/gui/dashboard.go @@ -224,8 +224,8 @@ canvas{width:100%;height:200px;border-radius:var(--rs);background:var(--bg2);bor
Central Server Settings
- - + +
diff --git a/internal/health/checker.go b/internal/health/checker.go index 618cd4c..4afbbb2 100644 --- a/internal/health/checker.go +++ b/internal/health/checker.go @@ -2,12 +2,14 @@ package health import ( "context" + "crypto/rand" "encoding/binary" "fmt" "io" "log" "net" "sync" + "sync/atomic" "time" "github.com/ParsaKSH/SlipStream-Plus/internal/config" @@ -34,6 +36,8 @@ type Checker struct { mu sync.Mutex failures map[int]int + + probeSeq atomic.Uint32 // unique probe counter to avoid ConnID collisions } func NewChecker(mgr *engine.Manager, cfg *config.HealthCheckConfig) *Checker { @@ -371,8 +375,14 @@ func (c *Checker) probeFramingProtocol(inst *engine.Instance) (time.Duration, er synPayload = append(synPayload, []byte(domain)...) // domain synPayload = append(synPayload, 0x00, 0x50) // port 80 - // Use a unique probe ConnID (high range to avoid collision) - probeConnID := uint32(0xFFFF0000) + uint32(inst.ID()) + // Use a unique probe ConnID combining high-range prefix, instance ID, + // monotonic counter, and random bits to avoid collisions with real connections + // and previous probes that haven't been cleaned up yet. + seq := c.probeSeq.Add(1) + var rndBuf [2]byte + rand.Read(rndBuf[:]) + rnd := uint32(binary.BigEndian.Uint16(rndBuf[:])) + probeConnID := uint32(0xFE000000) | (uint32(inst.ID())&0xFF)<<16 | (seq&0xFF)<<8 | (rnd & 0xFF) synFrame := &tunnel.Frame{ ConnID: probeConnID, diff --git a/internal/proxy/socks5.go b/internal/proxy/socks5.go index f715233..17f5931 100644 --- a/internal/proxy/socks5.go +++ b/internal/proxy/socks5.go @@ -34,7 +34,7 @@ type Server struct { // Packet-split mode fields packetSplit bool tunnelPool *tunnel.TunnelPool - connIDGen tunnel.ConnIDGenerator + connIDGen *tunnel.ConnIDGenerator chunkSize int } @@ -53,6 +53,7 @@ func NewServer(listenAddr string, bufferSize int, maxConns int, mgr *engine.Mana func (s *Server) EnablePacketSplit(pool *tunnel.TunnelPool, chunkSize int) { s.packetSplit = true s.tunnelPool = pool + s.connIDGen = tunnel.NewConnIDGenerator() s.chunkSize = chunkSize log.Printf("[proxy] packet-split mode enabled (chunk_size=%d)", chunkSize) } @@ -281,6 +282,13 @@ func (s *Server) handlePacketSplit(clientConn net.Conn, connID uint64, atyp byte ctx, cancel := context.WithCancel(context.Background()) defer cancel() + // When context is cancelled (either direction finished), close the client + // connection so that blocking Read/Write calls unblock immediately. + go func() { + <-ctx.Done() + clientConn.Close() + }() + var txN, rxN int64 var wg sync.WaitGroup diff --git a/internal/tunnel/pool.go b/internal/tunnel/pool.go index 5db735d..df0e070 100644 --- a/internal/tunnel/pool.go +++ b/internal/tunnel/pool.go @@ -230,8 +230,12 @@ func (p *TunnelPool) readLoop(tc *TunnelConn) { ch := v.(chan *Frame) select { case ch <- frame: - default: - // Buffer full — drop silently + case <-time.After(5 * time.Second): + // Buffer full for too long — connection is stuck, log and drop + log.Printf("[tunnel-pool] instance %d: frame buffer full for conn=%d, dropping frame seq=%d", + tc.inst.ID(), frame.ConnID, frame.SeqNum) + case <-p.stopCh: + return } } } diff --git a/internal/tunnel/protocol.go b/internal/tunnel/protocol.go index 922d9d5..4fb93b8 100644 --- a/internal/tunnel/protocol.go +++ b/internal/tunnel/protocol.go @@ -1,6 +1,7 @@ package tunnel import ( + "crypto/rand" "encoding/binary" "fmt" "io" @@ -104,10 +105,27 @@ func ReadFrame(r io.Reader) (*Frame, error) { } // ConnIDGenerator produces unique connection IDs. +// Starts from a random offset to avoid collisions after process restarts +// while CentralServer still holds state from the previous session. type ConnIDGenerator struct { counter atomic.Uint32 } +// NewConnIDGenerator creates a generator with a random starting offset. +func NewConnIDGenerator() *ConnIDGenerator { + g := &ConnIDGenerator{} + // Random start in [1, 0x7FFFFFFF) — avoids 0 and leaves room before wrap + var buf [4]byte + if _, err := rand.Read(buf[:]); err == nil { + start := binary.BigEndian.Uint32(buf[:]) & 0x7FFFFFFF + if start == 0 { + start = 1 + } + g.counter.Store(start) + } + return g +} + // Next returns the next unique connection ID. func (g *ConnIDGenerator) Next() uint32 { return g.counter.Add(1) diff --git a/internal/tunnel/splitter.go b/internal/tunnel/splitter.go index 456343a..9577b36 100644 --- a/internal/tunnel/splitter.go +++ b/internal/tunnel/splitter.go @@ -233,25 +233,28 @@ func (ps *PacketSplitter) pickInstanceExcluding(excludeID int) *engine.Instance } // Reorderer buffers out-of-order frames and delivers them in sequence order. +// If a frame is missing for longer than gapTimeout, it is skipped to prevent +// permanent stalls from lost frames. type Reorderer struct { - nextSeq uint32 - buffer map[uint32][]byte - timeout time.Duration + nextSeq uint32 + buffer map[uint32][]byte + gapTimeout time.Duration + waitingSince time.Time // when we first started waiting for nextSeq } func NewReorderer() *Reorderer { return &Reorderer{ - nextSeq: 0, - buffer: make(map[uint32][]byte), - timeout: 10 * time.Second, + nextSeq: 0, + buffer: make(map[uint32][]byte), + gapTimeout: 2 * time.Second, } } func NewReordererAt(startSeq uint32) *Reorderer { return &Reorderer{ - nextSeq: startSeq, - buffer: make(map[uint32][]byte), - timeout: 10 * time.Second, + nextSeq: startSeq, + buffer: make(map[uint32][]byte), + gapTimeout: 2 * time.Second, } } @@ -263,13 +266,59 @@ func (r *Reorderer) Insert(seq uint32, data []byte) { } func (r *Reorderer) Next() []byte { + // Fast path: next seq is available data, ok := r.buffer[r.nextSeq] - if !ok { + if ok { + delete(r.buffer, r.nextSeq) + r.nextSeq++ + r.waitingSince = time.Time{} // reset wait timer + return data + } + + // Nothing buffered at all — nothing to skip to + if len(r.buffer) == 0 { + r.waitingSince = time.Time{} return nil } - delete(r.buffer, r.nextSeq) - r.nextSeq++ - return data + + // There are buffered frames but nextSeq is missing. + // Start or check the gap timer. + now := time.Now() + if r.waitingSince.IsZero() { + r.waitingSince = now + return nil + } + + if now.Sub(r.waitingSince) < r.gapTimeout { + return nil // still within grace period + } + + // Gap timeout expired — skip to the lowest available seq + r.skipToLowest() + r.waitingSince = time.Time{} + + data, ok = r.buffer[r.nextSeq] + if ok { + delete(r.buffer, r.nextSeq) + r.nextSeq++ + return data + } + return nil +} + +// skipToLowest advances nextSeq to the lowest seq number in the buffer. +func (r *Reorderer) skipToLowest() { + minSeq := r.nextSeq + found := false + for seq := range r.buffer { + if !found || seq < minSeq { + minSeq = seq + found = true + } + } + if found && minSeq > r.nextSeq { + r.nextSeq = minSeq + } } func (r *Reorderer) Pending() int {