Skip to content

Commit 6af4b02

Browse files
authored
use plain dict and list in grouper [pr] (#10580)
1 parent 4ab3391 commit 6af4b02

File tree

1 file changed

+11
-12
lines changed

1 file changed

+11
-12
lines changed

tinygrad/engine/grouper.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
from collections import defaultdict, deque
21
from dataclasses import dataclass
3-
from tinygrad.uop.ops import UOp, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, identity_element, resolve
4-
from tinygrad.uop.ops import can_pad, sint, track_rewrites, _substitute
2+
from tinygrad.uop.ops import UOp, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, identity_element, resolve, can_pad, sint
3+
from tinygrad.uop.ops import track_rewrites, _substitute
4+
from tinygrad.uop.spec import type_verify, tensor_uop_spec
55
from tinygrad.codegen.lowerer import get_contraction_with_reduce, get_contraction
66
from tinygrad.codegen.symbolic import symbolic_simple
77
from tinygrad.helpers import Metadata, all_int, all_same, colored, prod, dedup, unwrap, getenv, pluralize, ContextVar, Context, diskcache_put
@@ -10,7 +10,6 @@
1010
from tinygrad.engine.multi import multi_pm, replace_allreduce
1111
from tinygrad.shape.shapetracker import ShapeTracker
1212
from tinygrad.shape.view import View, strides_for_shape
13-
from tinygrad.uop.spec import type_verify, tensor_uop_spec
1413

1514
# creation can recurse a lot
1615
import sys
@@ -142,7 +141,7 @@ def realize_before_view(ctx:dict[UOp, None], view:UOp, tr:UOp) -> None:
142141
(UPat((Ops.COPY, Ops.MSELECT), src=(UPat(GroupOp.All-ALWAYS_CONTIGUOUS, name="tr"),), allow_any_len=True), realize),
143142
])
144143

145-
def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:defaultdict[UOp, dict[UOp, None]], realizes:dict[UOp, None],
144+
def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:dict[UOp, dict[UOp, None]], realizes:dict[UOp, None],
146145
reduce_for_op:dict[UOp, UOp], group:dict[UOp, None], cache:dict[tuple[UOp, ShapeTracker], None]) -> None:
147146
if (tr, st) in cache: return
148147
cache.setdefault((tr, st))
@@ -152,7 +151,7 @@ def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:defaultdict[UOp, di
152151
# max one reduceop per kernel
153152
if not st.contiguous or st.size != rsize or tr in reduce_for_op: group.setdefault(r)
154153
return group.setdefault(tr)
155-
for tr_next in children[tr]:
154+
for tr_next in children.get(tr, {}):
156155
# max one reduceop per kernel
157156
if tr_next.op is Ops.REDUCE_AXIS: return group.setdefault(r)
158157
# can only fuse contiguous
@@ -166,12 +165,12 @@ def group_realizes(sink:UOp) -> dict[UOp, None]:
166165
if DONT_GROUP_REDUCES: return realizes
167166

168167
# construct children graph (only for bases)
169-
children: defaultdict[UOp, dict[UOp, None]] = defaultdict(dict)
168+
children: dict[UOp, dict[UOp, None]] = {}
170169
assigns: dict[UOp, None] = {}
171170
for u in (toposort:=sink.toposort()):
172171
if u.op in {Ops.VIEW, Ops.SINK}: continue
173172
if u.op is Ops.ASSIGN: assigns[u.buf_uop] = None
174-
for s in u.src: children[s.base][u] = None
173+
for s in u.src: children.setdefault(s.base, {})[u] = None
175174

176175
# find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
177176
reduce_for_op: dict[UOp, UOp] = {}
@@ -191,7 +190,7 @@ def group_realizes(sink:UOp) -> dict[UOp, None]:
191190
if not forced_realize and len(group) > 1: forced_realize = True
192191
# can only fuse assign if no other assign_target is used in the kernel
193192
if not forced_realize and (assign_targets:={x.buf_uop for x in group if x.op is Ops.ASSIGN}):
194-
parents = deque((r, *group))
193+
parents = [r, *group]
195194
while parents and not forced_realize:
196195
p = parents.pop().base
197196
if p.op is Ops.BUFFER and p in assigns and p not in assign_targets: forced_realize, can_chase = True, False
@@ -202,8 +201,8 @@ def group_realizes(sink:UOp) -> dict[UOp, None]:
202201
if can_chase:
203202
# can chase this down to contiguous children
204203
st = unwrap(tr.st)
205-
while len(children[tr]) == 1:
206-
tr_next = next(iter(children[tr]))
204+
while len(lst:=children.get(tr, {})) == 1:
205+
tr_next = next(iter(lst))
207206
st_childs = dedup(unwrap(s.st) for s in tr_next.src if s.base is tr)
208207
if len(st_childs) > 1: break
209208
if st.size != st_childs[0].size: break
@@ -219,7 +218,7 @@ def group_realizes(sink:UOp) -> dict[UOp, None]:
219218
# fuse double reduces with no other child
220219
for reduceop in double_reduces:
221220
top_reduce = reduceop.src[0].base
222-
if len(children[top_reduce]) == 1: del realizes[top_reduce]
221+
if len(children.get(top_reduce, {})) == 1: del realizes[top_reduce]
223222
return realizes
224223

225224
# **** create kernels

0 commit comments

Comments
 (0)