Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ simple
example-repeatable
example-validators
example-struct
example-command

!/examples/**

Expand Down
271 changes: 271 additions & 0 deletions command.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,271 @@
package rig

import (
"flag"
"io"
"os"
"reflect"
"strings"

"github.com/pkg/errors"
)

type Command interface {
Name() string
Usage() string
Call(invokedName string, arguments []string) error
}

type CommandsConfig struct {
CommandName string
CommandUsage string
Config *Config
Commands map[string]Command

lastError error
}

func Commands(commands ...Command) *CommandsConfig {
cs := &CommandsConfig{
Config: &Config{
FlagSet: DefaultFlagSet(),
},
Commands: make(map[string]Command, len(commands)),
}

docs := make([][]string, len(commands))
for i, command := range commands {
if _, ok := cs.Commands[command.Name()]; ok {
cs.lastError = errors.Errorf("command with name %q already set", command.Name())
return cs
}
cs.Commands[command.Name()] = command
docs[i] = append(docs[i], command.Name())
usage := command.Usage()
if usage != "" {
docs[i] = append(docs[i], usage)
}
}

cs.Config.UsageExtra = func() string {
out := &strings.Builder{}

printUsageLines(out, docs, 2, 4)

return "\nCommands:\n" + out.String()
}

return cs
}

func SubCommands(name, usage string, commands ...Command) *CommandsConfig {
cs := Commands(commands...)
cs.CommandName = name
cs.CommandUsage = usage

return cs
}

func (cs *CommandsConfig) AdditionalFlags(flags ...*Flag) *CommandsConfig {
cs.Config.Flags = append(cs.Config.Flags, flags...)

return cs
}

func (cs *CommandsConfig) ParseAndCall() error {
return cs.Call(os.Args[0], os.Args[1:])
}

func (cs *CommandsConfig) Call(name string, arguments []string) error {
if cs.lastError != nil {
return cs.Config.handleError(cs.lastError)
}

resetFlagSet(cs.Config, name)

err := cs.Config.Parse(arguments)
if err != nil {
return err
}

args := cs.Config.Args()
if len(args) == 0 {
return cs.Config.handleError(errors.New("missing command"))
}
cmd, ok := cs.Commands[args[0]]
if !ok {
return cs.Config.handleError(errors.Errorf("%q: unknown command", args[0]))
}

return cmd.Call(name+" "+args[0], args[1:])
}

func (cs *CommandsConfig) Name() string {
return cs.CommandName
}

func (cs *CommandsConfig) Usage() string {
return cs.CommandUsage
}

type CallbackCommandConfig struct {
CommandName string
CommandUsage string
Config *Config
Callback Command
}

func CallbackCommand(name string, fn Command, usage string, flags ...*Flag) *CallbackCommandConfig {
return &CallbackCommandConfig{
CommandName: name,
CommandUsage: usage,
Config: &Config{
FlagSet: DefaultFlagSet(),
Flags: flags,
},
Callback: fn,
}
}

func (cc *CallbackCommandConfig) Call(name string, arguments []string) error {
resetFlagSet(cc.Config, name)

err := cc.Config.Parse(arguments)
if err != nil {
return err
}

return cc.Callback.Call(name, cc.Config.Args())
}

func (cc CallbackCommandConfig) Name() string {
return cc.CommandName
}

func (cc CallbackCommandConfig) Usage() string {
return cc.CommandUsage
}

type StructCommandConfig struct {
CommandName string
CommandUsage string
Config *Config

fn reflect.Value
arg reflect.Value
lastError error
}

func StructCommand(name string, fn interface{}, usage string, additionalFlags ...*Flag) *StructCommandConfig {
fnV := reflect.ValueOf(fn)
if fnV.Kind() != reflect.Func {
return &StructCommandConfig{
Config: &Config{
FlagSet: DefaultFlagSet(),
},
lastError: errors.Errorf("expected fn to be a function, got %T instead", fn),
}
}
fnT := fnV.Type()
if fnT.NumIn() != 1 {
return &StructCommandConfig{
Config: &Config{
FlagSet: DefaultFlagSet(),
},
lastError: errors.Errorf("expected fn to take 1 argument, got %d instead", fnT.NumIn()),
}
}
if fnT.NumOut() != 1 {
return &StructCommandConfig{
Config: &Config{
FlagSet: DefaultFlagSet(),
},
lastError: errors.Errorf("expected fn to return 1 value, got %d instead", fnT.NumOut()),
}
}
if !fnT.Out(0).AssignableTo(reflect.TypeOf((*error)(nil)).Elem()) {
return &StructCommandConfig{
Config: &Config{
FlagSet: DefaultFlagSet(),
},
lastError: errors.Errorf("expected fn to return an 'error', got %v instead", fnT.Out(0)),
}
}

argV := reflect.New(fnT.In(0))
if !argV.CanInterface() {
return &StructCommandConfig{
Config: &Config{
FlagSet: DefaultFlagSet(),
},
lastError: errors.Errorf("failed to instanciate value of type %v", fnT.In(0)),
}
}
flags, err := StructToFlags(argV.Interface())
if err != nil {
return &StructCommandConfig{
Config: &Config{
FlagSet: DefaultFlagSet(),
},
lastError: err,
}
}

return &StructCommandConfig{
CommandName: name,
CommandUsage: usage,
Config: &Config{
FlagSet: DefaultFlagSet(),
Flags: append(flags, additionalFlags...),
},
fn: fnV,
arg: reflect.Indirect(argV),
}
}

func (sc *StructCommandConfig) Call(name string, arguments []string) error {
if sc.lastError != nil {
return sc.Config.handleError(sc.lastError)
}

resetFlagSet(sc.Config, name)

err := sc.Config.Parse(arguments)
if err != nil {
return err
}

ret := sc.fn.Call([]reflect.Value{sc.arg})
if ret[0].IsNil() {
return nil
}
if !ret[0].CanInterface() {
return errors.New("failed to get error back")
}

err, ok := ret[0].Interface().(error)
if !ok {
return errors.Errorf("expected return value to be of type 'error', got %T instead", ret[0].Interface())
}

return err
}

func (sc *StructCommandConfig) Name() string {
return sc.CommandName
}

func (sc *StructCommandConfig) Usage() string {
return sc.CommandUsage
}

func resetFlagSet(config *Config, name string) {
errorHandling := flag.ExitOnError
output := io.Writer(os.Stderr)
if config.FlagSet != nil {
errorHandling = config.FlagSet.ErrorHandling()
output = config.FlagSet.Output()
}
config.FlagSet = flag.NewFlagSet(name, errorHandling)
config.FlagSet.SetOutput(output)
}
5 changes: 5 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ type Config struct {
FlagSet *flag.FlagSet

Flags []*Flag
UsageExtra func() string
defaultValuesSet bool
}

Expand Down Expand Up @@ -171,6 +172,10 @@ func (c *Config) Usage() {
}

printUsageLines(c.FlagSet.Output(), lines, 2, 4)

if c.UsageExtra != nil {
fmt.Fprintln(c.FlagSet.Output(), c.UsageExtra())
}
}

func printUsageLines(output io.Writer, lines [][]string, margin, sep int) {
Expand Down
57 changes: 57 additions & 0 deletions examples/example-command/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package main

import (
"fmt"

"github.com/Pimmr/rig"
)

var hello = "world"

type FooConfig struct {
A int
B string
}

func Foo(c FooConfig) error {
fmt.Printf("%q: %+v\n", hello, c)

return nil
}

type BarConfig struct {
A []string
B float64
}

func Bar(c BarConfig) error {
fmt.Printf("%q: %+v\n", hello, c)

return nil
}

type BazConfig struct {
A bool
B bool
}

func Baz(c BazConfig) error {
fmt.Printf("%q: %+v\n", hello, c)

return nil
}

func main() {
err := rig.Commands(
rig.SubCommands("test", "test commands",
rig.StructCommand("foo", Foo, "this is foo ..."),
rig.StructCommand("bar", Bar, "this is bar ..."),
),
rig.StructCommand("baz", Baz, "this is baz ..."),
).AdditionalFlags(
rig.String(&hello, "hello", "HELLO", "hello ..."),
).ParseAndCall()
if err != nil {
panic(err)
}
}