diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index c0a05848f280..43c857b39e0d 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -58,6 +58,8 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/TimeProfiler.h" +#include "llvm/Demangle/Demangle.h" + #if LLVM_VERSION_MAJOR >= 14 #define addAttribute addAttributeAtIndex #define hasAttribute hasAttributeAtIndex @@ -125,6 +127,10 @@ llvm::cl::opt llvm::cl::opt EnzymePrintDiffUse("enzyme-print-diffuse", cl::init(false), cl::Hidden, cl::desc("Print differential use analysis")); + +llvm::cl::opt + EnzymeRustDeallocName("rust-dealloc-name", cl::init(""), cl::Hidden, + cl::desc("Name of Rust deallocation function")); } SmallVector MD_ToCopy = { @@ -9472,17 +9478,35 @@ llvm::CallInst *freeKnownAllocation(llvm::IRBuilder<> &builder, GradientUtils *gutils) { assert(isAllocationFunction(allocationfn, TLI)); - if (allocationfn == "__rust_alloc" || allocationfn == "__rust_alloc_zeroed") { +#if LLVM_VERSION_MAJOR >= 17 + std::string demangledName = llvm::demangle(allocationfn); + if (demangledName == "__rustc::__rust_alloc" || demangledName == "__rustc::__rust_alloc_zeroed" ) { Type *VoidTy = Type::getVoidTy(tofree->getContext()); Type *IntPtrTy = orig->getType(); Type *RustSz = orig->getArgOperand(0)->getType(); Type *inTys[3] = {IntPtrTy, RustSz, RustSz}; auto FT = FunctionType::get(VoidTy, inTys, false); + if (EnzymeRustDeallocName == "") { + // Rust's (de)alloc names aren't stable. We expect rustc to set them + // for us, but if it fails to do so we instead search for it here. + for (auto &F : *builder.GetInsertBlock()->getParent()->getParent()) { + auto demangledName = llvm::demangle(F.getName()); + if (demangledName == "__rustc::__rust_dealloc") { + EnzymeRustDeallocName = F.getName(); + break; + } + } + if (EnzymeRustDeallocName == "") { + // If we can't find it, use the raw __rust_dealloc as a fallback. + // FIXME: Make this a hard error once we pass the right name from rustc. + EnzymeRustDeallocName = "__rust_dealloc"; + } + } Value *freevalue = builder.GetInsertBlock() ->getParent() ->getParent() - ->getOrInsertFunction("__rust_dealloc", FT) + ->getOrInsertFunction(EnzymeRustDeallocName, FT) .getCallee(); Value *vals[3]; vals[0] = builder.CreatePointerCast(tofree, IntPtrTy); @@ -9506,6 +9530,7 @@ llvm::CallInst *freeKnownAllocation(llvm::IRBuilder<> &builder, builder.Insert(freecall); return freecall; } +#endif if (allocationfn == "julia.gc_alloc_obj" || allocationfn == "jl_gc_alloc_typed" || allocationfn == "ijl_gc_alloc_typed" || diff --git a/enzyme/Enzyme/LibraryFuncs.h b/enzyme/Enzyme/LibraryFuncs.h index 27697bc41a53..8291e0a63dd7 100644 --- a/enzyme/Enzyme/LibraryFuncs.h +++ b/enzyme/Enzyme/LibraryFuncs.h @@ -26,6 +26,7 @@ #include #include #include +#include "llvm/Demangle/Demangle.h" #include #include #include @@ -55,8 +56,11 @@ static inline bool isAllocationFunction(const llvm::StringRef name, return true; if (name == "__size_returning_new_experiment") return true; - if (name == "__rust_alloc" || name == "__rust_alloc_zeroed") +#if LLVM_VERSION_MAJOR >= 17 + std::string demangledName = llvm::demangle(name.str()); + if (demangledName == "__rustc::__rust_alloc" || demangledName == "__rustc::__rust_alloc_zeroed") return true; +#endif if (name == "julia.gc_alloc_obj" || name == "jl_gc_alloc_typed" || name == "ijl_gc_alloc_typed") return true; @@ -131,7 +135,8 @@ static inline bool isDeallocationFunction(const llvm::StringRef name, return true; if (name == "_mlir_memref_to_llvm_free") return true; - if (name == "__rust_dealloc") + std::string demangledName = llvm::demangle(name.str()); + if (demangledName == "__rustc::__rust_dealloc") return true; if (name == "swift_release") return true; @@ -207,7 +212,12 @@ static inline void zeroKnownAllocation(llvm::IRBuilder<> &bb, assert(isAllocationFunction(funcName, TLI)); // Don't re-zero an already-zero buffer - if (funcName == "calloc" || funcName == "__rust_alloc_zeroed") +#if LLVM_VERSION_MAJOR >= 17 + std::string demangledName = llvm::demangle(funcName.str()); + if (demangledName == "__rustc::__rust_alloc_zeroed") + return; +#endif + if (funcName == "calloc") return; Value *allocSize = argValues[0];