From 8bd5bc4cfb31f18443275d93efb909cb486a48d4 Mon Sep 17 00:00:00 2001 From: Nikolas De Giorgis Date: Wed, 20 Aug 2025 09:51:13 +0100 Subject: [PATCH] backend: move eth update logic into its own package. --- backend/backend.go | 83 ++------- backend/coins/eth/account.go | 67 +------ backend/coins/eth/mocks/balancefetcher.go | 84 +++++++++ backend/coins/eth/updater.go | 185 +++++++++++++++++++ backend/coins/eth/updater_test.go | 214 ++++++++++++++++++++++ 5 files changed, 500 insertions(+), 133 deletions(-) create mode 100644 backend/coins/eth/mocks/balancefetcher.go create mode 100644 backend/coins/eth/updater.go create mode 100644 backend/coins/eth/updater_test.go diff --git a/backend/backend.go b/backend/backend.go index f39589039e..2c2114cfe4 100644 --- a/backend/backend.go +++ b/backend/backend.go @@ -227,7 +227,6 @@ type Backend struct { socksProxy socksproxy.SocksProxy // can be a regular or, if Tor is enabled in the config, a SOCKS5 proxy client. httpClient *http.Client - etherScanHTTPClient *http.Client etherScanRateLimiter *rate.Limiter ratesUpdater *rates.RateUpdater banners *banners.Banners @@ -243,14 +242,8 @@ type Backend struct { // isOnline indicates whether the backend is online, i.e. able to connect to the internet. isOnline atomic.Bool - // quit is used to indicate to running goroutines that they should stop as the backend is being closed - quit chan struct{} - - // enqueueUpdateForAccount is used to enqueue an update for an account. - enqueueUpdateForAccount chan *eth.Account - - // updateETHAccountsCh is used to trigger an update of all ETH accounts. - updateETHAccountsCh chan struct{} + // ethupdater takes care of updating ETH accounts. + ethupdater *eth.Updater } // NewBackend creates a new backend with the given arguments. @@ -294,10 +287,8 @@ func NewBackend(arguments *arguments.Arguments, environment Environment) (*Backe log: log, - testing: backendConfig.AppConfig().Backend.StartInTestnet || arguments.Testing(), - quit: make(chan struct{}), - etherScanRateLimiter: rate.NewLimiter(rate.Limit(etherscan.CallsPerSec), 1), - enqueueUpdateForAccount: accountUpdate, + testing: backendConfig.AppConfig().Backend.StartInTestnet || arguments.Testing(), + etherScanRateLimiter: rate.NewLimiter(rate.Limit(etherscan.CallsPerSec), 1), } // TODO: remove when connectivity check is present on all platforms backend.isOnline.Store(true) @@ -309,7 +300,7 @@ func NewBackend(arguments *arguments.Arguments, environment Environment) (*Backe backend.notifier = notifier backend.socksProxy = backendProxy backend.httpClient = hclient - backend.etherScanHTTPClient = hclient + backend.ethupdater = eth.NewUpdater(accountUpdate, backend.httpClient, backend.etherScanRateLimiter, backend.updateETHAccounts) ratesCache := filepath.Join(arguments.CacheDirectoryPath(), "exchangerates") if err := os.MkdirAll(ratesCache, 0700); err != nil { @@ -545,19 +536,19 @@ func (backend *Backend) Coin(code coinpkg.Code) (coinpkg.Coin, error) { coin = btc.NewCoin(coinpkg.CodeLTC, "Litecoin", "LTC", coinpkg.BtcUnitDefault, <c.MainNetParams, dbFolder, servers, "https://blockchair.com/litecoin/transaction/", backend.socksProxy) case code == coinpkg.CodeETH: - etherScan := etherscan.NewEtherScan("1", backend.etherScanHTTPClient, backend.etherScanRateLimiter) + etherScan := etherscan.NewEtherScan("1", backend.httpClient, backend.etherScanRateLimiter) coin = eth.NewCoin(etherScan, code, "Ethereum", "ETH", "ETH", params.MainnetChainConfig, "https://etherscan.io/tx/", etherScan, nil) case code == coinpkg.CodeSEPETH: - etherScan := etherscan.NewEtherScan("11155111", backend.etherScanHTTPClient, backend.etherScanRateLimiter) + etherScan := etherscan.NewEtherScan("11155111", backend.httpClient, backend.etherScanRateLimiter) coin = eth.NewCoin(etherScan, code, "Ethereum Sepolia", "SEPETH", "SEPETH", params.SepoliaChainConfig, "https://sepolia.etherscan.io/tx/", etherScan, nil) case erc20Token != nil: - etherScan := etherscan.NewEtherScan("1", backend.etherScanHTTPClient, backend.etherScanRateLimiter) + etherScan := etherscan.NewEtherScan("1", backend.httpClient, backend.etherScanRateLimiter) coin = eth.NewCoin(etherScan, erc20Token.code, erc20Token.name, erc20Token.unit, "ETH", params.MainnetChainConfig, "https://etherscan.io/tx/", etherScan, @@ -571,46 +562,6 @@ func (backend *Backend) Coin(code coinpkg.Code) (coinpkg.Coin, error) { return coin, nil } -func (backend *Backend) pollETHAccounts() { - timer := time.After(0) - - updateAll := func() { - if err := backend.updateETHAccounts(); err != nil { - backend.log.WithError(err).Error("could not update ETH accounts") - } - } - - for { - select { - case <-backend.quit: - return - default: - select { - case <-backend.quit: - return - case account := <-backend.enqueueUpdateForAccount: - go func() { - // A single ETH accounts needs an update. - ethCoin, ok := account.Coin().(*eth.Coin) - if !ok { - backend.log.WithField("account", account.Config().Config.Name).Errorf("expected ETH account to have ETH coin, got %T", account.Coin()) - } - etherScanClient := etherscan.NewEtherScan(ethCoin.ChainIDstr(), backend.etherScanHTTPClient, backend.etherScanRateLimiter) - if err := eth.UpdateBalances([]*eth.Account{account}, etherScanClient); err != nil { - backend.log.WithError(err).Errorf("could not update account %s", account.Config().Config.Name) - } - }() - case <-backend.updateETHAccountsCh: - go updateAll() - timer = time.After(eth.PollInterval) - case <-timer: - go updateAll() - timer = time.After(eth.PollInterval) - } - } - } -} - func (backend *Backend) updateETHAccounts() error { defer backend.accountsAndKeystoreLock.RLock()() backend.log.Debug("Updating ETH accounts balances") @@ -619,21 +570,15 @@ func (backend *Backend) updateETHAccounts() error { for _, account := range backend.accounts { ethAccount, ok := account.(*eth.Account) if ok { - ethCoin, ok := ethAccount.Coin().(*eth.Coin) - if !ok { - return errp.Newf("expected ETH account to have ETH coin, got %T", ethAccount.Coin()) - } - chainID := ethCoin.ChainIDstr() + chainID := ethAccount.ETHCoin().ChainIDstr() accountsChainID[chainID] = append(accountsChainID[chainID], ethAccount) } } for chainID, ethAccounts := range accountsChainID { - etherScanClient := etherscan.NewEtherScan(chainID, backend.etherScanHTTPClient, backend.etherScanRateLimiter) - if err := eth.UpdateBalances(ethAccounts, etherScanClient); err != nil { - backend.log.WithError(err).Errorf("could not update ETH accounts for chain ID %s", chainID) - } + etherScanClient := etherscan.NewEtherScan(chainID, backend.httpClient, backend.etherScanRateLimiter) + backend.ethupdater.UpdateBalances(ethAccounts, etherScanClient) } return nil @@ -676,7 +621,7 @@ func (backend *Backend) ManualReconnect(reconnectETH bool) { } if reconnectETH { backend.log.Info("Reconnecting ETH accounts") - backend.updateETHAccountsCh <- struct{}{} + backend.ethupdater.EnqueueUpdateForAllAccounts() } } @@ -739,7 +684,7 @@ func (backend *Backend) Start() <-chan interface{} { backend.environment.OnAuthSettingChanged(backend.config.AppConfig().Backend.Authentication) - go backend.pollETHAccounts() + go backend.ethupdater.PollBalances() if backend.config.AppConfig().Backend.StartInTestnet { if err := backend.config.ModifyAppConfig(func(c *config.AppConfig) error { c.Backend.StartInTestnet = false; return nil }); err != nil { @@ -1014,7 +959,7 @@ func (backend *Backend) Close() error { return errp.New(strings.Join(errors, "; ")) } - close(backend.quit) + backend.ethupdater.Close() return nil } diff --git a/backend/coins/eth/account.go b/backend/coins/eth/account.go index 27e7ce26ab..6274aea79e 100644 --- a/backend/coins/eth/account.go +++ b/backend/coins/eth/account.go @@ -25,7 +25,6 @@ import ( "path" "strconv" "strings" - "time" "github.com/BitBoxSwiss/bitbox-wallet-app/backend/accounts" "github.com/BitBoxSwiss/bitbox-wallet-app/backend/accounts/errors" @@ -49,9 +48,6 @@ import ( "github.com/sirupsen/logrus" ) -// PollInterval is the interval at which the account is polled for updates. -var PollInterval = 5 * time.Minute - func isMixedCase(s string) bool { return strings.ToLower(s) != s && strings.ToUpper(s) != s } @@ -1019,64 +1015,7 @@ func (account *Account) EnqueueUpdate() { account.enqueueUpdateCh <- account } -// UpdateBalances updates the balances of the accounts in the provided slice. -func UpdateBalances(accounts []*Account, etherScanClient *etherscan.EtherScan) error { - ethNonErc20Addresses := make([]ethcommon.Address, 0, len(accounts)) - for _, account := range accounts { - if account.isClosed() { - continue - } - address, err := account.Address() - if err != nil { - account.log.WithError(err).Errorf("Could not get address for account %s", account.Config().Config.Code) - account.SetOffline(err) - continue - } - if account.coin.erc20Token == nil { - ethNonErc20Addresses = append(ethNonErc20Addresses, address.Address) - } - } - - balances, err := etherScanClient.Balances(context.TODO(), ethNonErc20Addresses) - if err != nil { - return errp.WithStack(err) - } - - for _, account := range accounts { - if account.isClosed() { - continue - } - address, err := account.Address() - if err != nil { - account.log.WithError(err).Errorf("Could not get address for account %s", account.Config().Config.Code) - account.SetOffline(err) - continue - } - var balance *big.Int - if account.coin.erc20Token != nil { - var err error - balance, err = account.coin.client.ERC20Balance(account.address.Address, account.coin.erc20Token) - if err != nil { - account.log.WithError(err).Errorf("Could not get ERC20 balance for address %s", address.Address.Hex()) - account.SetOffline(err) - continue - } - } else { - var ok bool - balance, ok = balances[address.Address] - if !ok { - errMsg := fmt.Sprintf("Could not find balance for address %s", address.Address.Hex()) - account.log.Error(errMsg) - account.SetOffline(errp.New(errMsg)) - continue - } - } - if err := account.Update(balance); err != nil { - account.log.WithError(err).Errorf("Could not update balance for address %s", address.Address.Hex()) - account.SetOffline(err) - } else { - account.SetOffline(nil) - } - } - return nil +// ETHCoin returns the eth.Coin of the account. +func (account *Account) ETHCoin() *Coin { + return account.coin } diff --git a/backend/coins/eth/mocks/balancefetcher.go b/backend/coins/eth/mocks/balancefetcher.go new file mode 100644 index 0000000000..2180de99b0 --- /dev/null +++ b/backend/coins/eth/mocks/balancefetcher.go @@ -0,0 +1,84 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package mocks + +import ( + "context" + "github.com/BitBoxSwiss/bitbox-wallet-app/backend/coins/eth" + ethcommon "github.com/ethereum/go-ethereum/common" + "math/big" + "sync" +) + +// Ensure, that BalanceFetcherMock does implement eth.BalanceFetcher. +// If this is not the case, regenerate this file with moq. +var _ eth.BalanceFetcher = &BalanceFetcherMock{} + +// BalanceFetcherMock is a mock implementation of eth.BalanceFetcher. +// +// func TestSomethingThatUsesBalanceFetcher(t *testing.T) { +// +// // make and configure a mocked eth.BalanceFetcher +// mockedBalanceFetcher := &BalanceFetcherMock{ +// BalancesFunc: func(ctx context.Context, addresses []ethcommon.Address) (map[ethcommon.Address]*big.Int, error) { +// panic("mock out the Balances method") +// }, +// } +// +// // use mockedBalanceFetcher in code that requires eth.BalanceFetcher +// // and then make assertions. +// +// } +type BalanceFetcherMock struct { + // BalancesFunc mocks the Balances method. + BalancesFunc func(ctx context.Context, addresses []ethcommon.Address) (map[ethcommon.Address]*big.Int, error) + + // calls tracks calls to the methods. + calls struct { + // Balances holds details about calls to the Balances method. + Balances []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Addresses is the addresses argument value. + Addresses []ethcommon.Address + } + } + lockBalances sync.RWMutex +} + +// Balances calls BalancesFunc. +func (mock *BalanceFetcherMock) Balances(ctx context.Context, addresses []ethcommon.Address) (map[ethcommon.Address]*big.Int, error) { + if mock.BalancesFunc == nil { + panic("BalanceFetcherMock.BalancesFunc: method is nil but BalanceFetcher.Balances was just called") + } + callInfo := struct { + Ctx context.Context + Addresses []ethcommon.Address + }{ + Ctx: ctx, + Addresses: addresses, + } + mock.lockBalances.Lock() + mock.calls.Balances = append(mock.calls.Balances, callInfo) + mock.lockBalances.Unlock() + return mock.BalancesFunc(ctx, addresses) +} + +// BalancesCalls gets all the calls that were made to Balances. +// Check the length with: +// +// len(mockedBalanceFetcher.BalancesCalls()) +func (mock *BalanceFetcherMock) BalancesCalls() []struct { + Ctx context.Context + Addresses []ethcommon.Address +} { + var calls []struct { + Ctx context.Context + Addresses []ethcommon.Address + } + mock.lockBalances.RLock() + calls = mock.calls.Balances + mock.lockBalances.RUnlock() + return calls +} diff --git a/backend/coins/eth/updater.go b/backend/coins/eth/updater.go new file mode 100644 index 0000000000..8be591b044 --- /dev/null +++ b/backend/coins/eth/updater.go @@ -0,0 +1,185 @@ +package eth + +import ( + "context" + "fmt" + "math/big" + "net/http" + "time" + + "github.com/BitBoxSwiss/bitbox-wallet-app/backend/coins/eth/etherscan" + "github.com/BitBoxSwiss/bitbox-wallet-app/util/errp" + "github.com/BitBoxSwiss/bitbox-wallet-app/util/logging" + "github.com/ethereum/go-ethereum/common" + "github.com/sirupsen/logrus" + "golang.org/x/time/rate" +) + +// pollInterval is the interval at which the account is polled for updates. +var pollInterval = 5 * time.Minute + +// BalanceFetcher is an interface that defines a method to fetch balances for a list of addresses. +// +//go:generate moq -pkg mocks -out mocks/balancefetcher.go . Interface +type BalanceFetcher interface { + Balances(ctx context.Context, addresses []common.Address) (map[common.Address]*big.Int, error) +} + +// Updater is a struct that takes care of updating ETH accounts. +type Updater struct { + // quit is used to indicate to running goroutines that they should stop as the backend is being closed + quit chan struct{} + + // enqueueUpdateForAccount is used to enqueue an update for a specific ETH account. + enqueueUpdateForAccount <-chan *Account + + // updateETHAccountsCh is used to trigger an update of all ETH accounts. + updateETHAccountsCh chan struct{} + + log *logrus.Entry + + etherscanClient *http.Client + etherscanRateLimiter *rate.Limiter + + // updateAccounts is a function that updates all ETH accounts. + updateAccounts func() error +} + +// NewUpdater creates a new Updater instance. +func NewUpdater( + accountUpdate chan *Account, + etherscanClient *http.Client, + etherscanRateLimiter *rate.Limiter, + updateETHAccounts func() error, +) *Updater { + return &Updater{ + quit: make(chan struct{}), + enqueueUpdateForAccount: accountUpdate, + updateETHAccountsCh: make(chan struct{}), + etherscanClient: etherscanClient, + etherscanRateLimiter: etherscanRateLimiter, + updateAccounts: updateETHAccounts, + log: logging.Get().WithGroup("ethupdater"), + } +} + +// Close closes the updater and its channels. +func (u *Updater) Close() { + close(u.quit) +} + +// EnqueueUpdateForAllAccounts enqueues an update for all ETH accounts. +func (u *Updater) EnqueueUpdateForAllAccounts() { + u.updateETHAccountsCh <- struct{}{} +} + +// PollBalances updates the balances of all ETH accounts. +// It does that in three different cases: +// - When a timer triggers the update. +// - When the signanl to update all accounts is sent through UpdateETHAccountsCh. +// - When a specific account is updated through EnqueueUpdateForAccount. +func (u *Updater) PollBalances() { + timer := time.After(0) + + updateAll := func() { + if err := u.updateAccounts(); err != nil { + u.log.WithError(err).Error("could not update ETH accounts") + } + } + + for { + select { + case <-u.quit: + return + default: + select { + case <-u.quit: + return + case account := <-u.enqueueUpdateForAccount: + go func() { + // A single ETH accounts needs an update. + etherScanClient := etherscan.NewEtherScan(account.ETHCoin().ChainIDstr(), u.etherscanClient, u.etherscanRateLimiter) + u.UpdateBalances([]*Account{account}, etherScanClient) + }() + case <-u.updateETHAccountsCh: + go updateAll() + timer = time.After(pollInterval) + case <-timer: + go updateAll() + timer = time.After(pollInterval) + } + } + } + +} + +// UpdateBalances updates the balances of the accounts in the provided slice. +func (u *Updater) UpdateBalances(accounts []*Account, etherScanClient BalanceFetcher) { + ethNonErc20Addresses := make([]common.Address, 0, len(accounts)) + for _, account := range accounts { + if account.isClosed() { + continue + } + address, err := account.Address() + if err != nil { + u.log.WithError(err).Errorf("Could not get address for account %s", account.Config().Config.Code) + account.SetOffline(err) + continue + } + if !IsERC20(account) { + ethNonErc20Addresses = append(ethNonErc20Addresses, address.Address) + } + } + + updateNonERC20 := true + balances, err := etherScanClient.Balances(context.TODO(), ethNonErc20Addresses) + if err != nil { + u.log.WithError(err).Error("Could not get balances for ETH accounts") + updateNonERC20 = false + } + + for _, account := range accounts { + if account.isClosed() { + continue + } + address, err := account.Address() + if err != nil { + u.log.WithError(err).Errorf("Could not get address for account %s", account.Config().Config.Code) + account.SetOffline(err) + } + var balance *big.Int + switch { + case IsERC20(account): + var err error + balance, err = account.coin.client.ERC20Balance(account.address.Address, account.coin.erc20Token) + if err != nil { + u.log.WithError(err).Errorf("Could not get ERC20 balance for address %s", address.Address.Hex()) + account.SetOffline(err) + } + case updateNonERC20: + var ok bool + balance, ok = balances[address.Address] + if !ok { + errMsg := fmt.Sprintf("Could not find balance for address %s", address.Address.Hex()) + u.log.Error(errMsg) + account.SetOffline(errp.Newf(errMsg)) + } + default: + // If we get there, this is a non-erc20 account and we failed getting balances. + // If we couldn't get the balances for non-erc20 accounts, we mark them as offline + errMsg := fmt.Sprintf("Could not get balance for address %s", address.Address.Hex()) + u.log.Error(errMsg) + account.SetOffline(errp.Newf(errMsg)) + } + + if account.Offline() != nil { + continue // Skip updating balance if the account is offline. + } + if err := account.Update(balance); err != nil { + u.log.WithError(err).Errorf("Could not update balance for address %s", address.Address.Hex()) + account.SetOffline(err) + } else { + account.SetOffline(nil) + } + } +} diff --git a/backend/coins/eth/updater_test.go b/backend/coins/eth/updater_test.go new file mode 100644 index 0000000000..c59f8e8059 --- /dev/null +++ b/backend/coins/eth/updater_test.go @@ -0,0 +1,214 @@ +package eth_test + +import ( + "context" + "math/big" + "net/http" + "os" + "slices" + "testing" + "time" + + "github.com/BitBoxSwiss/bitbox-wallet-app/backend/accounts" + "github.com/BitBoxSwiss/bitbox-wallet-app/backend/coins/coin" + "github.com/BitBoxSwiss/bitbox-wallet-app/backend/coins/eth" + "github.com/BitBoxSwiss/bitbox-wallet-app/backend/coins/eth/erc20" + "github.com/BitBoxSwiss/bitbox-wallet-app/backend/coins/eth/mocks" + rpcclientmocks "github.com/BitBoxSwiss/bitbox-wallet-app/backend/coins/eth/rpcclient/mocks" + "github.com/BitBoxSwiss/bitbox-wallet-app/backend/config" + "github.com/BitBoxSwiss/bitbox-wallet-app/backend/signing" + "github.com/BitBoxSwiss/bitbox-wallet-app/util/errp" + "github.com/BitBoxSwiss/bitbox-wallet-app/util/logging" + "github.com/BitBoxSwiss/bitbox-wallet-app/util/test" + "github.com/btcsuite/btcd/btcutil/hdkeychain" + "github.com/btcsuite/btcd/chaincfg" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/params" + "github.com/stretchr/testify/require" +) + +func newAccount(t *testing.T, erc20Token *erc20.Token, erc20error bool) *eth.Account { + t.Helper() + + log := logging.Get().WithGroup("updater_test") + dbFolder := test.TstTempDir("eth-dbfolder") + defer func() { _ = os.RemoveAll(dbFolder) }() + + net := &chaincfg.TestNet3Params + + keypath, err := signing.NewAbsoluteKeypath("m/60'/1'/0'/0") + require.NoError(t, err) + seed := make([]byte, 32) + if erc20Token != nil { + // For ERC20 tokens, we use a different seed to ensure the final address is + // different. + // We need this because in the test we check which addresses are passed to the + // balanceFetcher, but if the seed is the same, a test case in which we + // have both erc20 and non-erc20 accounts would have the same addresses. + for i := range seed { + seed[i] = byte(i + 1) // just something deterministic + } + } + xpub, err := hdkeychain.NewMaster(seed, net) + require.NoError(t, err) + xpub, err = xpub.Neuter() + require.NoError(t, err) + + signingConfigurations := signing.Configurations{signing.NewEthereumConfiguration( + []byte{1, 2, 3, 4}, + keypath, + xpub)} + client := &rpcclientmocks.InterfaceMock{ + BlockNumberFunc: func(ctx context.Context) (*big.Int, error) { + return big.NewInt(100), nil + }, + ERC20BalanceFunc: func(address common.Address, token *erc20.Token) (*big.Int, error) { + if erc20error { + return nil, errp.New("failed to fetch ERC20 balance") + } + return big.NewInt(1e16), nil // Mock balance for ERC20 token + }, + } + + coin := eth.NewCoin(client, coin.CodeSEPETH, "Sepolia", "SEPETH", "SEPETH", params.SepoliaChainConfig, "", nil, erc20Token) + acct := eth.NewAccount( + &accounts.AccountConfig{ + Config: &config.Account{ + Code: "accountcode", + Name: "accountname", + SigningConfigurations: signingConfigurations, + }, + GetNotifier: func(signing.Configurations) accounts.Notifier { return nil }, + DBFolder: dbFolder, + }, + coin, + &http.Client{}, + log, + make(chan *eth.Account), + ) + + require.NoError(t, acct.Initialize()) + require.NoError(t, acct.Update(big.NewInt(0))) + require.Eventually(t, acct.Synced, time.Second, time.Millisecond*200) + return acct +} + +func assertAccountBalance(t *testing.T, acct *eth.Account, expected *big.Int) { + t.Helper() + balance, err := acct.Balance() + require.NoError(t, err) + require.Equal(t, expected, balance.Available().BigInt()) +} + +func TestUpdateBalances(t *testing.T) { + testCases := []struct { + name string + accounts []*eth.Account + expectedBalances []*big.Int + accountsToClose []int + overrideBalanceFetcher *mocks.BalanceFetcherMock + }{ + { + name: "Single account - non erc20", + accounts: []*eth.Account{newAccount(t, nil, false)}, + expectedBalances: []*big.Int{big.NewInt(1000)}, + }, + { + name: "Single account - erc20", + accounts: []*eth.Account{newAccount(t, erc20.NewToken("0x0000000000000000000000000000000000000001", 12), false)}, + expectedBalances: []*big.Int{big.NewInt(1e16)}, + }, + { + name: "Multiple accounts - one erc20", + accounts: []*eth.Account{ + newAccount(t, nil, false), + newAccount(t, erc20.NewToken("0x0000000000000000000000000000000000000001", 12), false), + }, + expectedBalances: []*big.Int{big.NewInt(1000), big.NewInt(1e16)}, // 1e16 is the balance for the erc20 token + }, + { + name: "Multiple accounts - the nonerc20 account is closed", + accounts: []*eth.Account{ + newAccount(t, nil, false), + newAccount(t, erc20.NewToken("0x0000000000000000000000000000000000000001", 12), false), + }, + expectedBalances: []*big.Int{big.NewInt(1000), big.NewInt(1e16)}, + accountsToClose: []int{0}, + }, + } + + updatedBalances := []common.Address{} + balanceFetcher := mocks.BalanceFetcherMock{ + BalancesFunc: func(ctx context.Context, addresses []common.Address) (map[common.Address]*big.Int, error) { + updatedBalances = addresses + // We mock the balanceFetcher to always return a balance of 1000. + balances := make(map[common.Address]*big.Int) + for _, address := range addresses { + balances[address] = big.NewInt(1000) + } + return balances, nil + }, + } + + updater := eth.NewUpdater(nil, nil, nil, nil) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + for _, acct := range tc.accounts { + defer acct.Close() + } + for _, idx := range tc.accountsToClose { + tc.accounts[idx].Close() + } + + updater.UpdateBalances(tc.accounts, &balanceFetcher) + + for i, acct := range tc.accounts { + accountWasClosed := slices.Contains(tc.accountsToClose, i) + address, err := acct.Address() + require.NoError(t, err) + if accountWasClosed { + // If the account was closed, it must not have its balance updated. + require.NotContains(t, updatedBalances, address.Address) + continue + } + assertAccountBalance(t, acct, tc.expectedBalances[i]) + if eth.IsERC20(acct) { + // ERC20 accounts should not have their balances updated by the balanceFetcher + // since they have their own balance fetching logic. + require.NotContains(t, updatedBalances, address.Address) + } else { + // Non-closed, non-erc20 accounts should have their balances updated + // by the balanceFetcher. + require.Contains(t, updatedBalances, address.Address) + } + } + }) + } + +} + +func TestUpdateBalancesWithError(t *testing.T) { + balanceFetcher := &mocks.BalanceFetcherMock{ + BalancesFunc: func(ctx context.Context, addresses []common.Address) (map[common.Address]*big.Int, error) { + // We mock the balanceFetcher to always return an error. + // This simulates a failure in fetching balances which should set the account to offline. + return nil, errp.New("balance fetch error") + }, + } + + updater := eth.NewUpdater(nil, nil, nil, nil) + account := newAccount(t, nil, false) + defer account.Close() + + updater.UpdateBalances([]*eth.Account{account}, balanceFetcher) + require.Error(t, account.Offline()) + + // We create an ERC20 account and pass "true" to the "erc20error" parameter to simulate an error. + // This way we expect the account to be set offline as well. + erc20Account := newAccount(t, erc20.NewToken("0x0000000000000000000000000000000000000001", 12), true) + defer erc20Account.Close() + + updater.UpdateBalances([]*eth.Account{erc20Account}, balanceFetcher) + require.Error(t, erc20Account.Offline()) + +}