From cac92989b633c9d246bc5fcb98f1c0dcb0903f5c Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Sun, 3 Nov 2024 20:51:45 +0000 Subject: [PATCH 01/11] Change the order of operations in test_cli_scripts.py --- tests/test_cli_scripts.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_cli_scripts.py b/tests/test_cli_scripts.py index 97c674000..de5bad3ed 100644 --- a/tests/test_cli_scripts.py +++ b/tests/test_cli_scripts.py @@ -3,7 +3,7 @@ from subprocess import PIPE, Popen from time import sleep -DHT_START_PATTERN = re.compile(r"Running a DHT instance. To connect other peers to this one, use (.+)$") +_DHT_START_PATTERN = re.compile(r"Running a DHT instance. To connect other peers to this one, use (.+)$") def test_dht_connection_successful(): @@ -23,7 +23,7 @@ def test_dht_connection_successful(): first_line = dht_proc.stderr.readline() second_line = dht_proc.stderr.readline() - dht_pattern_match = DHT_START_PATTERN.search(first_line) + dht_pattern_match = _DHT_START_PATTERN.search(first_line) assert dht_pattern_match is not None, first_line assert "Full list of visible multiaddresses:" in second_line, second_line @@ -37,6 +37,9 @@ def test_dht_connection_successful(): env=cloned_env, ) + # ensure we get the output of dht_proc after the start of dht_client_proc + sleep(dht_refresh_period) + # skip first two lines with connectivity info for _ in range(2): dht_client_proc.stderr.readline() @@ -44,9 +47,6 @@ def test_dht_connection_successful(): assert "2 DHT nodes (including this one) are in the local routing table" in first_report_msg, first_report_msg - # ensure we get the output of dht_proc after the start of dht_client_proc - sleep(dht_refresh_period) - # expect that one of the next logging outputs from the first peer shows a new connection for _ in range(5): first_report_msg = dht_proc.stderr.readline() From bed05d71d40bc866b3bc371ecfdcfc6b1a9dc0ac Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Sun, 3 Nov 2024 22:52:23 +0000 Subject: [PATCH 02/11] Clean up resources in DHT and P2P --- hivemind/dht/dht.py | 4 +++- hivemind/hivemind_cli/run_dht.py | 16 +++++++++++++--- hivemind/p2p/p2p_daemon_bindings/control.py | 6 ++++++ hivemind/p2p/p2p_daemon_bindings/utils.py | 2 ++ 4 files changed, 24 insertions(+), 4 deletions(-) diff --git a/hivemind/dht/dht.py b/hivemind/dht/dht.py index 85b371d1c..957c8d3df 100644 --- a/hivemind/dht/dht.py +++ b/hivemind/dht/dht.py @@ -72,7 +72,7 @@ def __init__( self.num_workers = num_workers self._record_validator = CompositeValidator(record_validators) - self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=True) + self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=False) self.shutdown_timeout = shutdown_timeout self._ready = MPFuture() self.daemon = daemon @@ -137,6 +137,7 @@ async def _run(): break loop.run_until_complete(_run()) + loop.close() def run_in_background(self, await_ready: bool = True, timeout: Optional[float] = None) -> None: """ @@ -154,6 +155,7 @@ def shutdown(self) -> None: """Shut down a running dht process""" if self.is_alive(): self._outer_pipe.send(("_shutdown", [], {})) + self._outer_pipe.close() self.join(self.shutdown_timeout) if self.is_alive(): logger.warning("DHT did not shut down within the grace period; terminating it the hard way") diff --git a/hivemind/hivemind_cli/run_dht.py b/hivemind/hivemind_cli/run_dht.py index d72dbd22b..40cd1a0c5 100644 --- a/hivemind/hivemind_cli/run_dht.py +++ b/hivemind/hivemind_cli/run_dht.py @@ -1,6 +1,7 @@ -import time from argparse import ArgumentParser from secrets import token_hex +from signal import SIGINT, SIGTERM, signal, strsignal +from threading import Event from hivemind.dht import DHT, DHTNode from hivemind.utils.logging import get_logger, use_hivemind_log_handler @@ -72,6 +73,8 @@ def main(): args = parser.parse_args() + exit_event = Event() + dht = DHT( start=True, initial_peers=args.initial_peers, @@ -84,10 +87,17 @@ def main(): ) log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=args.use_ipfs) + def signal_handler(signum: int, _) -> None: + logger.info(f"Caught signal {signum} ({strsignal(signum)}), shutting down") + exit_event.set() + + signal(SIGTERM, signal_handler) + signal(SIGINT, signal_handler) + try: - while True: + while not exit_event.is_set(): dht.run_coroutine(report_status, return_future=False) - time.sleep(args.refresh_period) + exit_event.wait(args.refresh_period) except KeyboardInterrupt: logger.info("Caught KeyboardInterrupt, shutting down") finally: diff --git a/hivemind/p2p/p2p_daemon_bindings/control.py b/hivemind/p2p/p2p_daemon_bindings/control.py index 4f229bbdb..a8de5d74e 100644 --- a/hivemind/p2p/p2p_daemon_bindings/control.py +++ b/hivemind/p2p/p2p_daemon_bindings/control.py @@ -322,6 +322,7 @@ async def identify(self) -> Tuple[PeerID, Tuple[Multiaddr, ...]]: resp = p2pd_pb.Response() # type: ignore await read_pbmsg_safe(reader, resp) writer.close() + await writer.wait_closed() raise_if_failed(resp) peer_id_bytes = resp.identify.id @@ -343,6 +344,7 @@ async def connect(self, peer_id: PeerID, maddrs: Iterable[Multiaddr]) -> None: resp = p2pd_pb.Response() # type: ignore await read_pbmsg_safe(reader, resp) writer.close() + await writer.wait_closed() raise_if_failed(resp) async def list_peers(self) -> Tuple[PeerInfo, ...]: @@ -352,6 +354,7 @@ async def list_peers(self) -> Tuple[PeerInfo, ...]: resp = p2pd_pb.Response() # type: ignore await read_pbmsg_safe(reader, resp) writer.close() + await writer.wait_closed() raise_if_failed(resp) peers = tuple(PeerInfo.from_protobuf(pinfo) for pinfo in resp.peers) @@ -365,6 +368,7 @@ async def disconnect(self, peer_id: PeerID) -> None: resp = p2pd_pb.Response() # type: ignore await read_pbmsg_safe(reader, resp) writer.close() + await writer.wait_closed() raise_if_failed(resp) async def stream_open( @@ -403,6 +407,7 @@ async def stream_handler(self, proto: str, handler_cb: StreamHandler, balanced: resp = p2pd_pb.Response() # type: ignore await read_pbmsg_safe(reader, resp) writer.close() + await writer.wait_closed() raise_if_failed(resp) async def remove_stream_handler(self, proto: str) -> None: @@ -420,6 +425,7 @@ async def remove_stream_handler(self, proto: str) -> None: resp = p2pd_pb.Response() # type: ignore await read_pbmsg_safe(reader, resp) writer.close() + await writer.wait_closed() raise_if_failed(resp) del self.handlers[proto] diff --git a/hivemind/p2p/p2p_daemon_bindings/utils.py b/hivemind/p2p/p2p_daemon_bindings/utils.py index c8ca87901..4a1f106c6 100644 --- a/hivemind/p2p/p2p_daemon_bindings/utils.py +++ b/hivemind/p2p/p2p_daemon_bindings/utils.py @@ -46,6 +46,7 @@ async def write_unsigned_varint(stream: asyncio.StreamWriter, integer: int, max_ value |= 0x80 byte = value.to_bytes(1, "big") stream.write(byte) + await stream.drain() if integer == 0: break @@ -77,6 +78,7 @@ async def write_pbmsg(stream: asyncio.StreamWriter, pbmsg: PBMessage) -> None: await write_unsigned_varint(stream, size) msg_bytes: bytes = pbmsg.SerializeToString() stream.write(msg_bytes) + await stream.drain() async def read_pbmsg_safe(stream: asyncio.StreamReader, pbmsg: PBMessage) -> None: From 80e3ec3cd404cdea1eab9c093901e7e2f5201952 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Sun, 3 Nov 2024 22:52:46 +0000 Subject: [PATCH 03/11] Update the tests --- tests/test_allreduce.py | 4 +++- tests/test_cli_scripts.py | 16 +++++++++++++--- tests/test_p2p_daemon_bindings.py | 3 ++- tests/test_start_server.py | 7 +++++++ 4 files changed, 25 insertions(+), 5 deletions(-) diff --git a/tests/test_allreduce.py b/tests/test_allreduce.py index 43a1fdc5a..fb86951be 100644 --- a/tests/test_allreduce.py +++ b/tests/test_allreduce.py @@ -108,7 +108,9 @@ async def wait_synchronously(): wall_time = time.perf_counter() - start_time # check that event loop had enough time to respond to incoming requests; this is over 50% most of the time # we set 33% threshold to ensure that the test will pass reliably. If we break prefetch, this drops to <10% - assert time_in_waiting > wall_time / 3, f"Event loop could only run {time_in_waiting / wall_time :.5f} of the time" + assert ( + time_in_waiting > wall_time / 3 + ), f"Event loop could only run {time_in_waiting / wall_time * 100 :.5f}% of the time" @pytest.mark.parametrize("num_senders", [1, 2, 4, 10]) diff --git a/tests/test_cli_scripts.py b/tests/test_cli_scripts.py index de5bad3ed..f9e044947 100644 --- a/tests/test_cli_scripts.py +++ b/tests/test_cli_scripts.py @@ -30,7 +30,14 @@ def test_dht_connection_successful(): initial_peers = dht_pattern_match.group(1).split(" ") dht_client_proc = Popen( - ["hivemind-dht", *initial_peers, "--host_maddrs", "/ip4/127.0.0.1/tcp/0"], + [ + "hivemind-dht", + *initial_peers, + "--host_maddrs", + "/ip4/127.0.0.1/tcp/0", + "--refresh_period", + str(dht_refresh_period), + ], stderr=PIPE, text=True, encoding="utf-8", @@ -38,7 +45,7 @@ def test_dht_connection_successful(): ) # ensure we get the output of dht_proc after the start of dht_client_proc - sleep(dht_refresh_period) + sleep(2 * dht_refresh_period) # skip first two lines with connectivity info for _ in range(2): @@ -48,7 +55,7 @@ def test_dht_connection_successful(): assert "2 DHT nodes (including this one) are in the local routing table" in first_report_msg, first_report_msg # expect that one of the next logging outputs from the first peer shows a new connection - for _ in range(5): + for _ in range(10): first_report_msg = dht_proc.stderr.readline() second_report_msg = dht_proc.stderr.readline() @@ -63,6 +70,9 @@ def test_dht_connection_successful(): and "Local storage contains 0 keys" in second_report_msg ) + dht_proc.stderr.close() + dht_client_proc.stderr.close() + dht_proc.terminate() dht_client_proc.terminate() diff --git a/tests/test_p2p_daemon_bindings.py b/tests/test_p2p_daemon_bindings.py index 71658f2ef..d9160e173 100644 --- a/tests/test_p2p_daemon_bindings.py +++ b/tests/test_p2p_daemon_bindings.py @@ -71,7 +71,8 @@ async def readexactly(self, n): class MockWriter(io.BytesIO): - pass + async def drain(self): + pass class MockReaderWriter(MockReader, MockWriter): diff --git a/tests/test_start_server.py b/tests/test_start_server.py index b84dd5407..972131742 100644 --- a/tests/test_start_server.py +++ b/tests/test_start_server.py @@ -27,11 +27,16 @@ def test_cli_run_server_identity_path(): with TemporaryDirectory() as tempdir: id_path = os.path.join(tempdir, "id") + cloned_env = os.environ.copy() + # overriding the loglevel to prevent debug print statements + cloned_env["HIVEMIND_LOGLEVEL"] = "INFO" + server_1_proc = Popen( ["hivemind-server", "--num_experts", "1", "--identity_path", id_path], stderr=PIPE, text=True, encoding="utf-8", + env=cloned_env, ) line = server_1_proc.stderr.readline() @@ -50,6 +55,7 @@ def test_cli_run_server_identity_path(): stderr=PIPE, text=True, encoding="utf-8", + env=cloned_env, ) line = server_2_proc.stderr.readline() @@ -65,6 +71,7 @@ def test_cli_run_server_identity_path(): stderr=PIPE, text=True, encoding="utf-8", + env=cloned_env, ) line = server_3_proc.stderr.readline() From 30ad8826f58b02a17975867e9d2091819c87e738 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Sun, 3 Nov 2024 23:11:15 +0000 Subject: [PATCH 04/11] Gracefully handle SIGTERM in run_server.py --- hivemind/hivemind_cli/run_dht.py | 2 -- hivemind/hivemind_cli/run_server.py | 16 ++++++++++++++-- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/hivemind/hivemind_cli/run_dht.py b/hivemind/hivemind_cli/run_dht.py index 40cd1a0c5..19b1f10a3 100644 --- a/hivemind/hivemind_cli/run_dht.py +++ b/hivemind/hivemind_cli/run_dht.py @@ -98,8 +98,6 @@ def signal_handler(signum: int, _) -> None: while not exit_event.is_set(): dht.run_coroutine(report_status, return_future=False) exit_event.wait(args.refresh_period) - except KeyboardInterrupt: - logger.info("Caught KeyboardInterrupt, shutting down") finally: dht.shutdown() diff --git a/hivemind/hivemind_cli/run_server.py b/hivemind/hivemind_cli/run_server.py index 1c6bc9a09..d3fab8248 100644 --- a/hivemind/hivemind_cli/run_server.py +++ b/hivemind/hivemind_cli/run_server.py @@ -1,5 +1,7 @@ from functools import partial from pathlib import Path +from signal import SIGINT, SIGTERM, signal, strsignal +from threading import Event import configargparse import torch @@ -102,12 +104,22 @@ def main(): compression_type = args.pop("compression") compression = getattr(CompressionType, compression_type) + exit_event = Event() + server = Server.create(**args, optim_cls=optim_cls, start=True, compression=compression) + def signal_handler(signum: int, _) -> None: + logger.info(f"Caught signal {signum} ({strsignal(signum)}), shutting down") + exit_event.set() + + signal(SIGTERM, signal_handler) + signal(SIGINT, signal_handler) + try: + exit_event.wait() + finally: + server.shutdown() server.join() - except KeyboardInterrupt: - logger.info("Caught KeyboardInterrupt, shutting down") if __name__ == "__main__": From c8dee5573029dcec970be41a88ecfb416a44cfd7 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Sun, 3 Nov 2024 23:30:38 +0000 Subject: [PATCH 05/11] Optimize tests, add another synchronization event in test_mpfuture_done_callback --- tests/test_start_server.py | 8 +++++--- tests/test_util_modules.py | 15 +++++++++------ 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/tests/test_start_server.py b/tests/test_start_server.py index 972131742..b85507c1a 100644 --- a/tests/test_start_server.py +++ b/tests/test_start_server.py @@ -31,8 +31,10 @@ def test_cli_run_server_identity_path(): # overriding the loglevel to prevent debug print statements cloned_env["HIVEMIND_LOGLEVEL"] = "INFO" + common_server_args = ["--hidden_dim", "4", "--num_handlers", "1"] + server_1_proc = Popen( - ["hivemind-server", "--num_experts", "1", "--identity_path", id_path], + ["hivemind-server", "--num_experts", "1", "--identity_path", id_path] + common_server_args, stderr=PIPE, text=True, encoding="utf-8", @@ -51,7 +53,7 @@ def test_cli_run_server_identity_path(): assert len(ids_1) == 1 server_2_proc = Popen( - ["hivemind-server", "--num_experts", "1", "--identity_path", id_path], + ["hivemind-server", "--num_experts", "1", "--identity_path", id_path] + common_server_args, stderr=PIPE, text=True, encoding="utf-8", @@ -67,7 +69,7 @@ def test_cli_run_server_identity_path(): assert len(ids_2) == 1 server_3_proc = Popen( - ["hivemind-server", "--num_experts", "1"], + ["hivemind-server", "--num_experts", "1"] + common_server_args, stderr=PIPE, text=True, encoding="utf-8", diff --git a/tests/test_util_modules.py b/tests/test_util_modules.py index f245b777e..b4c284dd3 100644 --- a/tests/test_util_modules.py +++ b/tests/test_util_modules.py @@ -233,7 +233,7 @@ def _future_creator(): @pytest.mark.forked def test_mpfuture_done_callback(): receiver, sender = mp.Pipe(duplex=False) - events = [mp.Event() for _ in range(6)] + events = [mp.Event() for _ in range(7)] def _future_creator(): future1, future2, future3 = hivemind.MPFuture(), hivemind.MPFuture(), hivemind.MPFuture() @@ -250,7 +250,7 @@ def _check_result_and_set(future): sender.send((future1, future2)) future2.cancel() # trigger future2 callback from the same process - + events[6].set() events[0].wait() future1.add_done_callback( lambda future: events[4].set() @@ -262,6 +262,7 @@ def _check_result_and_set(future): future1, future2 = receiver.recv() future1.set_result(123) + events[6].wait() with pytest.raises(RuntimeError): future1.add_done_callback(lambda future: (1, 2, 3)) @@ -514,21 +515,23 @@ async def test_async_context_flooding(): Here's how the test below works: suppose that the thread pool has at most N workers; If at least N + 1 coroutines await lock1 concurrently, N of them occupy workers and the rest are awaiting workers; - When the first of N workers acquires lock1, it lets coroutine A inside lock1 and into await sleep(1e-2); + When the first of N workers acquires lock1, it lets coroutine A inside lock1 and into await sleep(); During that sleep, one of the worker-less coroutines will take up the worker freed by coroutine A. Finally, coroutine A finishes sleeping and immediately gets stuck at lock2, because there are no free workers. Thus, every single coroutine is either awaiting an already acquired lock, or awaiting for free workers in executor. """ + total_sleep_time = 1 lock1, lock2 = mp.Lock(), mp.Lock() + + num_coros = max(33, mp.cpu_count() * 5 + 1) async def coro(): async with enter_asynchronously(lock1): - await asyncio.sleep(1e-2) + await asyncio.sleep(total_sleep_time/(num_coros*2)) async with enter_asynchronously(lock2): - await asyncio.sleep(1e-2) + await asyncio.sleep(total_sleep_time/(num_coros*2)) - num_coros = max(33, mp.cpu_count() * 5 + 1) await asyncio.wait({asyncio.create_task(coro()) for _ in range(num_coros)}) From 0b1264de8443dea87adfab9e769cea6bc0444b76 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Sun, 3 Nov 2024 23:33:19 +0000 Subject: [PATCH 06/11] Reformat code with black --- tests/test_util_modules.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_util_modules.py b/tests/test_util_modules.py index b4c284dd3..33b4597fe 100644 --- a/tests/test_util_modules.py +++ b/tests/test_util_modules.py @@ -523,14 +523,14 @@ async def test_async_context_flooding(): """ total_sleep_time = 1 lock1, lock2 = mp.Lock(), mp.Lock() - + num_coros = max(33, mp.cpu_count() * 5 + 1) async def coro(): async with enter_asynchronously(lock1): - await asyncio.sleep(total_sleep_time/(num_coros*2)) + await asyncio.sleep(total_sleep_time / (num_coros * 2)) async with enter_asynchronously(lock2): - await asyncio.sleep(total_sleep_time/(num_coros*2)) + await asyncio.sleep(total_sleep_time / (num_coros * 2)) await asyncio.wait({asyncio.create_task(coro()) for _ in range(num_coros)}) From 85c86ad02e774283bc95f96bf212a0388bcbd14c Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Mon, 4 Nov 2024 00:20:58 +0000 Subject: [PATCH 07/11] Disable fail-fast for test matrix --- .github/workflows/run-tests.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml index 74d778bd0..11792a3c9 100644 --- a/.github/workflows/run-tests.yml +++ b/.github/workflows/run-tests.yml @@ -12,6 +12,7 @@ jobs: strategy: matrix: python-version: [ '3.8', '3.9', '3.10', '3.11' ] + fail-fast: false timeout-minutes: 15 steps: - uses: actions/checkout@v3 From e93416c72560edd8f92c87bbe21190f6a1a38590 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Mon, 4 Nov 2024 02:09:24 +0000 Subject: [PATCH 08/11] Try temporary fix of test_client_anomaly_detection with DHT init --- tests/test_moe.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_moe.py b/tests/test_moe.py index f62c2159d..efeef6c65 100644 --- a/tests/test_moe.py +++ b/tests/test_moe.py @@ -282,6 +282,7 @@ def test_client_anomaly_detection(): experts["expert.3"].module.ffn.weight.data[0, 0] = float("nan") dht = DHT(start=True) + dht.get_visible_maddrs(latest=True) server = Server(dht, experts, num_connection_handlers=1) server.start() try: From 17d2a8228f8bd6f077e5a7ec7fd5d9ce55eee77e Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Mon, 4 Nov 2024 10:33:22 +0000 Subject: [PATCH 09/11] Move exit_event creation closer to signal handler --- hivemind/hivemind_cli/run_dht.py | 8 ++++---- hivemind/hivemind_cli/run_server.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/hivemind/hivemind_cli/run_dht.py b/hivemind/hivemind_cli/run_dht.py index 19b1f10a3..62e783f5a 100644 --- a/hivemind/hivemind_cli/run_dht.py +++ b/hivemind/hivemind_cli/run_dht.py @@ -73,8 +73,6 @@ def main(): args = parser.parse_args() - exit_event = Event() - dht = DHT( start=True, initial_peers=args.initial_peers, @@ -86,9 +84,11 @@ def main(): use_auto_relay=args.use_auto_relay, ) log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=args.use_ipfs) + + exit_event = Event() - def signal_handler(signum: int, _) -> None: - logger.info(f"Caught signal {signum} ({strsignal(signum)}), shutting down") + def signal_handler(signal_number: int, _) -> None: + logger.info(f"Caught signal {signal_number} ({strsignal(signal_number)}), shutting down") exit_event.set() signal(SIGTERM, signal_handler) diff --git a/hivemind/hivemind_cli/run_server.py b/hivemind/hivemind_cli/run_server.py index d3fab8248..70cb27df1 100644 --- a/hivemind/hivemind_cli/run_server.py +++ b/hivemind/hivemind_cli/run_server.py @@ -104,12 +104,12 @@ def main(): compression_type = args.pop("compression") compression = getattr(CompressionType, compression_type) - exit_event = Event() - server = Server.create(**args, optim_cls=optim_cls, start=True, compression=compression) + + exit_event = Event() - def signal_handler(signum: int, _) -> None: - logger.info(f"Caught signal {signum} ({strsignal(signum)}), shutting down") + def signal_handler(signal_number: int, _) -> None: + logger.info(f"Caught signal {signal_number} ({strsignal(signal_number)}), shutting down") exit_event.set() signal(SIGTERM, signal_handler) From aa9caf49c1168bab5c5b1ace2dd6270a666f663f Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Mon, 4 Nov 2024 10:38:33 +0000 Subject: [PATCH 10/11] Reformat code with black --- hivemind/hivemind_cli/run_dht.py | 2 +- hivemind/hivemind_cli/run_server.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/hivemind/hivemind_cli/run_dht.py b/hivemind/hivemind_cli/run_dht.py index 62e783f5a..64f831e15 100644 --- a/hivemind/hivemind_cli/run_dht.py +++ b/hivemind/hivemind_cli/run_dht.py @@ -84,7 +84,7 @@ def main(): use_auto_relay=args.use_auto_relay, ) log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=args.use_ipfs) - + exit_event = Event() def signal_handler(signal_number: int, _) -> None: diff --git a/hivemind/hivemind_cli/run_server.py b/hivemind/hivemind_cli/run_server.py index 70cb27df1..b5abd529d 100644 --- a/hivemind/hivemind_cli/run_server.py +++ b/hivemind/hivemind_cli/run_server.py @@ -105,7 +105,7 @@ def main(): compression = getattr(CompressionType, compression_type) server = Server.create(**args, optim_cls=optim_cls, start=True, compression=compression) - + exit_event = Event() def signal_handler(signal_number: int, _) -> None: From 059095455a56402e3a2b0e00d82d3e46068ba121 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Mon, 4 Nov 2024 12:08:25 +0000 Subject: [PATCH 11/11] Acquire locks before mp.Value updates --- tests/test_moe.py | 3 ++- tests/test_optimizer.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/test_moe.py b/tests/test_moe.py index efeef6c65..d788cb0dc 100644 --- a/tests/test_moe.py +++ b/tests/test_moe.py @@ -319,7 +319,8 @@ def test_client_anomaly_detection(): def _measure_coro_running_time(n_coros, elapsed_fut, counter): async def coro(): await asyncio.sleep(0.1) - counter.value += 1 + with counter.get_lock(): + counter.value += 1 try: start_time = time.perf_counter() diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index c859e3879..16fb7f2f3 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -414,8 +414,8 @@ def run_trainer(batch_size: int, batch_time: float, client_mode: bool): loss.backward() optimizer.step() - - total_samples_accumulated.value += batch_size + with total_samples_accumulated.get_lock(): + total_samples_accumulated.value += batch_size if not reuse_grad_buffers: optimizer.zero_grad()