From fa0845c57581ff0a77b0e87672f8ad4fb7046f95 Mon Sep 17 00:00:00 2001 From: Ryan Brooks Date: Sun, 18 May 2025 06:09:11 -0500 Subject: [PATCH 1/5] Fix all problems with Python 3.13, maintain backward compat --- .github/workflows/run-tests.yml | 2 +- .gitignore | 3 + README.md | 2 +- setup.py | 9 +- tests/test_auth.py | 8 +- tests/test_averaging.py | 211 ++++++++++++++++++++------------ tests/test_utils/p2p_daemon.py | 7 +- 7 files changed, 153 insertions(+), 89 deletions(-) diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml index d1c3fe3ff..b6c5bf608 100644 --- a/.github/workflows/run-tests.yml +++ b/.github/workflows/run-tests.yml @@ -13,7 +13,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [ "3.9", "3.10", "3.11", "3.12" ] + python-version: [ "3.9", "3.10", "3.11", "3.12", "3.13" ] fail-fast: false timeout-minutes: 15 steps: diff --git a/.gitignore b/.gitignore index 2267a8c77..016f7ea70 100644 --- a/.gitignore +++ b/.gitignore @@ -82,3 +82,6 @@ hivemind/proto/*_pb2* # libp2p-daemon binary hivemind/hivemind_cli/p2pd + +# VSCode's default venv directory +.venv \ No newline at end of file diff --git a/README.md b/README.md index 7eb599c80..6ce12cc44 100644 --- a/README.md +++ b/README.md @@ -69,7 +69,7 @@ pip install . ``` If you would like to verify that your installation is working properly, you can install with `pip install .[dev]` -instead. Then, you can run the tests with `pytest tests/`. +instead. Then, you can run the tests with `pytest tests/`. Use `pip install -e .[dev]` to make the source code editable. By default, hivemind uses the precompiled binary of the [go-libp2p-daemon](https://github.com/learning-at-home/go-libp2p-daemon) library. If you face compatibility issues diff --git a/setup.py b/setup.py index e1f90b274..f3ec53f35 100644 --- a/setup.py +++ b/setup.py @@ -9,7 +9,7 @@ import tempfile import urllib.request -from pkg_resources import parse_requirements, parse_version +from packaging.version import parse as parse_version from setuptools import find_packages, setup from setuptools.command.build_py import build_py from setuptools.command.develop import develop @@ -141,7 +141,7 @@ def run(self): with open("requirements.txt") as requirements_file: - install_requires = list(map(str, parse_requirements(requirements_file))) + install_requires = [line.strip() for line in requirements_file if line.strip() and not line.startswith("#")] # loading version from setup.py with codecs.open(os.path.join(here, "hivemind/__init__.py"), encoding="utf-8") as init_file: @@ -151,10 +151,10 @@ def run(self): extras = {} with open("requirements-dev.txt") as dev_requirements_file: - extras["dev"] = list(map(str, parse_requirements(dev_requirements_file))) + extras["dev"] = [line.strip() for line in dev_requirements_file if line.strip() and not line.startswith("#")] with open("requirements-docs.txt") as docs_requirements_file: - extras["docs"] = list(map(str, parse_requirements(docs_requirements_file))) + extras["docs"] = [line.strip() for line in docs_requirements_file if line.strip() and not line.startswith("#")] extras["bitsandbytes"] = ["bitsandbytes~=0.45.2"] @@ -187,6 +187,7 @@ def run(self): "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", "Topic :: Scientific/Engineering", "Topic :: Scientific/Engineering :: Mathematics", "Topic :: Scientific/Engineering :: Artificial Intelligence", diff --git a/tests/test_auth.py b/tests/test_auth.py index 75b647995..193b780d2 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -1,4 +1,4 @@ -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from typing import Optional import pytest @@ -31,7 +31,7 @@ async def get_token(self) -> AccessToken: token = AccessToken( username=self._username, public_key=self.local_public_key.to_bytes(), - expiration_time=str(datetime.utcnow() + timedelta(minutes=1)), + expiration_time=str(datetime.now(timezone.utc).replace(tzinfo=None) + timedelta(minutes=1)), ) token.signature = MockAuthorizer._authority_private_key.sign(self._token_to_bytes(token)) return token @@ -52,7 +52,7 @@ def is_token_valid(self, access_token: AccessToken) -> bool: if expiration_time.tzinfo is not None: logger.exception(f"Expected to have no timezone for expiration time: {access_token.expiration_time}") return False - if expiration_time < datetime.utcnow(): + if expiration_time < datetime.now(timezone.utc).replace(tzinfo=None): logger.exception("Access token has expired") return False @@ -62,7 +62,7 @@ def is_token_valid(self, access_token: AccessToken) -> bool: def does_token_need_refreshing(self, access_token: AccessToken) -> bool: expiration_time = datetime.fromisoformat(access_token.expiration_time) - return expiration_time < datetime.utcnow() + self._MAX_LATENCY + return expiration_time < datetime.now(timezone.utc).replace(tzinfo=None) + self._MAX_LATENCY @staticmethod def _token_to_bytes(access_token: AccessToken) -> bytes: diff --git a/tests/test_averaging.py b/tests/test_averaging.py index 776b45747..9145619b3 100644 --- a/tests/test_averaging.py +++ b/tests/test_averaging.py @@ -1,3 +1,4 @@ +import functools import random import time @@ -17,6 +18,34 @@ from test_utils.dht_swarms import launch_dht_instances +def with_resource_cleanup(func): + """Decorator to ensure resources are cleaned up even if test fails""" + + @functools.wraps(func) + def wrapper(*args, **kwargs): + resources = {"dht_instances": [], "averagers": []} + + try: + # Run the test + result = func(*args, **kwargs, resources=resources) + return result + finally: + # Cleanup all resources + for averager in resources.get("averagers", []): + try: + averager.shutdown() + except Exception: + pass + + for dht in resources.get("dht_instances", []): + try: + dht.shutdown() + except Exception: + pass + + return wrapper + + @pytest.mark.forked @pytest.mark.asyncio async def test_key_manager(): @@ -468,98 +497,128 @@ def test_getset_bits(): @pytest.mark.forked +@pytest.mark.timeout(120) def test_averaging_trigger(): - averagers = tuple( - DecentralizedAverager( - averaged_tensors=[torch.randn(3)], - dht=dht, - min_matchmaking_time=0.5, - request_timeout=0.3, - prefix="mygroup", - initial_group_bits="", - start=True, - ) - for dht in launch_dht_instances(4) - ) - - controls = [] - for i, averager in enumerate(averagers): - controls.append( - averager.step( - wait=False, - scheduled_time=hivemind.get_dht_time() + 0.5, - weight=1.0, - require_trigger=i in (1, 2), + dht_instances = [] + averagers = [] + try: + dht_instances = launch_dht_instances(4) + averagers = [ + DecentralizedAverager( + averaged_tensors=[torch.randn(3)], + dht=dht, + min_matchmaking_time=0.5, + request_timeout=0.3, + prefix="mygroup", + initial_group_bits="", + start=True, + ) + for dht in dht_instances + ] + + controls = [] + for i, averager in enumerate(averagers): + controls.append( + averager.step( + wait=False, + scheduled_time=hivemind.get_dht_time() + 0.5, + weight=1.0, + require_trigger=i in (1, 2), + ) ) - ) - time.sleep(0.6) + time.sleep(0.6) - c0, c1, c2, c3 = controls - assert not any(c.done() for c in controls) - assert c0.stage == AveragingStage.RUNNING_ALLREDUCE - assert c1.stage == AveragingStage.AWAITING_TRIGGER - assert c2.stage == AveragingStage.AWAITING_TRIGGER - assert c3.stage == AveragingStage.RUNNING_ALLREDUCE + c0, c1, c2, c3 = controls + assert not any(c.done() for c in controls) + assert c0.stage == AveragingStage.RUNNING_ALLREDUCE + assert c1.stage == AveragingStage.AWAITING_TRIGGER + assert c2.stage == AveragingStage.AWAITING_TRIGGER + assert c3.stage == AveragingStage.RUNNING_ALLREDUCE - c1.allow_allreduce() - c2.allow_allreduce() + c1.allow_allreduce() + c2.allow_allreduce() - deadline = time.monotonic() + 5.0 - while time.monotonic() < deadline: - if all(c.stage == AveragingStage.FINISHED for c in controls): - break - time.sleep(0.1) - else: - stages = [c.stage for c in controls] - pytest.fail(f"Averaging did not complete in time. Current stages: {stages}") + deadline = time.monotonic() + 10.0 + while time.monotonic() < deadline: + if all(c.stage == AveragingStage.FINISHED for c in controls): + break + time.sleep(0.05) + else: + stages = [c.stage for c in controls] + pytest.fail(f"Averaging did not complete in time. Current stages: {stages}") - assert all(c.done() for c in controls) + assert all(c.done() for c in controls) - # check that setting trigger twice does not raise error - c0.allow_allreduce() + # check that setting trigger twice does not raise error + c0.allow_allreduce() + finally: + # Ensure proper cleanup + for averager in averagers: + try: + averager.shutdown() + except Exception: + pass + for dht in dht_instances: + try: + dht.shutdown() + except Exception: + pass @pytest.mark.forked +@pytest.mark.timeout(120) @pytest.mark.parametrize("target_group_size", [None, 2]) def test_averaging_cancel(target_group_size): - dht_instances = launch_dht_instances(4) - averagers = tuple( - DecentralizedAverager( - averaged_tensors=[torch.randn(3)], - dht=dht, - min_matchmaking_time=0.5, - request_timeout=0.3, - client_mode=(i % 2 == 0), - target_group_size=target_group_size, - prefix="mygroup", - start=True, - ) - for i, dht in enumerate(dht_instances) - ) + dht_instances = [] + averagers = [] + try: + dht_instances = launch_dht_instances(4) + averagers = [ + DecentralizedAverager( + averaged_tensors=[torch.randn(3)], + dht=dht, + min_matchmaking_time=0.5, + request_timeout=0.3, + client_mode=(i % 2 == 0), + target_group_size=target_group_size, + prefix="mygroup", + start=True, + ) + for i, dht in enumerate(dht_instances) + ] - step_controls = [averager.step(wait=False, require_trigger=True) for averager in averagers] + step_controls = [averager.step(wait=False, require_trigger=True) for averager in averagers] - peer_inds_to_cancel = (0, 1) + peer_inds_to_cancel = (0, 1) - for peer_index in peer_inds_to_cancel: - step_controls[peer_index].cancel() + for peer_index in peer_inds_to_cancel: + step_controls[peer_index].cancel() - time.sleep(0.05) + time.sleep(0.05) - for i, control in enumerate(step_controls): - if i not in peer_inds_to_cancel: - control.allow_allreduce() + for i, control in enumerate(step_controls): + if i not in peer_inds_to_cancel: + control.allow_allreduce() - for i, control in enumerate(step_controls): - if i in peer_inds_to_cancel: - assert control.cancelled() - else: - result = control.result() - assert result is not None - # Don't check group size when target_group_size=None, as it could change - if target_group_size is not None: - assert len(result) == target_group_size - - for averager in averagers: - averager.shutdown() + for i, control in enumerate(step_controls): + if i in peer_inds_to_cancel: + assert control.cancelled() + else: + result = control.result() + assert result is not None + # Don't check group size when target_group_size=None, as it could change + if target_group_size is not None: + assert len(result) == target_group_size + finally: + # Ensure proper cleanup + for averager in averagers: + try: + averager.shutdown() + except Exception: + pass + for dht in dht_instances: + try: + dht.shutdown() + except Exception: + pass diff --git a/tests/test_utils/p2p_daemon.py b/tests/test_utils/p2p_daemon.py index db208cdd5..7f26df76c 100644 --- a/tests/test_utils/p2p_daemon.py +++ b/tests/test_utils/p2p_daemon.py @@ -5,17 +5,18 @@ import time import uuid from contextlib import asynccontextmanager, suppress +from pathlib import Path from typing import NamedTuple -from pkg_resources import resource_filename - +import hivemind from hivemind.p2p.p2p_daemon_bindings.p2pclient import Client from hivemind.utils.multiaddr import Multiaddr, protocols from test_utils.networking import get_free_port TIMEOUT_DURATION = 5 # seconds -P2PD_PATH = resource_filename("hivemind", "hivemind_cli/p2pd") +HIVEMIND_ROOT = Path(hivemind.__file__).parent +P2PD_PATH = str(HIVEMIND_ROOT / "hivemind_cli" / "p2pd") async def try_until_success(coro_func, timeout=TIMEOUT_DURATION): From 00def1c74c339598fe9724273905d3de422104d3 Mon Sep 17 00:00:00 2001 From: Ryan Brooks Date: Sun, 18 May 2025 06:30:49 -0500 Subject: [PATCH 2/5] increase timeouts --- tests/test_averaging.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_averaging.py b/tests/test_averaging.py index 9145619b3..cf92bd577 100644 --- a/tests/test_averaging.py +++ b/tests/test_averaging.py @@ -83,7 +83,7 @@ async def test_key_manager(): dht.shutdown() - +@pytest.mark.timeout(300) def _test_allreduce_once(n_clients, n_aux): n_peers = 4 modes = ( @@ -497,7 +497,7 @@ def test_getset_bits(): @pytest.mark.forked -@pytest.mark.timeout(120) +@pytest.mark.timeout(300) def test_averaging_trigger(): dht_instances = [] averagers = [] @@ -567,7 +567,7 @@ def test_averaging_trigger(): @pytest.mark.forked -@pytest.mark.timeout(120) +@pytest.mark.timeout(300) @pytest.mark.parametrize("target_group_size", [None, 2]) def test_averaging_cancel(target_group_size): dht_instances = [] From fcc744a8bc424123b2aba981f82c334d9081ae16 Mon Sep 17 00:00:00 2001 From: Ryan Brooks Date: Sun, 18 May 2025 07:53:15 -0500 Subject: [PATCH 3/5] improve averager timing bugs --- hivemind/averaging/averager.py | 19 +++- tests/test_averaging.py | 162 +++++++++++++++++++++++++-------- 2 files changed, 143 insertions(+), 38 deletions(-) diff --git a/hivemind/averaging/averager.py b/hivemind/averaging/averager.py index 0db6129d8..36a305078 100644 --- a/hivemind/averaging/averager.py +++ b/hivemind/averaging/averager.py @@ -13,7 +13,7 @@ import time import weakref from dataclasses import asdict -from typing import Any, AsyncIterator, Dict, Optional, Sequence, Tuple, Union +from typing import Any, AsyncIterator, Dict, Optional, Sequence, Set, Tuple, Union import numpy as np import torch @@ -201,6 +201,7 @@ def __init__( self._allow_state_sharing = mp.Value(ctypes.c_bool, 0) self._state_sharing_priority = mp.Value(ctypes.c_double, 0) + self._background_tasks: Set[asyncio.Task] = set() if allow_state_sharing is None: allow_state_sharing = not client_mode and not auxiliary @@ -294,7 +295,9 @@ async def _run(): **self.matchmaking_kwargs, ) if not self.client_mode: - asyncio.create_task(self._declare_for_download_periodically()) + task = asyncio.create_task(self._declare_for_download_periodically()) + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) self._state_updated = asyncio.Event() self._pending_groups_registered = asyncio.Event() @@ -352,6 +355,15 @@ def shutdown(self) -> None: logger.exception("Averager shutdown has no effect: the process is already not alive") async def _shutdown(self, timeout: Optional[float]) -> None: + # Cancel background tasks first + for task in list(self._background_tasks): + if not task.done(): + task.cancel() + + # Wait for background tasks to finish with a short timeout + if self._background_tasks: + await asyncio.gather(*self._background_tasks, return_exceptions=True) + if not self.client_mode: await self.remove_p2p_handlers(self._p2p, namespace=self.prefix) @@ -360,6 +372,9 @@ async def _shutdown(self, timeout: Optional[float]) -> None: remaining_tasks.update(group.finalize(cancel=True)) await asyncio.wait_for(asyncio.gather(*remaining_tasks), timeout) + # Give a small delay for any remaining async cleanup + await asyncio.sleep(0.1) + def __del__(self): if self._parent_pid == os.getpid() and self.is_alive(): self.shutdown() diff --git a/tests/test_averaging.py b/tests/test_averaging.py index cf92bd577..6f25a27bb 100644 --- a/tests/test_averaging.py +++ b/tests/test_averaging.py @@ -1,3 +1,4 @@ +import asyncio import functools import random import time @@ -48,42 +49,49 @@ def wrapper(*args, **kwargs): @pytest.mark.forked @pytest.mark.asyncio +@pytest.mark.timeout(30) async def test_key_manager(): - dht = hivemind.DHT(start=True) - key_manager = GroupKeyManager( - dht, - prefix="test_averaging", - initial_group_bits="10110", - target_group_size=2, - ) - alice = dht.peer_id - bob = PeerID(b"bob") + dht = None + try: + dht = hivemind.DHT(start=True) + key_manager = GroupKeyManager( + dht, + prefix="test_averaging", + initial_group_bits="10110", + target_group_size=2, + ) + alice = dht.peer_id + bob = PeerID(b"bob") - t = hivemind.get_dht_time() - key = key_manager.current_key - await key_manager.declare_averager(key, alice, expiration_time=t + 60) - await key_manager.declare_averager(key, bob, expiration_time=t + 61) + t = hivemind.get_dht_time() + key = key_manager.current_key + await key_manager.declare_averager(key, alice, expiration_time=t + 60) + await key_manager.declare_averager(key, bob, expiration_time=t + 61) - q1 = await key_manager.get_averagers(key, only_active=True) + q1 = await key_manager.get_averagers(key, only_active=True) - await key_manager.declare_averager(key, alice, expiration_time=t + 66) - q2 = await key_manager.get_averagers(key, only_active=True) + await key_manager.declare_averager(key, alice, expiration_time=t + 66) + q2 = await key_manager.get_averagers(key, only_active=True) - await key_manager.declare_averager(key, bob, expiration_time=t + 61, looking_for_group=False) - q3 = await key_manager.get_averagers(key, only_active=True) - q4 = await key_manager.get_averagers(key, only_active=False) + await key_manager.declare_averager(key, bob, expiration_time=t + 61, looking_for_group=False) + q3 = await key_manager.get_averagers(key, only_active=True) + q4 = await key_manager.get_averagers(key, only_active=False) - q5 = await key_manager.get_averagers("nonexistent_key.0b0101", only_active=False) + q5 = await key_manager.get_averagers("nonexistent_key.0b0101", only_active=False) - assert len(q1) == 2 and (alice, t + 60) in q1 and (bob, t + 61) in q1 - assert len(q2) == 2 and (alice, t + 66) in q2 and (bob, t + 61) in q2 - assert len(q3) == 1 and (alice, t + 66) in q3 - assert len(q4) == 2 and (alice, t + 66) in q4 and (bob, t + 61) in q2 - assert len(q5) == 0 + assert len(q1) == 2 and (alice, t + 60) in q1 and (bob, t + 61) in q1 + assert len(q2) == 2 and (alice, t + 66) in q2 and (bob, t + 61) in q2 + assert len(q3) == 1 and (alice, t + 66) in q3 + assert len(q4) == 2 and (alice, t + 66) in q4 and (bob, t + 61) in q2 + assert len(q5) == 0 + finally: + if dht: + dht.shutdown() + # Give time for async cleanup + await asyncio.sleep(0.1) - dht.shutdown() -@pytest.mark.timeout(300) +@pytest.mark.timeout(30) def _test_allreduce_once(n_clients, n_aux): n_peers = 4 modes = ( @@ -432,8 +440,10 @@ def get_current_state(self): assert averager2.load_state_from_peers() is None averager1.allow_state_sharing = True - time.sleep(0.5) - got_metadata, got_tensors = averager2.load_state_from_peers() + time.sleep(1.0) # Increased timeout to allow DHT to propagate state sharing announcement + got_result = averager2.load_state_from_peers() + assert got_result is not None, "Failed to load state after enabling sharing" + got_metadata, got_tensors = got_result assert num_calls == 3 assert got_metadata == super_metadata @@ -442,6 +452,7 @@ def get_current_state(self): @pytest.mark.forked +@pytest.mark.skip(reason="Test is flaky due to DHT timing and priority selection non-determinism") def test_load_state_priority(): dht_instances = launch_dht_instances(4) @@ -455,24 +466,71 @@ def test_load_state_priority(): target_group_size=2, allow_state_sharing=i != 1, ) - averager.state_sharing_priority = 5 - abs(2 - i) + # Set distinct priorities to avoid randomness from tie-breaking + # Priorities: [0]=3, [1]=7 (but sharing disabled), [2]=10, [3]=8 + priority_values = [3, 7, 10, 8] + averager.state_sharing_priority = priority_values[i] averagers.append(averager) - time.sleep(0.5) + # Debug output + for i, avg in enumerate(averagers): + print(f"Averager {i}: priority={avg.state_sharing_priority}, sharing={avg.allow_state_sharing}") + + # Wait longer for initial DHT propagation and state announcements + # We need to ensure all nodes have announced their priorities to the DHT + time.sleep(2.0) # Increased wait time + + # First assertion: averager 0 should download from averager 2 (highest priority=10) metadata, tensors = averagers[0].load_state_from_peers(timeout=1) - assert tensors[-1].item() == 2 + loaded_value = tensors[-1].item() + + # Add more aggressive retry logic for flaky DHT propagation + if loaded_value != 2: + for retry in range(3): # Try up to 3 times + time.sleep(1.0) + metadata, tensors = averagers[0].load_state_from_peers(timeout=1) + loaded_value = tensors[-1].item() + if loaded_value == 2: + break + + assert loaded_value == 2 + # Second assertion: averager 2 should download from averager 3 (priority=8, next highest after 2) metadata, tensors = averagers[2].load_state_from_peers(timeout=1) - assert tensors[-1].item() == 3 + loaded_value = tensors[-1].item() + + # Add retry logic for this assertion too + if loaded_value != 3: + for retry in range(3): # Try up to 3 times + time.sleep(1.0) + metadata, tensors = averagers[2].load_state_from_peers(timeout=1) + loaded_value = tensors[-1].item() + if loaded_value == 3: + break - averagers[0].state_sharing_priority = 10 - time.sleep(0.2) + assert loaded_value == 3 + + averagers[0].state_sharing_priority = 15 # Make it highest priority + time.sleep(0.5) # Increased wait time for priority change propagation metadata, tensors = averagers[2].load_state_from_peers(timeout=1) - assert tensors[-1].item() == 0 + loaded_value = tensors[-1].item() + + # Add retry logic for priority change propagation + if loaded_value != 0: + for retry in range(3): # Try up to 3 times + time.sleep(1.0) + metadata, tensors = averagers[2].load_state_from_peers(timeout=1) + loaded_value = tensors[-1].item() + if loaded_value == 0: + break + + assert loaded_value == 0 averagers[1].allow_state_sharing = False averagers[2].allow_state_sharing = False + time.sleep(0.5) # Wait for state sharing changes to propagate + metadata, tensors = averagers[0].load_state_from_peers(timeout=1) assert tensors[-1].item() == 3 @@ -554,17 +612,33 @@ def test_averaging_trigger(): c0.allow_allreduce() finally: # Ensure proper cleanup + # First, try to cancel any pending operations + for control in controls: + if not control.done(): + try: + control.cancel() + except Exception: + pass + + # Then shutdown averagers for averager in averagers: try: averager.shutdown() + # Wait a bit for shutdown to complete + time.sleep(0.1) except Exception: pass + + # Finally shutdown DHT instances for dht in dht_instances: try: dht.shutdown() except Exception: pass + # Give time for all async operations to complete + time.sleep(0.5) + @pytest.mark.forked @pytest.mark.timeout(300) @@ -612,13 +686,29 @@ def test_averaging_cancel(target_group_size): assert len(result) == target_group_size finally: # Ensure proper cleanup + # First, try to cancel any pending operations + for control in step_controls: + if not control.done(): + try: + control.cancel() + except Exception: + pass + + # Then shutdown averagers for averager in averagers: try: averager.shutdown() + # Wait a bit for shutdown to complete + time.sleep(0.1) except Exception: pass + + # Finally shutdown DHT instances for dht in dht_instances: try: dht.shutdown() except Exception: pass + + # Give time for all async operations to complete + time.sleep(0.5) From d6a8fd317459f16d8633eef47a077f68a2aa7959 Mon Sep 17 00:00:00 2001 From: Ryan Brooks Date: Tue, 20 May 2025 18:55:27 -0500 Subject: [PATCH 4/5] add some timeouts --- tests/test_averaging.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_averaging.py b/tests/test_averaging.py index 6f25a27bb..42c479a15 100644 --- a/tests/test_averaging.py +++ b/tests/test_averaging.py @@ -452,7 +452,7 @@ def get_current_state(self): @pytest.mark.forked -@pytest.mark.skip(reason="Test is flaky due to DHT timing and priority selection non-determinism") +@pytest.mark.timeout(30) def test_load_state_priority(): dht_instances = launch_dht_instances(4) @@ -555,7 +555,7 @@ def test_getset_bits(): @pytest.mark.forked -@pytest.mark.timeout(300) +@pytest.mark.timeout(30) def test_averaging_trigger(): dht_instances = [] averagers = [] @@ -641,7 +641,7 @@ def test_averaging_trigger(): @pytest.mark.forked -@pytest.mark.timeout(300) +@pytest.mark.timeout(30) @pytest.mark.parametrize("target_group_size", [None, 2]) def test_averaging_cancel(target_group_size): dht_instances = [] From 6df16208c72043cf552b9b637e00dbf295a6f44f Mon Sep 17 00:00:00 2001 From: Ryan Brooks Date: Tue, 20 May 2025 19:33:19 -0500 Subject: [PATCH 5/5] revert to original averager code --- hivemind/averaging/averager.py | 19 +- tests/test_averaging.py | 385 ++++++++++----------------------- 2 files changed, 120 insertions(+), 284 deletions(-) diff --git a/hivemind/averaging/averager.py b/hivemind/averaging/averager.py index 36a305078..0db6129d8 100644 --- a/hivemind/averaging/averager.py +++ b/hivemind/averaging/averager.py @@ -13,7 +13,7 @@ import time import weakref from dataclasses import asdict -from typing import Any, AsyncIterator, Dict, Optional, Sequence, Set, Tuple, Union +from typing import Any, AsyncIterator, Dict, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -201,7 +201,6 @@ def __init__( self._allow_state_sharing = mp.Value(ctypes.c_bool, 0) self._state_sharing_priority = mp.Value(ctypes.c_double, 0) - self._background_tasks: Set[asyncio.Task] = set() if allow_state_sharing is None: allow_state_sharing = not client_mode and not auxiliary @@ -295,9 +294,7 @@ async def _run(): **self.matchmaking_kwargs, ) if not self.client_mode: - task = asyncio.create_task(self._declare_for_download_periodically()) - self._background_tasks.add(task) - task.add_done_callback(self._background_tasks.discard) + asyncio.create_task(self._declare_for_download_periodically()) self._state_updated = asyncio.Event() self._pending_groups_registered = asyncio.Event() @@ -355,15 +352,6 @@ def shutdown(self) -> None: logger.exception("Averager shutdown has no effect: the process is already not alive") async def _shutdown(self, timeout: Optional[float]) -> None: - # Cancel background tasks first - for task in list(self._background_tasks): - if not task.done(): - task.cancel() - - # Wait for background tasks to finish with a short timeout - if self._background_tasks: - await asyncio.gather(*self._background_tasks, return_exceptions=True) - if not self.client_mode: await self.remove_p2p_handlers(self._p2p, namespace=self.prefix) @@ -372,9 +360,6 @@ async def _shutdown(self, timeout: Optional[float]) -> None: remaining_tasks.update(group.finalize(cancel=True)) await asyncio.wait_for(asyncio.gather(*remaining_tasks), timeout) - # Give a small delay for any remaining async cleanup - await asyncio.sleep(0.1) - def __del__(self): if self._parent_pid == os.getpid() and self.is_alive(): self.shutdown() diff --git a/tests/test_averaging.py b/tests/test_averaging.py index 42c479a15..776b45747 100644 --- a/tests/test_averaging.py +++ b/tests/test_averaging.py @@ -1,5 +1,3 @@ -import asyncio -import functools import random import time @@ -19,79 +17,44 @@ from test_utils.dht_swarms import launch_dht_instances -def with_resource_cleanup(func): - """Decorator to ensure resources are cleaned up even if test fails""" - - @functools.wraps(func) - def wrapper(*args, **kwargs): - resources = {"dht_instances": [], "averagers": []} - - try: - # Run the test - result = func(*args, **kwargs, resources=resources) - return result - finally: - # Cleanup all resources - for averager in resources.get("averagers", []): - try: - averager.shutdown() - except Exception: - pass - - for dht in resources.get("dht_instances", []): - try: - dht.shutdown() - except Exception: - pass - - return wrapper - - @pytest.mark.forked @pytest.mark.asyncio -@pytest.mark.timeout(30) async def test_key_manager(): - dht = None - try: - dht = hivemind.DHT(start=True) - key_manager = GroupKeyManager( - dht, - prefix="test_averaging", - initial_group_bits="10110", - target_group_size=2, - ) - alice = dht.peer_id - bob = PeerID(b"bob") + dht = hivemind.DHT(start=True) + key_manager = GroupKeyManager( + dht, + prefix="test_averaging", + initial_group_bits="10110", + target_group_size=2, + ) + alice = dht.peer_id + bob = PeerID(b"bob") + + t = hivemind.get_dht_time() + key = key_manager.current_key + await key_manager.declare_averager(key, alice, expiration_time=t + 60) + await key_manager.declare_averager(key, bob, expiration_time=t + 61) - t = hivemind.get_dht_time() - key = key_manager.current_key - await key_manager.declare_averager(key, alice, expiration_time=t + 60) - await key_manager.declare_averager(key, bob, expiration_time=t + 61) + q1 = await key_manager.get_averagers(key, only_active=True) - q1 = await key_manager.get_averagers(key, only_active=True) + await key_manager.declare_averager(key, alice, expiration_time=t + 66) + q2 = await key_manager.get_averagers(key, only_active=True) - await key_manager.declare_averager(key, alice, expiration_time=t + 66) - q2 = await key_manager.get_averagers(key, only_active=True) + await key_manager.declare_averager(key, bob, expiration_time=t + 61, looking_for_group=False) + q3 = await key_manager.get_averagers(key, only_active=True) + q4 = await key_manager.get_averagers(key, only_active=False) - await key_manager.declare_averager(key, bob, expiration_time=t + 61, looking_for_group=False) - q3 = await key_manager.get_averagers(key, only_active=True) - q4 = await key_manager.get_averagers(key, only_active=False) + q5 = await key_manager.get_averagers("nonexistent_key.0b0101", only_active=False) - q5 = await key_manager.get_averagers("nonexistent_key.0b0101", only_active=False) + assert len(q1) == 2 and (alice, t + 60) in q1 and (bob, t + 61) in q1 + assert len(q2) == 2 and (alice, t + 66) in q2 and (bob, t + 61) in q2 + assert len(q3) == 1 and (alice, t + 66) in q3 + assert len(q4) == 2 and (alice, t + 66) in q4 and (bob, t + 61) in q2 + assert len(q5) == 0 - assert len(q1) == 2 and (alice, t + 60) in q1 and (bob, t + 61) in q1 - assert len(q2) == 2 and (alice, t + 66) in q2 and (bob, t + 61) in q2 - assert len(q3) == 1 and (alice, t + 66) in q3 - assert len(q4) == 2 and (alice, t + 66) in q4 and (bob, t + 61) in q2 - assert len(q5) == 0 - finally: - if dht: - dht.shutdown() - # Give time for async cleanup - await asyncio.sleep(0.1) + dht.shutdown() -@pytest.mark.timeout(30) def _test_allreduce_once(n_clients, n_aux): n_peers = 4 modes = ( @@ -440,10 +403,8 @@ def get_current_state(self): assert averager2.load_state_from_peers() is None averager1.allow_state_sharing = True - time.sleep(1.0) # Increased timeout to allow DHT to propagate state sharing announcement - got_result = averager2.load_state_from_peers() - assert got_result is not None, "Failed to load state after enabling sharing" - got_metadata, got_tensors = got_result + time.sleep(0.5) + got_metadata, got_tensors = averager2.load_state_from_peers() assert num_calls == 3 assert got_metadata == super_metadata @@ -452,7 +413,6 @@ def get_current_state(self): @pytest.mark.forked -@pytest.mark.timeout(30) def test_load_state_priority(): dht_instances = launch_dht_instances(4) @@ -466,71 +426,24 @@ def test_load_state_priority(): target_group_size=2, allow_state_sharing=i != 1, ) - # Set distinct priorities to avoid randomness from tie-breaking - # Priorities: [0]=3, [1]=7 (but sharing disabled), [2]=10, [3]=8 - priority_values = [3, 7, 10, 8] - averager.state_sharing_priority = priority_values[i] + averager.state_sharing_priority = 5 - abs(2 - i) averagers.append(averager) - # Debug output - for i, avg in enumerate(averagers): - print(f"Averager {i}: priority={avg.state_sharing_priority}, sharing={avg.allow_state_sharing}") - - # Wait longer for initial DHT propagation and state announcements - # We need to ensure all nodes have announced their priorities to the DHT - time.sleep(2.0) # Increased wait time - - # First assertion: averager 0 should download from averager 2 (highest priority=10) + time.sleep(0.5) metadata, tensors = averagers[0].load_state_from_peers(timeout=1) - loaded_value = tensors[-1].item() - - # Add more aggressive retry logic for flaky DHT propagation - if loaded_value != 2: - for retry in range(3): # Try up to 3 times - time.sleep(1.0) - metadata, tensors = averagers[0].load_state_from_peers(timeout=1) - loaded_value = tensors[-1].item() - if loaded_value == 2: - break + assert tensors[-1].item() == 2 - assert loaded_value == 2 - - # Second assertion: averager 2 should download from averager 3 (priority=8, next highest after 2) metadata, tensors = averagers[2].load_state_from_peers(timeout=1) - loaded_value = tensors[-1].item() - - # Add retry logic for this assertion too - if loaded_value != 3: - for retry in range(3): # Try up to 3 times - time.sleep(1.0) - metadata, tensors = averagers[2].load_state_from_peers(timeout=1) - loaded_value = tensors[-1].item() - if loaded_value == 3: - break - - assert loaded_value == 3 + assert tensors[-1].item() == 3 - averagers[0].state_sharing_priority = 15 # Make it highest priority - time.sleep(0.5) # Increased wait time for priority change propagation + averagers[0].state_sharing_priority = 10 + time.sleep(0.2) metadata, tensors = averagers[2].load_state_from_peers(timeout=1) - loaded_value = tensors[-1].item() - - # Add retry logic for priority change propagation - if loaded_value != 0: - for retry in range(3): # Try up to 3 times - time.sleep(1.0) - metadata, tensors = averagers[2].load_state_from_peers(timeout=1) - loaded_value = tensors[-1].item() - if loaded_value == 0: - break - - assert loaded_value == 0 + assert tensors[-1].item() == 0 averagers[1].allow_state_sharing = False averagers[2].allow_state_sharing = False - time.sleep(0.5) # Wait for state sharing changes to propagate - metadata, tensors = averagers[0].load_state_from_peers(timeout=1) assert tensors[-1].item() == 3 @@ -555,160 +468,98 @@ def test_getset_bits(): @pytest.mark.forked -@pytest.mark.timeout(30) def test_averaging_trigger(): - dht_instances = [] - averagers = [] - try: - dht_instances = launch_dht_instances(4) - averagers = [ - DecentralizedAverager( - averaged_tensors=[torch.randn(3)], - dht=dht, - min_matchmaking_time=0.5, - request_timeout=0.3, - prefix="mygroup", - initial_group_bits="", - start=True, - ) - for dht in dht_instances - ] - - controls = [] - for i, averager in enumerate(averagers): - controls.append( - averager.step( - wait=False, - scheduled_time=hivemind.get_dht_time() + 0.5, - weight=1.0, - require_trigger=i in (1, 2), - ) + averagers = tuple( + DecentralizedAverager( + averaged_tensors=[torch.randn(3)], + dht=dht, + min_matchmaking_time=0.5, + request_timeout=0.3, + prefix="mygroup", + initial_group_bits="", + start=True, + ) + for dht in launch_dht_instances(4) + ) + + controls = [] + for i, averager in enumerate(averagers): + controls.append( + averager.step( + wait=False, + scheduled_time=hivemind.get_dht_time() + 0.5, + weight=1.0, + require_trigger=i in (1, 2), ) + ) - time.sleep(0.6) + time.sleep(0.6) - c0, c1, c2, c3 = controls - assert not any(c.done() for c in controls) - assert c0.stage == AveragingStage.RUNNING_ALLREDUCE - assert c1.stage == AveragingStage.AWAITING_TRIGGER - assert c2.stage == AveragingStage.AWAITING_TRIGGER - assert c3.stage == AveragingStage.RUNNING_ALLREDUCE + c0, c1, c2, c3 = controls + assert not any(c.done() for c in controls) + assert c0.stage == AveragingStage.RUNNING_ALLREDUCE + assert c1.stage == AveragingStage.AWAITING_TRIGGER + assert c2.stage == AveragingStage.AWAITING_TRIGGER + assert c3.stage == AveragingStage.RUNNING_ALLREDUCE - c1.allow_allreduce() - c2.allow_allreduce() + c1.allow_allreduce() + c2.allow_allreduce() - deadline = time.monotonic() + 10.0 - while time.monotonic() < deadline: - if all(c.stage == AveragingStage.FINISHED for c in controls): - break - time.sleep(0.05) - else: - stages = [c.stage for c in controls] - pytest.fail(f"Averaging did not complete in time. Current stages: {stages}") - - assert all(c.done() for c in controls) - - # check that setting trigger twice does not raise error - c0.allow_allreduce() - finally: - # Ensure proper cleanup - # First, try to cancel any pending operations - for control in controls: - if not control.done(): - try: - control.cancel() - except Exception: - pass - - # Then shutdown averagers - for averager in averagers: - try: - averager.shutdown() - # Wait a bit for shutdown to complete - time.sleep(0.1) - except Exception: - pass + deadline = time.monotonic() + 5.0 + while time.monotonic() < deadline: + if all(c.stage == AveragingStage.FINISHED for c in controls): + break + time.sleep(0.1) + else: + stages = [c.stage for c in controls] + pytest.fail(f"Averaging did not complete in time. Current stages: {stages}") - # Finally shutdown DHT instances - for dht in dht_instances: - try: - dht.shutdown() - except Exception: - pass + assert all(c.done() for c in controls) - # Give time for all async operations to complete - time.sleep(0.5) + # check that setting trigger twice does not raise error + c0.allow_allreduce() @pytest.mark.forked -@pytest.mark.timeout(30) @pytest.mark.parametrize("target_group_size", [None, 2]) def test_averaging_cancel(target_group_size): - dht_instances = [] - averagers = [] - try: - dht_instances = launch_dht_instances(4) - averagers = [ - DecentralizedAverager( - averaged_tensors=[torch.randn(3)], - dht=dht, - min_matchmaking_time=0.5, - request_timeout=0.3, - client_mode=(i % 2 == 0), - target_group_size=target_group_size, - prefix="mygroup", - start=True, - ) - for i, dht in enumerate(dht_instances) - ] - - step_controls = [averager.step(wait=False, require_trigger=True) for averager in averagers] - - peer_inds_to_cancel = (0, 1) - - for peer_index in peer_inds_to_cancel: - step_controls[peer_index].cancel() - - time.sleep(0.05) - - for i, control in enumerate(step_controls): - if i not in peer_inds_to_cancel: - control.allow_allreduce() - - for i, control in enumerate(step_controls): - if i in peer_inds_to_cancel: - assert control.cancelled() - else: - result = control.result() - assert result is not None - # Don't check group size when target_group_size=None, as it could change - if target_group_size is not None: - assert len(result) == target_group_size - finally: - # Ensure proper cleanup - # First, try to cancel any pending operations - for control in step_controls: - if not control.done(): - try: - control.cancel() - except Exception: - pass - - # Then shutdown averagers - for averager in averagers: - try: - averager.shutdown() - # Wait a bit for shutdown to complete - time.sleep(0.1) - except Exception: - pass - - # Finally shutdown DHT instances - for dht in dht_instances: - try: - dht.shutdown() - except Exception: - pass - - # Give time for all async operations to complete - time.sleep(0.5) + dht_instances = launch_dht_instances(4) + averagers = tuple( + DecentralizedAverager( + averaged_tensors=[torch.randn(3)], + dht=dht, + min_matchmaking_time=0.5, + request_timeout=0.3, + client_mode=(i % 2 == 0), + target_group_size=target_group_size, + prefix="mygroup", + start=True, + ) + for i, dht in enumerate(dht_instances) + ) + + step_controls = [averager.step(wait=False, require_trigger=True) for averager in averagers] + + peer_inds_to_cancel = (0, 1) + + for peer_index in peer_inds_to_cancel: + step_controls[peer_index].cancel() + + time.sleep(0.05) + + for i, control in enumerate(step_controls): + if i not in peer_inds_to_cancel: + control.allow_allreduce() + + for i, control in enumerate(step_controls): + if i in peer_inds_to_cancel: + assert control.cancelled() + else: + result = control.result() + assert result is not None + # Don't check group size when target_group_size=None, as it could change + if target_group_size is not None: + assert len(result) == target_group_size + + for averager in averagers: + averager.shutdown()