diff --git a/cmd/prompt.go b/cmd/prompt.go index 1036dbb..84161f2 100644 --- a/cmd/prompt.go +++ b/cmd/prompt.go @@ -1,6 +1,8 @@ package cmd import ( + "context" + intconfig "github.com/foomo/posh/internal/config" "github.com/foomo/posh/pkg/config" "github.com/spf13/cobra" @@ -30,6 +32,6 @@ var promptCmd = &cobra.Command{ return err } - return plg.Prompt(cmd.Context(), cfg) + return plg.Prompt(context.TODO(), cfg) }, } diff --git a/pkg/command/cache.go b/pkg/command/cache.go index 3876b4f..2249252 100644 --- a/pkg/command/cache.go +++ b/pkg/command/cache.go @@ -17,7 +17,7 @@ import ( type Cache struct { l log.Logger - tree *tree.Root + tree tree.Root cache cache.Cache } @@ -30,8 +30,9 @@ func NewCache(l log.Logger, cache cache.Cache) *Cache { l: l, cache: cache, } - inst.tree = &tree.Root{ - Name: "cache", + inst.tree = tree.New(&tree.Node{ + Name: "cache", + Description: "manage the internal cache", Nodes: tree.Nodes{ { Name: "clear", @@ -44,7 +45,7 @@ func NewCache(l log.Logger, cache cache.Cache) *Cache { Execute: inst.list, }, }, - } + }) return inst } @@ -53,11 +54,11 @@ func NewCache(l log.Logger, cache cache.Cache) *Cache { // ------------------------------------------------------------------------------------------------ func (c *Cache) Name() string { - return c.tree.Name + return c.tree.Node().Name } func (c *Cache) Description() string { - return "manage the internal cache" + return c.tree.Node().Description } func (c *Cache) Complete(ctx context.Context, r *readline.Readline) []goprompt.Suggest { diff --git a/pkg/command/help.go b/pkg/command/help.go index 0d23772..8d69222 100644 --- a/pkg/command/help.go +++ b/pkg/command/help.go @@ -63,8 +63,6 @@ func (c *Help) Validate(ctx context.Context, r *readline.Readline) error { } } return errors.Errorf("invalid [command] argument: %s", r.Args().At(0)) - case r.Args().LenGte(2): - return errors.New("too many arguments") } return nil @@ -84,9 +82,11 @@ Available Commands: ret += c.format(value.Name(), value.Description()) } c.l.Print(ret) - case 1: + default: if helper, ok := c.commands.Get(r.Args().At(0)).(Helper); ok { c.l.Print(helper.Help(ctx, r)) + } else { + c.l.Print("command not found") } } return nil diff --git a/pkg/command/tree/arg.go b/pkg/command/tree/arg.go index 40aadc6..bdc6ed4 100644 --- a/pkg/command/tree/arg.go +++ b/pkg/command/tree/arg.go @@ -8,8 +8,9 @@ import ( ) type Arg struct { - Name string - Repeat bool - Optional bool - Suggest func(ctx context.Context, t *Root, r *readline.Readline) []goprompt.Suggest + Name string + Description string + Repeat bool + Optional bool + Suggest func(ctx context.Context, t Root, r *readline.Readline) []goprompt.Suggest } diff --git a/pkg/command/tree/flag.go b/pkg/command/tree/flag.go deleted file mode 100644 index 95a95ff..0000000 --- a/pkg/command/tree/flag.go +++ /dev/null @@ -1,7 +0,0 @@ -package tree - -type Flag struct { - Name string - Required bool - Value interface{} -} diff --git a/pkg/command/tree/node.go b/pkg/command/tree/node.go index 9ba1f5e..75e5a79 100644 --- a/pkg/command/tree/node.go +++ b/pkg/command/tree/node.go @@ -6,21 +6,20 @@ import ( "github.com/foomo/posh/pkg/prompt/goprompt" "github.com/foomo/posh/pkg/readline" + xstrings "github.com/foomo/posh/pkg/util/strings" "github.com/foomo/posh/pkg/util/suggests" "github.com/pkg/errors" ) type Node struct { - Name string - Values func(ctx context.Context, r *readline.Readline) []goprompt.Suggest - Args Args - Flags func(ctx context.Context, r *readline.Readline, fs *readline.FlagSet) error - PassThroughArgs Args - PassThroughFlags func(ctx context.Context, r *readline.Readline, fs *readline.FlagSet) error - Description string - Nodes []*Node - Execute func(ctx context.Context, r *readline.Readline) error + Name string + Values func(ctx context.Context, r *readline.Readline) []goprompt.Suggest + Args Args + Flags func(ctx context.Context, r *readline.Readline, fs *readline.FlagSets) error + Description string + Nodes []*Node + Execute func(ctx context.Context, r *readline.Readline) error } // ------------------------------------------------------------------------------------------------ @@ -29,33 +28,21 @@ type Node struct { func (c *Node) setFlags(ctx context.Context, r *readline.Readline, parse bool) error { if c.Flags != nil { - f := readline.NewFlagSet() - if err := c.Flags(ctx, r, f); err != nil { + fs := readline.NewFlagSets() + if err := c.Flags(ctx, r, fs); err != nil { return err } - r.SetFlags(f) - if parse { - if err := r.ParseFlags(); err != nil { - return errors.Wrap(err, "failed to parse flags") - } - } + r.SetFlagSets(fs) } - if c.PassThroughFlags != nil { - f := readline.NewFlagSet() - if err := c.PassThroughFlags(ctx, r, f); err != nil { - return err - } - r.SetParsePassThroughFlags(f) - if parse { - if err := r.ParsePassThroughFlags(); err != nil { - return errors.Wrap(err, "failed to parse pass through flags") - } + if parse { + if err := r.ParseFlagSets(); err != nil { + return errors.Wrap(err, "failed to parse flags") } } return nil } -func (c *Node) completeArguments(ctx context.Context, p *Root, r *readline.Readline, i int) []goprompt.Suggest { +func (c *Node) completeArguments(ctx context.Context, p *root, r *readline.Readline, i int) []goprompt.Suggest { var suggest []goprompt.Suggest localArgs := r.Args()[i:] switch { @@ -84,7 +71,7 @@ func (c *Node) completeArguments(ctx context.Context, p *Root, r *readline.Readl func (c *Node) completeFlags(r *readline.Readline) []goprompt.Suggest { allFlags := r.AllFlags() if r.Flags().LenGt(1) { - if values := r.FlagSet().GetValues(strings.TrimPrefix(r.Flags().At(r.Flags().Len()-2), "--")); values != nil { + if values := r.FlagSets().All().GetValues(strings.TrimPrefix(r.Flags().At(r.Flags().Len()-2), "--")); values != nil { return suggests.List(values) } } @@ -95,18 +82,11 @@ func (c *Node) completeFlags(r *readline.Readline) []goprompt.Suggest { return suggest } -func (c *Node) completePassThroughFlags(r *readline.Readline) []goprompt.Suggest { - allPassThroughFlags := r.AllPassThroughFlags() - suggest := make([]goprompt.Suggest, len(allPassThroughFlags)) - for i, f := range allPassThroughFlags { - suggest[i] = goprompt.Suggest{Text: "--" + f.Name, Description: f.Usage} - } - return suggest -} - func (c *Node) execute(ctx context.Context, r *readline.Readline, i int) error { localArgs := r.Args()[i:] switch { + case len(localArgs) == 0 && c.Execute != nil: + break case len(c.Nodes) > 0 && len(localArgs) == 0: return ErrMissingCommand case len(c.Args) > 0: @@ -122,3 +102,78 @@ func (c *Node) execute(ctx context.Context, r *readline.Readline, i int) error { } return c.Execute(ctx, r) } + +func (c *Node) find(ctx context.Context, r *readline.Readline, i int) (*Node, int) { + if r.Args().LenLt(i + 1) { + return nil, i + } + arg := r.Args().At(i) + for _, cmd := range c.Nodes { + if cmd.Name == arg { + if subCmd, j := cmd.find(ctx, r, i+1); subCmd != nil { + return subCmd, j + } + return cmd, i + } + if cmd.Values != nil { + for _, name := range cmd.Values(ctx, r) { + if name.Text == arg { + if subCmd, j := cmd.find(ctx, r, i+1); subCmd != nil { + return subCmd, j + } + return cmd, i + } + } + } + } + return nil, i +} + +func (c *Node) help(ctx context.Context, r *readline.Readline) string { + ret := c.Description + + if len(c.Nodes) > 0 { + ret += "\n\nUsage:\n" + ret += " " + c.Name + " [command]" + + ret += "\n\nAvailable Commands:\n" + for _, node := range c.Nodes { + ret += " " + xstrings.PadEnd(node.Name, " ", 30) + node.Description + "\n" + } + } else { + ret += "\n\nUsage:\n" + ret += " " + c.Name + + for _, arg := range c.Args { + ret += " " + if arg.Optional { + ret += "<" + } else { + ret += "[" + } + ret += arg.Name + if arg.Optional { + ret += ">" + } else { + ret += "]" + } + ret += "\n" + } + + if len(c.Args) > 0 { + ret += "\n\nArguments:\n" + for _, arg := range c.Args { + ret += " " + xstrings.PadEnd(arg.Name, " ", 30) + arg.Description + "\n" + } + } + + if c.Flags != nil { + fs := readline.NewFlagSets() + if err := c.Flags(ctx, r, fs); err == nil { + ret += "\n\nFlags:\n" + ret += fs.All().FlagUsages() + } + } + } + return ret +} diff --git a/pkg/command/tree/root.go b/pkg/command/tree/root.go index 6b80152..3ae041a 100644 --- a/pkg/command/tree/root.go +++ b/pkg/command/tree/root.go @@ -8,34 +8,52 @@ import ( "github.com/foomo/posh/pkg/readline" ) -type Root struct { - Name string - Description string - Node *Node - Nodes Nodes +type Root interface { + Node() *Node + Complete(ctx context.Context, r *readline.Readline) []goprompt.Suggest + Execute(ctx context.Context, r *readline.Readline) error + Help(ctx context.Context, r *readline.Readline) string +} + +type root struct { + node *Node +} + +// ------------------------------------------------------------------------------------------------ +// ~ Constructor +// ------------------------------------------------------------------------------------------------ + +func New(node *Node) Root { + return &root{ + node: node, + } } // ------------------------------------------------------------------------------------------------ // ~ Public methods // ------------------------------------------------------------------------------------------------ -func (t *Root) Complete(ctx context.Context, r *readline.Readline) []goprompt.Suggest { +func (t *root) Node() *Node { + return t.node +} + +func (t *root) Complete(ctx context.Context, r *readline.Readline) []goprompt.Suggest { var suggests []goprompt.Suggest switch r.Mode() { case readline.ModeArgs: - if r.Args().LenLte(1) && len(t.Nodes) > 0 { - for _, command := range t.Nodes { + if r.Args().LenLte(1) && len(t.node.Nodes) > 0 { + for _, command := range t.node.Nodes { if command.Values != nil { suggests = command.Values(ctx, r) } else { suggests = append(suggests, goprompt.Suggest{Text: command.Name, Description: command.Description}) } } - } else if cmd, i := t.find(ctx, t.Nodes, r, 0); cmd == nil && t.Node != nil { - if err := t.Node.setFlags(ctx, r, false); err != nil { + } else if cmd, i := t.node.find(ctx, r, 0); cmd == nil && t.node != nil { + if err := t.node.setFlags(ctx, r, false); err != nil { return nil } else { - suggests = t.Node.completeArguments(ctx, t, r, 0) + suggests = t.node.completeArguments(ctx, t, r, 0) } } else if cmd == nil { return nil @@ -45,11 +63,11 @@ func (t *Root) Complete(ctx context.Context, r *readline.Readline) []goprompt.Su suggests = cmd.completeArguments(ctx, t, r, i+1) } case readline.ModeFlags: - if cmd, _ := t.find(ctx, t.Nodes, r, 0); cmd == nil && t.Node != nil { - if err := t.Node.setFlags(ctx, r, false); err != nil { + if cmd, _ := t.node.find(ctx, r, 0); cmd == nil && t.node != nil { + if err := t.node.setFlags(ctx, r, false); err != nil { return nil } else { - suggests = t.Node.completeFlags(r) + suggests = t.node.completeFlags(r) } } else if cmd == nil { return nil @@ -58,20 +76,6 @@ func (t *Root) Complete(ctx context.Context, r *readline.Readline) []goprompt.Su } else { suggests = cmd.completeFlags(r) } - case readline.ModePassThroughFlags: - if cmd, _ := t.find(ctx, t.Nodes, r, 0); cmd == nil && t.Node != nil { - if err := t.Node.setFlags(ctx, r, false); err != nil { - return nil - } else { - suggests = t.Node.completePassThroughFlags(r) - } - } else if cmd == nil { - return nil - } else if err := cmd.setFlags(ctx, r, false); err != nil { - return nil - } else { - suggests = cmd.completePassThroughFlags(r) - } case readline.ModeAdditionalArgs: // do nothing } @@ -81,28 +85,28 @@ func (t *Root) Complete(ctx context.Context, r *readline.Readline) []goprompt.Su return suggests } -func (t *Root) Execute(ctx context.Context, r *readline.Readline) error { +func (t *root) Execute(ctx context.Context, r *readline.Readline) error { var ( cmd *Node index int ) switch { - case t.Node == nil && len(t.Nodes) == 0: + case t.node == nil && t.node.Execute == nil && len(t.node.Nodes) == 0: return ErrNoop - case r.Args().LenIs(0) && t.Node == nil: + case r.Args().LenIs(1) && t.node == nil: return ErrMissingCommand } if r.Args().LenIs(0) { - cmd = t.Node - } else if found, i := t.find(ctx, t.Nodes, r, 0); found != nil { + cmd = t.node + } else if found, i := t.node.find(ctx, r, 0); found != nil { cmd = found index = i - } else if t.Node == nil { + } else if t.node == nil { return ErrInvalidCommand } else { - cmd = t.Node + cmd = t.node } if err := cmd.setFlags(ctx, r, true); err != nil { @@ -113,47 +117,22 @@ func (t *Root) Execute(ctx context.Context, r *readline.Readline) error { return nil } -func (t *Root) Help(ctx context.Context, r *readline.Readline) string { - // TODO recursive help - ret := t.Description - if t.Nodes != nil { - ret += "\n\nUsage:\n" - ret += " " + t.Name + " [command]" +func (t *root) Help(ctx context.Context, r *readline.Readline) string { + var ( + cmd *Node + ) - ret += "\n\nAvailable Commands:\n" - for _, node := range t.Nodes { - ret += " " + node.Name - } + if t.node == nil { + return "command not found" + } else if r.Args().LenIs(1) { + cmd = t.node + } else if len(t.node.Nodes) == 0 { + return "command not found" + } else if found, _ := t.node.find(ctx, r, 1); found != nil { + cmd = found + } else { + cmd = t.node } - return ret -} - -// ------------------------------------------------------------------------------------------------ -// ~ Private methods -// ------------------------------------------------------------------------------------------------ - -func (t *Root) find(ctx context.Context, cmds []*Node, r *readline.Readline, i int) (*Node, int) { - if r.Args().LenLt(i + 1) { - return nil, i - } - arg := r.Args().At(i) - for _, cmd := range cmds { - if cmd.Name == arg { - if subCmd, j := t.find(ctx, cmd.Nodes, r, i+1); subCmd != nil { - return subCmd, j - } - return cmd, i - } - if cmd.Values != nil { - for _, name := range cmd.Values(ctx, r) { - if name.Text == arg { - if subCmd, j := t.find(ctx, cmd.Nodes, r, i+1); subCmd != nil { - return subCmd, j - } - return cmd, i - } - } - } - } - return nil, i + + return cmd.help(ctx, r) } diff --git a/pkg/command/tree/root_test.go b/pkg/command/tree/root_test.go index 3d0150a..f65d6f3 100644 --- a/pkg/command/tree/root_test.go +++ b/pkg/command/tree/root_test.go @@ -29,13 +29,11 @@ func TestRoot(t *testing.T) { ErrThird1 = errors.New("third1") ) - r := &tree.Root{ + r := tree.New(&tree.Node{ Name: "root", Description: "root tree", - Node: &tree.Node{ - Execute: func(ctx context.Context, r *readline.Readline) error { - return ErrRoot - }, + Execute: func(ctx context.Context, r *readline.Readline) error { + return ErrRoot }, Nodes: tree.Nodes{ { @@ -85,7 +83,7 @@ func TestRoot(t *testing.T) { }, }, }, - } + }) tests := []struct { name string @@ -179,39 +177,35 @@ func TestRoot_Node(t *testing.T) { tests := []struct { name string - root *tree.Root + root tree.Root wantErr assert.ErrorAssertionFunc }{ { name: "tree", - root: &tree.Root{}, + root: tree.New(&tree.Node{}), wantErr: func(t assert.TestingT, err error, i ...interface{}) bool { return assert.ErrorIs(t, err, tree.ErrNoop) }, }, { name: "tree", - root: &tree.Root{ - Node: &tree.Node{ - Execute: func(ctx context.Context, r *readline.Readline) error { - return ErrOK - }, + root: tree.New(&tree.Node{ + Execute: func(ctx context.Context, r *readline.Readline) error { + return ErrOK }, - }, + }), wantErr: func(t assert.TestingT, err error, i ...interface{}) bool { return assert.ErrorIs(t, err, ErrOK) }, }, { name: "tree one", - root: &tree.Root{ - Node: &tree.Node{ - Execute: func(ctx context.Context, r *readline.Readline) error { - assert.Equal(T(ctx), "one", r.Args().At(0)) - return ErrOK - }, + root: tree.New(&tree.Node{ + Execute: func(ctx context.Context, r *readline.Readline) error { + assert.Equal(T(ctx), "one", r.Args().At(0)) + return ErrOK }, - }, + }), wantErr: func(t assert.TestingT, err error, i ...interface{}) bool { return assert.ErrorIs(t, err, ErrOK) }, @@ -238,107 +232,97 @@ func TestRoot_NodeArgs(t *testing.T) { tests := []struct { name string - root *tree.Root + root tree.Root wantErr assert.ErrorAssertionFunc }{ { name: "tree", - root: &tree.Root{ - Node: &tree.Node{ - Args: tree.Args{ - { - Name: "first", - }, - }, - Execute: func(ctx context.Context, r *readline.Readline) error { - return ErrOK + root: tree.New(&tree.Node{ + Args: tree.Args{ + { + Name: "first", }, }, - }, + Execute: func(ctx context.Context, r *readline.Readline) error { + return ErrOK + }, + }), wantErr: func(t assert.TestingT, err error, i ...interface{}) bool { return assert.ErrorIs(t, err, tree.ErrMissingArgument) }, }, { name: "tree one", - root: &tree.Root{ - Node: &tree.Node{ - Args: tree.Args{ - { - Name: "first", - }, - }, - Execute: func(ctx context.Context, r *readline.Readline) error { - assert.Equal(T(ctx), "one", r.Args().At(0)) - return ErrOK + root: tree.New(&tree.Node{ + Args: tree.Args{ + { + Name: "first", }, }, - }, + Execute: func(ctx context.Context, r *readline.Readline) error { + assert.Equal(T(ctx), "one", r.Args().At(0)) + return ErrOK + }, + }), wantErr: func(t assert.TestingT, err error, i ...interface{}) bool { return assert.ErrorIs(t, err, ErrOK) }, }, { name: "tree", - root: &tree.Root{ - Node: &tree.Node{ - Args: tree.Args{ - { - Name: "first", - }, - { - Name: "second", - }, + root: tree.New(&tree.Node{ + Args: tree.Args{ + { + Name: "first", }, - Execute: func(ctx context.Context, r *readline.Readline) error { - return ErrOK + { + Name: "second", }, }, - }, + Execute: func(ctx context.Context, r *readline.Readline) error { + return ErrOK + }, + }), wantErr: func(t assert.TestingT, err error, i ...interface{}) bool { return assert.ErrorIs(t, err, tree.ErrMissingArgument) }, }, { name: "tree one", - root: &tree.Root{ - Node: &tree.Node{ - Args: tree.Args{ - { - Name: "first", - }, - { - Name: "second", - }, + root: tree.New(&tree.Node{ + Args: tree.Args{ + { + Name: "first", }, - Execute: func(ctx context.Context, r *readline.Readline) error { - return ErrOK + { + Name: "second", }, }, - }, + Execute: func(ctx context.Context, r *readline.Readline) error { + return ErrOK + }, + }), wantErr: func(t assert.TestingT, err error, i ...interface{}) bool { return assert.ErrorIs(t, err, tree.ErrMissingArgument) }, }, { name: "tree one two", - root: &tree.Root{ - Node: &tree.Node{ - Args: tree.Args{ - { - Name: "first", - }, - { - Name: "second", - }, + root: tree.New(&tree.Node{ + Args: tree.Args{ + { + Name: "first", }, - Execute: func(ctx context.Context, r *readline.Readline) error { - assert.Equal(T(ctx), "one", r.Args().At(0)) - assert.Equal(T(ctx), "two", r.Args().At(1)) - return ErrOK + { + Name: "second", }, }, - }, + Execute: func(ctx context.Context, r *readline.Readline) error { + assert.Equal(T(ctx), "one", r.Args().At(0)) + assert.Equal(T(ctx), "two", r.Args().At(1)) + return ErrOK + }, + }), wantErr: func(t assert.TestingT, err error, i ...interface{}) bool { return assert.ErrorIs(t, err, ErrOK) }, @@ -365,49 +349,57 @@ func TestRoot_NodeFlags(t *testing.T) { tests := []struct { name string - root *tree.Root + root tree.Root wantErr assert.ErrorAssertionFunc }{ { name: "tree", - root: &tree.Root{ - Node: &tree.Node{ - Flags: func(ctx context.Context, r *readline.Readline, fs *readline.FlagSet) error { - fs.String("first", "first", "first") - fs.Bool("second", false, "second") - fs.Int64("third", 0, "third") - return nil - }, - Execute: func(ctx context.Context, r *readline.Readline) error { - assert.Equal(T(ctx), "first", r.FlagSet().GetString("first")) - assert.False(T(ctx), r.FlagSet().GetBool("second")) - assert.Equal(T(ctx), int64(0), r.FlagSet().GetInt64("third")) - return ErrOK - }, + root: tree.New(&tree.Node{ + Flags: func(ctx context.Context, r *readline.Readline, fs *readline.FlagSets) error { + fs.Default().String("first", "first", "first") + fs.Default().Bool("second", false, "second") + fs.Default().Int64("third", 0, "third") + return nil }, - }, + Execute: func(ctx context.Context, r *readline.Readline) error { + if value, err := r.FlagSets().Default().GetString("first"); assert.NoError(T(ctx), err) { + assert.Equal(T(ctx), "first", value) + } + if value, err := r.FlagSets().Default().GetBool("second"); assert.NoError(T(ctx), err) { + assert.False(T(ctx), value) + } + if value, err := r.FlagSets().Default().GetInt64("third"); assert.NoError(T(ctx), err) { + assert.Equal(T(ctx), int64(0), value) + } + return ErrOK + }, + }), wantErr: func(t assert.TestingT, err error, i ...interface{}) bool { return assert.ErrorIs(t, err, ErrOK) }, }, { name: "tree --first one --second --third 13", - root: &tree.Root{ - Node: &tree.Node{ - Flags: func(ctx context.Context, r *readline.Readline, fs *readline.FlagSet) error { - fs.String("first", "first", "first") - fs.Bool("second", false, "second") - fs.Int64("third", 0, "third") - return nil - }, - Execute: func(ctx context.Context, r *readline.Readline) error { - assert.Equal(T(ctx), "one", r.FlagSet().GetString("first")) - assert.True(T(ctx), r.FlagSet().GetBool("second")) - assert.Equal(T(ctx), int64(13), r.FlagSet().GetInt64("third")) - return ErrOK - }, + root: tree.New(&tree.Node{ + Flags: func(ctx context.Context, r *readline.Readline, fs *readline.FlagSets) error { + fs.Default().String("first", "first", "first") + fs.Default().Bool("second", false, "second") + fs.Default().Int64("third", 0, "third") + return nil }, - }, + Execute: func(ctx context.Context, r *readline.Readline) error { + if value, err := r.FlagSets().Default().GetString("first"); assert.NoError(T(ctx), err) { + assert.Equal(T(ctx), "one", value) + } + if value, err := r.FlagSets().Default().GetBool("second"); assert.NoError(T(ctx), err) { + assert.True(T(ctx), value) + } + if value, err := r.FlagSets().Default().GetInt64("third"); assert.NoError(T(ctx), err) { + assert.Equal(T(ctx), int64(13), value) + } + return ErrOK + }, + }), wantErr: func(t assert.TestingT, err error, i ...interface{}) bool { return assert.ErrorIs(t, err, ErrOK) }, diff --git a/pkg/log/must.go b/pkg/log/must.go new file mode 100644 index 0000000..c752876 --- /dev/null +++ b/pkg/log/must.go @@ -0,0 +1,8 @@ +package log + +func MustGet[T any](value T, err error) func(l Logger) T { + return func(l Logger) T { + l.Must(err) + return value + } +} diff --git a/pkg/prompt/prompt.go b/pkg/prompt/prompt.go index dfaa45b..2c9e8b6 100644 --- a/pkg/prompt/prompt.go +++ b/pkg/prompt/prompt.go @@ -2,8 +2,6 @@ package prompt import ( "context" - "os" - "os/signal" "strings" "github.com/c-bata/go-prompt" @@ -239,7 +237,7 @@ func (s *Prompt) execute(input string) { return } - s.history.Persist(s.ctx, input) + defer s.history.Persist(s.ctx, input) input = s.alias(input, s.aliases) @@ -317,12 +315,6 @@ func (s *Prompt) complete(d prompt.Document) []prompt.Suggest { } else if value, ok := cmd.(command.Completer); ok { return s.filter(value.Complete(ctx, s.readline), word, true) } - case readline.ModePassThroughFlags: - if value, ok := cmd.(command.PassThroughFlagsCompleter); ok { - return s.filter(value.CompletePassTroughFlags(ctx, s.readline), word, true) - } else if value, ok := cmd.(command.Completer); ok { - return s.filter(value.Complete(ctx, s.readline), word, true) - } case readline.ModeAdditionalArgs: if value, ok := cmd.(command.AdditionalArgsCompleter); ok { return s.filter(value.CompleteAdditionalArgs(ctx, s.readline), word, true) @@ -335,17 +327,21 @@ func (s *Prompt) complete(d prompt.Document) []prompt.Suggest { // context returns and watches over a new context func (s *Prompt) context() context.Context { - ctx, cancel := context.WithCancel(context.Background()) // FIXME context.WithCancel(s.ctx) - go func(ctx context.Context, cancel context.CancelFunc) { - sigChan := make(chan os.Signal, 1) - signal.Notify(sigChan, os.Interrupt) - select { - case <-sigChan: - cancel() - return - case <-ctx.Done(): - return - } - }(ctx, cancel) - return ctx + //ctx, cancel := context.WithCancel(context.Background()) + //go func(ctx context.Context, cancel context.CancelFunc) { + // sigChan := make(chan os.Signal, 1) + // signal.Notify(sigChan, os.Interrupt) + // select { + // case <-s.ctx.Done(): + // cancel() + // return + // case <-sigChan: + // cancel() + // return + // case <-ctx.Done(): + // return + // } + //}(ctx, cancel) + //return ctx + return s.ctx } diff --git a/pkg/readline/arg.go b/pkg/readline/arg.go index d36d7c7..228f91d 100644 --- a/pkg/readline/arg.go +++ b/pkg/readline/arg.go @@ -30,5 +30,5 @@ func (a Arg) IsRedirect() bool { } func (a Arg) IsAdditional() bool { - return a.IsPipe() || a.IsRedirect() + return a.IsPass() || a.IsPipe() || a.IsRedirect() } diff --git a/pkg/readline/flags.go b/pkg/readline/flags.go new file mode 100644 index 0000000..bbd0c9d --- /dev/null +++ b/pkg/readline/flags.go @@ -0,0 +1,37 @@ +package readline + +import ( + "github.com/spf13/pflag" +) + +type Flags []*pflag.Flag + +func (f Flags) Remove(name string) (*pflag.Flag, Flags) { + for i, v := range f { + if v.Name == name { + return v, f.Splice(i, 1) + } + } + return nil, f +} + +func (f Flags) Slice(start, end int) Flags { + return append(f[:start], f[end:]...) +} + +func (f Flags) Splice(start, num int) Flags { + return append(f[:start], f[start+num:]...) +} + +func (f Flags) Args() Args { + var ret Args + for _, v := range f { + switch v.Value.Type() { + case "bool": + ret = append(ret, "--"+v.Name) + default: + ret = append(ret, "--"+v.Name, v.Value.String()) + } + } + return ret +} diff --git a/pkg/readline/flagset.go b/pkg/readline/flagset.go index eb15352..40c2002 100644 --- a/pkg/readline/flagset.go +++ b/pkg/readline/flagset.go @@ -1,8 +1,6 @@ package readline import ( - "strconv" - "github.com/spf13/pflag" ) @@ -10,73 +8,31 @@ type FlagSet struct { *pflag.FlagSet } -func NewFlagSet() *FlagSet { +func NewFlagSet(name string) *FlagSet { + fs := pflag.NewFlagSet(name, pflag.ContinueOnError) + fs.ParseErrorsWhitelist = pflag.ParseErrorsWhitelist{UnknownFlags: true} return &FlagSet{ - FlagSet: pflag.NewFlagSet("readline", pflag.ContinueOnError), + FlagSet: fs, } } -func (a *FlagSet) SetValues(name string, values ...string) error { - return a.SetAnnotation(name, "values", values) +func (s *FlagSet) Visited() Flags { + var ret Flags + s.Visit(func(f *pflag.Flag) { + ret = append(ret, f) + }) + return ret } -func (a *FlagSet) GetValues(name string) []string { - if f := a.FlagSet.Lookup(name); f == nil { +func (s *FlagSet) SetValues(name string, values ...string) error { + return s.SetAnnotation(name, "values", values) +} + +func (s *FlagSet) GetValues(name string) []string { + if f := s.FlagSet.Lookup(name); f == nil { return nil } else if v, ok := f.Annotations["values"]; ok { return v } return nil } - -func (a *FlagSet) GetString(name string) string { - if f := a.FlagSet.Lookup(name); f == nil { - return "" - } else if !a.flagIsSet(name) { - return f.DefValue - } else { - return f.Value.String() - } -} - -func (a *FlagSet) GetInt64(name string) int64 { - if value := a.GetString(name); value == "" { - return 0 - } else if v, err := strconv.ParseInt(value, 10, 64); err != nil { - return 0 - } else { - return v - } -} - -func (a *FlagSet) GetFloat64(name string) float64 { - if value := a.GetString(name); value == "" { - return 0 - } else if v, err := strconv.ParseFloat(value, 64); err != nil { - return 0 - } else { - return v - } -} - -func (a *FlagSet) GetBool(name string) bool { - if value := a.GetString(name); value == "" { - return false - } else if v, err := strconv.ParseBool(value); err != nil { - return false - } else { - return v - } -} - -func (a *FlagSet) flagIsSet(name string) bool { - found := false - if fs := a.FlagSet; fs != nil { - fs.Visit(func(f *pflag.Flag) { - if f.Name == name { - found = true - } - }) - } - return found -} diff --git a/pkg/readline/flagsets.go b/pkg/readline/flagsets.go new file mode 100644 index 0000000..dcc09a2 --- /dev/null +++ b/pkg/readline/flagsets.go @@ -0,0 +1,80 @@ +package readline + +import ( + "github.com/spf13/pflag" +) + +type FlagSets struct { + sets map[string]*FlagSet +} + +func NewFlagSets() *FlagSets { + return &FlagSets{ + sets: map[string]*FlagSet{}, + } +} + +func (s *FlagSets) Default() *FlagSet { + return s.Get("default") +} + +func (s *FlagSets) Internal() *FlagSet { + return s.Get("internal") +} + +func (s *FlagSets) Get(name string) *FlagSet { + if _, ok := s.sets[name]; !ok { + s.sets[name] = NewFlagSet(name) + } + return s.sets[name] +} + +func (s *FlagSets) Parse(arguments []string) error { + for _, set := range s.sets { + if err := set.Parse(arguments); err != nil { + return err + } + } + return nil +} + +func (s *FlagSets) All() *FlagSet { + fs := NewFlagSet("all") + for _, set := range s.sets { + fs.AddFlagSet(set.FlagSet) + } + return fs +} + +func (s *FlagSets) Visit(fn func(*pflag.Flag)) Flags { + var ret Flags + for _, set := range s.sets { + set.Visit(fn) + } + return ret +} + +func (s *FlagSets) VisitAll(fn func(*pflag.Flag)) Flags { + var ret Flags + for _, set := range s.sets { + set.VisitAll(fn) + } + return ret +} + +func (s *FlagSets) Visited() Flags { + var ret Flags + for _, set := range s.sets { + ret = append(ret, set.Visited()...) + } + return ret +} + +func (s *FlagSets) ParseAll(arguments []string, fn func(flag *pflag.Flag, value string) error) error { + for _, group := range s.sets { + if err := group.ParseAll(arguments, fn); err != nil { + return err + } + } + return nil +} diff --git a/pkg/readline/mode.go b/pkg/readline/mode.go index 492320f..e03e4e4 100644 --- a/pkg/readline/mode.go +++ b/pkg/readline/mode.go @@ -3,8 +3,7 @@ package readline type Mode string const ( - ModeArgs Mode = "" - ModeFlags Mode = "flags" - ModePassThroughFlags Mode = "passThroughFlags" - ModeAdditionalArgs Mode = "additional" + ModeArgs Mode = "args" + ModeFlags Mode = "flags" + ModeAdditionalArgs Mode = "additional" ) diff --git a/pkg/readline/readline.go b/pkg/readline/readline.go index dacf729..af13dd3 100644 --- a/pkg/readline/readline.go +++ b/pkg/readline/readline.go @@ -11,16 +11,14 @@ import ( type ( Readline struct { - l log.Logger - mu sync.RWMutex - cmd string - mode Mode - args Args - flags Args - flagSet *FlagSet - passThroughFlags Args - passThroughFlagSet *FlagSet - additionalArgs Args + l log.Logger + mu sync.RWMutex + cmd string + mode Mode + args Args + flags Args + flagSets *FlagSets + additionalArgs Args // regex - split cmd into args (https://regex101.com/r/EgiOzv/1) regex *regexp.Regexp } @@ -85,22 +83,10 @@ func (a *Readline) Flags() Args { return a.flags } -func (a *Readline) FlagSet() *FlagSet { +func (a *Readline) FlagSets() *FlagSets { a.mu.RLock() defer a.mu.RUnlock() - return a.flagSet -} - -func (a *Readline) PassThroughFlags() Args { - a.mu.RLock() - defer a.mu.RUnlock() - return a.passThroughFlags -} - -func (a *Readline) PassThroughFlagSet() *FlagSet { - a.mu.RLock() - defer a.mu.RUnlock() - return a.passThroughFlagSet + return a.flagSets } func (a *Readline) AdditionalArgs() Args { @@ -125,14 +111,10 @@ func (a *Readline) Parse(input string) error { a.cmd, parts = parts[0], parts[1:] } - last := len(parts) - 1 for i, part := range parts { if a.mode == ModeArgs && Arg(part).IsFlag() { a.mode = ModeFlags } - if i != last && (a.mode == ModeArgs || a.mode == ModeFlags) && Arg(part).IsPass() { - a.mode = ModePassThroughFlags - } if Arg(part).IsAdditional() && i < len(parts)-1 { a.mode = ModeAdditionalArgs } @@ -142,8 +124,6 @@ func (a *Readline) Parse(input string) error { a.args = append(a.args, part) case ModeFlags: a.flags = append(a.flags, part) - case ModePassThroughFlags: - a.passThroughFlags = append(a.passThroughFlags, part) case ModeAdditionalArgs: a.additionalArgs = append(a.additionalArgs, part) } @@ -152,84 +132,68 @@ func (a *Readline) Parse(input string) error { return nil } -func (a *Readline) SetFlags(fs *FlagSet) { +func (a *Readline) SetFlagSets(fs *FlagSets) { a.mu.Lock() defer a.mu.Unlock() - a.flagSet = fs + a.flagSets = fs } -func (a *Readline) ParseFlags() error { - if fs := a.FlagSet(); fs == nil { - return nil - } else if err := fs.Parse(a.flags); err != nil { - return err - } - return nil -} - -func (a *Readline) SetParsePassThroughFlags(fs *FlagSet) { - a.mu.Lock() - defer a.mu.Unlock() - a.passThroughFlagSet = fs -} - -func (a *Readline) ParsePassThroughFlags() error { - if fs := a.PassThroughFlagSet(); fs == nil { - return nil - } else if err := fs.Parse(a.passThroughFlags); err != nil { - return err +func (a *Readline) ParseFlagSets() error { + if fs := a.FlagSets(); fs != nil { + if err := fs.Parse(a.flags); err != nil { + return err + } } return nil } func (a *Readline) String() string { return fmt.Sprintf(` -Cmd: %s -Mode %s -Args: %s -Flags: %s -PassThroughFlags: %s -AdditionalArgs %s -`, a.Cmd(), a.Mode(), a.Args(), a.Flags(), a.PassThroughFlags(), a.AdditionalArgs()) +Cmd: %s +Args: %s +Flags: %s +AdditionalArgs: %s +`, a.Cmd(), a.Args(), a.Flags(), a.AdditionalArgs()) } func (a *Readline) IsModeDefault() bool { return a.Mode() == ModeArgs } -func (a *Readline) IsModePassThrough() bool { - return a.Mode() == ModePassThroughFlags -} - func (a *Readline) IsModeAdditional() bool { return a.Mode() == ModeAdditionalArgs } func (a *Readline) AllFlags() []*pflag.Flag { var ret []*pflag.Flag - if fs := a.FlagSet(); fs != nil { - fs.VisitAll(func(f *pflag.Flag) { + if fs := a.FlagSets(); fs != nil { + fs.All().VisitAll(func(f *pflag.Flag) { ret = append(ret, f) }) } return ret } -func (a *Readline) VisitedFlags() []*pflag.Flag { - var ret []*pflag.Flag - if fs := a.FlagSet(); fs != nil { - fs.Visit(func(f *pflag.Flag) { - ret = append(ret, f) - }) +func (a *Readline) VisitedFlags() Flags { + var ret Flags + if fs := a.FlagSets(); fs != nil { + ret = fs.Visited() } return ret } -func (a *Readline) AllPassThroughFlags() []*pflag.Flag { - var ret []*pflag.Flag - if fs := a.PassThroughFlagSet(); fs != nil { +func (a *Readline) AdditionalFlags() Args { + ret := append(Args{}, a.flags...) + if fs := a.FlagSets(); fs != nil { fs.VisitAll(func(f *pflag.Flag) { - ret = append(ret, f) + if i := ret.IndexOf("--" + f.Name); i >= 0 { + switch f.Value.Type() { + case "bool": + ret = ret.Splice(ret.IndexOf("--"+f.Name), 1) + default: + ret = ret.Splice(ret.IndexOf("--"+f.Name), 2) + } + } }) } return ret @@ -244,8 +208,6 @@ func (a *Readline) reset() { a.cmd = "" a.args = nil a.flags = nil - a.flagSet = nil - a.passThroughFlags = nil - a.passThroughFlagSet = nil + a.flagSets = nil a.additionalArgs = nil } diff --git a/pkg/readline/readline_test.go b/pkg/readline/readline_test.go new file mode 100644 index 0000000..8dcd909 --- /dev/null +++ b/pkg/readline/readline_test.go @@ -0,0 +1,118 @@ +package readline_test + +import ( + "testing" + + "github.com/foomo/posh/pkg/log" + "github.com/foomo/posh/pkg/readline" + "github.com/stretchr/testify/assert" +) + +func TestReadline(t *testing.T) { + tests := []struct { + name string + want func(t *testing.T, r *readline.Readline) + }{ + { + name: "foo bar", + want: func(t *testing.T, r *readline.Readline) { + assert.Equal(t, ` +Cmd: foo +Args: [bar] +Flags: [] +AdditionalArgs: [] +`, + r.String(), + ) + }, + }, + { + name: "foo bar baz", + want: func(t *testing.T, r *readline.Readline) { + assert.Equal(t, ` +Cmd: foo +Args: [bar baz] +Flags: [] +AdditionalArgs: [] +`, + r.String(), + ) + }, + }, + { + name: "foo bar --baz", + want: func(t *testing.T, r *readline.Readline) { + assert.Equal(t, ` +Cmd: foo +Args: [bar] +Flags: [--baz] +AdditionalArgs: [] +`, + r.String(), + ) + }, + }, + { + name: "foo --baz bar", + want: func(t *testing.T, r *readline.Readline) { + assert.Equal(t, ` +Cmd: foo +Args: [] +Flags: [--baz bar] +AdditionalArgs: [] +`, + r.String(), + ) + }, + }, + { + name: "foo --baz bar1", + want: func(t *testing.T, r *readline.Readline) { + assert.Equal(t, ` +Cmd: foo +Args: [] +Flags: [--baz bar1] +AdditionalArgs: [] +`, + r.String(), + ) + }, + }, + { + name: "foo | cat", + want: func(t *testing.T, r *readline.Readline) { + assert.Equal(t, ` +Cmd: foo +Args: [] +Flags: [] +AdditionalArgs: [| cat] +`, + r.String(), + ) + }, + }, + { + name: "foo --bar1 --bar2 one --bar3 two,three,four", + want: func(t *testing.T, r *readline.Readline) { + assert.Equal(t, ` +Cmd: foo +Args: [] +Flags: [--bar1 --bar2 one --bar3 two,three,four] +AdditionalArgs: [] +`, + r.String(), + ) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + l := log.NewTest(t, log.TestWithLevel(log.LevelDebug)) + if r, err := readline.New(l); assert.NoError(t, err) { + assert.NoError(t, r.Parse(tt.name)) + tt.want(t, r) + } + }) + } +}