|
| 1 | +// based on https://github.com/google/gvisor/blob/74f22885dc45e2866985fe7179103e1000382415/pkg/tcpip/link/channel/channel.go |
| 2 | +// |
| 3 | +// Copyright 2018 The gVisor Authors. |
| 4 | +// |
| 5 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 6 | +// you may not use this file except in compliance with the License. |
| 7 | +// You may obtain a copy of the License at |
| 8 | +// |
| 9 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +// |
| 11 | +// Unless required by applicable law or agreed to in writing, software |
| 12 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | +// See the License for the specific language governing permissions and |
| 15 | +// limitations under the License. |
| 16 | +// |
| 17 | +// Modifications from original source are Copyright 2024 Tailscale Inc & AUTHORS |
| 18 | + |
| 19 | +package netstack |
| 20 | + |
| 21 | +import ( |
| 22 | + "context" |
| 23 | + |
| 24 | + "gvisor.dev/gvisor/pkg/sync" |
| 25 | + "gvisor.dev/gvisor/pkg/tcpip" |
| 26 | + "gvisor.dev/gvisor/pkg/tcpip/header" |
| 27 | + "gvisor.dev/gvisor/pkg/tcpip/stack" |
| 28 | +) |
| 29 | + |
| 30 | +type queue struct { |
| 31 | + // c is the outbound packet channel. |
| 32 | + c chan *stack.PacketBuffer |
| 33 | + mu sync.RWMutex |
| 34 | + // +checklocks:mu |
| 35 | + closed bool |
| 36 | + |
| 37 | + closedChOnce sync.Once |
| 38 | + closedCh chan struct{} |
| 39 | +} |
| 40 | + |
| 41 | +func (q *queue) Close() { |
| 42 | + // This unblocks any calls to Write() which might be holding the mu. |
| 43 | + q.closedChOnce.Do(func() { |
| 44 | + close(q.closedCh) |
| 45 | + }) |
| 46 | + |
| 47 | + q.mu.Lock() |
| 48 | + defer q.mu.Unlock() |
| 49 | + if q.closed { |
| 50 | + return |
| 51 | + } |
| 52 | + close(q.c) |
| 53 | + q.closed = true |
| 54 | +} |
| 55 | + |
| 56 | +func (q *queue) Read() *stack.PacketBuffer { |
| 57 | + select { |
| 58 | + case p := <-q.c: |
| 59 | + return p |
| 60 | + default: |
| 61 | + return nil |
| 62 | + } |
| 63 | +} |
| 64 | + |
| 65 | +func (q *queue) ReadContext(ctx context.Context) *stack.PacketBuffer { |
| 66 | + select { |
| 67 | + case pkt := <-q.c: |
| 68 | + return pkt |
| 69 | + case <-ctx.Done(): |
| 70 | + return nil |
| 71 | + } |
| 72 | +} |
| 73 | + |
| 74 | +func (q *queue) Write(pkt *stack.PacketBuffer) tcpip.Error { |
| 75 | + q.mu.RLock() |
| 76 | + defer q.mu.RUnlock() |
| 77 | + if q.closed { |
| 78 | + return &tcpip.ErrClosedForSend{} |
| 79 | + } |
| 80 | + select { |
| 81 | + case q.c <- pkt.IncRef(): |
| 82 | + return nil |
| 83 | + case <-q.closedCh: |
| 84 | + pkt.DecRef() |
| 85 | + return &tcpip.ErrClosedForSend{} |
| 86 | + } |
| 87 | +} |
| 88 | + |
| 89 | +func (q *queue) Num() int { |
| 90 | + return len(q.c) |
| 91 | +} |
| 92 | + |
| 93 | +var _ stack.LinkEndpoint = (*Endpoint)(nil) |
| 94 | +var _ stack.GSOEndpoint = (*Endpoint)(nil) |
| 95 | + |
| 96 | +// Endpoint is link layer endpoint that stores outbound packets in a channel |
| 97 | +// and allows injection of inbound packets. It is based on gVisor |
| 98 | +// channel.Endpoint, however when the channel is full, it blocks writes until |
| 99 | +// there is space in the channel or until the Endpoint is closed. The gVisor |
| 100 | +// version dropped packets if the channel is full. This limits TCP throughput |
| 101 | +// as dropped packets need to be retransmitted and are interpreted as a |
| 102 | +// congestion event, causing the TCP sender to decrease the congestion window. |
| 103 | +// Much better to apply back-pressure to the TCP stack at the Endpoint. |
| 104 | +type Endpoint struct { |
| 105 | + mtu uint32 |
| 106 | + linkAddr tcpip.LinkAddress |
| 107 | + LinkEPCapabilities stack.LinkEndpointCapabilities |
| 108 | + SupportedGSOKind stack.SupportedGSO |
| 109 | + |
| 110 | + mu sync.RWMutex |
| 111 | + // +checklocks:mu |
| 112 | + dispatcher stack.NetworkDispatcher |
| 113 | + |
| 114 | + // Outbound packet queue. |
| 115 | + q *queue |
| 116 | +} |
| 117 | + |
| 118 | +// NewEndpoint creates a new channel endpoint. |
| 119 | +func NewEndpoint(size int, mtu uint32, linkAddr tcpip.LinkAddress) *Endpoint { |
| 120 | + return &Endpoint{ |
| 121 | + q: &queue{ |
| 122 | + c: make(chan *stack.PacketBuffer, size), |
| 123 | + closedCh: make(chan struct{}), |
| 124 | + }, |
| 125 | + mtu: mtu, |
| 126 | + linkAddr: linkAddr, |
| 127 | + } |
| 128 | +} |
| 129 | + |
| 130 | +// Close closes e. Further packet injections will return an error, and all pending |
| 131 | +// packets are discarded. Close may be called concurrently with WritePackets. |
| 132 | +func (e *Endpoint) Close() { |
| 133 | + e.q.Close() |
| 134 | + e.Drain() |
| 135 | +} |
| 136 | + |
| 137 | +// Read does non-blocking read one packet from the outbound packet queue. |
| 138 | +func (e *Endpoint) Read() *stack.PacketBuffer { |
| 139 | + return e.q.Read() |
| 140 | +} |
| 141 | + |
| 142 | +// ReadContext does blocking read for one packet from the outbound packet queue. |
| 143 | +// It can be cancelled by ctx, and in this case, it returns nil. |
| 144 | +func (e *Endpoint) ReadContext(ctx context.Context) *stack.PacketBuffer { |
| 145 | + return e.q.ReadContext(ctx) |
| 146 | +} |
| 147 | + |
| 148 | +// Drain removes all outbound packets from the channel and counts them. |
| 149 | +func (e *Endpoint) Drain() int { |
| 150 | + c := 0 |
| 151 | + for pkt := e.Read(); pkt != nil; pkt = e.Read() { |
| 152 | + pkt.DecRef() |
| 153 | + c++ |
| 154 | + } |
| 155 | + return c |
| 156 | +} |
| 157 | + |
| 158 | +// NumQueued returns the number of packet queued for outbound. |
| 159 | +func (e *Endpoint) NumQueued() int { |
| 160 | + return e.q.Num() |
| 161 | +} |
| 162 | + |
| 163 | +// InjectInbound injects an inbound packet. If the endpoint is not attached, the |
| 164 | +// packet is not delivered. |
| 165 | +func (e *Endpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { |
| 166 | + e.mu.RLock() |
| 167 | + d := e.dispatcher |
| 168 | + e.mu.RUnlock() |
| 169 | + if d != nil { |
| 170 | + d.DeliverNetworkPacket(protocol, pkt) |
| 171 | + } |
| 172 | +} |
| 173 | + |
| 174 | +// Attach saves the stack network-layer dispatcher for use later when packets |
| 175 | +// are injected. |
| 176 | +func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) { |
| 177 | + e.mu.Lock() |
| 178 | + defer e.mu.Unlock() |
| 179 | + e.dispatcher = dispatcher |
| 180 | +} |
| 181 | + |
| 182 | +// IsAttached implements stack.LinkEndpoint.IsAttached. |
| 183 | +func (e *Endpoint) IsAttached() bool { |
| 184 | + e.mu.RLock() |
| 185 | + defer e.mu.RUnlock() |
| 186 | + return e.dispatcher != nil |
| 187 | +} |
| 188 | + |
| 189 | +// MTU implements stack.LinkEndpoint.MTU. It returns the value initialized |
| 190 | +// during construction. |
| 191 | +func (e *Endpoint) MTU() uint32 { |
| 192 | + return e.mtu |
| 193 | +} |
| 194 | + |
| 195 | +// Capabilities implements stack.LinkEndpoint.Capabilities. |
| 196 | +func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities { |
| 197 | + return e.LinkEPCapabilities |
| 198 | +} |
| 199 | + |
| 200 | +// GSOMaxSize implements stack.GSOEndpoint. |
| 201 | +func (*Endpoint) GSOMaxSize() uint32 { |
| 202 | + return 1 << 15 |
| 203 | +} |
| 204 | + |
| 205 | +// SupportedGSO implements stack.GSOEndpoint. |
| 206 | +func (e *Endpoint) SupportedGSO() stack.SupportedGSO { |
| 207 | + return e.SupportedGSOKind |
| 208 | +} |
| 209 | + |
| 210 | +// MaxHeaderLength returns the maximum size of the link layer header. Given it |
| 211 | +// doesn't have a header, it just returns 0. |
| 212 | +func (*Endpoint) MaxHeaderLength() uint16 { |
| 213 | + return 0 |
| 214 | +} |
| 215 | + |
| 216 | +// LinkAddress returns the link address of this endpoint. |
| 217 | +func (e *Endpoint) LinkAddress() tcpip.LinkAddress { |
| 218 | + return e.linkAddr |
| 219 | +} |
| 220 | + |
| 221 | +// WritePackets stores outbound packets into the channel. |
| 222 | +// Multiple concurrent calls are permitted. |
| 223 | +func (e *Endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) { |
| 224 | + n := 0 |
| 225 | + for _, pkt := range pkts.AsSlice() { |
| 226 | + if err := e.q.Write(pkt); err != nil { |
| 227 | + return n, err |
| 228 | + } |
| 229 | + n++ |
| 230 | + } |
| 231 | + |
| 232 | + return n, nil |
| 233 | +} |
| 234 | + |
| 235 | +// Wait implements stack.LinkEndpoint.Wait. |
| 236 | +func (*Endpoint) Wait() {} |
| 237 | + |
| 238 | +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. |
| 239 | +func (*Endpoint) ARPHardwareType() header.ARPHardwareType { |
| 240 | + return header.ARPHardwareNone |
| 241 | +} |
| 242 | + |
| 243 | +// AddHeader implements stack.LinkEndpoint.AddHeader. |
| 244 | +func (*Endpoint) AddHeader(*stack.PacketBuffer) {} |
| 245 | + |
| 246 | +// ParseHeader implements stack.LinkEndpoint.ParseHeader. |
| 247 | +func (*Endpoint) ParseHeader(*stack.PacketBuffer) bool { return true } |
0 commit comments