Skip to content

Commit 5e65616

Browse files
authored
Merge pull request #1835 from dearchap/issue_1834
Fix:(issue_1834) Add check for persistent required flags
2 parents 2458b93 + d6eaf9a commit 5e65616

File tree

2 files changed

+104
-21
lines changed

2 files changed

+104
-21
lines changed

command.go

Lines changed: 62 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -611,18 +611,26 @@ func (cmd *Command) Run(ctx context.Context, osArgs []string) (deferErr error) {
611611

612612
if cmd.Action == nil {
613613
cmd.Action = helpCommandAction
614-
} else if len(cmd.Arguments) > 0 {
615-
rargs := cmd.Args().Slice()
616-
tracef("calling argparse with %[1]v", rargs)
617-
for _, arg := range cmd.Arguments {
618-
var err error
619-
rargs, err = arg.Parse(rargs)
620-
if err != nil {
621-
tracef("calling with %[1]v (cmd=%[2]q)", err, cmd.Name)
622-
return err
614+
} else {
615+
if err := cmd.checkPersistentRequiredFlags(); err != nil {
616+
cmd.isInError = true
617+
_ = ShowSubcommandHelp(cmd)
618+
return err
619+
}
620+
621+
if len(cmd.Arguments) > 0 {
622+
rargs := cmd.Args().Slice()
623+
tracef("calling argparse with %[1]v", rargs)
624+
for _, arg := range cmd.Arguments {
625+
var err error
626+
rargs, err = arg.Parse(rargs)
627+
if err != nil {
628+
tracef("calling with %[1]v (cmd=%[2]q)", err, cmd.Name)
629+
return err
630+
}
623631
}
632+
cmd.parsedArgs = &stringSliceArgs{v: rargs}
624633
}
625-
cmd.parsedArgs = &stringSliceArgs{v: rargs}
626634
}
627635

628636
if err := cmd.Action(ctx, cmd); err != nil {
@@ -929,26 +937,59 @@ func (cmd *Command) lookupFlagSet(name string) *flag.FlagSet {
929937
return nil
930938
}
931939

940+
func (cmd *Command) checkRequiredFlag(f Flag) (bool, string) {
941+
if rf, ok := f.(RequiredFlag); ok && rf.IsRequired() {
942+
flagPresent := false
943+
flagName := ""
944+
945+
for _, key := range f.Names() {
946+
flagName = key
947+
948+
if cmd.IsSet(strings.TrimSpace(key)) {
949+
flagPresent = true
950+
}
951+
}
952+
953+
if !flagPresent && flagName != "" {
954+
return false, flagName
955+
}
956+
}
957+
return true, ""
958+
}
959+
932960
func (cmd *Command) checkRequiredFlags() requiredFlagsErr {
933961
tracef("checking for required flags (cmd=%[1]q)", cmd.Name)
934962

935963
missingFlags := []string{}
936964

937965
for _, f := range cmd.Flags {
938-
if rf, ok := f.(RequiredFlag); ok && rf.IsRequired() {
939-
flagPresent := false
940-
flagName := ""
966+
if pf, ok := f.(PersistentFlag); !ok || !pf.IsPersistent() {
967+
if ok, name := cmd.checkRequiredFlag(f); !ok {
968+
missingFlags = append(missingFlags, name)
969+
}
970+
}
971+
}
941972

942-
for _, key := range f.Names() {
943-
flagName = key
973+
if len(missingFlags) != 0 {
974+
tracef("found missing required flags %[1]q (cmd=%[2]q)", missingFlags, cmd.Name)
944975

945-
if cmd.IsSet(strings.TrimSpace(key)) {
946-
flagPresent = true
947-
}
948-
}
976+
return &errRequiredFlags{missingFlags: missingFlags}
977+
}
978+
979+
tracef("all required flags set (cmd=%[1]q)", cmd.Name)
980+
981+
return nil
982+
}
983+
984+
func (cmd *Command) checkPersistentRequiredFlags() requiredFlagsErr {
985+
tracef("checking for required flags (cmd=%[1]q)", cmd.Name)
986+
987+
missingFlags := []string{}
949988

950-
if !flagPresent && flagName != "" {
951-
missingFlags = append(missingFlags, flagName)
989+
for _, f := range cmd.appliedFlags {
990+
if pf, ok := f.(PersistentFlag); ok && pf.IsPersistent() {
991+
if ok, name := cmd.checkRequiredFlag(f); !ok {
992+
missingFlags = append(missingFlags, name)
952993
}
953994
}
954995
}

command_test.go

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2926,6 +2926,7 @@ func TestFlagAction(t *testing.T) {
29262926
func TestPersistentFlag(t *testing.T) {
29272927
var topInt, topPersistentInt, subCommandInt, appOverrideInt int64
29282928
var appFlag string
2929+
var appRequiredFlag string
29292930
var appOverrideCmdInt int64
29302931
var appSliceFloat64 []float64
29312932
var persistentCommandSliceInt []int64
@@ -2957,6 +2958,12 @@ func TestPersistentFlag(t *testing.T) {
29572958
Persistent: true,
29582959
Destination: &appOverrideInt,
29592960
},
2961+
&StringFlag{
2962+
Name: "persistentRequiredCommandFlag",
2963+
Persistent: true,
2964+
Required: true,
2965+
Destination: &appRequiredFlag,
2966+
},
29602967
},
29612968
Commands: []*Command{
29622969
{
@@ -3005,6 +3012,7 @@ func TestPersistentFlag(t *testing.T) {
30053012
"--persistentCommandSliceFlag", "102",
30063013
"--persistentCommandFloatSliceFlag", "102.455",
30073014
"--paof", "105",
3015+
"--persistentRequiredCommandFlag", "hellor",
30083016
"subcmd",
30093017
"--cmdPersistentFlag", "20",
30103018
"--cmdFlag", "11",
@@ -3021,6 +3029,10 @@ func TestPersistentFlag(t *testing.T) {
30213029
t.Errorf("Expected 'bar' got %s", appFlag)
30223030
}
30233031

3032+
if appRequiredFlag != "hellor" {
3033+
t.Errorf("Expected 'hellor' got %s", appRequiredFlag)
3034+
}
3035+
30243036
if topInt != 12 {
30253037
t.Errorf("Expected 12 got %d", topInt)
30263038
}
@@ -3096,6 +3108,36 @@ func TestPersistentFlagIsSet(t *testing.T) {
30963108
r.True(resultIsSet)
30973109
}
30983110

3111+
func TestRequiredPersistentFlag(t *testing.T) {
3112+
3113+
app := &Command{
3114+
Name: "root",
3115+
Flags: []Flag{
3116+
&StringFlag{
3117+
Name: "result",
3118+
Persistent: true,
3119+
Required: true,
3120+
},
3121+
},
3122+
Commands: []*Command{
3123+
{
3124+
Name: "sub",
3125+
Action: func(ctx context.Context, c *Command) error {
3126+
return nil
3127+
},
3128+
},
3129+
},
3130+
}
3131+
3132+
r := require.New(t)
3133+
3134+
err := app.Run(context.Background(), []string{"root", "sub"})
3135+
r.Error(err)
3136+
3137+
err = app.Run(context.Background(), []string{"root", "sub", "--result", "after"})
3138+
r.NoError(err)
3139+
}
3140+
30993141
func TestFlagDuplicates(t *testing.T) {
31003142
tests := []struct {
31013143
name string

0 commit comments

Comments
 (0)