1
- from collections import defaultdict , deque
2
1
from dataclasses import dataclass
3
- from tinygrad .uop .ops import UOp , Ops , GroupOp , PatternMatcher , UPat , graph_rewrite , graph_rewrite_map , identity_element , resolve
4
- from tinygrad .uop .ops import can_pad , sint , track_rewrites , _substitute
2
+ from tinygrad .uop .ops import UOp , Ops , GroupOp , PatternMatcher , UPat , graph_rewrite , graph_rewrite_map , identity_element , resolve , can_pad , sint
3
+ from tinygrad .uop .ops import track_rewrites , _substitute
4
+ from tinygrad .uop .spec import type_verify , tensor_uop_spec
5
5
from tinygrad .codegen .lowerer import get_contraction_with_reduce , get_contraction
6
6
from tinygrad .codegen .symbolic import symbolic_simple
7
7
from tinygrad .helpers import Metadata , all_int , all_same , colored , prod , dedup , unwrap , getenv , pluralize , ContextVar , Context , diskcache_put
10
10
from tinygrad .engine .multi import multi_pm , replace_allreduce
11
11
from tinygrad .shape .shapetracker import ShapeTracker
12
12
from tinygrad .shape .view import View , strides_for_shape
13
- from tinygrad .uop .spec import type_verify , tensor_uop_spec
14
13
15
14
# creation can recurse a lot
16
15
import sys
@@ -142,7 +141,7 @@ def realize_before_view(ctx:dict[UOp, None], view:UOp, tr:UOp) -> None:
142
141
(UPat ((Ops .COPY , Ops .MSELECT ), src = (UPat (GroupOp .All - ALWAYS_CONTIGUOUS , name = "tr" ),), allow_any_len = True ), realize ),
143
142
])
144
143
145
- def recursive_group (tr :UOp , st :ShapeTracker , r :UOp , children :defaultdict [UOp , dict [UOp , None ]], realizes :dict [UOp , None ],
144
+ def recursive_group (tr :UOp , st :ShapeTracker , r :UOp , children :dict [UOp , dict [UOp , None ]], realizes :dict [UOp , None ],
146
145
reduce_for_op :dict [UOp , UOp ], group :dict [UOp , None ], cache :dict [tuple [UOp , ShapeTracker ], None ]) -> None :
147
146
if (tr , st ) in cache : return
148
147
cache .setdefault ((tr , st ))
@@ -152,7 +151,7 @@ def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:defaultdict[UOp, di
152
151
# max one reduceop per kernel
153
152
if not st .contiguous or st .size != rsize or tr in reduce_for_op : group .setdefault (r )
154
153
return group .setdefault (tr )
155
- for tr_next in children [ tr ] :
154
+ for tr_next in children . get ( tr , {}) :
156
155
# max one reduceop per kernel
157
156
if tr_next .op is Ops .REDUCE_AXIS : return group .setdefault (r )
158
157
# can only fuse contiguous
@@ -166,12 +165,12 @@ def group_realizes(sink:UOp) -> dict[UOp, None]:
166
165
if DONT_GROUP_REDUCES : return realizes
167
166
168
167
# construct children graph (only for bases)
169
- children : defaultdict [UOp , dict [UOp , None ]] = defaultdict ( dict )
168
+ children : dict [UOp , dict [UOp , None ]] = {}
170
169
assigns : dict [UOp , None ] = {}
171
170
for u in (toposort := sink .toposort ()):
172
171
if u .op in {Ops .VIEW , Ops .SINK }: continue
173
172
if u .op is Ops .ASSIGN : assigns [u .buf_uop ] = None
174
- for s in u .src : children [ s .base ] [u ] = None
173
+ for s in u .src : children . setdefault ( s .base , {}) [u ] = None
175
174
176
175
# find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
177
176
reduce_for_op : dict [UOp , UOp ] = {}
@@ -191,7 +190,7 @@ def group_realizes(sink:UOp) -> dict[UOp, None]:
191
190
if not forced_realize and len (group ) > 1 : forced_realize = True
192
191
# can only fuse assign if no other assign_target is used in the kernel
193
192
if not forced_realize and (assign_targets := {x .buf_uop for x in group if x .op is Ops .ASSIGN }):
194
- parents = deque (( r , * group ))
193
+ parents = [ r , * group ]
195
194
while parents and not forced_realize :
196
195
p = parents .pop ().base
197
196
if p .op is Ops .BUFFER and p in assigns and p not in assign_targets : forced_realize , can_chase = True , False
@@ -202,8 +201,8 @@ def group_realizes(sink:UOp) -> dict[UOp, None]:
202
201
if can_chase :
203
202
# can chase this down to contiguous children
204
203
st = unwrap (tr .st )
205
- while len (children [ tr ] ) == 1 :
206
- tr_next = next (iter (children [ tr ] ))
204
+ while len (lst := children . get ( tr , {}) ) == 1 :
205
+ tr_next = next (iter (lst ))
207
206
st_childs = dedup (unwrap (s .st ) for s in tr_next .src if s .base is tr )
208
207
if len (st_childs ) > 1 : break
209
208
if st .size != st_childs [0 ].size : break
@@ -219,7 +218,7 @@ def group_realizes(sink:UOp) -> dict[UOp, None]:
219
218
# fuse double reduces with no other child
220
219
for reduceop in double_reduces :
221
220
top_reduce = reduceop .src [0 ].base
222
- if len (children [ top_reduce ] ) == 1 : del realizes [top_reduce ]
221
+ if len (children . get ( top_reduce , {}) ) == 1 : del realizes [top_reduce ]
223
222
return realizes
224
223
225
224
# **** create kernels
0 commit comments