From 902ced95330b80d2a08ae3d42a1a1ecb346d8618 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Thu, 1 May 2025 22:30:37 -0500 Subject: [PATCH 01/47] generate --- src/Compiler.jl | 43 ++++++++++ src/Interpreter.jl | 107 ++++++++++++++++++++++++ src/Overlay.jl | 12 +++ src/mlir/Dialects/Enzyme.jl | 162 ++++++++++++++++++++++++++++++++++++ test/probprog.jl | 32 +++++++ 5 files changed, 356 insertions(+) create mode 100644 test/probprog.jl diff --git a/src/Compiler.jl b/src/Compiler.jl index db02746a07..bff1046c9d 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1489,6 +1489,49 @@ function compile_mlir!( ), "after_enzyme", ) + elseif optimize === :probprog + run_pass_pipeline!( + mod, + join( + if raise_first + [ + opt_passes, + kern, + raise_passes, + "enzyme-batch", + opt_passes2, + enzyme_pass, + "probprog", + enzyme_pass, + opt_passes2, + "canonicalize", + "remove-unnecessary-enzyme-ops", + "enzyme-simplify-math", + opt_passes2, + jit, + ] + else + [ + opt_passes, + "enzyme-batch", + opt_passes2, + enzyme_pass, + "probprog", + enzyme_pass, + opt_passes2, + "canonicalize", + "remove-unnecessary-enzyme-ops", + "enzyme-simplify-math", + opt_passes2, + kern, + raise_passes, + jit, + ] + end, + ',', + ), + "probprog", + ) elseif optimize === :canonicalize run_pass_pipeline!(mod, "canonicalize", "canonicalize") elseif optimize === :just_batch diff --git a/src/Interpreter.jl b/src/Interpreter.jl index ee299ca4c1..46c1f675e5 100644 --- a/src/Interpreter.jl +++ b/src/Interpreter.jl @@ -539,3 +539,110 @@ function overload_autodiff( end end end + +function overload_generate(f::Function, args::Vararg{Any,Nargs}) where {Nargs} + argprefix::Symbol = gensym("generatearg") + resprefix::Symbol = gensym("generateresult") + resargprefix::Symbol = gensym("generateresarg") + + mlir_fn_res = TracedUtils.make_mlir_fn( + f, args, (), string(f) * "_generate", false; argprefix, resprefix, resargprefix + ) + (; result, linear_args, in_tys, linear_results) = mlir_fn_res + fnwrap = mlir_fn_res.fnwrapped + func2 = mlir_fn_res.f + + out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] + fname = TracedUtils.get_attribute_by_name(func2, "sym_name") + fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) + + batch_inputs = MLIR.IR.Value[] + for a in linear_args + idx, path = TracedUtils.get_argidx(a, argprefix) + if idx == 1 && fnwrap + TracedUtils.push_val!(batch_inputs, f, path[3:end]) + else + if fnwrap + idx -= 1 + end + TracedUtils.push_val!(batch_inputs, args[idx], path[3:end]) + end + end + + gen_op = MLIR.Dialects.enzyme.generate(batch_inputs; outputs=out_tys, fn=fname) + + residx = 1 + for a in linear_results + resv = MLIR.IR.result(gen_op, residx) + residx += 1 + for path in a.paths + if length(path) == 0 + continue + end + if path[1] == resprefix + TracedUtils.set!(result, path[2:end], resv) + elseif path[1] == argprefix + idx = path[2]::Int + if idx == 1 && fnwrap + TracedUtils.set!(f, path[3:end], resv) + else + if fnwrap + idx -= 1 + end + TracedUtils.set!(args[idx], path[3:end], resv) + end + end + end + end + + return result +end + +function overload_sample(f::Function, args::Vararg{Any,Nargs}) where {Nargs} + argprefix = gensym("samplearg") + resprefix = gensym("sampleresult") + resargprefix = gensym("sampleresarg") + + mlir_fn_res = TracedUtils.make_mlir_fn( + f, args, (), string(f) * "_sample", false; argprefix, resprefix, resargprefix + ) + (; result, linear_args, in_tys, linear_results) = mlir_fn_res + fnwrap = mlir_fn_res.fnwrapped + func2 = mlir_fn_res.f + + batch_inputs = MLIR.IR.Value[] + for a in linear_args + idx, path = TracedUtils.get_argidx(a, argprefix) + if idx == 1 && fnwrap + TracedUtils.push_val!(batch_inputs, f, path[3:end]) + else + idx -= fnwrap ? 1 : 0 + TracedUtils.push_val!(batch_inputs, args[idx], path[3:end]) + end + end + + out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] + + sym = TracedUtils.get_attribute_by_name(func2, "sym_name") + fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(sym)) + + sample_op = MLIR.Dialects.enzyme.sample(batch_inputs; outputs=out_tys, fn=fn_attr) + + ridx = 1 + for a in linear_results + val = MLIR.IR.result(sample_op, ridx) + ridx += 1 + + for path in a.paths + isempty(path) && continue + if path[1] == resprefix + TracedUtils.set!(result, path[2:end], val) + elseif path[1] == argprefix + idx = path[2]::Int - (fnwrap ? 1 : 0) + TracedUtils.set!(args[idx], path[3:end], val) + end + end + end + + return result +end diff --git a/src/Overlay.jl b/src/Overlay.jl index c97a06664d..cfef42541f 100644 --- a/src/Overlay.jl +++ b/src/Overlay.jl @@ -21,6 +21,18 @@ end return overload_autodiff(rmode, f, rt, args...) end +@reactant_overlay @noinline function Enzyme.generate( + f::Function, args::Vararg{Any,Nargs} +) where {Nargs} + return overload_generate(f, args...) +end + +@reactant_overlay @noinline function Enzyme.sample( + f::Function, args::Vararg{Any,Nargs} +) where {Nargs} + return overload_sample(f, args...) +end + # Random.jl overlays @reactant_overlay @noinline function Random.default_rng() return call_with_reactant(TracedRandom.default_rng) diff --git a/src/mlir/Dialects/Enzyme.jl b/src/mlir/Dialects/Enzyme.jl index e4306b06a1..f558ee0468 100755 --- a/src/mlir/Dialects/Enzyme.jl +++ b/src/mlir/Dialects/Enzyme.jl @@ -151,6 +151,33 @@ function fwddiff( ) end +""" +`generate` + +Generate a sample from a probabilistic function by replacing all SampleOps with distribution calls. +""" +function generate( + inputs::Vector{Value}; outputs::Vector{IR.Type}, fn, name=nothing, location=Location() +) + op_ty_results = IR.Type[outputs...,] + operands = Value[inputs...,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("fn", fn),] + !isnothing(name) && push!(attributes, namedattribute("name", name)) + + return create_operation( + "enzyme.generate", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + function genericAdjoint( inputs::Vector{Value}, outputs::Vector{Value}; @@ -323,4 +350,139 @@ function set(gradient::Value, value::Value; location=Location()) ) end +""" +`simulate` + +Simulate a probabilistic function to generate execution trace +by replacing all SampleOps with distribution calls and inserting +sampled values into the choice map. +""" +function simulate( + inputs::Vector{Value}; newTrace::IR.Type, fn, name=nothing, location=Location() +) + op_ty_results = IR.Type[newTrace,] + operands = Value[inputs...,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("fn", fn),] + !isnothing(name) && push!(attributes, namedattribute("name", name)) + + return create_operation( + "enzyme.simulate", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`trace` + +Execute a probabilistic function specified by a symbol reference using the provided arguments, +and a set of constraints on the sampled variables (if provided). Return the execution trace +(if provided) and the log-likelihood of the execution trace. +""" +function trace( + inputs::Vector{Value}, + oldTrace=nothing::Union{Nothing,Value}; + constraints=nothing::Union{Nothing,Value}, + newTrace::IR.Type, + weights::Vector{IR.Type}, + fn, + name=nothing, + location=Location(), +) + op_ty_results = IR.Type[newTrace, weights...] + operands = Value[inputs...,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("fn", fn),] + !isnothing(oldTrace) && push!(operands, oldTrace) + !isnothing(constraints) && push!(operands, constraints) + push!( + attributes, + operandsegmentsizes([ + length(inputs), if (oldTrace == nothing) + 0 + elseif 1(constraints == nothing) + 0 + else + 1 + end + ]), + ) + !isnothing(name) && push!(attributes, namedattribute("name", name)) + + return create_operation( + "enzyme.trace", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`addSampleToTrace` + +Add a sampled value into the execution trace. +""" +function addSampleToTrace(trace::Value, sample::Value; name=nothing, location=Location()) + op_ty_results = IR.Type[] + operands = Value[trace, sample] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(name) && push!(attributes, namedattribute("name", name)) + + return create_operation( + "enzyme.addSampleToTrace", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`insertChoiceToMap` + +Insert a constraint on a sampled variable into the choice map. +""" +function insertChoiceToMap( + choiceMap::Value, + choice::Value; + newChoiceMap::IR.Type, + name=nothing, + location=Location(), +) + op_ty_results = IR.Type[newChoiceMap,] + operands = Value[choiceMap, choice] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(name) && push!(attributes, namedattribute("name", name)) + + return create_operation( + "enzyme.insertChoiceToMap", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + end # enzyme diff --git a/test/probprog.jl b/test/probprog.jl new file mode 100644 index 0000000000..e3f64faf30 --- /dev/null +++ b/test/probprog.jl @@ -0,0 +1,32 @@ +using Enzyme, Reactant, Test, Random, StableRNGs, Statistics + +normal(rng, mean, stddev) = mean .+ stddev .* randn(rng, 10000) + +function model(mean, stddev) + s = Enzyme.sample(normal, StableRNG(0), mean, stddev) + t = Enzyme.sample(normal, StableRNG(0), s, stddev) + return t +end + +@testset "ProbProg" begin + @testset "normal_hlo" begin + hlo = @code_hlo Enzyme.generate( + model, Reactant.to_rarray(0.0), Reactant.to_rarray(1.0) + ) + @test contains(repr(hlo), "enzyme.generate") + @test contains(repr(hlo), "enzyme.sample") + # println(hlo) + + lowered = Reactant.Compiler.run_pass_pipeline_on_source(repr(hlo), "probprog") + println(lowered) + end + + @testset "normal_generate" begin + X = Array( + @jit optimize = :probprog Enzyme.generate( + model, Reactant.to_rarray(0.0), Reactant.to_rarray(1.0) + ) + ) + @test mean(X) ≈ 0.0 atol = 0.05 rtol = 0.05 + end +end From e2c77e402f41fd39084933bef7fac89e08eeee01 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Fri, 2 May 2025 15:40:51 -0500 Subject: [PATCH 02/47] refactor --- src/Interpreter.jl | 107 ----------------------------------------- src/Overlay.jl | 12 ----- src/ProbProg.jl | 115 +++++++++++++++++++++++++++++++++++++++++++++ src/Reactant.jl | 1 + test/probprog.jl | 11 +++-- test/runtests.jl | 1 + 6 files changed, 123 insertions(+), 124 deletions(-) create mode 100644 src/ProbProg.jl diff --git a/src/Interpreter.jl b/src/Interpreter.jl index 46c1f675e5..ee299ca4c1 100644 --- a/src/Interpreter.jl +++ b/src/Interpreter.jl @@ -539,110 +539,3 @@ function overload_autodiff( end end end - -function overload_generate(f::Function, args::Vararg{Any,Nargs}) where {Nargs} - argprefix::Symbol = gensym("generatearg") - resprefix::Symbol = gensym("generateresult") - resargprefix::Symbol = gensym("generateresarg") - - mlir_fn_res = TracedUtils.make_mlir_fn( - f, args, (), string(f) * "_generate", false; argprefix, resprefix, resargprefix - ) - (; result, linear_args, in_tys, linear_results) = mlir_fn_res - fnwrap = mlir_fn_res.fnwrapped - func2 = mlir_fn_res.f - - out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] - fname = TracedUtils.get_attribute_by_name(func2, "sym_name") - fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) - - batch_inputs = MLIR.IR.Value[] - for a in linear_args - idx, path = TracedUtils.get_argidx(a, argprefix) - if idx == 1 && fnwrap - TracedUtils.push_val!(batch_inputs, f, path[3:end]) - else - if fnwrap - idx -= 1 - end - TracedUtils.push_val!(batch_inputs, args[idx], path[3:end]) - end - end - - gen_op = MLIR.Dialects.enzyme.generate(batch_inputs; outputs=out_tys, fn=fname) - - residx = 1 - for a in linear_results - resv = MLIR.IR.result(gen_op, residx) - residx += 1 - for path in a.paths - if length(path) == 0 - continue - end - if path[1] == resprefix - TracedUtils.set!(result, path[2:end], resv) - elseif path[1] == argprefix - idx = path[2]::Int - if idx == 1 && fnwrap - TracedUtils.set!(f, path[3:end], resv) - else - if fnwrap - idx -= 1 - end - TracedUtils.set!(args[idx], path[3:end], resv) - end - end - end - end - - return result -end - -function overload_sample(f::Function, args::Vararg{Any,Nargs}) where {Nargs} - argprefix = gensym("samplearg") - resprefix = gensym("sampleresult") - resargprefix = gensym("sampleresarg") - - mlir_fn_res = TracedUtils.make_mlir_fn( - f, args, (), string(f) * "_sample", false; argprefix, resprefix, resargprefix - ) - (; result, linear_args, in_tys, linear_results) = mlir_fn_res - fnwrap = mlir_fn_res.fnwrapped - func2 = mlir_fn_res.f - - batch_inputs = MLIR.IR.Value[] - for a in linear_args - idx, path = TracedUtils.get_argidx(a, argprefix) - if idx == 1 && fnwrap - TracedUtils.push_val!(batch_inputs, f, path[3:end]) - else - idx -= fnwrap ? 1 : 0 - TracedUtils.push_val!(batch_inputs, args[idx], path[3:end]) - end - end - - out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] - - sym = TracedUtils.get_attribute_by_name(func2, "sym_name") - fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(sym)) - - sample_op = MLIR.Dialects.enzyme.sample(batch_inputs; outputs=out_tys, fn=fn_attr) - - ridx = 1 - for a in linear_results - val = MLIR.IR.result(sample_op, ridx) - ridx += 1 - - for path in a.paths - isempty(path) && continue - if path[1] == resprefix - TracedUtils.set!(result, path[2:end], val) - elseif path[1] == argprefix - idx = path[2]::Int - (fnwrap ? 1 : 0) - TracedUtils.set!(args[idx], path[3:end], val) - end - end - end - - return result -end diff --git a/src/Overlay.jl b/src/Overlay.jl index cfef42541f..c97a06664d 100644 --- a/src/Overlay.jl +++ b/src/Overlay.jl @@ -21,18 +21,6 @@ end return overload_autodiff(rmode, f, rt, args...) end -@reactant_overlay @noinline function Enzyme.generate( - f::Function, args::Vararg{Any,Nargs} -) where {Nargs} - return overload_generate(f, args...) -end - -@reactant_overlay @noinline function Enzyme.sample( - f::Function, args::Vararg{Any,Nargs} -) where {Nargs} - return overload_sample(f, args...) -end - # Random.jl overlays @reactant_overlay @noinline function Random.default_rng() return call_with_reactant(TracedRandom.default_rng) diff --git a/src/ProbProg.jl b/src/ProbProg.jl new file mode 100644 index 0000000000..b80fb2f628 --- /dev/null +++ b/src/ProbProg.jl @@ -0,0 +1,115 @@ +module ProbProg + +using ..Reactant: Reactant, XLA, MLIR, TracedUtils +using ReactantCore: ReactantCore + +using Enzyme + +function generate(f::Function, args::Vararg{Any,Nargs}) where {Nargs} + argprefix::Symbol = gensym("generatearg") + resprefix::Symbol = gensym("generateresult") + resargprefix::Symbol = gensym("generateresarg") + + mlir_fn_res = TracedUtils.make_mlir_fn( + f, args, (), string(f) * "_generate", false; argprefix, resprefix, resargprefix + ) + (; result, linear_args, in_tys, linear_results) = mlir_fn_res + fnwrap = mlir_fn_res.fnwrapped + func2 = mlir_fn_res.f + + out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] + fname = TracedUtils.get_attribute_by_name(func2, "sym_name") + fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) + + batch_inputs = MLIR.IR.Value[] + for a in linear_args + idx, path = TracedUtils.get_argidx(a, argprefix) + if idx == 1 && fnwrap + TracedUtils.push_val!(batch_inputs, f, path[3:end]) + else + if fnwrap + idx -= 1 + end + TracedUtils.push_val!(batch_inputs, args[idx], path[3:end]) + end + end + + gen_op = MLIR.Dialects.enzyme.generate(batch_inputs; outputs=out_tys, fn=fname) + + residx = 1 + for a in linear_results + resv = MLIR.IR.result(gen_op, residx) + residx += 1 + for path in a.paths + if length(path) == 0 + continue + end + if path[1] == resprefix + TracedUtils.set!(result, path[2:end], resv) + elseif path[1] == argprefix + idx = path[2]::Int + if idx == 1 && fnwrap + TracedUtils.set!(f, path[3:end], resv) + else + if fnwrap + idx -= 1 + end + TracedUtils.set!(args[idx], path[3:end], resv) + end + end + end + end + + return result +end + +function sample(f::Function, args::Vararg{Any,Nargs}) where {Nargs} + argprefix = gensym("samplearg") + resprefix = gensym("sampleresult") + resargprefix = gensym("sampleresarg") + + mlir_fn_res = TracedUtils.make_mlir_fn( + f, args, (), string(f) * "_sample", false; argprefix, resprefix, resargprefix + ) + (; result, linear_args, in_tys, linear_results) = mlir_fn_res + fnwrap = mlir_fn_res.fnwrapped + func2 = mlir_fn_res.f + + batch_inputs = MLIR.IR.Value[] + for a in linear_args + idx, path = TracedUtils.get_argidx(a, argprefix) + if idx == 1 && fnwrap + TracedUtils.push_val!(batch_inputs, f, path[3:end]) + else + idx -= fnwrap ? 1 : 0 + TracedUtils.push_val!(batch_inputs, args[idx], path[3:end]) + end + end + + out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] + + sym = TracedUtils.get_attribute_by_name(func2, "sym_name") + fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(sym)) + + sample_op = MLIR.Dialects.enzyme.sample(batch_inputs; outputs=out_tys, fn=fn_attr) + + ridx = 1 + for a in linear_results + val = MLIR.IR.result(sample_op, ridx) + ridx += 1 + + for path in a.paths + isempty(path) && continue + if path[1] == resprefix + TracedUtils.set!(result, path[2:end], val) + elseif path[1] == argprefix + idx = path[2]::Int - (fnwrap ? 1 : 0) + TracedUtils.set!(args[idx], path[3:end], val) + end + end + end + + return result +end + +end \ No newline at end of file diff --git a/src/Reactant.jl b/src/Reactant.jl index 090a8d6b90..d9f5d908b8 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -174,6 +174,7 @@ include("stdlibs/Base.jl") # Other Integrations include("Enzyme.jl") +include("ProbProg.jl") const TracedType = Union{TracedRArray,TracedRNumber,MissingTracedValue} diff --git a/test/probprog.jl b/test/probprog.jl index e3f64faf30..a493fcee4b 100644 --- a/test/probprog.jl +++ b/test/probprog.jl @@ -1,16 +1,17 @@ -using Enzyme, Reactant, Test, Random, StableRNGs, Statistics +using Reactant, Test, Random, StableRNGs, Statistics +using Reactant: ProbProg normal(rng, mean, stddev) = mean .+ stddev .* randn(rng, 10000) function model(mean, stddev) - s = Enzyme.sample(normal, StableRNG(0), mean, stddev) - t = Enzyme.sample(normal, StableRNG(0), s, stddev) + s = ProbProg.sample(normal, StableRNG(0), mean, stddev) + t = ProbProg.sample(normal, StableRNG(0), s, stddev) return t end @testset "ProbProg" begin @testset "normal_hlo" begin - hlo = @code_hlo Enzyme.generate( + hlo = @code_hlo ProbProg.generate( model, Reactant.to_rarray(0.0), Reactant.to_rarray(1.0) ) @test contains(repr(hlo), "enzyme.generate") @@ -23,7 +24,7 @@ end @testset "normal_generate" begin X = Array( - @jit optimize = :probprog Enzyme.generate( + @jit optimize = :probprog ProbProg.generate( model, Reactant.to_rarray(0.0), Reactant.to_rarray(1.0) ) ) diff --git a/test/runtests.jl b/test/runtests.jl index b93fb9ae20..383aa44cf1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -47,6 +47,7 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) @safetestset "Tracing" include("tracing.jl") @safetestset "Basic" include("basic.jl") @safetestset "Autodiff" include("autodiff.jl") + @safetestset "ProbProg" include("probprog.jl") @safetestset "Complex" include("complex.jl") @safetestset "Broadcast" include("bcast.jl") @safetestset "Struct" include("struct.jl") From d611ae4f818ec0aee692b71805a5b6041583d96b Mon Sep 17 00:00:00 2001 From: sbrantq Date: Tue, 6 May 2025 20:14:21 -0500 Subject: [PATCH 03/47] add probprog pass to :all --- src/Compiler.jl | 45 ++------------------------------------------- 1 file changed, 2 insertions(+), 43 deletions(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index 84db740901..e2c9fd93c7 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1285,6 +1285,7 @@ function compile_mlir!( raise_passes, "enzyme-batch", opt_passes2, + "probprog", enzyme_pass, opt_passes2, "canonicalize", @@ -1299,6 +1300,7 @@ function compile_mlir!( opt_passes, "enzyme-batch", opt_passes2, + "probprog", enzyme_pass, opt_passes2, "canonicalize", @@ -1506,49 +1508,6 @@ function compile_mlir!( ), "after_enzyme", ) - elseif optimize === :probprog - run_pass_pipeline!( - mod, - join( - if raise_first - [ - opt_passes, - kern, - raise_passes, - "enzyme-batch", - opt_passes2, - enzyme_pass, - "probprog", - enzyme_pass, - opt_passes2, - "canonicalize", - "remove-unnecessary-enzyme-ops", - "enzyme-simplify-math", - opt_passes2, - jit, - ] - else - [ - opt_passes, - "enzyme-batch", - opt_passes2, - enzyme_pass, - "probprog", - enzyme_pass, - opt_passes2, - "canonicalize", - "remove-unnecessary-enzyme-ops", - "enzyme-simplify-math", - opt_passes2, - kern, - raise_passes, - jit, - ] - end, - ',', - ), - "probprog", - ) elseif optimize === :canonicalize run_pass_pipeline!(mod, "mark-func-memory-effects,canonicalize", "canonicalize") elseif optimize === :just_batch From 3672d83caa1b53d66bb640cfee6901ece906b89a Mon Sep 17 00:00:00 2001 From: sbrantq Date: Tue, 6 May 2025 20:14:28 -0500 Subject: [PATCH 04/47] improve test --- test/probprog.jl | 34 ++++++++++++++++++++++------------ 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/test/probprog.jl b/test/probprog.jl index a493fcee4b..6272ec3312 100644 --- a/test/probprog.jl +++ b/test/probprog.jl @@ -3,29 +3,39 @@ using Reactant: ProbProg normal(rng, mean, stddev) = mean .+ stddev .* randn(rng, 10000) -function model(mean, stddev) - s = ProbProg.sample(normal, StableRNG(0), mean, stddev) - t = ProbProg.sample(normal, StableRNG(0), s, stddev) +function model(rng, mean, stddev) + s = ProbProg.sample(normal, rng, mean, stddev) + t = ProbProg.sample(normal, rng, s, stddev) return t end @testset "ProbProg" begin @testset "normal_hlo" begin - hlo = @code_hlo ProbProg.generate( - model, Reactant.to_rarray(0.0), Reactant.to_rarray(1.0) + rng = StableRNG(0) + before = @code_hlo optimize = :none ProbProg.generate( + model, rng, Reactant.to_rarray(0.0), Reactant.to_rarray(1.0) ) - @test contains(repr(hlo), "enzyme.generate") - @test contains(repr(hlo), "enzyme.sample") - # println(hlo) + @test contains(repr(before), "enzyme.generate") + @test contains(repr(before), "enzyme.sample") - lowered = Reactant.Compiler.run_pass_pipeline_on_source(repr(hlo), "probprog") - println(lowered) + # println("Before") + # println(repr(before)) + + after = @code_hlo optimize = :all ProbProg.generate( + model, rng, Reactant.to_rarray(0.0), Reactant.to_rarray(1.0) + ) + @test !contains(repr(after), "enzyme.generate") + @test !contains(repr(after), "enzyme.sample") + + # println("After") + # println(repr(after)) end @testset "normal_generate" begin + rng = StableRNG(1) X = Array( - @jit optimize = :probprog ProbProg.generate( - model, Reactant.to_rarray(0.0), Reactant.to_rarray(1.0) + @jit optimize = :all ProbProg.generate( + model, rng, Reactant.to_rarray(0.0), Reactant.to_rarray(1.0) ) ) @test mean(X) ≈ 0.0 atol = 0.05 rtol = 0.05 From b70843e34d96bff7fdfb6a4e6c83f19e60c7d9b9 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Thu, 8 May 2025 12:22:05 -0500 Subject: [PATCH 05/47] only probprog opt mode --- src/Compiler.jl | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/Compiler.jl b/src/Compiler.jl index e2c9fd93c7..690c964a01 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1424,6 +1424,22 @@ function compile_mlir!( ), "only_enzyme", ) + elseif optimize === :probprog + run_pass_pipeline!( + mod, + join( + [ + "mark-func-memory-effects", + "enzyme-batch", + "probprog", + "canonicalize", + "remove-unnecessary-enzyme-ops", + "enzyme-simplify-math", + ], + ',', + ), + "probprog", + ) elseif optimize === :only_enzyme run_pass_pipeline!( mod, From 597fa89b0d4d009dfe0a463e63eb693f7105acd0 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Thu, 8 May 2025 12:22:15 -0500 Subject: [PATCH 06/47] fix up test --- test/probprog.jl | 60 +++++++++++++++++++++++++++++------------------- 1 file changed, 36 insertions(+), 24 deletions(-) diff --git a/test/probprog.jl b/test/probprog.jl index 6272ec3312..b3cfe75970 100644 --- a/test/probprog.jl +++ b/test/probprog.jl @@ -1,43 +1,55 @@ using Reactant, Test, Random, StableRNGs, Statistics using Reactant: ProbProg -normal(rng, mean, stddev) = mean .+ stddev .* randn(rng, 10000) +normal(rng, μ, σ) = μ .+ σ .* randn(rng, 10000) -function model(rng, mean, stddev) - s = ProbProg.sample(normal, rng, mean, stddev) - t = ProbProg.sample(normal, rng, s, stddev) - return t +function generate_model(seed, μ, σ) + function model(seed, μ, σ) + rng = Random.default_rng() + Random.seed!(rng, seed) + s = ProbProg.sample(normal, rng, μ, σ) + t = ProbProg.sample(normal, rng, s, σ) + return t + end + + return ProbProg.generate(model, seed, μ, σ) end @testset "ProbProg" begin + @testset "normal_deterministic" begin + seed1 = Reactant.to_rarray(UInt64[1, 4]) + seed2 = Reactant.to_rarray(UInt64[1, 4]) + μ1 = Reactant.ConcreteRArray(0.0) + μ2 = Reactant.ConcreteRArray(1000.0) + σ1 = Reactant.ConcreteRArray(1.0) + σ2 = Reactant.ConcreteRArray(1.0) + model_compiled = @compile generate_model(seed1, μ1, σ1) + + @test Array(model_compiled(seed1, μ1, σ1)) ≈ Array(model_compiled(seed1, μ1, σ1)) + @test mean(Array(model_compiled(seed1, μ1, σ1))) ≈ 0.0 atol = 0.05 rtol = 0.05 + @test mean(Array(model_compiled(seed2, μ2, σ2))) ≈ 1000.0 atol = 0.05 rtol = 0.05 + @test !(all( + Array(model_compiled(seed1, μ1, σ1)) .≈ Array(model_compiled(seed2, μ2, σ2)) + )) + end @testset "normal_hlo" begin - rng = StableRNG(0) - before = @code_hlo optimize = :none ProbProg.generate( - model, rng, Reactant.to_rarray(0.0), Reactant.to_rarray(1.0) - ) + seed = Reactant.to_rarray(UInt64[1, 4]) + μ = Reactant.ConcreteRArray(0.0) + σ = Reactant.ConcreteRArray(1.0) + before = @code_hlo optimize = :none generate_model(seed, μ, σ) @test contains(repr(before), "enzyme.generate") @test contains(repr(before), "enzyme.sample") - # println("Before") - # println(repr(before)) - - after = @code_hlo optimize = :all ProbProg.generate( - model, rng, Reactant.to_rarray(0.0), Reactant.to_rarray(1.0) - ) + after = @code_hlo optimize = :probprog generate_model(seed, μ, σ) @test !contains(repr(after), "enzyme.generate") @test !contains(repr(after), "enzyme.sample") - - # println("After") - # println(repr(after)) end @testset "normal_generate" begin - rng = StableRNG(1) - X = Array( - @jit optimize = :all ProbProg.generate( - model, rng, Reactant.to_rarray(0.0), Reactant.to_rarray(1.0) - ) - ) + seed = Reactant.to_rarray(UInt64[1, 4]) + μ = Reactant.ConcreteRArray(0.0) + σ = Reactant.ConcreteRArray(1.0) + X = Array(@jit optimize = :probprog generate_model(seed, μ, σ)) @test mean(X) ≈ 0.0 atol = 0.05 rtol = 0.05 end end From e6c2c0a2a37dec3be89051798ce8335896c99f5b Mon Sep 17 00:00:00 2001 From: sbrantq Date: Mon, 12 May 2025 09:38:53 -0500 Subject: [PATCH 07/47] move --- test/{probprog.jl => probprog/generate.jl} | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename test/{probprog.jl => probprog/generate.jl} (98%) diff --git a/test/probprog.jl b/test/probprog/generate.jl similarity index 98% rename from test/probprog.jl rename to test/probprog/generate.jl index b3cfe75970..5a488479d4 100644 --- a/test/probprog.jl +++ b/test/probprog/generate.jl @@ -15,7 +15,7 @@ function generate_model(seed, μ, σ) return ProbProg.generate(model, seed, μ, σ) end -@testset "ProbProg" begin +@testset "Generate" begin @testset "normal_deterministic" begin seed1 = Reactant.to_rarray(UInt64[1, 4]) seed2 = Reactant.to_rarray(UInt64[1, 4]) From 9b9395e361ea22db9e730c84a4a53da295335025 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Mon, 12 May 2025 16:10:31 -0500 Subject: [PATCH 08/47] simplify --- src/ProbProg.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ProbProg.jl b/src/ProbProg.jl index b80fb2f628..c73fceb6ef 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -11,7 +11,7 @@ function generate(f::Function, args::Vararg{Any,Nargs}) where {Nargs} resargprefix::Symbol = gensym("generateresarg") mlir_fn_res = TracedUtils.make_mlir_fn( - f, args, (), string(f) * "_generate", false; argprefix, resprefix, resargprefix + f, args, (), string(f), false; argprefix, resprefix, resargprefix ) (; result, linear_args, in_tys, linear_results) = mlir_fn_res fnwrap = mlir_fn_res.fnwrapped @@ -69,7 +69,7 @@ function sample(f::Function, args::Vararg{Any,Nargs}) where {Nargs} resargprefix = gensym("sampleresarg") mlir_fn_res = TracedUtils.make_mlir_fn( - f, args, (), string(f) * "_sample", false; argprefix, resprefix, resargprefix + f, args, (), string(f), false; argprefix, resprefix, resargprefix ) (; result, linear_args, in_tys, linear_results) = mlir_fn_res fnwrap = mlir_fn_res.fnwrapped From b3ba4779d709620ff345793659dac33a1ede1361 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Wed, 14 May 2025 16:52:47 -0500 Subject: [PATCH 09/47] fix up --- src/ProbProg.jl | 6 +- src/mlir/Dialects/Enzyme.jl | 453 ++++++++++++------------------------ test/runtests.jl | 1 - 3 files changed, 158 insertions(+), 302 deletions(-) diff --git a/src/ProbProg.jl b/src/ProbProg.jl index c73fceb6ef..68d0ca3a3f 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -64,9 +64,9 @@ function generate(f::Function, args::Vararg{Any,Nargs}) where {Nargs} end function sample(f::Function, args::Vararg{Any,Nargs}) where {Nargs} - argprefix = gensym("samplearg") - resprefix = gensym("sampleresult") - resargprefix = gensym("sampleresarg") + argprefix::Symbol = gensym("samplearg") + resprefix::Symbol = gensym("sampleresult") + resargprefix::Symbol = gensym("sampleresarg") mlir_fn_res = TracedUtils.make_mlir_fn( f, args, (), string(f), false; argprefix, resprefix, resargprefix diff --git a/src/mlir/Dialects/Enzyme.jl b/src/mlir/Dialects/Enzyme.jl index f558ee0468..3863cc567c 100755 --- a/src/mlir/Dialects/Enzyme.jl +++ b/src/mlir/Dialects/Enzyme.jl @@ -1,18 +1,10 @@ module enzyme using ...IR -import ...IR: - NamedAttribute, - Value, - Location, - Block, - Region, - Attribute, - create_operation, - context, - IndexType +import ...IR: NamedAttribute, Value, Location, Block, Region, Attribute, create_operation, context, IndexType import ..Dialects: namedattribute, operandsegmentsizes import ...API + """ `addTo` @@ -20,75 +12,49 @@ TODO """ function addTo(values::Vector{Value}; location=Location()) op_ty_results = IR.Type[] - operands = Value[values...,] + operands = Value[values..., ] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - return create_operation( - "enzyme.addTo", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzyme.addTo", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end -function autodiff( - inputs::Vector{Value}; - outputs::Vector{IR.Type}, - fn, - activity, - ret_activity, - width=nothing, - location=Location(), -) - op_ty_results = IR.Type[outputs...,] - operands = Value[inputs...,] + +function autodiff(inputs::Vector{Value}; outputs::Vector{IR.Type}, fn, activity, ret_activity, width=nothing, location=Location()) + op_ty_results = IR.Type[outputs..., ] + operands = Value[inputs..., ] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[ - namedattribute("fn", fn), - namedattribute("activity", activity), - namedattribute("ret_activity", ret_activity), - ] + attributes = NamedAttribute[namedattribute("fn", fn), namedattribute("activity", activity), namedattribute("ret_activity", ret_activity), ] !isnothing(width) && push!(attributes, namedattribute("width", width)) - - return create_operation( - "enzyme.autodiff", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzyme.autodiff", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end -function batch( - inputs::Vector{Value}; outputs::Vector{IR.Type}, fn, batch_shape, location=Location() -) - op_ty_results = IR.Type[outputs...,] - operands = Value[inputs...,] + +function batch(inputs::Vector{Value}; outputs::Vector{IR.Type}, fn, batch_shape, location=Location()) + op_ty_results = IR.Type[outputs..., ] + operands = Value[inputs..., ] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[ - namedattribute("fn", fn), namedattribute("batch_shape", batch_shape) - ] - - return create_operation( - "enzyme.batch", - location; - operands, - owned_regions, - successors, - attributes, + attributes = NamedAttribute[namedattribute("fn", fn), namedattribute("batch_shape", batch_shape), ] + + create_operation( + "enzyme.batch", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end @@ -101,53 +67,34 @@ For scalar operands, ranked tensor is created. NOTE: Only works for scalar and *ranked* tensor operands for now. """ function broadcast(input::Value; output::IR.Type, shape, location=Location()) - op_ty_results = IR.Type[output,] - operands = Value[input,] + op_ty_results = IR.Type[output, ] + operands = Value[input, ] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[namedattribute("shape", shape),] - - return create_operation( - "enzyme.broadcast", - location; - operands, - owned_regions, - successors, - attributes, + attributes = NamedAttribute[namedattribute("shape", shape), ] + + create_operation( + "enzyme.broadcast", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end -function fwddiff( - inputs::Vector{Value}; - outputs::Vector{IR.Type}, - fn, - activity, - ret_activity, - width=nothing, - location=Location(), -) - op_ty_results = IR.Type[outputs...,] - operands = Value[inputs...,] + +function fwddiff(inputs::Vector{Value}; outputs::Vector{IR.Type}, fn, activity, ret_activity, width=nothing, location=Location()) + op_ty_results = IR.Type[outputs..., ] + operands = Value[inputs..., ] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[ - namedattribute("fn", fn), - namedattribute("activity", activity), - namedattribute("ret_activity", ret_activity), - ] + attributes = NamedAttribute[namedattribute("fn", fn), namedattribute("activity", activity), namedattribute("ret_activity", ret_activity), ] !isnothing(width) && push!(attributes, namedattribute("width", width)) - - return create_operation( - "enzyme.fwddiff", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzyme.fwddiff", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end @@ -156,197 +103,151 @@ end Generate a sample from a probabilistic function by replacing all SampleOps with distribution calls. """ -function generate( - inputs::Vector{Value}; outputs::Vector{IR.Type}, fn, name=nothing, location=Location() -) - op_ty_results = IR.Type[outputs...,] - operands = Value[inputs...,] +function generate(inputs::Vector{Value}; outputs::Vector{IR.Type}, fn, name=nothing, location=Location()) + op_ty_results = IR.Type[outputs..., ] + operands = Value[inputs..., ] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[namedattribute("fn", fn),] + attributes = NamedAttribute[namedattribute("fn", fn), ] !isnothing(name) && push!(attributes, namedattribute("name", name)) - - return create_operation( - "enzyme.generate", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzyme.generate", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end -function genericAdjoint( - inputs::Vector{Value}, - outputs::Vector{Value}; - result_tensors::Vector{IR.Type}, - indexing_maps, - iterator_types, - doc=nothing, - library_call=nothing, - region::Region, - location=Location(), -) - op_ty_results = IR.Type[result_tensors...,] - operands = Value[inputs..., outputs...] - owned_regions = Region[region,] + +function genericAdjoint(inputs::Vector{Value}, outputs::Vector{Value}; result_tensors::Vector{IR.Type}, indexing_maps, iterator_types, doc=nothing, library_call=nothing, region::Region, location=Location()) + op_ty_results = IR.Type[result_tensors..., ] + operands = Value[inputs..., outputs..., ] + owned_regions = Region[region, ] successors = Block[] - attributes = NamedAttribute[ - namedattribute("indexing_maps", indexing_maps), - namedattribute("iterator_types", iterator_types), - ] - push!(attributes, operandsegmentsizes([length(inputs), length(outputs)])) + attributes = NamedAttribute[namedattribute("indexing_maps", indexing_maps), namedattribute("iterator_types", iterator_types), ] + push!(attributes, operandsegmentsizes([length(inputs), length(outputs), ])) !isnothing(doc) && push!(attributes, namedattribute("doc", doc)) - !isnothing(library_call) && - push!(attributes, namedattribute("library_call", library_call)) - - return create_operation( - "enzyme.genericAdjoint", - location; - operands, - owned_regions, - successors, - attributes, + !isnothing(library_call) && push!(attributes, namedattribute("library_call", library_call)) + + create_operation( + "enzyme.genericAdjoint", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end + function get(gradient::Value; result_0::IR.Type, location=Location()) - op_ty_results = IR.Type[result_0,] - operands = Value[gradient,] + op_ty_results = IR.Type[result_0, ] + operands = Value[gradient, ] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - return create_operation( - "enzyme.get", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzyme.get", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end + function init(; result_0::IR.Type, location=Location()) - op_ty_results = IR.Type[result_0,] + op_ty_results = IR.Type[result_0, ] operands = Value[] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - return create_operation( - "enzyme.init", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzyme.init", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end + function placeholder(; output::IR.Type, location=Location()) - op_ty_results = IR.Type[output,] + op_ty_results = IR.Type[output, ] operands = Value[] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - return create_operation( - "enzyme.placeholder", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzyme.placeholder", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end + function pop(cache::Value; output::IR.Type, location=Location()) - op_ty_results = IR.Type[output,] - operands = Value[cache,] + op_ty_results = IR.Type[output, ] + operands = Value[cache, ] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - return create_operation( - "enzyme.pop", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzyme.pop", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end + function push(cache::Value, value::Value; location=Location()) op_ty_results = IR.Type[] - operands = Value[cache, value] + operands = Value[cache, value, ] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - return create_operation( - "enzyme.push", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzyme.push", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end -function sample( - inputs::Vector{Value}; outputs::Vector{IR.Type}, fn, name=nothing, location=Location() -) - op_ty_results = IR.Type[outputs...,] - operands = Value[inputs...,] + +function sample(inputs::Vector{Value}; outputs::Vector{IR.Type}, fn, name=nothing, location=Location()) + op_ty_results = IR.Type[outputs..., ] + operands = Value[inputs..., ] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[namedattribute("fn", fn),] + attributes = NamedAttribute[namedattribute("fn", fn), ] !isnothing(name) && push!(attributes, namedattribute("name", name)) - - return create_operation( - "enzyme.sample", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzyme.sample", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end + function set(gradient::Value, value::Value; location=Location()) op_ty_results = IR.Type[] - operands = Value[gradient, value] + operands = Value[gradient, value, ] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - return create_operation( - "enzyme.set", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzyme.set", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end @@ -357,25 +258,19 @@ Simulate a probabilistic function to generate execution trace by replacing all SampleOps with distribution calls and inserting sampled values into the choice map. """ -function simulate( - inputs::Vector{Value}; newTrace::IR.Type, fn, name=nothing, location=Location() -) - op_ty_results = IR.Type[newTrace,] - operands = Value[inputs...,] +function simulate(inputs::Vector{Value}; trace::IR.Type, fn, name=nothing, location=Location()) + op_ty_results = IR.Type[trace, ] + operands = Value[inputs..., ] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[namedattribute("fn", fn),] + attributes = NamedAttribute[namedattribute("fn", fn), ] !isnothing(name) && push!(attributes, namedattribute("name", name)) - - return create_operation( - "enzyme.simulate", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzyme.simulate", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end @@ -386,46 +281,22 @@ Execute a probabilistic function specified by a symbol reference using the provi and a set of constraints on the sampled variables (if provided). Return the execution trace (if provided) and the log-likelihood of the execution trace. """ -function trace( - inputs::Vector{Value}, - oldTrace=nothing::Union{Nothing,Value}; - constraints=nothing::Union{Nothing,Value}, - newTrace::IR.Type, - weights::Vector{IR.Type}, - fn, - name=nothing, - location=Location(), -) - op_ty_results = IR.Type[newTrace, weights...] - operands = Value[inputs...,] +function trace(inputs::Vector{Value}, oldTrace=nothing::Union{Nothing, Value}; constraints=nothing::Union{Nothing, Value}, newTrace::IR.Type, weights::Vector{IR.Type}, fn, name=nothing, location=Location()) + op_ty_results = IR.Type[newTrace, weights..., ] + operands = Value[inputs..., ] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[namedattribute("fn", fn),] + attributes = NamedAttribute[namedattribute("fn", fn), ] !isnothing(oldTrace) && push!(operands, oldTrace) !isnothing(constraints) && push!(operands, constraints) - push!( - attributes, - operandsegmentsizes([ - length(inputs), if (oldTrace == nothing) - 0 - elseif 1(constraints == nothing) - 0 - else - 1 - end - ]), - ) + push!(attributes, operandsegmentsizes([length(inputs), (oldTrace==nothing) ? 0 : 1(constraints==nothing) ? 0 : 1])) !isnothing(name) && push!(attributes, namedattribute("name", name)) - - return create_operation( - "enzyme.trace", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzyme.trace", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end @@ -436,21 +307,17 @@ Add a sampled value into the execution trace. """ function addSampleToTrace(trace::Value, sample::Value; name=nothing, location=Location()) op_ty_results = IR.Type[] - operands = Value[trace, sample] + operands = Value[trace, sample, ] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] !isnothing(name) && push!(attributes, namedattribute("name", name)) - - return create_operation( - "enzyme.addSampleToTrace", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzyme.addSampleToTrace", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end @@ -459,29 +326,19 @@ end Insert a constraint on a sampled variable into the choice map. """ -function insertChoiceToMap( - choiceMap::Value, - choice::Value; - newChoiceMap::IR.Type, - name=nothing, - location=Location(), -) - op_ty_results = IR.Type[newChoiceMap,] - operands = Value[choiceMap, choice] +function insertChoiceToMap(choiceMap::Value, choice::Value; newChoiceMap::IR.Type, name=nothing, location=Location()) + op_ty_results = IR.Type[newChoiceMap, ] + operands = Value[choiceMap, choice, ] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] !isnothing(name) && push!(attributes, namedattribute("name", name)) - - return create_operation( - "enzyme.insertChoiceToMap", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzyme.insertChoiceToMap", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end diff --git a/test/runtests.jl b/test/runtests.jl index 489731eff5..a52159b4a3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -16,7 +16,6 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) @safetestset "Tracing" include("tracing.jl") @safetestset "Basic" include("basic.jl") @safetestset "Autodiff" include("autodiff.jl") - @safetestset "ProbProg" include("probprog.jl") @safetestset "Complex" include("complex.jl") @safetestset "Broadcast" include("bcast.jl") @safetestset "Struct" include("struct.jl") From 47e9fe312e2a3e0de63a7e243bd82490a099ab26 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Thu, 15 May 2025 18:10:34 -0500 Subject: [PATCH 10/47] saving changes --- src/ProbProg.jl | 22 +++++++++++++++++++--- src/mlir/Dialects/Enzyme.jl | 8 ++++---- test/probprog/generate.jl | 33 +++++++++++++++++++-------------- 3 files changed, 42 insertions(+), 21 deletions(-) diff --git a/src/ProbProg.jl b/src/ProbProg.jl index 68d0ca3a3f..afa3a0f5b0 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -5,13 +5,21 @@ using ReactantCore: ReactantCore using Enzyme -function generate(f::Function, args::Vararg{Any,Nargs}) where {Nargs} +@noinline function generate(f::Function, args::Vararg{Any,Nargs}) where {Nargs} argprefix::Symbol = gensym("generatearg") resprefix::Symbol = gensym("generateresult") resargprefix::Symbol = gensym("generateresarg") mlir_fn_res = TracedUtils.make_mlir_fn( - f, args, (), string(f), false; argprefix, resprefix, resargprefix + f, + args, + (), + string(f), + false; + args_in_result=:result, + argprefix, + resprefix, + resargprefix, ) (; result, linear_args, in_tys, linear_results) = mlir_fn_res fnwrap = mlir_fn_res.fnwrapped @@ -69,7 +77,15 @@ function sample(f::Function, args::Vararg{Any,Nargs}) where {Nargs} resargprefix::Symbol = gensym("sampleresarg") mlir_fn_res = TracedUtils.make_mlir_fn( - f, args, (), string(f), false; argprefix, resprefix, resargprefix + f, + args, + (), + string(f), + false; + args_in_result=:result, + argprefix, + resprefix, + resargprefix, ) (; result, linear_args, in_tys, linear_results) = mlir_fn_res fnwrap = mlir_fn_res.fnwrapped diff --git a/src/mlir/Dialects/Enzyme.jl b/src/mlir/Dialects/Enzyme.jl index 3863cc567c..54065e0136 100755 --- a/src/mlir/Dialects/Enzyme.jl +++ b/src/mlir/Dialects/Enzyme.jl @@ -258,8 +258,8 @@ Simulate a probabilistic function to generate execution trace by replacing all SampleOps with distribution calls and inserting sampled values into the choice map. """ -function simulate(inputs::Vector{Value}; trace::IR.Type, fn, name=nothing, location=Location()) - op_ty_results = IR.Type[trace, ] +function simulate(inputs::Vector{Value}; outputs::Vector{IR.Type}, fn, name=nothing, location=Location()) + op_ty_results = IR.Type[outputs..., ] operands = Value[inputs..., ] owned_regions = Region[] successors = Block[] @@ -326,8 +326,8 @@ end Insert a constraint on a sampled variable into the choice map. """ -function insertChoiceToMap(choiceMap::Value, choice::Value; newChoiceMap::IR.Type, name=nothing, location=Location()) - op_ty_results = IR.Type[newChoiceMap, ] +function insertChoiceToMap(choiceMap::Value, choice::Value; outputs::IR.Type, name=nothing, location=Location()) + op_ty_results = IR.Type[outputs, ] operands = Value[choiceMap, choice, ] owned_regions = Region[] successors = Block[] diff --git a/test/probprog/generate.jl b/test/probprog/generate.jl index 5a488479d4..8f0ddfcaa0 100644 --- a/test/probprog/generate.jl +++ b/test/probprog/generate.jl @@ -1,55 +1,60 @@ using Reactant, Test, Random, StableRNGs, Statistics using Reactant: ProbProg -normal(rng, μ, σ) = μ .+ σ .* randn(rng, 10000) +normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) -function generate_model(seed, μ, σ) - function model(seed, μ, σ) +function generate_model(seed, μ, σ, shape) + function model(seed, μ, σ, shape) rng = Random.default_rng() Random.seed!(rng, seed) - s = ProbProg.sample(normal, rng, μ, σ) - t = ProbProg.sample(normal, rng, s, σ) + s = ProbProg.sample(normal, rng, μ, σ, shape) + t = ProbProg.sample(normal, rng, s, σ, shape) return t end - return ProbProg.generate(model, seed, μ, σ) + return ProbProg.generate(model, seed, μ, σ, shape) end @testset "Generate" begin @testset "normal_deterministic" begin + shape = (10000,) seed1 = Reactant.to_rarray(UInt64[1, 4]) seed2 = Reactant.to_rarray(UInt64[1, 4]) μ1 = Reactant.ConcreteRArray(0.0) μ2 = Reactant.ConcreteRArray(1000.0) σ1 = Reactant.ConcreteRArray(1.0) σ2 = Reactant.ConcreteRArray(1.0) - model_compiled = @compile generate_model(seed1, μ1, σ1) - @test Array(model_compiled(seed1, μ1, σ1)) ≈ Array(model_compiled(seed1, μ1, σ1)) - @test mean(Array(model_compiled(seed1, μ1, σ1))) ≈ 0.0 atol = 0.05 rtol = 0.05 - @test mean(Array(model_compiled(seed2, μ2, σ2))) ≈ 1000.0 atol = 0.05 rtol = 0.05 + model_compiled = @compile generate_model(seed1, μ1, σ1, shape) + + @test Array(model_compiled(seed1, μ1, σ1, shape)) ≈ Array(model_compiled(seed1, μ1, σ1, shape)) + @test mean(Array(model_compiled(seed1, μ1, σ1, shape))) ≈ 0.0 atol = 0.05 rtol = 0.05 + @test mean(Array(model_compiled(seed2, μ2, σ2, shape))) ≈ 1000.0 atol = 0.05 rtol = 0.05 @test !(all( - Array(model_compiled(seed1, μ1, σ1)) .≈ Array(model_compiled(seed2, μ2, σ2)) + Array(model_compiled(seed1, μ1, σ1, shape)) .≈ Array(model_compiled(seed2, μ2, σ2, shape)) )) end @testset "normal_hlo" begin + shape = (10000,) seed = Reactant.to_rarray(UInt64[1, 4]) μ = Reactant.ConcreteRArray(0.0) σ = Reactant.ConcreteRArray(1.0) - before = @code_hlo optimize = :none generate_model(seed, μ, σ) + + before = @code_hlo optimize = :no_enzyme generate_model(seed, μ, σ, shape) @test contains(repr(before), "enzyme.generate") @test contains(repr(before), "enzyme.sample") - after = @code_hlo optimize = :probprog generate_model(seed, μ, σ) + after = @code_hlo optimize = :probprog generate_model(seed, μ, σ, shape) @test !contains(repr(after), "enzyme.generate") @test !contains(repr(after), "enzyme.sample") end @testset "normal_generate" begin + shape = (10000,) seed = Reactant.to_rarray(UInt64[1, 4]) μ = Reactant.ConcreteRArray(0.0) σ = Reactant.ConcreteRArray(1.0) - X = Array(@jit optimize = :probprog generate_model(seed, μ, σ)) + X = Array(@jit optimize = :probprog generate_model(seed, μ, σ, shape)) @test mean(X) ≈ 0.0 atol = 0.05 rtol = 0.05 end end From a6fcca3dcb18c751ee954b58f4ae0eadbd892c76 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Tue, 20 May 2025 16:37:46 -0500 Subject: [PATCH 11/47] fix sample op --- src/ProbProg.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/ProbProg.jl b/src/ProbProg.jl index afa3a0f5b0..23be48f1d4 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -16,7 +16,7 @@ using Enzyme (), string(f), false; - args_in_result=:result, + args_in_result=:result_and_mutated, argprefix, resprefix, resargprefix, @@ -71,7 +71,7 @@ using Enzyme return result end -function sample(f::Function, args::Vararg{Any,Nargs}) where {Nargs} +function sample!(f::Function, args::Vararg{Any,Nargs}) where {Nargs} argprefix::Symbol = gensym("samplearg") resprefix::Symbol = gensym("sampleresult") resargprefix::Symbol = gensym("sampleresarg") @@ -82,7 +82,7 @@ function sample(f::Function, args::Vararg{Any,Nargs}) where {Nargs} (), string(f), false; - args_in_result=:result, + args_in_result=:result_and_mutated, argprefix, resprefix, resargprefix, From e51e04bb5646f1833a622814045c34d255729f6d Mon Sep 17 00:00:00 2001 From: sbrantq Date: Tue, 20 May 2025 16:45:08 -0500 Subject: [PATCH 12/47] save tests --- test/probprog/generate.jl | 6 ++--- test/probprog/sample.jl | 50 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 3 deletions(-) create mode 100644 test/probprog/sample.jl diff --git a/test/probprog/generate.jl b/test/probprog/generate.jl index 8f0ddfcaa0..cc73d15b12 100644 --- a/test/probprog/generate.jl +++ b/test/probprog/generate.jl @@ -7,8 +7,8 @@ function generate_model(seed, μ, σ, shape) function model(seed, μ, σ, shape) rng = Random.default_rng() Random.seed!(rng, seed) - s = ProbProg.sample(normal, rng, μ, σ, shape) - t = ProbProg.sample(normal, rng, s, σ, shape) + s = ProbProg.sample!(normal, rng, μ, σ, shape) + t = ProbProg.sample!(normal, rng, s, σ, shape) return t end @@ -25,7 +25,7 @@ end σ1 = Reactant.ConcreteRArray(1.0) σ2 = Reactant.ConcreteRArray(1.0) - model_compiled = @compile generate_model(seed1, μ1, σ1, shape) + model_compiled = @compile optimize = :probprog generate_model(seed1, μ1, σ1, shape) @test Array(model_compiled(seed1, μ1, σ1, shape)) ≈ Array(model_compiled(seed1, μ1, σ1, shape)) @test mean(Array(model_compiled(seed1, μ1, σ1, shape))) ≈ 0.0 atol = 0.05 rtol = 0.05 diff --git a/test/probprog/sample.jl b/test/probprog/sample.jl new file mode 100644 index 0000000000..93d411f9a5 --- /dev/null +++ b/test/probprog/sample.jl @@ -0,0 +1,50 @@ +using Reactant, Test, Random, StableRNGs, Statistics +using Reactant: ProbProg + +@noinline normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) + +function sample1(seed, μ, σ, shape) + function model(seed, μ, σ, shape) + rng = Random.default_rng() + Random.seed!(rng, seed) + s = ProbProg.sample!(normal, rng, μ, σ, shape) + return s + end + + return ProbProg.generate(model, seed, μ, σ, shape) +end + +function sample2(seed, μ, σ, shape) + function model(seed, μ, σ, shape) + rng = Random.default_rng() + Random.seed!(rng, seed) + s = ProbProg.sample!(normal, rng, μ, σ, shape) + t = ProbProg.sample!(normal, rng, μ, σ, shape) + return t + end + + return ProbProg.generate(model, seed, μ, σ, shape) +end + +@testset "test" begin + @testset "sample_hlo" begin + shape = (10,) + seed = Reactant.to_rarray(UInt64[1, 4]) + μ = Reactant.ConcreteRArray(0.0) + σ = Reactant.ConcreteRArray(1.0) + before = @code_hlo optimize = false sample2(seed, μ, σ, shape) + @test contains(repr(before), "enzyme.sample") + after = @code_hlo optimize = :probprog sample2(seed, μ, σ, shape) + @test !contains(repr(after), "enzyme.sample") + end + + @testset "sample_normal" begin + shape = (10,) + seed = Reactant.to_rarray(UInt64[1, 4]) + μ = Reactant.ConcreteRArray(0.0) + σ = Reactant.ConcreteRArray(1.0) + X = Array(@jit optimize = :probprog sample1(seed, μ, σ, shape)) + Y = Array(@jit optimize = :probprog sample2(seed, μ, σ, shape)) + @test !all(X .≈ Y) + end +end From ce68f6a3da18d8da60bc01a971f03e5512b81b3f Mon Sep 17 00:00:00 2001 From: sbrantq Date: Tue, 20 May 2025 16:45:47 -0500 Subject: [PATCH 13/47] temporarily removing probprog pass from :all as MLIR pass is not merged yet --- src/Compiler.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index 74f3cd829d..4cee6f2446 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1290,7 +1290,6 @@ function compile_mlir!( raise_passes, "enzyme-batch", opt_passes2, - "probprog", enzyme_pass, opt_passes2, "canonicalize", @@ -1306,7 +1305,6 @@ function compile_mlir!( opt_passes, "enzyme-batch", opt_passes2, - "probprog", enzyme_pass, opt_passes2, "canonicalize", From d31bba636aeb7e1537c1dc90d985553dda84e1a6 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Tue, 20 May 2025 16:53:48 -0500 Subject: [PATCH 14/47] undo enzyme binding change --- src/mlir/Dialects/Enzyme.jl | 425 +++++++++++++++++------------------- 1 file changed, 203 insertions(+), 222 deletions(-) diff --git a/src/mlir/Dialects/Enzyme.jl b/src/mlir/Dialects/Enzyme.jl index 54065e0136..e4306b06a1 100755 --- a/src/mlir/Dialects/Enzyme.jl +++ b/src/mlir/Dialects/Enzyme.jl @@ -1,10 +1,18 @@ module enzyme using ...IR -import ...IR: NamedAttribute, Value, Location, Block, Region, Attribute, create_operation, context, IndexType +import ...IR: + NamedAttribute, + Value, + Location, + Block, + Region, + Attribute, + create_operation, + context, + IndexType import ..Dialects: namedattribute, operandsegmentsizes import ...API - """ `addTo` @@ -12,49 +20,75 @@ TODO """ function addTo(values::Vector{Value}; location=Location()) op_ty_results = IR.Type[] - operands = Value[values..., ] + operands = Value[values...,] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - create_operation( - "enzyme.addTo", location; - operands, owned_regions, successors, attributes, + + return create_operation( + "enzyme.addTo", + location; + operands, + owned_regions, + successors, + attributes, results=op_ty_results, - result_inference=false + result_inference=false, ) end - -function autodiff(inputs::Vector{Value}; outputs::Vector{IR.Type}, fn, activity, ret_activity, width=nothing, location=Location()) - op_ty_results = IR.Type[outputs..., ] - operands = Value[inputs..., ] +function autodiff( + inputs::Vector{Value}; + outputs::Vector{IR.Type}, + fn, + activity, + ret_activity, + width=nothing, + location=Location(), +) + op_ty_results = IR.Type[outputs...,] + operands = Value[inputs...,] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[namedattribute("fn", fn), namedattribute("activity", activity), namedattribute("ret_activity", ret_activity), ] + attributes = NamedAttribute[ + namedattribute("fn", fn), + namedattribute("activity", activity), + namedattribute("ret_activity", ret_activity), + ] !isnothing(width) && push!(attributes, namedattribute("width", width)) - - create_operation( - "enzyme.autodiff", location; - operands, owned_regions, successors, attributes, + + return create_operation( + "enzyme.autodiff", + location; + operands, + owned_regions, + successors, + attributes, results=op_ty_results, - result_inference=false + result_inference=false, ) end - -function batch(inputs::Vector{Value}; outputs::Vector{IR.Type}, fn, batch_shape, location=Location()) - op_ty_results = IR.Type[outputs..., ] - operands = Value[inputs..., ] +function batch( + inputs::Vector{Value}; outputs::Vector{IR.Type}, fn, batch_shape, location=Location() +) + op_ty_results = IR.Type[outputs...,] + operands = Value[inputs...,] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[namedattribute("fn", fn), namedattribute("batch_shape", batch_shape), ] - - create_operation( - "enzyme.batch", location; - operands, owned_regions, successors, attributes, + attributes = NamedAttribute[ + namedattribute("fn", fn), namedattribute("batch_shape", batch_shape) + ] + + return create_operation( + "enzyme.batch", + location; + operands, + owned_regions, + successors, + attributes, results=op_ty_results, - result_inference=false + result_inference=false, ) end @@ -67,278 +101,225 @@ For scalar operands, ranked tensor is created. NOTE: Only works for scalar and *ranked* tensor operands for now. """ function broadcast(input::Value; output::IR.Type, shape, location=Location()) - op_ty_results = IR.Type[output, ] - operands = Value[input, ] + op_ty_results = IR.Type[output,] + operands = Value[input,] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[namedattribute("shape", shape), ] - - create_operation( - "enzyme.broadcast", location; - operands, owned_regions, successors, attributes, + attributes = NamedAttribute[namedattribute("shape", shape),] + + return create_operation( + "enzyme.broadcast", + location; + operands, + owned_regions, + successors, + attributes, results=op_ty_results, - result_inference=false + result_inference=false, ) end - -function fwddiff(inputs::Vector{Value}; outputs::Vector{IR.Type}, fn, activity, ret_activity, width=nothing, location=Location()) - op_ty_results = IR.Type[outputs..., ] - operands = Value[inputs..., ] +function fwddiff( + inputs::Vector{Value}; + outputs::Vector{IR.Type}, + fn, + activity, + ret_activity, + width=nothing, + location=Location(), +) + op_ty_results = IR.Type[outputs...,] + operands = Value[inputs...,] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[namedattribute("fn", fn), namedattribute("activity", activity), namedattribute("ret_activity", ret_activity), ] + attributes = NamedAttribute[ + namedattribute("fn", fn), + namedattribute("activity", activity), + namedattribute("ret_activity", ret_activity), + ] !isnothing(width) && push!(attributes, namedattribute("width", width)) - - create_operation( - "enzyme.fwddiff", location; - operands, owned_regions, successors, attributes, - results=op_ty_results, - result_inference=false - ) -end -""" -`generate` - -Generate a sample from a probabilistic function by replacing all SampleOps with distribution calls. -""" -function generate(inputs::Vector{Value}; outputs::Vector{IR.Type}, fn, name=nothing, location=Location()) - op_ty_results = IR.Type[outputs..., ] - operands = Value[inputs..., ] - owned_regions = Region[] - successors = Block[] - attributes = NamedAttribute[namedattribute("fn", fn), ] - !isnothing(name) && push!(attributes, namedattribute("name", name)) - - create_operation( - "enzyme.generate", location; - operands, owned_regions, successors, attributes, + return create_operation( + "enzyme.fwddiff", + location; + operands, + owned_regions, + successors, + attributes, results=op_ty_results, - result_inference=false + result_inference=false, ) end - -function genericAdjoint(inputs::Vector{Value}, outputs::Vector{Value}; result_tensors::Vector{IR.Type}, indexing_maps, iterator_types, doc=nothing, library_call=nothing, region::Region, location=Location()) - op_ty_results = IR.Type[result_tensors..., ] - operands = Value[inputs..., outputs..., ] - owned_regions = Region[region, ] +function genericAdjoint( + inputs::Vector{Value}, + outputs::Vector{Value}; + result_tensors::Vector{IR.Type}, + indexing_maps, + iterator_types, + doc=nothing, + library_call=nothing, + region::Region, + location=Location(), +) + op_ty_results = IR.Type[result_tensors...,] + operands = Value[inputs..., outputs...] + owned_regions = Region[region,] successors = Block[] - attributes = NamedAttribute[namedattribute("indexing_maps", indexing_maps), namedattribute("iterator_types", iterator_types), ] - push!(attributes, operandsegmentsizes([length(inputs), length(outputs), ])) + attributes = NamedAttribute[ + namedattribute("indexing_maps", indexing_maps), + namedattribute("iterator_types", iterator_types), + ] + push!(attributes, operandsegmentsizes([length(inputs), length(outputs)])) !isnothing(doc) && push!(attributes, namedattribute("doc", doc)) - !isnothing(library_call) && push!(attributes, namedattribute("library_call", library_call)) - - create_operation( - "enzyme.genericAdjoint", location; - operands, owned_regions, successors, attributes, + !isnothing(library_call) && + push!(attributes, namedattribute("library_call", library_call)) + + return create_operation( + "enzyme.genericAdjoint", + location; + operands, + owned_regions, + successors, + attributes, results=op_ty_results, - result_inference=false + result_inference=false, ) end - function get(gradient::Value; result_0::IR.Type, location=Location()) - op_ty_results = IR.Type[result_0, ] - operands = Value[gradient, ] + op_ty_results = IR.Type[result_0,] + operands = Value[gradient,] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - create_operation( - "enzyme.get", location; - operands, owned_regions, successors, attributes, + + return create_operation( + "enzyme.get", + location; + operands, + owned_regions, + successors, + attributes, results=op_ty_results, - result_inference=false + result_inference=false, ) end - function init(; result_0::IR.Type, location=Location()) - op_ty_results = IR.Type[result_0, ] + op_ty_results = IR.Type[result_0,] operands = Value[] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - create_operation( - "enzyme.init", location; - operands, owned_regions, successors, attributes, + + return create_operation( + "enzyme.init", + location; + operands, + owned_regions, + successors, + attributes, results=op_ty_results, - result_inference=false + result_inference=false, ) end - function placeholder(; output::IR.Type, location=Location()) - op_ty_results = IR.Type[output, ] + op_ty_results = IR.Type[output,] operands = Value[] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - create_operation( - "enzyme.placeholder", location; - operands, owned_regions, successors, attributes, + + return create_operation( + "enzyme.placeholder", + location; + operands, + owned_regions, + successors, + attributes, results=op_ty_results, - result_inference=false + result_inference=false, ) end - function pop(cache::Value; output::IR.Type, location=Location()) - op_ty_results = IR.Type[output, ] - operands = Value[cache, ] + op_ty_results = IR.Type[output,] + operands = Value[cache,] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - create_operation( - "enzyme.pop", location; - operands, owned_regions, successors, attributes, + + return create_operation( + "enzyme.pop", + location; + operands, + owned_regions, + successors, + attributes, results=op_ty_results, - result_inference=false + result_inference=false, ) end - function push(cache::Value, value::Value; location=Location()) op_ty_results = IR.Type[] - operands = Value[cache, value, ] + operands = Value[cache, value] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - create_operation( - "enzyme.push", location; - operands, owned_regions, successors, attributes, - results=op_ty_results, - result_inference=false - ) -end - -function sample(inputs::Vector{Value}; outputs::Vector{IR.Type}, fn, name=nothing, location=Location()) - op_ty_results = IR.Type[outputs..., ] - operands = Value[inputs..., ] - owned_regions = Region[] - successors = Block[] - attributes = NamedAttribute[namedattribute("fn", fn), ] - !isnothing(name) && push!(attributes, namedattribute("name", name)) - - create_operation( - "enzyme.sample", location; - operands, owned_regions, successors, attributes, + return create_operation( + "enzyme.push", + location; + operands, + owned_regions, + successors, + attributes, results=op_ty_results, - result_inference=false + result_inference=false, ) end - -function set(gradient::Value, value::Value; location=Location()) - op_ty_results = IR.Type[] - operands = Value[gradient, value, ] +function sample( + inputs::Vector{Value}; outputs::Vector{IR.Type}, fn, name=nothing, location=Location() +) + op_ty_results = IR.Type[outputs...,] + operands = Value[inputs...,] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[] - - create_operation( - "enzyme.set", location; - operands, owned_regions, successors, attributes, - results=op_ty_results, - result_inference=false - ) -end - -""" -`simulate` - -Simulate a probabilistic function to generate execution trace -by replacing all SampleOps with distribution calls and inserting -sampled values into the choice map. -""" -function simulate(inputs::Vector{Value}; outputs::Vector{IR.Type}, fn, name=nothing, location=Location()) - op_ty_results = IR.Type[outputs..., ] - operands = Value[inputs..., ] - owned_regions = Region[] - successors = Block[] - attributes = NamedAttribute[namedattribute("fn", fn), ] + attributes = NamedAttribute[namedattribute("fn", fn),] !isnothing(name) && push!(attributes, namedattribute("name", name)) - - create_operation( - "enzyme.simulate", location; - operands, owned_regions, successors, attributes, - results=op_ty_results, - result_inference=false - ) -end -""" -`trace` - -Execute a probabilistic function specified by a symbol reference using the provided arguments, -and a set of constraints on the sampled variables (if provided). Return the execution trace -(if provided) and the log-likelihood of the execution trace. -""" -function trace(inputs::Vector{Value}, oldTrace=nothing::Union{Nothing, Value}; constraints=nothing::Union{Nothing, Value}, newTrace::IR.Type, weights::Vector{IR.Type}, fn, name=nothing, location=Location()) - op_ty_results = IR.Type[newTrace, weights..., ] - operands = Value[inputs..., ] - owned_regions = Region[] - successors = Block[] - attributes = NamedAttribute[namedattribute("fn", fn), ] - !isnothing(oldTrace) && push!(operands, oldTrace) - !isnothing(constraints) && push!(operands, constraints) - push!(attributes, operandsegmentsizes([length(inputs), (oldTrace==nothing) ? 0 : 1(constraints==nothing) ? 0 : 1])) - !isnothing(name) && push!(attributes, namedattribute("name", name)) - - create_operation( - "enzyme.trace", location; - operands, owned_regions, successors, attributes, + return create_operation( + "enzyme.sample", + location; + operands, + owned_regions, + successors, + attributes, results=op_ty_results, - result_inference=false + result_inference=false, ) end -""" -`addSampleToTrace` - -Add a sampled value into the execution trace. -""" -function addSampleToTrace(trace::Value, sample::Value; name=nothing, location=Location()) +function set(gradient::Value, value::Value; location=Location()) op_ty_results = IR.Type[] - operands = Value[trace, sample, ] + operands = Value[gradient, value] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - !isnothing(name) && push!(attributes, namedattribute("name", name)) - - create_operation( - "enzyme.addSampleToTrace", location; - operands, owned_regions, successors, attributes, - results=op_ty_results, - result_inference=false - ) -end - -""" -`insertChoiceToMap` -Insert a constraint on a sampled variable into the choice map. -""" -function insertChoiceToMap(choiceMap::Value, choice::Value; outputs::IR.Type, name=nothing, location=Location()) - op_ty_results = IR.Type[outputs, ] - operands = Value[choiceMap, choice, ] - owned_regions = Region[] - successors = Block[] - attributes = NamedAttribute[] - !isnothing(name) && push!(attributes, namedattribute("name", name)) - - create_operation( - "enzyme.insertChoiceToMap", location; - operands, owned_regions, successors, attributes, + return create_operation( + "enzyme.set", + location; + operands, + owned_regions, + successors, + attributes, results=op_ty_results, - result_inference=false + result_inference=false, ) end From 573fa021e147f6ef2eac49e46b01edb8baaadd45 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Tue, 20 May 2025 16:58:21 -0500 Subject: [PATCH 15/47] format --- test/probprog/generate.jl | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/test/probprog/generate.jl b/test/probprog/generate.jl index cc73d15b12..47f7beff45 100644 --- a/test/probprog/generate.jl +++ b/test/probprog/generate.jl @@ -27,11 +27,15 @@ end model_compiled = @compile optimize = :probprog generate_model(seed1, μ1, σ1, shape) - @test Array(model_compiled(seed1, μ1, σ1, shape)) ≈ Array(model_compiled(seed1, μ1, σ1, shape)) - @test mean(Array(model_compiled(seed1, μ1, σ1, shape))) ≈ 0.0 atol = 0.05 rtol = 0.05 - @test mean(Array(model_compiled(seed2, μ2, σ2, shape))) ≈ 1000.0 atol = 0.05 rtol = 0.05 + @test Array(model_compiled(seed1, μ1, σ1, shape)) ≈ + Array(model_compiled(seed1, μ1, σ1, shape)) + @test mean(Array(model_compiled(seed1, μ1, σ1, shape))) ≈ 0.0 atol = 0.05 rtol = + 0.05 + @test mean(Array(model_compiled(seed2, μ2, σ2, shape))) ≈ 1000.0 atol = 0.05 rtol = + 0.05 @test !(all( - Array(model_compiled(seed1, μ1, σ1, shape)) .≈ Array(model_compiled(seed2, μ2, σ2, shape)) + Array(model_compiled(seed1, μ1, σ1, shape)) .≈ + Array(model_compiled(seed2, μ2, σ2, shape)), )) end @testset "normal_hlo" begin From 0264a3dc40ae1be90460db8e5cae7bdd3633f381 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Tue, 20 May 2025 17:01:42 -0500 Subject: [PATCH 16/47] format --- src/ProbProg.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ProbProg.jl b/src/ProbProg.jl index 23be48f1d4..a5301a999f 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -128,4 +128,4 @@ function sample!(f::Function, args::Vararg{Any,Nargs}) where {Nargs} return result end -end \ No newline at end of file +end From 2e18bdf8d878212510c4a2b70202183e1085b126 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Tue, 20 May 2025 18:50:58 -0500 Subject: [PATCH 17/47] improve --- src/ProbProg.jl | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/src/ProbProg.jl b/src/ProbProg.jl index a5301a999f..22824bee7b 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -44,14 +44,10 @@ using Enzyme gen_op = MLIR.Dialects.enzyme.generate(batch_inputs; outputs=out_tys, fn=fname) - residx = 1 - for a in linear_results - resv = MLIR.IR.result(gen_op, residx) - residx += 1 - for path in a.paths - if length(path) == 0 - continue - end + for (i, res) in enumerate(linear_results) + resv = MLIR.IR.result(gen_op, i) + for path in res.paths + isempty(path) && continue if path[1] == resprefix TracedUtils.set!(result, path[2:end], resv) elseif path[1] == argprefix @@ -109,18 +105,23 @@ function sample!(f::Function, args::Vararg{Any,Nargs}) where {Nargs} sample_op = MLIR.Dialects.enzyme.sample(batch_inputs; outputs=out_tys, fn=fn_attr) - ridx = 1 - for a in linear_results - val = MLIR.IR.result(sample_op, ridx) - ridx += 1 + for (i, res) in enumerate(linear_results) + resv = MLIR.IR.result(sample_op, i) - for path in a.paths + for path in res.paths isempty(path) && continue if path[1] == resprefix - TracedUtils.set!(result, path[2:end], val) + TracedUtils.set!(result, path[2:end], resv) elseif path[1] == argprefix - idx = path[2]::Int - (fnwrap ? 1 : 0) - TracedUtils.set!(args[idx], path[3:end], val) + idx = path[2]::Int + if idx == 1 && fnwrap + TracedUtils.set!(f, path[3:end], resv) + else + if fnwrap + idx -= 1 + end + TracedUtils.set!(args[idx], path[3:end], resv) + end end end end From 1f1997976987733303c33e841099f91a376493ef Mon Sep 17 00:00:00 2001 From: sbrantq Date: Tue, 20 May 2025 18:51:34 -0500 Subject: [PATCH 18/47] improve --- src/ProbProg.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ProbProg.jl b/src/ProbProg.jl index 22824bee7b..14bca0db4a 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -67,7 +67,7 @@ using Enzyme return result end -function sample!(f::Function, args::Vararg{Any,Nargs}) where {Nargs} +@noinline function sample!(f::Function, args::Vararg{Any,Nargs}) where {Nargs} argprefix::Symbol = gensym("samplearg") resprefix::Symbol = gensym("sampleresult") resargprefix::Symbol = gensym("sampleresarg") From 096d790abbd32fa235f7f123ee5aaed3e2b2f352 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Thu, 22 May 2025 17:22:36 -0500 Subject: [PATCH 19/47] get rid of result_and_mutated too --- src/ProbProg.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ProbProg.jl b/src/ProbProg.jl index 14bca0db4a..8cfa5bec35 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -16,7 +16,7 @@ using Enzyme (), string(f), false; - args_in_result=:result_and_mutated, + args_in_result=:all, argprefix, resprefix, resargprefix, @@ -78,7 +78,7 @@ end (), string(f), false; - args_in_result=:result_and_mutated, + args_in_result=:all, argprefix, resprefix, resargprefix, From 9ac653555b69a695a33421e331d0471e266999fb Mon Sep 17 00:00:00 2001 From: sbrantq Date: Thu, 5 Jun 2025 00:10:15 -0500 Subject: [PATCH 20/47] working trace object pointer hacks + tests --- src/Compiler.jl | 3 + src/ProbProg.jl | 149 +++++++++++++++++++++++++++++++++++++- test/probprog/simulate.jl | 46 ++++++++++++ 3 files changed, 194 insertions(+), 4 deletions(-) create mode 100644 test/probprog/simulate.jl diff --git a/src/Compiler.jl b/src/Compiler.jl index 35bf0c4e9b..6a2becb4e2 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1488,6 +1488,7 @@ function compile_mlir!( blas_int_width = sizeof(BLAS.BlasInt) * 8 lower_enzymexla_linalg_pass = "lower-enzymexla-linalg{backend=$backend \ blas_int_width=$blas_int_width}" + lower_enzyme_probprog_pass = "lower-enzyme-probprog{backend=$backend}" if optimize === :all run_pass_pipeline!( @@ -1651,6 +1652,8 @@ function compile_mlir!( "canonicalize", "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", + lower_enzyme_probprog_pass, + jit ], ',', ), diff --git a/src/ProbProg.jl b/src/ProbProg.jl index 8cfa5bec35..7f83ac96c3 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -1,10 +1,50 @@ module ProbProg -using ..Reactant: Reactant, XLA, MLIR, TracedUtils +using ..Reactant: Reactant, XLA, MLIR, TracedUtils, TracedRArray, ConcretePJRTArray using ReactantCore: ReactantCore +using Libdl: Libdl using Enzyme +const Trace = Dict{Symbol,Any}(:_integrity_check => 0x123456789abcdef) + +function initTraceLowered(trace_ptr_ptr::Ptr{Ptr{Cvoid}}) + trace_ptr = unsafe_load(trace_ptr_ptr) + @assert reinterpret(UInt64, trace_ptr) == 42 + + unsafe_store!(trace_ptr_ptr, pointer_from_objref(Trace)) + + return nothing +end + +function addSampleToTraceLowered( + trace_ptr_ptr::Ptr{Ptr{Cvoid}}, + symbol_ptr_ptr::Ptr{Ptr{Cvoid}}, + sample_ptr_ptr::Ptr{Cvoid}, +) + trace = unsafe_pointer_to_objref(unsafe_load(trace_ptr_ptr)) + symbol = unsafe_pointer_to_objref(unsafe_load(symbol_ptr_ptr)) + + trace[symbol] = 888 + + return nothing +end + +function __init__() + init_trace_ptr = @cfunction(initTraceLowered, Cvoid, (Ptr{Ptr{Cvoid}},)) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_init_trace::Cstring, init_trace_ptr::Ptr{Cvoid} + )::Cvoid + add_sample_to_trace_ptr = @cfunction( + addSampleToTraceLowered, Cvoid, (Ptr{Ptr{Cvoid}}, Ptr{Ptr{Cvoid}}, Ptr{Cvoid}) + ) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_add_sample_to_trace::Cstring, add_sample_to_trace_ptr::Ptr{Cvoid} + )::Cvoid + + return nothing +end + @noinline function generate(f::Function, args::Vararg{Any,Nargs}) where {Nargs} argprefix::Symbol = gensym("generatearg") resprefix::Symbol = gensym("generateresult") @@ -67,7 +107,9 @@ using Enzyme return result end -@noinline function sample!(f::Function, args::Vararg{Any,Nargs}) where {Nargs} +@noinline function sample!( + f::Function, args::Vararg{Any,Nargs}; symbol::Symbol=gensym("sample") +) where {Nargs} argprefix::Symbol = gensym("samplearg") resprefix::Symbol = gensym("sampleresult") resargprefix::Symbol = gensym("sampleresarg") @@ -83,7 +125,7 @@ end resprefix, resargprefix, ) - (; result, linear_args, in_tys, linear_results) = mlir_fn_res + (; result, linear_args, linear_results) = mlir_fn_res fnwrap = mlir_fn_res.fnwrapped func2 = mlir_fn_res.f @@ -103,7 +145,17 @@ end sym = TracedUtils.get_attribute_by_name(func2, "sym_name") fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(sym)) - sample_op = MLIR.Dialects.enzyme.sample(batch_inputs; outputs=out_tys, fn=fn_attr) + symbol_ptr = pointer_from_objref(symbol) + symbol_addr = reinterpret(UInt64, symbol_ptr) + + addr_attr = MLIR.IR.DenseElementsAttribute([symbol_addr]) + + sample_op = MLIR.Dialects.enzyme.sample( + MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=addr_attr), 1), + batch_inputs; + outputs=out_tys, + fn=fn_attr, + ) for (i, res) in enumerate(linear_results) resv = MLIR.IR.result(sample_op, i) @@ -129,4 +181,93 @@ end return result end +@noinline function simulate(f::Function, args::Vararg{Any,Nargs}) where {Nargs} + argprefix::Symbol = gensym("simulatearg") + resprefix::Symbol = gensym("simulateresult") + resargprefix::Symbol = gensym("simulateresarg") + + mlir_fn_res = TracedUtils.make_mlir_fn( + f, + args, + (), + string(f), + false; + args_in_result=:all, + argprefix, + resprefix, + resargprefix, + ) + (; linear_args, linear_results) = mlir_fn_res + fnwrap = mlir_fn_res.fnwrapped + func2 = mlir_fn_res.f + + batch_inputs = MLIR.IR.Value[] + for a in linear_args + idx, path = TracedUtils.get_argidx(a, argprefix) + if idx == 1 && fnwrap + TracedUtils.push_val!(batch_inputs, f, path[3:end]) + else + if fnwrap + idx -= 1 + end + TracedUtils.push_val!(batch_inputs, args[idx], path[3:end]) + end + end + + out_tys = MLIR.IR.Type[] + supress_rest = false + for res in linear_results + if TracedUtils.has_idx(res, resprefix) && !supress_rest + push!(out_tys, MLIR.IR.TensorType([1], MLIR.IR.Type(UInt64))) + supress_rest = true + else + # push!(out_tys, MLIR.IR.type(TracedUtils.get_mlir_data(res))) + end + end + + fname = TracedUtils.get_attribute_by_name(func2, "sym_name") + fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) + + simulate_op = MLIR.Dialects.enzyme.simulate(batch_inputs; outputs=out_tys, fn=fname) + + result = nothing + for (i, res) in enumerate(linear_results) + resv = MLIR.IR.result(simulate_op, i) + + if TracedUtils.has_idx(res, resprefix) + # casted = MLIR.IR.result( + # MLIR.Dialects.builtin.unrealized_conversion_cast( + # resv; to=MLIR.IR.TensorType([1], MLIR.IR.Type(UInt64)) + # ), + # 1, + # ) + # result = TracedRArray(casted) + result = TracedRArray(resv) + break + # continue + end + + # for path in res.paths + # isempty(path) && continue + # if path[1] == argprefix + # idx = path[2]::Int + # if idx == 1 && fnwrap + # TracedUtils.set!(f, path[3:end], resv) + # else + # if fnwrap + # idx -= 1 + # end + # TracedUtils.set!(args[idx], path[3:end], resv) + # end + # end + # end + end + + return result +end + +function getTrace(t::ConcretePJRTArray) + return unsafe_pointer_to_objref(reinterpret(Ptr{Cvoid}, Array{UInt64,1}(t)[1])) +end + end diff --git a/test/probprog/simulate.jl b/test/probprog/simulate.jl new file mode 100644 index 0000000000..de1abc05df --- /dev/null +++ b/test/probprog/simulate.jl @@ -0,0 +1,46 @@ +using Reactant, Test, Random, StableRNGs, Statistics +using Reactant: ProbProg +using Libdl: Libdl + +normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) + +function simulate_model(seed, μ, σ, shape) + function model(seed, μ, σ, shape) + rng = Random.default_rng() + Random.seed!(rng, seed) + s = ProbProg.sample!(normal, rng, μ, σ, shape) + t = ProbProg.sample!(normal, rng, s, σ, shape) + return t + end + + return ProbProg.simulate(model, seed, μ, σ, shape) +end + + +@testset "Simulate" begin + @testset "normal_hlo" begin + shape = (10000,) + seed = Reactant.to_rarray(UInt64[1, 4]) + μ = Reactant.ConcreteRArray(0.0) + σ = Reactant.ConcreteRArray(1.0) + + before = @code_hlo optimize = :no_enzyme simulate_model(seed, μ, σ, shape) + @test contains(repr(before), "enzyme.simulate") + @test contains(repr(before), "enzyme.sample") + + after = @code_hlo optimize = :probprog simulate_model(seed, μ, σ, shape) + @test !contains(repr(after), "enzyme.simulate") + @test !contains(repr(after), "enzyme.sample") + @test contains(repr(after), "enzyme_probprog_add_sample_to_trace") + @test contains(repr(after), "enzyme_probprog_init_trace") + end + + @testset "normal_simulate" begin + shape = (10000,) + seed = Reactant.to_rarray(UInt64[1, 4]) + μ = Reactant.ConcreteRArray(0.0) + σ = Reactant.ConcreteRArray(1.0) + X = ProbProg.getTrace(@jit optimize = :probprog simulate_model(seed, μ, σ, shape)) + @test X[:_integrity_check] == 0x123456789abcdef + end +end From b24766f0ccd570e443a92b4e09f901860806e2ec Mon Sep 17 00:00:00 2001 From: sbrantq Date: Thu, 5 Jun 2025 01:38:31 -0500 Subject: [PATCH 21/47] Assuming scalar samples for now; simple Bayesian linear regression test --- src/ProbProg.jl | 4 ++-- test/probprog/blr.jl | 28 ++++++++++++++++++++++++++++ test/probprog/simulate.jl | 4 ++-- 3 files changed, 32 insertions(+), 4 deletions(-) create mode 100644 test/probprog/blr.jl diff --git a/src/ProbProg.jl b/src/ProbProg.jl index 7f83ac96c3..834e09910f 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -20,12 +20,12 @@ end function addSampleToTraceLowered( trace_ptr_ptr::Ptr{Ptr{Cvoid}}, symbol_ptr_ptr::Ptr{Ptr{Cvoid}}, - sample_ptr_ptr::Ptr{Cvoid}, + sample_ptr::Ptr{Cvoid}, ) trace = unsafe_pointer_to_objref(unsafe_load(trace_ptr_ptr)) symbol = unsafe_pointer_to_objref(unsafe_load(symbol_ptr_ptr)) - trace[symbol] = 888 + trace[symbol] = unsafe_load(reinterpret(Ptr{Float64}, sample_ptr)) return nothing end diff --git a/test/probprog/blr.jl b/test/probprog/blr.jl new file mode 100644 index 0000000000..619af0b5c5 --- /dev/null +++ b/test/probprog/blr.jl @@ -0,0 +1,28 @@ +using Reactant, Test, Random, StableRNGs, Statistics +using Reactant: ProbProg +using Libdl: Libdl + +normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) + +function blr(seed, xs) + function model(seed, xs) + rng = Random.default_rng() + Random.seed!(rng, seed) + slope = ProbProg.sample!(normal, rng, 0, 2, (1,); symbol=:slope) + intercept = ProbProg.sample!(normal, rng, 0, 10, (1,); symbol=:intercept) + for (i, x) in enumerate(xs) + ProbProg.sample!(normal, rng, slope * x + intercept, 1, (1,); symbol=Symbol("y-$i")) + end + return intercept + end + + return ProbProg.simulate(model, seed, xs) +end + +@testset "BLR" begin + xs = [1, 2, 3, 4, 5] + seed = Reactant.to_rarray(UInt64[1, 4]) + X = ProbProg.getTrace(@jit optimize = :probprog blr(seed, xs)) + @test X[:_integrity_check] == 0x123456789abcdef + @show X +end diff --git a/test/probprog/simulate.jl b/test/probprog/simulate.jl index de1abc05df..80c820a861 100644 --- a/test/probprog/simulate.jl +++ b/test/probprog/simulate.jl @@ -8,8 +8,8 @@ function simulate_model(seed, μ, σ, shape) function model(seed, μ, σ, shape) rng = Random.default_rng() Random.seed!(rng, seed) - s = ProbProg.sample!(normal, rng, μ, σ, shape) - t = ProbProg.sample!(normal, rng, s, σ, shape) + s = ProbProg.sample!(normal, rng, μ, σ, shape; symbol = :s) + t = ProbProg.sample!(normal, rng, s, σ, shape; symbol = :t) return t end From 3c52b39a913f2f529eaff0d7e88ef4d4e6736062 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Thu, 5 Jun 2025 17:14:16 -0500 Subject: [PATCH 22/47] exclamation mark --- src/ProbProg.jl | 4 ++-- test/probprog/generate.jl | 2 +- test/probprog/sample.jl | 4 ++-- test/probprog/simulate.jl | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/ProbProg.jl b/src/ProbProg.jl index 834e09910f..5fdde5489b 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -45,7 +45,7 @@ function __init__() return nothing end -@noinline function generate(f::Function, args::Vararg{Any,Nargs}) where {Nargs} +@noinline function generate!(f::Function, args::Vararg{Any,Nargs}) where {Nargs} argprefix::Symbol = gensym("generatearg") resprefix::Symbol = gensym("generateresult") resargprefix::Symbol = gensym("generateresarg") @@ -181,7 +181,7 @@ end return result end -@noinline function simulate(f::Function, args::Vararg{Any,Nargs}) where {Nargs} +@noinline function simulate!(f::Function, args::Vararg{Any,Nargs}) where {Nargs} argprefix::Symbol = gensym("simulatearg") resprefix::Symbol = gensym("simulateresult") resargprefix::Symbol = gensym("simulateresarg") diff --git a/test/probprog/generate.jl b/test/probprog/generate.jl index 47f7beff45..64297a6ce2 100644 --- a/test/probprog/generate.jl +++ b/test/probprog/generate.jl @@ -12,7 +12,7 @@ function generate_model(seed, μ, σ, shape) return t end - return ProbProg.generate(model, seed, μ, σ, shape) + return ProbProg.generate!(model, seed, μ, σ, shape) end @testset "Generate" begin diff --git a/test/probprog/sample.jl b/test/probprog/sample.jl index 93d411f9a5..8d79488566 100644 --- a/test/probprog/sample.jl +++ b/test/probprog/sample.jl @@ -11,7 +11,7 @@ function sample1(seed, μ, σ, shape) return s end - return ProbProg.generate(model, seed, μ, σ, shape) + return ProbProg.generate!(model, seed, μ, σ, shape) end function sample2(seed, μ, σ, shape) @@ -23,7 +23,7 @@ function sample2(seed, μ, σ, shape) return t end - return ProbProg.generate(model, seed, μ, σ, shape) + return ProbProg.generate!(model, seed, μ, σ, shape) end @testset "test" begin diff --git a/test/probprog/simulate.jl b/test/probprog/simulate.jl index 80c820a861..5ec3f1d031 100644 --- a/test/probprog/simulate.jl +++ b/test/probprog/simulate.jl @@ -13,7 +13,7 @@ function simulate_model(seed, μ, σ, shape) return t end - return ProbProg.simulate(model, seed, μ, σ, shape) + return ProbProg.simulate!(model, seed, μ, σ, shape) end From af3d055e2166666d711f570aba418d5883c35ff1 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Thu, 5 Jun 2025 23:38:59 -0500 Subject: [PATCH 23/47] sample metadata --- src/ProbProg.jl | 70 +++++++++++++++++++++++++++++++++++---- test/probprog/simulate.jl | 3 +- 2 files changed, 66 insertions(+), 7 deletions(-) diff --git a/src/ProbProg.jl b/src/ProbProg.jl index 5fdde5489b..d374e2efa4 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -6,7 +6,20 @@ using Libdl: Libdl using Enzyme -const Trace = Dict{Symbol,Any}(:_integrity_check => 0x123456789abcdef) +struct SampleMetadata + shape::NTuple{N,Int} where {N} + element_type::Type + is_scalar::Bool + + function SampleMetadata( + shape::NTuple{N,Int}, element_type::Type, is_scalar::Bool + ) where {N} + return new(shape, element_type, is_scalar) + end +end + +const SAMPLE_METADATA_CACHE = IdDict{Symbol,SampleMetadata}() +const Trace = IdDict{Symbol,Any}(:_integrity_check => 0x123456789abcdef) function initTraceLowered(trace_ptr_ptr::Ptr{Ptr{Cvoid}}) trace_ptr = unsafe_load(trace_ptr_ptr) @@ -18,14 +31,28 @@ function initTraceLowered(trace_ptr_ptr::Ptr{Ptr{Cvoid}}) end function addSampleToTraceLowered( - trace_ptr_ptr::Ptr{Ptr{Cvoid}}, - symbol_ptr_ptr::Ptr{Ptr{Cvoid}}, - sample_ptr::Ptr{Cvoid}, + trace_ptr_ptr::Ptr{Ptr{Cvoid}}, symbol_ptr_ptr::Ptr{Ptr{Cvoid}}, sample_ptr::Ptr{Cvoid} ) trace = unsafe_pointer_to_objref(unsafe_load(trace_ptr_ptr)) symbol = unsafe_pointer_to_objref(unsafe_load(symbol_ptr_ptr)) - trace[symbol] = unsafe_load(reinterpret(Ptr{Float64}, sample_ptr)) + @assert haskey(SAMPLE_METADATA_CACHE, symbol) "Symbol $symbol not found in metadata cache" + + metadata = SAMPLE_METADATA_CACHE[symbol] + shape = metadata.shape + element_type = metadata.element_type + is_scalar = metadata.is_scalar + + if is_scalar + value = unsafe_load(reinterpret(Ptr{element_type}, sample_ptr)) + else + value = unsafe_wrap( + Array{element_type}, reinterpret(Ptr{element_type}, sample_ptr), prod(shape) + ) + value = reshape(value, shape) # TODO: GC'd? + end + + trace[symbol] = value return nothing end @@ -145,9 +172,22 @@ end sym = TracedUtils.get_attribute_by_name(func2, "sym_name") fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(sym)) + if !isempty(linear_results) + sample_result = linear_results[1] # TODO: consider multiple results + sample_mlir_data = TracedUtils.get_mlir_data(sample_result) + @assert sample_mlir_data isa MLIR.IR.Value "Sample $sample_result is not a MLIR.IR.Value" + + sample_type = MLIR.IR.type(sample_mlir_data) + sample_shape = size(sample_type) + sample_element_type = MLIR.IR.julia_type(eltype(sample_type)) + + SAMPLE_METADATA_CACHE[symbol] = SampleMetadata( + sample_shape, sample_element_type, length(sample_shape) == 0 + ) + end + symbol_ptr = pointer_from_objref(symbol) symbol_addr = reinterpret(UInt64, symbol_ptr) - addr_attr = MLIR.IR.DenseElementsAttribute([symbol_addr]) sample_op = MLIR.Dialects.enzyme.sample( @@ -270,4 +310,22 @@ function getTrace(t::ConcretePJRTArray) return unsafe_pointer_to_objref(reinterpret(Ptr{Cvoid}, Array{UInt64,1}(t)[1])) end +function print_trace(trace::IdDict) + println("Probabilistic Program Trace:") + for (symbol, sample) in trace + symbol == :_integrity_check && continue + metadata = SAMPLE_METADATA_CACHE[symbol] + + println(" $symbol:") + println(" Sample: $(sample)") + println(" Shape: $(metadata.shape)") + println(" Element Type: $(metadata.element_type)") + end +end + +function clear_sample_metadata_cache!() + empty!(SAMPLE_METADATA_CACHE) + return nothing +end + end diff --git a/test/probprog/simulate.jl b/test/probprog/simulate.jl index 5ec3f1d031..94fde7e55a 100644 --- a/test/probprog/simulate.jl +++ b/test/probprog/simulate.jl @@ -36,11 +36,12 @@ end end @testset "normal_simulate" begin - shape = (10000,) + shape = (10,) seed = Reactant.to_rarray(UInt64[1, 4]) μ = Reactant.ConcreteRArray(0.0) σ = Reactant.ConcreteRArray(1.0) X = ProbProg.getTrace(@jit optimize = :probprog simulate_model(seed, μ, σ, shape)) @test X[:_integrity_check] == 0x123456789abcdef + ProbProg.print_trace(X) end end From 6c7ffa3e4e692de353f9848789dd759f55dd8dda Mon Sep 17 00:00:00 2001 From: sbrantq Date: Thu, 5 Jun 2025 23:54:42 -0500 Subject: [PATCH 24/47] fix up copy --- src/ProbProg.jl | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/ProbProg.jl b/src/ProbProg.jl index d374e2efa4..0c9a95b9bd 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -44,16 +44,15 @@ function addSampleToTraceLowered( is_scalar = metadata.is_scalar if is_scalar - value = unsafe_load(reinterpret(Ptr{element_type}, sample_ptr)) + trace[symbol] = unsafe_load(reinterpret(Ptr{element_type}, sample_ptr)) else - value = unsafe_wrap( - Array{element_type}, reinterpret(Ptr{element_type}, sample_ptr), prod(shape) + trace[symbol] = Base.deepcopy( + unsafe_wrap( + Array{element_type}, reinterpret(Ptr{element_type}, sample_ptr), prod(shape) + ), ) - value = reshape(value, shape) # TODO: GC'd? end - trace[symbol] = value - return nothing end From 4e017d0476f33d41ac61e96490f2fce2a76ec6b2 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Thu, 5 Jun 2025 23:58:33 -0500 Subject: [PATCH 25/47] fix up copy --- src/ProbProg.jl | 9 +++++++-- test/probprog/simulate.jl | 7 +++---- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/ProbProg.jl b/src/ProbProg.jl index 0c9a95b9bd..80c934f95e 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -47,8 +47,13 @@ function addSampleToTraceLowered( trace[symbol] = unsafe_load(reinterpret(Ptr{element_type}, sample_ptr)) else trace[symbol] = Base.deepcopy( - unsafe_wrap( - Array{element_type}, reinterpret(Ptr{element_type}, sample_ptr), prod(shape) + reshape( + unsafe_wrap( + Array{element_type}, + reinterpret(Ptr{element_type}, sample_ptr), + prod(shape), + ), + shape, ), ) end diff --git a/test/probprog/simulate.jl b/test/probprog/simulate.jl index 94fde7e55a..97910443e2 100644 --- a/test/probprog/simulate.jl +++ b/test/probprog/simulate.jl @@ -8,15 +8,14 @@ function simulate_model(seed, μ, σ, shape) function model(seed, μ, σ, shape) rng = Random.default_rng() Random.seed!(rng, seed) - s = ProbProg.sample!(normal, rng, μ, σ, shape; symbol = :s) - t = ProbProg.sample!(normal, rng, s, σ, shape; symbol = :t) + s = ProbProg.sample!(normal, rng, μ, σ, shape; symbol=:s) + t = ProbProg.sample!(normal, rng, s, σ, shape; symbol=:t) return t end return ProbProg.simulate!(model, seed, μ, σ, shape) end - @testset "Simulate" begin @testset "normal_hlo" begin shape = (10000,) @@ -36,7 +35,7 @@ end end @testset "normal_simulate" begin - shape = (10,) + shape = (3, 3, 3) seed = Reactant.to_rarray(UInt64[1, 4]) μ = Reactant.ConcreteRArray(0.0) σ = Reactant.ConcreteRArray(1.0) From e53fc7cf58d9d93047c055a9d6adc9e7c32f0487 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Fri, 6 Jun 2025 00:20:39 -0500 Subject: [PATCH 26/47] working vectorized blr test --- src/ProbProg.jl | 1 + test/probprog/blr.jl | 38 +++++++++++++++++++++++++------------- 2 files changed, 26 insertions(+), 13 deletions(-) diff --git a/src/ProbProg.jl b/src/ProbProg.jl index 80c934f95e..46d9b4f786 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -152,6 +152,7 @@ end string(f), false; args_in_result=:all, + do_transpose=false, # TODO: double check transpose argprefix, resprefix, resargprefix, diff --git a/test/probprog/blr.jl b/test/probprog/blr.jl index 619af0b5c5..4f0836b76c 100644 --- a/test/probprog/blr.jl +++ b/test/probprog/blr.jl @@ -3,26 +3,38 @@ using Reactant: ProbProg using Libdl: Libdl normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) +bernoulli_logit(rng, logit, shape) = rand(rng, shape...) .< (1 ./ (1 .+ exp.(-logit))) -function blr(seed, xs) - function model(seed, xs) +function blr(seed, N, K) + function model(seed, N, K) rng = Random.default_rng() Random.seed!(rng, seed) - slope = ProbProg.sample!(normal, rng, 0, 2, (1,); symbol=:slope) - intercept = ProbProg.sample!(normal, rng, 0, 10, (1,); symbol=:intercept) - for (i, x) in enumerate(xs) - ProbProg.sample!(normal, rng, slope * x + intercept, 1, (1,); symbol=Symbol("y-$i")) - end - return intercept + + # α ~ Normal(0, 10, size = 1) + α = ProbProg.sample!(normal, rng, 0, 10, (1,); symbol=:α) + + # β ~ Normal(0, 2.5, size = K) + β = ProbProg.sample!(normal, rng, 0, 2.5, (K,); symbol=:β) + + # X ~ Normal(0, 10, size = (N, K)) + X = ProbProg.sample!(normal, rng, 0, 10, (N, K); symbol=:X) # TODO: double check transpose + + # μ = α .+ X * β + μ = α .+ X * β + + ProbProg.sample!(bernoulli_logit, rng, μ, (N,); symbol=:Y) + + return μ end - return ProbProg.simulate(model, seed, xs) + return ProbProg.simulate!(model, seed, N, K) end @testset "BLR" begin - xs = [1, 2, 3, 4, 5] + N = 5 # number of observations + K = 3 # number of features seed = Reactant.to_rarray(UInt64[1, 4]) - X = ProbProg.getTrace(@jit optimize = :probprog blr(seed, xs)) - @test X[:_integrity_check] == 0x123456789abcdef - @show X + + X = ProbProg.getTrace(@jit optimize = :probprog blr(seed, N, K)) + ProbProg.print_trace(X) end From 1dbf5c73702b068d2ad338021f5d37912351c3b4 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Wed, 11 Jun 2025 17:35:40 -0500 Subject: [PATCH 27/47] fix test warning --- test/probprog/generate.jl | 16 ++++++++-------- test/probprog/sample.jl | 8 ++++---- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/test/probprog/generate.jl b/test/probprog/generate.jl index 64297a6ce2..605b375805 100644 --- a/test/probprog/generate.jl +++ b/test/probprog/generate.jl @@ -20,10 +20,10 @@ end shape = (10000,) seed1 = Reactant.to_rarray(UInt64[1, 4]) seed2 = Reactant.to_rarray(UInt64[1, 4]) - μ1 = Reactant.ConcreteRArray(0.0) - μ2 = Reactant.ConcreteRArray(1000.0) - σ1 = Reactant.ConcreteRArray(1.0) - σ2 = Reactant.ConcreteRArray(1.0) + μ1 = Reactant.ConcreteRNumber(0.0) + μ2 = Reactant.ConcreteRNumber(1000.0) + σ1 = Reactant.ConcreteRNumber(1.0) + σ2 = Reactant.ConcreteRNumber(1.0) model_compiled = @compile optimize = :probprog generate_model(seed1, μ1, σ1, shape) @@ -41,8 +41,8 @@ end @testset "normal_hlo" begin shape = (10000,) seed = Reactant.to_rarray(UInt64[1, 4]) - μ = Reactant.ConcreteRArray(0.0) - σ = Reactant.ConcreteRArray(1.0) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) before = @code_hlo optimize = :no_enzyme generate_model(seed, μ, σ, shape) @test contains(repr(before), "enzyme.generate") @@ -56,8 +56,8 @@ end @testset "normal_generate" begin shape = (10000,) seed = Reactant.to_rarray(UInt64[1, 4]) - μ = Reactant.ConcreteRArray(0.0) - σ = Reactant.ConcreteRArray(1.0) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) X = Array(@jit optimize = :probprog generate_model(seed, μ, σ, shape)) @test mean(X) ≈ 0.0 atol = 0.05 rtol = 0.05 end diff --git a/test/probprog/sample.jl b/test/probprog/sample.jl index 8d79488566..9c711241d8 100644 --- a/test/probprog/sample.jl +++ b/test/probprog/sample.jl @@ -30,8 +30,8 @@ end @testset "sample_hlo" begin shape = (10,) seed = Reactant.to_rarray(UInt64[1, 4]) - μ = Reactant.ConcreteRArray(0.0) - σ = Reactant.ConcreteRArray(1.0) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) before = @code_hlo optimize = false sample2(seed, μ, σ, shape) @test contains(repr(before), "enzyme.sample") after = @code_hlo optimize = :probprog sample2(seed, μ, σ, shape) @@ -41,8 +41,8 @@ end @testset "sample_normal" begin shape = (10,) seed = Reactant.to_rarray(UInt64[1, 4]) - μ = Reactant.ConcreteRArray(0.0) - σ = Reactant.ConcreteRArray(1.0) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) X = Array(@jit optimize = :probprog sample1(seed, μ, σ, shape)) Y = Array(@jit optimize = :probprog sample2(seed, μ, σ, shape)) @test !all(X .≈ Y) From dd9dcabe5bd683b384dc6bf40b6520d9c99fae18 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Wed, 11 Jun 2025 17:36:15 -0500 Subject: [PATCH 28/47] hacks to temporarily remove world age issue in tests --- src/ProbProg.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ProbProg.jl b/src/ProbProg.jl index 46d9b4f786..5b0d33a88f 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -81,7 +81,7 @@ end resprefix::Symbol = gensym("generateresult") resargprefix::Symbol = gensym("generateresarg") - mlir_fn_res = TracedUtils.make_mlir_fn( + mlir_fn_res = invokelatest(TracedUtils.make_mlir_fn, f, args, (), @@ -145,7 +145,7 @@ end resprefix::Symbol = gensym("sampleresult") resargprefix::Symbol = gensym("sampleresarg") - mlir_fn_res = TracedUtils.make_mlir_fn( + mlir_fn_res = invokelatest(TracedUtils.make_mlir_fn, f, args, (), From a34472613d541527cb54b126704bd4345f533e89 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Thu, 12 Jun 2025 17:20:15 -0500 Subject: [PATCH 29/47] partial refactoring --- src/ProbProg.jl | 150 +++++++++++++++++++----------------------------- 1 file changed, 58 insertions(+), 92 deletions(-) diff --git a/src/ProbProg.jl b/src/ProbProg.jl index 5b0d33a88f..d9f0672071 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -1,9 +1,6 @@ module ProbProg -using ..Reactant: Reactant, XLA, MLIR, TracedUtils, TracedRArray, ConcretePJRTArray -using ReactantCore: ReactantCore -using Libdl: Libdl - +using ..Reactant: MLIR, TracedUtils, AbstractConcreteArray using Enzyme struct SampleMetadata @@ -18,16 +15,10 @@ struct SampleMetadata end end -const SAMPLE_METADATA_CACHE = IdDict{Symbol,SampleMetadata}() -const Trace = IdDict{Symbol,Any}(:_integrity_check => 0x123456789abcdef) - -function initTraceLowered(trace_ptr_ptr::Ptr{Ptr{Cvoid}}) - trace_ptr = unsafe_load(trace_ptr_ptr) - @assert reinterpret(UInt64, trace_ptr) == 42 - - unsafe_store!(trace_ptr_ptr, pointer_from_objref(Trace)) +const SAMPLE_METADATA_CACHE = Dict{Symbol,SampleMetadata}() - return nothing +function createTrace() + return Dict{Symbol,Any}(:_integrity_check => 0x123456789abcdef) end function addSampleToTraceLowered( @@ -46,7 +37,7 @@ function addSampleToTraceLowered( if is_scalar trace[symbol] = unsafe_load(reinterpret(Ptr{element_type}, sample_ptr)) else - trace[symbol] = Base.deepcopy( + trace[symbol] = copy( reshape( unsafe_wrap( Array{element_type}, @@ -62,10 +53,6 @@ function addSampleToTraceLowered( end function __init__() - init_trace_ptr = @cfunction(initTraceLowered, Cvoid, (Ptr{Ptr{Cvoid}},)) - @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( - :enzyme_probprog_init_trace::Cstring, init_trace_ptr::Ptr{Cvoid} - )::Cvoid add_sample_to_trace_ptr = @cfunction( addSampleToTraceLowered, Cvoid, (Ptr{Ptr{Cvoid}}, Ptr{Ptr{Cvoid}}, Ptr{Cvoid}) ) @@ -81,7 +68,8 @@ end resprefix::Symbol = gensym("generateresult") resargprefix::Symbol = gensym("generateresarg") - mlir_fn_res = invokelatest(TracedUtils.make_mlir_fn, + mlir_fn_res = invokelatest( + TracedUtils.make_mlir_fn, f, args, (), @@ -139,13 +127,17 @@ end end @noinline function sample!( - f::Function, args::Vararg{Any,Nargs}; symbol::Symbol=gensym("sample") + f::Function, + args::Vararg{Any,Nargs}; + symbol::Symbol=gensym("sample"), + trace::Union{Dict,Nothing}=nothing, ) where {Nargs} argprefix::Symbol = gensym("samplearg") resprefix::Symbol = gensym("sampleresult") resargprefix::Symbol = gensym("sampleresarg") - mlir_fn_res = invokelatest(TracedUtils.make_mlir_fn, + mlir_fn_res = invokelatest( + TracedUtils.make_mlir_fn, f, args, (), @@ -191,47 +183,44 @@ end ) end - symbol_ptr = pointer_from_objref(symbol) - symbol_addr = reinterpret(UInt64, symbol_ptr) - addr_attr = MLIR.IR.DenseElementsAttribute([symbol_addr]) + symbol_addr = reinterpret(UInt64, pointer_from_objref(symbol)) sample_op = MLIR.Dialects.enzyme.sample( - MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=addr_attr), 1), - batch_inputs; - outputs=out_tys, - fn=fn_attr, + batch_inputs; outputs=out_tys, fn=fn_attr, symbol=symbol_addr ) for (i, res) in enumerate(linear_results) resv = MLIR.IR.result(sample_op, i) - - for path in res.paths - isempty(path) && continue - if path[1] == resprefix - TracedUtils.set!(result, path[2:end], resv) - elseif path[1] == argprefix - idx = path[2]::Int - if idx == 1 && fnwrap - TracedUtils.set!(f, path[3:end], resv) - else - if fnwrap - idx -= 1 - end - TracedUtils.set!(args[idx], path[3:end], resv) + if TracedUtils.has_idx(res, resprefix) + path = TracedUtils.get_idx(res, resprefix) + TracedUtils.set!(result, path[2:end], TracedUtils.transpose_val(resv)) + elseif TracedUtils.has_idx(res, argprefix) + idx, path = TracedUtils.get_argidx(res, argprefix) + if idx == 1 && fnwrap + TracedUtils.set!(f, path[3:end], TracedUtils.transpose_val(resv)) + else + if fnwrap + idx -= 1 end + TracedUtils.set!(args[idx], path[3:end], TracedUtils.transpose_val(resv)) end + else + TracedUtils.set!(res, (), TracedUtils.transpose_val(resv)) end end return result end -@noinline function simulate!(f::Function, args::Vararg{Any,Nargs}) where {Nargs} +@noinline function simulate!( + f::Function, args::Vararg{Any,Nargs}; trace::Dict +) where {Nargs} argprefix::Symbol = gensym("simulatearg") resprefix::Symbol = gensym("simulateresult") resargprefix::Symbol = gensym("simulateresarg") - mlir_fn_res = TracedUtils.make_mlir_fn( + mlir_fn_res = invokelatest( + TracedUtils.make_mlir_fn, f, args, (), @@ -242,10 +231,14 @@ end resprefix, resargprefix, ) - (; linear_args, linear_results) = mlir_fn_res + (; result, linear_args, in_tys, linear_results) = mlir_fn_res fnwrap = mlir_fn_res.fnwrapped func2 = mlir_fn_res.f + out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] + fname = TracedUtils.get_attribute_by_name(func2, "sym_name") + fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) + batch_inputs = MLIR.IR.Value[] for a in linear_args idx, path = TracedUtils.get_argidx(a, argprefix) @@ -259,63 +252,36 @@ end end end - out_tys = MLIR.IR.Type[] - supress_rest = false - for res in linear_results - if TracedUtils.has_idx(res, resprefix) && !supress_rest - push!(out_tys, MLIR.IR.TensorType([1], MLIR.IR.Type(UInt64))) - supress_rest = true - else - # push!(out_tys, MLIR.IR.type(TracedUtils.get_mlir_data(res))) - end - end + trace_addr = reinterpret(UInt64, pointer_from_objref(trace)) - fname = TracedUtils.get_attribute_by_name(func2, "sym_name") - fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) - - simulate_op = MLIR.Dialects.enzyme.simulate(batch_inputs; outputs=out_tys, fn=fname) + simulate_op = MLIR.Dialects.enzyme.simulate( + batch_inputs; outputs=out_tys, fn=fname, trace=trace_addr + ) - result = nothing for (i, res) in enumerate(linear_results) resv = MLIR.IR.result(simulate_op, i) - if TracedUtils.has_idx(res, resprefix) - # casted = MLIR.IR.result( - # MLIR.Dialects.builtin.unrealized_conversion_cast( - # resv; to=MLIR.IR.TensorType([1], MLIR.IR.Type(UInt64)) - # ), - # 1, - # ) - # result = TracedRArray(casted) - result = TracedRArray(resv) - break - # continue + path = TracedUtils.get_idx(res, resprefix) + TracedUtils.set!(result, path[2:end], TracedUtils.transpose_val(resv)) + elseif TracedUtils.has_idx(res, argprefix) + idx, path = TracedUtils.get_argidx(res, argprefix) + if idx == 1 && fnwrap + TracedUtils.set!(f, path[3:end], TracedUtils.transpose_val(resv)) + else + if fnwrap + idx -= 1 + end + TracedUtils.set!(args[idx], path[3:end], TracedUtils.transpose_val(resv)) + end + else + TracedUtils.set!(res, (), TracedUtils.transpose_val(resv)) end - - # for path in res.paths - # isempty(path) && continue - # if path[1] == argprefix - # idx = path[2]::Int - # if idx == 1 && fnwrap - # TracedUtils.set!(f, path[3:end], resv) - # else - # if fnwrap - # idx -= 1 - # end - # TracedUtils.set!(args[idx], path[3:end], resv) - # end - # end - # end end - return result -end - -function getTrace(t::ConcretePJRTArray) - return unsafe_pointer_to_objref(reinterpret(Ptr{Cvoid}, Array{UInt64,1}(t)[1])) + return trace, result end -function print_trace(trace::IdDict) +function print_trace(trace::Dict) println("Probabilistic Program Trace:") for (symbol, sample) in trace symbol == :_integrity_check && continue From ef2e77064851226259ab813e3131a3e9b119e82c Mon Sep 17 00:00:00 2001 From: sbrantq Date: Sat, 14 Jun 2025 16:51:59 -0500 Subject: [PATCH 30/47] fixed tracing infra --- src/Compiler.jl | 92 ++++++++++++++++++++++---- src/ProbProg.jl | 133 ++++++++++++++++---------------------- test/probprog/simulate.jl | 54 +++++++++------- 3 files changed, 165 insertions(+), 114 deletions(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index dc92821853..a4e2ff44a5 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1740,21 +1740,91 @@ function compile_mlir!( ), "only_enzyme", ) + elseif optimize === :probprog_no_lowering + run_pass_pipeline!( + mod, + join( + if raise_first + [ + "mark-func-memory-effects", + opt_passes, + kern, + raise_passes, + "enzyme-batch", + opt_passes2, + enzyme_pass, + "probprog", + opt_passes2, + "canonicalize", + "remove-unnecessary-enzyme-ops", + "enzyme-simplify-math", + opt_passes2, + ] + else + [ + "mark-func-memory-effects", + opt_passes, + "enzyme-batch", + opt_passes2, + enzyme_pass, + "probprog", + opt_passes2, + "canonicalize", + "remove-unnecessary-enzyme-ops", + "enzyme-simplify-math", + opt_passes2, + kern, + raise_passes, + ] + end, + ",", + ), + "probprog_no_lowering", + ) elseif optimize === :probprog run_pass_pipeline!( mod, join( - [ - "mark-func-memory-effects", - "enzyme-batch", - "probprog", - "canonicalize", - "remove-unnecessary-enzyme-ops", - "enzyme-simplify-math", - lower_enzyme_probprog_pass, - jit - ], - ',', + if raise_first + [ + "mark-func-memory-effects", + opt_passes, + kern, + raise_passes, + "enzyme-batch", + opt_passes2, + enzyme_pass, + "probprog", + opt_passes2, + "canonicalize", + "remove-unnecessary-enzyme-ops", + "enzyme-simplify-math", + opt_passes2, + lower_enzymexla_linalg_pass, + lower_enzyme_probprog_pass, + jit, + ] + else + [ + "mark-func-memory-effects", + opt_passes, + "enzyme-batch", + opt_passes2, + enzyme_pass, + "probprog", + opt_passes2, + "canonicalize", + "remove-unnecessary-enzyme-ops", + "enzyme-simplify-math", + opt_passes2, + kern, + raise_passes, + lower_enzymexla_linalg_pass, + lower_enzyme_probprog_pass, + jit, + ] + end, + ",", ), "probprog", ) diff --git a/src/ProbProg.jl b/src/ProbProg.jl index d9f0672071..74c014de82 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -3,50 +3,38 @@ module ProbProg using ..Reactant: MLIR, TracedUtils, AbstractConcreteArray using Enzyme -struct SampleMetadata - shape::NTuple{N,Int} where {N} - element_type::Type - is_scalar::Bool - - function SampleMetadata( - shape::NTuple{N,Int}, element_type::Type, is_scalar::Bool - ) where {N} - return new(shape, element_type, is_scalar) - end -end - -const SAMPLE_METADATA_CACHE = Dict{Symbol,SampleMetadata}() - function createTrace() - return Dict{Symbol,Any}(:_integrity_check => 0x123456789abcdef) + return Dict{Symbol,Any}() end function addSampleToTraceLowered( - trace_ptr_ptr::Ptr{Ptr{Cvoid}}, symbol_ptr_ptr::Ptr{Ptr{Cvoid}}, sample_ptr::Ptr{Cvoid} + trace_ptr_ptr::Ptr{Ptr{Any}}, + symbol_ptr_ptr::Ptr{Ptr{Any}}, + sample_ptr::Ptr{Any}, + num_dims_ptr::Ptr{Int64}, + shape_array_ptr::Ptr{Int64}, + datatype_width_ptr::Ptr{Int64}, ) trace = unsafe_pointer_to_objref(unsafe_load(trace_ptr_ptr)) symbol = unsafe_pointer_to_objref(unsafe_load(symbol_ptr_ptr)) - @assert haskey(SAMPLE_METADATA_CACHE, symbol) "Symbol $symbol not found in metadata cache" + num_dims = unsafe_load(num_dims_ptr) + shape_array = unsafe_wrap(Array, shape_array_ptr, num_dims) + datatype_width = unsafe_load(datatype_width_ptr) - metadata = SAMPLE_METADATA_CACHE[symbol] - shape = metadata.shape - element_type = metadata.element_type - is_scalar = metadata.is_scalar + julia_type = if datatype_width == 32 + Float32 + elseif datatype_width == 64 + Float64 + else + error("Unsupported datatype width: $datatype_width") + end - if is_scalar - trace[symbol] = unsafe_load(reinterpret(Ptr{element_type}, sample_ptr)) + typed_ptr = Ptr{julia_type}(sample_ptr) + if num_dims == 0 + trace[symbol] = unsafe_load(typed_ptr) else - trace[symbol] = copy( - reshape( - unsafe_wrap( - Array{element_type}, - reinterpret(Ptr{element_type}, sample_ptr), - prod(shape), - ), - shape, - ), - ) + trace[symbol] = copy(unsafe_wrap(Array, typed_ptr, Tuple(shape_array))) end return nothing @@ -54,7 +42,9 @@ end function __init__() add_sample_to_trace_ptr = @cfunction( - addSampleToTraceLowered, Cvoid, (Ptr{Ptr{Cvoid}}, Ptr{Ptr{Cvoid}}, Ptr{Cvoid}) + addSampleToTraceLowered, + Cvoid, + (Ptr{Ptr{Any}}, Ptr{Ptr{Any}}, Ptr{Any}, Ptr{Int64}, Ptr{Int64}, Ptr{Int64}) ) @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( :enzyme_probprog_add_sample_to_trace::Cstring, add_sample_to_trace_ptr::Ptr{Cvoid} @@ -105,21 +95,21 @@ end for (i, res) in enumerate(linear_results) resv = MLIR.IR.result(gen_op, i) - for path in res.paths - isempty(path) && continue - if path[1] == resprefix - TracedUtils.set!(result, path[2:end], resv) - elseif path[1] == argprefix - idx = path[2]::Int - if idx == 1 && fnwrap - TracedUtils.set!(f, path[3:end], resv) - else - if fnwrap - idx -= 1 - end - TracedUtils.set!(args[idx], path[3:end], resv) + if TracedUtils.has_idx(res, resprefix) + path = TracedUtils.get_idx(res, resprefix) + TracedUtils.set!(result, path[2:end], TracedUtils.transpose_val(resv)) + elseif TracedUtils.has_idx(res, argprefix) + idx, path = TracedUtils.get_argidx(res, argprefix) + if idx == 1 && fnwrap + TracedUtils.set!(f, path[3:end], TracedUtils.transpose_val(resv)) + else + if fnwrap + idx -= 1 end + TracedUtils.set!(args[idx], path[3:end], TracedUtils.transpose_val(resv)) end + else + TracedUtils.set!(res, (), TracedUtils.transpose_val(resv)) end end @@ -127,10 +117,7 @@ end end @noinline function sample!( - f::Function, - args::Vararg{Any,Nargs}; - symbol::Symbol=gensym("sample"), - trace::Union{Dict,Nothing}=nothing, + f::Function, args::Vararg{Any,Nargs}; symbol::Symbol=gensym("sample") ) where {Nargs} argprefix::Symbol = gensym("samplearg") resprefix::Symbol = gensym("sampleresult") @@ -169,24 +156,21 @@ end sym = TracedUtils.get_attribute_by_name(func2, "sym_name") fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(sym)) - if !isempty(linear_results) - sample_result = linear_results[1] # TODO: consider multiple results - sample_mlir_data = TracedUtils.get_mlir_data(sample_result) - @assert sample_mlir_data isa MLIR.IR.Value "Sample $sample_result is not a MLIR.IR.Value" - - sample_type = MLIR.IR.type(sample_mlir_data) - sample_shape = size(sample_type) - sample_element_type = MLIR.IR.julia_type(eltype(sample_type)) - - SAMPLE_METADATA_CACHE[symbol] = SampleMetadata( - sample_shape, sample_element_type, length(sample_shape) == 0 - ) + traced_output_indices = Int[] + for (i, res) in enumerate(linear_results) + if TracedUtils.has_idx(res, resprefix) + push!(traced_output_indices, i - 1) + end end symbol_addr = reinterpret(UInt64, pointer_from_objref(symbol)) sample_op = MLIR.Dialects.enzyme.sample( - batch_inputs; outputs=out_tys, fn=fn_attr, symbol=symbol_addr + batch_inputs; + outputs=out_tys, + fn=fn_attr, + symbol=symbol_addr, + traced_output_indices=traced_output_indices, ) for (i, res) in enumerate(linear_results) @@ -213,7 +197,7 @@ end end @noinline function simulate!( - f::Function, args::Vararg{Any,Nargs}; trace::Dict + f::Function, args::Vararg{Any,Nargs}; trace::Dict{Symbol,Any} ) where {Nargs} argprefix::Symbol = gensym("simulatearg") resprefix::Symbol = gensym("simulateresult") @@ -278,25 +262,16 @@ end end end - return trace, result + return result end -function print_trace(trace::Dict) - println("Probabilistic Program Trace:") +function print_trace(trace::Dict{Symbol,Any}) + println("### Probabilistic Program Trace ###") for (symbol, sample) in trace - symbol == :_integrity_check && continue - metadata = SAMPLE_METADATA_CACHE[symbol] - println(" $symbol:") println(" Sample: $(sample)") - println(" Shape: $(metadata.shape)") - println(" Element Type: $(metadata.element_type)") end + println("### End of Trace ###") end -function clear_sample_metadata_cache!() - empty!(SAMPLE_METADATA_CACHE) - return nothing -end - -end +end \ No newline at end of file diff --git a/test/probprog/simulate.jl b/test/probprog/simulate.jl index 97910443e2..59bbfe0509 100644 --- a/test/probprog/simulate.jl +++ b/test/probprog/simulate.jl @@ -1,46 +1,52 @@ using Reactant, Test, Random, StableRNGs, Statistics using Reactant: ProbProg -using Libdl: Libdl -normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) +@testset "Simulate" begin + normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) -function simulate_model(seed, μ, σ, shape) - function model(seed, μ, σ, shape) - rng = Random.default_rng() - Random.seed!(rng, seed) - s = ProbProg.sample!(normal, rng, μ, σ, shape; symbol=:s) - t = ProbProg.sample!(normal, rng, s, σ, shape; symbol=:t) - return t - end + function simulate_model(trace, seed, μ, σ, shape) + function model(seed, μ, σ, shape) + rng = Random.default_rng() + Random.seed!(rng, seed) + s = ProbProg.sample!(normal, rng, μ, σ, shape; symbol=:s) + t = ProbProg.sample!(normal, rng, s, σ, shape; symbol=:t) + return t + end - return ProbProg.simulate!(model, seed, μ, σ, shape) -end - -@testset "Simulate" begin + result = ProbProg.simulate!(model, seed, μ, σ, shape; trace) + return result + end @testset "normal_hlo" begin shape = (10000,) seed = Reactant.to_rarray(UInt64[1, 4]) - μ = Reactant.ConcreteRArray(0.0) - σ = Reactant.ConcreteRArray(1.0) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + trace = ProbProg.createTrace() - before = @code_hlo optimize = :no_enzyme simulate_model(seed, μ, σ, shape) + before = @code_hlo optimize = :no_enzyme simulate_model(trace, seed, μ, σ, shape) @test contains(repr(before), "enzyme.simulate") @test contains(repr(before), "enzyme.sample") - after = @code_hlo optimize = :probprog simulate_model(seed, μ, σ, shape) + after = @code_hlo optimize = :probprog simulate_model(trace, seed, μ, σ, shape) @test !contains(repr(after), "enzyme.simulate") @test !contains(repr(after), "enzyme.sample") @test contains(repr(after), "enzyme_probprog_add_sample_to_trace") - @test contains(repr(after), "enzyme_probprog_init_trace") end @testset "normal_simulate" begin shape = (3, 3, 3) seed = Reactant.to_rarray(UInt64[1, 4]) - μ = Reactant.ConcreteRArray(0.0) - σ = Reactant.ConcreteRArray(1.0) - X = ProbProg.getTrace(@jit optimize = :probprog simulate_model(seed, μ, σ, shape)) - @test X[:_integrity_check] == 0x123456789abcdef - ProbProg.print_trace(X) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + trace = ProbProg.createTrace() + + result = Array( + @jit optimize = :probprog sync = true simulate_model(trace, seed, μ, σ, shape) + ) + + ProbProg.print_trace(trace) + @test size(result) == shape end end From 46e0f6b9f6f47ee69df4148bc665c4ea998e388c Mon Sep 17 00:00:00 2001 From: sbrantq Date: Mon, 16 Jun 2025 17:54:15 -0500 Subject: [PATCH 31/47] transpose fix up --- src/ProbProg.jl | 33 ++++++++++++++++++--------------- test/probprog/generate.jl | 22 +++++++++++++++++++--- test/probprog/sample.jl | 2 +- test/probprog/simulate.jl | 29 +++++++++++++++++++++++++---- 4 files changed, 63 insertions(+), 23 deletions(-) diff --git a/src/ProbProg.jl b/src/ProbProg.jl index 74c014de82..bfc6b7d942 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -27,7 +27,8 @@ function addSampleToTraceLowered( elseif datatype_width == 64 Float64 else - error("Unsupported datatype width: $datatype_width") + @ccall printf("Unsupported datatype width: %d\n"::Cstring, datatype_width::Cint)::Cvoid + return nothing end typed_ptr = Ptr{julia_type}(sample_ptr) @@ -65,6 +66,7 @@ end (), string(f), false; + do_transpose=false, args_in_result=:all, argprefix, resprefix, @@ -97,19 +99,19 @@ end resv = MLIR.IR.result(gen_op, i) if TracedUtils.has_idx(res, resprefix) path = TracedUtils.get_idx(res, resprefix) - TracedUtils.set!(result, path[2:end], TracedUtils.transpose_val(resv)) + TracedUtils.set!(result, path[2:end], resv) elseif TracedUtils.has_idx(res, argprefix) idx, path = TracedUtils.get_argidx(res, argprefix) if idx == 1 && fnwrap - TracedUtils.set!(f, path[3:end], TracedUtils.transpose_val(resv)) + TracedUtils.set!(f, path[3:end], resv) else if fnwrap idx -= 1 end - TracedUtils.set!(args[idx], path[3:end], TracedUtils.transpose_val(resv)) + TracedUtils.set!(args[idx], path[3:end], resv) end else - TracedUtils.set!(res, (), TracedUtils.transpose_val(resv)) + TracedUtils.set!(res, (), resv) end end @@ -130,8 +132,8 @@ end (), string(f), false; + do_transpose=false, args_in_result=:all, - do_transpose=false, # TODO: double check transpose argprefix, resprefix, resargprefix, @@ -177,19 +179,19 @@ end resv = MLIR.IR.result(sample_op, i) if TracedUtils.has_idx(res, resprefix) path = TracedUtils.get_idx(res, resprefix) - TracedUtils.set!(result, path[2:end], TracedUtils.transpose_val(resv)) + TracedUtils.set!(result, path[2:end], resv) elseif TracedUtils.has_idx(res, argprefix) idx, path = TracedUtils.get_argidx(res, argprefix) if idx == 1 && fnwrap - TracedUtils.set!(f, path[3:end], TracedUtils.transpose_val(resv)) + TracedUtils.set!(f, path[3:end], resv) else if fnwrap idx -= 1 end - TracedUtils.set!(args[idx], path[3:end], TracedUtils.transpose_val(resv)) + TracedUtils.set!(args[idx], path[3:end], resv) end else - TracedUtils.set!(res, (), TracedUtils.transpose_val(resv)) + TracedUtils.set!(res, (), resv) end end @@ -210,6 +212,7 @@ end (), string(f), false; + do_transpose=false, args_in_result=:all, argprefix, resprefix, @@ -246,19 +249,19 @@ end resv = MLIR.IR.result(simulate_op, i) if TracedUtils.has_idx(res, resprefix) path = TracedUtils.get_idx(res, resprefix) - TracedUtils.set!(result, path[2:end], TracedUtils.transpose_val(resv)) + TracedUtils.set!(result, path[2:end], resv) elseif TracedUtils.has_idx(res, argprefix) idx, path = TracedUtils.get_argidx(res, argprefix) if idx == 1 && fnwrap - TracedUtils.set!(f, path[3:end], TracedUtils.transpose_val(resv)) + TracedUtils.set!(f, path[3:end], resv) else if fnwrap idx -= 1 end - TracedUtils.set!(args[idx], path[3:end], TracedUtils.transpose_val(resv)) + TracedUtils.set!(args[idx], path[3:end], resv) end else - TracedUtils.set!(res, (), TracedUtils.transpose_val(resv)) + TracedUtils.set!(res, (), resv) end end @@ -271,7 +274,7 @@ function print_trace(trace::Dict{Symbol,Any}) println(" $symbol:") println(" Sample: $(sample)") end - println("### End of Trace ###") + return println("### End of Trace ###") end end \ No newline at end of file diff --git a/test/probprog/generate.jl b/test/probprog/generate.jl index 605b375805..cefd648a93 100644 --- a/test/probprog/generate.jl +++ b/test/probprog/generate.jl @@ -16,7 +16,7 @@ function generate_model(seed, μ, σ, shape) end @testset "Generate" begin - @testset "normal_deterministic" begin + @testset "deterministic" begin shape = (10000,) seed1 = Reactant.to_rarray(UInt64[1, 4]) seed2 = Reactant.to_rarray(UInt64[1, 4]) @@ -38,7 +38,7 @@ end Array(model_compiled(seed2, μ2, σ2, shape)), )) end - @testset "normal_hlo" begin + @testset "hlo" begin shape = (10000,) seed = Reactant.to_rarray(UInt64[1, 4]) μ = Reactant.ConcreteRNumber(0.0) @@ -53,7 +53,7 @@ end @test !contains(repr(after), "enzyme.sample") end - @testset "normal_generate" begin + @testset "normal" begin shape = (10000,) seed = Reactant.to_rarray(UInt64[1, 4]) μ = Reactant.ConcreteRNumber(0.0) @@ -61,4 +61,20 @@ end X = Array(@jit optimize = :probprog generate_model(seed, μ, σ, shape)) @test mean(X) ≈ 0.0 atol = 0.05 rtol = 0.05 end + + @testset "correctness" begin + op(x, y) = x * y' + + function fake_model(x, y) + return ProbProg.sample!(op, x, y) + end + + x = reshape(collect(Float64, 1:12), (4, 3)) + y = reshape(collect(Float64, 1:12), (4, 3)) + x_ra = Reactant.to_rarray(x) + y_ra = Reactant.to_rarray(y) + + @test Array(@jit optimize = :probprog ProbProg.generate!(fake_model, x_ra, y_ra)) == + op(x, y) + end end diff --git a/test/probprog/sample.jl b/test/probprog/sample.jl index 9c711241d8..aabf476f94 100644 --- a/test/probprog/sample.jl +++ b/test/probprog/sample.jl @@ -18,7 +18,7 @@ function sample2(seed, μ, σ, shape) function model(seed, μ, σ, shape) rng = Random.default_rng() Random.seed!(rng, seed) - s = ProbProg.sample!(normal, rng, μ, σ, shape) + _ = ProbProg.sample!(normal, rng, μ, σ, shape) t = ProbProg.sample!(normal, rng, μ, σ, shape) return t end diff --git a/test/probprog/simulate.jl b/test/probprog/simulate.jl index 59bbfe0509..7b44e8bcd9 100644 --- a/test/probprog/simulate.jl +++ b/test/probprog/simulate.jl @@ -42,11 +42,32 @@ using Reactant: ProbProg trace = ProbProg.createTrace() - result = Array( - @jit optimize = :probprog sync = true simulate_model(trace, seed, μ, σ, shape) - ) + result = Array(@jit optimize = :probprog simulate_model(trace, seed, μ, σ, shape)) - ProbProg.print_trace(trace) @test size(result) == shape + @test haskey(trace, :s) + @test haskey(trace, :t) + @test size(trace[:s]) == shape + @test size(trace[:t]) == shape + end + + @testset "correctness" begin + op(x, y) = x * y' + function fake_model(x, y) + return ProbProg.sample!(op, x, y; symbol=:matmul) + end + + trace = ProbProg.createTrace() + x = reshape(collect(Float64, 1:12), (4, 3)) + y = reshape(collect(Float64, 1:12), (4, 3)) + x_ra = Reactant.to_rarray(x) + y_ra = Reactant.to_rarray(y) + + @test Array( + @jit optimize = :probprog ProbProg.simulate!(fake_model, x_ra, y_ra; trace) + ) == op(x, y) + + @test haskey(trace, :matmul) + @test trace[:matmul] == op(x, y) end end From 1c5297cab8f1e09ea6861fe54397997581d8fc51 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Tue, 17 Jun 2025 13:12:45 -0500 Subject: [PATCH 32/47] minor changes --- src/ProbProg.jl | 10 ++++---- test/probprog/blr.jl | 49 +++++++++++++++++++++------------------ test/probprog/simulate.jl | 6 ++--- 3 files changed, 36 insertions(+), 29 deletions(-) diff --git a/src/ProbProg.jl b/src/ProbProg.jl index bfc6b7d942..f876d43122 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -3,10 +3,6 @@ module ProbProg using ..Reactant: MLIR, TracedUtils, AbstractConcreteArray using Enzyme -function createTrace() - return Dict{Symbol,Any}() -end - function addSampleToTraceLowered( trace_ptr_ptr::Ptr{Ptr{Any}}, symbol_ptr_ptr::Ptr{Ptr{Any}}, @@ -26,6 +22,8 @@ function addSampleToTraceLowered( Float32 elseif datatype_width == 64 Float64 + elseif datatype_width == 1 + Bool else @ccall printf("Unsupported datatype width: %d\n"::Cstring, datatype_width::Cint)::Cvoid return nothing @@ -268,6 +266,10 @@ end return result end +function create_trace() + return Dict{Symbol,Any}() +end + function print_trace(trace::Dict{Symbol,Any}) println("### Probabilistic Program Trace ###") for (symbol, sample) in trace diff --git a/test/probprog/blr.jl b/test/probprog/blr.jl index 4f0836b76c..3e0e040963 100644 --- a/test/probprog/blr.jl +++ b/test/probprog/blr.jl @@ -1,33 +1,33 @@ -using Reactant, Test, Random, StableRNGs, Statistics +using Reactant, Test, Random using Reactant: ProbProg -using Libdl: Libdl -normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) -bernoulli_logit(rng, logit, shape) = rand(rng, shape...) .< (1 ./ (1 .+ exp.(-logit))) +function normal(rng, μ, σ, shape) + return μ .+ σ .* randn(rng, shape) +end -function blr(seed, N, K) - function model(seed, N, K) - rng = Random.default_rng() - Random.seed!(rng, seed) +function bernoulli_logit(rng, logit, shape) + return rand(rng, shape...) .< (1 ./ (1 .+ exp.(-logit))) +end - # α ~ Normal(0, 10, size = 1) - α = ProbProg.sample!(normal, rng, 0, 10, (1,); symbol=:α) +function blr(seed, N, K) + rng = Random.default_rng() + Random.seed!(rng, seed) - # β ~ Normal(0, 2.5, size = K) - β = ProbProg.sample!(normal, rng, 0, 2.5, (K,); symbol=:β) + # α ~ Normal(0, 10, size = 1) + α = ProbProg.sample!(normal, rng, 0, 10, (1,); symbol=:α) - # X ~ Normal(0, 10, size = (N, K)) - X = ProbProg.sample!(normal, rng, 0, 10, (N, K); symbol=:X) # TODO: double check transpose + # β ~ Normal(0, 2.5, size = K) + β = ProbProg.sample!(normal, rng, 0, 2.5, (K,); symbol=:β) - # μ = α .+ X * β - μ = α .+ X * β + # X ~ Normal(0, 10, size = (N, K)) + X = ProbProg.sample!(normal, rng, 0, 10, (N, K); symbol=:X) - ProbProg.sample!(bernoulli_logit, rng, μ, (N,); symbol=:Y) + # μ = α .+ X * β + μ = α .+ X * β - return μ - end + Y = ProbProg.sample!(bernoulli_logit, rng, μ, (N,); symbol=:Y) - return ProbProg.simulate!(model, seed, N, K) + return Y end @testset "BLR" begin @@ -35,6 +35,11 @@ end K = 3 # number of features seed = Reactant.to_rarray(UInt64[1, 4]) - X = ProbProg.getTrace(@jit optimize = :probprog blr(seed, N, K)) - ProbProg.print_trace(X) + trace = ProbProg.create_trace() + + @test size( + Array(@jit optimize = :probprog ProbProg.simulate!(blr, seed, N, K; trace)) + ) == (N,) + + ProbProg.print_trace(trace) end diff --git a/test/probprog/simulate.jl b/test/probprog/simulate.jl index 7b44e8bcd9..403505e1b0 100644 --- a/test/probprog/simulate.jl +++ b/test/probprog/simulate.jl @@ -22,7 +22,7 @@ using Reactant: ProbProg μ = Reactant.ConcreteRNumber(0.0) σ = Reactant.ConcreteRNumber(1.0) - trace = ProbProg.createTrace() + trace = ProbProg.create_trace() before = @code_hlo optimize = :no_enzyme simulate_model(trace, seed, μ, σ, shape) @test contains(repr(before), "enzyme.simulate") @@ -40,7 +40,7 @@ using Reactant: ProbProg μ = Reactant.ConcreteRNumber(0.0) σ = Reactant.ConcreteRNumber(1.0) - trace = ProbProg.createTrace() + trace = ProbProg.create_trace() result = Array(@jit optimize = :probprog simulate_model(trace, seed, μ, σ, shape)) @@ -57,7 +57,7 @@ using Reactant: ProbProg return ProbProg.sample!(op, x, y; symbol=:matmul) end - trace = ProbProg.createTrace() + trace = ProbProg.create_trace() x = reshape(collect(Float64, 1:12), (4, 3)) y = reshape(collect(Float64, 1:12), (4, 3)) x_ra = Reactant.to_rarray(x) From d707053c7f553828dea03f458584b1e5181fe356 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Tue, 17 Jun 2025 16:01:00 -0500 Subject: [PATCH 33/47] reorder --- src/ProbProg.jl | 91 ++++++++++++++++++++++++------------------------- 1 file changed, 45 insertions(+), 46 deletions(-) diff --git a/src/ProbProg.jl b/src/ProbProg.jl index f876d43122..230ce3c2af 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -52,10 +52,12 @@ function __init__() return nothing end -@noinline function generate!(f::Function, args::Vararg{Any,Nargs}) where {Nargs} - argprefix::Symbol = gensym("generatearg") - resprefix::Symbol = gensym("generateresult") - resargprefix::Symbol = gensym("generateresarg") +@noinline function sample!( + f::Function, args::Vararg{Any,Nargs}; symbol::Symbol=gensym("sample") +) where {Nargs} + argprefix::Symbol = gensym("samplearg") + resprefix::Symbol = gensym("sampleresult") + resargprefix::Symbol = gensym("sampleresarg") mlir_fn_res = invokelatest( TracedUtils.make_mlir_fn, @@ -70,31 +72,45 @@ end resprefix, resargprefix, ) - (; result, linear_args, in_tys, linear_results) = mlir_fn_res + (; result, linear_args, linear_results) = mlir_fn_res fnwrap = mlir_fn_res.fnwrapped func2 = mlir_fn_res.f - out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] - fname = TracedUtils.get_attribute_by_name(func2, "sym_name") - fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) - batch_inputs = MLIR.IR.Value[] for a in linear_args idx, path = TracedUtils.get_argidx(a, argprefix) if idx == 1 && fnwrap TracedUtils.push_val!(batch_inputs, f, path[3:end]) else - if fnwrap - idx -= 1 - end + idx -= fnwrap ? 1 : 0 TracedUtils.push_val!(batch_inputs, args[idx], path[3:end]) end end - gen_op = MLIR.Dialects.enzyme.generate(batch_inputs; outputs=out_tys, fn=fname) + out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] + sym = TracedUtils.get_attribute_by_name(func2, "sym_name") + fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(sym)) + + traced_output_indices = Int[] for (i, res) in enumerate(linear_results) - resv = MLIR.IR.result(gen_op, i) + if TracedUtils.has_idx(res, resprefix) + push!(traced_output_indices, i - 1) + end + end + + symbol_addr = reinterpret(UInt64, pointer_from_objref(symbol)) + + sample_op = MLIR.Dialects.enzyme.sample( + batch_inputs; + outputs=out_tys, + fn=fn_attr, + symbol=symbol_addr, + traced_output_indices=traced_output_indices, + ) + + for (i, res) in enumerate(linear_results) + resv = MLIR.IR.result(sample_op, i) if TracedUtils.has_idx(res, resprefix) path = TracedUtils.get_idx(res, resprefix) TracedUtils.set!(result, path[2:end], resv) @@ -116,12 +132,10 @@ end return result end -@noinline function sample!( - f::Function, args::Vararg{Any,Nargs}; symbol::Symbol=gensym("sample") -) where {Nargs} - argprefix::Symbol = gensym("samplearg") - resprefix::Symbol = gensym("sampleresult") - resargprefix::Symbol = gensym("sampleresarg") +@noinline function generate!(f::Function, args::Vararg{Any,Nargs}) where {Nargs} + argprefix::Symbol = gensym("generatearg") + resprefix::Symbol = gensym("generateresult") + resargprefix::Symbol = gensym("generateresarg") mlir_fn_res = invokelatest( TracedUtils.make_mlir_fn, @@ -136,45 +150,31 @@ end resprefix, resargprefix, ) - (; result, linear_args, linear_results) = mlir_fn_res + (; result, linear_args, in_tys, linear_results) = mlir_fn_res fnwrap = mlir_fn_res.fnwrapped func2 = mlir_fn_res.f + out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] + fname = TracedUtils.get_attribute_by_name(func2, "sym_name") + fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) + batch_inputs = MLIR.IR.Value[] for a in linear_args idx, path = TracedUtils.get_argidx(a, argprefix) if idx == 1 && fnwrap TracedUtils.push_val!(batch_inputs, f, path[3:end]) else - idx -= fnwrap ? 1 : 0 + if fnwrap + idx -= 1 + end TracedUtils.push_val!(batch_inputs, args[idx], path[3:end]) end end - out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] - - sym = TracedUtils.get_attribute_by_name(func2, "sym_name") - fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(sym)) - - traced_output_indices = Int[] - for (i, res) in enumerate(linear_results) - if TracedUtils.has_idx(res, resprefix) - push!(traced_output_indices, i - 1) - end - end - - symbol_addr = reinterpret(UInt64, pointer_from_objref(symbol)) - - sample_op = MLIR.Dialects.enzyme.sample( - batch_inputs; - outputs=out_tys, - fn=fn_attr, - symbol=symbol_addr, - traced_output_indices=traced_output_indices, - ) + gen_op = MLIR.Dialects.enzyme.generate(batch_inputs; outputs=out_tys, fn=fname) for (i, res) in enumerate(linear_results) - resv = MLIR.IR.result(sample_op, i) + resv = MLIR.IR.result(gen_op, i) if TracedUtils.has_idx(res, resprefix) path = TracedUtils.get_idx(res, resprefix) TracedUtils.set!(result, path[2:end], resv) @@ -278,5 +278,4 @@ function print_trace(trace::Dict{Symbol,Any}) end return println("### End of Trace ###") end - -end \ No newline at end of file +end From 91a0850ed245da954b3c294475418149ed14476a Mon Sep 17 00:00:00 2001 From: sbrantq Date: Fri, 20 Jun 2025 14:17:03 -0500 Subject: [PATCH 34/47] API change --- src/ProbProg.jl | 50 ++++++++++++++++++----------- src/Reactant.jl | 2 +- test/probprog/blr.jl | 16 ++++------ test/probprog/generate.jl | 36 +++++++++++---------- test/probprog/sample.jl | 44 +++++++++++-------------- test/probprog/simulate.jl | 67 +++++++++++++++++---------------------- 6 files changed, 106 insertions(+), 109 deletions(-) diff --git a/src/ProbProg.jl b/src/ProbProg.jl index 230ce3c2af..0dc2bdeffb 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -1,8 +1,18 @@ module ProbProg -using ..Reactant: MLIR, TracedUtils, AbstractConcreteArray +using ..Reactant: MLIR, TracedUtils, AbstractConcreteArray, AbstractConcreteNumber +using ..Compiler: @jit using Enzyme +mutable struct ProbProgTrace + choices::Dict{Symbol,Any} + retval::Any + + function ProbProgTrace() + return new(Dict{Symbol,Any}(), nothing) + end +end + function addSampleToTraceLowered( trace_ptr_ptr::Ptr{Ptr{Any}}, symbol_ptr_ptr::Ptr{Ptr{Any}}, @@ -31,9 +41,9 @@ function addSampleToTraceLowered( typed_ptr = Ptr{julia_type}(sample_ptr) if num_dims == 0 - trace[symbol] = unsafe_load(typed_ptr) + trace.choices[symbol] = unsafe_load(typed_ptr) else - trace[symbol] = copy(unsafe_wrap(Array, typed_ptr, Tuple(shape_array))) + trace.choices[symbol] = copy(unsafe_wrap(Array, typed_ptr, Tuple(shape_array))) end return nothing @@ -52,7 +62,7 @@ function __init__() return nothing end -@noinline function sample!( +function sample( f::Function, args::Vararg{Any,Nargs}; symbol::Symbol=gensym("sample") ) where {Nargs} argprefix::Symbol = gensym("samplearg") @@ -132,7 +142,12 @@ end return result end -@noinline function generate!(f::Function, args::Vararg{Any,Nargs}) where {Nargs} +function generate(f::Function, args::Vararg{Any,Nargs}) where {Nargs} + res = @jit optimize = :probprog generate_internal(f, args...) + return res isa AbstractConcreteArray ? Array(res) : res +end + +function generate_internal(f::Function, args::Vararg{Any,Nargs}) where {Nargs} argprefix::Symbol = gensym("generatearg") resprefix::Symbol = gensym("generateresult") resargprefix::Symbol = gensym("generateresarg") @@ -196,8 +211,18 @@ end return result end -@noinline function simulate!( - f::Function, args::Vararg{Any,Nargs}; trace::Dict{Symbol,Any} +function simulate(f::Function, args::Vararg{Any,Nargs}) where {Nargs} + trace = ProbProgTrace() + + res = @jit optimize = :probprog sync = true simulate_internal(f, args...; trace) + + trace.retval = res isa AbstractConcreteArray ? Array(res) : res + + return trace +end + +function simulate_internal( + f::Function, args::Vararg{Any,Nargs}; trace::ProbProgTrace ) where {Nargs} argprefix::Symbol = gensym("simulatearg") resprefix::Symbol = gensym("simulateresult") @@ -266,16 +291,5 @@ end return result end -function create_trace() - return Dict{Symbol,Any}() -end -function print_trace(trace::Dict{Symbol,Any}) - println("### Probabilistic Program Trace ###") - for (symbol, sample) in trace - println(" $symbol:") - println(" Sample: $(sample)") - end - return println("### End of Trace ###") -end end diff --git a/src/Reactant.jl b/src/Reactant.jl index f0b6c044f1..48874e8f95 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -176,7 +176,6 @@ include("stdlibs/Base.jl") # Other Integrations include("Enzyme.jl") -include("ProbProg.jl") const TracedType = Union{TracedRArray,TracedRNumber,MissingTracedValue} @@ -189,6 +188,7 @@ export OptimizeCommunicationOptions include("Compiler.jl") include("Overlay.jl") +include("ProbProg.jl") using .Compiler: @compile, @code_hlo, @code_mhlo, @jit, @code_xla, traced_getfield, compile export ConcreteRArray, diff --git a/test/probprog/blr.jl b/test/probprog/blr.jl index 3e0e040963..7c53aaafd7 100644 --- a/test/probprog/blr.jl +++ b/test/probprog/blr.jl @@ -14,18 +14,18 @@ function blr(seed, N, K) Random.seed!(rng, seed) # α ~ Normal(0, 10, size = 1) - α = ProbProg.sample!(normal, rng, 0, 10, (1,); symbol=:α) + α = ProbProg.sample(normal, rng, 0, 10, (1,); symbol=:α) # β ~ Normal(0, 2.5, size = K) - β = ProbProg.sample!(normal, rng, 0, 2.5, (K,); symbol=:β) + β = ProbProg.sample(normal, rng, 0, 2.5, (K,); symbol=:β) # X ~ Normal(0, 10, size = (N, K)) - X = ProbProg.sample!(normal, rng, 0, 10, (N, K); symbol=:X) + X = ProbProg.sample(normal, rng, 0, 10, (N, K); symbol=:X) # μ = α .+ X * β μ = α .+ X * β - Y = ProbProg.sample!(bernoulli_logit, rng, μ, (N,); symbol=:Y) + Y = ProbProg.sample(bernoulli_logit, rng, μ, (N,); symbol=:Y) return Y end @@ -35,11 +35,9 @@ end K = 3 # number of features seed = Reactant.to_rarray(UInt64[1, 4]) - trace = ProbProg.create_trace() + trace = ProbProg.simulate(blr, seed, N, K) - @test size( - Array(@jit optimize = :probprog ProbProg.simulate!(blr, seed, N, K; trace)) - ) == (N,) + @test size(Array(trace.retval)) == (N,) - ProbProg.print_trace(trace) + println(trace) end diff --git a/test/probprog/generate.jl b/test/probprog/generate.jl index cefd648a93..9c93c7a6f5 100644 --- a/test/probprog/generate.jl +++ b/test/probprog/generate.jl @@ -1,18 +1,14 @@ -using Reactant, Test, Random, StableRNGs, Statistics +using Reactant, Test, Random, Statistics using Reactant: ProbProg normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) -function generate_model(seed, μ, σ, shape) - function model(seed, μ, σ, shape) - rng = Random.default_rng() - Random.seed!(rng, seed) - s = ProbProg.sample!(normal, rng, μ, σ, shape) - t = ProbProg.sample!(normal, rng, s, σ, shape) - return t - end - - return ProbProg.generate!(model, seed, μ, σ, shape) +function model(seed, μ, σ, shape) + rng = Random.default_rng() + Random.seed!(rng, seed) + s = ProbProg.sample(normal, rng, μ, σ, shape) + t = ProbProg.sample(normal, rng, s, σ, shape) + return t end @testset "Generate" begin @@ -25,6 +21,9 @@ end σ1 = Reactant.ConcreteRNumber(1.0) σ2 = Reactant.ConcreteRNumber(1.0) + generate_model(seed, μ, σ, shape) = + ProbProg.generate_internal(model, seed, μ, σ, shape) + model_compiled = @compile optimize = :probprog generate_model(seed1, μ1, σ1, shape) @test Array(model_compiled(seed1, μ1, σ1, shape)) ≈ @@ -44,11 +43,15 @@ end μ = Reactant.ConcreteRNumber(0.0) σ = Reactant.ConcreteRNumber(1.0) - before = @code_hlo optimize = :no_enzyme generate_model(seed, μ, σ, shape) + before = @code_hlo optimize = :no_enzyme ProbProg.generate_internal( + model, seed, μ, σ, shape + ) @test contains(repr(before), "enzyme.generate") @test contains(repr(before), "enzyme.sample") - after = @code_hlo optimize = :probprog generate_model(seed, μ, σ, shape) + after = @code_hlo optimize = :probprog ProbProg.generate_internal( + model, seed, μ, σ, shape + ) @test !contains(repr(after), "enzyme.generate") @test !contains(repr(after), "enzyme.sample") end @@ -58,7 +61,7 @@ end seed = Reactant.to_rarray(UInt64[1, 4]) μ = Reactant.ConcreteRNumber(0.0) σ = Reactant.ConcreteRNumber(1.0) - X = Array(@jit optimize = :probprog generate_model(seed, μ, σ, shape)) + X = ProbProg.generate(model, seed, μ, σ, shape) @test mean(X) ≈ 0.0 atol = 0.05 rtol = 0.05 end @@ -66,7 +69,7 @@ end op(x, y) = x * y' function fake_model(x, y) - return ProbProg.sample!(op, x, y) + return ProbProg.sample(op, x, y) end x = reshape(collect(Float64, 1:12), (4, 3)) @@ -74,7 +77,6 @@ end x_ra = Reactant.to_rarray(x) y_ra = Reactant.to_rarray(y) - @test Array(@jit optimize = :probprog ProbProg.generate!(fake_model, x_ra, y_ra)) == - op(x, y) + @test ProbProg.generate(fake_model, x_ra, y_ra) == op(x, y) end end diff --git a/test/probprog/sample.jl b/test/probprog/sample.jl index aabf476f94..9541b2feb8 100644 --- a/test/probprog/sample.jl +++ b/test/probprog/sample.jl @@ -1,29 +1,21 @@ -using Reactant, Test, Random, StableRNGs, Statistics +using Reactant, Test, Random using Reactant: ProbProg -@noinline normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) +normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) -function sample1(seed, μ, σ, shape) - function model(seed, μ, σ, shape) - rng = Random.default_rng() - Random.seed!(rng, seed) - s = ProbProg.sample!(normal, rng, μ, σ, shape) - return s - end - - return ProbProg.generate!(model, seed, μ, σ, shape) +function one_sample(seed, μ, σ, shape) + rng = Random.default_rng() + Random.seed!(rng, seed) + s = ProbProg.sample(normal, rng, μ, σ, shape) + return s end -function sample2(seed, μ, σ, shape) - function model(seed, μ, σ, shape) - rng = Random.default_rng() - Random.seed!(rng, seed) - _ = ProbProg.sample!(normal, rng, μ, σ, shape) - t = ProbProg.sample!(normal, rng, μ, σ, shape) - return t - end - - return ProbProg.generate!(model, seed, μ, σ, shape) +function two_samples(seed, μ, σ, shape) + rng = Random.default_rng() + Random.seed!(rng, seed) + _ = ProbProg.sample(normal, rng, μ, σ, shape) + t = ProbProg.sample(normal, rng, μ, σ, shape) + return t end @testset "test" begin @@ -32,19 +24,19 @@ end seed = Reactant.to_rarray(UInt64[1, 4]) μ = Reactant.ConcreteRNumber(0.0) σ = Reactant.ConcreteRNumber(1.0) - before = @code_hlo optimize = false sample2(seed, μ, σ, shape) + before = @code_hlo optimize = false ProbProg.generate_internal(one_sample, seed, μ, σ, shape) @test contains(repr(before), "enzyme.sample") - after = @code_hlo optimize = :probprog sample2(seed, μ, σ, shape) + after = @code_hlo optimize = :probprog ProbProg.generate_internal(two_samples, seed, μ, σ, shape) @test !contains(repr(after), "enzyme.sample") end - @testset "sample_normal" begin + @testset "rng_state" begin shape = (10,) seed = Reactant.to_rarray(UInt64[1, 4]) μ = Reactant.ConcreteRNumber(0.0) σ = Reactant.ConcreteRNumber(1.0) - X = Array(@jit optimize = :probprog sample1(seed, μ, σ, shape)) - Y = Array(@jit optimize = :probprog sample2(seed, μ, σ, shape)) + X = ProbProg.generate(one_sample, seed, μ, σ, shape) + Y = ProbProg.generate(two_samples, seed, μ, σ, shape) @test !all(X .≈ Y) end end diff --git a/test/probprog/simulate.jl b/test/probprog/simulate.jl index 403505e1b0..a97fc5ae8d 100644 --- a/test/probprog/simulate.jl +++ b/test/probprog/simulate.jl @@ -1,37 +1,32 @@ -using Reactant, Test, Random, StableRNGs, Statistics +using Reactant, Test, Random using Reactant: ProbProg -@testset "Simulate" begin - normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) +normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) - function simulate_model(trace, seed, μ, σ, shape) - function model(seed, μ, σ, shape) - rng = Random.default_rng() - Random.seed!(rng, seed) - s = ProbProg.sample!(normal, rng, μ, σ, shape; symbol=:s) - t = ProbProg.sample!(normal, rng, s, σ, shape; symbol=:t) - return t - end +function model(seed, μ, σ, shape) + rng = Random.default_rng() + Random.seed!(rng, seed) + s = ProbProg.sample(normal, rng, μ, σ, shape; symbol=:s) + t = ProbProg.sample(normal, rng, s, σ, shape; symbol=:t) + return t +end - result = ProbProg.simulate!(model, seed, μ, σ, shape; trace) - return result - end - @testset "normal_hlo" begin - shape = (10000,) +@testset "Simulate" begin + @testset "simulate_hlo" begin + shape = (3, 3, 3) seed = Reactant.to_rarray(UInt64[1, 4]) μ = Reactant.ConcreteRNumber(0.0) σ = Reactant.ConcreteRNumber(1.0) - trace = ProbProg.create_trace() - - before = @code_hlo optimize = :no_enzyme simulate_model(trace, seed, μ, σ, shape) + before = @code_hlo optimize = false ProbProg.simulate_internal( + model, seed, μ, σ, shape; trace = ProbProg.ProbProgTrace() + ) @test contains(repr(before), "enzyme.simulate") - @test contains(repr(before), "enzyme.sample") - after = @code_hlo optimize = :probprog simulate_model(trace, seed, μ, σ, shape) + after = @code_hlo optimize = :probprog ProbProg.simulate_internal( + model, seed, μ, σ, shape; trace = ProbProg.ProbProgTrace() + ) @test !contains(repr(after), "enzyme.simulate") - @test !contains(repr(after), "enzyme.sample") - @test contains(repr(after), "enzyme_probprog_add_sample_to_trace") end @testset "normal_simulate" begin @@ -40,34 +35,30 @@ using Reactant: ProbProg μ = Reactant.ConcreteRNumber(0.0) σ = Reactant.ConcreteRNumber(1.0) - trace = ProbProg.create_trace() - - result = Array(@jit optimize = :probprog simulate_model(trace, seed, μ, σ, shape)) + trace = ProbProg.simulate(model, seed, μ, σ, shape) - @test size(result) == shape - @test haskey(trace, :s) - @test haskey(trace, :t) - @test size(trace[:s]) == shape - @test size(trace[:t]) == shape + @test size(trace.retval) == shape + @test haskey(trace.choices, :s) + @test haskey(trace.choices, :t) + @test size(trace.choices[:s]) == shape + @test size(trace.choices[:t]) == shape end @testset "correctness" begin op(x, y) = x * y' function fake_model(x, y) - return ProbProg.sample!(op, x, y; symbol=:matmul) + return ProbProg.sample(op, x, y; symbol=:matmul) end - trace = ProbProg.create_trace() x = reshape(collect(Float64, 1:12), (4, 3)) y = reshape(collect(Float64, 1:12), (4, 3)) x_ra = Reactant.to_rarray(x) y_ra = Reactant.to_rarray(y) - @test Array( - @jit optimize = :probprog ProbProg.simulate!(fake_model, x_ra, y_ra; trace) - ) == op(x, y) + trace = ProbProg.simulate(fake_model, x_ra, y_ra) - @test haskey(trace, :matmul) - @test trace[:matmul] == op(x, y) + @test Array(trace.retval) == op(x, y) + @test haskey(trace.choices, :matmul) + @test trace.choices[:matmul] == op(x, y) end end From 561b051b6c0801160a5e8a5df824547924171651 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Fri, 20 Jun 2025 14:19:51 -0500 Subject: [PATCH 35/47] better print --- src/ProbProg.jl | 66 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/src/ProbProg.jl b/src/ProbProg.jl index 0dc2bdeffb..3373831a53 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -291,5 +291,71 @@ function simulate_internal( return result end +# Reference: https://github.com/probcomp/Gen.jl/blob/91d798f2d2f0c175b1be3dc6daf3a10a8acf5da3/src/choice_map.jl#L104 +function _show_pretty(io::IO, trace::ProbProgTrace, pre::Int, vert_bars::Tuple) + VERT = '\u2502' + PLUS = '\u251C' + HORZ = '\u2500' + LAST = '\u2514' + + indent_vert = vcat(Char[' ' for _ in 1:pre], Char[VERT, '\n']) + indent_vert_last = vcat(Char[' ' for _ in 1:pre], Char[VERT, '\n']) + indent = vcat(Char[' ' for _ in 1:pre], Char[PLUS, HORZ, HORZ, ' ']) + indent_last = vcat(Char[' ' for _ in 1:pre], Char[LAST, HORZ, HORZ, ' ']) + + for i in vert_bars + indent_vert[i] = VERT + indent[i] = VERT + indent_last[i] = VERT + end + + indent_vert_str = join(indent_vert) + indent_str = join(indent) + indent_last_str = join(indent_last) + + sorted_choices = sort(collect(trace.choices); by=x -> x[1]) + n = length(sorted_choices) + + if trace.retval !== nothing + n += 1 + end + + cur = 1 + + if trace.retval !== nothing + print(io, indent_vert_str) + print(io, (cur == n ? indent_last_str : indent_str) * "retval : $(trace.retval)\n") + cur += 1 + end + + for (key, value) in sorted_choices + print(io, indent_vert_str) + print(io, (cur == n ? indent_last_str : indent_str) * "$(repr(key)) : $value\n") + cur += 1 + end +end + +function Base.show(io::IO, ::MIME"text/plain", trace::ProbProgTrace) + println(io, "ProbProgTrace:") + if isempty(trace.choices) && trace.retval === nothing + println(io, " (empty)") + else + _show_pretty(io, trace, 0, ()) + end +end + +function Base.show(io::IO, trace::ProbProgTrace) + if get(io, :compact, false) + choices_count = length(trace.choices) + has_retval = trace.retval !== nothing + print(io, "ProbProgTrace($(choices_count) choices") + if has_retval + print(io, ", retval=$(trace.retval)") + end + print(io, ")") + else + show(io, MIME"text/plain"(), trace) + end +end end From 99d7608c10cb14d3c033864f2996c6320a21458a Mon Sep 17 00:00:00 2001 From: sbrantq Date: Wed, 25 Jun 2025 17:47:19 -0500 Subject: [PATCH 36/47] unconstrained real generate op --- src/ProbProg.jl | 109 +++++++++++++++++++++++++++++++++----- test/probprog/generate.jl | 57 ++++---------------- 2 files changed, 105 insertions(+), 61 deletions(-) diff --git a/src/ProbProg.jl b/src/ProbProg.jl index 3373831a53..1b7986d3ab 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -1,15 +1,22 @@ module ProbProg -using ..Reactant: MLIR, TracedUtils, AbstractConcreteArray, AbstractConcreteNumber +using ..Reactant: + MLIR, + TracedUtils, + AbstractConcreteArray, + AbstractConcreteNumber, + AbstractRNG, + TracedRArray using ..Compiler: @jit using Enzyme mutable struct ProbProgTrace choices::Dict{Symbol,Any} retval::Any + weight::Any function ProbProgTrace() - return new(Dict{Symbol,Any}(), nothing) + return new(Dict{Symbol,Any}(), nothing, nothing) end end @@ -63,7 +70,10 @@ function __init__() end function sample( - f::Function, args::Vararg{Any,Nargs}; symbol::Symbol=gensym("sample") + f::Function, + args::Vararg{Any,Nargs}; + symbol::Symbol=gensym("sample"), + logpdf::Union{Nothing,Function}=nothing, ) where {Nargs} argprefix::Symbol = gensym("samplearg") resprefix::Symbol = gensym("sampleresult") @@ -102,6 +112,7 @@ function sample( sym = TracedUtils.get_attribute_by_name(func2, "sym_name") fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(sym)) + # Specify which outputs to add to the trace. traced_output_indices = Int[] for (i, res) in enumerate(linear_results) if TracedUtils.has_idx(res, resprefix) @@ -109,13 +120,60 @@ function sample( end end + # Specify which inputs to pass to logpdf. + traced_input_indices = Int[] + for (i, a) in enumerate(linear_args) + idx, _ = TracedUtils.get_argidx(a, argprefix) + if fnwrap && idx == 1 # TODO: add test for fnwrap + continue + end + + if fnwrap + idx -= 1 + end + + if !(args[idx] isa AbstractRNG) + push!(traced_input_indices, i - 1) + end + end + symbol_addr = reinterpret(UInt64, pointer_from_objref(symbol)) + # Construct MLIR attribute if Julia logpdf function is provided. + logpdf_attr = nothing + if logpdf !== nothing + # Just to get static information about the sample. TODO: kwargs? + example_sample = f(args...) + + # Remove AbstractRNG from `f`'s argument list if present, assuming that + # logpdf parameters follows `(sample, args...)` convention. + logpdf_args = (example_sample,) + if !isempty(args) && args[1] isa AbstractRNG + logpdf_args = (example_sample, Base.tail(args)...) # TODO: kwargs? + end + + logpdf_mlir = invokelatest( + TracedUtils.make_mlir_fn, + logpdf, + logpdf_args, + (), + string(logpdf), + false; + do_transpose=false, + args_in_result=:all, + ) + + logpdf_sym = TracedUtils.get_attribute_by_name(logpdf_mlir.f, "sym_name") + logpdf_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(logpdf_sym)) + end + sample_op = MLIR.Dialects.enzyme.sample( batch_inputs; outputs=out_tys, fn=fn_attr, + logpdf=logpdf_attr, symbol=symbol_addr, + traced_input_indices=traced_input_indices, traced_output_indices=traced_output_indices, ) @@ -143,11 +201,19 @@ function sample( end function generate(f::Function, args::Vararg{Any,Nargs}) where {Nargs} - res = @jit optimize = :probprog generate_internal(f, args...) - return res isa AbstractConcreteArray ? Array(res) : res + trace = ProbProgTrace() + + weight, res = @jit optimize = :probprog generate_internal(f, args...; trace) + + trace.retval = res isa AbstractConcreteArray ? Array(res) : res + trace.weight = Array(weight)[1] + + return trace, trace.weight end -function generate_internal(f::Function, args::Vararg{Any,Nargs}) where {Nargs} +function generate_internal( + f::Function, args::Vararg{Any,Nargs}; trace::ProbProgTrace +) where {Nargs} argprefix::Symbol = gensym("generatearg") resprefix::Symbol = gensym("generateresult") resargprefix::Symbol = gensym("generateresarg") @@ -169,7 +235,8 @@ function generate_internal(f::Function, args::Vararg{Any,Nargs}) where {Nargs} fnwrap = mlir_fn_res.fnwrapped func2 = mlir_fn_res.f - out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] + f_out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] + out_tys = [MLIR.IR.TensorType(Int64[], MLIR.IR.Type(Float64)); f_out_tys] fname = TracedUtils.get_attribute_by_name(func2, "sym_name") fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) @@ -186,10 +253,17 @@ function generate_internal(f::Function, args::Vararg{Any,Nargs}) where {Nargs} end end - gen_op = MLIR.Dialects.enzyme.generate(batch_inputs; outputs=out_tys, fn=fname) + trace_addr = reinterpret(UInt64, pointer_from_objref(trace)) + + # Output: (weight, f's outputs...) + gen_op = MLIR.Dialects.enzyme.generate( + batch_inputs; outputs=out_tys, fn=fname, trace=trace_addr + ) + + weight = TracedRArray(MLIR.IR.result(gen_op, 1)) for (i, res) in enumerate(linear_results) - resv = MLIR.IR.result(gen_op, i) + resv = MLIR.IR.result(gen_op, i + 1) # to skip weight if TracedUtils.has_idx(res, resprefix) path = TracedUtils.get_idx(res, resprefix) TracedUtils.set!(result, path[2:end], resv) @@ -208,7 +282,7 @@ function generate_internal(f::Function, args::Vararg{Any,Nargs}) where {Nargs} end end - return result + return weight, result end function simulate(f::Function, args::Vararg{Any,Nargs}) where {Nargs} @@ -299,7 +373,6 @@ function _show_pretty(io::IO, trace::ProbProgTrace, pre::Int, vert_bars::Tuple) LAST = '\u2514' indent_vert = vcat(Char[' ' for _ in 1:pre], Char[VERT, '\n']) - indent_vert_last = vcat(Char[' ' for _ in 1:pre], Char[VERT, '\n']) indent = vcat(Char[' ' for _ in 1:pre], Char[PLUS, HORZ, HORZ, ' ']) indent_last = vcat(Char[' ' for _ in 1:pre], Char[LAST, HORZ, HORZ, ' ']) @@ -320,6 +393,10 @@ function _show_pretty(io::IO, trace::ProbProgTrace, pre::Int, vert_bars::Tuple) n += 1 end + if trace.weight !== nothing + n += 1 + end + cur = 1 if trace.retval !== nothing @@ -328,6 +405,12 @@ function _show_pretty(io::IO, trace::ProbProgTrace, pre::Int, vert_bars::Tuple) cur += 1 end + if trace.weight !== nothing + print(io, indent_vert_str) + print(io, (cur == n ? indent_last_str : indent_str) * "weight : $(trace.weight)\n") + cur += 1 + end + for (key, value) in sorted_choices print(io, indent_vert_str) print(io, (cur == n ? indent_last_str : indent_str) * "$(repr(key)) : $value\n") @@ -337,7 +420,7 @@ end function Base.show(io::IO, ::MIME"text/plain", trace::ProbProgTrace) println(io, "ProbProgTrace:") - if isempty(trace.choices) && trace.retval === nothing + if isempty(trace.choices) && trace.retval === nothing && trace.weight === nothing println(io, " (empty)") else _show_pretty(io, trace, 0, ()) @@ -350,7 +433,7 @@ function Base.show(io::IO, trace::ProbProgTrace) has_retval = trace.retval !== nothing print(io, "ProbProgTrace($(choices_count) choices") if has_retval - print(io, ", retval=$(trace.retval)") + print(io, ", retval=$(trace.retval), weight=$(trace.weight)") end print(io, ")") else diff --git a/test/probprog/generate.jl b/test/probprog/generate.jl index 9c93c7a6f5..8c1a8917a4 100644 --- a/test/probprog/generate.jl +++ b/test/probprog/generate.jl @@ -2,81 +2,42 @@ using Reactant, Test, Random, Statistics using Reactant: ProbProg normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) +normal_logpdf(x, μ, σ, _) = -sum(log.(σ)) - sum((μ .- x) .^ 2) / (2 * σ^2) function model(seed, μ, σ, shape) rng = Random.default_rng() Random.seed!(rng, seed) - s = ProbProg.sample(normal, rng, μ, σ, shape) - t = ProbProg.sample(normal, rng, s, σ, shape) + s = ProbProg.sample(normal, rng, μ, σ, shape; symbol=:s, logpdf=normal_logpdf) + t = ProbProg.sample(normal, rng, s, σ, shape; symbol=:t, logpdf=normal_logpdf) return t end @testset "Generate" begin - @testset "deterministic" begin - shape = (10000,) - seed1 = Reactant.to_rarray(UInt64[1, 4]) - seed2 = Reactant.to_rarray(UInt64[1, 4]) - μ1 = Reactant.ConcreteRNumber(0.0) - μ2 = Reactant.ConcreteRNumber(1000.0) - σ1 = Reactant.ConcreteRNumber(1.0) - σ2 = Reactant.ConcreteRNumber(1.0) - - generate_model(seed, μ, σ, shape) = - ProbProg.generate_internal(model, seed, μ, σ, shape) - - model_compiled = @compile optimize = :probprog generate_model(seed1, μ1, σ1, shape) - - @test Array(model_compiled(seed1, μ1, σ1, shape)) ≈ - Array(model_compiled(seed1, μ1, σ1, shape)) - @test mean(Array(model_compiled(seed1, μ1, σ1, shape))) ≈ 0.0 atol = 0.05 rtol = - 0.05 - @test mean(Array(model_compiled(seed2, μ2, σ2, shape))) ≈ 1000.0 atol = 0.05 rtol = - 0.05 - @test !(all( - Array(model_compiled(seed1, μ1, σ1, shape)) .≈ - Array(model_compiled(seed2, μ2, σ2, shape)), - )) - end @testset "hlo" begin - shape = (10000,) + shape = (10,) seed = Reactant.to_rarray(UInt64[1, 4]) μ = Reactant.ConcreteRNumber(0.0) σ = Reactant.ConcreteRNumber(1.0) before = @code_hlo optimize = :no_enzyme ProbProg.generate_internal( - model, seed, μ, σ, shape + model, seed, μ, σ, shape; trace=ProbProg.ProbProgTrace() ) @test contains(repr(before), "enzyme.generate") @test contains(repr(before), "enzyme.sample") after = @code_hlo optimize = :probprog ProbProg.generate_internal( - model, seed, μ, σ, shape + model, seed, μ, σ, shape; trace=ProbProg.ProbProgTrace() ) @test !contains(repr(after), "enzyme.generate") @test !contains(repr(after), "enzyme.sample") end @testset "normal" begin - shape = (10000,) + shape = (1000,) seed = Reactant.to_rarray(UInt64[1, 4]) μ = Reactant.ConcreteRNumber(0.0) σ = Reactant.ConcreteRNumber(1.0) - X = ProbProg.generate(model, seed, μ, σ, shape) - @test mean(X) ≈ 0.0 atol = 0.05 rtol = 0.05 - end - - @testset "correctness" begin - op(x, y) = x * y' - - function fake_model(x, y) - return ProbProg.sample(op, x, y) - end - - x = reshape(collect(Float64, 1:12), (4, 3)) - y = reshape(collect(Float64, 1:12), (4, 3)) - x_ra = Reactant.to_rarray(x) - y_ra = Reactant.to_rarray(y) - - @test ProbProg.generate(fake_model, x_ra, y_ra) == op(x, y) + trace, weight = ProbProg.generate(model, seed, μ, σ, shape) + @test mean(trace.retval) ≈ 0.0 atol = 0.05 rtol = 0.05 end end From b13f8bf58a700876a76d47b28069e5acfeab7346 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Wed, 25 Jun 2025 17:47:41 -0500 Subject: [PATCH 37/47] probprog postpasses --- src/Compiler.jl | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index a4e2ff44a5..bd496ba0f1 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1183,6 +1183,7 @@ end # TODO we want to be able to run the more advanced passes via transform dialect as an enzyme intermediate # However, this errs as we cannot attach the transform with to the funcop itself [as we run a functionpass]. const enzyme_pass::String = "enzyme{postpasses=\"arith-raise{stablehlo=true},canonicalize,cse,canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,canonicalize,cse,canonicalize\"}" +const probprog_pass::String = "probprog{postpasses=\"arith-raise{stablehlo=true},canonicalize,cse,canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,canonicalize,cse,canonicalize\"}" function run_pass_pipeline!(mod, pass_pipeline, key=""; enable_verifier=true) pm = MLIR.IR.PassManager() @@ -1753,7 +1754,7 @@ function compile_mlir!( "enzyme-batch", opt_passes2, enzyme_pass, - "probprog", + probprog_pass, opt_passes2, "canonicalize", "remove-unnecessary-enzyme-ops", @@ -1767,7 +1768,7 @@ function compile_mlir!( "enzyme-batch", opt_passes2, enzyme_pass, - "probprog", + probprog_pass, opt_passes2, "canonicalize", "remove-unnecessary-enzyme-ops", @@ -1794,7 +1795,7 @@ function compile_mlir!( "enzyme-batch", opt_passes2, enzyme_pass, - "probprog", + probprog_pass, opt_passes2, "canonicalize", "remove-unnecessary-enzyme-ops", @@ -1811,7 +1812,7 @@ function compile_mlir!( "enzyme-batch", opt_passes2, enzyme_pass, - "probprog", + probprog_pass, opt_passes2, "canonicalize", "remove-unnecessary-enzyme-ops", From 6e4dc0c4d56efe81f984ce3a42e5e56076f5814b Mon Sep 17 00:00:00 2001 From: sbrantq Date: Thu, 26 Jun 2025 13:21:24 -0500 Subject: [PATCH 38/47] bug fix for alising outputs --- src/ProbProg.jl | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/src/ProbProg.jl b/src/ProbProg.jl index 1b7986d3ab..18de59758a 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -139,6 +139,25 @@ function sample( symbol_addr = reinterpret(UInt64, pointer_from_objref(symbol)) + # (out_idx1, in_idx1, out_idx2, in_idx2, ...) + alias_pairs = Int64[] + for (out_idx, res) in enumerate(linear_results) + if TracedUtils.has_idx(res, argprefix) + in_idx = nothing + for (i, arg) in enumerate(linear_args) + if TracedUtils.has_idx(arg, argprefix) && + TracedUtils.get_idx(arg, argprefix) == TracedUtils.get_idx(res, argprefix) + in_idx = i - 1 + break + end + end + @assert in_idx !== nothing "Unable to find operand for aliased result" + push!(alias_pairs, out_idx - 1) + push!(alias_pairs, in_idx) + end + end + alias_attr = MLIR.IR.DenseArrayAttribute(alias_pairs) + # Construct MLIR attribute if Julia logpdf function is provided. logpdf_attr = nothing if logpdf !== nothing @@ -175,6 +194,8 @@ function sample( symbol=symbol_addr, traced_input_indices=traced_input_indices, traced_output_indices=traced_output_indices, + alias_map=alias_attr, + name=Base.String(symbol), ) for (i, res) in enumerate(linear_results) From 5b5c1d15938b33336428b65f3750e63064f661e3 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Thu, 26 Jun 2025 13:26:48 -0500 Subject: [PATCH 39/47] generate op with constraints --- deps/ReactantExtra/API.cpp | 14 ++++++++++ src/ProbProg.jl | 52 +++++++++++++++++++++++++++++++++----- test/probprog/generate.jl | 19 ++++++++++++++ 3 files changed, 79 insertions(+), 6 deletions(-) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index 911b21ae77..9658522757 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -353,6 +353,20 @@ enzymeActivityAttrGet(MlirContext ctx, int32_t val) { (mlir::enzyme::Activity)val)); } +extern "C" MLIR_CAPI_EXPORTED MlirAttribute enzymeConstraintAttrGet( + MlirContext ctx, uint64_t symbol, MlirAttribute values) { + mlir::Attribute vals = unwrap(values); + auto arr = llvm::dyn_cast(vals); + if (!arr) { + ReactantThrowError( + "enzymeConstraintAttrGet: `values` must be an ArrayAttr"); + return MlirAttribute{nullptr}; + } + mlir::Attribute attr = + mlir::enzyme::ConstraintAttr::get(unwrap(ctx), symbol, arr); + return wrap(attr); +} + // Create profiler session and start profiling extern "C" tsl::ProfilerSession * CreateProfilerSession(uint32_t device_tracer_level, diff --git a/src/ProbProg.jl b/src/ProbProg.jl index 18de59758a..6e161676f1 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -146,7 +146,8 @@ function sample( in_idx = nothing for (i, arg) in enumerate(linear_args) if TracedUtils.has_idx(arg, argprefix) && - TracedUtils.get_idx(arg, argprefix) == TracedUtils.get_idx(res, argprefix) + TracedUtils.get_idx(arg, argprefix) == + TracedUtils.get_idx(res, argprefix) in_idx = i - 1 break end @@ -221,10 +222,12 @@ function sample( return result end -function generate(f::Function, args::Vararg{Any,Nargs}) where {Nargs} +function generate(f::Function, args::Vararg{Any,Nargs}; constraints=nothing) where {Nargs} trace = ProbProgTrace() - weight, res = @jit optimize = :probprog generate_internal(f, args...; trace) + weight, res = @jit sync = true optimize = :probprog generate_internal( + f, args...; trace, constraints + ) trace.retval = res isa AbstractConcreteArray ? Array(res) : res trace.weight = Array(weight)[1] @@ -233,7 +236,7 @@ function generate(f::Function, args::Vararg{Any,Nargs}) where {Nargs} end function generate_internal( - f::Function, args::Vararg{Any,Nargs}; trace::ProbProgTrace + f::Function, args::Vararg{Any,Nargs}; trace::ProbProgTrace, constraints=nothing ) where {Nargs} argprefix::Symbol = gensym("generatearg") resprefix::Symbol = gensym("generateresult") @@ -276,9 +279,46 @@ function generate_internal( trace_addr = reinterpret(UInt64, pointer_from_objref(trace)) - # Output: (weight, f's outputs...) + constraints_attr = nothing + if constraints !== nothing && !isempty(constraints) + constraint_attrs = MLIR.IR.Attribute[] + + for (sym, constraint) in constraints + sym_addr = reinterpret(UInt64, pointer_from_objref(sym)) + + if !(constraint isa AbstractArray) + error( + "Constraints must be an array (one element per traced output) of arrays" + ) + end + + sym_constraint_attrs = MLIR.IR.Attribute[] + for oc in constraint + if !(oc isa AbstractArray) + error("Per-output constraints must be arrays") + end + + push!(sym_constraint_attrs, MLIR.IR.DenseElementsAttribute(oc)) + end + + cattr_ptr = @ccall MLIR.API.mlir_c.enzymeConstraintAttrGet( + MLIR.IR.context()::MLIR.API.MlirContext, + sym_addr::UInt64, + MLIR.IR.Attribute(sym_constraint_attrs)::MLIR.API.MlirAttribute, + )::MLIR.API.MlirAttribute + + push!(constraint_attrs, MLIR.IR.Attribute(cattr_ptr)) + end + + constraints_attr = MLIR.IR.Attribute(constraint_attrs) + end + gen_op = MLIR.Dialects.enzyme.generate( - batch_inputs; outputs=out_tys, fn=fname, trace=trace_addr + batch_inputs; + outputs=out_tys, + fn=fname, + trace=trace_addr, + constraints=constraints_attr, ) weight = TracedRArray(MLIR.IR.result(gen_op, 1)) diff --git a/test/probprog/generate.jl b/test/probprog/generate.jl index 8c1a8917a4..5ed4f662fc 100644 --- a/test/probprog/generate.jl +++ b/test/probprog/generate.jl @@ -40,4 +40,23 @@ end trace, weight = ProbProg.generate(model, seed, μ, σ, shape) @test mean(trace.retval) ≈ 0.0 atol = 0.05 rtol = 0.05 end + + @testset "constraints" begin + shape = (10,) + seed = Reactant.to_rarray(UInt64[1, 4]) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + s_constraint = fill(0.1, shape) + constraints = Dict(:s => [s_constraint]) + + trace, weight = ProbProg.generate(model, seed, μ, σ, shape; constraints) + + @test trace.choices[:s] == s_constraint + + expected_weight = + normal_logpdf(s_constraint, 0.0, 1.0, shape) + + normal_logpdf(trace.choices[:t], s_constraint, 1.0, shape) + @test weight ≈ expected_weight atol = 1e-6 + end end From 1ad167a3d3943672cc43bcc3c2a470d7ac25eb0e Mon Sep 17 00:00:00 2001 From: sbrantq Date: Thu, 26 Jun 2025 14:58:13 -0500 Subject: [PATCH 40/47] untraced call --- src/ProbProg.jl | 69 +++++++++++++++++++++++++++++++++++++++++ test/probprog/sample.jl | 8 ++--- 2 files changed, 73 insertions(+), 4 deletions(-) diff --git a/src/ProbProg.jl b/src/ProbProg.jl index 6e161676f1..262482690c 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -222,6 +222,75 @@ function sample( return result end +function call(f::Function, args::Vararg{Any,Nargs}) where {Nargs} + res = @jit optimize = :probprog call_internal(f, args...) + return res isa AbstractConcreteArray ? Array(res) : res +end + +function call_internal(f::Function, args::Vararg{Any,Nargs}) where {Nargs} + argprefix::Symbol = gensym("callarg") + resprefix::Symbol = gensym("callresult") + resargprefix::Symbol = gensym("callresarg") + + mlir_fn_res = invokelatest( + TracedUtils.make_mlir_fn, + f, + args, + (), + string(f), + false; + do_transpose=false, + args_in_result=:all, + argprefix, + resprefix, + resargprefix, + ) + (; result, linear_args, in_tys, linear_results) = mlir_fn_res + fnwrap = mlir_fn_res.fnwrapped + func2 = mlir_fn_res.f + + out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] + fname = TracedUtils.get_attribute_by_name(func2, "sym_name") + fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) + + batch_inputs = MLIR.IR.Value[] + for a in linear_args + idx, path = TracedUtils.get_argidx(a, argprefix) + if idx == 1 && fnwrap + TracedUtils.push_val!(batch_inputs, f, path[3:end]) + else + if fnwrap + idx -= 1 + end + TracedUtils.push_val!(batch_inputs, args[idx], path[3:end]) + end + end + + call_op = MLIR.Dialects.enzyme.untracedCall(batch_inputs; outputs=out_tys, fn=fname) + + for (i, res) in enumerate(linear_results) + resv = MLIR.IR.result(call_op, i) + if TracedUtils.has_idx(res, resprefix) + path = TracedUtils.get_idx(res, resprefix) + TracedUtils.set!(result, path[2:end], resv) + elseif TracedUtils.has_idx(res, argprefix) + idx, path = TracedUtils.get_argidx(res, argprefix) + if idx == 1 && fnwrap + TracedUtils.set!(f, path[3:end], resv) + else + if fnwrap + idx -= 1 + end + TracedUtils.set!(args[idx], path[3:end], resv) + end + else + TracedUtils.set!(res, (), resv) + end + end + + return result +end + function generate(f::Function, args::Vararg{Any,Nargs}; constraints=nothing) where {Nargs} trace = ProbProgTrace() diff --git a/test/probprog/sample.jl b/test/probprog/sample.jl index 9541b2feb8..904d2a7ccd 100644 --- a/test/probprog/sample.jl +++ b/test/probprog/sample.jl @@ -24,9 +24,9 @@ end seed = Reactant.to_rarray(UInt64[1, 4]) μ = Reactant.ConcreteRNumber(0.0) σ = Reactant.ConcreteRNumber(1.0) - before = @code_hlo optimize = false ProbProg.generate_internal(one_sample, seed, μ, σ, shape) + before = @code_hlo optimize = false ProbProg.call_internal(one_sample, seed, μ, σ, shape) @test contains(repr(before), "enzyme.sample") - after = @code_hlo optimize = :probprog ProbProg.generate_internal(two_samples, seed, μ, σ, shape) + after = @code_hlo optimize = :probprog ProbProg.call_internal(two_samples, seed, μ, σ, shape) @test !contains(repr(after), "enzyme.sample") end @@ -35,8 +35,8 @@ end seed = Reactant.to_rarray(UInt64[1, 4]) μ = Reactant.ConcreteRNumber(0.0) σ = Reactant.ConcreteRNumber(1.0) - X = ProbProg.generate(one_sample, seed, μ, σ, shape) - Y = ProbProg.generate(two_samples, seed, μ, σ, shape) + X = ProbProg.call(one_sample, seed, μ, σ, shape) + Y = ProbProg.call(two_samples, seed, μ, σ, shape) @test !all(X .≈ Y) end end From 8f66b5f8813120d6d58b1ba9e9138ac68ef86eb7 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Thu, 26 Jun 2025 17:19:02 -0500 Subject: [PATCH 41/47] working metropolis hastings (with hacks) --- src/ProbProg.jl | 46 ++++++++++++++++-- test/probprog/linear_regression.jl | 77 ++++++++++++++++++++++++++++++ 2 files changed, 119 insertions(+), 4 deletions(-) create mode 100644 test/probprog/linear_regression.jl diff --git a/src/ProbProg.jl b/src/ProbProg.jl index 262482690c..c7a04abe24 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -14,10 +14,14 @@ mutable struct ProbProgTrace choices::Dict{Symbol,Any} retval::Any weight::Any + fn::Union{Nothing,Function} + args::Union{Nothing,Tuple} - function ProbProgTrace() - return new(Dict{Symbol,Any}(), nothing, nothing) + function ProbProgTrace(fn::Function, args::Tuple) + return new(Dict{Symbol,Any}(), nothing, nothing, fn, args) end + + ProbProgTrace() = new(Dict{Symbol,Any}(), nothing, nothing, nothing, ()) end function addSampleToTraceLowered( @@ -292,7 +296,7 @@ function call_internal(f::Function, args::Vararg{Any,Nargs}) where {Nargs} end function generate(f::Function, args::Vararg{Any,Nargs}; constraints=nothing) where {Nargs} - trace = ProbProgTrace() + trace = ProbProgTrace(f, (args...,)) weight, res = @jit sync = true optimize = :probprog generate_internal( f, args...; trace, constraints @@ -416,7 +420,7 @@ function generate_internal( end function simulate(f::Function, args::Vararg{Any,Nargs}) where {Nargs} - trace = ProbProgTrace() + trace = ProbProgTrace(f, (args...,)) res = @jit optimize = :probprog sync = true simulate_internal(f, args...; trace) @@ -571,4 +575,38 @@ function Base.show(io::IO, trace::ProbProgTrace) end end +struct Selection + symbols::Vector{Symbol} +end + +select(symbol::Symbol) = Selection([symbol]) + +choicemap() = Dict{Symbol,Any}() +get_choices(trace::ProbProgTrace) = trace.choices + +function metropolis_hastings(trace::ProbProgTrace, sel::Selection) + if trace.fn === nothing + error("MH requires a trace with fn and args recorded") + end + + constraints = Dict{Symbol,Any}() + for (sym, val) in trace.choices + sym in sel.symbols && continue + constraints[sym] = [val] + end + + new_trace, _ = generate(trace.fn, trace.args...; constraints) + rng_state = new_trace.retval[1] # TODO: this is a temporary hack + + log_alpha = new_trace.weight - trace.weight + + if log(rand()) < log_alpha + new_trace.args = (rng_state, new_trace.args[2:end]...) + return (new_trace, true) + else + trace.args = (rng_state, trace.args[2:end]...) + return (trace, false) + end +end + end diff --git a/test/probprog/linear_regression.jl b/test/probprog/linear_regression.jl new file mode 100644 index 0000000000..095e4d8aac --- /dev/null +++ b/test/probprog/linear_regression.jl @@ -0,0 +1,77 @@ +using Reactant, Test, Random +using Reactant: ProbProg + +# Reference: https://www.gen.dev/docs/stable/getting_started/linear_regression/ + +normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) +normal_logpdf(x, μ, σ, _) = -sum(log.(σ)) - sum((μ .- x) .^ 2) / (2 * σ^2) + +function my_model(seed, xs) + rng = Random.default_rng() + Random.seed!(rng, seed) + + slope = ProbProg.sample( + normal, rng, 0.0, 2.0, (1,); symbol=:slope, logpdf=normal_logpdf + ) + intercept = ProbProg.sample( + normal, rng, 0.0, 10.0, (1,); symbol=:intercept, logpdf=normal_logpdf + ) + + ys = ProbProg.sample( + normal, + rng, + slope .* xs .+ intercept, + 1.0, + (length(xs),); + symbol=:ys, + logpdf=normal_logpdf, + ) + + return rng.seed, ys +end + +function my_inference_program(xs, ys, num_iters) + xs_r = Reactant.to_rarray(xs) + + constraints = ProbProg.choicemap() + constraints[:ys] = [ys] + + seed = Reactant.to_rarray(UInt64[1, 4]) + + trace, _ = ProbProg.generate(my_model, seed, xs_r; constraints) + trace.args = (trace.retval[1], trace.args[2:end]...) # TODO: this is a temporary hack + + for i in 1:num_iters + trace, _ = ProbProg.metropolis_hastings(trace, ProbProg.select(:slope)) + trace, _ = ProbProg.metropolis_hastings(trace, ProbProg.select(:intercept)) + choices = ProbProg.get_choices(trace) + @show i, choices[:slope], choices[:intercept] + end + + choices = ProbProg.get_choices(trace) + return (choices[:slope], choices[:intercept]) +end + +@testset "linear_regression" begin + @testset "simulate" begin + seed = Reactant.to_rarray(UInt64[1, 4]) + + xs = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] + xs_r = Reactant.to_rarray(xs) + + trace = ProbProg.simulate(my_model, seed, xs_r) + + @test haskey(trace.choices, :slope) + @test haskey(trace.choices, :intercept) + @test haskey(trace.choices, :ys) + end + + @testset "inference" begin + xs = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] + ys = [8.23, 5.87, 3.99, 2.59, 0.23, -0.66, -3.53, -6.91, -7.24, -9.90] + + slope, intercept = my_inference_program(xs, ys, 1000) + + @show slope, intercept + end +end \ No newline at end of file From 850e3c4e42a5e7601ff0b49312861c7030370f16 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Fri, 27 Jun 2025 13:49:05 -0500 Subject: [PATCH 42/47] set julia rng --- test/probprog/linear_regression.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/probprog/linear_regression.jl b/test/probprog/linear_regression.jl index 095e4d8aac..f246fead50 100644 --- a/test/probprog/linear_regression.jl +++ b/test/probprog/linear_regression.jl @@ -55,6 +55,7 @@ end @testset "linear_regression" begin @testset "simulate" begin seed = Reactant.to_rarray(UInt64[1, 4]) + Random.seed!(42) # For Julia side RNG xs = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] xs_r = Reactant.to_rarray(xs) @@ -74,4 +75,4 @@ end @show slope, intercept end -end \ No newline at end of file +end From e1b3bcb2d0e02c15ea1fec2fd3882db7a48efbbf Mon Sep 17 00:00:00 2001 From: sbrantq Date: Fri, 27 Jun 2025 13:49:27 -0500 Subject: [PATCH 43/47] remove print --- test/probprog/blr.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/probprog/blr.jl b/test/probprog/blr.jl index 7c53aaafd7..615edb842d 100644 --- a/test/probprog/blr.jl +++ b/test/probprog/blr.jl @@ -38,6 +38,4 @@ end trace = ProbProg.simulate(blr, seed, N, K) @test size(Array(trace.retval)) == (N,) - - println(trace) end From 659b9637fee1d1da31c0c6c04d24bf5e7f09328a Mon Sep 17 00:00:00 2001 From: sbrantq Date: Fri, 27 Jun 2025 14:12:19 -0500 Subject: [PATCH 44/47] less iterations. hiding prints --- test/probprog/linear_regression.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/probprog/linear_regression.jl b/test/probprog/linear_regression.jl index f246fead50..a0efed9416 100644 --- a/test/probprog/linear_regression.jl +++ b/test/probprog/linear_regression.jl @@ -45,7 +45,7 @@ function my_inference_program(xs, ys, num_iters) trace, _ = ProbProg.metropolis_hastings(trace, ProbProg.select(:slope)) trace, _ = ProbProg.metropolis_hastings(trace, ProbProg.select(:intercept)) choices = ProbProg.get_choices(trace) - @show i, choices[:slope], choices[:intercept] + # @show i, choices[:slope], choices[:intercept] end choices = ProbProg.get_choices(trace) @@ -71,8 +71,8 @@ end xs = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] ys = [8.23, 5.87, 3.99, 2.59, 0.23, -0.66, -3.53, -6.91, -7.24, -9.90] - slope, intercept = my_inference_program(xs, ys, 1000) + slope, intercept = my_inference_program(xs, ys, 5) - @show slope, intercept + # @show slope, intercept end end From 537de49a6c205b9c8d14c7d5d84aae96a246e8e8 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Fri, 27 Jun 2025 14:12:45 -0500 Subject: [PATCH 45/47] add probprog test group --- test/runtests.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test/runtests.jl b/test/runtests.jl index 411cf443ea..e7998129f7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -60,4 +60,12 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) @safetestset "Lux Integration" include("nn/lux.jl") end end + + if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "probprog" + @safetestset "ProbProg Sample" include("probprog/sample.jl") + @safetestset "ProbProg BLR" include("probprog/blr.jl") + @safetestset "ProbProg Simulate" include("probprog/simulate.jl") + @safetestset "ProbProg Generate" include("probprog/generate.jl") + @safetestset "ProbProg Linear Regression" include("probprog/linear_regression.jl") + end end From 8260fee3d8ab2cc451c246e7dee5b2b1e9391e1e Mon Sep 17 00:00:00 2001 From: sbrantq Date: Fri, 27 Jun 2025 14:37:21 -0500 Subject: [PATCH 46/47] format --- test/probprog/sample.jl | 8 ++++++-- test/probprog/simulate.jl | 4 ++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/test/probprog/sample.jl b/test/probprog/sample.jl index 904d2a7ccd..ef212a63bf 100644 --- a/test/probprog/sample.jl +++ b/test/probprog/sample.jl @@ -24,9 +24,13 @@ end seed = Reactant.to_rarray(UInt64[1, 4]) μ = Reactant.ConcreteRNumber(0.0) σ = Reactant.ConcreteRNumber(1.0) - before = @code_hlo optimize = false ProbProg.call_internal(one_sample, seed, μ, σ, shape) + before = @code_hlo optimize = false ProbProg.call_internal( + one_sample, seed, μ, σ, shape + ) @test contains(repr(before), "enzyme.sample") - after = @code_hlo optimize = :probprog ProbProg.call_internal(two_samples, seed, μ, σ, shape) + after = @code_hlo optimize = :probprog ProbProg.call_internal( + two_samples, seed, μ, σ, shape + ) @test !contains(repr(after), "enzyme.sample") end diff --git a/test/probprog/simulate.jl b/test/probprog/simulate.jl index a97fc5ae8d..3fbdfdd1ad 100644 --- a/test/probprog/simulate.jl +++ b/test/probprog/simulate.jl @@ -19,12 +19,12 @@ end σ = Reactant.ConcreteRNumber(1.0) before = @code_hlo optimize = false ProbProg.simulate_internal( - model, seed, μ, σ, shape; trace = ProbProg.ProbProgTrace() + model, seed, μ, σ, shape; trace=ProbProg.ProbProgTrace() ) @test contains(repr(before), "enzyme.simulate") after = @code_hlo optimize = :probprog ProbProg.simulate_internal( - model, seed, μ, σ, shape; trace = ProbProg.ProbProgTrace() + model, seed, μ, σ, shape; trace=ProbProg.ProbProgTrace() ) @test !contains(repr(after), "enzyme.simulate") end From 0f9416668baf4d9b5002aaf2e263474037e75f34 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Fri, 27 Jun 2025 14:46:34 -0500 Subject: [PATCH 47/47] add probprog compile opt --- src/CompileOptions.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/CompileOptions.jl b/src/CompileOptions.jl index 30dfda915f..9b01785c11 100644 --- a/src/CompileOptions.jl +++ b/src/CompileOptions.jl @@ -221,6 +221,8 @@ function CompileOptions(; :canonicalize, :just_batch, :none, + :probprog, + :probprog_no_lowering, ] end