Skip to content

Commit 14f84bf

Browse files
authored
refactor apply deadline (#59)
1 parent 032c922 commit 14f84bf

File tree

5 files changed

+343
-191
lines changed

5 files changed

+343
-191
lines changed

keysign/signature_notifier_test.go

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ import (
1818
"github.com/zeta-chain/go-tss/conversion"
1919

2020
"github.com/zeta-chain/go-tss/common"
21-
"github.com/zeta-chain/go-tss/p2p"
2221
)
2322

2423
func TestSignatureNotifierHappyPath(t *testing.T) {
@@ -31,7 +30,6 @@ func TestSignatureNotifierHappyPath(t *testing.T) {
3130
assert.NoError(t, err)
3231
messageID, err := common.MsgToHashString(buf)
3332
assert.NoError(t, err)
34-
p2p.ApplyDeadline.Store(false)
3533
id1 := tnet.RandIdentityOrFatal(t)
3634
id2 := tnet.RandIdentityOrFatal(t)
3735
id3 := tnet.RandIdentityOrFatal(t)
@@ -104,7 +102,6 @@ func TestSignatureNotifierBroadcastFirst(t *testing.T) {
104102
assert.NoError(t, err)
105103
messageID, err := common.MsgToHashString(buf)
106104
assert.NoError(t, err)
107-
p2p.ApplyDeadline.Store(false)
108105
id1 := tnet.RandIdentityOrFatal(t)
109106
id2 := tnet.RandIdentityOrFatal(t)
110107
id3 := tnet.RandIdentityOrFatal(t)

p2p/mocks/stream.go

Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
package mocks
2+
3+
import (
4+
"bytes"
5+
"math/rand/v2"
6+
"strconv"
7+
"sync"
8+
"testing"
9+
"time"
10+
11+
"github.com/libp2p/go-libp2p/core/network"
12+
"github.com/libp2p/go-libp2p/core/protocol"
13+
"github.com/pkg/errors"
14+
"github.com/stretchr/testify/require"
15+
)
16+
17+
const testProtocolID protocol.ID = "/p2p/test-stream"
18+
19+
type Stream struct {
20+
t *testing.T
21+
buffer *bytes.Buffer
22+
23+
id uint64
24+
protocol protocol.ID
25+
26+
readDeadline time.Time
27+
writeDeadline time.Time
28+
29+
errSetReadDeadLine bool
30+
errSetWriteDeadLine bool
31+
errRead bool
32+
33+
mu *sync.RWMutex
34+
}
35+
36+
var _ network.Stream = &Stream{}
37+
38+
func NewStream(t *testing.T) *Stream {
39+
return &Stream{
40+
t: t,
41+
buffer: &bytes.Buffer{},
42+
43+
id: rand.Uint64() % 10_000,
44+
protocol: testProtocolID,
45+
46+
readDeadline: time.Time{},
47+
writeDeadline: time.Time{},
48+
49+
errSetReadDeadLine: false,
50+
errSetWriteDeadLine: false,
51+
errRead: false,
52+
53+
mu: &sync.RWMutex{},
54+
}
55+
}
56+
57+
func (s *Stream) Read(buf []byte) (n int, err error) {
58+
s.mu.RLock()
59+
defer s.mu.RUnlock()
60+
61+
if s.errRead {
62+
return 0, errors.New("you asked for it")
63+
}
64+
65+
// no deadline, read immediately (sync)
66+
if s.readDeadline.IsZero() {
67+
return s.buffer.Read(buf)
68+
}
69+
70+
var (
71+
timeout = time.Until(s.readDeadline)
72+
done = make(chan struct{})
73+
)
74+
75+
go func() {
76+
n, err = s.buffer.Read(buf)
77+
close(done)
78+
}()
79+
80+
select {
81+
case <-done:
82+
return n, err
83+
case <-time.After(timeout):
84+
return 0, errors.New("mock: read deadline exceeded")
85+
}
86+
}
87+
88+
func (s *Stream) MustRead(buf []byte) {
89+
_, err := s.Read(buf)
90+
require.NoError(s.t, err, "failed to read from stream")
91+
}
92+
93+
func (s *Stream) Write(buf []byte) (n int, err error) {
94+
s.mu.Lock()
95+
defer s.mu.Unlock()
96+
97+
if s.errSetWriteDeadLine {
98+
return 0, errors.New("mock: unable to set write deadline")
99+
}
100+
101+
// no deadline, write immediately (sync)
102+
if s.writeDeadline.IsZero() {
103+
return s.buffer.Write(buf)
104+
}
105+
106+
var (
107+
timeout = time.Until(s.writeDeadline)
108+
done = make(chan struct{})
109+
)
110+
111+
go func() {
112+
n, err = s.buffer.Write(buf)
113+
close(done)
114+
}()
115+
116+
select {
117+
case <-done:
118+
return n, err
119+
case <-time.After(timeout):
120+
return 0, errors.New("mock: write deadline exceeded")
121+
}
122+
}
123+
124+
func (s *Stream) MustWrite(buf []byte) {
125+
_, err := s.Write(buf)
126+
require.NoError(s.t, err, "failed to write to stream")
127+
}
128+
129+
func (s *Stream) Stat() network.Stats {
130+
return network.Stats{
131+
Direction: network.DirUnknown,
132+
Extra: make(map[any]any),
133+
}
134+
}
135+
136+
func (s *Stream) Protocol() protocol.ID {
137+
s.mu.RLock()
138+
defer s.mu.RUnlock()
139+
140+
return s.protocol
141+
}
142+
143+
func (s *Stream) SetProtocol(id protocol.ID) error {
144+
s.mu.Lock()
145+
defer s.mu.Unlock()
146+
147+
s.protocol = id
148+
149+
return nil
150+
}
151+
152+
func (s *Stream) SetDeadline(at time.Time) error {
153+
if err := s.SetReadDeadline(at); err != nil {
154+
return err
155+
}
156+
157+
if err := s.SetWriteDeadline(at); err != nil {
158+
return err
159+
}
160+
161+
return nil
162+
}
163+
164+
func (s *Stream) SetReadDeadline(at time.Time) error {
165+
s.mu.Lock()
166+
defer s.mu.Unlock()
167+
168+
if s.errSetReadDeadLine {
169+
return errors.New("mock: unable to set read deadline")
170+
}
171+
172+
s.readDeadline = at
173+
174+
return nil
175+
}
176+
177+
func (s *Stream) SetWriteDeadline(at time.Time) error {
178+
s.mu.Lock()
179+
defer s.mu.Unlock()
180+
181+
if s.errSetWriteDeadLine {
182+
return errors.New("mock: unable to set write deadline")
183+
}
184+
185+
s.writeDeadline = at
186+
187+
return nil
188+
}
189+
190+
func (s *Stream) ErrRead(v bool) {
191+
s.mu.Lock()
192+
defer s.mu.Unlock()
193+
194+
s.errRead = v
195+
}
196+
197+
func (s *Stream) ErrSetReadDeadline(v bool) {
198+
s.mu.Lock()
199+
defer s.mu.Unlock()
200+
201+
s.errSetReadDeadLine = v
202+
}
203+
204+
func (s *Stream) ErrSetWriteDeadline(v bool) {
205+
s.mu.Lock()
206+
defer s.mu.Unlock()
207+
208+
s.errSetWriteDeadLine = v
209+
}
210+
211+
func (s *Stream) ID() string { return strconv.FormatUint(s.id, 10) }
212+
func (s *Stream) Scope() network.StreamScope { return &network.NullScope{} }
213+
func (s *Stream) Conn() network.Conn { return nil }
214+
func (s *Stream) Close() error { return nil }
215+
func (s *Stream) CloseRead() error { return nil }
216+
func (s *Stream) CloseWrite() error { return nil }
217+
func (s *Stream) Reset() error { return nil }

p2p/party_coordinator_test.go

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,6 @@ import (
1717
"github.com/zeta-chain/go-tss/conversion"
1818
)
1919

20-
func init() {
21-
ApplyDeadline.Store(false)
22-
}
23-
2420
func setupHosts(t *testing.T, n int) []host.Host {
2521
mn := mocknet.New()
2622

p2p/stream_manager.go

Lines changed: 62 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@ import (
55
"bytes"
66
"encoding/binary"
77
"io"
8+
"net"
9+
"strings"
810
"sync"
9-
"sync/atomic"
1011
"time"
1112

1213
"github.com/libp2p/go-libp2p/core/network"
@@ -24,14 +25,6 @@ const (
2425

2526
const unknown = "unknown"
2627

27-
// ApplyDeadline will be true, and only disable it when we are doing test
28-
// the reason being the p2p network, mocknet, mock stream doesn't support SetReadDeadline ,SetWriteDeadline feature
29-
var ApplyDeadline = &atomic.Bool{}
30-
31-
func init() {
32-
ApplyDeadline.Store(true)
33-
}
34-
3528
// StreamManager is responsible fro libp2p stream bookkeeping.
3629
// It can store streams by message id for latter release
3730
// as we can't have thousands of streams opened at the same time.
@@ -214,14 +207,8 @@ func (sm *StreamManager) cleanup() {
214207

215208
// ReadStreamWithBuffer read data from the given stream
216209
func ReadStreamWithBuffer(stream network.Stream) ([]byte, error) {
217-
if ApplyDeadline.Load() {
218-
deadline := time.Now().Add(TimeoutReadPayload)
219-
if err := stream.SetReadDeadline(deadline); err != nil {
220-
if errReset := stream.Reset(); errReset != nil {
221-
return nil, errReset
222-
}
223-
return nil, err
224-
}
210+
if err := applyDeadline(stream, TimeoutReadPayload, true); err != nil {
211+
return nil, err
225212
}
226213

227214
streamReader := bufio.NewReader(stream)
@@ -253,16 +240,8 @@ func WriteStreamWithBuffer(msg []byte, stream network.Stream) error {
253240
return errors.Errorf("payload size exceeded (got %d, max %d)", len(msg), MaxPayload)
254241
}
255242

256-
if ApplyDeadline.Load() {
257-
deadline := time.Now().Add(TimeoutWritePayload)
258-
259-
if err := stream.SetWriteDeadline(deadline); err != nil {
260-
if errReset := stream.Reset(); errReset != nil {
261-
return errors.Wrap(errReset, "failed to reset stream during failure in write deadline")
262-
}
263-
264-
return errors.Wrap(err, "failed to set write deadline")
265-
}
243+
if err := applyDeadline(stream, TimeoutWritePayload, false); err != nil {
244+
return err
266245
}
267246

268247
// Create header containing the message length
@@ -283,3 +262,59 @@ func WriteStreamWithBuffer(msg []byte, stream network.Stream) error {
283262

284263
return nil
285264
}
265+
266+
// applies read/write (read=true, write=false) deadline to the stream.
267+
// Tolerates mocknet errors.
268+
// Resets the stream on failure.
269+
func applyDeadline(stream network.Stream, timeout time.Duration, readOrWrite bool) error {
270+
// noop
271+
if timeout == 0 {
272+
return nil
273+
}
274+
275+
// calculate deadline
276+
deadline := time.Now().Add(timeout)
277+
278+
set := stream.SetReadDeadline
279+
if !readOrWrite {
280+
set = stream.SetWriteDeadline
281+
}
282+
283+
err := set(deadline)
284+
285+
if err == nil || isMockNetError(err) {
286+
return nil
287+
}
288+
289+
// err is not nil, so we need to reset the stream
290+
if errReset := stream.Reset(); errReset != nil {
291+
return errors.Wrap(errReset, "failed to reset stream after setDeadline failure")
292+
}
293+
294+
return err
295+
}
296+
297+
// mocknet doesn't support deadlines, so we need to check for it and ignore.
298+
// See: libp2p/p2p/net/mock/mock_stream.go
299+
//
300+
// func (s *stream) SetDeadline(...) error {
301+
// return &net.OpError{Op: "set", Net: "pipe", Err: errors.New("deadline not supported")}
302+
// }
303+
func isMockNetError(err error) bool {
304+
if err == nil {
305+
return false
306+
}
307+
308+
opError := &net.OpError{}
309+
if !errors.As(err, &opError) {
310+
return false
311+
}
312+
313+
if opError.Err == nil {
314+
return false
315+
}
316+
317+
return opError.Op == "set" &&
318+
opError.Net == "pipe" &&
319+
strings.Contains(opError.Err.Error(), "deadline not supported")
320+
}

0 commit comments

Comments
 (0)