diff --git a/src/TensorCast.jl b/src/TensorCast.jl index 6950525..f5b9e84 100644 --- a/src/TensorCast.jl +++ b/src/TensorCast.jl @@ -12,6 +12,7 @@ export @cast, @reduce, @matmul, @pretty using MacroTools, StaticArrays, Compat using LinearAlgebra, Random +include("capture.jl") include("macro.jl") include("pretty.jl") include("string.jl") diff --git a/src/capture.jl b/src/capture.jl new file mode 100644 index 0000000..1f3bd96 --- /dev/null +++ b/src/capture.jl @@ -0,0 +1,158 @@ +""" + @capture_(ex, A_[ijk__]) + +Faster drop-in replacement for `MacroTools.@capture`, for a few patterns only: +* `A_[ijk__]` and `A_{ijk__}` +* `[ijk__]` +* `A_.field_` +* `A_ := B_` and `A_ = B_` and `A_ += B_` etc. +* `f_(x_)` +""" +macro capture_(ex, pat::Expr) + + H = QuoteNode(pat.head) + + A,B = if pat.head in [:ref, :curly] && length(pat.args)==2 && + _endswithone(pat.args[1]) && _endswithtwo(pat.args[2]) # :( A_[ijk__] ) + _symbolone(pat.args[1]), _symboltwo(pat.args[2]) + + elseif pat.head == :. && + _endswithone(pat.args[1]) && _endswithone(pat.args[2].value) # :( A_.field_ ) + _symbolone(pat.args[1]), _symbolone(pat.args[2].value) + + elseif pat.head == :call && length(pat.args)==2 && + _endswithone(pat.args[1]) && _endswithone(pat.args[2]) # :( f_(x_) ) + _symbolone(pat.args[1]), _symbolone(pat.args[2]) + + elseif pat.head in [:call, :(=), :(:=), :+=, :-=, :*=, :/=] && + _endswithone(pat.args[1]) && _endswithone(pat.args[2]) # :( A_ += B_ ) + _symbolone(pat.args[1]), _symbolone(pat.args[2]) + + elseif pat.head == :vect && _endswithtwo(pat.args[1]) # :( [ijk__] ) + _symboltwo(pat.args[1]), gensym(:ignore) + + else + error("@capture_ doesn't work on pattern $pat") + end + + @gensym res + quote + $A, $B = nothing, nothing + $res = TensorCast._trymatch($ex, Val($H)) + # $res = _trymatch($ex, Val($H)) + if $res === nothing + false + else + $A, $B = $res + true + end + end |> esc +end + +_endswithone(ex) = endswith(string(ex), '_') && !_endswithtwo(ex) +_endswithtwo(ex) = endswith(string(ex), "__") + +_symbolone(ex) = Symbol(string(ex)[1:end-1]) +_symboltwo(ex) = Symbol(string(ex)[1:end-2]) + +_getvalue(::Val{val}) where {val} = val + +_trymatch(s, v) = nothing # Symbol, or other Expr +_trymatch(ex::Expr, pat::Union{Val{:ref}, Val{:curly}}) = # A_[ijk__] or A_{ijk__} + if ex.head === _getvalue(pat) + ex.args[1], ex.args[2:end] + else + nothing + end +_trymatch(ex::Expr, ::Val{:.}) = # A_.field_ + if ex.head === :. + ex.args[1], ex.args[2].value + else + nothing + end +_trymatch(ex::Expr, pat::Val{:call}) = + if ex.head === _getvalue(pat) && length(ex.args) == 2 + ex.args[1], ex.args[2] + else + nothing + end +_trymatch(ex::Expr, pat::Union{Val{:(=)}, Val{:(:=)}, Val{:(+=)}, Val{:(-=)}, Val{:(*=)}, Val{:(/=)}}) = + if ex.head === _getvalue(pat) + ex.args[1], ex.args[2] + else + nothing + end +_trymatch(ex::Expr, ::Val{:vect}) = # [ijk__] + if ex.head === :vect + ex.args, nothing + else + nothing + end + +# Cases for Tullio: +# @capture(ex, B_[inds__].field_) --> @capture_(ex, Binds_.field_) && @capture_(Binds, B_[inds__]) +# @capture(ex, [inds__]) + +#= + +julia> ex = :(Z[1,2,3]) + +julia> @pretty @capture(ex, A_[ijk__]) +begin + A = MacroTools.nothing + ijk = MacroTools.nothing + tarsier = trymatch($(Expr(:copyast, :($(QuoteNode(:(A_[ijk__])))))), ex) + if tarsier == MacroTools.nothing + false + else + A = get(tarsier, :A, MacroTools.nothing) + ijk = get(tarsier, :ijk, MacroTools.nothing) + true + end +end + +julia> @pretty @capture_(ex, A_[ijk__]) +begin + A = nothing + ijk = nothing + louse = _trymatch(ex) + if louse == nothing + false + else + (A, ijk) = louse + true + end +end + + + +ex = :( A[i,j][k] + B[I[i],J[j],k]^2 / 2 ) +f1(x) = MacroTools.postwalk(ex) do x + @capture(x, A_[ijk__]) || return x + :($A[$(ijk...),9]) + end +f2(x) = MacroTools.postwalk(ex) do x + @capture_(x, A_[ijk__]) || return x + :($A[$(ijk...),9]) + end +f1(ex) +f2(ex) + +@btime f1(x) setup=(x=ex) # 3.181 ms +@btime f2(x) setup=(x=ex) # 31.440 μs -- 100x faster. + + +$ time julia -e 'using TensorCast; TensorCast._macro(:( Z[i,k][j] := fun(A[i,:], B[j])[k] + C[k]^2 ))' + +real 0m8.132s # was 0m8.900s on master, noise or signal? +user 0m7.747s +sys 0m0.358s +real 0m8.132s +user 0m8.295s +sys 0m0.329s + +$ time julia -e 'using TensorCast; @time TensorCast._macro(:( Z[i,k][j] := fun(A[i,:], B[j])[k] + C[k]^2 ))' + +4.899634 seconds, best run # was 5.845 on master, that's a second? + +=# diff --git a/src/macro.jl b/src/macro.jl index cf6e61f..e358e33 100644 --- a/src/macro.jl +++ b/src/macro.jl @@ -237,17 +237,17 @@ function standardise(ex, store::NamedTuple, call::CallInfo; LHS=false) @nospecialize ex # This acts only on single indexing expressions: - if @capture(ex, A_{ijk__}) + if @capture_(ex, A_{ijk__}) static=true push!(call.flags, :staticslice) - elseif @capture(ex, A_[ijk__]) + elseif @capture_(ex, A_[ijk__]) static=false else return ex end # Ensure that f(x)[i,j] will evaluate once, including in size(A) - if A isa Symbol || @capture(A, AA_.ff_) # caller has ensured !containsindexing(A) + if A isa Symbol || @capture_(A, AA_.ff_) # caller has ensured !containsindexing(A) else Asym = Symbol(A,"_val") # exact same symbol is used by rightsizes() push!(store.top, :( local $Asym = $A ) ) @@ -287,11 +287,11 @@ function standardise(ex, store::NamedTuple, call::CallInfo; LHS=false) end # Nested indices A[i,j,B[k,l,m],n] or worse A[i,B[j,k],C[i,j]] - if any(i -> @capture(i, B_[klm__]), ijk) + if any(i -> @capture_(i, B_[klm__]), ijk) newijk, beecolon = [], [] # for simple case # listB, listijk = [], [] for i in ijk - if @capture(i, B_[klm__]) + if @capture_(i, B_[klm__]) append!(newijk, klm) push!(beecolon, B) # push!(listijk, klm) @@ -383,9 +383,9 @@ function standardglue(ex, target, store::NamedTuple, call::CallInfo) @nospecialize ex # The sole target here is indexing expressions: - if @capture(ex, A_[inner__]) + if @capture_(ex, A_[inner__]) static=false - elseif @capture(ex, A_{inner__}) + elseif @capture_(ex, A_{inner__}) static=true else return ex @@ -397,7 +397,7 @@ function standardglue(ex, target, store::NamedTuple, call::CallInfo) end # Otherwise there are two options, (brodcasting...)[k] or simple B[i,j][k] - needcast = !@capture(A, B_[outer__]) + needcast = !@capture_(A, B_[outer__]) if needcast outer = unique(reduce(vcat, listindices(A))) @@ -475,7 +475,7 @@ function targetcast(ex, target, store::NamedTuple, call::CallInfo) @nospecialize ex # If just one naked expression, then we won't broadcast: - if @capture(ex, A_[ijk__]) + if @capture_(ex, A_[ijk__]) containsindexing(A) && error("that should have been dealt with") return readycast(ex, target, store, call) end @@ -508,6 +508,7 @@ This is walked over the expression to prepare for `@__dot__` etc, by `targetcast """ function readycast(ex, target, store::NamedTuple, call::CallInfo) @nospecialize ex + ex isa Symbol && return ex # quit early? # Scalar functions can be protected entirely from broadcasting: # TODO this means A[i,j] + rand()/10 doesn't work, /(...,10) is a function! @@ -525,10 +526,10 @@ function readycast(ex, target, store::NamedTuple, call::CallInfo) return :( getproperty($fun($(arg...)), $(QuoteNode(field))) ) # tuple creation... now including namedtuples @capture(ex, (args__,) ) && any(containsindexing, args) && - if any(a -> @capture(a, sym_ = val_), args) + if any(a -> @capture_(a, sym_ = val_), args) syms, vals = [], [] map(args) do a - @capture(a, sym_ = val_ ) || throw(MacroError("invalid named tuple element $a", call)) + @capture_(a, sym_ = val_ ) || throw(MacroError("invalid named tuple element $a", call)) push!(syms, QuoteNode(sym)) push!(vals, val) end @@ -541,7 +542,7 @@ function readycast(ex, target, store::NamedTuple, call::CallInfo) return :( Core._apply($funs[$(ijk...)], $(args...) ) ) # Apart from those, readycast acts only on lone tensors: - @capture(ex, A_[ijk__]) || return ex + @capture_(ex, A_[ijk__]) || return ex dims = Int[ findcheck(i, target, call, " on the left") for i in ijk ] @@ -638,6 +639,7 @@ Also a convenient place to tidy all indices, including e.g. `fun(M[:,j],N[j]).sa """ function recursemacro(ex, store::NamedTuple, call::CallInfo) @nospecialize ex + ex isa Symbol && return ex # quit early? # Actually look for recursion if @capture(ex, @reduce(subex__) ) @@ -663,9 +665,9 @@ function recursemacro(ex, store::NamedTuple, call::CallInfo) end # Tidy up indices, A[i,j][k] will be hit on different rounds... - if @capture(ex, A_[ijk__]) + if @capture_(ex, A_[ijk__]) return :( $A[$(tensorprimetidy(ijk)...)] ) - elseif @capture(ex, A_{ijk__}) + elseif @capture_(ex, A_{ijk__}) return :( $A{$(tensorprimetidy(ijk)...)} ) else return ex @@ -689,20 +691,21 @@ function rightsizes(ex, store::NamedTuple, call::CallInfo) if @capture(ex, A_[outer__][inner__] | A_[outer__]{inner__} ) field = nothing elseif @capture(ex, A_[outer__].field_[inner__] | A_[outer__].field_{inner__} ) - elseif @capture(ex, A_[outer__] | A_{outer__} ) + # elseif @capture(ex, A_[outer__] | A_{outer__} ) + elseif @capture_(ex, A_[outer__] ) || @capture_(ex, A_{outer__} ) field = nothing else return ex end # Special treatment for fun(x)[i,j], goldilocks A not just symbol, but no indexing - if A isa Symbol || @capture(A, AA_.ff_) + if A isa Symbol || @capture_(A, AA_.ff_) elseif !containsindexing(A) A = Symbol(A,"_val") # the exact same symbol is used by standardiser end # When we can save the sizes, then we destroy so as not to save again: - if A isa Symbol || @capture(A, AA_.ff_) && !containsindexing(A) + if A isa Symbol || @capture_(A, AA_.ff_) && !containsindexing(A) indexparse(A, outer, store, call; save=true) if field==nothing innerparse(:(first($A)), inner, store, call; save=true) @@ -728,24 +731,24 @@ function castparse(ex, store::NamedTuple, call::CallInfo; reduce=false) Z = gensym(:left) # Do we make a new array? With or without collecting: - if @capture(ex, left_ := right_ ) + if @capture_(ex, left_ := right_ ) elseif @capture(ex, left_ == right_ ) @warn "using == no longer does anything" call.string maxlog=1 _id=hash(call.string) elseif @capture(ex, left_ |= right_ ) push!(call.flags, :collect) # Do we write into an exising array? Possibly updating it: - elseif @capture(ex, left_ = right_ ) + elseif @capture_(ex, left_ = right_ ) push!(call.flags, :inplace) - elseif @capture(ex, left_ += right_ ) + elseif @capture_(ex, left_ += right_ ) push!(call.flags, :inplace) right = :( $left + $right ) reduce && throw(MacroError("can't use += with @reduce", call)) - elseif @capture(ex, left_ -= right_ ) + elseif @capture_(ex, left_ -= right_ ) push!(call.flags, :inplace) right = :( $left - ($right) ) reduce && throw(MacroError("can't use -= with @reduce", call)) - elseif @capture(ex, left_ *= right_ ) + elseif @capture_(ex, left_ *= right_ ) push!(call.flags, :inplace) right = :( $left * ($right) ) reduce && throw(MacroError("can't use *= with @reduce", call)) @@ -762,7 +765,7 @@ function castparse(ex, store::NamedTuple, call::CallInfo; reduce=false) error("wtf is $ex") end - static = @capture(left, ZZ_{ii__}) + static = @capture_(left, ZZ_{ii__}) if @capture(left, Z_[outer__][inner__] | [outer__][inner__] | Z_[outer__]{inner__} | [outer__]{inner__} ) isnothing(Z) && (:inplace in call.flags) && throw(MacroError("can't write into a nameless tensor", call)) @@ -884,7 +887,7 @@ function indexparse(A, ijk::Vector, store=nothing, call=nothing; save=false) push!(outsize, szwrap(ii)) save && saveonesize(ii, :(size($A, $d)), store) - elseif @capture(i, B_[klm__]) + elseif @capture_(i, B_[klm__]) innerparse(B, klm, store, call) # called just for error on tensor/colon/constant sub = indexparse(B, klm, store, call; save=save) # I do want to save size(B,1) etc. append!(flat, sub.flat) @@ -1125,6 +1128,8 @@ end tensorprimetidy(v::Vector) = Any[ tensorprimetidy(x) for x in v ] function tensorprimetidy(ex) MacroTools.postwalk(ex) do @nospecialize x + x isa Symbol && return x # quit early? + @capture(x, ((ij__,) \ k_) ) && return :( ($(ij...),$k) ) @capture(x, i_ \ j_ ) && return :( ($i,$j) ) @@ -1151,7 +1156,7 @@ isconstant(ex::Expr) = ex.head == :($) isconstant(q::QuoteNode) = false isindexing(s) = false -isindexing(ex::Expr) = @capture(x, A_[ijk__]) +isindexing(ex::Expr) = @capture_(x, A_[ijk__]) isCorI(i) = isconstant(i) || isindexing(ii) @@ -1182,7 +1187,7 @@ function containsindexing(ex::Expr) # MacroTools.postwalk(x -> @capture(x, A_[ijk__]) && (flag=true), ex) MacroTools.postwalk(ex) do @nospecialize x # @capture(x, A_[ijk__]) && !(all(isconstant, ijk)) && (flag=true) - if @capture(x, A_[ijk__]) + if @capture_(x, A_[ijk__]) # @show x ijk # TODO this is a bit broken? @pretty @cast Z[i,j] := W[i] * exp(X[1][i] - X[2][j]) flag=true end @@ -1194,7 +1199,7 @@ listindices(s::Symbol) = [] function listindices(ex::Expr) list = [] MacroTools.postwalk(ex) do @nospecialize x - if @capture(x, A_[ijk__]) + if @capture_(x, A_[ijk__]) flat, _ = indexparse(nothing, ijk) push!(list, flat) end @@ -1447,13 +1452,13 @@ function inplaceoutput(ex, canon, parsed, store::NamedTuple, call::CallInfo) pop!(call.flags, :nolazy, :ok) # ensure we use diagview(), Reverse{}, etc, not a copy if @capture(parsed.left, zed_[]) # special case Z[] = ... else allconst pulls it out - zed isa Symbol || @capture(zed, ZZ_.field_) || error("wtf") + zed isa Symbol || @capture_(zed, ZZ_.field_) || error("wtf") newleft = parsed.left str = "expected a 0-tensor $zed[]" push!(store.mustassert, :( TensorCast.@assert_ ndims($zed)==0 $str) ) else newleft = standardise(parsed.left, store, call) - @capture(newleft, zed_[ijk__]) || throw(MacroError("failed to parse LHS correctly, $(parsed.left) -> $newleft")) + @capture_(newleft, zed_[ijk__]) || throw(MacroError("failed to parse LHS correctly, $(parsed.left) -> $newleft")) if !(zed isa Symbol) # then standardise did something! push!(call.flags, :showfinal) diff --git a/test/runtests.jl b/test/runtests.jl index 8abf1a9..10d3417 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -11,6 +11,7 @@ using Compat if VERSION >= v"1.1" using LoopVectorization end +using TensorCast: @capture_ @testset "ex-@shape" begin include("shape.jl") end @testset "@reduce" begin include("reduce.jl") end diff --git a/test/two.jl b/test/two.jl index 0992c32..b6bc6e3 100644 --- a/test/two.jl +++ b/test/two.jl @@ -356,3 +356,36 @@ end # @test_throws DimensionMismatch @cast M5[i,j] := fun(M[:,j]).same[i] i:99, j:4 # TODO make this check canonical length? end +@testset "capture_ macro" begin + + using TensorCast: @capture_ + + EXS = [:(A[i,j,k]), :(B{i,2,:}), :(C.dee), :(fun(5)), :(g := h+i), :(k[3] += l[4]), :([m,n,0]) ] + PATS = [:(A_[ijk__]), :(B_{ind__}), :(C_.d_), :(f_(arg_)), :(left_ := right_), :(a_ += b_), :([emm__]) ] + # @test length(EXS) == length(PATS) + @testset "ex = $(EXS[i])" for i in eachindex(EXS) + for j in eachindex(PATS) + @eval res = @capture_($EXS[$i], $(PATS[j])) + if i != j + @test res == false + else + @test res == true + if i==1 + @test A == :A + @test ijk == [:i, :j, :k] + elseif i==3 + @test C == :C + @test d == :dee + elseif i==5 + @test left == :g + @test right == :(h+i) + elseif i==7 + @test emm == [:m, :n, 0] + end + end + end + end + + @test !@capture_( :(f(1,2,3)), f_(x_) ) + +end