Skip to content
Original file line number Diff line number Diff line change
@@ -1,23 +1,28 @@
package shutterservice
package syncmonitor

import (
"context"
"errors"
"fmt"
"time"

"github.com/jackc/pgx/v4"
"github.com/jackc/pgx/v4/pgxpool"
"github.com/pkg/errors"
"github.com/rs/zerolog/log"

keyperDB "github.com/shutter-network/rolling-shutter/rolling-shutter/keyper/database"
"github.com/shutter-network/rolling-shutter/rolling-shutter/keyperimpl/shutterservice/database"
"github.com/shutter-network/rolling-shutter/rolling-shutter/medley/service"
)

// BlockSyncState is an interface that different keyper implementations
// can implement to provide their own block sync state logic.
type BlockSyncState interface {
// GetSyncedBlockNumber retrieves the current synced block number.
GetSyncedBlockNumber(ctx context.Context) (int64, error)
}

// SyncMonitor monitors the sync state of the keyper.
type SyncMonitor struct {
DBPool *pgxpool.Pool
CheckInterval time.Duration
SyncState BlockSyncState
}

func (s *SyncMonitor) Start(ctx context.Context, runner service.Runner) error {
Expand All @@ -30,15 +35,13 @@ func (s *SyncMonitor) Start(ctx context.Context, runner service.Runner) error {

func (s *SyncMonitor) runMonitor(ctx context.Context) error {
var lastBlockNumber int64
db := database.New(s.DBPool)
keyperdb := keyperDB.New(s.DBPool)

log.Debug().Msg("starting the sync monitor")

for {
select {
case <-time.After(s.CheckInterval):
if err := s.runCheck(ctx, db, keyperdb, &lastBlockNumber); err != nil {
if err := s.runCheck(ctx, &lastBlockNumber); err != nil {
if errors.Is(err, ErrBlockNotIncreasing) {
return err
}
Expand All @@ -55,21 +58,18 @@ var ErrBlockNotIncreasing = errors.New("block number has not increased between c

func (s *SyncMonitor) runCheck(
ctx context.Context,
db *database.Queries,
keyperdb *keyperDB.Queries,
lastBlockNumber *int64,
) error {
record, err := db.GetIdentityRegisteredEventsSyncedUntil(ctx)
currentBlockNumber, err := s.SyncState.GetSyncedBlockNumber(ctx)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
log.Warn().Err(err).Msg("no rows found in table identity_registered_events_synced_until")
log.Warn().Err(err).Msg("no rows found in sync state table")
return nil // This is not an error condition that should stop monitoring
}
return fmt.Errorf("error getting identity_registered_events_synced_until: %w", err)
return fmt.Errorf("error getting synced block number: %w", err)
}

currentBlockNumber := record.BlockNumber
log.Debug().Int64("current-block-number", currentBlockNumber).Msg("current block number")
log.Debug().Int64("current-block-number", currentBlockNumber).Int64("last-block-number", *lastBlockNumber).Msg("current block number")

// if the current block number < last block number, this means a reorg is detected, so we do not throw error
// if the current block number > last block number, then syncing is working as expected
Expand Down
222 changes: 222 additions & 0 deletions rolling-shutter/keyper/syncmonitor/syncmonitor_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
package syncmonitor

import (
"context"
"sync"
"testing"
"time"

"github.com/jackc/pgx/v4"
"gotest.tools/assert"

"github.com/shutter-network/rolling-shutter/rolling-shutter/medley/service"
)

// MockSyncState is a mock implementation of BlockSyncState for testing.
type MockSyncState struct {
mu sync.Mutex
blockNumber int64
err error
}

func (m *MockSyncState) GetSyncedBlockNumber(_ context.Context) (int64, error) {
m.mu.Lock()
defer m.mu.Unlock()
return m.blockNumber, m.err
}

func (m *MockSyncState) SetBlockNumber(n int64) {
m.mu.Lock()
defer m.mu.Unlock()
m.blockNumber = n
}

func TestSyncMonitor_ThrowsErrorWhenBlockNotIncreasing(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

initialBlockNumber := int64(100)
mockSyncState := &MockSyncState{
blockNumber: initialBlockNumber,
}

monitor := &SyncMonitor{
CheckInterval: 5 * time.Second,
SyncState: mockSyncState,
}

errCh := make(chan error, 1)
go func() {
err := service.RunWithSighandler(ctx, monitor)
if err != nil {
errCh <- err
}
}()

time.Sleep(12 * time.Second)

select {
case err := <-errCh:
assert.ErrorContains(t, err, ErrBlockNotIncreasing.Error())
case <-time.After(5 * time.Second):
t.Fatal("expected an error, but none was returned")
}

// Verify final state
finalBlockNumber, err := mockSyncState.GetSyncedBlockNumber(ctx)
assert.NilError(t, err)
assert.Equal(t, initialBlockNumber, finalBlockNumber)
}

func TestSyncMonitor_HandlesBlockNumberIncreasing(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

initialBlockNumber := int64(100)
mockSyncState := &MockSyncState{
blockNumber: initialBlockNumber,
}

monitor := &SyncMonitor{
CheckInterval: 200 * time.Millisecond,
SyncState: mockSyncState,
}

monitorCtx, cancelMonitor := context.WithCancel(ctx)
errCh := make(chan error, 1)
go func() {
if err := service.RunWithSighandler(monitorCtx, monitor); err != nil {
errCh <- err
}
}()

// Update block numbers more quickly
for i := 0; i < 5; i++ {
time.Sleep(200 * time.Millisecond)
mockSyncState.SetBlockNumber(initialBlockNumber + int64(i+1))
}

cancelMonitor()

// Verify final state
finalBlockNumber, err := mockSyncState.GetSyncedBlockNumber(ctx)
assert.NilError(t, err)
assert.Equal(t, initialBlockNumber+5, finalBlockNumber, "block number should have been incremented correctly")
}

func TestSyncMonitor_RunsNormallyWhenNoEons(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

initialBlockNumber := int64(100)
mockSyncState := &MockSyncState{
blockNumber: initialBlockNumber,
}

monitor := &SyncMonitor{
CheckInterval: 5 * time.Second,
SyncState: mockSyncState,
}

monitorCtx, cancelMonitor := context.WithCancel(ctx)
defer cancelMonitor()

errCh := make(chan error, 1)
go func() {
err := service.RunWithSighandler(monitorCtx, monitor)
if err != nil {
errCh <- err
}
}()

// Let it run for a while without incrementing the block number
time.Sleep(15 * time.Second)
cancelMonitor()

select {
case err := <-errCh:
assert.ErrorContains(t, err, ErrBlockNotIncreasing.Error())
case <-time.After(1 * time.Second):
t.Fatalf("expected monitor to throw error, but no error returned")
}
}

func TestSyncMonitor_ContinuesWhenNoRows(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

// Set up mock sync state that returns no rows error
mockSyncState := &MockSyncState{
err: pgx.ErrNoRows,
}
mockSyncState.SetBlockNumber(0) // Initialize block number

monitor := &SyncMonitor{
CheckInterval: 5 * time.Second,
SyncState: mockSyncState,
}

monitorCtx, cancelMonitor := context.WithCancel(ctx)
defer cancelMonitor()

errCh := make(chan error, 1)
go func() {
err := service.RunWithSighandler(monitorCtx, monitor)
if err != nil {
errCh <- err
}
}()

time.Sleep(15 * time.Second)
cancelMonitor()

select {
case err := <-errCh:
t.Fatalf("expected monitor to continue without error, but got: %v", err)
case <-time.After(1 * time.Second):
// Test passes if no error is received
}
}

func TestSyncMonitor_HandlesReorg(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

// Set up mock sync state that returns no rows error
mockSyncState := &MockSyncState{}
mockSyncState.SetBlockNumber(0) // Initialize block number

monitor := &SyncMonitor{
CheckInterval: 5 * time.Second,
SyncState: mockSyncState,
}

monitorCtx, cancelMonitor := context.WithCancel(ctx)
defer cancelMonitor()

errCh := make(chan error, 1)
go func() {
err := service.RunWithSighandler(monitorCtx, monitor)
if err != nil {
errCh <- err
}
}()

// Decrease the block number
decreasedBlockNumber := int64(50)
mockSyncState.SetBlockNumber(decreasedBlockNumber)

time.Sleep(4 * time.Second)
cancelMonitor()

select {
case err := <-errCh:
t.Fatalf("expected monitor to continue without error, but got: %v", err)
case <-time.After(1 * time.Second):
}

// Verify the block number was updated to the latest value
syncedData, err := mockSyncState.GetSyncedBlockNumber(ctx)
assert.NilError(t, err)
assert.Equal(t, decreasedBlockNumber, syncedData, "block number should be updated to the decreased value")
}
12 changes: 7 additions & 5 deletions rolling-shutter/keyperimpl/gnosis/keyper.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"github.com/shutter-network/rolling-shutter/rolling-shutter/keyper"
"github.com/shutter-network/rolling-shutter/rolling-shutter/keyper/epochkghandler"
"github.com/shutter-network/rolling-shutter/rolling-shutter/keyper/kprconfig"
"github.com/shutter-network/rolling-shutter/rolling-shutter/keyper/syncmonitor"
"github.com/shutter-network/rolling-shutter/rolling-shutter/keyperimpl/gnosis/database"
"github.com/shutter-network/rolling-shutter/rolling-shutter/medley/beaconapiclient"
"github.com/shutter-network/rolling-shutter/rolling-shutter/medley/broker"
Expand Down Expand Up @@ -49,7 +50,7 @@ type Keyper struct {
validatorSyncer *ValidatorSyncer
eonKeyPublisher *eonkeypublisher.EonKeyPublisher
latestTriggeredSlot *uint64
syncMonitor *SyncMonitor
syncMonitor *syncmonitor.SyncMonitor

// input events
newBlocks chan *syncevent.LatestBlock
Expand All @@ -63,8 +64,7 @@ type Keyper struct {

func New(c *Config) *Keyper {
return &Keyper{
config: c,
syncMonitor: &SyncMonitor{},
config: c,
}
}

Expand Down Expand Up @@ -156,9 +156,11 @@ func (kpr *Keyper) Start(ctx context.Context, runner service.Runner) error {
return errors.Wrap(err, "failed to reset transaction pointer age")
}

kpr.syncMonitor = &SyncMonitor{
DBPool: kpr.dbpool,
kpr.syncMonitor = &syncmonitor.SyncMonitor{
CheckInterval: time.Duration(kpr.config.Gnosis.SyncMonitorCheckInterval) * time.Second,
SyncState: &GnosisSyncState{
kpr.dbpool,
},
}

runner.Go(func() error { return kpr.processInputs(ctx) })
Expand Down
Loading