Skip to content

Commit e0fddea

Browse files
authored
Merge pull request #52 from coder/spike/nodrop
fix: block writes from gVisor to tailscale instead of dropping
2 parents 5cd256c + 6225460 commit e0fddea

File tree

4 files changed

+394
-3
lines changed

4 files changed

+394
-3
lines changed

AUTHORS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,5 @@
1414
# earlier contributions and clarifying whether it's you or your
1515
# company that owns the rights to your contribution.
1616

17+
Coder Technologies, Inc.
1718
Tailscale Inc.

wgengine/netstack/endpoint.go

Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
// based on https://github.com/google/gvisor/blob/74f22885dc45e2866985fe7179103e1000382415/pkg/tcpip/link/channel/channel.go
2+
//
3+
// Copyright 2018 The gVisor Authors.
4+
//
5+
// Licensed under the Apache License, Version 2.0 (the "License");
6+
// you may not use this file except in compliance with the License.
7+
// You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing, software
12+
// distributed under the License is distributed on an "AS IS" BASIS,
13+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
// See the License for the specific language governing permissions and
15+
// limitations under the License.
16+
//
17+
// Modifications from original source are Copyright 2024 Tailscale Inc & AUTHORS
18+
19+
package netstack
20+
21+
import (
22+
"context"
23+
24+
"gvisor.dev/gvisor/pkg/sync"
25+
"gvisor.dev/gvisor/pkg/tcpip"
26+
"gvisor.dev/gvisor/pkg/tcpip/header"
27+
"gvisor.dev/gvisor/pkg/tcpip/stack"
28+
)
29+
30+
type queue struct {
31+
// c is the outbound packet channel.
32+
c chan *stack.PacketBuffer
33+
mu sync.RWMutex
34+
// +checklocks:mu
35+
closed bool
36+
37+
closedChOnce sync.Once
38+
closedCh chan struct{}
39+
}
40+
41+
func (q *queue) Close() {
42+
// This unblocks any calls to Write() which might be holding the mu.
43+
q.closedChOnce.Do(func() {
44+
close(q.closedCh)
45+
})
46+
47+
q.mu.Lock()
48+
defer q.mu.Unlock()
49+
if q.closed {
50+
return
51+
}
52+
close(q.c)
53+
q.closed = true
54+
}
55+
56+
func (q *queue) Read() *stack.PacketBuffer {
57+
select {
58+
case p := <-q.c:
59+
return p
60+
default:
61+
return nil
62+
}
63+
}
64+
65+
func (q *queue) ReadContext(ctx context.Context) *stack.PacketBuffer {
66+
select {
67+
case pkt := <-q.c:
68+
return pkt
69+
case <-ctx.Done():
70+
return nil
71+
}
72+
}
73+
74+
func (q *queue) Write(pkt *stack.PacketBuffer) tcpip.Error {
75+
q.mu.RLock()
76+
defer q.mu.RUnlock()
77+
if q.closed {
78+
return &tcpip.ErrClosedForSend{}
79+
}
80+
select {
81+
case q.c <- pkt.IncRef():
82+
return nil
83+
case <-q.closedCh:
84+
pkt.DecRef()
85+
return &tcpip.ErrClosedForSend{}
86+
}
87+
}
88+
89+
func (q *queue) Num() int {
90+
return len(q.c)
91+
}
92+
93+
var _ stack.LinkEndpoint = (*Endpoint)(nil)
94+
var _ stack.GSOEndpoint = (*Endpoint)(nil)
95+
96+
// Endpoint is link layer endpoint that stores outbound packets in a channel
97+
// and allows injection of inbound packets. It is based on gVisor
98+
// channel.Endpoint, however when the channel is full, it blocks writes until
99+
// there is space in the channel or until the Endpoint is closed. The gVisor
100+
// version dropped packets if the channel is full. This limits TCP throughput
101+
// as dropped packets need to be retransmitted and are interpreted as a
102+
// congestion event, causing the TCP sender to decrease the congestion window.
103+
// Much better to apply back-pressure to the TCP stack at the Endpoint.
104+
type Endpoint struct {
105+
mtu uint32
106+
linkAddr tcpip.LinkAddress
107+
LinkEPCapabilities stack.LinkEndpointCapabilities
108+
SupportedGSOKind stack.SupportedGSO
109+
110+
mu sync.RWMutex
111+
// +checklocks:mu
112+
dispatcher stack.NetworkDispatcher
113+
114+
// Outbound packet queue.
115+
q *queue
116+
}
117+
118+
// NewEndpoint creates a new channel endpoint.
119+
func NewEndpoint(size int, mtu uint32, linkAddr tcpip.LinkAddress) *Endpoint {
120+
return &Endpoint{
121+
q: &queue{
122+
c: make(chan *stack.PacketBuffer, size),
123+
closedCh: make(chan struct{}),
124+
},
125+
mtu: mtu,
126+
linkAddr: linkAddr,
127+
}
128+
}
129+
130+
// Close closes e. Further packet injections will return an error, and all pending
131+
// packets are discarded. Close may be called concurrently with WritePackets.
132+
func (e *Endpoint) Close() {
133+
e.q.Close()
134+
e.Drain()
135+
}
136+
137+
// Read does non-blocking read one packet from the outbound packet queue.
138+
func (e *Endpoint) Read() *stack.PacketBuffer {
139+
return e.q.Read()
140+
}
141+
142+
// ReadContext does blocking read for one packet from the outbound packet queue.
143+
// It can be cancelled by ctx, and in this case, it returns nil.
144+
func (e *Endpoint) ReadContext(ctx context.Context) *stack.PacketBuffer {
145+
return e.q.ReadContext(ctx)
146+
}
147+
148+
// Drain removes all outbound packets from the channel and counts them.
149+
func (e *Endpoint) Drain() int {
150+
c := 0
151+
for pkt := e.Read(); pkt != nil; pkt = e.Read() {
152+
pkt.DecRef()
153+
c++
154+
}
155+
return c
156+
}
157+
158+
// NumQueued returns the number of packet queued for outbound.
159+
func (e *Endpoint) NumQueued() int {
160+
return e.q.Num()
161+
}
162+
163+
// InjectInbound injects an inbound packet. If the endpoint is not attached, the
164+
// packet is not delivered.
165+
func (e *Endpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
166+
e.mu.RLock()
167+
d := e.dispatcher
168+
e.mu.RUnlock()
169+
if d != nil {
170+
d.DeliverNetworkPacket(protocol, pkt)
171+
}
172+
}
173+
174+
// Attach saves the stack network-layer dispatcher for use later when packets
175+
// are injected.
176+
func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) {
177+
e.mu.Lock()
178+
defer e.mu.Unlock()
179+
e.dispatcher = dispatcher
180+
}
181+
182+
// IsAttached implements stack.LinkEndpoint.IsAttached.
183+
func (e *Endpoint) IsAttached() bool {
184+
e.mu.RLock()
185+
defer e.mu.RUnlock()
186+
return e.dispatcher != nil
187+
}
188+
189+
// MTU implements stack.LinkEndpoint.MTU. It returns the value initialized
190+
// during construction.
191+
func (e *Endpoint) MTU() uint32 {
192+
return e.mtu
193+
}
194+
195+
// Capabilities implements stack.LinkEndpoint.Capabilities.
196+
func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities {
197+
return e.LinkEPCapabilities
198+
}
199+
200+
// GSOMaxSize implements stack.GSOEndpoint.
201+
func (*Endpoint) GSOMaxSize() uint32 {
202+
return 1 << 15
203+
}
204+
205+
// SupportedGSO implements stack.GSOEndpoint.
206+
func (e *Endpoint) SupportedGSO() stack.SupportedGSO {
207+
return e.SupportedGSOKind
208+
}
209+
210+
// MaxHeaderLength returns the maximum size of the link layer header. Given it
211+
// doesn't have a header, it just returns 0.
212+
func (*Endpoint) MaxHeaderLength() uint16 {
213+
return 0
214+
}
215+
216+
// LinkAddress returns the link address of this endpoint.
217+
func (e *Endpoint) LinkAddress() tcpip.LinkAddress {
218+
return e.linkAddr
219+
}
220+
221+
// WritePackets stores outbound packets into the channel.
222+
// Multiple concurrent calls are permitted.
223+
func (e *Endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) {
224+
n := 0
225+
for _, pkt := range pkts.AsSlice() {
226+
if err := e.q.Write(pkt); err != nil {
227+
return n, err
228+
}
229+
n++
230+
}
231+
232+
return n, nil
233+
}
234+
235+
// Wait implements stack.LinkEndpoint.Wait.
236+
func (*Endpoint) Wait() {}
237+
238+
// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
239+
func (*Endpoint) ARPHardwareType() header.ARPHardwareType {
240+
return header.ARPHardwareNone
241+
}
242+
243+
// AddHeader implements stack.LinkEndpoint.AddHeader.
244+
func (*Endpoint) AddHeader(*stack.PacketBuffer) {}
245+
246+
// ParseHeader implements stack.LinkEndpoint.ParseHeader.
247+
func (*Endpoint) ParseHeader(*stack.PacketBuffer) bool { return true }

wgengine/netstack/endpoint_test.go

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
// Copyright (c) Tailscale Inc & AUTHORS
2+
// SPDX-License-Identifier: BSD-3-Clause
3+
4+
package netstack
5+
6+
import (
7+
"context"
8+
"testing"
9+
"time"
10+
11+
"gvisor.dev/gvisor/pkg/tcpip"
12+
"gvisor.dev/gvisor/pkg/tcpip/stack"
13+
)
14+
15+
func TestEndpointBlockingWrites(t *testing.T) {
16+
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
17+
defer cancel()
18+
linkEP := NewEndpoint(1, 1500, "")
19+
pb1 := stack.NewPacketBuffer(stack.PacketBufferOptions{})
20+
defer pb1.DecRef()
21+
pb2 := stack.NewPacketBuffer(stack.PacketBufferOptions{})
22+
defer pb2.DecRef()
23+
numWrites := make(chan int, 2)
24+
go func() {
25+
bl := stack.PacketBufferList{}
26+
bl.PushBack(pb1)
27+
n, err := linkEP.WritePackets(bl)
28+
if err != nil {
29+
t.Errorf("expected no error, got %s", err)
30+
} else {
31+
pb1.DecRef()
32+
}
33+
numWrites <- n
34+
bl = stack.PacketBufferList{}
35+
bl.PushBack(pb2)
36+
n, err = linkEP.WritePackets(bl)
37+
if err != nil {
38+
t.Errorf("expected no error, got %s", err)
39+
} else {
40+
pb2.DecRef()
41+
}
42+
numWrites <- n
43+
}()
44+
45+
select {
46+
case n := <-numWrites:
47+
if n != 1 {
48+
t.Fatalf("expected 1 write got %d", n)
49+
}
50+
case <-ctx.Done():
51+
t.Fatal("timed out waiting for 1st write")
52+
}
53+
54+
// second write should block
55+
select {
56+
case <-numWrites:
57+
t.Fatalf("expected write to block")
58+
case <-time.After(50 * time.Millisecond):
59+
// OK
60+
}
61+
62+
pbg := linkEP.ReadContext(ctx)
63+
if pbg != pb1 {
64+
t.Fatalf("expected pb1")
65+
}
66+
// Read unblocks the 2nd write
67+
select {
68+
case n := <-numWrites:
69+
if n != 1 {
70+
t.Fatalf("expected 1 write got %d", n)
71+
}
72+
case <-ctx.Done():
73+
t.Fatal("timed out waiting for 2nd write")
74+
}
75+
pbg = linkEP.ReadContext(ctx)
76+
if pbg != pb2 {
77+
t.Fatalf("expected pb2")
78+
}
79+
}
80+
81+
func TestEndpointCloseUnblocksWrites(t *testing.T) {
82+
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
83+
defer cancel()
84+
linkEP := NewEndpoint(1, 1500, "")
85+
pb1 := stack.NewPacketBuffer(stack.PacketBufferOptions{})
86+
pb2 := stack.NewPacketBuffer(stack.PacketBufferOptions{})
87+
defer pb2.DecRef()
88+
numWrites := make(chan int, 2)
89+
errors := make(chan tcpip.Error, 1)
90+
go func() {
91+
bl := stack.PacketBufferList{}
92+
bl.PushBack(pb1)
93+
n, err := linkEP.WritePackets(bl)
94+
if err != nil {
95+
t.Errorf("expected no error, got %s", err)
96+
} else {
97+
pb1.DecRef()
98+
}
99+
numWrites <- n
100+
bl = stack.PacketBufferList{}
101+
bl.PushBack(pb2)
102+
n, err = linkEP.WritePackets(bl)
103+
numWrites <- n
104+
errors <- err
105+
}()
106+
107+
select {
108+
case n := <-numWrites:
109+
if n != 1 {
110+
t.Fatalf("expected 1 write got %d", n)
111+
}
112+
case <-ctx.Done():
113+
t.Fatal("timed out waiting for 1st write")
114+
}
115+
116+
// second write should block
117+
select {
118+
case <-numWrites:
119+
t.Fatalf("expected write to block")
120+
case <-time.After(50 * time.Millisecond):
121+
// OK
122+
}
123+
124+
// close must unblock pending writes without deadlocking
125+
linkEP.Close()
126+
select {
127+
case n := <-numWrites:
128+
if n != 0 {
129+
t.Fatalf("expected 0 writes got %d", n)
130+
}
131+
case <-ctx.Done():
132+
t.Fatal("timed out waiting for 2nd write num")
133+
}
134+
select {
135+
case err := <-errors:
136+
if _, ok := err.(*tcpip.ErrClosedForSend); !ok {
137+
t.Fatalf("expected ErrClosedForSend got %s", err)
138+
}
139+
case <-ctx.Done():
140+
t.Fatal("timed out for 2nd write error")
141+
}
142+
}

0 commit comments

Comments
 (0)