Skip to content

Commit 119c807

Browse files
committed
feat: add positional arguments as pseudo-arguments
1 parent 3d86fd7 commit 119c807

File tree

4 files changed

+240
-13
lines changed

4 files changed

+240
-13
lines changed

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ require (
2424
github.com/charmbracelet/lipgloss v0.8.0 // indirect
2525
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
2626
github.com/go-logr/logr v1.4.1 // indirect
27-
github.com/google/go-cmp v0.6.0 // indirect
27+
github.com/google/go-cmp v0.7.0 // indirect
2828
github.com/hashicorp/errwrap v1.1.0 // indirect
2929
github.com/kr/pretty v0.3.1 // indirect
3030
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg
2828
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
2929
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
3030
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
31+
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
32+
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
3133
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
3234
github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I=
3335
github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=

mcp.go

Lines changed: 133 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"fmt"
99
"io"
1010
"path"
11+
"slices"
1112
"strings"
1213

1314
"golang.org/x/xerrors"
@@ -473,20 +474,26 @@ func (s *MCPServer) handleCallTool(req JSONRPC2Request) {
473474
// Convert the arguments map to command-line args
474475
var cmdArgs []string
475476

476-
// Check for positional arguments (using "_" as the key)
477-
if posArgs, ok := args["_"]; ok {
478-
switch val := posArgs.(type) {
479-
case string:
480-
cmdArgs = append(cmdArgs, val)
481-
case []any:
482-
for _, item := range val {
483-
cmdArgs = append(cmdArgs, fmt.Sprintf("%v", item))
477+
// Check for positional arguments prefix with `argN__<name>`
478+
deleteKeys := make([]string, 0)
479+
for k, v := range args {
480+
if strings.HasPrefix(k, "arg") && len(k) > 4 && k[3] >= '0' && k[3] <= '9' {
481+
deleteKeys = append(deleteKeys, k)
482+
switch val := v.(type) {
483+
case string:
484+
cmdArgs = append(cmdArgs, val)
485+
case []any:
486+
for _, item := range val {
487+
cmdArgs = append(cmdArgs, fmt.Sprintf("%v", item))
488+
}
489+
default:
490+
cmdArgs = append(cmdArgs, fmt.Sprintf("%v", val))
484491
}
485-
default:
486-
cmdArgs = append(cmdArgs, fmt.Sprintf("%v", val))
487492
}
488-
// Remove the "_" key so it's not processed as a flag
489-
delete(args, "_")
493+
}
494+
// Delete any of the positional argument keys so they don't get processed below.
495+
for _, dk := range deleteKeys {
496+
delete(args, dk)
490497
}
491498

492499
// Process remaining arguments as flags
@@ -639,6 +646,15 @@ func (s *MCPServer) generateJSONSchema(cmd *Command) (json.RawMessage, error) {
639646
properties := schema["properties"].(map[string]any)
640647
requiredList := schema["required"].([]string)
641648

649+
// Add positional arguments based on the cmd usage.
650+
if posArgs, err := PosArgsFromCmdUsage(cmd.Use); err != nil {
651+
return nil, xerrors.Errorf("unable to process positional argument for command %q: %w", cmd.Name(), err)
652+
} else {
653+
for k, v := range posArgs {
654+
properties[k] = v
655+
}
656+
}
657+
642658
// Process each option in the command
643659
for _, opt := range cmd.Options {
644660
// Skip options that aren't exposed as flags
@@ -925,3 +941,108 @@ Commands with neither Tool nor Resource set will not be accessible via MCP.`,
925941
},
926942
}
927943
}
944+
945+
// PosArgsFromCmdUsage attempts to process a 'usage' string into a set of
946+
// arguments for display as tool parameters.
947+
// Example: the usage string `foo [flags] <bar> [baz] [razzle|dazzle]`
948+
// defines three arguments for the `foo` command:
949+
// - bar (required)
950+
// - baz (optional)
951+
// - the string `razzle` XOR `dazzle` (optional)
952+
//
953+
// The expected output of the above is as follows:
954+
//
955+
// {
956+
// "arg1:bar": {
957+
// "type": "string",
958+
// "description": "required argument",
959+
// },
960+
// "arg2:baz": {
961+
// "type": "string",
962+
// "description": "optional argument",
963+
// },
964+
// "arg3:razzle_dazzle": {
965+
// "type": "string",
966+
// "enum": ["razzle", "dazzle"]
967+
// },
968+
// }
969+
//
970+
// The usage string is processed given the following assumptions:
971+
// 1. The first non-whitespace string of usage is the name of the command
972+
// and will be skipped.
973+
// 2. The pseudo-argument specifier [flags] will also be skipped, if present.
974+
// 3. Argument specifiers enclosed by [square brackets] are considered optional.
975+
// 4. All other argument specifiers are considered required.
976+
// 5. Invidiual argument specifiers are separated by a single whitespace character.
977+
// Argument specifiers that contain a space are considered invalid (e.g. `[foo bar]`)
978+
//
979+
// Variadic arguments [arg...] are treated as a single argument.
980+
func PosArgsFromCmdUsage(usage string) (map[string]any, error) {
981+
if len(usage) == 0 {
982+
return nil, xerrors.Errorf("usage may not be empty")
983+
}
984+
985+
// Step 1: preprocessing. Skip the first token.
986+
parts := strings.Fields(usage)
987+
if len(parts) < 2 {
988+
return map[string]any{}, nil
989+
}
990+
parts = parts[1:]
991+
// Skip [flags], if present.
992+
parts = slices.DeleteFunc(parts, func(s string) bool {
993+
return s == "[flags]"
994+
})
995+
996+
result := make(map[string]any, len(parts))
997+
998+
// Process each argument token
999+
for i, part := range parts {
1000+
argIndex := i + 1
1001+
argKey := fmt.Sprintf("arg%d__", argIndex)
1002+
1003+
// Check for unbalanced brackets in the part.
1004+
// This catches cases like "command [flags] [a" or "command [flags] a b [c | d]"
1005+
// which would be split into multiple tokens by strings.Fields()
1006+
openSquare := strings.Count(part, "[")
1007+
closeSquare := strings.Count(part, "]")
1008+
openAngle := strings.Count(part, "<")
1009+
closeAngle := strings.Count(part, ">")
1010+
openBrace := strings.Count(part, "{")
1011+
closeBrace := strings.Count(part, "}")
1012+
1013+
if openSquare != closeSquare {
1014+
return nil, xerrors.Errorf("malformed usage: unbalanced square bracket at %q", part)
1015+
} else if openAngle != closeAngle {
1016+
return nil, xerrors.Errorf("malformed usage: unbalanced angle bracket at %q", part)
1017+
} else if openBrace != closeBrace {
1018+
return nil, xerrors.Errorf("malformed usage: unbalanced brace at %q", part)
1019+
}
1020+
1021+
// Determine if the argument is optional (enclosed in square brackets)
1022+
isOptional := openSquare > 0
1023+
cleanName := strings.Trim(part, "[]{}<>.")
1024+
description := "required argument"
1025+
if isOptional {
1026+
description = "optional argument"
1027+
}
1028+
1029+
argVal := map[string]any{
1030+
"type": "string",
1031+
"description": description,
1032+
// "required": !isOptional,
1033+
}
1034+
1035+
keyName := cleanName
1036+
// If an argument specifier contains a pipe, treat it as an enum.
1037+
if strings.Contains(cleanName, "|") {
1038+
choices := strings.Split(cleanName, "|")
1039+
// Create a name by joining alternatives with underscores
1040+
keyName = strings.Join(choices, "_")
1041+
argVal["enum"] = choices
1042+
}
1043+
argKey += keyName
1044+
result[argKey] = argVal
1045+
}
1046+
1047+
return result, nil
1048+
}

mcp_test.go

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ import (
88
"strings"
99
"testing"
1010
"time"
11+
12+
"github.com/google/go-cmp/cmp"
1113
)
1214

1315
func TestToolAndResourceFields(t *testing.T) {
@@ -347,3 +349,105 @@ func TestLenientParameterHandling(t *testing.T) {
347349
})
348350
}
349351
}
352+
353+
func TestPosArgsFromCmdUsage(t *testing.T) {
354+
for _, tc := range []struct {
355+
input string
356+
expected map[string]any
357+
expectedError string
358+
}{
359+
{
360+
input: "",
361+
expectedError: "usage may not be empty",
362+
},
363+
{
364+
input: "command",
365+
expected: map[string]any{},
366+
},
367+
{
368+
input: "[flags]",
369+
expected: map[string]any{},
370+
},
371+
{
372+
input: "command [flags]",
373+
expected: map[string]any{},
374+
},
375+
{
376+
input: "command [flags] a [b] <c> [<d>] [e...]",
377+
expected: map[string]any{
378+
"arg1__a": map[string]any{
379+
"description": "required argument",
380+
"type": "string",
381+
"required": true,
382+
},
383+
"arg2__b": map[string]any{
384+
"description": "optional argument",
385+
"type": "string",
386+
},
387+
"arg3__c": map[string]any{
388+
"description": "required argument",
389+
"type": "string",
390+
"required": true,
391+
},
392+
"arg4__d": map[string]any{
393+
"description": "optional argument",
394+
"type": "string",
395+
},
396+
"arg5__e": map[string]any{
397+
"description": "optional argument",
398+
"type": "string",
399+
},
400+
},
401+
},
402+
{
403+
input: "command [flags] <a|b> [c|d]",
404+
expected: map[string]any{
405+
"arg1__a_b": map[string]any{
406+
"description": "required argument",
407+
"enum": []string{"a", "b"},
408+
"type": "string",
409+
},
410+
"arg2__c_d": map[string]any{
411+
"description": "optional argument",
412+
"enum": []string{"c", "d"},
413+
"type": "string",
414+
},
415+
},
416+
},
417+
{
418+
input: "command [flags] <a b>",
419+
expectedError: "malformed usage",
420+
},
421+
{
422+
input: "command [flags] [c | d]",
423+
expectedError: "malformed usage",
424+
},
425+
{
426+
input: "command [flags] {e f}",
427+
expectedError: "malformed usage",
428+
},
429+
} {
430+
actual, err := PosArgsFromCmdUsage(tc.input)
431+
if tc.expectedError == "" {
432+
if err != nil {
433+
t.Errorf("expected no error from %q, got %v", tc.input, err)
434+
continue
435+
}
436+
if diff := cmp.Diff(tc.expected, actual); diff != "" {
437+
t.Errorf("unexpected diff (-want +got):\n%s", diff)
438+
continue
439+
}
440+
} else {
441+
if err == nil {
442+
t.Errorf("expected error containing '%s' from input %q, got no error", tc.expectedError, tc.input)
443+
continue
444+
}
445+
if !strings.Contains(err.Error(), tc.expectedError) {
446+
t.Errorf("expected error containing '%s' from input %q, got '%s'", tc.expectedError, tc.input, err.Error())
447+
}
448+
if len(actual) != 0 {
449+
t.Errorf("expected empty result on error, got %v", actual)
450+
}
451+
}
452+
}
453+
}

0 commit comments

Comments
 (0)