Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package app

import (
"context"
"fmt"
"os"
"os/signal"
Expand Down Expand Up @@ -33,7 +34,7 @@ var _ App = (*app)(nil)
type App interface {
// Start kicks off the application and returns immediately.
// Start should only be called once.
Start()
Start(ctx context.Context)

// Stop notifies the application to exit and returns immediately.
// Stop should only be called after [Start].
Expand All @@ -45,7 +46,7 @@ type App interface {
ExitCode() int
}

func New(config nodeconfig.Config) (App, error) {
func New(ctx context.Context, config nodeconfig.Config) (App, error) {
// Set the data directory permissions to be read write.
if err := perms.ChmodR(config.DatabaseConfig.Path, true, perms.ReadWriteExecute); err != nil {
return nil, fmt.Errorf("failed to restrict the permissions of the database directory with: %w", err)
Expand All @@ -71,7 +72,7 @@ func New(config nodeconfig.Config) (App, error) {
return nil, err
}

n, err := node.New(&config, logFactory, log)
n, err := node.New(ctx, &config, logFactory, log)
if err != nil {
log.Fatal("failed to initialize node", zap.Error(err))
log.Stop()
Expand All @@ -86,9 +87,9 @@ func New(config nodeconfig.Config) (App, error) {
}, nil
}

func Run(app App) int {
func Run(ctx context.Context, app App) int {
// start running the application
app.Start()
app.Start(ctx)

// register terminationSignals to kill the application
terminationSignals := make(chan os.Signal, 1)
Expand Down Expand Up @@ -138,7 +139,7 @@ type app struct {

// Start the business logic of the node (as opposed to config reading, etc).
// Does not block until the node is done.
func (a *app) Start() {
func (a *app) Start(ctx context.Context) {
// [p.ExitCode] will block until [p.exitWG.Done] is called
a.exitWG.Add(1)
go func() {
Expand All @@ -157,7 +158,7 @@ func (a *app) Start() {
a.log.StopOnPanic()
}()

err := a.node.Dispatch()
err := a.node.Dispatch(ctx)
a.log.Debug("dispatch returned",
zap.Error(err),
)
Expand Down
5 changes: 3 additions & 2 deletions main/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package main

import (
"context"
"encoding/json"
"errors"
"fmt"
Expand Down Expand Up @@ -61,12 +62,12 @@ func main() {
fmt.Println(app.Header)
}

nodeApp, err := app.New(nodeConfig)
nodeApp, err := app.New(context.Background(), nodeConfig)
if err != nil {
fmt.Printf("couldn't start node: %s\n", err)
os.Exit(1)
}

exitCode := app.Run(nodeApp)
exitCode := app.Run(context.Background(), nodeApp)
os.Exit(exitCode)
}
5 changes: 3 additions & 2 deletions nat/nat.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package nat

import (
"context"
"net/netip"
"sync"
"time"
Expand Down Expand Up @@ -33,15 +34,15 @@ type Router interface {
}

// GetRouter returns a router on the current network.
func GetRouter() Router {
func GetRouter(ctx context.Context) Router {
if r := getUPnPRouter(); r != nil {
return r
}
if r := getPMPRouter(); r != nil {
return r
}

return NewNoRouter()
return NewNoRouter(ctx)
}

// Mapper attempts to open a set of ports on a router
Expand Down
13 changes: 9 additions & 4 deletions nat/no_router.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package nat

import (
"context"
"errors"
"net"
"net/netip"
Expand Down Expand Up @@ -40,8 +41,12 @@ func (r noRouter) ExternalIP() (netip.Addr, error) {
return r.ip, r.ipErr
}

func getOutboundIP() (netip.Addr, error) {
conn, err := net.Dial("udp", googleDNSServer)
func getOutboundIP(ctx context.Context) (netip.Addr, error) {
conn, err := (&net.Dialer{}).DialContext(
ctx,
"udp",
googleDNSServer,
)
if err != nil {
return netip.Addr{}, err
}
Expand All @@ -63,8 +68,8 @@ func getOutboundIP() (netip.Addr, error) {
}

// NewNoRouter returns a router that assumes the network is public
func NewNoRouter() Router {
ip, err := getOutboundIP()
func NewNoRouter(ctx context.Context) Router {
ip, err := getOutboundIP(ctx)
return &noRouter{
ip: ip,
ipErr: err,
Expand Down
6 changes: 5 additions & 1 deletion network/dialer/dialer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@ func TestDialerDialCanceledContext(t *testing.T) {
func TestDialerDial(t *testing.T) {
require := require.New(t)

l, err := net.Listen("tcp", "127.0.0.1:0")
l, err := (&net.ListenConfig{}).Listen(
context.Background(),
"tcp",
"127.0.0.1:0",
)
require.NoError(err)

listenedAddrPort, err := netip.ParseAddrPort(l.Addr().String())
Expand Down
8 changes: 6 additions & 2 deletions network/example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,11 @@ func ExampleNewTestNetwork() {
// gossip will enable connecting to all the remaining nodes in the network.
bootstrappers := genesis.SampleBootstrappers(constants.FujiID, 5)
for _, bootstrapper := range bootstrappers {
network.ManuallyTrack(bootstrapper.ID, bootstrapper.IP)
network.ManuallyTrack(
context.Background(),
bootstrapper.ID,
bootstrapper.IP,
)
}

// Typically network.StartClose() should be called based on receiving a
Expand All @@ -137,7 +141,7 @@ func ExampleNewTestNetwork() {

// Calling network.Dispatch() will block until a fatal error occurs or
// network.StartClose() is called.
err = network.Dispatch()
err = network.Dispatch(context.Background())
log.Info(
"network exited",
zap.Error(err),
Expand Down
66 changes: 45 additions & 21 deletions network/network.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,11 @@ type Network interface {

// Should only be called once, will run until either a fatal error occurs,
// or the network is closed.
Dispatch() error
Dispatch(ctx context.Context) error

// Attempt to connect to this IP. The network will never stop attempting to
// connect to this ID.
ManuallyTrack(nodeID ids.NodeID, ip netip.AddrPort)
ManuallyTrack(ctx context.Context, nodeID ids.NodeID, ip netip.AddrPort)

// PeerInfo returns information about peers. If [nodeIDs] is empty, returns
// info about all peers that have finished the handshake. Otherwise, returns
Expand Down Expand Up @@ -506,10 +506,13 @@ func (n *network) AllowConnection(nodeID ids.NodeID) bool {
return areWeAPrimaryNetworkAValidator || n.ipTracker.WantsConnection(nodeID)
}

func (n *network) Track(claimedIPPorts []*ips.ClaimedIPPort) error {
func (n *network) Track(
ctx context.Context,
claimedIPPorts []*ips.ClaimedIPPort,
) error {
_, areWeAPrimaryNetworkAValidator := n.config.Validators.GetValidator(constants.PrimaryNetworkID, n.config.MyNodeID)
for _, ip := range claimedIPPorts {
if err := n.track(ip, areWeAPrimaryNetworkAValidator); err != nil {
if err := n.track(ctx, ip, areWeAPrimaryNetworkAValidator); err != nil {
return err
}
}
Expand All @@ -521,17 +524,17 @@ func (n *network) Track(claimedIPPorts []*ips.ClaimedIPPort) error {
// It is guaranteed that [Connected] will not be called with [nodeID] after this
// call. Note that this is from the perspective of a single peer object, because
// a peer with the same ID can reconnect to this network instance.
func (n *network) Disconnected(nodeID ids.NodeID) {
func (n *network) Disconnected(ctx context.Context, nodeID ids.NodeID) {
n.peersLock.RLock()
_, connecting := n.connectingPeers.GetByID(nodeID)
peer, connected := n.connectedPeers.GetByID(nodeID)
n.peersLock.RUnlock()

if connecting {
n.disconnectedFromConnecting(nodeID)
n.disconnectedFromConnecting(ctx, nodeID)
}
if connected {
n.disconnectedFromConnected(peer, nodeID)
n.disconnectedFromConnected(ctx, peer, nodeID)
}
}

Expand Down Expand Up @@ -599,7 +602,7 @@ func (n *network) Peers(

// Dispatch starts accepting connections from other nodes attempting to connect
// to this node.
func (n *network) Dispatch() error {
func (n *network) Dispatch(ctx context.Context) error {
go n.runTimers() // Periodically perform operations
go n.inboundConnUpgradeThrottler.Dispatch()
for n.onCloseCtx.Err() == nil { // Continuously accept new connections
Expand Down Expand Up @@ -647,7 +650,7 @@ func (n *network) Dispatch() error {
zap.Stringer("peerIP", ip),
)

if err := n.upgrade(conn, n.serverUpgrader, true); err != nil {
if err := n.upgrade(ctx, conn, n.serverUpgrader, true); err != nil {
n.peerConfig.Log.Verbo("failed to upgrade connection",
zap.String("direction", "inbound"),
zap.Error(err),
Expand All @@ -670,7 +673,11 @@ func (n *network) Dispatch() error {
return errs.Err
}

func (n *network) ManuallyTrack(nodeID ids.NodeID, ip netip.AddrPort) {
func (n *network) ManuallyTrack(
ctx context.Context,
nodeID ids.NodeID,
ip netip.AddrPort,
) {
n.ipTracker.ManuallyTrack(nodeID)

n.peersLock.Lock()
Expand All @@ -688,11 +695,15 @@ func (n *network) ManuallyTrack(nodeID ids.NodeID, ip netip.AddrPort) {
if !isTracked {
tracked := newTrackedIP(ip)
n.trackedIPs[nodeID] = tracked
n.dial(nodeID, tracked)
n.dial(ctx, nodeID, tracked)
}
}

func (n *network) track(ip *ips.ClaimedIPPort, trackAllSubnets bool) error {
func (n *network) track(
ctx context.Context,
ip *ips.ClaimedIPPort,
trackAllSubnets bool,
) error {
// To avoid signature verification when the IP isn't needed, we
// optimistically filter out IPs. This can result in us not tracking an IP
// that we otherwise would have. This case can only happen if the node
Expand Down Expand Up @@ -741,7 +752,7 @@ func (n *network) track(ip *ips.ClaimedIPPort, trackAllSubnets bool) error {
tracked = newTrackedIP(ip.AddrPort)
}
n.trackedIPs[ip.NodeID] = tracked
n.dial(ip.NodeID, tracked)
n.dial(ctx, ip.NodeID, tracked)
return nil
}

Expand Down Expand Up @@ -834,7 +845,10 @@ func (n *network) samplePeers(
)
}

func (n *network) disconnectedFromConnecting(nodeID ids.NodeID) {
func (n *network) disconnectedFromConnecting(
ctx context.Context,
nodeID ids.NodeID,
) {
n.peersLock.Lock()
defer n.peersLock.Unlock()

Expand All @@ -846,7 +860,7 @@ func (n *network) disconnectedFromConnecting(nodeID ids.NodeID) {
if n.ipTracker.WantsConnection(nodeID) {
tracked := tracked.trackNewIP(tracked.ip)
n.trackedIPs[nodeID] = tracked
n.dial(nodeID, tracked)
n.dial(ctx, nodeID, tracked)
} else {
tracked.stopTracking()
delete(n.trackedIPs, nodeID)
Expand All @@ -856,7 +870,11 @@ func (n *network) disconnectedFromConnecting(nodeID ids.NodeID) {
n.metrics.disconnected.Inc()
}

func (n *network) disconnectedFromConnected(peer peer.Peer, nodeID ids.NodeID) {
func (n *network) disconnectedFromConnected(
ctx context.Context,
peer peer.Peer,
nodeID ids.NodeID,
) {
n.ipTracker.Disconnected(nodeID)
n.router.Disconnected(nodeID)

Expand All @@ -869,7 +887,7 @@ func (n *network) disconnectedFromConnected(peer peer.Peer, nodeID ids.NodeID) {
if ip, wantsConnection := n.ipTracker.GetIP(nodeID); wantsConnection {
tracked := newTrackedIP(ip.AddrPort)
n.trackedIPs[nodeID] = tracked
n.dial(nodeID, tracked)
n.dial(ctx, nodeID, tracked)
}

n.metrics.markDisconnected(peer)
Expand All @@ -894,7 +912,7 @@ func (n *network) disconnectedFromConnected(peer peer.Peer, nodeID ids.NodeID) {
// If initiating a connection to [ip] fails, then dial will reattempt. However,
// there is a randomized exponential backoff to avoid spamming connection
// attempts.
func (n *network) dial(nodeID ids.NodeID, ip *trackedIP) {
func (n *network) dial(ctx context.Context, nodeID ids.NodeID, ip *trackedIP) {
n.peerConfig.Log.Verbo("attempting to dial node",
zap.Stringer("nodeID", nodeID),
zap.Stringer("ip", ip.ip),
Expand Down Expand Up @@ -994,7 +1012,7 @@ func (n *network) dial(nodeID ids.NodeID, ip *trackedIP) {
zap.Stringer("peerIP", ip.ip),
)

err = n.upgrade(conn, n.clientUpgrader, false)
err = n.upgrade(ctx, conn, n.clientUpgrader, false)
if err != nil {
n.peerConfig.Log.Verbo(
"failed to upgrade, attempting again",
Expand All @@ -1017,7 +1035,12 @@ func (n *network) dial(nodeID ids.NodeID, ip *trackedIP) {
// If the connection is desired by the node, then the resulting upgraded
// connection will be used to create a new peer. Otherwise the connection will
// be immediately closed.
func (n *network) upgrade(conn net.Conn, upgrader peer.Upgrader, isIngress bool) error {
func (n *network) upgrade(
ctx context.Context,
conn net.Conn,
upgrader peer.Upgrader,
isIngress bool,
) error {
upgradeTimeout := n.peerConfig.Clock.Time().Add(n.config.ReadHandshakeTimeout)
if err := conn.SetReadDeadline(upgradeTimeout); err != nil {
_ = conn.Close()
Expand All @@ -1027,7 +1050,7 @@ func (n *network) upgrade(conn net.Conn, upgrader peer.Upgrader, isIngress bool)
return err
}

nodeID, tlsConn, cert, err := upgrader.Upgrade(conn)
nodeID, tlsConn, cert, err := upgrader.Upgrade(ctx, conn)
if err != nil {
_ = conn.Close()
n.peerConfig.Log.Verbo("failed to upgrade connection",
Expand Down Expand Up @@ -1107,6 +1130,7 @@ func (n *network) upgrade(conn net.Conn, upgrader peer.Upgrader, isIngress bool)
// same [peerConfig.InboundMsgThrottler]. This is guaranteed by the above
// de-duplications for [connectingPeers] and [connectedPeers].
peer := peer.Start(
ctx,
n.peerConfig,
tlsConn,
cert,
Expand Down
Loading
Loading