From 275fa9ca0306f59e50655ff7087e2ad14545eb80 Mon Sep 17 00:00:00 2001 From: elewis787 Date: Wed, 1 Oct 2025 11:41:49 -0700 Subject: [PATCH] feat: add additional process kill for sigterm and sigkill --- client/transport/stdio.go | 101 ++++++++++++++++++++++++++++---------- 1 file changed, 75 insertions(+), 26 deletions(-) diff --git a/client/transport/stdio.go b/client/transport/stdio.go index e4f26857..26c1d73a 100644 --- a/client/transport/stdio.go +++ b/client/transport/stdio.go @@ -11,6 +11,8 @@ import ( "os/exec" "strings" "sync" + "syscall" + "time" "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/util" @@ -25,21 +27,22 @@ type Stdio struct { args []string env []string - cmd *exec.Cmd - cmdFunc CommandFunc - stdin io.WriteCloser - stdout *bufio.Reader - stderr io.ReadCloser - responses map[string]chan *JSONRPCResponse - mu sync.RWMutex - done chan struct{} - onNotification func(mcp.JSONRPCNotification) - notifyMu sync.RWMutex - onRequest RequestHandler - requestMu sync.RWMutex - ctx context.Context - ctxMu sync.RWMutex - logger util.Logger + cmd *exec.Cmd + cmdFunc CommandFunc + stdin io.WriteCloser + stdout *bufio.Reader + stderr io.ReadCloser + responses map[string]chan *JSONRPCResponse + mu sync.RWMutex + done chan struct{} + onNotification func(mcp.JSONRPCNotification) + notifyMu sync.RWMutex + onRequest RequestHandler + requestMu sync.RWMutex + ctx context.Context + ctxMu sync.RWMutex + logger util.Logger + terminateDuration time.Duration } // StdioOption defines a function that configures a Stdio transport instance. @@ -67,6 +70,13 @@ func WithCommandLogger(logger util.Logger) StdioOption { } } +// WithTerminateDuration sets the duration to wait for graceful shutdown before sending SIGTERM. +func WithTerminateDuration(duration time.Duration) StdioOption { + return func(s *Stdio) { + s.terminateDuration = duration + } +} + // NewIO returns a new stdio-based transport using existing input, output, and // logging streams instead of spawning a subprocess. // This is useful for testing and simulating client behavior. @@ -76,10 +86,11 @@ func NewIO(input io.Reader, output io.WriteCloser, logging io.ReadCloser) *Stdio stdout: bufio.NewReader(input), stderr: logging, - responses: make(map[string]chan *JSONRPCResponse), - done: make(chan struct{}), - ctx: context.Background(), - logger: util.DefaultLogger(), + responses: make(map[string]chan *JSONRPCResponse), + done: make(chan struct{}), + ctx: context.Background(), + logger: util.DefaultLogger(), + terminateDuration: 5 * time.Second, // Default 5 second timeout } } @@ -110,10 +121,11 @@ func NewStdioWithOptions( args: args, env: env, - responses: make(map[string]chan *JSONRPCResponse), - done: make(chan struct{}), - ctx: context.Background(), - logger: util.DefaultLogger(), + responses: make(map[string]chan *JSONRPCResponse), + done: make(chan struct{}), + ctx: context.Background(), + logger: util.DefaultLogger(), + terminateDuration: 5 * time.Second, // Default 5 second timeout } for _, opt := range opts { @@ -190,8 +202,10 @@ func (c *Stdio) spawnCommand(ctx context.Context) error { return nil } -// Close shuts down the stdio client, closing the stdin pipe and waiting for the subprocess to exit. -// Returns an error if there are issues closing stdin or waiting for the subprocess to terminate. +// Close closes the input stream to the child process, and awaits normal +// termination of the command. If the command does not exit, it is signalled to +// terminate, and then eventually killed. This follows the MCP specification +// for stdio transport shutdown. func (c *Stdio) Close() error { select { case <-c.done: @@ -201,6 +215,8 @@ func (c *Stdio) Close() error { // cancel all in-flight request close(c.done) + // For the stdio transport, the client SHOULD initiate shutdown by: + // First, closing the input stream to the child process (the server) if c.stdin != nil { if err := c.stdin.Close(); err != nil { return fmt.Errorf("failed to close stdin: %w", err) @@ -213,7 +229,40 @@ func (c *Stdio) Close() error { } if c.cmd != nil { - return c.cmd.Wait() + resChan := make(chan error, 1) + go func() { + resChan <- c.cmd.Wait() + }() + + // Waiting for the server to exit, or sending SIGTERM if the server does not exit within a reasonable time + wait := func() (error, bool) { + select { + case err := <-resChan: + return err, true + case <-time.After(c.terminateDuration): + } + return nil, false + } + + if err, ok := wait(); ok { + return err + } + + // Note the condition here: if sending SIGTERM fails, don't wait and just + // move on to SIGKILL. + if err := c.cmd.Process.Signal(syscall.SIGTERM); err == nil { + if err, ok := wait(); ok { + return err + } + } + // Sending SIGKILL if the server does not exit within a reasonable time after SIGTERM + if err := c.cmd.Process.Kill(); err != nil { + return err + } + if err, ok := wait(); ok { + return err + } + return fmt.Errorf("unresponsive subprocess") } return nil