@@ -25,16 +25,14 @@ def packed_load(root:UOp, bidx:UOp, dtype:DType, var:UOp|None=None):
25
25
val = (load .cast (dtypes .uint32 ) >> shift_am ) & (0xFF if dtype .itemsize == 1 else 0xFFFF )
26
26
return sign_extend (val , 8 * dtype .itemsize ).cast (dtype ) if dtype in [dtypes .char , dtypes .short ] else val .cast (dtype )
27
27
28
- # No packing for shared mem
29
- def is_packed (idx :UOp , dt :DType ) -> bool : return dt .itemsize < 4 and dt .base != dtypes .half and idx .src [0 ].op != Ops .DEFINE_LOCAL
28
+ def is_packed (dt :DType ) -> bool : return dt .itemsize < 4 and dt .base != dtypes .half
30
29
31
30
wgsl_matcher = PatternMatcher ([
32
31
(UPat ((Ops .CMPLT , Ops .XOR ), src = (UPat (name = "a" , dtype = dtypes .bool ), UPat .var ("b" )), name = "c" ),
33
32
lambda a ,b ,c : a .cast (dtypes .int ).alu (c .op , b .cast (dtypes .int )).cast (dtypes .bool )),
34
- (UPat (Ops .LOAD , name = "l" , src = (UPat .var ("b" ),)), lambda l ,b : packed_load (l , b , l .dtype ) if is_packed (b ,l .dtype ) else None ),
35
- (UPat (Ops .LOAD , name = "l" , src = (UPat .var ("b" ), UPat .cvar ("c" ))),
36
- lambda l ,b ,c : packed_load (l ,b ,l .dtype ,c .cast (dtypes .uint32 )) if is_packed (b ,l .dtype ) else None ),
37
- (UPat .store (UPat .var ("bidx" ), UPat .var ("var" ), allow_any_len = True ), lambda bidx ,var : packed_store (bidx ,var ) if is_packed (bidx ,var .dtype ) else None ),
33
+ (UPat .load (UPat .var ("b" ), UPat .cvar ("c" ), name = "l" ),lambda l ,b ,c : packed_load (l ,b ,l .dtype ,c .cast (dtypes .uint32 )) if is_packed (l .dtype ) else None ),
34
+ (UPat .load (UPat .var ("b" ), name = 'l' , allow_any_len = True ), lambda l ,b : packed_load (l , b , l .dtype ) if is_packed (l .dtype ) else None ),
35
+ (UPat .store (UPat .var ("bidx" ), UPat .var ("var" ), allow_any_len = True ), lambda bidx ,var : packed_store (bidx ,var ) if is_packed (var .dtype ) else None ),
38
36
# TODO: why is this needed, and only for this MUL order
39
37
(UPat (Ops .MUL , src = (UPat .var ("a" ), UPat .var ("g" ).where (UPat .cvar ("c1" ), UPat .cvar ("c2" )))),
40
38
lambda a ,g ,c1 ,c2 : g .where (c1 , a ) if math .isnan (c1 .arg ) and c2 .arg == 1.0 else None ),
@@ -58,18 +56,18 @@ class WGSLRenderer(CStyleLanguage):
58
56
(UPat .cvar ("x" , dtype = dtypes .bool ), lambda x : "true" if x .arg else "false" ),
59
57
(UPat (Ops .CONST , dtype = (dtypes .uchar , dtypes .ushort , dtypes .uint32 ), name = "x" ),
60
58
lambda x : f"bitcast<u32>({ x .arg } )" if x .arg < 0 else f"{ x .arg & 0xFFFFFFFF } u" ),
61
- (UPat (Ops .DEFINE_LOCAL , name = "x" ), lambda ctx ,x : f"var<workgroup> { ctx [x ]} : array<{ ctx .type_map [ x .dtype .base ] } , { x .dtype .size } >;" ),
59
+ (UPat (Ops .DEFINE_LOCAL , name = "x" ), lambda ctx ,x : f"var<workgroup> { ctx [x ]} : array<{ ctx .buf_map ( x .dtype .base ) } , { x .dtype .size } >;" ),
62
60
(UPat (Ops .BITCAST , dtype = dtypes .half , name = "x" , src = (UPat (dtype = (dtypes .short , dtypes .ushort , dtypes .uint32 ),),)),
63
61
lambda ctx ,x : f"bitcast<vec2<f16>>({ ctx [x .src [0 ]]} )[0]" ),
64
62
(UPat (Ops .BITCAST , dtype = (dtypes .char , dtypes .uchar ), name = "x" ), lambda ctx ,x : f"bitcast<{ ctx .type_map [x .dtype ]} >({ ctx [x .src [0 ]]} &0xFF)" ),
65
63
(UPat (Ops .BITCAST , dtype = (dtypes .short , dtypes .ushort ), name = "x" ),lambda ctx ,x :f"bitcast<{ ctx .type_map [x .dtype ]} >(vec2<f16>({ ctx [x .src [0 ]]} ,0))" \
66
64
if x .src [0 ].dtype == dtypes .half else f"bitcast<{ ctx .type_map [x .dtype ]} >({ ctx [x .src [0 ]]} &0xFFFF)" ),
67
65
(UPat (Ops .BITCAST , name = "x" ), lambda ctx ,x : f"bitcast<{ ctx .type_map [x .dtype ]} >({ ctx [x .src [0 ]]} )" ),
68
- (UPat .load (UPat .var ("b" ), UPat .cvar ("v" )),lambda ctx ,b ,v : f"select({ ctx [v ]} , { ctx .render_load (b , ctx [b ],b .src [0 ].dtype )} , { ctx [b .src [2 ]]} )" ),
69
- (UPat .load (UPat .var ("b" ), allow_any_len = True ), lambda ctx , b : ctx .render_load (b , ctx [b ], b .dtype )),
66
+ (UPat .load (UPat .var ("b" ), UPat .cvar ("v" )),lambda ctx ,b ,v : f"select({ ctx [v ]} , { ctx .render_load (ctx [b ],b .src [0 ].dtype )} , { ctx [b .src [2 ]]} )" ),
67
+ (UPat .load (UPat .var ("b" ), allow_any_len = True ), lambda ctx , b : ctx .render_load (ctx [b ], b .dtype )),
70
68
(UPat .store (UPat .var ("b" ), UPat .var ("v" ), allow_any_len = True ),lambda ctx ,b ,v :\
71
69
# (load & mask) | var -> mask = v.src[0].src[1], var = v.src[1]
72
- f"atomicAnd(&{ ctx [b ]} ,{ ctx [v .src [0 ].src [1 ]]} );\n atomicAdd(&{ ctx [b ]} ,{ ctx [v .src [1 ]]} );" if is_packed (b , b .src [0 ].dtype ) \
70
+ f"atomicAnd(&{ ctx [b ]} ,{ ctx [v .src [0 ].src [1 ]]} );\n atomicAdd(&{ ctx [b ]} ,{ ctx [v .src [1 ]]} );" if is_packed (b .src [0 ].dtype ) \
73
71
else f"{ ctx [b ]} = { ctx [v ]} ;" ),
74
72
(UPat (Ops .INDEX , src = (UPat .var ("b" ), UPat .var ("idx" )), allow_any_len = True ),
75
73
lambda ctx ,b ,idx : f"{ ctx [b ]} [{ strip_parens (ctx [idx ]) if idx .arg is Ops .ADD else ctx [idx ]} ]" ),
@@ -79,8 +77,8 @@ class WGSLRenderer(CStyleLanguage):
79
77
80
78
def render_cast (self , dt :DType , val : str ) -> str : return f"{ self .type_map [dt ]} ({ val } )"
81
79
def render_dtype (self , dt :DType , mutable = True ) -> str : return "var"
82
- def render_load (self , idx : UOp , x :str , dt :DType ) -> str : return f"atomicLoad(&{ x } )" if is_packed (idx , dt ) else x
83
- def buf_map (self , dt :DType ) -> str : return "atomic<u32>" if dt . itemsize < 4 and dt . base != dtypes . half else self .type_map [dt .base ]
80
+ def render_load (self , x :str , dt :DType ) -> str : return f"atomicLoad(&{ x } )" if is_packed (dt ) else x
81
+ def buf_map (self , dt :DType ) -> str : return "atomic<u32>" if is_packed ( dt ) else self .type_map [dt .base ]
84
82
def render_kernel (self , function_name :str , kernel :list [str ], bufs :list [tuple [str ,tuple [DType ,bool ]]], uops :list [UOp ], prefix = None ) -> str :
85
83
local_size = [num for _ , num in sorted ([u .arg for u in uops if u .op is Ops .SPECIAL and u .arg [0 ][0 ] == 'l' ], key = lambda x : x [0 ])]
86
84
if not local_size : local_size = [1 ]
0 commit comments