Skip to content

Commit 5276b64

Browse files
committed
Fix overwrite by SetDefault for options that share Value
1 parent 5a56b57 commit 5276b64

File tree

2 files changed

+157
-15
lines changed

2 files changed

+157
-15
lines changed

command_test.go

Lines changed: 78 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,14 @@ func sampleCommand(t *testing.T) *serpent.Command {
4949
Use: "root [subcommand]",
5050
Options: serpent.OptionSet{
5151
serpent.Option{
52-
Name: "verbose",
53-
Flag: "verbose",
52+
Name: "verbose",
53+
Flag: "verbose",
54+
Default: "false",
55+
Value: serpent.BoolOf(&verbose),
56+
},
57+
serpent.Option{
58+
Name: "verbose-old",
59+
Flag: "verbode-old",
5460
Value: serpent.BoolOf(&verbose),
5561
},
5662
serpent.Option{
@@ -742,6 +748,12 @@ func TestCommand_DefaultsOverride(t *testing.T) {
742748
Value: serpent.StringOf(&got),
743749
YAML: "url",
744750
},
751+
{
752+
Name: "url-deprecated",
753+
Flag: "url-deprecated",
754+
Env: "URL_DEPRECATED",
755+
Value: serpent.StringOf(&got),
756+
},
745757
{
746758
Name: "config",
747759
Flag: "config",
@@ -790,6 +802,17 @@ func TestCommand_DefaultsOverride(t *testing.T) {
790802
inv.Args = []string{"--config", fi.Name(), "--url", "good.com"}
791803
})
792804

805+
test("EnvOverYAML", "good.com", func(t *testing.T, inv *serpent.Invocation) {
806+
fi, err := os.CreateTemp(t.TempDir(), "config.yaml")
807+
require.NoError(t, err)
808+
defer fi.Close()
809+
810+
_, err = fi.WriteString("url: bad.com")
811+
require.NoError(t, err)
812+
813+
inv.Environ.Set("URL", "good.com")
814+
})
815+
793816
test("YAMLOverDefault", "good.com", func(t *testing.T, inv *serpent.Invocation) {
794817
fi, err := os.CreateTemp(t.TempDir(), "config.yaml")
795818
require.NoError(t, err)
@@ -800,4 +823,57 @@ func TestCommand_DefaultsOverride(t *testing.T) {
800823

801824
inv.Args = []string{"--config", fi.Name()}
802825
})
826+
827+
test("AltFlagOverDefault", "good.com", func(t *testing.T, inv *serpent.Invocation) {
828+
inv.Args = []string{"--url-deprecated", "good.com"}
829+
})
830+
}
831+
832+
func TestCommand_OptionsWithSharedValue(t *testing.T) {
833+
t.Parallel()
834+
835+
var got string
836+
makeCmd := func(def, altDef string) *serpent.Command {
837+
got = ""
838+
return &serpent.Command{
839+
Options: serpent.OptionSet{
840+
{
841+
Name: "url",
842+
Flag: "url",
843+
Default: def,
844+
Value: serpent.StringOf(&got),
845+
},
846+
{
847+
Name: "alt-url",
848+
Flag: "alt-url",
849+
Default: altDef,
850+
Value: serpent.StringOf(&got),
851+
},
852+
},
853+
Handler: (func(i *serpent.Invocation) error {
854+
return nil
855+
}),
856+
}
857+
}
858+
859+
// Check proper value propagation.
860+
err := makeCmd("def.com", "def.com").Invoke().Run()
861+
require.NoError(t, err, "default values are same")
862+
require.Equal(t, "def.com", got)
863+
864+
err = makeCmd("def.com", "").Invoke().Run()
865+
require.NoError(t, err, "other default value is empty")
866+
require.Equal(t, "def.com", got)
867+
868+
err = makeCmd("def.com", "").Invoke("--url", "sup").Run()
869+
require.NoError(t, err)
870+
require.Equal(t, "sup", got)
871+
872+
err = makeCmd("def.com", "").Invoke("--alt-url", "hup").Run()
873+
require.NoError(t, err)
874+
require.Equal(t, "hup", got)
875+
876+
// Catch invalid configuration.
877+
err = makeCmd("def.com", "alt-def.com").Invoke().Run()
878+
require.Error(t, err, "default values are different")
803879
}

option.go

Lines changed: 79 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"bytes"
55
"encoding/json"
66
"os"
7+
"slices"
78
"strings"
89

910
"github.com/hashicorp/go-multierror"
@@ -21,6 +22,14 @@ const (
2122
ValueSourceDefault ValueSource = "default"
2223
)
2324

25+
var valueSourcePriority = []ValueSource{
26+
ValueSourceFlag,
27+
ValueSourceEnv,
28+
ValueSourceYAML,
29+
ValueSourceDefault,
30+
ValueSourceNone,
31+
}
32+
2433
// Option is a configuration option for a CLI application.
2534
type Option struct {
2635
Name string `json:"name,omitempty"`
@@ -305,16 +314,12 @@ func (optSet *OptionSet) SetDefaults() error {
305314

306315
var merr *multierror.Error
307316

308-
for i, opt := range *optSet {
309-
// Skip values that may have already been set by the user.
310-
if opt.ValueSource != ValueSourceNone {
311-
continue
312-
}
313-
314-
if opt.Default == "" {
315-
continue
316-
}
317-
317+
// It's common to have multiple options with the same value to
318+
// handle deprecation. We group the options by value so that we
319+
// don't let other options overwrite user input.
320+
groupByValue := make(map[pflag.Value][]*Option)
321+
for i := range *optSet {
322+
opt := &(*optSet)[i]
318323
if opt.Value == nil {
319324
merr = multierror.Append(
320325
merr,
@@ -325,13 +330,74 @@ func (optSet *OptionSet) SetDefaults() error {
325330
)
326331
continue
327332
}
328-
(*optSet)[i].ValueSource = ValueSourceDefault
329-
if err := opt.Value.Set(opt.Default); err != nil {
333+
groupByValue[opt.Value] = append(groupByValue[opt.Value], opt)
334+
}
335+
336+
for _, opts := range groupByValue {
337+
// Sort the options by priority and whether or not a default is
338+
// set. This won't affect the value but represents correctness
339+
// from whence the value originated.
340+
slices.SortFunc(opts, func(a, b *Option) int {
341+
if a.ValueSource != b.ValueSource {
342+
for _, vs := range valueSourcePriority {
343+
if a.ValueSource == vs {
344+
return -1
345+
}
346+
if b.ValueSource == vs {
347+
return 1
348+
}
349+
}
350+
}
351+
if a.Default != b.Default {
352+
if a.Default == "" {
353+
return 1
354+
}
355+
if b.Default == "" {
356+
return -1
357+
}
358+
}
359+
return 0
360+
})
361+
362+
// If the first option has a value source, then we don't need to
363+
// set the default, but mark the source for all options.
364+
if opts[0].ValueSource != ValueSourceNone {
365+
for _, opt := range opts[1:] {
366+
opt.ValueSource = opts[0].ValueSource
367+
}
368+
continue
369+
}
370+
371+
var optWithDefault *Option
372+
for _, opt := range opts {
373+
if opt.Default == "" {
374+
continue
375+
}
376+
if optWithDefault != nil && optWithDefault.Default != opt.Default {
377+
merr = multierror.Append(
378+
merr,
379+
xerrors.Errorf(
380+
"parse %q: multiple defaults set for the same value: %q and %q (%q)",
381+
opt.Name, opt.Default, optWithDefault.Default, optWithDefault.Name,
382+
),
383+
)
384+
continue
385+
}
386+
optWithDefault = opt
387+
}
388+
if optWithDefault == nil {
389+
continue
390+
}
391+
if err := optWithDefault.Value.Set(optWithDefault.Default); err != nil {
330392
merr = multierror.Append(
331-
merr, xerrors.Errorf("parse %q: %w", opt.Name, err),
393+
merr, xerrors.Errorf("parse %q: %w", optWithDefault.Name, err),
332394
)
333395
}
396+
for _, opt := range opts {
397+
opt.ValueSource = ValueSourceDefault
398+
}
334399
}
400+
335401
return merr.ErrorOrNil()
336402
}
337403

0 commit comments

Comments
 (0)