Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 75 additions & 26 deletions client/transport/stdio.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import (
"os/exec"
"strings"
"sync"
"syscall"
"time"

"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/util"
Expand All @@ -25,21 +27,22 @@ type Stdio struct {
args []string
env []string

cmd *exec.Cmd
cmdFunc CommandFunc
stdin io.WriteCloser
stdout *bufio.Reader
stderr io.ReadCloser
responses map[string]chan *JSONRPCResponse
mu sync.RWMutex
done chan struct{}
onNotification func(mcp.JSONRPCNotification)
notifyMu sync.RWMutex
onRequest RequestHandler
requestMu sync.RWMutex
ctx context.Context
ctxMu sync.RWMutex
logger util.Logger
cmd *exec.Cmd
cmdFunc CommandFunc
stdin io.WriteCloser
stdout *bufio.Reader
stderr io.ReadCloser
responses map[string]chan *JSONRPCResponse
mu sync.RWMutex
done chan struct{}
onNotification func(mcp.JSONRPCNotification)
notifyMu sync.RWMutex
onRequest RequestHandler
requestMu sync.RWMutex
ctx context.Context
ctxMu sync.RWMutex
logger util.Logger
terminateDuration time.Duration
}

// StdioOption defines a function that configures a Stdio transport instance.
Expand Down Expand Up @@ -67,6 +70,13 @@ func WithCommandLogger(logger util.Logger) StdioOption {
}
}

// WithTerminateDuration sets the duration to wait for graceful shutdown before sending SIGTERM.
func WithTerminateDuration(duration time.Duration) StdioOption {
return func(s *Stdio) {
s.terminateDuration = duration
}
}

// NewIO returns a new stdio-based transport using existing input, output, and
// logging streams instead of spawning a subprocess.
// This is useful for testing and simulating client behavior.
Expand All @@ -76,10 +86,11 @@ func NewIO(input io.Reader, output io.WriteCloser, logging io.ReadCloser) *Stdio
stdout: bufio.NewReader(input),
stderr: logging,

responses: make(map[string]chan *JSONRPCResponse),
done: make(chan struct{}),
ctx: context.Background(),
logger: util.DefaultLogger(),
responses: make(map[string]chan *JSONRPCResponse),
done: make(chan struct{}),
ctx: context.Background(),
logger: util.DefaultLogger(),
terminateDuration: 5 * time.Second, // Default 5 second timeout
}
}

Expand Down Expand Up @@ -110,10 +121,11 @@ func NewStdioWithOptions(
args: args,
env: env,

responses: make(map[string]chan *JSONRPCResponse),
done: make(chan struct{}),
ctx: context.Background(),
logger: util.DefaultLogger(),
responses: make(map[string]chan *JSONRPCResponse),
done: make(chan struct{}),
ctx: context.Background(),
logger: util.DefaultLogger(),
terminateDuration: 5 * time.Second, // Default 5 second timeout
}

for _, opt := range opts {
Expand Down Expand Up @@ -190,8 +202,10 @@ func (c *Stdio) spawnCommand(ctx context.Context) error {
return nil
}

// Close shuts down the stdio client, closing the stdin pipe and waiting for the subprocess to exit.
// Returns an error if there are issues closing stdin or waiting for the subprocess to terminate.
// Close closes the input stream to the child process, and awaits normal
// termination of the command. If the command does not exit, it is signalled to
// terminate, and then eventually killed. This follows the MCP specification
// for stdio transport shutdown.
func (c *Stdio) Close() error {
select {
case <-c.done:
Expand All @@ -201,6 +215,8 @@ func (c *Stdio) Close() error {
// cancel all in-flight request
close(c.done)

// For the stdio transport, the client SHOULD initiate shutdown by:
// First, closing the input stream to the child process (the server)
if c.stdin != nil {
if err := c.stdin.Close(); err != nil {
return fmt.Errorf("failed to close stdin: %w", err)
Expand All @@ -213,7 +229,40 @@ func (c *Stdio) Close() error {
}

if c.cmd != nil {
return c.cmd.Wait()
resChan := make(chan error, 1)
go func() {
resChan <- c.cmd.Wait()
}()

// Waiting for the server to exit, or sending SIGTERM if the server does not exit within a reasonable time
wait := func() (error, bool) {
select {
case err := <-resChan:
return err, true
case <-time.After(c.terminateDuration):
}
return nil, false
}

if err, ok := wait(); ok {
return err
}

// Note the condition here: if sending SIGTERM fails, don't wait and just
// move on to SIGKILL.
if err := c.cmd.Process.Signal(syscall.SIGTERM); err == nil {
if err, ok := wait(); ok {
return err
}
}
// Sending SIGKILL if the server does not exit within a reasonable time after SIGTERM
if err := c.cmd.Process.Kill(); err != nil {
return err
}
if err, ok := wait(); ok {
return err
}
return fmt.Errorf("unresponsive subprocess")
Comment on lines +247 to +265
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Avoid returning an error when the child already exited.

wait() can fall through when cmd.Wait() finishes just as the timeout fires. In that case resChan already holds the exit result, but we still hit the fallback path. Because the process is already gone, Process.Signal/Process.Kill return os.ErrProcessDone, which bubbles up and causes Close() to report an error even though the subprocess shut down cleanly. Please drain resChan (or treat os.ErrProcessDone as success) before signalling, so we don’t misreport a failure when the child already exited.

🤖 Prompt for AI Agents
In client/transport/stdio.go around lines 247 to 265, Close() can return an
error when the child already exited because we signal/kill without first
draining the wait result; Process.Signal/Process.Kill then return
os.ErrProcessDone and bubbles up. Fix by first draining the wait result channel
(or calling wait() and treating os.ErrProcessDone as success) before attempting
to Signal/Kill; if wait() indicates the process already exited, return that
result (or nil if successful) and do not call Signal/Kill, otherwise proceed
with SIGTERM/SIGKILL as now.

}

return nil
Expand Down