Skip to content

Commit 3aded3b

Browse files
committed
allow_any_len in load pattern matcher to fix temp load inside loop
1 parent a3049a7 commit 3aded3b

File tree

1 file changed

+10
-12
lines changed

1 file changed

+10
-12
lines changed

tinygrad/renderer/wgsl.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,14 @@ def packed_load(root:UOp, bidx:UOp, dtype:DType, var:UOp|None=None):
2525
val = (load.cast(dtypes.uint32) >> shift_am) & (0xFF if dtype.itemsize == 1 else 0xFFFF)
2626
return sign_extend(val, 8*dtype.itemsize).cast(dtype) if dtype in [dtypes.char, dtypes.short] else val.cast(dtype)
2727

28-
# No packing for shared mem
29-
def is_packed(idx:UOp, dt:DType) -> bool: return dt.itemsize < 4 and dt.base != dtypes.half and idx.src[0].op != Ops.DEFINE_LOCAL
28+
def is_packed(dt:DType) -> bool: return dt.itemsize < 4 and dt.base != dtypes.half
3029

3130
wgsl_matcher = PatternMatcher([
3231
(UPat((Ops.CMPLT, Ops.XOR), src=(UPat(name="a", dtype=dtypes.bool), UPat.var("b")), name="c"),
3332
lambda a,b,c: a.cast(dtypes.int).alu(c.op, b.cast(dtypes.int)).cast(dtypes.bool)),
34-
(UPat(Ops.LOAD, name="l", src=(UPat.var("b"),)), lambda l,b: packed_load(l, b, l.dtype) if is_packed(b,l.dtype) else None),
35-
(UPat(Ops.LOAD, name="l", src=(UPat.var("b"), UPat.cvar("c"))),
36-
lambda l,b,c: packed_load(l,b,l.dtype,c.cast(dtypes.uint32)) if is_packed(b,l.dtype) else None),
37-
(UPat.store(UPat.var("bidx"), UPat.var("var"), allow_any_len=True), lambda bidx,var: packed_store(bidx,var) if is_packed(bidx,var.dtype) else None),
33+
(UPat.load(UPat.var("b"), UPat.cvar("c"), name="l"),lambda l,b,c: packed_load(l,b,l.dtype,c.cast(dtypes.uint32)) if is_packed(l.dtype) else None),
34+
(UPat.load(UPat.var("b"), name='l', allow_any_len=True), lambda l,b: packed_load(l, b, l.dtype) if is_packed(l.dtype) else None),
35+
(UPat.store(UPat.var("bidx"), UPat.var("var"), allow_any_len=True), lambda bidx,var: packed_store(bidx,var) if is_packed(var.dtype) else None),
3836
# TODO: why is this needed, and only for this MUL order
3937
(UPat(Ops.MUL, src=(UPat.var("a"), UPat.var("g").where(UPat.cvar("c1"), UPat.cvar("c2")))),
4038
lambda a,g,c1,c2: g.where(c1, a) if math.isnan(c1.arg) and c2.arg == 1.0 else None),
@@ -58,18 +56,18 @@ class WGSLRenderer(CStyleLanguage):
5856
(UPat.cvar("x", dtype=dtypes.bool), lambda x: "true" if x.arg else "false"),
5957
(UPat(Ops.CONST, dtype=(dtypes.uchar, dtypes.ushort, dtypes.uint32), name="x"),
6058
lambda x: f"bitcast<u32>({x.arg})" if x.arg < 0 else f"{x.arg&0xFFFFFFFF}u"),
61-
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"var<workgroup> {ctx[x]}: array<{ctx.type_map[x.dtype.base]}, {x.dtype.size}>;"),
59+
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"var<workgroup> {ctx[x]}: array<{ctx.buf_map(x.dtype.base)}, {x.dtype.size}>;"),
6260
(UPat(Ops.BITCAST, dtype=dtypes.half, name="x", src=(UPat(dtype=(dtypes.short, dtypes.ushort, dtypes.uint32),),)),
6361
lambda ctx,x: f"bitcast<vec2<f16>>({ctx[x.src[0]]})[0]"),
6462
(UPat(Ops.BITCAST, dtype=(dtypes.char, dtypes.uchar), name="x"), lambda ctx,x: f"bitcast<{ctx.type_map[x.dtype]}>({ctx[x.src[0]]}&0xFF)"),
6563
(UPat(Ops.BITCAST, dtype=(dtypes.short, dtypes.ushort), name="x"),lambda ctx,x:f"bitcast<{ctx.type_map[x.dtype]}>(vec2<f16>({ctx[x.src[0]]},0))" \
6664
if x.src[0].dtype == dtypes.half else f"bitcast<{ctx.type_map[x.dtype]}>({ctx[x.src[0]]}&0xFFFF)"),
6765
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"bitcast<{ctx.type_map[x.dtype]}>({ctx[x.src[0]]})"),
68-
(UPat.load(UPat.var("b"), UPat.cvar("v")),lambda ctx,b,v: f"select({ctx[v]}, {ctx.render_load(b, ctx[b],b.src[0].dtype)}, {ctx[b.src[2]]})"),
69-
(UPat.load(UPat.var("b"), allow_any_len=True), lambda ctx, b: ctx.render_load(b, ctx[b], b.dtype)),
66+
(UPat.load(UPat.var("b"), UPat.cvar("v")),lambda ctx,b,v: f"select({ctx[v]}, {ctx.render_load(ctx[b],b.src[0].dtype)}, {ctx[b.src[2]]})"),
67+
(UPat.load(UPat.var("b"), allow_any_len=True), lambda ctx, b: ctx.render_load(ctx[b], b.dtype)),
7068
(UPat.store(UPat.var("b"), UPat.var("v"), allow_any_len=True),lambda ctx,b,v:\
7169
# (load & mask) | var -> mask = v.src[0].src[1], var = v.src[1]
72-
f"atomicAnd(&{ctx[b]},{ctx[v.src[0].src[1]]});\n atomicAdd(&{ctx[b]},{ctx[v.src[1]]});" if is_packed(b, b.src[0].dtype) \
70+
f"atomicAnd(&{ctx[b]},{ctx[v.src[0].src[1]]});\n atomicAdd(&{ctx[b]},{ctx[v.src[1]]});" if is_packed(b.src[0].dtype) \
7371
else f"{ctx[b]} = {ctx[v]};"),
7472
(UPat(Ops.INDEX, src=(UPat.var("b"), UPat.var("idx")), allow_any_len=True),
7573
lambda ctx,b,idx: f"{ctx[b]}[{strip_parens(ctx[idx]) if idx.arg is Ops.ADD else ctx[idx]}]"),
@@ -79,8 +77,8 @@ class WGSLRenderer(CStyleLanguage):
7977

8078
def render_cast(self, dt:DType, val: str) -> str: return f"{self.type_map[dt]}({val})"
8179
def render_dtype(self, dt:DType, mutable=True) -> str: return "var"
82-
def render_load(self, idx:UOp, x:str, dt:DType) -> str: return f"atomicLoad(&{x})" if is_packed(idx, dt) else x
83-
def buf_map(self, dt:DType) -> str: return "atomic<u32>" if dt.itemsize < 4 and dt.base != dtypes.half else self.type_map[dt.base]
80+
def render_load(self, x:str, dt:DType) -> str: return f"atomicLoad(&{x})" if is_packed(dt) else x
81+
def buf_map(self, dt:DType) -> str: return "atomic<u32>" if is_packed(dt) else self.type_map[dt.base]
8482
def render_kernel(self, function_name:str, kernel:list[str], bufs:list[tuple[str,tuple[DType,bool]]], uops:list[UOp], prefix=None) -> str:
8583
local_size = [num for _, num in sorted([u.arg for u in uops if u.op is Ops.SPECIAL and u.arg[0][0] == 'l'], key=lambda x: x[0])]
8684
if not local_size: local_size = [1]

0 commit comments

Comments
 (0)