feat: add flag sets

This commit is contained in:
Kevin Franklin Kim 2023-03-20 11:40:09 +01:00
parent db05bf2784
commit 6cfe3ed9d0
17 changed files with 591 additions and 412 deletions

View File

@ -1,6 +1,8 @@
package cmd
import (
"context"
intconfig "github.com/foomo/posh/internal/config"
"github.com/foomo/posh/pkg/config"
"github.com/spf13/cobra"
@ -30,6 +32,6 @@ var promptCmd = &cobra.Command{
return err
}
return plg.Prompt(cmd.Context(), cfg)
return plg.Prompt(context.TODO(), cfg)
},
}

View File

@ -17,7 +17,7 @@ import (
type Cache struct {
l log.Logger
tree *tree.Root
tree tree.Root
cache cache.Cache
}
@ -30,8 +30,9 @@ func NewCache(l log.Logger, cache cache.Cache) *Cache {
l: l,
cache: cache,
}
inst.tree = &tree.Root{
Name: "cache",
inst.tree = tree.New(&tree.Node{
Name: "cache",
Description: "manage the internal cache",
Nodes: tree.Nodes{
{
Name: "clear",
@ -44,7 +45,7 @@ func NewCache(l log.Logger, cache cache.Cache) *Cache {
Execute: inst.list,
},
},
}
})
return inst
}
@ -53,11 +54,11 @@ func NewCache(l log.Logger, cache cache.Cache) *Cache {
// ------------------------------------------------------------------------------------------------
func (c *Cache) Name() string {
return c.tree.Name
return c.tree.Node().Name
}
func (c *Cache) Description() string {
return "manage the internal cache"
return c.tree.Node().Description
}
func (c *Cache) Complete(ctx context.Context, r *readline.Readline) []goprompt.Suggest {

View File

@ -63,8 +63,6 @@ func (c *Help) Validate(ctx context.Context, r *readline.Readline) error {
}
}
return errors.Errorf("invalid [command] argument: %s", r.Args().At(0))
case r.Args().LenGte(2):
return errors.New("too many arguments")
}
return nil
@ -84,9 +82,11 @@ Available Commands:
ret += c.format(value.Name(), value.Description())
}
c.l.Print(ret)
case 1:
default:
if helper, ok := c.commands.Get(r.Args().At(0)).(Helper); ok {
c.l.Print(helper.Help(ctx, r))
} else {
c.l.Print("command not found")
}
}
return nil

View File

@ -8,8 +8,9 @@ import (
)
type Arg struct {
Name string
Repeat bool
Optional bool
Suggest func(ctx context.Context, t *Root, r *readline.Readline) []goprompt.Suggest
Name string
Description string
Repeat bool
Optional bool
Suggest func(ctx context.Context, t Root, r *readline.Readline) []goprompt.Suggest
}

View File

@ -1,7 +0,0 @@
package tree
type Flag struct {
Name string
Required bool
Value interface{}
}

View File

@ -6,21 +6,20 @@ import (
"github.com/foomo/posh/pkg/prompt/goprompt"
"github.com/foomo/posh/pkg/readline"
xstrings "github.com/foomo/posh/pkg/util/strings"
"github.com/foomo/posh/pkg/util/suggests"
"github.com/pkg/errors"
)
type Node struct {
Name string
Values func(ctx context.Context, r *readline.Readline) []goprompt.Suggest
Args Args
Flags func(ctx context.Context, r *readline.Readline, fs *readline.FlagSet) error
PassThroughArgs Args
PassThroughFlags func(ctx context.Context, r *readline.Readline, fs *readline.FlagSet) error
Description string
Nodes []*Node
Execute func(ctx context.Context, r *readline.Readline) error
Name string
Values func(ctx context.Context, r *readline.Readline) []goprompt.Suggest
Args Args
Flags func(ctx context.Context, r *readline.Readline, fs *readline.FlagSets) error
Description string
Nodes []*Node
Execute func(ctx context.Context, r *readline.Readline) error
}
// ------------------------------------------------------------------------------------------------
@ -29,33 +28,21 @@ type Node struct {
func (c *Node) setFlags(ctx context.Context, r *readline.Readline, parse bool) error {
if c.Flags != nil {
f := readline.NewFlagSet()
if err := c.Flags(ctx, r, f); err != nil {
fs := readline.NewFlagSets()
if err := c.Flags(ctx, r, fs); err != nil {
return err
}
r.SetFlags(f)
if parse {
if err := r.ParseFlags(); err != nil {
return errors.Wrap(err, "failed to parse flags")
}
}
r.SetFlagSets(fs)
}
if c.PassThroughFlags != nil {
f := readline.NewFlagSet()
if err := c.PassThroughFlags(ctx, r, f); err != nil {
return err
}
r.SetParsePassThroughFlags(f)
if parse {
if err := r.ParsePassThroughFlags(); err != nil {
return errors.Wrap(err, "failed to parse pass through flags")
}
if parse {
if err := r.ParseFlagSets(); err != nil {
return errors.Wrap(err, "failed to parse flags")
}
}
return nil
}
func (c *Node) completeArguments(ctx context.Context, p *Root, r *readline.Readline, i int) []goprompt.Suggest {
func (c *Node) completeArguments(ctx context.Context, p *root, r *readline.Readline, i int) []goprompt.Suggest {
var suggest []goprompt.Suggest
localArgs := r.Args()[i:]
switch {
@ -84,7 +71,7 @@ func (c *Node) completeArguments(ctx context.Context, p *Root, r *readline.Readl
func (c *Node) completeFlags(r *readline.Readline) []goprompt.Suggest {
allFlags := r.AllFlags()
if r.Flags().LenGt(1) {
if values := r.FlagSet().GetValues(strings.TrimPrefix(r.Flags().At(r.Flags().Len()-2), "--")); values != nil {
if values := r.FlagSets().All().GetValues(strings.TrimPrefix(r.Flags().At(r.Flags().Len()-2), "--")); values != nil {
return suggests.List(values)
}
}
@ -95,18 +82,11 @@ func (c *Node) completeFlags(r *readline.Readline) []goprompt.Suggest {
return suggest
}
func (c *Node) completePassThroughFlags(r *readline.Readline) []goprompt.Suggest {
allPassThroughFlags := r.AllPassThroughFlags()
suggest := make([]goprompt.Suggest, len(allPassThroughFlags))
for i, f := range allPassThroughFlags {
suggest[i] = goprompt.Suggest{Text: "--" + f.Name, Description: f.Usage}
}
return suggest
}
func (c *Node) execute(ctx context.Context, r *readline.Readline, i int) error {
localArgs := r.Args()[i:]
switch {
case len(localArgs) == 0 && c.Execute != nil:
break
case len(c.Nodes) > 0 && len(localArgs) == 0:
return ErrMissingCommand
case len(c.Args) > 0:
@ -122,3 +102,78 @@ func (c *Node) execute(ctx context.Context, r *readline.Readline, i int) error {
}
return c.Execute(ctx, r)
}
func (c *Node) find(ctx context.Context, r *readline.Readline, i int) (*Node, int) {
if r.Args().LenLt(i + 1) {
return nil, i
}
arg := r.Args().At(i)
for _, cmd := range c.Nodes {
if cmd.Name == arg {
if subCmd, j := cmd.find(ctx, r, i+1); subCmd != nil {
return subCmd, j
}
return cmd, i
}
if cmd.Values != nil {
for _, name := range cmd.Values(ctx, r) {
if name.Text == arg {
if subCmd, j := cmd.find(ctx, r, i+1); subCmd != nil {
return subCmd, j
}
return cmd, i
}
}
}
}
return nil, i
}
func (c *Node) help(ctx context.Context, r *readline.Readline) string {
ret := c.Description
if len(c.Nodes) > 0 {
ret += "\n\nUsage:\n"
ret += " " + c.Name + " [command]"
ret += "\n\nAvailable Commands:\n"
for _, node := range c.Nodes {
ret += " " + xstrings.PadEnd(node.Name, " ", 30) + node.Description + "\n"
}
} else {
ret += "\n\nUsage:\n"
ret += " " + c.Name
for _, arg := range c.Args {
ret += " "
if arg.Optional {
ret += "<"
} else {
ret += "["
}
ret += arg.Name
if arg.Optional {
ret += ">"
} else {
ret += "]"
}
ret += "\n"
}
if len(c.Args) > 0 {
ret += "\n\nArguments:\n"
for _, arg := range c.Args {
ret += " " + xstrings.PadEnd(arg.Name, " ", 30) + arg.Description + "\n"
}
}
if c.Flags != nil {
fs := readline.NewFlagSets()
if err := c.Flags(ctx, r, fs); err == nil {
ret += "\n\nFlags:\n"
ret += fs.All().FlagUsages()
}
}
}
return ret
}

View File

@ -8,34 +8,52 @@ import (
"github.com/foomo/posh/pkg/readline"
)
type Root struct {
Name string
Description string
Node *Node
Nodes Nodes
type Root interface {
Node() *Node
Complete(ctx context.Context, r *readline.Readline) []goprompt.Suggest
Execute(ctx context.Context, r *readline.Readline) error
Help(ctx context.Context, r *readline.Readline) string
}
type root struct {
node *Node
}
// ------------------------------------------------------------------------------------------------
// ~ Constructor
// ------------------------------------------------------------------------------------------------
func New(node *Node) Root {
return &root{
node: node,
}
}
// ------------------------------------------------------------------------------------------------
// ~ Public methods
// ------------------------------------------------------------------------------------------------
func (t *Root) Complete(ctx context.Context, r *readline.Readline) []goprompt.Suggest {
func (t *root) Node() *Node {
return t.node
}
func (t *root) Complete(ctx context.Context, r *readline.Readline) []goprompt.Suggest {
var suggests []goprompt.Suggest
switch r.Mode() {
case readline.ModeArgs:
if r.Args().LenLte(1) && len(t.Nodes) > 0 {
for _, command := range t.Nodes {
if r.Args().LenLte(1) && len(t.node.Nodes) > 0 {
for _, command := range t.node.Nodes {
if command.Values != nil {
suggests = command.Values(ctx, r)
} else {
suggests = append(suggests, goprompt.Suggest{Text: command.Name, Description: command.Description})
}
}
} else if cmd, i := t.find(ctx, t.Nodes, r, 0); cmd == nil && t.Node != nil {
if err := t.Node.setFlags(ctx, r, false); err != nil {
} else if cmd, i := t.node.find(ctx, r, 0); cmd == nil && t.node != nil {
if err := t.node.setFlags(ctx, r, false); err != nil {
return nil
} else {
suggests = t.Node.completeArguments(ctx, t, r, 0)
suggests = t.node.completeArguments(ctx, t, r, 0)
}
} else if cmd == nil {
return nil
@ -45,11 +63,11 @@ func (t *Root) Complete(ctx context.Context, r *readline.Readline) []goprompt.Su
suggests = cmd.completeArguments(ctx, t, r, i+1)
}
case readline.ModeFlags:
if cmd, _ := t.find(ctx, t.Nodes, r, 0); cmd == nil && t.Node != nil {
if err := t.Node.setFlags(ctx, r, false); err != nil {
if cmd, _ := t.node.find(ctx, r, 0); cmd == nil && t.node != nil {
if err := t.node.setFlags(ctx, r, false); err != nil {
return nil
} else {
suggests = t.Node.completeFlags(r)
suggests = t.node.completeFlags(r)
}
} else if cmd == nil {
return nil
@ -58,20 +76,6 @@ func (t *Root) Complete(ctx context.Context, r *readline.Readline) []goprompt.Su
} else {
suggests = cmd.completeFlags(r)
}
case readline.ModePassThroughFlags:
if cmd, _ := t.find(ctx, t.Nodes, r, 0); cmd == nil && t.Node != nil {
if err := t.Node.setFlags(ctx, r, false); err != nil {
return nil
} else {
suggests = t.Node.completePassThroughFlags(r)
}
} else if cmd == nil {
return nil
} else if err := cmd.setFlags(ctx, r, false); err != nil {
return nil
} else {
suggests = cmd.completePassThroughFlags(r)
}
case readline.ModeAdditionalArgs:
// do nothing
}
@ -81,28 +85,28 @@ func (t *Root) Complete(ctx context.Context, r *readline.Readline) []goprompt.Su
return suggests
}
func (t *Root) Execute(ctx context.Context, r *readline.Readline) error {
func (t *root) Execute(ctx context.Context, r *readline.Readline) error {
var (
cmd *Node
index int
)
switch {
case t.Node == nil && len(t.Nodes) == 0:
case t.node == nil && t.node.Execute == nil && len(t.node.Nodes) == 0:
return ErrNoop
case r.Args().LenIs(0) && t.Node == nil:
case r.Args().LenIs(1) && t.node == nil:
return ErrMissingCommand
}
if r.Args().LenIs(0) {
cmd = t.Node
} else if found, i := t.find(ctx, t.Nodes, r, 0); found != nil {
cmd = t.node
} else if found, i := t.node.find(ctx, r, 0); found != nil {
cmd = found
index = i
} else if t.Node == nil {
} else if t.node == nil {
return ErrInvalidCommand
} else {
cmd = t.Node
cmd = t.node
}
if err := cmd.setFlags(ctx, r, true); err != nil {
@ -113,47 +117,22 @@ func (t *Root) Execute(ctx context.Context, r *readline.Readline) error {
return nil
}
func (t *Root) Help(ctx context.Context, r *readline.Readline) string {
// TODO recursive help
ret := t.Description
if t.Nodes != nil {
ret += "\n\nUsage:\n"
ret += " " + t.Name + " [command]"
func (t *root) Help(ctx context.Context, r *readline.Readline) string {
var (
cmd *Node
)
ret += "\n\nAvailable Commands:\n"
for _, node := range t.Nodes {
ret += " " + node.Name
}
if t.node == nil {
return "command not found"
} else if r.Args().LenIs(1) {
cmd = t.node
} else if len(t.node.Nodes) == 0 {
return "command not found"
} else if found, _ := t.node.find(ctx, r, 1); found != nil {
cmd = found
} else {
cmd = t.node
}
return ret
}
// ------------------------------------------------------------------------------------------------
// ~ Private methods
// ------------------------------------------------------------------------------------------------
func (t *Root) find(ctx context.Context, cmds []*Node, r *readline.Readline, i int) (*Node, int) {
if r.Args().LenLt(i + 1) {
return nil, i
}
arg := r.Args().At(i)
for _, cmd := range cmds {
if cmd.Name == arg {
if subCmd, j := t.find(ctx, cmd.Nodes, r, i+1); subCmd != nil {
return subCmd, j
}
return cmd, i
}
if cmd.Values != nil {
for _, name := range cmd.Values(ctx, r) {
if name.Text == arg {
if subCmd, j := t.find(ctx, cmd.Nodes, r, i+1); subCmd != nil {
return subCmd, j
}
return cmd, i
}
}
}
}
return nil, i
return cmd.help(ctx, r)
}

View File

@ -29,13 +29,11 @@ func TestRoot(t *testing.T) {
ErrThird1 = errors.New("third1")
)
r := &tree.Root{
r := tree.New(&tree.Node{
Name: "root",
Description: "root tree",
Node: &tree.Node{
Execute: func(ctx context.Context, r *readline.Readline) error {
return ErrRoot
},
Execute: func(ctx context.Context, r *readline.Readline) error {
return ErrRoot
},
Nodes: tree.Nodes{
{
@ -85,7 +83,7 @@ func TestRoot(t *testing.T) {
},
},
},
}
})
tests := []struct {
name string
@ -179,39 +177,35 @@ func TestRoot_Node(t *testing.T) {
tests := []struct {
name string
root *tree.Root
root tree.Root
wantErr assert.ErrorAssertionFunc
}{
{
name: "tree",
root: &tree.Root{},
root: tree.New(&tree.Node{}),
wantErr: func(t assert.TestingT, err error, i ...interface{}) bool {
return assert.ErrorIs(t, err, tree.ErrNoop)
},
},
{
name: "tree",
root: &tree.Root{
Node: &tree.Node{
Execute: func(ctx context.Context, r *readline.Readline) error {
return ErrOK
},
root: tree.New(&tree.Node{
Execute: func(ctx context.Context, r *readline.Readline) error {
return ErrOK
},
},
}),
wantErr: func(t assert.TestingT, err error, i ...interface{}) bool {
return assert.ErrorIs(t, err, ErrOK)
},
},
{
name: "tree one",
root: &tree.Root{
Node: &tree.Node{
Execute: func(ctx context.Context, r *readline.Readline) error {
assert.Equal(T(ctx), "one", r.Args().At(0))
return ErrOK
},
root: tree.New(&tree.Node{
Execute: func(ctx context.Context, r *readline.Readline) error {
assert.Equal(T(ctx), "one", r.Args().At(0))
return ErrOK
},
},
}),
wantErr: func(t assert.TestingT, err error, i ...interface{}) bool {
return assert.ErrorIs(t, err, ErrOK)
},
@ -238,107 +232,97 @@ func TestRoot_NodeArgs(t *testing.T) {
tests := []struct {
name string
root *tree.Root
root tree.Root
wantErr assert.ErrorAssertionFunc
}{
{
name: "tree",
root: &tree.Root{
Node: &tree.Node{
Args: tree.Args{
{
Name: "first",
},
},
Execute: func(ctx context.Context, r *readline.Readline) error {
return ErrOK
root: tree.New(&tree.Node{
Args: tree.Args{
{
Name: "first",
},
},
},
Execute: func(ctx context.Context, r *readline.Readline) error {
return ErrOK
},
}),
wantErr: func(t assert.TestingT, err error, i ...interface{}) bool {
return assert.ErrorIs(t, err, tree.ErrMissingArgument)
},
},
{
name: "tree one",
root: &tree.Root{
Node: &tree.Node{
Args: tree.Args{
{
Name: "first",
},
},
Execute: func(ctx context.Context, r *readline.Readline) error {
assert.Equal(T(ctx), "one", r.Args().At(0))
return ErrOK
root: tree.New(&tree.Node{
Args: tree.Args{
{
Name: "first",
},
},
},
Execute: func(ctx context.Context, r *readline.Readline) error {
assert.Equal(T(ctx), "one", r.Args().At(0))
return ErrOK
},
}),
wantErr: func(t assert.TestingT, err error, i ...interface{}) bool {
return assert.ErrorIs(t, err, ErrOK)
},
},
{
name: "tree",
root: &tree.Root{
Node: &tree.Node{
Args: tree.Args{
{
Name: "first",
},
{
Name: "second",
},
root: tree.New(&tree.Node{
Args: tree.Args{
{
Name: "first",
},
Execute: func(ctx context.Context, r *readline.Readline) error {
return ErrOK
{
Name: "second",
},
},
},
Execute: func(ctx context.Context, r *readline.Readline) error {
return ErrOK
},
}),
wantErr: func(t assert.TestingT, err error, i ...interface{}) bool {
return assert.ErrorIs(t, err, tree.ErrMissingArgument)
},
},
{
name: "tree one",
root: &tree.Root{
Node: &tree.Node{
Args: tree.Args{
{
Name: "first",
},
{
Name: "second",
},
root: tree.New(&tree.Node{
Args: tree.Args{
{
Name: "first",
},
Execute: func(ctx context.Context, r *readline.Readline) error {
return ErrOK
{
Name: "second",
},
},
},
Execute: func(ctx context.Context, r *readline.Readline) error {
return ErrOK
},
}),
wantErr: func(t assert.TestingT, err error, i ...interface{}) bool {
return assert.ErrorIs(t, err, tree.ErrMissingArgument)
},
},
{
name: "tree one two",
root: &tree.Root{
Node: &tree.Node{
Args: tree.Args{
{
Name: "first",
},
{
Name: "second",
},
root: tree.New(&tree.Node{
Args: tree.Args{
{
Name: "first",
},
Execute: func(ctx context.Context, r *readline.Readline) error {
assert.Equal(T(ctx), "one", r.Args().At(0))
assert.Equal(T(ctx), "two", r.Args().At(1))
return ErrOK
{
Name: "second",
},
},
},
Execute: func(ctx context.Context, r *readline.Readline) error {
assert.Equal(T(ctx), "one", r.Args().At(0))
assert.Equal(T(ctx), "two", r.Args().At(1))
return ErrOK
},
}),
wantErr: func(t assert.TestingT, err error, i ...interface{}) bool {
return assert.ErrorIs(t, err, ErrOK)
},
@ -365,49 +349,57 @@ func TestRoot_NodeFlags(t *testing.T) {
tests := []struct {
name string
root *tree.Root
root tree.Root
wantErr assert.ErrorAssertionFunc
}{
{
name: "tree",
root: &tree.Root{
Node: &tree.Node{
Flags: func(ctx context.Context, r *readline.Readline, fs *readline.FlagSet) error {
fs.String("first", "first", "first")
fs.Bool("second", false, "second")
fs.Int64("third", 0, "third")
return nil
},
Execute: func(ctx context.Context, r *readline.Readline) error {
assert.Equal(T(ctx), "first", r.FlagSet().GetString("first"))
assert.False(T(ctx), r.FlagSet().GetBool("second"))
assert.Equal(T(ctx), int64(0), r.FlagSet().GetInt64("third"))
return ErrOK
},
root: tree.New(&tree.Node{
Flags: func(ctx context.Context, r *readline.Readline, fs *readline.FlagSets) error {
fs.Default().String("first", "first", "first")
fs.Default().Bool("second", false, "second")
fs.Default().Int64("third", 0, "third")
return nil
},
},
Execute: func(ctx context.Context, r *readline.Readline) error {
if value, err := r.FlagSets().Default().GetString("first"); assert.NoError(T(ctx), err) {
assert.Equal(T(ctx), "first", value)
}
if value, err := r.FlagSets().Default().GetBool("second"); assert.NoError(T(ctx), err) {
assert.False(T(ctx), value)
}
if value, err := r.FlagSets().Default().GetInt64("third"); assert.NoError(T(ctx), err) {
assert.Equal(T(ctx), int64(0), value)
}
return ErrOK
},
}),
wantErr: func(t assert.TestingT, err error, i ...interface{}) bool {
return assert.ErrorIs(t, err, ErrOK)
},
},
{
name: "tree --first one --second --third 13",
root: &tree.Root{
Node: &tree.Node{
Flags: func(ctx context.Context, r *readline.Readline, fs *readline.FlagSet) error {
fs.String("first", "first", "first")
fs.Bool("second", false, "second")
fs.Int64("third", 0, "third")
return nil
},
Execute: func(ctx context.Context, r *readline.Readline) error {
assert.Equal(T(ctx), "one", r.FlagSet().GetString("first"))
assert.True(T(ctx), r.FlagSet().GetBool("second"))
assert.Equal(T(ctx), int64(13), r.FlagSet().GetInt64("third"))
return ErrOK
},
root: tree.New(&tree.Node{
Flags: func(ctx context.Context, r *readline.Readline, fs *readline.FlagSets) error {
fs.Default().String("first", "first", "first")
fs.Default().Bool("second", false, "second")
fs.Default().Int64("third", 0, "third")
return nil
},
},
Execute: func(ctx context.Context, r *readline.Readline) error {
if value, err := r.FlagSets().Default().GetString("first"); assert.NoError(T(ctx), err) {
assert.Equal(T(ctx), "one", value)
}
if value, err := r.FlagSets().Default().GetBool("second"); assert.NoError(T(ctx), err) {
assert.True(T(ctx), value)
}
if value, err := r.FlagSets().Default().GetInt64("third"); assert.NoError(T(ctx), err) {
assert.Equal(T(ctx), int64(13), value)
}
return ErrOK
},
}),
wantErr: func(t assert.TestingT, err error, i ...interface{}) bool {
return assert.ErrorIs(t, err, ErrOK)
},

8
pkg/log/must.go Normal file
View File

@ -0,0 +1,8 @@
package log
func MustGet[T any](value T, err error) func(l Logger) T {
return func(l Logger) T {
l.Must(err)
return value
}
}

View File

@ -2,8 +2,6 @@ package prompt
import (
"context"
"os"
"os/signal"
"strings"
"github.com/c-bata/go-prompt"
@ -239,7 +237,7 @@ func (s *Prompt) execute(input string) {
return
}
s.history.Persist(s.ctx, input)
defer s.history.Persist(s.ctx, input)
input = s.alias(input, s.aliases)
@ -317,12 +315,6 @@ func (s *Prompt) complete(d prompt.Document) []prompt.Suggest {
} else if value, ok := cmd.(command.Completer); ok {
return s.filter(value.Complete(ctx, s.readline), word, true)
}
case readline.ModePassThroughFlags:
if value, ok := cmd.(command.PassThroughFlagsCompleter); ok {
return s.filter(value.CompletePassTroughFlags(ctx, s.readline), word, true)
} else if value, ok := cmd.(command.Completer); ok {
return s.filter(value.Complete(ctx, s.readline), word, true)
}
case readline.ModeAdditionalArgs:
if value, ok := cmd.(command.AdditionalArgsCompleter); ok {
return s.filter(value.CompleteAdditionalArgs(ctx, s.readline), word, true)
@ -335,17 +327,21 @@ func (s *Prompt) complete(d prompt.Document) []prompt.Suggest {
// context returns and watches over a new context
func (s *Prompt) context() context.Context {
ctx, cancel := context.WithCancel(context.Background()) // FIXME context.WithCancel(s.ctx)
go func(ctx context.Context, cancel context.CancelFunc) {
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt)
select {
case <-sigChan:
cancel()
return
case <-ctx.Done():
return
}
}(ctx, cancel)
return ctx
//ctx, cancel := context.WithCancel(context.Background())
//go func(ctx context.Context, cancel context.CancelFunc) {
// sigChan := make(chan os.Signal, 1)
// signal.Notify(sigChan, os.Interrupt)
// select {
// case <-s.ctx.Done():
// cancel()
// return
// case <-sigChan:
// cancel()
// return
// case <-ctx.Done():
// return
// }
//}(ctx, cancel)
//return ctx
return s.ctx
}

View File

@ -30,5 +30,5 @@ func (a Arg) IsRedirect() bool {
}
func (a Arg) IsAdditional() bool {
return a.IsPipe() || a.IsRedirect()
return a.IsPass() || a.IsPipe() || a.IsRedirect()
}

37
pkg/readline/flags.go Normal file
View File

@ -0,0 +1,37 @@
package readline
import (
"github.com/spf13/pflag"
)
type Flags []*pflag.Flag
func (f Flags) Remove(name string) (*pflag.Flag, Flags) {
for i, v := range f {
if v.Name == name {
return v, f.Splice(i, 1)
}
}
return nil, f
}
func (f Flags) Slice(start, end int) Flags {
return append(f[:start], f[end:]...)
}
func (f Flags) Splice(start, num int) Flags {
return append(f[:start], f[start+num:]...)
}
func (f Flags) Args() Args {
var ret Args
for _, v := range f {
switch v.Value.Type() {
case "bool":
ret = append(ret, "--"+v.Name)
default:
ret = append(ret, "--"+v.Name, v.Value.String())
}
}
return ret
}

View File

@ -1,8 +1,6 @@
package readline
import (
"strconv"
"github.com/spf13/pflag"
)
@ -10,73 +8,31 @@ type FlagSet struct {
*pflag.FlagSet
}
func NewFlagSet() *FlagSet {
func NewFlagSet(name string) *FlagSet {
fs := pflag.NewFlagSet(name, pflag.ContinueOnError)
fs.ParseErrorsWhitelist = pflag.ParseErrorsWhitelist{UnknownFlags: true}
return &FlagSet{
FlagSet: pflag.NewFlagSet("readline", pflag.ContinueOnError),
FlagSet: fs,
}
}
func (a *FlagSet) SetValues(name string, values ...string) error {
return a.SetAnnotation(name, "values", values)
func (s *FlagSet) Visited() Flags {
var ret Flags
s.Visit(func(f *pflag.Flag) {
ret = append(ret, f)
})
return ret
}
func (a *FlagSet) GetValues(name string) []string {
if f := a.FlagSet.Lookup(name); f == nil {
func (s *FlagSet) SetValues(name string, values ...string) error {
return s.SetAnnotation(name, "values", values)
}
func (s *FlagSet) GetValues(name string) []string {
if f := s.FlagSet.Lookup(name); f == nil {
return nil
} else if v, ok := f.Annotations["values"]; ok {
return v
}
return nil
}
func (a *FlagSet) GetString(name string) string {
if f := a.FlagSet.Lookup(name); f == nil {
return ""
} else if !a.flagIsSet(name) {
return f.DefValue
} else {
return f.Value.String()
}
}
func (a *FlagSet) GetInt64(name string) int64 {
if value := a.GetString(name); value == "" {
return 0
} else if v, err := strconv.ParseInt(value, 10, 64); err != nil {
return 0
} else {
return v
}
}
func (a *FlagSet) GetFloat64(name string) float64 {
if value := a.GetString(name); value == "" {
return 0
} else if v, err := strconv.ParseFloat(value, 64); err != nil {
return 0
} else {
return v
}
}
func (a *FlagSet) GetBool(name string) bool {
if value := a.GetString(name); value == "" {
return false
} else if v, err := strconv.ParseBool(value); err != nil {
return false
} else {
return v
}
}
func (a *FlagSet) flagIsSet(name string) bool {
found := false
if fs := a.FlagSet; fs != nil {
fs.Visit(func(f *pflag.Flag) {
if f.Name == name {
found = true
}
})
}
return found
}

80
pkg/readline/flagsets.go Normal file
View File

@ -0,0 +1,80 @@
package readline
import (
"github.com/spf13/pflag"
)
type FlagSets struct {
sets map[string]*FlagSet
}
func NewFlagSets() *FlagSets {
return &FlagSets{
sets: map[string]*FlagSet{},
}
}
func (s *FlagSets) Default() *FlagSet {
return s.Get("default")
}
func (s *FlagSets) Internal() *FlagSet {
return s.Get("internal")
}
func (s *FlagSets) Get(name string) *FlagSet {
if _, ok := s.sets[name]; !ok {
s.sets[name] = NewFlagSet(name)
}
return s.sets[name]
}
func (s *FlagSets) Parse(arguments []string) error {
for _, set := range s.sets {
if err := set.Parse(arguments); err != nil {
return err
}
}
return nil
}
func (s *FlagSets) All() *FlagSet {
fs := NewFlagSet("all")
for _, set := range s.sets {
fs.AddFlagSet(set.FlagSet)
}
return fs
}
func (s *FlagSets) Visit(fn func(*pflag.Flag)) Flags {
var ret Flags
for _, set := range s.sets {
set.Visit(fn)
}
return ret
}
func (s *FlagSets) VisitAll(fn func(*pflag.Flag)) Flags {
var ret Flags
for _, set := range s.sets {
set.VisitAll(fn)
}
return ret
}
func (s *FlagSets) Visited() Flags {
var ret Flags
for _, set := range s.sets {
ret = append(ret, set.Visited()...)
}
return ret
}
func (s *FlagSets) ParseAll(arguments []string, fn func(flag *pflag.Flag, value string) error) error {
for _, group := range s.sets {
if err := group.ParseAll(arguments, fn); err != nil {
return err
}
}
return nil
}

View File

@ -3,8 +3,7 @@ package readline
type Mode string
const (
ModeArgs Mode = ""
ModeFlags Mode = "flags"
ModePassThroughFlags Mode = "passThroughFlags"
ModeAdditionalArgs Mode = "additional"
ModeArgs Mode = "args"
ModeFlags Mode = "flags"
ModeAdditionalArgs Mode = "additional"
)

View File

@ -11,16 +11,14 @@ import (
type (
Readline struct {
l log.Logger
mu sync.RWMutex
cmd string
mode Mode
args Args
flags Args
flagSet *FlagSet
passThroughFlags Args
passThroughFlagSet *FlagSet
additionalArgs Args
l log.Logger
mu sync.RWMutex
cmd string
mode Mode
args Args
flags Args
flagSets *FlagSets
additionalArgs Args
// regex - split cmd into args (https://regex101.com/r/EgiOzv/1)
regex *regexp.Regexp
}
@ -85,22 +83,10 @@ func (a *Readline) Flags() Args {
return a.flags
}
func (a *Readline) FlagSet() *FlagSet {
func (a *Readline) FlagSets() *FlagSets {
a.mu.RLock()
defer a.mu.RUnlock()
return a.flagSet
}
func (a *Readline) PassThroughFlags() Args {
a.mu.RLock()
defer a.mu.RUnlock()
return a.passThroughFlags
}
func (a *Readline) PassThroughFlagSet() *FlagSet {
a.mu.RLock()
defer a.mu.RUnlock()
return a.passThroughFlagSet
return a.flagSets
}
func (a *Readline) AdditionalArgs() Args {
@ -125,14 +111,10 @@ func (a *Readline) Parse(input string) error {
a.cmd, parts = parts[0], parts[1:]
}
last := len(parts) - 1
for i, part := range parts {
if a.mode == ModeArgs && Arg(part).IsFlag() {
a.mode = ModeFlags
}
if i != last && (a.mode == ModeArgs || a.mode == ModeFlags) && Arg(part).IsPass() {
a.mode = ModePassThroughFlags
}
if Arg(part).IsAdditional() && i < len(parts)-1 {
a.mode = ModeAdditionalArgs
}
@ -142,8 +124,6 @@ func (a *Readline) Parse(input string) error {
a.args = append(a.args, part)
case ModeFlags:
a.flags = append(a.flags, part)
case ModePassThroughFlags:
a.passThroughFlags = append(a.passThroughFlags, part)
case ModeAdditionalArgs:
a.additionalArgs = append(a.additionalArgs, part)
}
@ -152,84 +132,68 @@ func (a *Readline) Parse(input string) error {
return nil
}
func (a *Readline) SetFlags(fs *FlagSet) {
func (a *Readline) SetFlagSets(fs *FlagSets) {
a.mu.Lock()
defer a.mu.Unlock()
a.flagSet = fs
a.flagSets = fs
}
func (a *Readline) ParseFlags() error {
if fs := a.FlagSet(); fs == nil {
return nil
} else if err := fs.Parse(a.flags); err != nil {
return err
}
return nil
}
func (a *Readline) SetParsePassThroughFlags(fs *FlagSet) {
a.mu.Lock()
defer a.mu.Unlock()
a.passThroughFlagSet = fs
}
func (a *Readline) ParsePassThroughFlags() error {
if fs := a.PassThroughFlagSet(); fs == nil {
return nil
} else if err := fs.Parse(a.passThroughFlags); err != nil {
return err
func (a *Readline) ParseFlagSets() error {
if fs := a.FlagSets(); fs != nil {
if err := fs.Parse(a.flags); err != nil {
return err
}
}
return nil
}
func (a *Readline) String() string {
return fmt.Sprintf(`
Cmd: %s
Mode %s
Args: %s
Flags: %s
PassThroughFlags: %s
AdditionalArgs %s
`, a.Cmd(), a.Mode(), a.Args(), a.Flags(), a.PassThroughFlags(), a.AdditionalArgs())
Cmd: %s
Args: %s
Flags: %s
AdditionalArgs: %s
`, a.Cmd(), a.Args(), a.Flags(), a.AdditionalArgs())
}
func (a *Readline) IsModeDefault() bool {
return a.Mode() == ModeArgs
}
func (a *Readline) IsModePassThrough() bool {
return a.Mode() == ModePassThroughFlags
}
func (a *Readline) IsModeAdditional() bool {
return a.Mode() == ModeAdditionalArgs
}
func (a *Readline) AllFlags() []*pflag.Flag {
var ret []*pflag.Flag
if fs := a.FlagSet(); fs != nil {
fs.VisitAll(func(f *pflag.Flag) {
if fs := a.FlagSets(); fs != nil {
fs.All().VisitAll(func(f *pflag.Flag) {
ret = append(ret, f)
})
}
return ret
}
func (a *Readline) VisitedFlags() []*pflag.Flag {
var ret []*pflag.Flag
if fs := a.FlagSet(); fs != nil {
fs.Visit(func(f *pflag.Flag) {
ret = append(ret, f)
})
func (a *Readline) VisitedFlags() Flags {
var ret Flags
if fs := a.FlagSets(); fs != nil {
ret = fs.Visited()
}
return ret
}
func (a *Readline) AllPassThroughFlags() []*pflag.Flag {
var ret []*pflag.Flag
if fs := a.PassThroughFlagSet(); fs != nil {
func (a *Readline) AdditionalFlags() Args {
ret := append(Args{}, a.flags...)
if fs := a.FlagSets(); fs != nil {
fs.VisitAll(func(f *pflag.Flag) {
ret = append(ret, f)
if i := ret.IndexOf("--" + f.Name); i >= 0 {
switch f.Value.Type() {
case "bool":
ret = ret.Splice(ret.IndexOf("--"+f.Name), 1)
default:
ret = ret.Splice(ret.IndexOf("--"+f.Name), 2)
}
}
})
}
return ret
@ -244,8 +208,6 @@ func (a *Readline) reset() {
a.cmd = ""
a.args = nil
a.flags = nil
a.flagSet = nil
a.passThroughFlags = nil
a.passThroughFlagSet = nil
a.flagSets = nil
a.additionalArgs = nil
}

View File

@ -0,0 +1,118 @@
package readline_test
import (
"testing"
"github.com/foomo/posh/pkg/log"
"github.com/foomo/posh/pkg/readline"
"github.com/stretchr/testify/assert"
)
func TestReadline(t *testing.T) {
tests := []struct {
name string
want func(t *testing.T, r *readline.Readline)
}{
{
name: "foo bar",
want: func(t *testing.T, r *readline.Readline) {
assert.Equal(t, `
Cmd: foo
Args: [bar]
Flags: []
AdditionalArgs: []
`,
r.String(),
)
},
},
{
name: "foo bar baz",
want: func(t *testing.T, r *readline.Readline) {
assert.Equal(t, `
Cmd: foo
Args: [bar baz]
Flags: []
AdditionalArgs: []
`,
r.String(),
)
},
},
{
name: "foo bar --baz",
want: func(t *testing.T, r *readline.Readline) {
assert.Equal(t, `
Cmd: foo
Args: [bar]
Flags: [--baz]
AdditionalArgs: []
`,
r.String(),
)
},
},
{
name: "foo --baz bar",
want: func(t *testing.T, r *readline.Readline) {
assert.Equal(t, `
Cmd: foo
Args: []
Flags: [--baz bar]
AdditionalArgs: []
`,
r.String(),
)
},
},
{
name: "foo --baz bar1",
want: func(t *testing.T, r *readline.Readline) {
assert.Equal(t, `
Cmd: foo
Args: []
Flags: [--baz bar1]
AdditionalArgs: []
`,
r.String(),
)
},
},
{
name: "foo | cat",
want: func(t *testing.T, r *readline.Readline) {
assert.Equal(t, `
Cmd: foo
Args: []
Flags: []
AdditionalArgs: [| cat]
`,
r.String(),
)
},
},
{
name: "foo --bar1 --bar2 one --bar3 two,three,four",
want: func(t *testing.T, r *readline.Readline) {
assert.Equal(t, `
Cmd: foo
Args: []
Flags: [--bar1 --bar2 one --bar3 two,three,four]
AdditionalArgs: []
`,
r.String(),
)
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
l := log.NewTest(t, log.TestWithLevel(log.LevelDebug))
if r, err := readline.New(l); assert.NoError(t, err) {
assert.NoError(t, r.Parse(tt.name))
tt.want(t, r)
}
})
}
}