13
13
import inspect
14
14
import itertools
15
15
import logging
16
+ import os
16
17
import pickle
17
18
import random
18
19
import traceback
56
57
MonitoredOncePortReceiver ,
57
58
MonitoredPortReceiver ,
58
59
PythonActorMesh ,
60
+ PythonActorMeshRef ,
59
61
)
60
62
from monarch ._rust_bindings .monarch_hyperactor .mailbox import (
61
63
Mailbox ,
65
67
PortRef ,
66
68
)
67
69
from monarch ._rust_bindings .monarch_hyperactor .proc import ActorId
70
+ from monarch ._rust_bindings .monarch_hyperactor .selection import Selection as HySelection
68
71
from monarch ._rust_bindings .monarch_hyperactor .shape import Point as HyPoint , Shape
69
72
from monarch ._rust_bindings .monarch_hyperactor .supervision import SupervisionError
70
-
71
73
from monarch ._rust_bindings .monarch_hyperactor .telemetry import enter_span , exit_span
72
74
from monarch ._src .actor .allocator import LocalAllocator , ProcessAllocator
73
75
from monarch ._src .actor .future import Future
@@ -146,6 +148,28 @@ def set(debug_context: "DebugContext") -> None:
146
148
Selection = Literal ["all" , "choose" ] | int # TODO: replace with real selection objects
147
149
148
150
151
+ def to_hy_sel (selection : Selection , shape : Shape ) -> HySelection :
152
+ if selection == "choose" :
153
+ dim = len (shape .labels )
154
+ assert dim > 0
155
+ query = "," .join (["?" ] * dim )
156
+ return HySelection .from_string (f"{ query } " )
157
+ elif selection == "all" :
158
+ return HySelection .from_string ("*" )
159
+ else :
160
+ raise ValueError (f"invalid selection: { selection } " )
161
+
162
+
163
+ # A temporary gate used by the PythonActorMesh/PythonActorMeshRef migration.
164
+ # We can use this gate to quickly roll back to using _ActorMeshRefImpl, if we
165
+ # encounter any issues with the migration.
166
+ #
167
+ # This should be removed once we confirm PythonActorMesh/PythonActorMeshRef is
168
+ # working correctly in production.
169
+ def _use_standin_mesh () -> bool :
170
+ return bool (os .getenv ("USE_STANDIN_ACTOR_MESH" , default = False ))
171
+
172
+
149
173
# standin class for whatever is the serializable python object we use
150
174
# to name an actor mesh. Hacked up today because ActorMesh
151
175
# isn't plumbed to non-clients
@@ -158,6 +182,10 @@ def __init__(
158
182
shape : Shape ,
159
183
actor_ids : List [ActorId ],
160
184
) -> None :
185
+ if not _use_standin_mesh ():
186
+ raise ValueError (
187
+ "ActorMeshRefImpl should only be used when USE_STANDIN_ACTOR_MESH is set"
188
+ )
161
189
self ._mailbox = mailbox
162
190
self ._actor_mesh = hy_actor_mesh
163
191
# actor meshes do not have a way to look this up at the moment,
@@ -296,8 +324,8 @@ def __init__(
296
324
297
325
def cast (
298
326
self ,
299
- message : PythonMessage ,
300
327
selection : Selection ,
328
+ message : PythonMessage ,
301
329
) -> None :
302
330
self ._mailbox .post (self ._actor_id , message )
303
331
@@ -309,6 +337,110 @@ def monitor(self) -> Optional[ActorMeshMonitor]:
309
337
return None
310
338
311
339
340
+ # A temporary wrapper used by the PythonActorMesh/PythonActorMeshRef migration.
341
+ # This wrapper is used to enable switching between PythonActorMeshRef and
342
+ # _ActorMeshRefImpl through the `USE_STANDIN_ACTOR_MESH` env var.
343
+ #
344
+ # This should be removed once we confirm PythonActorMesh/PythonActorMeshRef is
345
+ # working correctly in production.
346
+ class EitherPyActorMeshRef :
347
+ def __init__ (self , inner : PythonActorMeshRef | _ActorMeshRefImpl ) -> None :
348
+ if _use_standin_mesh ():
349
+ assert isinstance (
350
+ inner , _ActorMeshRefImpl
351
+ ), "expect _ActorMeshRefImpl because env var USE_STANDIN_ACTOR_MESH is set"
352
+ else :
353
+ assert isinstance (
354
+ inner , PythonActorMeshRef
355
+ ), "expect PythonActorMeshRef because env var USE_STANDIN_ACTOR_MESH is not set"
356
+ self ._inner : PythonActorMeshRef | _ActorMeshRefImpl = inner
357
+
358
+ def cast (
359
+ self , mailbox : Mailbox , selection : Selection , message : PythonMessage
360
+ ) -> None :
361
+ inner = self ._inner
362
+ if isinstance (inner , _ActorMeshRefImpl ):
363
+ inner .cast (message , selection )
364
+ elif isinstance (inner , PythonActorMeshRef ):
365
+ inner .cast (mailbox , to_hy_sel (selection , self .shape ), message )
366
+ else :
367
+ raise ValueError (f"unsupported mesh type: { inner .__class__ .__name__ } " )
368
+
369
+ def slice (self , ** kwargs ) -> "EitherPyActorMeshRef" :
370
+ return EitherPyActorMeshRef (self ._inner .slice (** kwargs ))
371
+
372
+ @property
373
+ def shape (self ) -> Shape :
374
+ return self ._inner .shape
375
+
376
+ def monitor (self ) -> Optional [ActorMeshMonitor ]:
377
+ return None
378
+
379
+
380
+ # A temporary wrapper used by the PythonActorMesh/PythonActorMesh migration.
381
+ # This wrapper is used to enable switching between PythonActorMesh and
382
+ # _ActorMeshRefImpl through the `USE_STANDIN_ACTOR_MESH` env var.
383
+ #
384
+ # This should be removed once we confirm PythonActorMesh/PythonActorMeshRef is
385
+ # working correctly in production.
386
+ class EitherPyActorMesh :
387
+ def __init__ (
388
+ self , actor_mesh : PythonActorMesh , mailbox : Mailbox , proc_mesh : "ProcMesh"
389
+ ) -> None :
390
+ if _use_standin_mesh ():
391
+ inner = _ActorMeshRefImpl .from_hyperactor_mesh (
392
+ mailbox , actor_mesh , proc_mesh
393
+ )
394
+ else :
395
+ inner = actor_mesh
396
+ self ._inner : PythonActorMesh | _ActorMeshRefImpl = inner
397
+ self ._proc_mesh = proc_mesh
398
+
399
+ def bind (self ) -> "EitherPyActorMeshRef" :
400
+ inner = self ._inner
401
+ if isinstance (inner , PythonActorMesh ):
402
+ return EitherPyActorMeshRef (inner .bind ())
403
+ elif isinstance (inner , _ActorMeshRefImpl ):
404
+ return EitherPyActorMeshRef (inner )
405
+ else :
406
+ raise ValueError (f"unsupported mesh type: { inner .__class__ .__name__ } " )
407
+
408
+ def cast (self , selection : Selection , message : PythonMessage ) -> None :
409
+ inner = self ._inner
410
+ if isinstance (inner , _ActorMeshRefImpl ):
411
+ inner .cast (message , selection )
412
+ elif isinstance (inner , PythonActorMesh ):
413
+ inner .cast (to_hy_sel (selection , self .shape ), message )
414
+ else :
415
+ raise ValueError (f"unsupported mesh type: { inner .__class__ .__name__ } " )
416
+
417
+ def slice (self , ** kwargs ) -> "EitherPyActorMeshRef" :
418
+ return EitherPyActorMeshRef (self ._inner .slice (** kwargs ))
419
+
420
+ @property
421
+ def shape (self ) -> Shape :
422
+ return self ._inner .shape
423
+
424
+ def monitor (self ) -> Optional [ActorMeshMonitor ]:
425
+ return self ._inner .monitor ()
426
+
427
+ @property
428
+ def proc_mesh (self ) -> "ProcMesh" :
429
+ return self ._proc_mesh
430
+
431
+ @property
432
+ def name_pid (self ) -> Tuple [str , int ]:
433
+ inner = self ._inner
434
+ if isinstance (inner , _ActorMeshRefImpl ):
435
+ return inner ._name_pid
436
+ elif isinstance (inner , PythonActorMesh ):
437
+ actor_id0 = inner .get (0 )
438
+ assert actor_id0 is not None
439
+ return actor_id0 .actor_name , actor_id0 .pid
440
+ else :
441
+ raise ValueError (f"unsupported mesh type: { inner .__class__ .__name__ } " )
442
+
443
+
312
444
class Extent (NamedTuple ):
313
445
labels : Sequence [str ]
314
446
sizes : Sequence [int ]
@@ -377,26 +509,28 @@ def call(self, *args: P.args, **kwargs: P.kwargs) -> "Future[ValueMesh[R]]":
377
509
extent = self ._send (args , kwargs , port = p )
378
510
379
511
async def process () -> ValueMesh [R ]:
380
- results : List [ R ] = [ None ] * extent . nelements # pyre-fixme[9]
512
+ results : Dict [ int , R ] = dict ()
381
513
for _ in range (extent .nelements ):
382
514
rank , value = await r .recv ()
383
515
results [rank ] = value
384
516
call_shape = Shape (
385
517
extent .labels ,
386
518
NDSlice .new_row_major (extent .sizes ),
387
519
)
388
- return ValueMesh (call_shape , results )
520
+ sorted_values = [results [rank ] for rank in sorted (results )]
521
+ return ValueMesh (call_shape , sorted_values )
389
522
390
523
def process_blocking () -> ValueMesh [R ]:
391
- results : List [ R ] = [ None ] * extent . nelements # pyre-fixme[9]
524
+ results : Dict [ int , R ] = dict ()
392
525
for _ in range (extent .nelements ):
393
526
rank , value = r .recv ().get ()
394
527
results [rank ] = value
395
528
call_shape = Shape (
396
529
extent .labels ,
397
530
NDSlice .new_row_major (extent .sizes ),
398
531
)
399
- return ValueMesh (call_shape , results )
532
+ sorted_values = [results [rank ] for rank in sorted (results )]
533
+ return ValueMesh (call_shape , sorted_values )
400
534
401
535
return Future (process , process_blocking )
402
536
@@ -428,7 +562,7 @@ def broadcast(self, *args: P.args, **kwargs: P.kwargs) -> None:
428
562
class ActorEndpoint (Endpoint [P , R ]):
429
563
def __init__ (
430
564
self ,
431
- actor_mesh_ref : _ActorMeshRefImpl | ActorIdFakeMesh ,
565
+ actor_mesh_ref : EitherPyActorMesh | EitherPyActorMeshRef | ActorIdFakeMesh ,
432
566
name : str ,
433
567
impl : Callable [Concatenate [Any , P ], Awaitable [R ]],
434
568
mailbox : Mailbox ,
@@ -469,7 +603,15 @@ def _send(
469
603
),
470
604
bytes ,
471
605
)
472
- self ._actor_mesh .cast (message , selection )
606
+ mesh = self ._actor_mesh
607
+ if isinstance (mesh , EitherPyActorMeshRef ):
608
+ mesh .cast (self ._mailbox , selection , message )
609
+ elif isinstance (mesh , EitherPyActorMesh ) or isinstance (
610
+ mesh , ActorIdFakeMesh
611
+ ):
612
+ mesh .cast (selection , message )
613
+ else :
614
+ raise ValueError (f"unsupported mesh type: { mesh .__class__ .__name__ } " )
473
615
else :
474
616
importlib .import_module ("monarch." + "mesh_controller" ).actor_send (
475
617
self , self ._name , bytes , refs , port
@@ -931,12 +1073,14 @@ class _ActorMeshTrait(MeshTrait):
931
1073
def __init__ (
932
1074
self ,
933
1075
Class : Type [T ],
934
- actor_mesh_ref : _ActorMeshRefImpl | ActorIdFakeMesh ,
1076
+ actor_mesh_ref : EitherPyActorMesh | EitherPyActorMeshRef | ActorIdFakeMesh ,
935
1077
mailbox : Mailbox ,
936
1078
) -> None :
937
1079
self .__name__ : str = Class .__name__
938
1080
self ._class : Type [T ] = Class
939
- self ._actor_mesh_ref : _ActorMeshRefImpl | ActorIdFakeMesh = actor_mesh_ref
1081
+ self ._actor_mesh_ref : (
1082
+ EitherPyActorMesh | EitherPyActorMeshRef | ActorIdFakeMesh
1083
+ ) = actor_mesh_ref
940
1084
self ._mailbox : Mailbox = mailbox
941
1085
for attr_name in dir (self ._class ):
942
1086
attr_value = getattr (self ._class , attr_name , None )
@@ -1003,20 +1147,22 @@ class ActorMeshHandle(_ActorMeshTrait, Generic[T]):
1003
1147
def __init__ (
1004
1148
self ,
1005
1149
Class : Type [T ],
1006
- actor_mesh : _ActorMeshRefImpl ,
1150
+ actor_mesh : PythonActorMesh ,
1007
1151
mailbox : Mailbox ,
1152
+ proc_mesh : "ProcMesh" ,
1008
1153
) -> None :
1009
- super ().__init__ (Class , actor_mesh , mailbox )
1154
+ wrapper = EitherPyActorMesh (actor_mesh , mailbox , proc_mesh )
1155
+ super ().__init__ (Class , wrapper , mailbox )
1010
1156
1011
- def _inner (self ) -> "_ActorMeshRefImpl " :
1157
+ def _inner (self ) -> "EitherPyActorMesh " :
1012
1158
mesh = self ._actor_mesh_ref
1013
1159
assert isinstance (
1014
- mesh , _ActorMeshRefImpl
1160
+ mesh , EitherPyActorMesh
1015
1161
), f"mesh type is { mesh .__class__ .__name__ } "
1016
1162
return mesh
1017
1163
1018
1164
def bind (self ) -> "ActorMeshRef[T]" :
1019
- return ActorMeshRef (self ._class , self ._inner (), self ._mailbox )
1165
+ return ActorMeshRef (self ._class , self ._inner (). bind () , self ._mailbox )
1020
1166
1021
1167
def _create (
1022
1168
self ,
@@ -1056,22 +1202,22 @@ def proc_mesh(self) -> "Optional[ProcMesh]":
1056
1202
1057
1203
@property
1058
1204
def name_pid (self ) -> Tuple [str , int ]:
1059
- return self ._inner ()._name_pid
1205
+ return self ._inner ().name_pid
1060
1206
1061
1207
1062
1208
class ActorMeshRef (_ActorMeshTrait , Generic [T ]):
1063
1209
def __init__ (
1064
1210
self ,
1065
1211
Class : Type [T ],
1066
- actor_mesh : _ActorMeshRefImpl ,
1212
+ actor_mesh : EitherPyActorMeshRef ,
1067
1213
mailbox : Mailbox ,
1068
1214
) -> None :
1069
1215
super ().__init__ (Class , actor_mesh , mailbox )
1070
1216
1071
- def _inner (self ) -> "_ActorMeshRefImpl " :
1217
+ def _inner (self ) -> "EitherPyActorMeshRef " :
1072
1218
mesh = self ._actor_mesh_ref
1073
1219
assert isinstance (
1074
- mesh , _ActorMeshRefImpl
1220
+ mesh , EitherPyActorMeshRef
1075
1221
), f"mesh type is { mesh .__class__ .__name__ } "
1076
1222
return mesh
1077
1223
0 commit comments