Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ require (
github.com/eiannone/keyboard v0.0.0-20220611211555-0d226195f203
github.com/fsnotify/fsevents v0.2.0
github.com/google/go-cmp v0.7.0
github.com/hashicorp/go-multierror v1.1.1
github.com/hashicorp/go-version v1.7.0
github.com/jonboulle/clockwork v0.5.0
github.com/mattn/go-shellwords v1.0.12
Expand Down Expand Up @@ -117,6 +116,7 @@ require (
github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.1 // indirect
github.com/hashicorp/errwrap v1.1.0 // indirect
github.com/hashicorp/go-cleanhttp v0.5.2 // indirect
github.com/hashicorp/go-multierror v1.1.1 // indirect
github.com/in-toto/in-toto-golang v0.9.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/inhies/go-bytesize v0.0.0-20220417184213-4913239db9cf // indirect
Expand Down
91 changes: 59 additions & 32 deletions pkg/compose/up.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@ package compose

import (
"context"
"errors"
"fmt"
"os"
"os/signal"
"slices"
"sync"
"sync/atomic"
"syscall"

Expand All @@ -33,7 +35,6 @@ import (
"github.com/docker/compose/v2/pkg/api"
"github.com/docker/compose/v2/pkg/progress"
"github.com/eiannone/keyboard"
"github.com/hashicorp/go-multierror"
"github.com/sirupsen/logrus"
"golang.org/x/sync/errgroup"
)
Expand Down Expand Up @@ -61,14 +62,11 @@ func (s *composeService) Up(ctx context.Context, project *types.Project, options
return err
}

var eg multierror.Group

// if we get a second signal during shutdown, we kill the services
// immediately, so the channel needs to have sufficient capacity or
// we might miss a signal while setting up the second channel read
// (this is also why signal.Notify is used vs signal.NotifyContext)
signalChan := make(chan os.Signal, 2)
defer close(signalChan)
signal.Notify(signalChan, syscall.SIGINT, syscall.SIGTERM)
defer signal.Stop(signalChan)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Closing the signalChan isn't needed for these; signal.Stop should be enough.

var isTerminated atomic.Bool
Expand Down Expand Up @@ -103,26 +101,45 @@ func (s *composeService) Up(ctx context.Context, project *types.Project, options

printer := newLogPrinter(logConsumer)

doneCh := make(chan bool)
// global context to handle canceling goroutines
globalCtx, cancel := context.WithCancel(ctx)
defer cancel()

var (
eg errgroup.Group
mu sync.Mutex
errs []error
)

appendErr := func(err error) {
if err != nil {
mu.Lock()
errs = append(errs, err)
mu.Unlock()
}
}

eg.Go(func() error {
first := true
gracefulTeardown := func() {
first = false
fmt.Println("Gracefully Stopping... press Ctrl+C again to force")
eg.Go(func() error {
return progress.RunWithLog(context.WithoutCancel(ctx), func(ctx context.Context) error {
return s.stop(ctx, project.Name, api.StopOptions{
err := progress.RunWithLog(context.WithoutCancel(globalCtx), func(c context.Context) error {
return s.stop(c, project.Name, api.StopOptions{
Services: options.Create.Services,
Project: project,
}, printer.HandleEvent)
}, s.stdinfo(), logConsumer)
appendErr(err)
return nil
})
isTerminated.Store(true)
}

for {
select {
case <-doneCh:
case <-globalCtx.Done():
if watcher != nil {
Comment on lines 141 to 143
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had a bit of input from my AI friends to see if we could have everything handled through the context (there were some paths that were a bit involved, and I saw some paths were possibly errgroup.Wait were not handled).

So pay close attention if I did it right, as I'm not so familiar with this code ❤️ 🙈

return watcher.Stop()
}
Expand All @@ -133,12 +150,12 @@ func (s *composeService) Up(ctx context.Context, project *types.Project, options
}
case <-signalChan:
if first {
keyboard.Close() //nolint:errcheck
_ = keyboard.Close()
gracefulTeardown()
break
}
eg.Go(func() error {
err := s.kill(context.WithoutCancel(ctx), project.Name, api.KillOptions{
err := s.kill(context.WithoutCancel(globalCtx), project.Name, api.KillOptions{
Services: options.Create.Services,
Project: project,
All: true,
Expand All @@ -148,18 +165,21 @@ func (s *composeService) Up(ctx context.Context, project *types.Project, options
return nil
}

return err
appendErr(err)
return nil
})
return nil
case event := <-kEvents:
navigationMenu.HandleKeyEvents(ctx, event, project, options)
navigationMenu.HandleKeyEvents(globalCtx, event, project, options)
}
}
})

if options.Start.Watch && watcher != nil {
err = watcher.Start(ctx)
if err != nil {
if err := watcher.Start(globalCtx); err != nil {
// cancel the global context to terminate background goroutines
cancel()
_ = eg.Wait()
return err
}
}
Expand All @@ -186,12 +206,14 @@ func (s *composeService) Up(ctx context.Context, project *types.Project, options
exitCode = event.ExitCode
_, _ = fmt.Fprintln(s.stdinfo(), progress.ErrorColor("Aborting on container exit..."))
eg.Go(func() error {
return progress.RunWithLog(context.WithoutCancel(ctx), func(ctx context.Context) error {
return s.stop(ctx, project.Name, api.StopOptions{
err := progress.RunWithLog(context.WithoutCancel(globalCtx), func(c context.Context) error {
return s.stop(c, project.Name, api.StopOptions{
Services: options.Create.Services,
Project: project,
}, printer.HandleEvent)
}, s.stdinfo(), logConsumer)
appendErr(err)
return nil
})
}
})
Expand All @@ -208,13 +230,10 @@ func (s *composeService) Up(ctx context.Context, project *types.Project, options
})
}

// use an independent context tied to the errgroup for background attach operations
// the primary context is still used for other operations
// this means that once any attach operation fails, all other attaches are cancelled,
// but an attach failing won't interfere with the rest of the start
_, attachCtx := errgroup.WithContext(ctx)
containers, err := s.attach(attachCtx, project, printer.HandleEvent, options.Start.AttachTo)
containers, err := s.attach(globalCtx, project, printer.HandleEvent, options.Start.AttachTo)
if err != nil {
cancel()
_ = eg.Wait()
return err
}
attached := make([]string, len(containers))
Expand All @@ -230,38 +249,46 @@ func (s *composeService) Up(ctx context.Context, project *types.Project, options
return
}
eg.Go(func() error {
ctr, err := s.apiClient().ContainerInspect(ctx, event.ID)
ctr, err := s.apiClient().ContainerInspect(globalCtx, event.ID)
if err != nil {
return err
appendErr(err)
return nil
}

err = s.doLogContainer(ctx, options.Start.Attach, event.Source, ctr, api.LogOptions{
err = s.doLogContainer(globalCtx, options.Start.Attach, event.Source, ctr, api.LogOptions{
Follow: true,
Since: ctr.State.StartedAt,
})
if errdefs.IsNotImplemented(err) {
// container may be configured with logging_driver: none
// as container already started, we might miss the very first logs. But still better than none
return s.doAttachContainer(ctx, event.Service, event.ID, event.Source, printer.HandleEvent)
err := s.doAttachContainer(globalCtx, event.Service, event.ID, event.Source, printer.HandleEvent)
appendErr(err)
return nil
}
return err
appendErr(err)
return nil
})
})

eg.Go(func() error {
err := monitor.Start(context.Background())
// Signal for the signal-handler goroutines to stop
close(doneCh)
return err
err := monitor.Start(globalCtx)
// cancel the global context to terminate signal-handler goroutines
cancel()
appendErr(err)
return nil
})

// We use the parent context without cancellation as we manage sigterm to stop the stack
err = s.start(context.WithoutCancel(ctx), project.Name, options.Start, printer.HandleEvent)
if err != nil && !isTerminated.Load() { // Ignore error if the process is terminated
cancel()
_ = eg.Wait()
return err
}

err = eg.Wait().ErrorOrNil()
_ = eg.Wait()
err = errors.Join(errs...)
if exitCode != 0 {
errMsg := ""
if err != nil {
Expand Down
Loading