diff --git a/app/app.go b/app/app.go index ca9af1b..4a0981e 100644 --- a/app/app.go +++ b/app/app.go @@ -84,7 +84,11 @@ func NewApp( app := new(App) anonymiser.Enable(settings.Anonymise) - app.db = connector.NewConnector(connectorFlags).DB + conn, err := connector.NewConnector(connectorFlags) + if err != nil { + return nil, fmt.Errorf("failed to connect to database: %w", err) + } + app.db = conn.DB status := global.NewStatus(app.db) variables := global.NewVariables(app.db) @@ -124,10 +128,10 @@ func NewApp( app.resetDBStatistics() - var err error - app.currentView, err = view.SetupAndValidate(settings.ViewName, app.db) // if empty will use the default - if err != nil { - return nil, fmt.Errorf("app.NewApp: %w", err) + var viewErr error + app.currentView, viewErr = view.SetupAndValidate(settings.ViewName, app.db) // if empty will use the default + if viewErr != nil { + return nil, fmt.Errorf("app.NewApp: %w", viewErr) } app.UpdateCurrentTabler() diff --git a/connector/connector.go b/connector/connector.go index a47a526..a2ad4e7 100644 --- a/connector/connector.go +++ b/connector/connector.go @@ -6,11 +6,9 @@ import ( "database/sql" "fmt" "math" - "os" "github.com/sjmudd/mysql_defaults_file" "github.com/sjmudd/ps-top/log" - "github.com/sjmudd/ps-top/utils" ) // Method indicates how we want to connect to MySQL @@ -101,51 +99,24 @@ func (c *Connector) Connect() error { return nil } -// ConnectByConfig connects to MySQL using various configuration settings -// needed to create the DSN. -func (c *Connector) ConnectByConfig(config mysql_defaults_file.Config) { - c.config = config - c.SetMethod(ConnectByConfig) - if err := c.Connect(); err != nil { - fmt.Println(utils.ProgName+": ConnectByConfig failed:", err.Error()) - os.Exit(1) - } -} - -// ConnectByDefaultsFile connects to the database with the given -// defaults-file, or ~/.my.cnf if not provided. -func (c *Connector) ConnectByDefaultsFile(defaultsFile string) { - c.config = mysql_defaults_file.NewConfig(defaultsFile) - c.SetMethod(ConnectByDefaultsFile) - if err := c.Connect(); err != nil { - fmt.Println(utils.ProgName+": ConnectByDefaultsFile failed:", err.Error()) - os.Exit(1) - } -} -// ConnectByEnvironment connects using environment variables -func (c *Connector) ConnectByEnvironment() { - c.SetMethod(ConnectByEnvironment) - if err := c.Connect(); err != nil { - fmt.Println(utils.ProgName+": ConnectByEnvironment failed:", err.Error()) - os.Exit(1) - } -} - -// NewConnector returns a connected Connector given the provided configuration -func NewConnector(cfg Config) *Connector { // nolint:gocyclo +// NewConnector returns a connected Connector given the provided configuration. +// It returns an error if connection fails instead of calling os.Exit. +func NewConnector(cfg Config) (*Connector, error) { // nolint:gocyclo var defaultsFile string connector := new(Connector) if *cfg.UseEnvironment { - connector.ConnectByEnvironment() + connector.method = ConnectByEnvironment + if err := connector.Connect(); err != nil { + return nil, fmt.Errorf("ConnectByEnvironment: %w", err) + } } else { if *cfg.Host != "" || *cfg.Socket != "" { log.Println("--host= or --socket= defined") var config mysql_defaults_file.Config if *cfg.Host != "" && *cfg.Socket != "" { - fmt.Println(utils.ProgName + ": Do not specify --host and --socket together") - os.Exit(1) + return nil, fmt.Errorf("do not specify --host and --socket together") } if *cfg.Host != "" { config.Host = *cfg.Host @@ -155,13 +126,11 @@ func NewConnector(cfg Config) *Connector { // nolint:gocyclo // validate port number port := *cfg.Port if port < 0 || port > math.MaxUint16 { - fmt.Println(utils.ProgName+": Invalid port value", *cfg.Port) - os.Exit(1) + return nil, fmt.Errorf("invalid port value: %d", port) } config.Port = uint16(port) // nolint:gosec } else { - fmt.Println(utils.ProgName + ": Do not specify --socket and --port together") - os.Exit(1) + return nil, fmt.Errorf("do not specify --socket and --port together") } } if *cfg.Socket != "" { @@ -173,7 +142,11 @@ func NewConnector(cfg Config) *Connector { // nolint:gocyclo if *cfg.Password != "" { config.Password = *cfg.Password } - connector.ConnectByConfig(config) + connector.config = config + connector.method = ConnectByConfig + if err := connector.Connect(); err != nil { + return nil, fmt.Errorf("ConnectByConfig: %w", err) + } } else { // no host or socket provided so assume connecting by a defaults file. // - if an explicit defaults-file is provided use that. @@ -186,9 +159,13 @@ func NewConnector(cfg Config) *Connector { // nolint:gocyclo } else { log.Println("NewConnector: connecting by implicit defaults file") } - connector.ConnectByDefaultsFile(defaultsFile) + connector.config = mysql_defaults_file.NewConfig(defaultsFile) + connector.method = ConnectByDefaultsFile + if err := connector.Connect(); err != nil { + return nil, fmt.Errorf("ConnectByDefaultsFile(%s): %w", defaultsFile, err) + } } } - return connector + return connector, nil }