Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 61 additions & 38 deletions cmd/main/mping.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package main
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net"
Expand All @@ -12,7 +13,7 @@ import (

"github.com/nagayon-935/mping/internal/pinger"
"github.com/nagayon-935/mping/internal/stats"
"github.com/nagayon-935/mping/internal/ui"
ui "github.com/nagayon-935/mping/internal/ui"
"github.com/spf13/pflag"
"gopkg.in/yaml.v3"
)
Expand Down Expand Up @@ -78,7 +79,7 @@ var newPinger = func(targets []*stats.TargetStats, opts pinger.Options) pingerCo
return &pingerAdapter{Pinger: pinger.NewPingerWithOptions(targets, opts)}
}

var uiRun = ui.Run
var uiRun = func(opts ui.RunOptions) error { return ui.Run(opts) }

var (
interfaceByName = net.InterfaceByName
Expand All @@ -88,11 +89,11 @@ var (
func getInterfaceIP(ifaceName string, wantIPv6 bool) (string, error) {
iface, err := interfaceByName(ifaceName)
if err != nil {
return "", err
return "", fmt.Errorf("lookup interface %q: %w", ifaceName, err)
}
addrs, err := iface.Addrs()
if err != nil {
return "", err
return "", fmt.Errorf("get addresses for interface %q: %w", ifaceName, err)
}
for _, addr := range addrs {
if ipnet, ok := addr.(*net.IPNet); ok && !ipnet.IP.IsLoopback() {
Expand Down Expand Up @@ -478,7 +479,7 @@ func setupPortChecker(targets []*stats.TargetStats, portSpecs []pinger.PortSpec,
func run(args []string, out io.Writer, errOut io.Writer) int {
cfg, hosts, fs, usage, err := parseArgs(args)
if err != nil {
if err == pflag.ErrHelp {
if errors.Is(err, pflag.ErrHelp) {
fmt.Fprint(out, usage)
return 0
}
Expand Down Expand Up @@ -552,23 +553,22 @@ func run(args []string, out io.Writer, errOut io.Writer) int {
p pingerController
traceCtx context.Context
traceCancel context.CancelFunc
portChecker *pinger.PortChecker
)

startPinger := func() error {
next := makePinger(packetSizeToUse)
if err := next.Start(interval, timeout); err != nil {
return err
}
pMu.Lock()
if cfg.trace {
pMu.Lock()
if traceCancel != nil {
traceCancel()
}
traceCtx, traceCancel = context.WithCancel(context.Background())
go runTraceroutes(traceCtx, next, targets)
pMu.Unlock()
}
pMu.Lock()
p = next
pMu.Unlock()
return nil
Expand Down Expand Up @@ -603,16 +603,17 @@ func run(args []string, out io.Writer, errOut io.Writer) int {
}
portSpecs = append(portSpecs, spec)
}
portChecker := setupPortChecker(targets, portSpecs, interval, timeout)

if cfg.trace {
// Traceroute info can be added to logs if needed
}
pMu.Lock()
portChecker = setupPortChecker(targets, portSpecs, interval, timeout)
pMu.Unlock()

stopAll := func() {
stopPinger()
if portChecker != nil {
portChecker.Stop()
pMu.Lock()
cur := portChecker
pMu.Unlock()
if cur != nil {
cur.Stop()
}
}

Expand All @@ -632,40 +633,62 @@ func run(args []string, out io.Writer, errOut io.Writer) int {
var resetPort func()
if len(portSpecs) > 0 {
resetPort = func() {
if portChecker != nil {
portChecker.Stop()
pMu.Lock()
cur := portChecker
portChecker = nil
pMu.Unlock()
if cur != nil {
cur.Stop()
}
portChecker = setupPortChecker(targets, portSpecs, interval, timeout)
next := setupPortChecker(targets, portSpecs, interval, timeout)
pMu.Lock()
portChecker = next
pMu.Unlock()
}
}

// doneCh is closed when the pinger finishes (count-limited mode).
var doneCh chan struct{}
if cfg.count > 0 {
doneCh = make(chan struct{})
go func() {
pMu.Lock()
cur := p
pMu.Unlock()
if cur != nil {
cur.Wait()
}
close(doneCh)
}()
}

// Start TUI
if err := uiRun(
targets,
interval,
timeout,
nil,
displaySourceIPv4,
displaySourceIPv6,
packetSizeToUse,
preLogs,
cfg.trace,
len(portSpecs) > 0,
cfg.asnEnabled,
stopAll,
func() error {
if err := uiRun(ui.RunOptions{
Targets: targets,
Interval: interval,
Timeout: timeout,
DoneCh: doneCh,
SourceIPv4: displaySourceIPv4,
SourceIPv6: displaySourceIPv6,
PacketSize: packetSizeToUse,
InitialLogs: preLogs,
TraceEnabled: cfg.trace,
PortEnabled: len(portSpecs) > 0,
ASNEnabled: cfg.asnEnabled,
OnStop: stopAll,
OnRestart: func() error {
stopAll()
if err := startPinger(); err != nil {
return err
}
if portChecker != nil {
portChecker = setupPortChecker(targets, portSpecs, interval, timeout)
}
pMu.Lock()
portChecker = setupPortChecker(targets, portSpecs, interval, timeout)
pMu.Unlock()
return nil
},
resetTrace,
resetPort,
); err != nil {
OnResetTrace: resetTrace,
OnResetPort: resetPort,
}); err != nil {
fmt.Fprintf(errOut, "Error running application: %v\n", err)
stopAll()
return 1
Expand Down
47 changes: 24 additions & 23 deletions cmd/main/mping_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (

"github.com/nagayon-935/mping/internal/pinger"
"github.com/nagayon-935/mping/internal/stats"
"github.com/nagayon-935/mping/internal/ui"
"github.com/spf13/pflag"
)

Expand Down Expand Up @@ -209,9 +210,9 @@ func TestRunStopRestart(t *testing.T) {
newPinger = func(targets []*stats.TargetStats, opts pinger.Options) pingerController {
return fp
}
uiRun = func(targets []*stats.TargetStats, interval, timeout time.Duration, doneCh chan struct{}, sourceIPv4, sourceIPv6 string, packetSize int, initialLogs []string, traceEnabled bool, portEnabled bool, asnEnabled bool, onStop func(), onRestart func() error, onResetTrace func(), onResetPort func()) error {
onStop()
if err := onRestart(); err != nil {
uiRun = func(opts ui.RunOptions) error {
opts.OnStop()
if err := opts.OnRestart(); err != nil {
t.Fatalf("restart failed: %v", err)
}
return nil
Expand Down Expand Up @@ -239,7 +240,7 @@ func TestRunStartError(t *testing.T) {
newPinger = func(targets []*stats.TargetStats, opts pinger.Options) pingerController {
return fp
}
uiRun = func(targets []*stats.TargetStats, interval, timeout time.Duration, doneCh chan struct{}, sourceIPv4, sourceIPv6 string, packetSize int, initialLogs []string, traceEnabled bool, portEnabled bool, asnEnabled bool, onStop func(), onRestart func() error, onResetTrace func(), onResetPort func()) error {
uiRun = func(opts ui.RunOptions) error {
return nil
}

Expand Down Expand Up @@ -592,7 +593,7 @@ func TestRunInvalidPortSpec(t *testing.T) {
newPinger = func(targets []*stats.TargetStats, opts pinger.Options) pingerController {
return &fakePinger{}
}
uiRun = func(targets []*stats.TargetStats, interval, timeout time.Duration, doneCh chan struct{}, sourceIPv4, sourceIPv6 string, packetSize int, initialLogs []string, traceEnabled bool, portEnabled bool, asnEnabled bool, onStop func(), onRestart func() error, onResetTrace func(), onResetPort func()) error {
uiRun = func(opts ui.RunOptions) error {
return nil
}
var out, errOut bytes.Buffer
Expand All @@ -616,7 +617,7 @@ func TestRunWithPortSpec(t *testing.T) {
newPinger = func(targets []*stats.TargetStats, opts pinger.Options) pingerController {
return fp
}
uiRun = func(targets []*stats.TargetStats, interval, timeout time.Duration, doneCh chan struct{}, sourceIPv4, sourceIPv6 string, packetSize int, initialLogs []string, traceEnabled bool, portEnabled bool, asnEnabled bool, onStop func(), onRestart func() error, onResetTrace func(), onResetPort func()) error {
uiRun = func(opts ui.RunOptions) error {
return nil
}
var out, errOut bytes.Buffer
Expand All @@ -637,7 +638,7 @@ func TestRunWithTrace(t *testing.T) {
newPinger = func(targets []*stats.TargetStats, opts pinger.Options) pingerController {
return fp
}
uiRun = func(targets []*stats.TargetStats, interval, timeout time.Duration, doneCh chan struct{}, sourceIPv4, sourceIPv6 string, packetSize int, initialLogs []string, traceEnabled bool, portEnabled bool, asnEnabled bool, onStop func(), onRestart func() error, onResetTrace func(), onResetPort func()) error {
uiRun = func(opts ui.RunOptions) error {
return nil
}
var out, errOut bytes.Buffer
Expand All @@ -662,26 +663,26 @@ func TestRunResetTrace(t *testing.T) {
return fp
}

uiRun = func(targets []*stats.TargetStats, interval, timeout time.Duration, doneCh chan struct{}, sourceIPv4, sourceIPv6 string, packetSize int, initialLogs []string, traceEnabled bool, portEnabled bool, asnEnabled bool, onStop func(), onRestart func() error, onResetTrace func(), onResetPort func()) error {
if onResetTrace == nil {
t.Error("onResetTrace must not be nil when traceEnabled=true")
uiRun = func(opts ui.RunOptions) error {
if opts.OnResetTrace == nil {
t.Error("OnResetTrace must not be nil when TraceEnabled=true")
return nil
}
// Simulate the UI clearing TraceHops before calling onResetTrace (as 'R' does).
for _, tg := range targets {
// Simulate the UI clearing TraceHops before calling OnResetTrace (as 'R' does).
for _, tg := range opts.Targets {
tg.SetTraceHops(nil)
}
onResetTrace()
opts.OnResetTrace()
// Wait for the re-run to populate TraceHops.
deadline := time.Now().Add(2 * time.Second)
for time.Now().Before(deadline) {
if len(targets[0].GetView().TraceHops) > 0 {
if len(opts.Targets[0].GetView().TraceHops) > 0 {
break
}
time.Sleep(10 * time.Millisecond)
}
if hops := targets[0].GetView().TraceHops; len(hops) == 0 {
t.Error("TraceHops not repopulated after onResetTrace()")
if hops := opts.Targets[0].GetView().TraceHops; len(hops) == 0 {
t.Error("TraceHops not repopulated after OnResetTrace()")
}
return nil
}
Expand All @@ -708,17 +709,17 @@ func TestRunResetPort(t *testing.T) {
return fp
}

uiRun = func(targets []*stats.TargetStats, interval, timeout time.Duration, doneCh chan struct{}, sourceIPv4, sourceIPv6 string, packetSize int, initialLogs []string, traceEnabled bool, portEnabled bool, asnEnabled bool, onStop func(), onRestart func() error, onResetTrace func(), onResetPort func()) error {
if !portEnabled {
t.Error("portEnabled should be true when port specs given")
uiRun = func(opts ui.RunOptions) error {
if !opts.PortEnabled {
t.Error("PortEnabled should be true when port specs given")
return nil
}
if onResetPort == nil {
t.Error("onResetPort must not be nil when port specs are provided")
if opts.OnResetPort == nil {
t.Error("OnResetPort must not be nil when port specs are provided")
return nil
}
// Calling it must not panic
onResetPort()
opts.OnResetPort()
return nil
}

Expand All @@ -739,7 +740,7 @@ func TestRunWithMTUIPv6Warning(t *testing.T) {
newPinger = func(targets []*stats.TargetStats, opts pinger.Options) pingerController {
return &fakePinger{}
}
uiRun = func(targets []*stats.TargetStats, interval, timeout time.Duration, doneCh chan struct{}, sourceIPv4, sourceIPv6 string, packetSize int, initialLogs []string, traceEnabled bool, portEnabled bool, asnEnabled bool, onStop func(), onRestart func() error, onResetTrace func(), onResetPort func()) error {
uiRun = func(opts ui.RunOptions) error {
return nil
}
var out, errOut bytes.Buffer
Expand Down
11 changes: 8 additions & 3 deletions internal/pinger/pinger.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
package pinger

import (
"errors"
"fmt"
"io"
"net"
"os"
"strings"
"sync"
"sync/atomic"
"time"

"github.com/nagayon-935/mping/internal/stats"
Expand Down Expand Up @@ -44,7 +46,7 @@ type PacketConnV6 interface {
SetControlMessage(cf ipv6.ControlFlags, on bool) error
}

// Reply represents a received ICMP echo reply.
// Reply represents a single received ICMP echo reply or error from the receiver loop.
type Reply struct {
RTT time.Duration
TTL int
Expand Down Expand Up @@ -79,7 +81,7 @@ type Pinger struct {

traceChans []chan traceMsg // one per concurrent TraceRoute call
traceChansMu sync.RWMutex
traceCounter uint32 // atomic counter for unique traceID per concurrent call
traceCounter atomic.Uint32 // unique traceID per concurrent call

LogWriter io.Writer // Optional logger

Expand All @@ -104,10 +106,12 @@ type Options struct {
AsnEnabled bool
}

// NewPinger creates a Pinger with default options for the given targets.
func NewPinger(targets []*stats.TargetStats) *Pinger {
return NewPingerWithOptions(targets, Options{})
}

// NewPingerWithOptions creates a Pinger with the provided options.
func NewPingerWithOptions(targets []*stats.TargetStats, opts Options) *Pinger {
resolve := opts.ResolveIPAddr
if resolve == nil {
Expand Down Expand Up @@ -357,7 +361,8 @@ func (p *Pinger) runReceiver(
setDeadline(time.Now().Add(receiverReadTimeout))
n, ttl, src, err := readFrom(buf)
if err != nil {
if opErr, ok := err.(*net.OpError); ok && opErr.Timeout() {
var opErr *net.OpError
if errors.As(err, &opErr) && opErr.Timeout() {
continue
}
return
Expand Down
Loading